├── .gitignore ├── .gitmodules ├── LICENSE ├── LICENSE_inria.md ├── README.md ├── arguments └── __init__.py ├── assets └── teaser.gif ├── convert.py ├── environment.yml ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── requirements.txt ├── run_cf3dgs.py ├── scene ├── __init__.py ├── camera_model.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py └── gaussian_model_cf.py ├── train.py ├── trainer ├── cf3dgs_trainer.py ├── losses.py └── trainer.py └── utils ├── camera_conversion.py ├── camera_utils.py ├── general_utils.py ├── geometry_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py ├── system_utils.py ├── utils_poses ├── ATE │ ├── align_trajectory.py │ ├── align_utils.py │ ├── compute_trajectory_errors.py │ ├── results_writer.py │ ├── trajectory_utils.py │ └── transformations.py ├── align_traj.py ├── comp_ate.py ├── lie_group_helper.py └── vis_cam_traj.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | demo 10 | checkpoints 11 | data/ 12 | external/ 13 | vis_ply/ 14 | vis/ 15 | debug/ 16 | debug_vis/ 17 | output/ 18 | output_sup/ 19 | wandb/ 20 | *.pth 21 | *.ply 22 | *.png 23 | *.jpg 24 | *.mp4 25 | *.ipynb_checkpoints 26 | *.ipynb -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/ashawkey/diff-gaussian-rasterization 7 | [submodule "submodules/SIBR_viewers"] 8 | path = submodules/SIBR_viewers 9 | url = https://gitlab.inria.fr/sibr/sibr_core 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | NVIDIA Source Code License for RADIO 4 | 5 | ======================================================================= 6 | 7 | 1. Definitions 8 | 9 | “Licensor” means any person or entity that distributes its Work. 10 | 11 | “Work” means (a) the original work of authorship made available under 12 | this license, which may include software, documentation, or other files, 13 | and (b) any additions to or derivative works thereof that are made 14 | available under this license. 15 | 16 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” 17 | have the meaning as provided under U.S. copyright law; provided, however, 18 | that for the purposes of this license, derivative works shall not include works 19 | that remain separable from, or merely link (or bind by name) to the 20 | interfaces of, the Work. 21 | 22 | Works are “made available” under this license by including in or with the Work 23 | either (a) a copyright notice referencing the applicability of 24 | this license to the Work, or (b) a copy of this license. 25 | 26 | 2. License Grant 27 | 28 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each 29 | Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, 30 | copyright license to use, reproduce, prepare derivative works of, publicly display, 31 | publicly perform, sublicense and distribute its Work and any resulting derivative 32 | works in any form. 33 | 34 | 3. Limitations 35 | 36 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under 37 | this license, (b) you include a complete copy of this license with your distribution, 38 | and (c) you retain without modification any copyright, patent, trademark, or 39 | attribution notices that are present in the Work. 40 | 41 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, 42 | reproduction, and distribution of your derivative works of the Work (“Your Terms”) only 43 | if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative 44 | works, and (b) you identify the specific derivative works that are subject to Your Terms. 45 | Notwithstanding Your Terms, this license (including the redistribution requirements in 46 | Section 3.1) will continue to apply to the Work itself. 47 | 48 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or 49 | intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation 50 | and its affiliates may use the Work and any derivative works commercially. 51 | As used herein, “non-commercially” means for research or evaluation purposes only. 52 | 53 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor 54 | (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that 55 | you allege are infringed by any Work, then your rights under this license from 56 | such Licensor (including the grant in Section 2.1) will terminate immediately. 57 | 58 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its 59 | affiliates’ names, logos, or trademarks, except as necessary to reproduce 60 | the notices described in this license. 61 | 62 | 3.6 Termination. If you violate any term of this license, then your rights under 63 | this license (including the grant in Section 2.1) will terminate immediately. 64 | 65 | 4. Disclaimer of Warranty. 66 | 67 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 68 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 69 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 70 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 71 | 72 | 5. Limitation of Liability. 73 | 74 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, 75 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR 76 | BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, 77 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 78 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS 79 | INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY 80 | OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE 81 | POSSIBILITY OF SUCH DAMAGES. 82 | 83 | ======================================================================= 84 | -------------------------------------------------------------------------------- /LICENSE_inria.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | 5 |

COLMAP-Free 3D Gaussian Splatting

6 |

7 | Yang Fu 8 | · 9 | Sifei Liu 10 | · 11 | Amey Kulkarni 12 | · 13 | Jan Kautz 14 |
15 | Alexei A. Efros 16 | · 17 | Xiaolong Wang 18 |

19 |

Paper | Video | Project Page

20 |
21 |

22 | 23 |

24 | 25 | Logo 26 | 27 |

28 | 29 | ## Installation 30 | 31 | ##### (Recommended) 32 | The codes have been tested on python 3.10, CUDA>=11.6. The simplest way to install all dependences is to use [anaconda](https://www.anaconda.com/) and [pip](https://pypi.org/project/pip/) in the following steps: 33 | 34 | ```bash 35 | conda create -n cf3dgs python=3.10 36 | conda activate cf3dgs 37 | conda install conda-forge::cudatoolkit-dev=11.7.0 38 | conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.7 -c pytorch -c nvidia 39 | git clone --recursive git@github.com:NVlabs/CF-3DGS.git 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ## Dataset Preparsion 44 | DATAROOT is `./data` by default. Please first make data folder by `mkdir data`. 45 | 46 | ### Tanks and Temples 47 | 48 | Download the data preprocessed by [Nope-NeRF](https://github.com/ActiveVisionLab/nope-nerf/?tab=readme-ov-file#Data) as below, and the data is saved into the `./data/Tanks` folder. 49 | ```bash 50 | wget https://www.robots.ox.ac.uk/~wenjing/Tanks.zip 51 | ``` 52 | 53 | ### CO3D 54 | Download our preprocessed [data](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/yafu_ucsd_edu/EftJV9Xpn0hNjmOiGKZuzyIBW5j6hAVEGhewc8aUcFShEA?e=x1aXVx), and put it saved into the `./data/co3d` folder. 55 | 56 | 57 | ## Run 58 | 59 | ### Training 60 | ```bash 61 | python run_cf3dgs.py -s data/Tanks/Francis \ # change the scene path 62 | --mode train \ 63 | --data_type tanks 64 | ``` 65 | 66 | ### Evaluation 67 | ```bash 68 | # pose estimation 69 | python run_cf3dgs.py --source data/Tanks/Francis \ 70 | --mode eval_pose \ 71 | --data_type tanks \ 72 | --model_path ${CKPT_PATH} 73 | # by default the checkpoint should be store in "./output/progressive/Tanks_Francis/chkpnt/ep00_init.pth" 74 | # novel view synthesis 75 | python run_cf3dgs.py --source data/Tanks/Francis \ 76 | --mode eval_nvs \ 77 | --data_type tanks \ 78 | --model_path ${CKPT_PATH} 79 | ``` 80 | We release some of the novel view synthesis results ([gdrive](https://drive.google.com/drive/folders/1p3WljCN90zrm1N5lO-24OLHmUFmFWntt?usp=sharing)) for comparison with future works. 81 | 82 | ### Run on your own video 83 | 84 | * To run CF-3DGS on your own video, you need to first convert your video to frames and save them to `./data/$CUSTOM_DATA/images/ 85 | ` 86 | 87 | * Camera intrincics can be obtained by running COLMAP (check details in `convert.py`). Otherwise, we provide a heuristic camera setting which should work for most landscope videos. 88 | 89 | * Run the following commands: 90 | 91 | ```bash 92 | python run_cf3dgs.py -s ./data/$CUSTOM_DATA/ \ # change to your data path 93 | --mode train \ 94 | --data_type custom 95 | ``` 96 | 97 | ## Acknowledgement 98 | Our render is built upon [3DGS](https://github.com/graphdeco-inria/gaussian-splatting). The data processing and visualization codes are partially borrowed from [Nope-NeRF](https://github.com/ActiveVisionLab/nope-nerf/). We thank all the authors for their great repos. 99 | 100 | ## Citation 101 | 102 | If you find this code helpful, please cite: 103 | 104 | ``` 105 | @InProceedings{Fu_2024_CVPR, 106 | author = {Fu, Yang and Liu, Sifei and Kulkarni, Amey and Kautz, Jan and Efros, Alexei A. and Wang, Xiaolong}, 107 | title = {COLMAP-Free 3D Gaussian Splatting}, 108 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 109 | month = {June}, 110 | year = {2024}, 111 | pages = {20796-20805} 112 | } 113 | ``` -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument( 34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true") 35 | else: 36 | group.add_argument( 37 | "--" + key, ("-" + key[0:1]), default=value, type=t) 38 | else: 39 | if t == bool: 40 | group.add_argument( 41 | "--" + key, default=value, action="store_true") 42 | else: 43 | group.add_argument("--" + key, default=value, type=t) 44 | 45 | def extract(self, args): 46 | group = GroupParams() 47 | for arg in vars(args).items(): 48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 49 | setattr(group, arg[0], arg[1]) 50 | return group 51 | 52 | 53 | class ModelParams(ParamGroup): 54 | def __init__(self, parser, sentinel=False): 55 | self.sh_degree = 3 56 | self._source_path = "" 57 | self._model_path = "" 58 | self._images = "images" 59 | self._resolution = -1 60 | self._white_background = False 61 | self.data_device = "cuda" 62 | self.eval = True 63 | self.rot_type = "6d" 64 | self.view_dependent = True 65 | # self.model_type = 'original' 66 | self.data_type = "tanks" 67 | self.depth_model_type = "dpt" 68 | self.mode = "train" 69 | # self.eval_nvs = False 70 | # self.eval_pose = False 71 | # self.vis_mesh = False 72 | # self.render_nvs = False 73 | self.traj_opt = "bspline" 74 | super().__init__(parser, "Loading Parameters", sentinel) 75 | 76 | def extract(self, args): 77 | g = super().extract(args) 78 | g.source_path = os.path.abspath(g.source_path) 79 | return g 80 | 81 | 82 | class PipelineParams(ParamGroup): 83 | def __init__(self, parser): 84 | self.convert_SHs_python = False 85 | self.compute_cov3D_python = False 86 | self.debug = False 87 | # self.mode = "color" 88 | self.use_gt_pcd = False 89 | self.use_mask = False 90 | self.use_ref_img = False 91 | self.init_mode = "rand" 92 | self.use_mono = True 93 | self.interval = 15 94 | self.expname = "" 95 | self.use_sampon = False 96 | self.refine = False 97 | self.distortion = False 98 | super().__init__(parser, "Pipeline Parameters") 99 | 100 | 101 | class OptimizationParams(ParamGroup): 102 | def __init__(self, parser): 103 | self.iterations = 30_000 104 | self.position_lr_init = 0.00016 105 | self.position_lr_final = 0.0000016 106 | self.position_lr_delay_mult = 0.01 107 | self.position_lr_max_steps = 30_000 108 | self.feature_lr = 0.0025 109 | self.opacity_lr = 0.05 110 | self.scaling_lr = 0.005 111 | self.rotation_lr = 0.001 112 | self.percent_dense = 0.01 113 | self.lambda_dssim = 0.2 114 | self.lambda_depth = 0.0 115 | self.lambda_dist_2nd_loss = 0.0 116 | self.lambda_pc = 0.0 117 | self.lambda_rgb_s = 0.0 118 | self.depth_loss_type = "invariant" 119 | # self.depth_loss_type = "l1" 120 | self.match_method = "dense" 121 | self.densification_interval = 100 122 | self.densify_interval = 500 123 | self.prune_interval = 2000 124 | self.opacity_reset_interval = 3000 125 | self.densify_from_iter = 500 126 | self.densify_until_iter = 15_000 127 | self.reset_until_iter = 15_000 128 | self.densify_grad_threshold = 0.0002 129 | super().__init__(parser, "Optimization Parameters") 130 | 131 | 132 | def get_combined_args(parser: ArgumentParser): 133 | cmdlne_string = sys.argv[1:] 134 | cfgfile_string = "Namespace()" 135 | args_cmdline = parser.parse_args(cmdlne_string) 136 | 137 | try: 138 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 139 | print("Looking for config file in", cfgfilepath) 140 | with open(cfgfilepath) as cfg_file: 141 | print("Config file found: {}".format(cfgfilepath)) 142 | cfgfile_string = cfg_file.read() 143 | except TypeError: 144 | print("Config file not found at") 145 | pass 146 | args_cfgfile = eval(cfgfile_string) 147 | 148 | merged_dict = vars(args_cfgfile).copy() 149 | for k, v in vars(args_cmdline).items(): 150 | if v != None: 151 | merged_dict[k] = v 152 | return Namespace(**merged_dict) 153 | -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CF-3DGS/886df529c7dabd337f0abbd660e81321a9a1c047/assets/teaser.gif -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import os 11 | import logging 12 | from argparse import ArgumentParser 13 | import shutil 14 | 15 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 16 | parser = ArgumentParser("Colmap converter") 17 | parser.add_argument("--no_gpu", action='store_true') 18 | parser.add_argument("--skip_matching", action='store_true') 19 | parser.add_argument("--source_path", "-s", required=True, type=str) 20 | parser.add_argument("--camera", default="OPENCV", type=str) 21 | parser.add_argument("--colmap_executable", default="", type=str) 22 | parser.add_argument("--resize", action="store_true") 23 | parser.add_argument("--magick_executable", default="", type=str) 24 | args = parser.parse_args() 25 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 26 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 27 | use_gpu = 1 if not args.no_gpu else 0 28 | 29 | if not args.skip_matching: 30 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 31 | 32 | ## Feature extraction 33 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 34 | "--database_path " + args.source_path + "/distorted/database.db \ 35 | --image_path " + args.source_path + "/input \ 36 | --ImageReader.single_camera 1 \ 37 | --ImageReader.camera_model " + args.camera + " \ 38 | --SiftExtraction.use_gpu " + str(use_gpu) 39 | exit_code = os.system(feat_extracton_cmd) 40 | if exit_code != 0: 41 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 42 | exit(exit_code) 43 | 44 | ## Feature matching 45 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 46 | --database_path " + args.source_path + "/distorted/database.db \ 47 | --SiftMatching.use_gpu " + str(use_gpu) 48 | exit_code = os.system(feat_matching_cmd) 49 | if exit_code != 0: 50 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 51 | exit(exit_code) 52 | 53 | ### Bundle adjustment 54 | # The default Mapper tolerance is unnecessarily large, 55 | # decreasing it speeds up bundle adjustment steps. 56 | mapper_cmd = (colmap_command + " mapper \ 57 | --database_path " + args.source_path + "/distorted/database.db \ 58 | --image_path " + args.source_path + "/input \ 59 | --output_path " + args.source_path + "/distorted/sparse \ 60 | --Mapper.ba_global_function_tolerance=0.000001") 61 | exit_code = os.system(mapper_cmd) 62 | if exit_code != 0: 63 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 64 | exit(exit_code) 65 | 66 | ### Image undistortion 67 | ## We need to undistort our images into ideal pinhole intrinsics. 68 | img_undist_cmd = (colmap_command + " image_undistorter \ 69 | --image_path " + args.source_path + "/input \ 70 | --input_path " + args.source_path + "/distorted/sparse/0 \ 71 | --output_path " + args.source_path + "\ 72 | --output_type COLMAP") 73 | exit_code = os.system(img_undist_cmd) 74 | if exit_code != 0: 75 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 76 | exit(exit_code) 77 | 78 | files = os.listdir(args.source_path + "/sparse") 79 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 80 | # Copy each file from the source directory to the destination directory 81 | for file in files: 82 | if file == '0': 83 | continue 84 | source_file = os.path.join(args.source_path, "sparse", file) 85 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 86 | shutil.move(source_file, destination_file) 87 | 88 | if(args.resize): 89 | print("Copying and resizing...") 90 | 91 | # Resize images. 92 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 93 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 94 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 95 | # Get the list of files in the source directory 96 | files = os.listdir(args.source_path + "/images") 97 | # Copy each file from the source directory to the destination directory 98 | for file in files: 99 | source_file = os.path.join(args.source_path, "images", file) 100 | 101 | destination_file = os.path.join(args.source_path, "images_2", file) 102 | shutil.copy2(source_file, destination_file) 103 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 104 | if exit_code != 0: 105 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 106 | exit(exit_code) 107 | 108 | destination_file = os.path.join(args.source_path, "images_4", file) 109 | shutil.copy2(source_file, destination_file) 110 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 111 | if exit_code != 0: 112 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 113 | exit(exit_code) 114 | 115 | destination_file = os.path.join(args.source_path, "images_8", file) 116 | shutil.copy2(source_file, destination_file) 117 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 118 | if exit_code != 0: 119 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 120 | exit(exit_code) 121 | 122 | print("Done.") 123 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussian_splatting 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - pip: 16 | - submodules/diff-gaussian-rasterization 17 | - submodules/simple-knn -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import torch 11 | import math 12 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 13 | from scene.gaussian_model import GaussianModel 14 | from utils.sh_utils import eval_sh 15 | from utils.image_utils import colorize 16 | from PIL import Image 17 | import pdb 18 | 19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, 20 | scaling_modifier = 1.0, override_color = None): 21 | """ 22 | Render the scene. 23 | 24 | Background tensor (bg_color) must be on GPU! 25 | """ 26 | 27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 28 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 29 | try: 30 | screenspace_points.retain_grad() 31 | except: 32 | pass 33 | 34 | # Set up rasterization configuration 35 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 36 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 37 | 38 | raster_settings = GaussianRasterizationSettings( 39 | image_height=int(viewpoint_camera.image_height), 40 | image_width=int(viewpoint_camera.image_width), 41 | tanfovx=tanfovx, 42 | tanfovy=tanfovy, 43 | bg=bg_color, 44 | scale_modifier=scaling_modifier, 45 | viewmatrix=viewpoint_camera.world_view_transform, 46 | projmatrix=viewpoint_camera.full_proj_transform, 47 | sh_degree=pc.active_sh_degree, 48 | campos=viewpoint_camera.camera_center, 49 | prefiltered=False, 50 | debug=pipe.debug 51 | ) 52 | 53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 54 | 55 | means3D = pc.get_xyz 56 | means2D = screenspace_points 57 | opacity = pc.get_opacity 58 | 59 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 60 | # scaling / rotation by the rasterizer. 61 | scales = None 62 | rotations = None 63 | cov3D_precomp = None 64 | if pipe.compute_cov3D_python: 65 | cov3D_precomp = pc.get_covariance(scaling_modifier) 66 | else: 67 | scales = pc.get_scaling 68 | rotations = pc.get_rotation 69 | 70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 72 | shs = None 73 | colors_precomp = None 74 | if override_color is None: 75 | if pipe.convert_SHs_python: 76 | pc_features = pc.get_features.transpose(1, 2) 77 | shs_view = pc_features.view(pc_features.shape[0], -1, (pc.max_sh_degree+1)**2) 78 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 79 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 80 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 81 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 82 | else: 83 | shs = pc.get_features 84 | else: 85 | colors_precomp = override_color 86 | 87 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 88 | render_out = rasterizer( 89 | means3D = means3D, 90 | means2D = means2D, 91 | shs = shs, 92 | colors_precomp = colors_precomp, 93 | opacities = opacity, 94 | scales = scales, 95 | rotations = rotations, 96 | cov3D_precomp = cov3D_precomp) 97 | 98 | if len(render_out) > 2: 99 | rendered_image, radii, rendered_depth, rendered_alpha = render_out 100 | return {"render": rendered_image, 101 | "viewspace_points": screenspace_points, 102 | "visibility_filter" : radii > 0, 103 | "radii": radii, 104 | "depth": rendered_depth, 105 | "alpha": rendered_alpha} 106 | else: 107 | rendered_image, radii = render_out 108 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 109 | # They will be excluded from value updates used in the splitting criteria. 110 | return {"render": rendered_image, 111 | "viewspace_points": screenspace_points, 112 | "visibility_filter" : radii > 0, 113 | "radii": radii} 114 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import torch 11 | import traceback 12 | import socket 13 | import json 14 | from scene.cameras import MiniCam 15 | 16 | host = "127.0.0.1" 17 | port = 6009 18 | 19 | conn = None 20 | addr = None 21 | 22 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 23 | 24 | def init(wish_host, wish_port): 25 | global host, port, listener 26 | host = wish_host 27 | port = wish_port 28 | listener.bind((host, port)) 29 | listener.listen() 30 | listener.settimeout(0) 31 | 32 | def try_connect(): 33 | global conn, addr, listener 34 | try: 35 | conn, addr = listener.accept() 36 | print(f"\nConnected by {addr}") 37 | conn.settimeout(None) 38 | except Exception as inst: 39 | pass 40 | 41 | def read(): 42 | global conn 43 | messageLength = conn.recv(4) 44 | messageLength = int.from_bytes(messageLength, 'little') 45 | message = conn.recv(messageLength) 46 | return json.loads(message.decode("utf-8")) 47 | 48 | def send(message_bytes, verify): 49 | global conn 50 | if message_bytes != None: 51 | conn.sendall(message_bytes) 52 | conn.sendall(len(verify).to_bytes(4, 'little')) 53 | conn.sendall(bytes(verify, 'ascii')) 54 | 55 | def receive(): 56 | message = read() 57 | 58 | width = message["resolution_x"] 59 | height = message["resolution_y"] 60 | 61 | if width != 0 and height != 0: 62 | try: 63 | do_training = bool(message["train"]) 64 | fovy = message["fov_y"] 65 | fovx = message["fov_x"] 66 | znear = message["z_near"] 67 | zfar = message["z_far"] 68 | do_shs_python = bool(message["shs_python"]) 69 | do_rot_scale_python = bool(message["rot_scale_python"]) 70 | keep_alive = bool(message["keep_alive"]) 71 | scaling_modifier = message["scaling_modifier"] 72 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 73 | world_view_transform[:,1] = -world_view_transform[:,1] 74 | world_view_transform[:,2] = -world_view_transform[:,2] 75 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 76 | full_proj_transform[:,1] = -full_proj_transform[:,1] 77 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 78 | except Exception as e: 79 | print("") 80 | traceback.print_exc() 81 | raise e 82 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 83 | else: 84 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | from pathlib import Path 11 | import os 12 | from PIL import Image 13 | import torch 14 | import torchvision.transforms.functional as tf 15 | from utils.loss_utils import ssim 16 | from lpipsPyTorch import lpips 17 | import json 18 | from tqdm import tqdm 19 | from utils.image_utils import psnr 20 | from argparse import ArgumentParser 21 | 22 | def readImages(renders_dir, gt_dir): 23 | renders = [] 24 | gts = [] 25 | image_names = [] 26 | for fname in os.listdir(renders_dir): 27 | render = Image.open(renders_dir / fname) 28 | gt = Image.open(gt_dir / fname) 29 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 30 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 31 | image_names.append(fname) 32 | return renders, gts, image_names 33 | 34 | def evaluate(model_paths): 35 | 36 | full_dict = {} 37 | per_view_dict = {} 38 | full_dict_polytopeonly = {} 39 | per_view_dict_polytopeonly = {} 40 | print("") 41 | 42 | for scene_dir in model_paths: 43 | try: 44 | print("Scene:", scene_dir) 45 | full_dict[scene_dir] = {} 46 | per_view_dict[scene_dir] = {} 47 | full_dict_polytopeonly[scene_dir] = {} 48 | per_view_dict_polytopeonly[scene_dir] = {} 49 | 50 | test_dir = Path(scene_dir) / "test" 51 | 52 | for method in os.listdir(test_dir): 53 | print("Method:", method) 54 | 55 | full_dict[scene_dir][method] = {} 56 | per_view_dict[scene_dir][method] = {} 57 | full_dict_polytopeonly[scene_dir][method] = {} 58 | per_view_dict_polytopeonly[scene_dir][method] = {} 59 | 60 | method_dir = test_dir / method 61 | gt_dir = method_dir/ "gt" 62 | renders_dir = method_dir / "renders" 63 | renders, gts, image_names = readImages(renders_dir, gt_dir) 64 | 65 | ssims = [] 66 | psnrs = [] 67 | lpipss = [] 68 | 69 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 70 | ssims.append(ssim(renders[idx], gts[idx])) 71 | psnrs.append(psnr(renders[idx], gts[idx])) 72 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 73 | 74 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 75 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 76 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 77 | print("") 78 | 79 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 80 | "PSNR": torch.tensor(psnrs).mean().item(), 81 | "LPIPS": torch.tensor(lpipss).mean().item()}) 82 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 83 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 84 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 85 | 86 | with open(scene_dir + "/results.json", 'w') as fp: 87 | json.dump(full_dict[scene_dir], fp, indent=True) 88 | with open(scene_dir + "/per_view.json", 'w') as fp: 89 | json.dump(per_view_dict[scene_dir], fp, indent=True) 90 | except: 91 | print("Unable to compute metrics for model", scene_dir) 92 | 93 | if __name__ == "__main__": 94 | device = torch.device("cuda:0") 95 | torch.cuda.set_device(device) 96 | 97 | # Set up command line argument parser 98 | parser = ArgumentParser(description="Training script parameters") 99 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 100 | args = parser.parse_args() 101 | evaluate(args.model_paths) 102 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import torch 11 | from scene import Scene 12 | import os 13 | from tqdm import tqdm 14 | from os import makedirs 15 | from gaussian_renderer import render 16 | import torchvision 17 | from utils.general_utils import safe_state 18 | from argparse import ArgumentParser 19 | from arguments import ModelParams, PipelineParams, get_combined_args 20 | from gaussian_renderer import GaussianModel 21 | 22 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 23 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 24 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 25 | 26 | makedirs(render_path, exist_ok=True) 27 | makedirs(gts_path, exist_ok=True) 28 | 29 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 30 | rendering = render(view, gaussians, pipeline, background)["render"] 31 | gt = view.original_image[0:3, :, :] 32 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 33 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 34 | 35 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): 36 | with torch.no_grad(): 37 | gaussians = GaussianModel(dataset.sh_degree) 38 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 39 | 40 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 41 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 42 | 43 | if not skip_train: 44 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 45 | 46 | if not skip_test: 47 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 48 | 49 | if __name__ == "__main__": 50 | # Set up command line argument parser 51 | parser = ArgumentParser(description="Testing script parameters") 52 | model = ModelParams(parser, sentinel=True) 53 | pipeline = PipelineParams(parser) 54 | parser.add_argument("--iteration", default=-1, type=int) 55 | parser.add_argument("--skip_train", action="store_true") 56 | parser.add_argument("--skip_test", action="store_true") 57 | parser.add_argument("--quiet", action="store_true") 58 | args = get_combined_args(parser) 59 | print("Rendering " + args.model_path) 60 | 61 | # Initialize system state (RNG) 62 | safe_state(args.quiet) 63 | 64 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.65.0 2 | Pillow 3 | opencv-python 4 | imageio 5 | matplotlib==3.7 6 | pyyaml 7 | plyfile 8 | kornia 9 | open3d 10 | timm==0.6.11 11 | submodules/diff-gaussian-rasterization 12 | submodules/simple-knn 13 | git+https://github.com/facebookresearch/pytorch3d.git@stable 14 | git+https://github.com/princeton-vl/lietorch.git -------------------------------------------------------------------------------- /run_cf3dgs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import sys 11 | from argparse import ArgumentParser, Namespace 12 | 13 | from trainer.cf3dgs_trainer import CFGaussianTrainer 14 | from arguments import ModelParams, PipelineParams, OptimizationParams 15 | 16 | import torch 17 | import pdb 18 | from datetime import datetime 19 | 20 | 21 | def contruct_pose(poses): 22 | n_trgt = poses.shape[0] 23 | # for i in range(n_trgt-1): 24 | # poses[i+1] = poses[i+1] @ torch.inverse(poses[i]) 25 | for i in range(n_trgt-1, 0, -1): 26 | poses = torch.cat( 27 | (poses[:i], poses[[i-1]]@poses[i:]), 0) 28 | return poses 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = ArgumentParser(description="Training script parameters") 33 | lp = ModelParams(parser) 34 | op = OptimizationParams(parser) 35 | pp = PipelineParams(parser) 36 | args = parser.parse_args(sys.argv[1:]) 37 | model_cfg = lp.extract(args) 38 | pipe_cfg = pp.extract(args) 39 | optim_cfg = op.extract(args) 40 | # hydrant/615_99120_197713 41 | # hydrant/106_12648_23157 42 | # teddybear/34_1403_4393 43 | data_path = model_cfg.source_path 44 | trainer = CFGaussianTrainer(data_path, model_cfg, pipe_cfg, optim_cfg) 45 | start_time = datetime.now() 46 | if model_cfg.mode == "train": 47 | trainer.train_from_progressive() 48 | elif model_cfg.mode == "render": 49 | trainer.render_nvs(traj_opt=model_cfg.traj_opt) 50 | elif model_cfg.mode == "eval_nvs": 51 | trainer.eval_nvs() 52 | elif model_cfg.mode == "eval_pose": 53 | trainer.eval_pose() 54 | end_time = datetime.now() 55 | print('Duration: {}'.format(end_time - start_time)) 56 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 39 | 40 | self.train_cameras = {} 41 | self.test_cameras = {} 42 | 43 | if os.path.exists(os.path.join(args.source_path, "sparse")): 44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 46 | print("Found transforms_train.json file, assuming Blender data set!") 47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 48 | else: 49 | assert False, "Could not recognize scene type!" 50 | 51 | if not self.loaded_iter: 52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 53 | dest_file.write(src_file.read()) 54 | json_cams = [] 55 | camlist = [] 56 | if scene_info.test_cameras: 57 | camlist.extend(scene_info.test_cameras) 58 | if scene_info.train_cameras: 59 | camlist.extend(scene_info.train_cameras) 60 | for id, cam in enumerate(camlist): 61 | json_cams.append(camera_to_JSON(id, cam)) 62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 63 | json.dump(json_cams, file) 64 | 65 | if shuffle: 66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 68 | 69 | self.cameras_extent = scene_info.nerf_normalization["radius"] 70 | 71 | for resolution_scale in resolution_scales: 72 | print("Loading Training Cameras") 73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 74 | print("Loading Test Cameras") 75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 76 | 77 | if self.loaded_iter: 78 | self.gaussians.load_ply(os.path.join(self.model_path, 79 | "point_cloud", 80 | "iteration_" + str(self.loaded_iter), 81 | "point_cloud.ply")) 82 | else: 83 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 84 | 85 | def save(self, iteration): 86 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 87 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 88 | 89 | def getTrainCameras(self, scale=1.0): 90 | return self.train_cameras[scale] 91 | 92 | def getTestCameras(self, scale=1.0): 93 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/camera_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from lietorch import SO3, SE3 13 | class CameraModel(nn.Module): 14 | def __init__(self, num_cams, pose_dim=9, init_pose=None): 15 | """ 16 | :param num_cams: 17 | :param learn_R: True/False 18 | :param learn_t: True/False 19 | :param cfg: config argument options 20 | :param init_c2w: (N, 4, 4) torch tensor 21 | """ 22 | super(CameraModel, self).__init__() 23 | self.register_buffer('init_pose', init_pose) 24 | 25 | img_emb_dim = 1 26 | pose_dim = 9 27 | self.img_embedding = nn.Embedding(num_cams, img_emb_dim) 28 | self.cam_mlp = nn.Sequential( 29 | nn.Linear(img_emb_dim+pose_dim, 32), 30 | nn.ReLU(), 31 | nn.Linear(32, pose_dim)) 32 | self.factor = 0.01 33 | 34 | def forward(self, cam_id): 35 | cam_id = torch.int(cam_id).cuda() 36 | img_emb = self.img_embedding(cam_id) 37 | init_cam = self.init_pose[cam_id] 38 | cam_residual = self.cam_mlp(torch.cat([img_emb, init_cam], dim=1)) 39 | cam = init_cam + self.factor * cam_residual 40 | cam = SE3(cam) 41 | 42 | return cam 43 | 44 | 45 | def vec2skew(v): 46 | """ 47 | :param v: (3, ) torch tensor 48 | :return: (3, 3) 49 | """ 50 | zero = torch.zeros(1, dtype=torch.float32, device=v.device) 51 | skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1) 52 | skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]]) 53 | skew_v2 = torch.cat([-v[1:2], v[0:1], zero]) 54 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3) 55 | return skew_v # (3, 3) 56 | 57 | def Exp(r): 58 | """so(3) vector to SO(3) matrix 59 | :param r: (3, ) axis-angle, torch tensor 60 | :return: (3, 3) 61 | """ 62 | skew_r = vec2skew(r) # (3, 3) 63 | norm_r = r.norm() + 1e-15 64 | eye = torch.eye(3, dtype=torch.float32, device=r.device) 65 | R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r) 66 | return R 67 | 68 | def make_c2w(r, t): 69 | """ 70 | :param r: (3, ) axis-angle torch tensor 71 | :param t: (3, ) translation vector torch tensor 72 | :return: (4, 4) 73 | """ 74 | R = Exp(r) # (3, 3) 75 | c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4) 76 | c2w = convert3x4_4x4(c2w) # (4, 4) 77 | return c2w 78 | 79 | def convert3x4_4x4(input): 80 | """ 81 | :param input: (N, 3, 4) or (3, 4) torch or np 82 | :return: (N, 4, 4) or (4, 4) torch or np 83 | """ 84 | if torch.is_tensor(input): 85 | if len(input.shape) == 3: 86 | output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4) 87 | output[:, 3, 3] = 1.0 88 | else: 89 | output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 90 | else: 91 | if len(input.shape) == 3: 92 | output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 93 | output[:, 3, 3] = 1.0 94 | else: 95 | output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4) 96 | output[3, 3] = 1.0 97 | return output -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | 11 | import torch 12 | from torch import nn 13 | import numpy as np 14 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getWorld2View3 15 | 16 | class Camera(nn.Module): 17 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 18 | image_name, uid, intrinsics=None, 19 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", 20 | do_grad=False, is_co3d=False, 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | self.intrinsics = intrinsics.astype(np.float32) 32 | 33 | try: 34 | self.data_device = torch.device(data_device) 35 | except Exception as e: 36 | print(e) 37 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 38 | self.data_device = torch.device("cuda") 39 | 40 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 41 | self.image_width = self.original_image.shape[2] 42 | self.image_height = self.original_image.shape[1] 43 | 44 | if gt_alpha_mask is not None: 45 | self.original_image *= gt_alpha_mask.to(self.data_device) 46 | else: 47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 48 | 49 | self.gt_alpha_mask = gt_alpha_mask 50 | 51 | self.zfar = 100.0 52 | self.znear = 0.01 53 | 54 | self.trans = trans 55 | self.scale = scale 56 | 57 | if is_co3d: 58 | self.world_view_transform = torch.tensor(getWorld2View3(R, T, trans, scale)).transpose(0, 1).cuda() 59 | w, h = self.image_width, self.image_height 60 | fx, fy, cx, cy = self.intrinsics[0, 0], self.intrinsics[1, 1], \ 61 | self.intrinsics[0, 2], self.intrinsics[1, 2] 62 | far, near = self.zfar, self.znear 63 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0], 64 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0], 65 | [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], 66 | [0.0, 0.0, 1.0, 0.0]]).cuda().float().transpose(0, 1) 67 | self.projection_matrix = opengl_proj 68 | 69 | else: 70 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 71 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 72 | 73 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 74 | self.camera_center = self.world_view_transform.inverse()[3, :3] 75 | 76 | class MiniCam: 77 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 78 | self.image_width = width 79 | self.image_height = height 80 | self.FoVy = fovy 81 | self.FoVx = fovx 82 | self.znear = znear 83 | self.zfar = zfar 84 | self.world_view_transform = world_view_transform 85 | self.full_proj_transform = full_proj_transform 86 | view_inv = torch.inverse(self.world_view_transform) 87 | self.camera_center = view_inv[3][:3] 88 | 89 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import numpy as np 11 | import collections 12 | import struct 13 | 14 | CameraModel = collections.namedtuple( 15 | "CameraModel", ["model_id", "model_name", "num_params"]) 16 | Camera = collections.namedtuple( 17 | "Camera", ["id", "model", "width", "height", "params"]) 18 | BaseImage = collections.namedtuple( 19 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 20 | Point3D = collections.namedtuple( 21 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 22 | CAMERA_MODELS = { 23 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 24 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 25 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 26 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 27 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 28 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 29 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 30 | CameraModel(model_id=7, model_name="FOV", num_params=5), 31 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 32 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 33 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 34 | } 35 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 36 | for camera_model in CAMERA_MODELS]) 37 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | 40 | 41 | def qvec2rotmat(qvec): 42 | return np.array([ 43 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 44 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 45 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 46 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 47 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 48 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 49 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 50 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 51 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 52 | 53 | def rotmat2qvec(R): 54 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 55 | K = np.array([ 56 | [Rxx - Ryy - Rzz, 0, 0, 0], 57 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 58 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 59 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 60 | eigvals, eigvecs = np.linalg.eigh(K) 61 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 62 | if qvec[0] < 0: 63 | qvec *= -1 64 | return qvec 65 | 66 | class Image(BaseImage): 67 | def qvec2rotmat(self): 68 | return qvec2rotmat(self.qvec) 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | def read_points3D_text(path): 82 | """ 83 | see: src/base/reconstruction.cc 84 | void Reconstruction::ReadPoints3DText(const std::string& path) 85 | void Reconstruction::WritePoints3DText(const std::string& path) 86 | """ 87 | xyzs = None 88 | rgbs = None 89 | errors = None 90 | with open(path, "r") as fid: 91 | while True: 92 | line = fid.readline() 93 | if not line: 94 | break 95 | line = line.strip() 96 | if len(line) > 0 and line[0] != "#": 97 | elems = line.split() 98 | xyz = np.array(tuple(map(float, elems[1:4]))) 99 | rgb = np.array(tuple(map(int, elems[4:7]))) 100 | error = np.array(float(elems[7])) 101 | if xyzs is None: 102 | xyzs = xyz[None, ...] 103 | rgbs = rgb[None, ...] 104 | errors = error[None, ...] 105 | else: 106 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 107 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 108 | errors = np.append(errors, error[None, ...], axis=0) 109 | return xyzs, rgbs, errors 110 | 111 | def read_points3D_binary(path_to_model_file): 112 | """ 113 | see: src/base/reconstruction.cc 114 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 115 | void Reconstruction::WritePoints3DBinary(const std::string& path) 116 | """ 117 | 118 | 119 | with open(path_to_model_file, "rb") as fid: 120 | num_points = read_next_bytes(fid, 8, "Q")[0] 121 | 122 | xyzs = np.empty((num_points, 3)) 123 | rgbs = np.empty((num_points, 3)) 124 | errors = np.empty((num_points, 1)) 125 | 126 | for p_id in range(num_points): 127 | binary_point_line_properties = read_next_bytes( 128 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 129 | xyz = np.array(binary_point_line_properties[1:4]) 130 | rgb = np.array(binary_point_line_properties[4:7]) 131 | error = np.array(binary_point_line_properties[7]) 132 | track_length = read_next_bytes( 133 | fid, num_bytes=8, format_char_sequence="Q")[0] 134 | track_elems = read_next_bytes( 135 | fid, num_bytes=8*track_length, 136 | format_char_sequence="ii"*track_length) 137 | xyzs[p_id] = xyz 138 | rgbs[p_id] = rgb 139 | errors[p_id] = error 140 | return xyzs, rgbs, errors 141 | 142 | def read_intrinsics_text(path): 143 | """ 144 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 145 | """ 146 | cameras = {} 147 | with open(path, "r") as fid: 148 | while True: 149 | line = fid.readline() 150 | if not line: 151 | break 152 | line = line.strip() 153 | if len(line) > 0 and line[0] != "#": 154 | elems = line.split() 155 | camera_id = int(elems[0]) 156 | model = elems[1] 157 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 158 | width = int(elems[2]) 159 | height = int(elems[3]) 160 | params = np.array(tuple(map(float, elems[4:]))) 161 | cameras[camera_id] = Camera(id=camera_id, model=model, 162 | width=width, height=height, 163 | params=params) 164 | return cameras 165 | 166 | def read_extrinsics_binary(path_to_model_file): 167 | """ 168 | see: src/base/reconstruction.cc 169 | void Reconstruction::ReadImagesBinary(const std::string& path) 170 | void Reconstruction::WriteImagesBinary(const std::string& path) 171 | """ 172 | images = {} 173 | with open(path_to_model_file, "rb") as fid: 174 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 175 | for _ in range(num_reg_images): 176 | binary_image_properties = read_next_bytes( 177 | fid, num_bytes=64, format_char_sequence="idddddddi") 178 | image_id = binary_image_properties[0] 179 | qvec = np.array(binary_image_properties[1:5]) 180 | tvec = np.array(binary_image_properties[5:8]) 181 | camera_id = binary_image_properties[8] 182 | image_name = "" 183 | current_char = read_next_bytes(fid, 1, "c")[0] 184 | while current_char != b"\x00": # look for the ASCII 0 entry 185 | image_name += current_char.decode("utf-8") 186 | current_char = read_next_bytes(fid, 1, "c")[0] 187 | num_points2D = read_next_bytes(fid, num_bytes=8, 188 | format_char_sequence="Q")[0] 189 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 190 | format_char_sequence="ddq"*num_points2D) 191 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 192 | tuple(map(float, x_y_id_s[1::3]))]) 193 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 194 | images[image_id] = Image( 195 | id=image_id, qvec=qvec, tvec=tvec, 196 | camera_id=camera_id, name=image_name, 197 | xys=xys, point3D_ids=point3D_ids) 198 | return images 199 | 200 | 201 | def read_intrinsics_binary(path_to_model_file): 202 | """ 203 | see: src/base/reconstruction.cc 204 | void Reconstruction::WriteCamerasBinary(const std::string& path) 205 | void Reconstruction::ReadCamerasBinary(const std::string& path) 206 | """ 207 | cameras = {} 208 | with open(path_to_model_file, "rb") as fid: 209 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 210 | for _ in range(num_cameras): 211 | camera_properties = read_next_bytes( 212 | fid, num_bytes=24, format_char_sequence="iiQQ") 213 | camera_id = camera_properties[0] 214 | model_id = camera_properties[1] 215 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 216 | width = camera_properties[2] 217 | height = camera_properties[3] 218 | num_params = CAMERA_MODEL_IDS[model_id].num_params 219 | params = read_next_bytes(fid, num_bytes=8*num_params, 220 | format_char_sequence="d"*num_params) 221 | cameras[camera_id] = Camera(id=camera_id, 222 | model=model_name, 223 | width=width, 224 | height=height, 225 | params=np.array(params)) 226 | assert len(cameras) == num_cameras 227 | return cameras 228 | 229 | 230 | def read_extrinsics_text(path): 231 | """ 232 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 233 | """ 234 | images = {} 235 | with open(path, "r") as fid: 236 | while True: 237 | line = fid.readline() 238 | if not line: 239 | break 240 | line = line.strip() 241 | if len(line) > 0 and line[0] != "#": 242 | elems = line.split() 243 | image_id = int(elems[0]) 244 | qvec = np.array(tuple(map(float, elems[1:5]))) 245 | tvec = np.array(tuple(map(float, elems[5:8]))) 246 | camera_id = int(elems[8]) 247 | image_name = elems[9] 248 | elems = fid.readline().split() 249 | xys = np.column_stack([tuple(map(float, elems[0::3])), 250 | tuple(map(float, elems[1::3]))]) 251 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 252 | images[image_id] = Image( 253 | id=image_id, qvec=qvec, tvec=tvec, 254 | camera_id=camera_id, name=image_name, 255 | xys=xys, point3D_ids=point3D_ids) 256 | return images 257 | 258 | 259 | def read_colmap_bin_array(path): 260 | """ 261 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 262 | 263 | :param path: path to the colmap binary file. 264 | :return: nd array with the floating point values in the value 265 | """ 266 | with open(path, "rb") as fid: 267 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 268 | usecols=(0, 1, 2), dtype=int) 269 | fid.seek(0) 270 | num_delimiter = 0 271 | byte = fid.read(1) 272 | while True: 273 | if byte == b"&": 274 | num_delimiter += 1 275 | if num_delimiter >= 3: 276 | break 277 | byte = fid.read(1) 278 | array = np.fromfile(fid, np.float32) 279 | array = array.reshape((width, height, channels), order="F") 280 | return np.transpose(array, (1, 0, 2)).squeeze() 281 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import os 11 | import sys 12 | from PIL import Image 13 | from typing import NamedTuple 14 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 15 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 16 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 17 | import numpy as np 18 | import json 19 | from pathlib import Path 20 | from plyfile import PlyData, PlyElement 21 | from utils.sh_utils import SH2RGB 22 | from scene.gaussian_model import BasicPointCloud 23 | 24 | 25 | class CameraInfo(NamedTuple): 26 | uid: int 27 | R: np.array 28 | T: np.array 29 | FovY: np.array 30 | FovX: np.array 31 | image: np.array 32 | intrinsics: np.array 33 | image_path: str 34 | image_name: str 35 | width: int 36 | height: int 37 | 38 | 39 | class SceneInfo(NamedTuple): 40 | point_cloud: BasicPointCloud 41 | train_cameras: list 42 | test_cameras: list 43 | nerf_normalization: dict 44 | ply_path: str 45 | 46 | 47 | def getNerfppNorm(cam_info): 48 | def get_center_and_diag(cam_centers): 49 | cam_centers = np.hstack(cam_centers) 50 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 51 | center = avg_cam_center 52 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 53 | diagonal = np.max(dist) 54 | return center.flatten(), diagonal 55 | 56 | cam_centers = [] 57 | 58 | for cam in cam_info: 59 | W2C = getWorld2View2(cam.R, cam.T) 60 | C2W = np.linalg.inv(W2C) 61 | cam_centers.append(C2W[:3, 3:4]) 62 | 63 | center, diagonal = get_center_and_diag(cam_centers) 64 | radius = diagonal * 1.1 65 | 66 | translate = -center 67 | 68 | return {"translate": translate, "radius": radius} 69 | 70 | 71 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 72 | cam_infos = [] 73 | for idx, key in enumerate(cam_extrinsics): 74 | sys.stdout.write('\r') 75 | # the exact output you're looking for: 76 | sys.stdout.write( 77 | "Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 78 | sys.stdout.flush() 79 | 80 | extr = cam_extrinsics[key] 81 | intr = cam_intrinsics[extr.camera_id] 82 | height = intr.height 83 | width = intr.width 84 | 85 | uid = intr.id 86 | R = np.transpose(qvec2rotmat(extr.qvec)) 87 | T = np.array(extr.tvec) 88 | if intr.model == "SIMPLE_PINHOLE" or intr.model == "SIMPLE_RADIAL": 89 | focal_length_x = intr.params[0] 90 | FovY = focal2fov(focal_length_x, height) 91 | FovX = focal2fov(focal_length_x, width) 92 | intr_mat = np.array( 93 | [[focal_length_x, 0, width/2], [0, focal_length_x, height/2], [0, 0, 1]]) 94 | elif intr.model == "PINHOLE": 95 | focal_length_x = intr.params[0] 96 | focal_length_y = intr.params[1] 97 | FovY = focal2fov(focal_length_y, height) 98 | FovX = focal2fov(focal_length_x, width) 99 | intr_mat = np.array( 100 | [[focal_length_x, 0, width/2], [0, focal_length_y, height/2], [0, 0, 1]]) 101 | else: 102 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 103 | 104 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 105 | image_name = os.path.basename(image_path).split(".")[0] 106 | image = Image.open(image_path) 107 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, intrinsics=intr_mat, 108 | image_path=image_path, image_name=image_name, width=width, height=height) 109 | cam_infos.append(cam_info) 110 | sys.stdout.write('\n') 111 | return cam_infos 112 | 113 | 114 | def fetchPly(path): 115 | plydata = PlyData.read(path) 116 | vertices = plydata['vertex'] 117 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 118 | colors = np.vstack([vertices['red'], vertices['green'], 119 | vertices['blue']]).T / 255.0 120 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 121 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 122 | 123 | 124 | def storePly(path, xyz, rgb): 125 | # Define the dtype for the structured array 126 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 127 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 128 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 129 | 130 | normals = np.zeros_like(xyz) 131 | 132 | elements = np.empty(xyz.shape[0], dtype=dtype) 133 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 134 | elements[:] = list(map(tuple, attributes)) 135 | 136 | # Create the PlyData object and write to file 137 | vertex_element = PlyElement.describe(elements, 'vertex') 138 | ply_data = PlyData([vertex_element]) 139 | ply_data.write(path) 140 | 141 | 142 | def readColmapSceneInfo(path, images, eval, llffhold=8): 143 | try: 144 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 145 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 146 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 147 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 148 | except: 149 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 150 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 151 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 152 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 153 | 154 | reading_dir = "images" if images == None else images 155 | cam_infos_unsorted = readColmapCameras( 156 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 157 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) 158 | 159 | if eval: 160 | # train_cam_infos = [c for idx, c in enumerate( 161 | # cam_infos) if idx % llffhold != 0] 162 | # test_cam_infos = [c for idx, c in enumerate( 163 | # cam_infos) if idx % llffhold == 0] 164 | sample_rate = 2 if "Family" in path else 8 165 | # sample_rate = 8 166 | ids = np.arange(len(cam_infos)) 167 | i_test = ids[int(sample_rate/2)::sample_rate] 168 | i_train = np.array([i for i in ids if i not in i_test]) 169 | train_cam_infos = [cam_infos[i] for i in i_train] 170 | test_cam_infos = [cam_infos[i] for i in i_test] 171 | else: 172 | train_cam_infos = cam_infos 173 | test_cam_infos = [] 174 | 175 | nerf_normalization = getNerfppNorm(train_cam_infos) 176 | 177 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 178 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 179 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 180 | if not os.path.exists(ply_path): 181 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 182 | try: 183 | xyz, rgb, _ = read_points3D_binary(bin_path) 184 | except: 185 | xyz, rgb, _ = read_points3D_text(txt_path) 186 | storePly(ply_path, xyz, rgb) 187 | try: 188 | pcd = fetchPly(ply_path) 189 | except: 190 | pcd = None 191 | 192 | scene_info = SceneInfo(point_cloud=pcd, 193 | train_cameras=train_cam_infos, 194 | test_cameras=test_cam_infos, 195 | nerf_normalization=nerf_normalization, 196 | ply_path=ply_path) 197 | return scene_info 198 | 199 | 200 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 201 | cam_infos = [] 202 | 203 | with open(os.path.join(path, transformsfile)) as json_file: 204 | contents = json.load(json_file) 205 | fovx = contents["camera_angle_x"] 206 | 207 | frames = contents["frames"] 208 | for idx, frame in enumerate(frames): 209 | cam_name = os.path.join(path, frame["file_path"] + extension) 210 | 211 | # NeRF 'transform_matrix' is a camera-to-world transform 212 | c2w = np.array(frame["transform_matrix"]) 213 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 214 | c2w[:3, 1:3] *= -1 215 | 216 | # get the world-to-camera transform and set R, T 217 | w2c = np.linalg.inv(c2w) 218 | # R is stored transposed due to 'glm' in CUDA code 219 | R = np.transpose(w2c[:3, :3]) 220 | T = w2c[:3, 3] 221 | 222 | image_path = os.path.join(path, cam_name) 223 | image_name = Path(cam_name).stem 224 | image = Image.open(image_path) 225 | 226 | im_data = np.array(image.convert("RGBA")) 227 | 228 | bg = np.array( 229 | [1, 1, 1]) if white_background else np.array([0, 0, 0]) 230 | 231 | norm_data = im_data / 255.0 232 | arr = norm_data[:, :, :3] * norm_data[:, :, 233 | 3:4] + bg * (1 - norm_data[:, :, 3:4]) 234 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 235 | 236 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 237 | FovY = fovy 238 | FovX = fovx 239 | 240 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 241 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 242 | 243 | return cam_infos 244 | 245 | 246 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 247 | print("Reading Training Transforms") 248 | train_cam_infos = readCamerasFromTransforms( 249 | path, "transforms_train.json", white_background, extension) 250 | print("Reading Test Transforms") 251 | test_cam_infos = readCamerasFromTransforms( 252 | path, "transforms_test.json", white_background, extension) 253 | 254 | if not eval: 255 | train_cam_infos.extend(test_cam_infos) 256 | test_cam_infos = [] 257 | 258 | nerf_normalization = getNerfppNorm(train_cam_infos) 259 | 260 | ply_path = os.path.join(path, "points3d.ply") 261 | if not os.path.exists(ply_path): 262 | # Since this data set has no colmap data, we start with random points 263 | num_pts = 100_000 264 | print(f"Generating random point cloud ({num_pts})...") 265 | 266 | # We create random points inside the bounds of the synthetic Blender scenes 267 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 268 | shs = np.random.random((num_pts, 3)) / 255.0 269 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 270 | shs), normals=np.zeros((num_pts, 3))) 271 | 272 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 273 | try: 274 | pcd = fetchPly(ply_path) 275 | except: 276 | pcd = None 277 | 278 | scene_info = SceneInfo(point_cloud=pcd, 279 | train_cameras=train_cam_infos, 280 | test_cameras=test_cam_infos, 281 | nerf_normalization=nerf_normalization, 282 | ply_path=ply_path) 283 | return scene_info 284 | 285 | 286 | sceneLoadTypeCallbacks = { 287 | "Colmap": readColmapSceneInfo, 288 | "Blender": readNerfSyntheticInfo 289 | } 290 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import os 11 | import torch 12 | from random import randint 13 | from utils.loss_utils import l1_loss, ssim 14 | from gaussian_renderer import render, network_gui 15 | import sys 16 | from scene import Scene, GaussianModel 17 | from utils.general_utils import safe_state 18 | import uuid 19 | from tqdm import tqdm 20 | from utils.image_utils import psnr 21 | from argparse import ArgumentParser, Namespace 22 | from arguments import ModelParams, PipelineParams, OptimizationParams 23 | try: 24 | from torch.utils.tensorboard import SummaryWriter 25 | TENSORBOARD_FOUND = True 26 | except ImportError: 27 | TENSORBOARD_FOUND = False 28 | 29 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 30 | first_iter = 0 31 | tb_writer = prepare_output_and_logger(dataset) 32 | gaussians = GaussianModel(dataset.sh_degree) 33 | scene = Scene(dataset, gaussians) 34 | gaussians.training_setup(opt) 35 | if checkpoint: 36 | (model_params, first_iter) = torch.load(checkpoint) 37 | gaussians.restore(model_params, opt) 38 | 39 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 40 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 41 | 42 | iter_start = torch.cuda.Event(enable_timing = True) 43 | iter_end = torch.cuda.Event(enable_timing = True) 44 | 45 | viewpoint_stack = None 46 | ema_loss_for_log = 0.0 47 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 48 | first_iter += 1 49 | for iteration in range(first_iter, opt.iterations + 1): 50 | if network_gui.conn == None: 51 | network_gui.try_connect() 52 | while network_gui.conn != None: 53 | try: 54 | net_image_bytes = None 55 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 56 | if custom_cam != None: 57 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 58 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 59 | network_gui.send(net_image_bytes, dataset.source_path) 60 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 61 | break 62 | except Exception as e: 63 | network_gui.conn = None 64 | 65 | iter_start.record() 66 | 67 | gaussians.update_learning_rate(iteration) 68 | 69 | # Every 1000 its we increase the levels of SH up to a maximum degree 70 | if iteration % 1000 == 0: 71 | gaussians.oneupSHdegree() 72 | 73 | # Pick a random Camera 74 | if not viewpoint_stack: 75 | viewpoint_stack = scene.getTrainCameras().copy() 76 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 77 | 78 | # Render 79 | if (iteration - 1) == debug_from: 80 | pipe.debug = True 81 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 82 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 83 | 84 | # Loss 85 | gt_image = viewpoint_cam.original_image.cuda() 86 | Ll1 = l1_loss(image, gt_image) 87 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 88 | loss.backward() 89 | 90 | iter_end.record() 91 | 92 | with torch.no_grad(): 93 | # Progress bar 94 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 95 | if iteration % 10 == 0: 96 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 97 | progress_bar.update(10) 98 | if iteration == opt.iterations: 99 | progress_bar.close() 100 | 101 | # Log and save 102 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 103 | if (iteration in saving_iterations): 104 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 105 | scene.save(iteration) 106 | 107 | # Densification 108 | if iteration < opt.densify_until_iter: 109 | # Keep track of max radii in image-space for pruning 110 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 111 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 112 | 113 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 114 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 115 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 116 | 117 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 118 | gaussians.reset_opacity() 119 | 120 | # Optimizer step 121 | if iteration < opt.iterations: 122 | gaussians.optimizer.step() 123 | gaussians.optimizer.zero_grad(set_to_none = True) 124 | 125 | if (iteration in checkpoint_iterations): 126 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 127 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 128 | 129 | def prepare_output_and_logger(args): 130 | if not args.model_path: 131 | if os.getenv('OAR_JOB_ID'): 132 | unique_str=os.getenv('OAR_JOB_ID') 133 | else: 134 | unique_str = str(uuid.uuid4()) 135 | args.model_path = os.path.join("./output/", unique_str[0:10]) 136 | 137 | # Set up output folder 138 | print("Output folder: {}".format(args.model_path)) 139 | os.makedirs(args.model_path, exist_ok = True) 140 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 141 | cfg_log_f.write(str(Namespace(**vars(args)))) 142 | 143 | # Create Tensorboard writer 144 | tb_writer = None 145 | if TENSORBOARD_FOUND: 146 | tb_writer = SummaryWriter(args.model_path) 147 | else: 148 | print("Tensorboard not available: not logging progress") 149 | return tb_writer 150 | 151 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 152 | if tb_writer: 153 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 154 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 155 | tb_writer.add_scalar('iter_time', elapsed, iteration) 156 | 157 | # Report test and samples of training set 158 | if iteration in testing_iterations: 159 | torch.cuda.empty_cache() 160 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 161 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 162 | 163 | for config in validation_configs: 164 | if config['cameras'] and len(config['cameras']) > 0: 165 | l1_test = 0.0 166 | psnr_test = 0.0 167 | for idx, viewpoint in enumerate(config['cameras']): 168 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 169 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 170 | if tb_writer and (idx < 5): 171 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 172 | if iteration == testing_iterations[0]: 173 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 174 | l1_test += l1_loss(image, gt_image).mean().double() 175 | psnr_test += psnr(image, gt_image).mean().double() 176 | psnr_test /= len(config['cameras']) 177 | l1_test /= len(config['cameras']) 178 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 179 | if tb_writer: 180 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 181 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 182 | 183 | if tb_writer: 184 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 185 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 186 | torch.cuda.empty_cache() 187 | 188 | if __name__ == "__main__": 189 | # Set up command line argument parser 190 | parser = ArgumentParser(description="Training script parameters") 191 | lp = ModelParams(parser) 192 | op = OptimizationParams(parser) 193 | pp = PipelineParams(parser) 194 | parser.add_argument('--ip', type=str, default="127.0.0.1") 195 | parser.add_argument('--port', type=int, default=6009) 196 | parser.add_argument('--debug_from', type=int, default=-1) 197 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 198 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) 199 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 200 | parser.add_argument("--quiet", action="store_true") 201 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 202 | parser.add_argument("--start_checkpoint", type=str, default = None) 203 | args = parser.parse_args(sys.argv[1:]) 204 | args.save_iterations.append(args.iterations) 205 | 206 | print("Optimizing " + args.model_path) 207 | 208 | # Initialize system state (RNG) 209 | safe_state(args.quiet) 210 | 211 | # Start GUI server, configure and run training 212 | network_gui.init(args.ip, args.port) 213 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 214 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 215 | 216 | # All done 217 | print("\nTraining complete.") 218 | -------------------------------------------------------------------------------- /trainer/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.autograd import Variable 13 | 14 | import numpy as np 15 | from math import exp 16 | 17 | 18 | import pdb 19 | 20 | 21 | class Loss_Eval(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, rgb_pred, rgb_gt): 26 | loss = F.mse_loss(rgb_pred, rgb_gt) 27 | return_dict = { 28 | 'loss': loss 29 | } 30 | return return_dict 31 | 32 | 33 | class Loss(nn.Module): 34 | def __init__(self, cfg=None): 35 | super().__init__() 36 | 37 | self.depth_loss_type = cfg.depth_loss_type 38 | 39 | self.l1_loss = nn.L1Loss(reduction='sum') 40 | self.l2_loss = nn.MSELoss(reduction='sum') 41 | self.scale_inv_loss = ScaleAndShiftInvariantLoss(alpha=0.5, 42 | scales=1) 43 | 44 | # ssim_loss = ssim 45 | self.ssim_loss = SSIM_V2() 46 | 47 | self.cfg = cfg 48 | 49 | def get_rgb_full_loss(self, rgb_values, rgb_gt, rgb_loss_type='l2'): 50 | num_pixels = rgb_values.shape[1] * rgb_values.shape[2] 51 | if rgb_loss_type == 'l1': 52 | rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(num_pixels) 53 | elif rgb_loss_type == 'l2': 54 | rgb_loss = self.l2_loss(rgb_values, rgb_gt) / float(num_pixels) 55 | return rgb_loss 56 | 57 | def depth_loss_dpt(self, pred_depth, gt_depth, weight=None): 58 | """ 59 | :param pred_depth: (H, W) 60 | :param gt_depth: (H, W) 61 | :param weight: (H, W) 62 | :return: scalar 63 | """ 64 | t_pred = torch.median(pred_depth) 65 | s_pred = torch.mean(torch.abs(pred_depth - t_pred)) 66 | 67 | t_gt = torch.median(gt_depth) 68 | s_gt = torch.mean(torch.abs(gt_depth - t_gt)) 69 | 70 | pred_depth_n = (pred_depth - t_pred) / s_pred 71 | gt_depth_n = (gt_depth - t_gt) / s_gt 72 | 73 | if weight is not None: 74 | loss = F.mse_loss(pred_depth_n, gt_depth_n, reduction='none') 75 | loss = loss * weight 76 | loss = loss.sum() / (weight.sum() + 1e-8) 77 | else: 78 | 79 | depth_error = (pred_depth_n - gt_depth_n) ** 2 80 | depth_error[depth_error > torch.quantile(depth_error, 0.8)] = 0 81 | loss = depth_error.mean() 82 | # loss = F.mse_loss(pred_depth_n, gt_depth_n) 83 | 84 | return loss 85 | 86 | def get_depth_loss(self, depth_pred, depth_gt): 87 | num_pixels = depth_pred.shape[0] * depth_pred.shape[1] 88 | if self.depth_loss_type == 'l1': 89 | loss = self.l1_loss(depth_pred, depth_gt) / float(num_pixels) 90 | elif self.depth_loss_type == 'invariant': 91 | # loss = self.depth_loss_dpt(1.0/depth_pred, 1.0/depth_gt) 92 | mask = (depth_gt > 0.02).float() 93 | loss = self.scale_inv_loss( 94 | depth_pred[None], depth_gt[None], mask[None]) 95 | return loss 96 | 97 | 98 | def forward(self, rgb_pred, rgb_gt, depth_pred=None, depth_gt=None, 99 | rgb_loss_type='l1', **kwargs): 100 | 101 | rgb_gt = rgb_gt.cuda() 102 | 103 | lambda_dssim = self.cfg.lambda_dssim 104 | lambda_depth = self.cfg.lambda_depth 105 | 106 | # rgb_full_loss = self.get_rgb_full_loss(rgb_pred, rgb_gt, rgb_loss_type) 107 | rgb_full_loss = (1 - lambda_dssim) * l1_loss(rgb_pred, rgb_gt) 108 | if lambda_dssim != 0.0: 109 | # dssim_loss = compute_ssim_loss(rgb_pred, rgb_gt) 110 | # pdb.set_trace() 111 | # dssim_loss = 1 - ssim(rgb_pred, rgb_gt) 112 | dssim_loss = 1 - self.ssim_loss(rgb_pred, rgb_gt) 113 | 114 | if lambda_depth != 0.0 and depth_pred is not None and depth_gt is not None: 115 | depth_gt = depth_gt.cuda() 116 | depth_pred[depth_pred < 0.02] = 0.02 117 | depth_pred[depth_pred > 20.0] = 20.0 118 | depth_loss = self.get_depth_loss( 119 | depth_pred.squeeze(), depth_gt.squeeze()) 120 | else: 121 | depth_loss = torch.tensor(0.0).cuda().float() 122 | 123 | loss = rgb_full_loss + lambda_dssim * dssim_loss +\ 124 | lambda_depth * depth_loss 125 | 126 | if torch.isnan(loss): 127 | breakpoint() 128 | 129 | return_dict = { 130 | 'loss': loss, 131 | 'loss_rgb': rgb_full_loss, 132 | 'loss_dssim': dssim_loss, 133 | 'loss_depth': depth_loss, 134 | } 135 | 136 | return return_dict 137 | 138 | 139 | def l1_loss(network_output, gt): 140 | return torch.abs((network_output - gt)).mean() 141 | 142 | 143 | def l2_loss(network_output, gt): 144 | return ((network_output - gt) ** 2).mean() 145 | 146 | 147 | def gaussian(window_size, sigma): 148 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / 149 | float(2 * sigma ** 2)) for x in range(window_size)]) 150 | return gauss / gauss.sum() 151 | 152 | 153 | def create_window(window_size, channel): 154 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 155 | _2D_window = _1D_window.mm( 156 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 157 | window = Variable(_2D_window.expand( 158 | channel, 1, window_size, window_size).contiguous()) 159 | return window 160 | 161 | 162 | 163 | 164 | def _ssim( 165 | img1, img2, window, window_size, channel, mask=None, size_average=True 166 | ): 167 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 168 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 169 | 170 | mu1_sq = mu1.pow(2) 171 | mu2_sq = mu2.pow(2) 172 | mu1_mu2 = mu1 * mu2 173 | 174 | sigma1_sq = ( 175 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 176 | - mu1_sq 177 | ) 178 | sigma2_sq = ( 179 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 180 | - mu2_sq 181 | ) 182 | sigma12 = ( 183 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 184 | - mu1_mu2 185 | ) 186 | 187 | C1 = (0.01) ** 2 188 | C2 = (0.03) ** 2 189 | 190 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 191 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 192 | ) 193 | 194 | if not (mask is None): 195 | b = mask.size(0) 196 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 197 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 198 | dim=1 199 | ).clamp(min=1) 200 | return ssim_map 201 | 202 | # import pdb 203 | 204 | # pdb.set_trace 205 | 206 | if size_average: 207 | return ssim_map.mean() 208 | else: 209 | return ssim_map.mean(1).mean(1).mean(1) 210 | 211 | 212 | class SSIM_V2(torch.nn.Module): 213 | def __init__(self, window_size=11, size_average=True): 214 | super(SSIM_V2, self).__init__() 215 | self.window_size = window_size 216 | self.size_average = size_average 217 | self.channel = 1 218 | self.window = create_window(window_size, self.channel) 219 | 220 | def forward(self, img1, img2, mask=None): 221 | if img1.dim() == 3: 222 | img1 = img1.unsqueeze(0) 223 | if img2.dim() == 3: 224 | img2 = img2.unsqueeze(0) 225 | 226 | (_, channel, _, _) = img1.size() 227 | 228 | if ( 229 | channel == self.channel 230 | and self.window.data.type() == img1.data.type() 231 | ): 232 | window = self.window 233 | else: 234 | window = create_window(self.window_size, channel) 235 | 236 | if img1.is_cuda: 237 | window = window.cuda(img1.get_device()) 238 | window = window.type_as(img1) 239 | 240 | self.window = window 241 | self.channel = channel 242 | 243 | return _ssim( 244 | img1, 245 | img2, 246 | window, 247 | self.window_size, 248 | channel, 249 | mask, 250 | self.size_average, 251 | ) 252 | 253 | 254 | 255 | 256 | 257 | 258 | # copy from MiDaS and MonoSDF 259 | def compute_scale_and_shift(prediction, target, mask): 260 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 261 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 262 | a_01 = torch.sum(mask * prediction, (1, 2)) 263 | a_11 = torch.sum(mask, (1, 2)) 264 | 265 | # right hand side: b = [b_0, b_1] 266 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 267 | b_1 = torch.sum(mask * target, (1, 2)) 268 | 269 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 270 | x_0 = torch.zeros_like(b_0) 271 | x_1 = torch.zeros_like(b_1) 272 | 273 | det = a_00 * a_11 - a_01 * a_01 274 | valid = det.nonzero() 275 | 276 | x_0[valid] = (a_11[valid] * b_0[valid] - 277 | a_01[valid] * b_1[valid]) / det[valid] 278 | x_1[valid] = (-a_01[valid] * b_0[valid] + 279 | a_00[valid] * b_1[valid]) / det[valid] 280 | 281 | return x_0, x_1 282 | 283 | 284 | def reduction_batch_based(image_loss, M): 285 | # average of all valid pixels of the batch 286 | 287 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 288 | divisor = torch.sum(M) 289 | 290 | if divisor == 0: 291 | return 0 292 | else: 293 | return torch.sum(image_loss) / divisor 294 | 295 | 296 | def reduction_image_based(image_loss, M): 297 | # mean of average of valid pixels of an image 298 | 299 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 300 | valid = M.nonzero() 301 | 302 | image_loss[valid] = image_loss[valid] / M[valid] 303 | 304 | return torch.mean(image_loss) 305 | 306 | 307 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 308 | 309 | M = torch.sum(mask, (1, 2)) 310 | res = prediction - target 311 | image_loss = torch.sum(mask * res * res, (1, 2)) 312 | 313 | return reduction(image_loss, 2 * M) 314 | 315 | 316 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 317 | 318 | M = torch.sum(mask, (1, 2)) 319 | 320 | diff = prediction - target 321 | diff = torch.mul(mask, diff) 322 | 323 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 324 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 325 | grad_x = torch.mul(mask_x, grad_x) 326 | 327 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 328 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 329 | grad_y = torch.mul(mask_y, grad_y) 330 | 331 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 332 | 333 | return reduction(image_loss, M) 334 | 335 | 336 | class MSELoss(nn.Module): 337 | def __init__(self, reduction='batch-based'): 338 | super().__init__() 339 | 340 | if reduction == 'batch-based': 341 | self.__reduction = reduction_batch_based 342 | else: 343 | self.__reduction = reduction_image_based 344 | 345 | def forward(self, prediction, target, mask): 346 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 347 | 348 | 349 | class GradientLoss(nn.Module): 350 | def __init__(self, scales=4, reduction='batch-based'): 351 | super().__init__() 352 | 353 | if reduction == 'batch-based': 354 | self.__reduction = reduction_batch_based 355 | else: 356 | self.__reduction = reduction_image_based 357 | 358 | self.__scales = scales 359 | 360 | def forward(self, prediction, target, mask): 361 | total = 0 362 | 363 | for scale in range(self.__scales): 364 | step = pow(2, scale) 365 | 366 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 367 | mask[:, ::step, ::step], reduction=self.__reduction) 368 | 369 | return total 370 | 371 | 372 | class ScaleAndShiftInvariantLoss(nn.Module): 373 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): 374 | super().__init__() 375 | 376 | self.__data_loss = MSELoss(reduction=reduction) 377 | self.__regularization_loss = GradientLoss( 378 | scales=scales, reduction=reduction) 379 | self.__alpha = alpha 380 | 381 | self.__prediction_ssi = None 382 | 383 | def forward(self, prediction, target, mask): 384 | scale, shift = compute_scale_and_shift(prediction, target, mask) 385 | self.__prediction_ssi = scale.view(-1, 1, 1) * \ 386 | prediction + shift.view(-1, 1, 1) 387 | 388 | total = self.__data_loss(self.__prediction_ssi, target, mask) 389 | if self.__alpha > 0: 390 | total += self.__alpha * \ 391 | self.__regularization_loss(self.__prediction_ssi, target, mask) 392 | 393 | return total 394 | 395 | def __get_prediction_ssi(self): 396 | return self.__prediction_ssi 397 | 398 | prediction_ssi = property(__get_prediction_ssi) 399 | -------------------------------------------------------------------------------- /utils/camera_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # License file: https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Union 8 | import numpy as np 9 | 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | Device = Union[str, torch.device] 15 | 16 | """ 17 | The transformation matrices returned from the functions in this file assume 18 | the points on which the transformation will be applied are column vectors. 19 | i.e. the R matrix is structured as 20 | 21 | R = [ 22 | [Rxx, Rxy, Rxz], 23 | [Ryx, Ryy, Ryz], 24 | [Rzx, Rzy, Rzz], 25 | ] # (3, 3) 26 | 27 | This matrix can be applied to column vectors by post multiplication 28 | by the points e.g. 29 | 30 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point 31 | transformed_points = R * points 32 | 33 | To apply the same matrix to points which are row vectors, the R matrix 34 | can be transposed and pre multiplied by the points: 35 | 36 | e.g. 37 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point 38 | transformed_points = points * R.transpose(1, 0) 39 | """ 40 | 41 | 42 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Convert rotations given as quaternions to rotation matrices. 45 | 46 | Args: 47 | quaternions: quaternions with real part first, 48 | as tensor of shape (..., 4). 49 | 50 | Returns: 51 | Rotation matrices as tensor of shape (..., 3, 3). 52 | """ 53 | r, i, j, k = torch.unbind(quaternions, -1) 54 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 55 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 56 | 57 | o = torch.stack( 58 | ( 59 | 1 - two_s * (j * j + k * k), 60 | two_s * (i * j - k * r), 61 | two_s * (i * k + j * r), 62 | two_s * (i * j + k * r), 63 | 1 - two_s * (i * i + k * k), 64 | two_s * (j * k - i * r), 65 | two_s * (i * k - j * r), 66 | two_s * (j * k + i * r), 67 | 1 - two_s * (i * i + j * j), 68 | ), 69 | -1, 70 | ) 71 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 72 | 73 | 74 | def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 75 | """ 76 | Return a tensor where each element has the absolute value taken from the, 77 | corresponding element of a, with sign taken from the corresponding 78 | element of b. This is like the standard copysign floating-point operation, 79 | but is not careful about negative 0 and NaN. 80 | 81 | Args: 82 | a: source tensor. 83 | b: tensor whose signs will be used, of the same shape as a. 84 | 85 | Returns: 86 | Tensor of the same shape as a with the signs of b. 87 | """ 88 | signs_differ = (a < 0) != (b < 0) 89 | return torch.where(signs_differ, -a, a) 90 | 91 | 92 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 93 | """ 94 | Returns torch.sqrt(torch.max(0, x)) 95 | but with a zero subgradient where x is 0. 96 | """ 97 | ret = torch.zeros_like(x) 98 | positive_mask = x > 0 99 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 100 | return ret 101 | 102 | 103 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 104 | """ 105 | Convert rotations given as rotation matrices to quaternions. 106 | 107 | Args: 108 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 109 | 110 | Returns: 111 | quaternions with real part first, as tensor of shape (..., 4). 112 | """ 113 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 114 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 115 | 116 | batch_dim = matrix.shape[:-2] 117 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 118 | matrix.reshape(batch_dim + (9,)), dim=-1 119 | ) 120 | 121 | q_abs = _sqrt_positive_part( 122 | torch.stack( 123 | [ 124 | 1.0 + m00 + m11 + m22, 125 | 1.0 + m00 - m11 - m22, 126 | 1.0 - m00 + m11 - m22, 127 | 1.0 - m00 - m11 + m22, 128 | ], 129 | dim=-1, 130 | ) 131 | ) 132 | 133 | # we produce the desired quaternion multiplied by each of r, i, j, k 134 | quat_by_rijk = torch.stack( 135 | [ 136 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 137 | # `int`. 138 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, 139 | m02 - m20, m10 - m01], dim=-1), 140 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 141 | # `int`. 142 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, 143 | m10 + m01, m02 + m20], dim=-1), 144 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 145 | # `int`. 146 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] 147 | ** 2, m12 + m21], dim=-1), 148 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 149 | # `int`. 150 | torch.stack([m10 - m01, m20 + m02, m21 + m12, 151 | q_abs[..., 3] ** 2], dim=-1), 152 | ], 153 | dim=-2, 154 | ) 155 | 156 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 157 | # the candidate won't be picked. 158 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 159 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 160 | 161 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 162 | # forall i; we pick the best-conditioned one (with the largest denominator) 163 | 164 | out = quat_candidates[ 165 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 166 | ].reshape(batch_dim + (4,)) 167 | return standardize_quaternion(out) 168 | 169 | 170 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 171 | """ 172 | Convert a unit quaternion to a standard form: one in which the real 173 | part is non negative. 174 | 175 | Args: 176 | quaternions: Quaternions with real part first, 177 | as tensor of shape (..., 4). 178 | 179 | Returns: 180 | Standardized quaternions as tensor of shape (..., 4). 181 | """ 182 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 183 | 184 | 185 | def convert3x4_4x4(input): 186 | """ 187 | :param input: (N, 3, 4) or (3, 4) torch or np 188 | :return: (N, 4, 4) or (4, 4) torch or np 189 | """ 190 | if torch.is_tensor(input): 191 | if len(input.shape) == 3: 192 | output = torch.cat([input, torch.zeros_like( 193 | input[:, 0:1])], dim=1) # (N, 4, 4) 194 | output[:, 3, 3] = 1.0 195 | else: 196 | output = torch.cat([input, torch.tensor( 197 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 198 | else: 199 | if len(input.shape) == 3: 200 | output = np.concatenate( 201 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 202 | output[:, 3, 3] = 1.0 203 | else: 204 | output = np.concatenate( 205 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) 206 | output[3, 3] = 1.0 207 | return output 208 | 209 | 210 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: 211 | """ 212 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 213 | using Gram--Schmidt orthogonalization per Section B of [1]. 214 | Args: 215 | d6: 6D rotation representation, of size (*, 6) 216 | 217 | Returns: 218 | batch of rotation matrices of size (*, 3, 3) 219 | 220 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 221 | On the Continuity of Rotation Representations in Neural Networks. 222 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 223 | Retrieved from http://arxiv.org/abs/1812.07035 224 | """ 225 | 226 | a1, a2 = d6[..., :3], d6[..., 3:] 227 | b1 = F.normalize(a1, dim=-1) 228 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 229 | b2 = F.normalize(b2, dim=-1) 230 | b3 = torch.cross(b1, b2, dim=-1) 231 | return torch.stack((b1, b2, b3), dim=-2) 232 | 233 | 234 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 235 | """ 236 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 237 | by dropping the last row. Note that 6D representation is not unique. 238 | Args: 239 | matrix: batch of rotation matrices of size (*, 3, 3) 240 | 241 | Returns: 242 | 6D rotation representation, of size (*, 6) 243 | 244 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 245 | On the Continuity of Rotation Representations in Neural Networks. 246 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 247 | Retrieved from http://arxiv.org/abs/1812.07035 248 | """ 249 | batch_dim = matrix.size()[:-2] 250 | return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) 251 | 252 | 253 | class Pose(): 254 | """ 255 | A class of operations on camera poses (PyTorch tensors with shape [...,3,4]) 256 | each [3,4] camera pose takes the form of [R|t] 257 | """ 258 | 259 | def __call__(self, R=None, t=None): 260 | # construct a camera pose from the given R and/or t 261 | assert (R is not None or t is not None) 262 | if R is None: 263 | if not isinstance(t, torch.Tensor): 264 | t = torch.tensor(t) 265 | R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1) 266 | elif t is None: 267 | if not isinstance(R, torch.Tensor): 268 | R = torch.tensor(R) 269 | t = torch.zeros(R.shape[:-1], device=R.device) 270 | else: 271 | if not isinstance(R, torch.Tensor): 272 | R = torch.tensor(R) 273 | if not isinstance(t, torch.Tensor): 274 | t = torch.tensor(t) 275 | assert (R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3)) 276 | R = R.float() 277 | t = t.float() 278 | pose = torch.cat([R, t[..., None]], dim=-1) # [...,3,4] 279 | assert (pose.shape[-2:] == (3, 4)) 280 | return pose 281 | 282 | def invert(self, pose, use_inverse=False): 283 | # invert a camera pose 284 | R, t = pose[..., :3], pose[..., 3:] 285 | R_inv = R.inverse() if use_inverse else R.transpose(-1, -2) 286 | t_inv = (-R_inv@t)[..., 0] 287 | pose_inv = self(R=R_inv, t=t_inv) 288 | return pose_inv 289 | 290 | def compose(self, pose_list): 291 | # compose a sequence of poses together 292 | # pose_new(x) = poseN o ... o pose2 o pose1(x) 293 | pose_new = pose_list[0] 294 | for pose in pose_list[1:]: 295 | pose_new = self.compose_pair(pose_new, pose) 296 | return pose_new 297 | 298 | def compose_pair(self, pose_a, pose_b): 299 | # pose_new(x) = pose_b o pose_a(x) 300 | R_a, t_a = pose_a[..., :3], pose_a[..., 3:] 301 | R_b, t_b = pose_b[..., :3], pose_b[..., 3:] 302 | R_new = R_b@R_a 303 | t_new = (R_b@t_a+t_b)[..., 0] 304 | pose_new = self(R=R_new, t=t_new) 305 | return pose_new 306 | 307 | 308 | class Lie(): 309 | """ 310 | Lie algebra for SO(3) and SE(3) operations in PyTorch 311 | """ 312 | 313 | def so3_to_SO3(self, w): # [...,3] 314 | wx = self.skew_symmetric(w) 315 | theta = w.norm(dim=-1)[..., None, None] 316 | I = torch.eye(3, device=w.device, dtype=torch.float32) 317 | A = self.taylor_A(theta) 318 | B = self.taylor_B(theta) 319 | R = I+A*wx+B*wx@wx 320 | return R 321 | 322 | def SO3_to_so3(self, R, eps=1e-7): # [...,3,3] 323 | trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2] 324 | # ln(R) will explode if theta==pi 325 | theta = ((trace-1)/2).clamp(-1+eps, 1 - 326 | eps).acos_()[..., None, None] % np.pi 327 | lnR = 1/(2*self.taylor_A(theta)+1e-8) * \ 328 | (R-R.transpose(-2, -1)) # FIXME: wei-chiu finds it weird 329 | w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0] 330 | w = torch.stack([w0, w1, w2], dim=-1) 331 | return w 332 | 333 | def se3_to_SE3(self, wu): # [...,3] 334 | w, u = wu.split([3, 3], dim=-1) 335 | wx = self.skew_symmetric(w) 336 | theta = w.norm(dim=-1)[..., None, None] 337 | I = torch.eye(3, device=w.device, dtype=torch.float32) 338 | A = self.taylor_A(theta) 339 | B = self.taylor_B(theta) 340 | C = self.taylor_C(theta) 341 | R = I+A*wx+B*wx@wx 342 | V = I+B*wx+C*wx@wx 343 | Rt = torch.cat([R, (V@u[..., None])], dim=-1) 344 | return Rt 345 | 346 | def SE3_to_se3(self, Rt, eps=1e-8): # [...,3,4] 347 | R, t = Rt.split([3, 1], dim=-1) 348 | w = self.SO3_to_so3(R) 349 | wx = self.skew_symmetric(w) 350 | theta = w.norm(dim=-1)[..., None, None] 351 | I = torch.eye(3, device=w.device, dtype=torch.float32) 352 | A = self.taylor_A(theta) 353 | B = self.taylor_B(theta) 354 | invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx 355 | u = (invV@t)[..., 0] 356 | wu = torch.cat([w, u], dim=-1) 357 | return wu 358 | 359 | def skew_symmetric(self, w): 360 | w0, w1, w2 = w.unbind(dim=-1) 361 | O = torch.zeros_like(w0) 362 | wx = torch.stack([torch.stack([O, -w2, w1], dim=-1), 363 | torch.stack([w2, O, -w0], dim=-1), 364 | torch.stack([-w1, w0, O], dim=-1)], dim=-2) 365 | return wx 366 | 367 | def taylor_A(self, x, nth=10): 368 | # Taylor expansion of sin(x)/x 369 | ans = torch.zeros_like(x) 370 | denom = 1. 371 | for i in range(nth+1): 372 | if i > 0: 373 | denom *= (2*i)*(2*i+1) 374 | ans = ans+(-1)**i*x**(2*i)/denom 375 | return ans 376 | 377 | def taylor_B(self, x, nth=10): 378 | # Taylor expansion of (1-cos(x))/x**2 379 | ans = torch.zeros_like(x) 380 | denom = 1. 381 | for i in range(nth+1): 382 | denom *= (2*i+1)*(2*i+2) 383 | ans = ans+(-1)**i*x**(2*i)/denom 384 | return ans 385 | 386 | def taylor_C(self, x, nth=10): 387 | # Taylor expansion of (x-sin(x))/x**3 388 | ans = torch.zeros_like(x) 389 | denom = 1. 390 | for i in range(nth+1): 391 | denom *= (2*i+2)*(2*i+3) 392 | ans = ans+(-1)**i*x**(2*i)/denom 393 | return ans 394 | 395 | 396 | class Quaternion(): 397 | 398 | def q_to_R(self, q): 399 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 400 | qa, qb, qc, qd = q.unbind(dim=-1) 401 | R = torch.stack([torch.stack([1-2*(qc**2+qd**2), 2*(qb*qc-qa*qd), 2*(qa*qc+qb*qd)], dim=-1), 402 | torch.stack( 403 | [2*(qb*qc+qa*qd), 1-2*(qb**2+qd**2), 2*(qc*qd-qa*qb)], dim=-1), 404 | torch.stack([2*(qb*qd-qa*qc), 2*(qa*qb+qc*qd), 1-2*(qb**2+qc**2)], dim=-1)], dim=-2) 405 | return R 406 | 407 | def R_to_q(self, R, eps=1e-8): # [B,3,3] 408 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 409 | # FIXME: this function seems a bit problematic, need to double-check 410 | row0, row1, row2 = R.unbind(dim=-2) 411 | R00, R01, R02 = row0.unbind(dim=-1) 412 | R10, R11, R12 = row1.unbind(dim=-1) 413 | R20, R21, R22 = row2.unbind(dim=-1) 414 | t = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2] 415 | r = (1+t+eps).sqrt() 416 | qa = 0.5*r 417 | qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt() 418 | qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt() 419 | qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt() 420 | q = torch.stack([qa, qb, qc, qd], dim=-1) 421 | for i, qi in enumerate(q): 422 | if torch.isnan(qi).any(): 423 | K = torch.stack([torch.stack([R00-R11-R22, R10+R01, R20+R02, R12-R21], dim=-1), 424 | torch.stack( 425 | [R10+R01, R11-R00-R22, R21+R12, R20-R02], dim=-1), 426 | torch.stack( 427 | [R20+R02, R21+R12, R22-R00-R11, R01-R10], dim=-1), 428 | torch.stack([R12-R21, R20-R02, R01-R10, R00+R11+R22], dim=-1)], dim=-2)/3.0 429 | K = K[i] 430 | eigval, eigvec = torch.linalg.eigh(K) 431 | V = eigvec[:, eigval.argmax()] 432 | q[i] = torch.stack([V[3], V[0], V[1], V[2]]) 433 | return q 434 | 435 | def invert(self, q): 436 | qa, qb, qc, qd = q.unbind(dim=-1) 437 | norm = q.norm(dim=-1, keepdim=True) 438 | q_inv = torch.stack([qa, -qb, -qc, -qd], dim=-1)/norm**2 439 | return q_inv 440 | 441 | def product(self, q1, q2): # [B,4] 442 | q1a, q1b, q1c, q1d = q1.unbind(dim=-1) 443 | q2a, q2b, q2c, q2d = q2.unbind(dim=-1) 444 | hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d, 445 | q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c, 446 | q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b, 447 | q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a], dim=-1) 448 | return hamil_prod 449 | 450 | 451 | lie = Lie() 452 | pose = Pose() 453 | quaternion = Quaternion() 454 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | 11 | from scene.cameras import Camera 12 | import numpy as np 13 | from utils.general_utils import PILtoTorch 14 | from utils.graphics_utils import fov2focal 15 | import copy 16 | WARNED = False 17 | 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution) 24 | ), round(orig_h/(resolution_scale * args.resolution)) 25 | else: # should be a type that converts to float 26 | if args.resolution == -1: 27 | if orig_w > 1600: 28 | global WARNED 29 | if not WARNED: 30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 32 | WARNED = True 33 | global_down = orig_w / 1600 34 | else: 35 | global_down = 1 36 | else: 37 | global_down = orig_w / args.resolution 38 | 39 | scale = float(global_down) * float(resolution_scale) 40 | resolution = (int(orig_w / scale), int(orig_h / scale)) 41 | # intrinsics = copy.copy(cam_info.intrinsics) 42 | # intrinsics[:2, :] /= scale 43 | # import pdb; pdb.set_trace() 44 | focal_length_x = fov2focal(cam_info.FovX, orig_w) 45 | focal_length_y = fov2focal(cam_info.FovY, orig_h) 46 | scale = int(orig_w / resolution[0]) 47 | intrinsics = np.array( 48 | [[focal_length_x//scale, 0, resolution[0]/2], 49 | [0, focal_length_y//scale, resolution[1]/2], [0, 0, 1]]).astype(np.float32) 50 | 51 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 52 | 53 | gt_image = resized_image_rgb[:3, ...] 54 | loaded_mask = None 55 | 56 | if resized_image_rgb.shape[1] == 4: 57 | loaded_mask = resized_image_rgb[3:4, ...] 58 | 59 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 60 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 61 | image=gt_image, gt_alpha_mask=loaded_mask, intrinsics=intrinsics, 62 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 63 | 64 | 65 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 66 | camera_list = [] 67 | 68 | for id, c in enumerate(cam_infos): 69 | camera_list.append(loadCam(args, id, c, resolution_scale)) 70 | 71 | return camera_list 72 | 73 | 74 | def camera_to_JSON(id, camera: Camera): 75 | Rt = np.zeros((4, 4)) 76 | Rt[:3, :3] = camera.R.transpose() 77 | Rt[:3, 3] = camera.T 78 | Rt[3, 3] = 1.0 79 | 80 | W2C = np.linalg.inv(Rt) 81 | pos = W2C[:3, 3] 82 | rot = W2C[:3, :3] 83 | serializable_array_2d = [x.tolist() for x in rot] 84 | camera_entry = { 85 | 'id': id, 86 | 'img_name': camera.image_name, 87 | 'width': camera.width, 88 | 'height': camera.height, 89 | 'position': pos.tolist(), 90 | 'rotation': serializable_array_2d, 91 | 'fy': fov2focal(camera.FovY, camera.height), 92 | 'fx': fov2focal(camera.FovX, camera.width) 93 | } 94 | return camera_entry 95 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import torch 11 | import sys 12 | from datetime import datetime 13 | import numpy as np 14 | import random 15 | 16 | def inverse_sigmoid(x): 17 | return torch.log(x/(1-x)) 18 | 19 | def PILtoTorch(pil_image, resolution): 20 | resized_image_PIL = pil_image.resize(resolution) 21 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 22 | if len(resized_image.shape) == 3: 23 | return resized_image.permute(2, 0, 1) 24 | else: 25 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 26 | 27 | def get_expon_lr_func( 28 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 29 | ): 30 | """ 31 | Copied from Plenoxels 32 | 33 | Continuous learning rate decay function. Adapted from JaxNeRF 34 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 35 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 36 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 37 | function of lr_delay_mult, such that the initial learning rate is 38 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 39 | to the normal learning rate when steps>lr_delay_steps. 40 | :param conf: config subtree 'lr' or similar 41 | :param max_steps: int, the number of steps during optimization. 42 | :return HoF which takes step as input 43 | """ 44 | 45 | def helper(step): 46 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 47 | # Disable this parameter 48 | return 0.0 49 | if lr_delay_steps > 0: 50 | # A kind of reverse cosine decay. 51 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 52 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 53 | ) 54 | else: 55 | delay_rate = 1.0 56 | t = np.clip(step / max_steps, 0, 1) 57 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 58 | return delay_rate * log_lerp 59 | 60 | return helper 61 | 62 | def strip_lowerdiag(L): 63 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 64 | 65 | uncertainty[:, 0] = L[:, 0, 0] 66 | uncertainty[:, 1] = L[:, 0, 1] 67 | uncertainty[:, 2] = L[:, 0, 2] 68 | uncertainty[:, 3] = L[:, 1, 1] 69 | uncertainty[:, 4] = L[:, 1, 2] 70 | uncertainty[:, 5] = L[:, 2, 2] 71 | return uncertainty 72 | 73 | def strip_symmetric(sym): 74 | return strip_lowerdiag(sym) 75 | 76 | def build_rotation(r): 77 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 78 | 79 | q = r / norm[:, None] 80 | 81 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 82 | 83 | r = q[:, 0] 84 | x = q[:, 1] 85 | y = q[:, 2] 86 | z = q[:, 3] 87 | 88 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 89 | R[:, 0, 1] = 2 * (x*y - r*z) 90 | R[:, 0, 2] = 2 * (x*z + r*y) 91 | R[:, 1, 0] = 2 * (x*y + r*z) 92 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 93 | R[:, 1, 2] = 2 * (y*z - r*x) 94 | R[:, 2, 0] = 2 * (x*z - r*y) 95 | R[:, 2, 1] = 2 * (y*z + r*x) 96 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 97 | return R 98 | 99 | def build_scaling_rotation(s, r): 100 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 101 | R = build_rotation(r) 102 | 103 | L[:,0,0] = s[:,0] 104 | L[:,1,1] = s[:,1] 105 | L[:,2,2] = s[:,2] 106 | 107 | L = R @ L 108 | return L 109 | 110 | def safe_state(silent): 111 | old_f = sys.stdout 112 | class F: 113 | def __init__(self, silent): 114 | self.silent = silent 115 | 116 | def write(self, x): 117 | if not self.silent: 118 | if x.endswith("\n"): 119 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 120 | else: 121 | old_f.write(x) 122 | 123 | def flush(self): 124 | old_f.flush() 125 | 126 | sys.stdout = F(silent) 127 | 128 | random.seed(0) 129 | np.random.seed(0) 130 | torch.manual_seed(0) 131 | torch.cuda.set_device(torch.device("cuda:0")) 132 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | 11 | import torch 12 | import math 13 | import numpy as np 14 | from typing import NamedTuple 15 | 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points: np.array 19 | colors: np.array 20 | normals: np.array 21 | 22 | def update(self, points): 23 | self.points = points 24 | 25 | 26 | class SegmentedPointCloud(NamedTuple): 27 | points: np.array 28 | colors: np.array 29 | normals: np.array 30 | labels: np.array 31 | 32 | 33 | class SimilarityTransform(NamedTuple): 34 | R: torch.Tensor 35 | T: torch.Tensor 36 | s: torch.Tensor 37 | 38 | 39 | def geom_transform_points(points, transf_matrix): 40 | P, _ = points.shape 41 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 42 | points_hom = torch.cat([points, ones], dim=1) 43 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 44 | 45 | denom = points_out[..., 3:] + 0.0000001 46 | return (points_out[..., :3] / denom).squeeze(dim=0) 47 | 48 | 49 | def getWorld2View(R, t): 50 | Rt = np.zeros((4, 4)) 51 | Rt[:3, :3] = R.transpose() 52 | Rt[:3, 3] = t 53 | Rt[3, 3] = 1.0 54 | return np.float32(Rt) 55 | 56 | 57 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 58 | Rt = np.zeros((4, 4)) 59 | Rt[:3, :3] = R.transpose() 60 | Rt[:3, 3] = t 61 | Rt[3, 3] = 1.0 62 | 63 | C2W = np.linalg.inv(Rt) 64 | cam_center = C2W[:3, 3] 65 | cam_center = (cam_center + translate) * scale 66 | C2W[:3, 3] = cam_center 67 | Rt = np.linalg.inv(C2W) 68 | return np.float32(Rt) 69 | 70 | 71 | def getWorld2View3(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 72 | Rt = np.zeros((4, 4)) 73 | Rt[:3, :3] = R 74 | Rt[:3, 3] = t 75 | Rt[3, 3] = 1.0 76 | 77 | C2W = np.linalg.inv(Rt) 78 | cam_center = C2W[:3, 3] 79 | cam_center = (cam_center + translate) * scale 80 | C2W[:3, 3] = cam_center 81 | Rt = np.linalg.inv(C2W) 82 | return np.float32(Rt) 83 | 84 | 85 | def getProjectionMatrix(znear, zfar, fovX, fovY): 86 | tanHalfFovY = math.tan((fovY / 2)) 87 | tanHalfFovX = math.tan((fovX / 2)) 88 | 89 | top = tanHalfFovY * znear 90 | bottom = -top 91 | right = tanHalfFovX * znear 92 | left = -right 93 | 94 | P = torch.zeros(4, 4) 95 | 96 | z_sign = 1.0 97 | 98 | P[0, 0] = 2.0 * znear / (right - left) 99 | P[1, 1] = 2.0 * znear / (top - bottom) 100 | P[0, 2] = (right + left) / (right - left) 101 | P[1, 2] = (top + bottom) / (top - bottom) 102 | P[3, 2] = z_sign 103 | P[2, 2] = z_sign * zfar / (zfar - znear) 104 | P[2, 3] = -(zfar * znear) / (zfar - znear) 105 | return P 106 | 107 | 108 | def fov2focal(fov, pixels): 109 | return pixels / (2 * math.tan(fov / 2)) 110 | 111 | 112 | def focal2fov(focal, pixels): 113 | return 2*math.atan(pixels/(2*focal)) 114 | 115 | 116 | def procrustes(S1, S2): 117 | ''' 118 | Computes a similarity transform (sR, t) that takes 119 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2, 120 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. 121 | i.e. solves the orthogonal Procrutes problem. 122 | ''' 123 | transposed = False 124 | if S1.shape[0] != 3 and S1.shape[0] != 2: 125 | S1 = S1.T 126 | S2 = S2.T 127 | transposed = True 128 | assert (S2.shape[1] == S1.shape[1]) 129 | 130 | # 1. Remove mean. 131 | mu1 = S1.mean(axis=1, keepdims=True) 132 | mu2 = S2.mean(axis=1, keepdims=True) 133 | X1 = S1 - mu1 134 | X2 = S2 - mu2 135 | 136 | # print('X1', X1.shape) 137 | 138 | # 2. Compute variance of X1 used for scale. 139 | var1 = torch.sum(X1 ** 2) 140 | 141 | # print('var', var1.shape) 142 | 143 | # 3. The outer product of X1 and X2. 144 | K = X1.mm(X2.T) 145 | 146 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are 147 | # singular vectors of K. 148 | U, s, V = torch.svd(K) 149 | # V = Vh.T 150 | # Construct Z that fixes the orientation of R to get det(R)=1. 151 | Z = torch.eye(U.shape[0], device=S1.device) 152 | Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) 153 | # Construct R. 154 | R = V.mm(Z.mm(U.T)) 155 | 156 | # print('R', X1.shape) 157 | 158 | # 5. Recover scale. 159 | scale = torch.trace(R.mm(K)) / var1 160 | # print(R.shape, mu1.shape) 161 | # 6. Recover translation. 162 | t = mu2 - scale * (R.mm(mu1)) 163 | # t = mu2 - (R.mm(mu1)) 164 | # print(t.shape) 165 | 166 | # 7. Error: 167 | S1_hat = scale * R.mm(S1) + t 168 | # S1_hat = R.mm(S1) + t 169 | 170 | if transposed: 171 | S1_hat = S1_hat.T 172 | 173 | R_ = torch.eye(4).to(S1) 174 | R_[:3, :3] = R 175 | T_ = torch.eye(4).to(S1) 176 | T_[:3, -1] = t.squeeze(-1) 177 | S_ = torch.eye(4).to(S1) 178 | transf = T_@S_@R_ 179 | 180 | return S1_hat, transf 181 | 182 | 183 | # def procrustes(S1, S2,weights=None): 184 | # ''' 185 | # Computes a similarity transform (sR, t) that takes 186 | # a set of 3D points S1 (BxNx3) closest to a set of 3D points, S2, 187 | # where R is an 3x3 rotation matrix, t 3x1 translation, s scale. / mod : assuming scale is 1 188 | # i.e. solves the orthogonal Procrutes problem. 189 | # ''' 190 | # transposed = False 191 | # if S1.shape[0] != 3 and S1.shape[0] != 2: 192 | # S1 = S1.T 193 | # S2 = S2.T 194 | # transposed = True 195 | # assert (S2.shape[1] == S1.shape[1]) 196 | 197 | # if weights is None: 198 | # weights = torch.ones_like(S1[:1,:]) 199 | 200 | # # 1. Remove mean. 201 | # weights_norm = weights/(weights.sum(-1, keepdim=True) + 1e-6) 202 | # mu1 = (S1*weights_norm).sum(1, keepdim=True) 203 | # mu2 = (S2*weights_norm).sum(1, keepdim=True) 204 | 205 | # X1 = S1 - mu1 206 | # X2 = S2 - mu2 207 | 208 | # # diags = torch.stack([torch.diag(w.squeeze(0)) for w in weights]) # does batched version exist? 209 | # diags = torch.diag(weights.squeeze()) 210 | 211 | # # 3. The outer product of X1 and X2. 212 | # K = (X1@diags).mm(X2.T) 213 | # # K = (X1@diags).bmm(X2.permute(0,2,1)) 214 | 215 | # # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. 216 | # U, s, V = torch.svd(K) 217 | 218 | # # Construct Z that fixes the orientation of R to get det(R)=1. 219 | # Z = torch.eye(U.shape[0], device=S1.device) 220 | # Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) 221 | # # Construct R. 222 | # R = V.mm(Z.mm(U.T)) 223 | 224 | # # 6. Recover translation. 225 | # t = mu2 - ((R.mm(mu1))) 226 | 227 | # # 7. Error: 228 | # S1_hat = R.mm(S1) + t 229 | # if transposed: 230 | # S1_hat = S1_hat.T 231 | 232 | # # Combine recovered transformation as single matrix 233 | # R_=torch.eye(4).to(S1) 234 | # R_[:3, :3]=R 235 | # T_=torch.eye(4).to(S1) 236 | # T_[:3, -1]=t.squeeze(-1) 237 | # S_=torch.eye(4).to(S1) 238 | # transf = T_@S_@R_ 239 | 240 | # return (S1_hat-S2).square().mean(),transf 241 | 242 | 243 | def convert3x4_4x4(input): 244 | """ 245 | Make into homogeneous cordinates by adding [0, 0, 0, 1] to the bottom. 246 | :param input: (N, 3, 4) or (3, 4) torch or np 247 | :return: (N, 4, 4) or (4, 4) torch or np 248 | """ 249 | if torch.is_tensor(input): 250 | if len(input.shape) == 3: 251 | output = torch.cat([input, torch.zeros_like( 252 | input[:, 0:1])], dim=1) # (N, 4, 4) 253 | output[:, 3, 3] = 1.0 254 | else: 255 | output = torch.cat([input, torch.tensor( 256 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 257 | else: 258 | if len(input.shape) == 3: 259 | output = np.concatenate( 260 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 261 | output[:, 3, 3] = 1.0 262 | else: 263 | output = np.concatenate( 264 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) 265 | output[3, 3] = 1.0 266 | return output 267 | 268 | 269 | def align_umeyama(model, data, known_scale=False): 270 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation 271 | of Transformation Parameters Between Two Point Patterns, 272 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991. 273 | 274 | model = s * R * data + t 275 | 276 | Input: 277 | model -- first trajectory (nx3), numpy array type 278 | data -- second trajectory (nx3), numpy array type 279 | 280 | Output: 281 | s -- scale factor (scalar) 282 | R -- rotation matrix (3x3) 283 | t -- translation vector (3x1) 284 | t_error -- translational error per point (1xn) 285 | 286 | """ 287 | 288 | # substract mean 289 | mu_M = model.mean(0) 290 | mu_D = data.mean(0) 291 | model_zerocentered = model - mu_M 292 | data_zerocentered = data - mu_D 293 | n = np.shape(model)[0] 294 | 295 | # correlation 296 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered) 297 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum() 298 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C) 299 | 300 | D_svd = np.diag(D_svd) 301 | V_svd = np.transpose(V_svd) 302 | 303 | S = np.eye(3) 304 | if (np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0): 305 | S[2, 2] = -1 306 | 307 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd))) 308 | 309 | if known_scale: 310 | s = 1 311 | else: 312 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S)) 313 | 314 | t = mu_M-s*np.dot(R, mu_D) 315 | 316 | return s, R, t 317 | 318 | 319 | def _getIndices(n_aligned, total_n): 320 | if n_aligned == -1: 321 | idxs = np.arange(0, total_n) 322 | else: 323 | assert n_aligned <= total_n and n_aligned >= 1 324 | idxs = np.arange(0, n_aligned) 325 | return idxs 326 | 327 | 328 | # align by similarity transformation 329 | def align_sim3(p_es, p_gt, n_aligned=-1): 330 | ''' 331 | calculate s, R, t so that: 332 | gt = R * s * est + t 333 | ''' 334 | idxs = _getIndices(n_aligned, p_es.shape[0]) 335 | est_pos = p_es[idxs, 0:3] 336 | gt_pos = p_gt[idxs, 0:3] 337 | try: 338 | s, R, t = align_umeyama(gt_pos, est_pos) # note the order 339 | except: 340 | print('[WARNING] align_poses.py: SVD did not converge!') 341 | s, R, t = 1.0, np.eye(3), np.zeros(3) 342 | return s, R, t 343 | 344 | 345 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None): 346 | """Align c to b using the sim3 from a to b. 347 | :param traj_a: (N0, 3/4, 4) torch tensor 348 | :param traj_b: (N0, 3/4, 4) torch tensor 349 | :param traj_c: None or (N1, 3/4, 4) torch tensor 350 | :return: (N1, 4, 4) torch tensor 351 | """ 352 | device = traj_a.device 353 | if traj_c is None: 354 | traj_c = traj_a.clone() 355 | 356 | traj_a = traj_a.float().cpu().numpy() 357 | traj_b = traj_b.float().cpu().numpy() 358 | traj_c = traj_c.float().cpu().numpy() 359 | 360 | R_a = traj_a[:, :3, :3] # (N0, 3, 3) 361 | t_a = traj_a[:, :3, 3] # (N0, 3) 362 | 363 | R_b = traj_b[:, :3, :3] # (N0, 3, 3) 364 | t_b = traj_b[:, :3, 3] # (N0, 3) 365 | 366 | # This function works in quaternion. 367 | # scalar, (3, 3), (3, ) gt = R * s * est + t. 368 | s, R, t = align_sim3(t_a, t_b) 369 | 370 | # reshape tensors 371 | R = R[None, :, :].astype(np.float32) # (1, 3, 3) 372 | t = t[None, :, None].astype(np.float32) # (1, 3, 1) 373 | s = float(s) 374 | 375 | R_c = traj_c[:, :3, :3] # (N1, 3, 3) 376 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1) 377 | 378 | R_c_aligned = R @ R_c # (N1, 3, 3) 379 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1) 380 | traj_c_aligned = np.concatenate( 381 | [R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4) 382 | 383 | # append the last row 384 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4) 385 | 386 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device) 387 | 388 | return traj_c_aligned # (N1, 4, 4) 389 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | 11 | import torch 12 | 13 | def mse(img1, img2): 14 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 15 | 16 | def psnr(img1, img2): 17 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 18 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 19 | 20 | import numpy as np 21 | import matplotlib 22 | import matplotlib as mpl 23 | import matplotlib.pyplot as plt 24 | from matplotlib.patches import Patch 25 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 26 | 27 | class CameraPoseVisualizer: 28 | def __init__(self, xlim=None, ylim=None, zlim=None): 29 | self.fig = plt.figure(figsize=(18, 7)) 30 | self.ax = self.fig.add_subplot(projection='3d') 31 | self.ax.set_aspect("auto") 32 | if xlim is not None and ylim is not None and zlim is not None: 33 | self.ax.set_xlim(xlim) 34 | self.ax.set_ylim(ylim) 35 | self.ax.set_zlim(zlim) 36 | 37 | self.ax.view_init(elev=10., azim=45) 38 | self.ax.xaxis.set_tick_params(labelbottom=False) 39 | self.ax.yaxis.set_tick_params(labelleft=False) 40 | self.ax.zaxis.set_tick_params(labelleft=False) 41 | self.ax.set_xlabel('x') 42 | self.ax.set_ylabel('y') 43 | self.ax.set_zlabel('z') 44 | # plt.tight_layout() 45 | print('initialize camera pose visualizer') 46 | 47 | def extrinsic2pyramid(self, extrinsic, color='r', focal_len_scaled=5, aspect_ratio=0.3): 48 | vertex_std = np.array([[0, 0, 0, 1], 49 | [focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 50 | [focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 51 | [-focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 52 | [-focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1]]) 53 | vertex_transformed = vertex_std @ extrinsic.T 54 | meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], 55 | [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], 56 | [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], 57 | [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], 58 | [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] 59 | self.ax.add_collection3d( 60 | Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) 61 | 62 | def customize_legend(self, list_label): 63 | list_handle = [] 64 | for idx, label in enumerate(list_label): 65 | color = plt.cm.rainbow(idx / len(list_label)) 66 | patch = Patch(color=color, label=label) 67 | list_handle.append(patch) 68 | plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) 69 | 70 | def colorbar(self, max_frame_length): 71 | cmap = mpl.cm.rainbow 72 | norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) 73 | self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='Frame Number') 74 | 75 | def show(self): 76 | plt.title('Extrinsic Parameters') 77 | plt.show() 78 | 79 | def add_traj(self, poses, c='r', alpha=0.5): 80 | x = [float(pose[0, 3]) for pose in poses] 81 | y = [float(pose[1, 3]) for pose in poses] 82 | z = [float(pose[2, 3]) for pose in poses] 83 | self.ax.plot(x,y,z, c=c, alpha=alpha) 84 | 85 | def save(self, path): 86 | plt.tight_layout() 87 | self.fig.savefig(path, bbox_inches='tight') 88 | 89 | 90 | def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): 91 | """Converts a depth map to a color image. 92 | 93 | Args: 94 | value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed 95 | vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. 96 | vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. 97 | cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. 98 | invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. 99 | invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. 100 | background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). 101 | gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. 102 | value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. 103 | 104 | Returns: 105 | numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) 106 | """ 107 | if isinstance(value, torch.Tensor): 108 | value = value.detach().cpu().numpy() 109 | 110 | value = value.squeeze() 111 | if invalid_mask is None: 112 | invalid_mask = value == invalid_val 113 | mask = np.logical_not(invalid_mask) 114 | 115 | # normalize 116 | vmin = np.percentile(value[mask],2) if vmin is None else vmin 117 | vmax = np.percentile(value[mask],85) if vmax is None else vmax 118 | if vmin != vmax: 119 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 120 | else: 121 | # Avoid 0-division 122 | value = value * 0. 123 | 124 | # squeeze last dim if it exists 125 | # grey out the invalid values 126 | 127 | value[invalid_mask] = np.nan 128 | cmapper = mpl.cm.get_cmap(cmap) 129 | if value_transform: 130 | value = value_transform(value) 131 | # value = value / value.max() 132 | value = cmapper(value, bytes=True) # (nxmx4) 133 | 134 | # img = value[:, :, :] 135 | img = value[...] 136 | img[invalid_mask] = background_color 137 | 138 | # return img.transpose((2, 0, 1)) 139 | if gamma_corrected: 140 | # gamma correction 141 | img = img / 255 142 | img = np.power(img, 2.2) 143 | img = img * 255 144 | img = img.astype(np.uint8) 145 | return img 146 | 147 | 148 | 149 | def cm_prune(x_): 150 | """Custom colormap to visualize pruning""" 151 | if isinstance(x_, torch.Tensor): 152 | x_ = x_.cpu().numpy() 153 | max_i = max(x_) 154 | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) 155 | return cm_BlRdGn(norm_x) 156 | 157 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from math import exp 14 | 15 | def l1_loss(network_output, gt): 16 | return torch.abs((network_output - gt)).mean() 17 | 18 | def l2_loss(network_output, gt): 19 | return ((network_output - gt) ** 2).mean() 20 | 21 | def gaussian(window_size, sigma): 22 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 23 | return gauss / gauss.sum() 24 | 25 | def create_window(window_size, channel): 26 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 27 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 28 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 29 | return window 30 | 31 | def ssim(img1, img2, window_size=11, size_average=True): 32 | channel = img1.size(-3) 33 | window = create_window(window_size, channel) 34 | 35 | if img1.is_cuda: 36 | window = window.cuda(img1.get_device()) 37 | window = window.type_as(img1) 38 | 39 | return _ssim(img1, img2, window, window_size, channel, size_average) 40 | 41 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 42 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 43 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 44 | 45 | mu1_sq = mu1.pow(2) 46 | mu2_sq = mu2.pow(2) 47 | mu1_mu2 = mu1 * mu2 48 | 49 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 50 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 51 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 52 | 53 | C1 = 0.01 ** 2 54 | C2 = 0.03 ** 2 55 | 56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 57 | 58 | if size_average: 59 | return ssim_map.mean() 60 | else: 61 | return ssim_map.mean(1).mean(1).mean(1) 62 | 63 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Inria 2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 3 | # All rights reserved. 4 | # 5 | # This software is free for non-commercial, research and evaluation use 6 | # under the terms of the LICENSE_inria.md file. 7 | # 8 | # For inquiries contact george.drettakis@inria.fr 9 | 10 | from errno import EEXIST 11 | from os import makedirs, path 12 | import os 13 | 14 | def mkdir_p(folder_path): 15 | # Creates a directory. equivalent to using mkdir -p on the command line 16 | try: 17 | makedirs(folder_path) 18 | except OSError as exc: # Python >2.5 19 | if exc.errno == EEXIST and path.isdir(folder_path): 20 | pass 21 | else: 22 | raise 23 | 24 | def searchForMaxIteration(folder): 25 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 26 | return max(saved_iters) 27 | -------------------------------------------------------------------------------- /utils/utils_poses/ATE/align_trajectory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | #!/usr/bin/env python2 11 | # -*- coding: utf-8 -*- 12 | 13 | import numpy as np 14 | import utils.utils_poses.ATE.transformations as tfs 15 | 16 | 17 | def get_best_yaw(C): 18 | ''' 19 | maximize trace(Rz(theta) * C) 20 | ''' 21 | assert C.shape == (3, 3) 22 | 23 | A = C[0, 1] - C[1, 0] 24 | B = C[0, 0] + C[1, 1] 25 | theta = np.pi / 2 - np.arctan2(B, A) 26 | 27 | return theta 28 | 29 | 30 | def rot_z(theta): 31 | R = tfs.rotation_matrix(theta, [0, 0, 1]) 32 | R = R[0:3, 0:3] 33 | 34 | return R 35 | 36 | 37 | def align_umeyama(model, data, known_scale=False, yaw_only=False): 38 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation 39 | of Transformation Parameters Between Two Point Patterns, 40 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991. 41 | 42 | model = s * R * data + t 43 | 44 | Input: 45 | model -- first trajectory (nx3), numpy array type 46 | data -- second trajectory (nx3), numpy array type 47 | 48 | Output: 49 | s -- scale factor (scalar) 50 | R -- rotation matrix (3x3) 51 | t -- translation vector (3x1) 52 | t_error -- translational error per point (1xn) 53 | 54 | """ 55 | 56 | # substract mean 57 | mu_M = model.mean(0) 58 | mu_D = data.mean(0) 59 | model_zerocentered = model - mu_M 60 | data_zerocentered = data - mu_D 61 | n = np.shape(model)[0] 62 | 63 | # correlation 64 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered) 65 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum() 66 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C) 67 | 68 | D_svd = np.diag(D_svd) 69 | V_svd = np.transpose(V_svd) 70 | 71 | S = np.eye(3) 72 | if(np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0): 73 | S[2, 2] = -1 74 | 75 | if yaw_only: 76 | rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered) 77 | theta = get_best_yaw(rot_C) 78 | R = rot_z(theta) 79 | else: 80 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd))) 81 | 82 | if known_scale: 83 | s = 1 84 | else: 85 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S)) 86 | 87 | t = mu_M-s*np.dot(R, mu_D) 88 | 89 | return s, R, t 90 | -------------------------------------------------------------------------------- /utils/utils_poses/ATE/align_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #!/usr/bin/env python2 10 | # -*- coding: utf-8 -*- 11 | 12 | import numpy as np 13 | 14 | import utils.utils_poses.ATE.transformations as tfs 15 | import utils.utils_poses.ATE.align_trajectory as align 16 | 17 | 18 | def _getIndices(n_aligned, total_n): 19 | if n_aligned == -1: 20 | idxs = np.arange(0, total_n) 21 | else: 22 | assert n_aligned <= total_n and n_aligned >= 1 23 | idxs = np.arange(0, n_aligned) 24 | return idxs 25 | 26 | 27 | def alignPositionYawSingle(p_es, p_gt, q_es, q_gt): 28 | ''' 29 | calcualte the 4DOF transformation: yaw R and translation t so that: 30 | gt = R * est + t 31 | ''' 32 | 33 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :] 34 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :] 35 | g_rot = tfs.quaternion_matrix(q_gt_0) 36 | g_rot = g_rot[0:3, 0:3] 37 | est_rot = tfs.quaternion_matrix(q_es_0) 38 | est_rot = est_rot[0:3, 0:3] 39 | 40 | C_R = np.dot(est_rot, g_rot.transpose()) 41 | theta = align.get_best_yaw(C_R) 42 | R = align.rot_z(theta) 43 | t = p_gt_0 - np.dot(R, p_es_0) 44 | 45 | return R, t 46 | 47 | 48 | def alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned=1): 49 | if n_aligned == 1: 50 | R, t = alignPositionYawSingle(p_es, p_gt, q_es, q_gt) 51 | return R, t 52 | else: 53 | idxs = _getIndices(n_aligned, p_es.shape[0]) 54 | est_pos = p_es[idxs, 0:3] 55 | gt_pos = p_gt[idxs, 0:3] 56 | _, R, t = align.align_umeyama(gt_pos, est_pos, known_scale=True, 57 | yaw_only=True) # note the order 58 | t = np.array(t) 59 | t = t.reshape((3, )) 60 | R = np.array(R) 61 | return R, t 62 | 63 | 64 | # align by a SE3 transformation 65 | def alignSE3Single(p_es, p_gt, q_es, q_gt): 66 | ''' 67 | Calculate SE3 transformation R and t so that: 68 | gt = R * est + t 69 | Using only the first poses of est and gt 70 | ''' 71 | 72 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :] 73 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :] 74 | 75 | g_rot = tfs.quaternion_matrix(q_gt_0) 76 | g_rot = g_rot[0:3, 0:3] 77 | est_rot = tfs.quaternion_matrix(q_es_0) 78 | est_rot = est_rot[0:3, 0:3] 79 | 80 | R = np.dot(g_rot, np.transpose(est_rot)) 81 | t = p_gt_0 - np.dot(R, p_es_0) 82 | 83 | return R, t 84 | 85 | 86 | def alignSE3(p_es, p_gt, q_es, q_gt, n_aligned=-1): 87 | ''' 88 | Calculate SE3 transformation R and t so that: 89 | gt = R * est + t 90 | ''' 91 | if n_aligned == 1: 92 | R, t = alignSE3Single(p_es, p_gt, q_es, q_gt) 93 | return R, t 94 | else: 95 | idxs = _getIndices(n_aligned, p_es.shape[0]) 96 | est_pos = p_es[idxs, 0:3] 97 | gt_pos = p_gt[idxs, 0:3] 98 | s, R, t = align.align_umeyama(gt_pos, est_pos, 99 | known_scale=True) # note the order 100 | t = np.array(t) 101 | t = t.reshape((3, )) 102 | R = np.array(R) 103 | return R, t 104 | 105 | 106 | # align by similarity transformation 107 | def alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned=-1): 108 | ''' 109 | calculate s, R, t so that: 110 | gt = R * s * est + t 111 | ''' 112 | idxs = _getIndices(n_aligned, p_es.shape[0]) 113 | est_pos = p_es[idxs, 0:3] 114 | gt_pos = p_gt[idxs, 0:3] 115 | s, R, t = align.align_umeyama(gt_pos, est_pos) # note the order 116 | return s, R, t 117 | 118 | 119 | # a general interface 120 | def alignTrajectory(p_es, p_gt, q_es, q_gt, method, n_aligned=-1): 121 | ''' 122 | calculate s, R, t so that: 123 | gt = R * s * est + t 124 | method can be: sim3, se3, posyaw, none; 125 | n_aligned: -1 means using all the frames 126 | ''' 127 | assert p_es.shape[1] == 3 128 | assert p_gt.shape[1] == 3 129 | assert q_es.shape[1] == 4 130 | assert q_gt.shape[1] == 4 131 | 132 | s = 1 133 | R = None 134 | t = None 135 | if method == 'sim3': 136 | assert n_aligned >= 2 or n_aligned == -1, "sim3 uses at least 2 frames" 137 | s, R, t = alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned) 138 | elif method == 'se3': 139 | R, t = alignSE3(p_es, p_gt, q_es, q_gt, n_aligned) 140 | elif method == 'posyaw': 141 | R, t = alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned) 142 | elif method == 'none': 143 | R = np.identity(3) 144 | t = np.zeros((3, )) 145 | else: 146 | assert False, 'unknown alignment method' 147 | 148 | return s, R, t 149 | 150 | 151 | if __name__ == '__main__': 152 | pass 153 | -------------------------------------------------------------------------------- /utils/utils_poses/ATE/compute_trajectory_errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #!/usr/bin/env python2 10 | 11 | import os 12 | import numpy as np 13 | 14 | import utils.utils_poses.ATE.trajectory_utils as tu 15 | import utils.utils_poses.ATE.transformations as tf 16 | 17 | 18 | def compute_relative_error(p_es, q_es, p_gt, q_gt, T_cm, dist, max_dist_diff, 19 | accum_distances=[], 20 | scale=1.0): 21 | 22 | if len(accum_distances) == 0: 23 | accum_distances = tu.get_distance_from_start(p_gt) 24 | comparisons = tu.compute_comparison_indices_length( 25 | accum_distances, dist, max_dist_diff) 26 | 27 | n_samples = len(comparisons) 28 | print('number of samples = {0} '.format(n_samples)) 29 | if n_samples < 2: 30 | print("Too few samples! Will not compute.") 31 | return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]),\ 32 | np.array([]), np.array([]) 33 | 34 | T_mc = np.linalg.inv(T_cm) 35 | errors = [] 36 | for idx, c in enumerate(comparisons): 37 | if not c == -1: 38 | T_c1 = tu.get_rigid_body_trafo(q_es[idx, :], p_es[idx, :]) 39 | T_c2 = tu.get_rigid_body_trafo(q_es[c, :], p_es[c, :]) 40 | T_c1_c2 = np.dot(np.linalg.inv(T_c1), T_c2) 41 | T_c1_c2[:3, 3] *= scale 42 | 43 | T_m1 = tu.get_rigid_body_trafo(q_gt[idx, :], p_gt[idx, :]) 44 | T_m2 = tu.get_rigid_body_trafo(q_gt[c, :], p_gt[c, :]) 45 | T_m1_m2 = np.dot(np.linalg.inv(T_m1), T_m2) 46 | 47 | T_m1_m2_in_c1 = np.dot(T_cm, np.dot(T_m1_m2, T_mc)) 48 | T_error_in_c2 = np.dot(np.linalg.inv(T_m1_m2_in_c1), T_c1_c2) 49 | T_c2_rot = np.eye(4) 50 | T_c2_rot[0:3, 0:3] = T_c2[0:3, 0:3] 51 | T_error_in_w = np.dot(T_c2_rot, np.dot( 52 | T_error_in_c2, np.linalg.inv(T_c2_rot))) 53 | errors.append(T_error_in_w) 54 | 55 | error_trans_norm = [] 56 | error_trans_perc = [] 57 | error_yaw = [] 58 | error_gravity = [] 59 | e_rot = [] 60 | e_rot_deg_per_m = [] 61 | for e in errors: 62 | tn = np.linalg.norm(e[0:3, 3]) 63 | error_trans_norm.append(tn) 64 | error_trans_perc.append(tn / dist * 100) 65 | ypr_angles = tf.euler_from_matrix(e, 'rzyx') 66 | e_rot.append(tu.compute_angle(e)) 67 | error_yaw.append(abs(ypr_angles[0])*180.0/np.pi) 68 | error_gravity.append( 69 | np.sqrt(ypr_angles[1]**2+ypr_angles[2]**2)*180.0/np.pi) 70 | e_rot_deg_per_m.append(e_rot[-1] / dist) 71 | return errors, np.array(error_trans_norm), np.array(error_trans_perc),\ 72 | np.array(error_yaw), np.array(error_gravity), np.array(e_rot),\ 73 | np.array(e_rot_deg_per_m) 74 | 75 | 76 | def compute_absolute_error(p_es_aligned, q_es_aligned, p_gt, q_gt): 77 | e_trans_vec = (p_gt-p_es_aligned) 78 | e_trans = np.sqrt(np.sum(e_trans_vec**2, 1)) 79 | 80 | 81 | # orientation error 82 | e_rot = np.zeros((len(e_trans,))) 83 | e_ypr = np.zeros(np.shape(p_es_aligned)) 84 | for i in range(np.shape(p_es_aligned)[0]): 85 | R_we = tf.matrix_from_quaternion(q_es_aligned[i, :]) 86 | R_wg = tf.matrix_from_quaternion(q_gt[i, :]) 87 | e_R = np.dot(R_we, np.linalg.inv(R_wg)) 88 | e_ypr[i, :] = tf.euler_from_matrix(e_R, 'rzyx') 89 | e_rot[i] = np.rad2deg(np.linalg.norm(tf.logmap_so3(e_R[:3, :3]))) 90 | # scale drift 91 | motion_gt = np.diff(p_gt, 0) 92 | motion_es = np.diff(p_es_aligned, 0) 93 | dist_gt = np.sqrt(np.sum(np.multiply(motion_gt, motion_gt), 1)) 94 | dist_es = np.sqrt(np.sum(np.multiply(motion_es, motion_es), 1)) 95 | e_scale_perc = np.abs((np.divide(dist_es, dist_gt)-1.0) * 100) 96 | # ate = np.sqrt(np.mean(np.asarray(e_trans) ** 2)) 97 | return e_trans, e_trans_vec, e_rot, e_ypr, e_scale_perc 98 | -------------------------------------------------------------------------------- /utils/utils_poses/ATE/results_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #!/usr/bin/env python2 10 | import os 11 | # import yaml 12 | import numpy as np 13 | 14 | 15 | def compute_statistics(data_vec): 16 | stats = dict() 17 | if len(data_vec) > 0: 18 | stats['rmse'] = float( 19 | np.sqrt(np.dot(data_vec, data_vec) / len(data_vec))) 20 | stats['mean'] = float(np.mean(data_vec)) 21 | stats['median'] = float(np.median(data_vec)) 22 | stats['std'] = float(np.std(data_vec)) 23 | stats['min'] = float(np.min(data_vec)) 24 | stats['max'] = float(np.max(data_vec)) 25 | stats['num_samples'] = int(len(data_vec)) 26 | else: 27 | stats['rmse'] = 0 28 | stats['mean'] = 0 29 | stats['median'] = 0 30 | stats['std'] = 0 31 | stats['min'] = 0 32 | stats['max'] = 0 33 | stats['num_samples'] = 0 34 | 35 | return stats 36 | 37 | 38 | # def update_and_save_stats(new_stats, label, yaml_filename): 39 | # stats = dict() 40 | # if os.path.exists(yaml_filename): 41 | # stats = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader) 42 | # stats[label] = new_stats 43 | # 44 | # with open(yaml_filename, 'w') as outfile: 45 | # outfile.write(yaml.dump(stats, default_flow_style=False)) 46 | # 47 | # return 48 | # 49 | # 50 | # def compute_and_save_statistics(data_vec, label, yaml_filename): 51 | # new_stats = compute_statistics(data_vec) 52 | # update_and_save_stats(new_stats, label, yaml_filename) 53 | # 54 | # return new_stats 55 | # 56 | # 57 | # def write_tex_table(list_values, rows, cols, outfn): 58 | # ''' 59 | # write list_values[row_idx][col_idx] to a table that is ready to be pasted 60 | # into latex source 61 | # 62 | # list_values is a list of row values 63 | # 64 | # The value should be string of desired format 65 | # ''' 66 | # 67 | # assert len(rows) >= 1 68 | # assert len(cols) >= 1 69 | # 70 | # with open(outfn, 'w') as f: 71 | # # write header 72 | # f.write(' & ') 73 | # for col_i in cols[:-1]: 74 | # f.write(col_i + ' & ') 75 | # f.write(' ' + cols[-1]+'\n') 76 | # 77 | # # write each row 78 | # for row_idx, row_i in enumerate(list_values): 79 | # f.write(rows[row_idx] + ' & ') 80 | # row_values = list_values[row_idx] 81 | # for col_idx in range(len(row_values) - 1): 82 | # f.write(row_values[col_idx] + ' & ') 83 | # f.write(' ' + row_values[-1]+' \n') 84 | -------------------------------------------------------------------------------- /utils/utils_poses/ATE/trajectory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #!/usr/bin/env python2 10 | """ 11 | @author: Christian Forster 12 | """ 13 | 14 | import os 15 | import numpy as np 16 | import utils.utils_poses.ATE.transformations as tf 17 | 18 | 19 | def get_rigid_body_trafo(quat, trans): 20 | T = tf.quaternion_matrix(quat) 21 | T[0:3, 3] = trans 22 | return T 23 | 24 | 25 | def get_distance_from_start(gt_translation): 26 | distances = np.diff(gt_translation[:, 0:3], axis=0) 27 | distances = np.sqrt(np.sum(np.multiply(distances, distances), 1)) 28 | distances = np.cumsum(distances) 29 | distances = np.concatenate(([0], distances)) 30 | return distances 31 | 32 | 33 | def compute_comparison_indices_length(distances, dist, max_dist_diff): 34 | max_idx = len(distances) 35 | comparisons = [] 36 | for idx, d in enumerate(distances): 37 | best_idx = -1 38 | error = max_dist_diff 39 | for i in range(idx, max_idx): 40 | if np.abs(distances[i]-(d+dist)) < error: 41 | best_idx = i 42 | error = np.abs(distances[i] - (d+dist)) 43 | if best_idx != -1: 44 | comparisons.append(best_idx) 45 | return comparisons 46 | 47 | 48 | def compute_angle(transform): 49 | """ 50 | Compute the rotation angle from a 4x4 homogeneous matrix. 51 | """ 52 | # an invitation to 3-d vision, p 27 53 | return np.arccos( 54 | min(1, max(-1, (np.trace(transform[0:3, 0:3]) - 1)/2)))*180.0/np.pi 55 | -------------------------------------------------------------------------------- /utils/utils_poses/align_traj.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from utils.utils_poses.ATE.align_utils import alignTrajectory 13 | from utils.utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4 14 | 15 | 16 | def pts_dist_max(pts): 17 | """ 18 | :param pts: (N, 3) torch or np 19 | :return: scalar 20 | """ 21 | if torch.is_tensor(pts): 22 | dist = pts.unsqueeze(0) - pts.unsqueeze(1) # (1, N, 3) - (N, 1, 3) -> (N, N, 3) 23 | dist = dist[0] # (N, 3) 24 | dist = dist.norm(dim=1) # (N, ) 25 | max_dist = dist.max() 26 | else: 27 | dist = pts[None, :, :] - pts[:, None, :] # (1, N, 3) - (N, 1, 3) -> (N, N, 3) 28 | dist = dist[0] # (N, 3) 29 | dist = np.linalg.norm(dist, axis=1) # (N, ) 30 | max_dist = dist.max() 31 | return max_dist 32 | 33 | 34 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None, method='sim3'): 35 | """Align c to b using the sim3 from a to b. 36 | :param traj_a: (N0, 3/4, 4) torch tensor 37 | :param traj_b: (N0, 3/4, 4) torch tensor 38 | :param traj_c: None or (N1, 3/4, 4) torch tensor 39 | :return: (N1, 4, 4) torch tensor 40 | """ 41 | device = traj_a.device 42 | if traj_c is None: 43 | traj_c = traj_a.clone() 44 | 45 | traj_a = traj_a.float().cpu().numpy() 46 | traj_b = traj_b.float().cpu().numpy() 47 | traj_c = traj_c.float().cpu().numpy() 48 | 49 | R_a = traj_a[:, :3, :3] # (N0, 3, 3) 50 | t_a = traj_a[:, :3, 3] # (N0, 3) 51 | quat_a = SO3_to_quat(R_a) # (N0, 4) 52 | 53 | R_b = traj_b[:, :3, :3] # (N0, 3, 3) 54 | t_b = traj_b[:, :3, 3] # (N0, 3) 55 | quat_b = SO3_to_quat(R_b) # (N0, 4) 56 | 57 | # This function works in quaternion. 58 | # scalar, (3, 3), (3, ) gt = R * s * est + t. 59 | s, R, t = alignTrajectory(t_a, t_b, quat_a, quat_b, method=method) 60 | 61 | # reshape tensors 62 | R = R[None, :, :].astype(np.float32) # (1, 3, 3) 63 | t = t[None, :, None].astype(np.float32) # (1, 3, 1) 64 | s = float(s) 65 | 66 | R_c = traj_c[:, :3, :3] # (N1, 3, 3) 67 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1) 68 | 69 | R_c_aligned = R @ R_c # (N1, 3, 3) 70 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1) 71 | traj_c_aligned = np.concatenate([R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4) 72 | 73 | # append the last row 74 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4) 75 | 76 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device) 77 | return traj_c_aligned # (N1, 4, 4) 78 | 79 | 80 | 81 | def align_scale_c2b_use_a2b(traj_a, traj_b, traj_c=None): 82 | '''Scale c to b using the scale from a to b. 83 | :param traj_a: (N0, 3/4, 4) torch tensor 84 | :param traj_b: (N0, 3/4, 4) torch tensor 85 | :param traj_c: None or (N1, 3/4, 4) torch tensor 86 | :return: 87 | scaled_traj_c (N1, 4, 4) torch tensor 88 | scale scalar 89 | ''' 90 | if traj_c is None: 91 | traj_c = traj_a.clone() 92 | 93 | t_a = traj_a[:, :3, 3] # (N, 3) 94 | t_b = traj_b[:, :3, 3] # (N, 3) 95 | 96 | # scale estimated poses to colmap scale 97 | # s_a2b: a*s ~ b 98 | scale_a2b = pts_dist_max(t_b) / pts_dist_max(t_a) 99 | 100 | traj_c[:, :3, 3] *= scale_a2b 101 | 102 | if traj_c.shape[1] == 3: 103 | traj_c = convert3x4_4x4(traj_c) # (N, 4, 4) 104 | 105 | return traj_c, scale_a2b # (N, 4, 4) 106 | -------------------------------------------------------------------------------- /utils/utils_poses/comp_ate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | 11 | import utils.utils_poses.ATE.trajectory_utils as tu 12 | import utils.utils_poses.ATE.transformations as tf 13 | def rotation_error(pose_error): 14 | """Compute rotation error 15 | Args: 16 | pose_error (4x4 array): relative pose error 17 | Returns: 18 | rot_error (float): rotation error 19 | """ 20 | a = pose_error[0, 0] 21 | b = pose_error[1, 1] 22 | c = pose_error[2, 2] 23 | d = 0.5*(a+b+c-1.0) 24 | rot_error = np.arccos(max(min(d, 1.0), -1.0)) 25 | return rot_error 26 | 27 | def translation_error(pose_error): 28 | """Compute translation error 29 | Args: 30 | pose_error (4x4 array): relative pose error 31 | Returns: 32 | trans_error (float): translation error 33 | """ 34 | dx = pose_error[0, 3] 35 | dy = pose_error[1, 3] 36 | dz = pose_error[2, 3] 37 | trans_error = np.sqrt(dx**2+dy**2+dz**2) 38 | return trans_error 39 | 40 | def compute_rpe(gt, pred): 41 | trans_errors = [] 42 | rot_errors = [] 43 | for i in range(len(gt)-1): 44 | gt1 = gt[i] 45 | gt2 = gt[i+1] 46 | gt_rel = np.linalg.inv(gt1) @ gt2 47 | 48 | pred1 = pred[i] 49 | pred2 = pred[i+1] 50 | pred_rel = np.linalg.inv(pred1) @ pred2 51 | rel_err = np.linalg.inv(gt_rel) @ pred_rel 52 | 53 | trans_errors.append(translation_error(rel_err)) 54 | rot_errors.append(rotation_error(rel_err)) 55 | rpe_trans = np.mean(np.asarray(trans_errors)) 56 | rpe_rot = np.mean(np.asarray(rot_errors)) 57 | return rpe_trans, rpe_rot 58 | 59 | def compute_ATE(gt, pred): 60 | """Compute RMSE of ATE 61 | Args: 62 | gt: ground-truth poses 63 | pred: predicted poses 64 | """ 65 | errors = [] 66 | 67 | for i in range(len(pred)): 68 | # cur_gt = np.linalg.inv(gt_0) @ gt[i] 69 | cur_gt = gt[i] 70 | gt_xyz = cur_gt[:3, 3] 71 | 72 | # cur_pred = np.linalg.inv(pred_0) @ pred[i] 73 | cur_pred = pred[i] 74 | pred_xyz = cur_pred[:3, 3] 75 | 76 | align_err = gt_xyz - pred_xyz 77 | 78 | errors.append(np.sqrt(np.sum(align_err ** 2))) 79 | ate = np.sqrt(np.mean(np.asarray(errors) ** 2)) 80 | return ate 81 | 82 | -------------------------------------------------------------------------------- /utils/utils_poses/lie_group_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | from scipy.spatial.transform import Rotation as RotLib 12 | 13 | 14 | def SO3_to_quat(R): 15 | """ 16 | :param R: (N, 3, 3) or (3, 3) np 17 | :return: (N, 4, ) or (4, ) np 18 | """ 19 | x = RotLib.from_matrix(R) 20 | quat = x.as_quat() 21 | return quat 22 | 23 | 24 | def quat_to_SO3(quat): 25 | """ 26 | :param quat: (N, 4, ) or (4, ) np 27 | :return: (N, 3, 3) or (3, 3) np 28 | """ 29 | x = RotLib.from_quat(quat) 30 | R = x.as_matrix() 31 | return R 32 | 33 | 34 | def convert3x4_4x4(input): 35 | """ 36 | :param input: (N, 3, 4) or (3, 4) torch or np 37 | :return: (N, 4, 4) or (4, 4) torch or np 38 | """ 39 | if torch.is_tensor(input): 40 | if len(input.shape) == 3: 41 | output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4) 42 | output[:, 3, 3] = 1.0 43 | else: 44 | output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 45 | else: 46 | if len(input.shape) == 3: 47 | output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 48 | output[:, 3, 3] = 1.0 49 | else: 50 | output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4) 51 | output[3, 3] = 1.0 52 | return output 53 | 54 | 55 | def vec2skew(v): 56 | """ 57 | :param v: (3, ) torch tensor 58 | :return: (3, 3) 59 | """ 60 | zero = torch.zeros(1, dtype=torch.float32, device=v.device) 61 | skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1) 62 | skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]]) 63 | skew_v2 = torch.cat([-v[1:2], v[0:1], zero]) 64 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3) 65 | return skew_v # (3, 3) 66 | 67 | 68 | def Exp(r): 69 | """so(3) vector to SO(3) matrix 70 | :param r: (3, ) axis-angle, torch tensor 71 | :return: (3, 3) 72 | """ 73 | skew_r = vec2skew(r) # (3, 3) 74 | norm_r = r.norm() + 1e-15 75 | eye = torch.eye(3, dtype=torch.float32, device=r.device) 76 | R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r) 77 | return R 78 | 79 | 80 | def make_c2w(r, t): 81 | """ 82 | :param r: (3, ) axis-angle torch tensor 83 | :param t: (3, ) translation vector torch tensor 84 | :return: (4, 4) 85 | """ 86 | R = Exp(r) # (3, 3) 87 | c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4) 88 | c2w = convert3x4_4x4(c2w) # (4, 4) 89 | return c2w 90 | -------------------------------------------------------------------------------- /utils/utils_poses/vis_cam_traj.py: -------------------------------------------------------------------------------- 1 | ''' 2 | BSD 2-Clause License 3 | 4 | Copyright (c) 2020, the NeRF++ authors 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | ''' 28 | 29 | import numpy as np 30 | 31 | try: 32 | import open3d as o3d 33 | except ImportError: 34 | pass 35 | 36 | 37 | def frustums2lineset(frustums): 38 | N = len(frustums) 39 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 40 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 41 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 42 | 43 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 44 | merged_points[i*5:(i+1)*5, :] = frustum_points 45 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 46 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 47 | 48 | lineset = o3d.geometry.LineSet() 49 | lineset.points = o3d.utility.Vector3dVector(merged_points) 50 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 51 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 52 | 53 | return lineset 54 | 55 | 56 | def get_camera_frustum_opengl_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])): 57 | '''X right, Y up, Z backward to the observer. 58 | :param H, W: 59 | :param fx, fy: 60 | :param W2C: (4, 4) matrix 61 | :param frustum_length: scalar: scale the frustum 62 | :param color: (3,) list, frustum line color 63 | :return: 64 | frustum_points: (5, 3) frustum points in world coordinate 65 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index. 66 | frustum_colors: (8, 3) colors for 8 lines. 67 | ''' 68 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.) 69 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.) 70 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 71 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 72 | 73 | # build view frustum in camera space in homogenous coordinate (5, 4) 74 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin 75 | [-half_w, half_h, -frustum_length, 1.0], # top-left image corner 76 | [half_w, half_h, -frustum_length, 1.0], # top-right image corner 77 | [half_w, -half_h, -frustum_length, 1.0], # bottom-right image corner 78 | [-half_w, -half_h, -frustum_length, 1.0]]) # bottom-left image corner 79 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2) 80 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3) 81 | 82 | # transform view frustum from camera space to world space 83 | C2W = np.linalg.inv(W2C) 84 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4) 85 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate 86 | return frustum_points, frustum_lines, frustum_colors 87 | 88 | def get_camera_frustum_opencv_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])): 89 | '''X right, Y up, Z backward to the observer. 90 | :param H, W: 91 | :param fx, fy: 92 | :param W2C: (4, 4) matrix 93 | :param frustum_length: scalar: scale the frustum 94 | :param color: (3,) list, frustum line color 95 | :return: 96 | frustum_points: (5, 3) frustum points in world coordinate 97 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index. 98 | frustum_colors: (8, 3) colors for 8 lines. 99 | ''' 100 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.) 101 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.) 102 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 103 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 104 | 105 | # build view frustum in camera space in homogenous coordinate (5, 4) 106 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin 107 | [-half_w, -half_h, frustum_length, 1.0], # top-left image corner 108 | [ half_w, -half_h, frustum_length, 1.0], # top-right image corner 109 | [ half_w, half_h, frustum_length, 1.0], # bottom-right image corner 110 | [-half_w, +half_h, frustum_length, 1.0]]) # bottom-left image corner 111 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2) 112 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3) 113 | 114 | # transform view frustum from camera space to world space 115 | C2W = np.linalg.inv(W2C) 116 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4) 117 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate 118 | return frustum_points, frustum_lines, frustum_colors 119 | 120 | 121 | 122 | def draw_camera_frustum_geometry(c2ws, H, W, fx=600.0, fy=600.0, frustum_length=0.5, 123 | color=np.array([29.0, 53.0, 87.0])/255.0, draw_now=False, coord='opengl'): 124 | ''' 125 | :param c2ws: (N, 4, 4) np.array 126 | :param H: scalar 127 | :param W: scalar 128 | :param fx: scalar 129 | :param fy: scalar 130 | :param frustum_length: scalar 131 | :param color: None or (N, 3) or (3, ) or (1, 3) or (3, 1) np array 132 | :param draw_now: True/False call o3d vis now 133 | :return: 134 | ''' 135 | N = c2ws.shape[0] 136 | 137 | num_ele = color.flatten().shape[0] 138 | if num_ele == 3: 139 | color = color.reshape(1, 3) 140 | color = np.tile(color, (N, 1)) 141 | 142 | frustum_list = [] 143 | if coord == 'opengl': 144 | for i in range(N): 145 | frustum_list.append(get_camera_frustum_opengl_coord(H, W, fx, fy, 146 | W2C=np.linalg.inv(c2ws[i]), 147 | frustum_length=frustum_length, 148 | color=color[i])) 149 | elif coord == 'opencv': 150 | for i in range(N): 151 | frustum_list.append(get_camera_frustum_opencv_coord(H, W, fx, fy, 152 | W2C=np.linalg.inv(c2ws[i]), 153 | frustum_length=frustum_length, 154 | color=color[i])) 155 | else: 156 | print('Undefined coordinate system. Exit') 157 | exit() 158 | 159 | frustums_geometry = frustums2lineset(frustum_list) 160 | 161 | if draw_now: 162 | o3d.visualization.draw_geometries([frustums_geometry]) 163 | 164 | return frustums_geometry # this is an o3d geometry object. 165 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | from matplotlib import tight_layout 13 | import copy 14 | from evo.core.trajectory import PosePath3D, PoseTrajectory3D 15 | from evo.main_ape import ape 16 | from evo.tools import plot 17 | from evo.core import sync 18 | from evo.tools import file_interface 19 | from evo.core import metrics 20 | import evo 21 | import torch 22 | import numpy as np 23 | from scipy.spatial.transform import Slerp 24 | from scipy.spatial.transform import Rotation as R 25 | import scipy.interpolate as si 26 | 27 | 28 | def interp_poses(c2ws, N_views): 29 | N_inputs = c2ws.shape[0] 30 | trans = c2ws[:, :3, 3:].permute(2, 1, 0) 31 | rots = c2ws[:, :3, :3] 32 | render_poses = [] 33 | rots = R.from_matrix(rots) 34 | slerp = Slerp(np.linspace(0, 1, N_inputs), rots) 35 | interp_rots = torch.tensor( 36 | slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32)) 37 | interp_trans = torch.nn.functional.interpolate( 38 | trans, size=N_views, mode='linear').permute(2, 1, 0) 39 | render_poses = torch.cat([interp_rots, interp_trans], dim=2) 40 | render_poses = convert3x4_4x4(render_poses) 41 | return render_poses 42 | 43 | 44 | def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree): 45 | target_trans = torch.tensor(scipy_bspline( 46 | c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2) 47 | rots = R.from_matrix(c2ws[:, :3, :3]) 48 | slerp = Slerp(input_times, rots) 49 | target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs) 50 | target_rots = torch.tensor( 51 | slerp(target_times).as_matrix().astype(np.float32)) 52 | target_poses = torch.cat([target_rots, target_trans], dim=2) 53 | target_poses = convert3x4_4x4(target_poses) 54 | return target_poses 55 | 56 | 57 | def poses_avg(poses): 58 | 59 | hwf = poses[0, :3, -1:] 60 | 61 | center = poses[:, :3, 3].mean(0) 62 | vec2 = normalize(poses[:, :3, 2].sum(0)) 63 | up = poses[:, :3, 1].sum(0) 64 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 65 | 66 | return c2w 67 | 68 | 69 | def normalize(v): 70 | """Normalize a vector.""" 71 | return v / np.linalg.norm(v) 72 | 73 | 74 | def viewmatrix(z, up, pos): 75 | vec2 = normalize(z) 76 | vec1_avg = up 77 | vec0 = normalize(np.cross(vec1_avg, vec2)) 78 | vec1 = normalize(np.cross(vec2, vec0)) 79 | m = np.stack([vec0, vec1, vec2, pos], 1) 80 | return m 81 | 82 | 83 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 84 | render_poses = [] 85 | rads = np.array(list(rads) + [1.]) 86 | hwf = c2w[:, 4:5] 87 | 88 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 89 | # c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads) 90 | # c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads) 91 | c = np.dot(c2w[:3, :4], np.array( 92 | [0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads) 93 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 94 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 95 | return render_poses 96 | 97 | 98 | def scipy_bspline(cv, n=100, degree=3, periodic=False): 99 | """ Calculate n samples on a bspline 100 | 101 | cv : Array ov control vertices 102 | n : Number of samples to return 103 | degree: Curve degree 104 | periodic: True - Curve is closed 105 | """ 106 | cv = np.asarray(cv) 107 | count = cv.shape[0] 108 | 109 | # Closed curve 110 | if periodic: 111 | kv = np.arange(-degree, count+degree+1) 112 | factor, fraction = divmod(count+degree+1, count) 113 | cv = np.roll(np.concatenate( 114 | (cv,) * factor + (cv[:fraction],)), -1, axis=0) 115 | degree = np.clip(degree, 1, degree) 116 | 117 | # Opened curve 118 | else: 119 | degree = np.clip(degree, 1, count-1) 120 | kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree) 121 | 122 | # Return samples 123 | max_param = count - (degree * (1-periodic)) 124 | spl = si.BSpline(kv, cv, degree) 125 | return spl(np.linspace(0, max_param, n)) 126 | 127 | 128 | def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf): 129 | learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach( 130 | ).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1) 131 | c2w = poses_avg(learned_poses_) 132 | print('recentered', c2w.shape) 133 | # Get spiral 134 | # Get average pose 135 | up = normalize(learned_poses_[:, :3, 1].sum(0)) 136 | # Find a reasonable "focus depth" for this dataset 137 | 138 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 139 | dt = .75 140 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 141 | focal = mean_dz 142 | 143 | # Get radii for spiral path 144 | shrink_factor = .8 145 | zdelta = close_depth * .2 146 | tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 147 | rads = np.percentile(np.abs(tt), 90, 0) 148 | c2w_path = c2w 149 | N_rots = 2 150 | c2ws = render_path_spiral( 151 | c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views) 152 | c2ws = torch.tensor(np.stack(c2ws).astype(np.float32)) 153 | c2ws = c2ws[:, :3, :4] 154 | c2ws = convert3x4_4x4(c2ws) 155 | return c2ws 156 | 157 | 158 | def convert3x4_4x4(input): 159 | """ 160 | :param input: (N, 3, 4) or (3, 4) torch or np 161 | :return: (N, 4, 4) or (4, 4) torch or np 162 | """ 163 | if torch.is_tensor(input): 164 | if len(input.shape) == 3: 165 | output = torch.cat([input, torch.zeros_like( 166 | input[:, 0:1])], dim=1) # (N, 4, 4) 167 | output[:, 3, 3] = 1.0 168 | else: 169 | output = torch.cat([input, torch.tensor( 170 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) 171 | else: 172 | if len(input.shape) == 3: 173 | output = np.concatenate( 174 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) 175 | output[:, 3, 3] = 1.0 176 | else: 177 | output = np.concatenate( 178 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) 179 | output[3, 3] = 1.0 180 | return output 181 | 182 | 183 | plt.rc('legend', fontsize=20) # using a named size 184 | 185 | 186 | def plot_pose(ref_poses, est_poses, output_path, vid=False): 187 | ref_poses = [pose for pose in ref_poses] 188 | if isinstance(est_poses, dict): 189 | est_poses = [pose for k, pose in est_poses.items()] 190 | else: 191 | est_poses = [pose for pose in est_poses] 192 | traj_ref = PosePath3D(poses_se3=ref_poses) 193 | traj_est = PosePath3D(poses_se3=est_poses) 194 | traj_est_aligned = copy.deepcopy(traj_est) 195 | traj_est_aligned.align(traj_ref, correct_scale=True, 196 | correct_only_scale=False) 197 | if vid: 198 | for p_idx in range(len(ref_poses)): 199 | fig = plt.figure() 200 | current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1] 201 | current_ref = traj_ref.poses_se3[:p_idx+1] 202 | current_est_aligned = PosePath3D(poses_se3=current_est_aligned) 203 | current_ref = PosePath3D(poses_se3=current_ref) 204 | traj_by_label = { 205 | # "estimate (not aligned)": traj_est, 206 | "Ours (aligned)": current_est_aligned, 207 | "Ground-truth": current_ref 208 | } 209 | plot_mode = plot.PlotMode.xyz 210 | # ax = plot.prepare_axis(fig, plot_mode, 111) 211 | ax = fig.add_subplot(111, projection="3d") 212 | ax.xaxis.set_tick_params(labelbottom=False) 213 | ax.yaxis.set_tick_params(labelleft=False) 214 | ax.zaxis.set_tick_params(labelleft=False) 215 | colors = ['r', 'b'] 216 | styles = ['-', '--'] 217 | 218 | for idx, (label, traj) in enumerate(traj_by_label.items()): 219 | plot.traj(ax, plot_mode, traj, 220 | styles[idx], colors[idx], label) 221 | # break 222 | # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) 223 | ax.view_init(elev=10., azim=45) 224 | plt.tight_layout() 225 | os.makedirs(os.path.join(os.path.dirname( 226 | output_path), 'pose_vid'), exist_ok=True) 227 | pose_vis_path = os.path.join(os.path.dirname( 228 | output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx)) 229 | print(pose_vis_path) 230 | fig.savefig(pose_vis_path) 231 | 232 | # else: 233 | 234 | fig = plt.figure() 235 | traj_by_label = { 236 | # "estimate (not aligned)": traj_est, 237 | "Ours (aligned)": traj_est_aligned, 238 | "Ground-truth": traj_ref 239 | } 240 | plot_mode = plot.PlotMode.xyz 241 | # ax = plot.prepare_axis(fig, plot_mode, 111) 242 | ax = fig.add_subplot(111, projection="3d") 243 | ax.xaxis.set_tick_params(labelbottom=False) 244 | ax.yaxis.set_tick_params(labelleft=False) 245 | ax.zaxis.set_tick_params(labelleft=False) 246 | colors = ['r', 'b'] 247 | styles = ['-', '--'] 248 | 249 | for idx, (label, traj) in enumerate(traj_by_label.items()): 250 | plot.traj(ax, plot_mode, traj, 251 | styles[idx], colors[idx], label) 252 | # break 253 | # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) 254 | ax.view_init(elev=10., azim=45) 255 | plt.tight_layout() 256 | pose_vis_path = os.path.join(os.path.dirname(output_path), 'pose_vis.png') 257 | fig.savefig(pose_vis_path) 258 | --------------------------------------------------------------------------------