├── .gitignore
├── .gitmodules
├── LICENSE
├── LICENSE_inria.md
├── README.md
├── arguments
└── __init__.py
├── assets
└── teaser.gif
├── convert.py
├── environment.yml
├── gaussian_renderer
├── __init__.py
└── network_gui.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── requirements.txt
├── run_cf3dgs.py
├── scene
├── __init__.py
├── camera_model.py
├── cameras.py
├── colmap_loader.py
├── dataset_readers.py
├── gaussian_model.py
└── gaussian_model_cf.py
├── train.py
├── trainer
├── cf3dgs_trainer.py
├── losses.py
└── trainer.py
└── utils
├── camera_conversion.py
├── camera_utils.py
├── general_utils.py
├── geometry_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── sh_utils.py
├── system_utils.py
├── utils_poses
├── ATE
│ ├── align_trajectory.py
│ ├── align_utils.py
│ ├── compute_trajectory_errors.py
│ ├── results_writer.py
│ ├── trajectory_utils.py
│ └── transformations.py
├── align_traj.py
├── comp_ate.py
├── lie_group_helper.py
└── vis_cam_traj.py
└── vis_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | output
4 | build
5 | diff_rasterization/diff_rast.egg-info
6 | diff_rasterization/dist
7 | tensorboard_3d
8 | screenshots
9 | demo
10 | checkpoints
11 | data/
12 | external/
13 | vis_ply/
14 | vis/
15 | debug/
16 | debug_vis/
17 | output/
18 | output_sup/
19 | wandb/
20 | *.pth
21 | *.ply
22 | *.png
23 | *.jpg
24 | *.mp4
25 | *.ipynb_checkpoints
26 | *.ipynb
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/simple-knn"]
2 | path = submodules/simple-knn
3 | url = https://gitlab.inria.fr/bkerbl/simple-knn
4 | [submodule "submodules/diff-gaussian-rasterization"]
5 | path = submodules/diff-gaussian-rasterization
6 | url = https://github.com/ashawkey/diff-gaussian-rasterization
7 | [submodule "submodules/SIBR_viewers"]
8 | path = submodules/SIBR_viewers
9 | url = https://gitlab.inria.fr/sibr/sibr_core
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 |
3 | NVIDIA Source Code License for RADIO
4 |
5 | =======================================================================
6 |
7 | 1. Definitions
8 |
9 | “Licensor” means any person or entity that distributes its Work.
10 |
11 | “Work” means (a) the original work of authorship made available under
12 | this license, which may include software, documentation, or other files,
13 | and (b) any additions to or derivative works thereof that are made
14 | available under this license.
15 |
16 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution”
17 | have the meaning as provided under U.S. copyright law; provided, however,
18 | that for the purposes of this license, derivative works shall not include works
19 | that remain separable from, or merely link (or bind by name) to the
20 | interfaces of, the Work.
21 |
22 | Works are “made available” under this license by including in or with the Work
23 | either (a) a copyright notice referencing the applicability of
24 | this license to the Work, or (b) a copy of this license.
25 |
26 | 2. License Grant
27 |
28 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each
29 | Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free,
30 | copyright license to use, reproduce, prepare derivative works of, publicly display,
31 | publicly perform, sublicense and distribute its Work and any resulting derivative
32 | works in any form.
33 |
34 | 3. Limitations
35 |
36 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under
37 | this license, (b) you include a complete copy of this license with your distribution,
38 | and (c) you retain without modification any copyright, patent, trademark, or
39 | attribution notices that are present in the Work.
40 |
41 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use,
42 | reproduction, and distribution of your derivative works of the Work (“Your Terms”) only
43 | if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative
44 | works, and (b) you identify the specific derivative works that are subject to Your Terms.
45 | Notwithstanding Your Terms, this license (including the redistribution requirements in
46 | Section 3.1) will continue to apply to the Work itself.
47 |
48 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or
49 | intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation
50 | and its affiliates may use the Work and any derivative works commercially.
51 | As used herein, “non-commercially” means for research or evaluation purposes only.
52 |
53 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor
54 | (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that
55 | you allege are infringed by any Work, then your rights under this license from
56 | such Licensor (including the grant in Section 2.1) will terminate immediately.
57 |
58 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its
59 | affiliates’ names, logos, or trademarks, except as necessary to reproduce
60 | the notices described in this license.
61 |
62 | 3.6 Termination. If you violate any term of this license, then your rights under
63 | this license (including the grant in Section 2.1) will terminate immediately.
64 |
65 | 4. Disclaimer of Warranty.
66 |
67 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
68 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
69 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
70 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
71 |
72 | 5. Limitation of Liability.
73 |
74 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY,
75 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR
76 | BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL,
77 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR
78 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS
79 | INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY
80 | OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE
81 | POSSIBILITY OF SUCH DAMAGES.
82 |
83 | =======================================================================
84 |
--------------------------------------------------------------------------------
/LICENSE_inria.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
85 | ## 6. Files subject to permissive licenses
86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license.
87 |
88 | Title: pytorch-ssim\
89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\
90 | Copyright Evan Su, 2017\
91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
COLMAP-Free 3D Gaussian Splatting
6 |
7 | Yang Fu
8 | ·
9 | Sifei Liu
10 | ·
11 | Amey Kulkarni
12 | ·
13 | Jan Kautz
14 |
15 | Alexei A. Efros
16 | ·
17 | Xiaolong Wang
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | ## Installation
30 |
31 | ##### (Recommended)
32 | The codes have been tested on python 3.10, CUDA>=11.6. The simplest way to install all dependences is to use [anaconda](https://www.anaconda.com/) and [pip](https://pypi.org/project/pip/) in the following steps:
33 |
34 | ```bash
35 | conda create -n cf3dgs python=3.10
36 | conda activate cf3dgs
37 | conda install conda-forge::cudatoolkit-dev=11.7.0
38 | conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.7 -c pytorch -c nvidia
39 | git clone --recursive git@github.com:NVlabs/CF-3DGS.git
40 | pip install -r requirements.txt
41 | ```
42 |
43 | ## Dataset Preparsion
44 | DATAROOT is `./data` by default. Please first make data folder by `mkdir data`.
45 |
46 | ### Tanks and Temples
47 |
48 | Download the data preprocessed by [Nope-NeRF](https://github.com/ActiveVisionLab/nope-nerf/?tab=readme-ov-file#Data) as below, and the data is saved into the `./data/Tanks` folder.
49 | ```bash
50 | wget https://www.robots.ox.ac.uk/~wenjing/Tanks.zip
51 | ```
52 |
53 | ### CO3D
54 | Download our preprocessed [data](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/yafu_ucsd_edu/EftJV9Xpn0hNjmOiGKZuzyIBW5j6hAVEGhewc8aUcFShEA?e=x1aXVx), and put it saved into the `./data/co3d` folder.
55 |
56 |
57 | ## Run
58 |
59 | ### Training
60 | ```bash
61 | python run_cf3dgs.py -s data/Tanks/Francis \ # change the scene path
62 | --mode train \
63 | --data_type tanks
64 | ```
65 |
66 | ### Evaluation
67 | ```bash
68 | # pose estimation
69 | python run_cf3dgs.py --source data/Tanks/Francis \
70 | --mode eval_pose \
71 | --data_type tanks \
72 | --model_path ${CKPT_PATH}
73 | # by default the checkpoint should be store in "./output/progressive/Tanks_Francis/chkpnt/ep00_init.pth"
74 | # novel view synthesis
75 | python run_cf3dgs.py --source data/Tanks/Francis \
76 | --mode eval_nvs \
77 | --data_type tanks \
78 | --model_path ${CKPT_PATH}
79 | ```
80 | We release some of the novel view synthesis results ([gdrive](https://drive.google.com/drive/folders/1p3WljCN90zrm1N5lO-24OLHmUFmFWntt?usp=sharing)) for comparison with future works.
81 |
82 | ### Run on your own video
83 |
84 | * To run CF-3DGS on your own video, you need to first convert your video to frames and save them to `./data/$CUSTOM_DATA/images/
85 | `
86 |
87 | * Camera intrincics can be obtained by running COLMAP (check details in `convert.py`). Otherwise, we provide a heuristic camera setting which should work for most landscope videos.
88 |
89 | * Run the following commands:
90 |
91 | ```bash
92 | python run_cf3dgs.py -s ./data/$CUSTOM_DATA/ \ # change to your data path
93 | --mode train \
94 | --data_type custom
95 | ```
96 |
97 | ## Acknowledgement
98 | Our render is built upon [3DGS](https://github.com/graphdeco-inria/gaussian-splatting). The data processing and visualization codes are partially borrowed from [Nope-NeRF](https://github.com/ActiveVisionLab/nope-nerf/). We thank all the authors for their great repos.
99 |
100 | ## Citation
101 |
102 | If you find this code helpful, please cite:
103 |
104 | ```
105 | @InProceedings{Fu_2024_CVPR,
106 | author = {Fu, Yang and Liu, Sifei and Kulkarni, Amey and Kautz, Jan and Efros, Alexei A. and Wang, Xiaolong},
107 | title = {COLMAP-Free 3D Gaussian Splatting},
108 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
109 | month = {June},
110 | year = {2024},
111 | pages = {20796-20805}
112 | }
113 | ```
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument(
34 | "--" + key, ("-" + key[0:1]), default=value, action="store_true")
35 | else:
36 | group.add_argument(
37 | "--" + key, ("-" + key[0:1]), default=value, type=t)
38 | else:
39 | if t == bool:
40 | group.add_argument(
41 | "--" + key, default=value, action="store_true")
42 | else:
43 | group.add_argument("--" + key, default=value, type=t)
44 |
45 | def extract(self, args):
46 | group = GroupParams()
47 | for arg in vars(args).items():
48 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
49 | setattr(group, arg[0], arg[1])
50 | return group
51 |
52 |
53 | class ModelParams(ParamGroup):
54 | def __init__(self, parser, sentinel=False):
55 | self.sh_degree = 3
56 | self._source_path = ""
57 | self._model_path = ""
58 | self._images = "images"
59 | self._resolution = -1
60 | self._white_background = False
61 | self.data_device = "cuda"
62 | self.eval = True
63 | self.rot_type = "6d"
64 | self.view_dependent = True
65 | # self.model_type = 'original'
66 | self.data_type = "tanks"
67 | self.depth_model_type = "dpt"
68 | self.mode = "train"
69 | # self.eval_nvs = False
70 | # self.eval_pose = False
71 | # self.vis_mesh = False
72 | # self.render_nvs = False
73 | self.traj_opt = "bspline"
74 | super().__init__(parser, "Loading Parameters", sentinel)
75 |
76 | def extract(self, args):
77 | g = super().extract(args)
78 | g.source_path = os.path.abspath(g.source_path)
79 | return g
80 |
81 |
82 | class PipelineParams(ParamGroup):
83 | def __init__(self, parser):
84 | self.convert_SHs_python = False
85 | self.compute_cov3D_python = False
86 | self.debug = False
87 | # self.mode = "color"
88 | self.use_gt_pcd = False
89 | self.use_mask = False
90 | self.use_ref_img = False
91 | self.init_mode = "rand"
92 | self.use_mono = True
93 | self.interval = 15
94 | self.expname = ""
95 | self.use_sampon = False
96 | self.refine = False
97 | self.distortion = False
98 | super().__init__(parser, "Pipeline Parameters")
99 |
100 |
101 | class OptimizationParams(ParamGroup):
102 | def __init__(self, parser):
103 | self.iterations = 30_000
104 | self.position_lr_init = 0.00016
105 | self.position_lr_final = 0.0000016
106 | self.position_lr_delay_mult = 0.01
107 | self.position_lr_max_steps = 30_000
108 | self.feature_lr = 0.0025
109 | self.opacity_lr = 0.05
110 | self.scaling_lr = 0.005
111 | self.rotation_lr = 0.001
112 | self.percent_dense = 0.01
113 | self.lambda_dssim = 0.2
114 | self.lambda_depth = 0.0
115 | self.lambda_dist_2nd_loss = 0.0
116 | self.lambda_pc = 0.0
117 | self.lambda_rgb_s = 0.0
118 | self.depth_loss_type = "invariant"
119 | # self.depth_loss_type = "l1"
120 | self.match_method = "dense"
121 | self.densification_interval = 100
122 | self.densify_interval = 500
123 | self.prune_interval = 2000
124 | self.opacity_reset_interval = 3000
125 | self.densify_from_iter = 500
126 | self.densify_until_iter = 15_000
127 | self.reset_until_iter = 15_000
128 | self.densify_grad_threshold = 0.0002
129 | super().__init__(parser, "Optimization Parameters")
130 |
131 |
132 | def get_combined_args(parser: ArgumentParser):
133 | cmdlne_string = sys.argv[1:]
134 | cfgfile_string = "Namespace()"
135 | args_cmdline = parser.parse_args(cmdlne_string)
136 |
137 | try:
138 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
139 | print("Looking for config file in", cfgfilepath)
140 | with open(cfgfilepath) as cfg_file:
141 | print("Config file found: {}".format(cfgfilepath))
142 | cfgfile_string = cfg_file.read()
143 | except TypeError:
144 | print("Config file not found at")
145 | pass
146 | args_cfgfile = eval(cfgfile_string)
147 |
148 | merged_dict = vars(args_cfgfile).copy()
149 | for k, v in vars(args_cmdline).items():
150 | if v != None:
151 | merged_dict[k] = v
152 | return Namespace(**merged_dict)
153 |
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/CF-3DGS/886df529c7dabd337f0abbd660e81321a9a1c047/assets/teaser.gif
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import os
11 | import logging
12 | from argparse import ArgumentParser
13 | import shutil
14 |
15 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
16 | parser = ArgumentParser("Colmap converter")
17 | parser.add_argument("--no_gpu", action='store_true')
18 | parser.add_argument("--skip_matching", action='store_true')
19 | parser.add_argument("--source_path", "-s", required=True, type=str)
20 | parser.add_argument("--camera", default="OPENCV", type=str)
21 | parser.add_argument("--colmap_executable", default="", type=str)
22 | parser.add_argument("--resize", action="store_true")
23 | parser.add_argument("--magick_executable", default="", type=str)
24 | args = parser.parse_args()
25 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
26 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
27 | use_gpu = 1 if not args.no_gpu else 0
28 |
29 | if not args.skip_matching:
30 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
31 |
32 | ## Feature extraction
33 | feat_extracton_cmd = colmap_command + " feature_extractor "\
34 | "--database_path " + args.source_path + "/distorted/database.db \
35 | --image_path " + args.source_path + "/input \
36 | --ImageReader.single_camera 1 \
37 | --ImageReader.camera_model " + args.camera + " \
38 | --SiftExtraction.use_gpu " + str(use_gpu)
39 | exit_code = os.system(feat_extracton_cmd)
40 | if exit_code != 0:
41 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
42 | exit(exit_code)
43 |
44 | ## Feature matching
45 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
46 | --database_path " + args.source_path + "/distorted/database.db \
47 | --SiftMatching.use_gpu " + str(use_gpu)
48 | exit_code = os.system(feat_matching_cmd)
49 | if exit_code != 0:
50 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
51 | exit(exit_code)
52 |
53 | ### Bundle adjustment
54 | # The default Mapper tolerance is unnecessarily large,
55 | # decreasing it speeds up bundle adjustment steps.
56 | mapper_cmd = (colmap_command + " mapper \
57 | --database_path " + args.source_path + "/distorted/database.db \
58 | --image_path " + args.source_path + "/input \
59 | --output_path " + args.source_path + "/distorted/sparse \
60 | --Mapper.ba_global_function_tolerance=0.000001")
61 | exit_code = os.system(mapper_cmd)
62 | if exit_code != 0:
63 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
64 | exit(exit_code)
65 |
66 | ### Image undistortion
67 | ## We need to undistort our images into ideal pinhole intrinsics.
68 | img_undist_cmd = (colmap_command + " image_undistorter \
69 | --image_path " + args.source_path + "/input \
70 | --input_path " + args.source_path + "/distorted/sparse/0 \
71 | --output_path " + args.source_path + "\
72 | --output_type COLMAP")
73 | exit_code = os.system(img_undist_cmd)
74 | if exit_code != 0:
75 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
76 | exit(exit_code)
77 |
78 | files = os.listdir(args.source_path + "/sparse")
79 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
80 | # Copy each file from the source directory to the destination directory
81 | for file in files:
82 | if file == '0':
83 | continue
84 | source_file = os.path.join(args.source_path, "sparse", file)
85 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
86 | shutil.move(source_file, destination_file)
87 |
88 | if(args.resize):
89 | print("Copying and resizing...")
90 |
91 | # Resize images.
92 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
93 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
94 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
95 | # Get the list of files in the source directory
96 | files = os.listdir(args.source_path + "/images")
97 | # Copy each file from the source directory to the destination directory
98 | for file in files:
99 | source_file = os.path.join(args.source_path, "images", file)
100 |
101 | destination_file = os.path.join(args.source_path, "images_2", file)
102 | shutil.copy2(source_file, destination_file)
103 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
104 | if exit_code != 0:
105 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
106 | exit(exit_code)
107 |
108 | destination_file = os.path.join(args.source_path, "images_4", file)
109 | shutil.copy2(source_file, destination_file)
110 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
111 | if exit_code != 0:
112 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
113 | exit(exit_code)
114 |
115 | destination_file = os.path.join(args.source_path, "images_8", file)
116 | shutil.copy2(source_file, destination_file)
117 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
118 | if exit_code != 0:
119 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
120 | exit(exit_code)
121 |
122 | print("Done.")
123 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gaussian_splatting
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.6
8 | - plyfile=0.8.1
9 | - python=3.7.13
10 | - pip=22.3.1
11 | - pytorch=1.12.1
12 | - torchaudio=0.12.1
13 | - torchvision=0.13.1
14 | - tqdm
15 | - pip:
16 | - submodules/diff-gaussian-rasterization
17 | - submodules/simple-knn
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import torch
11 | import math
12 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
13 | from scene.gaussian_model import GaussianModel
14 | from utils.sh_utils import eval_sh
15 | from utils.image_utils import colorize
16 | from PIL import Image
17 | import pdb
18 |
19 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
20 | scaling_modifier = 1.0, override_color = None):
21 | """
22 | Render the scene.
23 |
24 | Background tensor (bg_color) must be on GPU!
25 | """
26 |
27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
28 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
29 | try:
30 | screenspace_points.retain_grad()
31 | except:
32 | pass
33 |
34 | # Set up rasterization configuration
35 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
36 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
37 |
38 | raster_settings = GaussianRasterizationSettings(
39 | image_height=int(viewpoint_camera.image_height),
40 | image_width=int(viewpoint_camera.image_width),
41 | tanfovx=tanfovx,
42 | tanfovy=tanfovy,
43 | bg=bg_color,
44 | scale_modifier=scaling_modifier,
45 | viewmatrix=viewpoint_camera.world_view_transform,
46 | projmatrix=viewpoint_camera.full_proj_transform,
47 | sh_degree=pc.active_sh_degree,
48 | campos=viewpoint_camera.camera_center,
49 | prefiltered=False,
50 | debug=pipe.debug
51 | )
52 |
53 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
54 |
55 | means3D = pc.get_xyz
56 | means2D = screenspace_points
57 | opacity = pc.get_opacity
58 |
59 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
60 | # scaling / rotation by the rasterizer.
61 | scales = None
62 | rotations = None
63 | cov3D_precomp = None
64 | if pipe.compute_cov3D_python:
65 | cov3D_precomp = pc.get_covariance(scaling_modifier)
66 | else:
67 | scales = pc.get_scaling
68 | rotations = pc.get_rotation
69 |
70 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
71 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
72 | shs = None
73 | colors_precomp = None
74 | if override_color is None:
75 | if pipe.convert_SHs_python:
76 | pc_features = pc.get_features.transpose(1, 2)
77 | shs_view = pc_features.view(pc_features.shape[0], -1, (pc.max_sh_degree+1)**2)
78 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
79 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
80 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
81 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
82 | else:
83 | shs = pc.get_features
84 | else:
85 | colors_precomp = override_color
86 |
87 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
88 | render_out = rasterizer(
89 | means3D = means3D,
90 | means2D = means2D,
91 | shs = shs,
92 | colors_precomp = colors_precomp,
93 | opacities = opacity,
94 | scales = scales,
95 | rotations = rotations,
96 | cov3D_precomp = cov3D_precomp)
97 |
98 | if len(render_out) > 2:
99 | rendered_image, radii, rendered_depth, rendered_alpha = render_out
100 | return {"render": rendered_image,
101 | "viewspace_points": screenspace_points,
102 | "visibility_filter" : radii > 0,
103 | "radii": radii,
104 | "depth": rendered_depth,
105 | "alpha": rendered_alpha}
106 | else:
107 | rendered_image, radii = render_out
108 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
109 | # They will be excluded from value updates used in the splitting criteria.
110 | return {"render": rendered_image,
111 | "viewspace_points": screenspace_points,
112 | "visibility_filter" : radii > 0,
113 | "radii": radii}
114 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import torch
11 | import traceback
12 | import socket
13 | import json
14 | from scene.cameras import MiniCam
15 |
16 | host = "127.0.0.1"
17 | port = 6009
18 |
19 | conn = None
20 | addr = None
21 |
22 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
23 |
24 | def init(wish_host, wish_port):
25 | global host, port, listener
26 | host = wish_host
27 | port = wish_port
28 | listener.bind((host, port))
29 | listener.listen()
30 | listener.settimeout(0)
31 |
32 | def try_connect():
33 | global conn, addr, listener
34 | try:
35 | conn, addr = listener.accept()
36 | print(f"\nConnected by {addr}")
37 | conn.settimeout(None)
38 | except Exception as inst:
39 | pass
40 |
41 | def read():
42 | global conn
43 | messageLength = conn.recv(4)
44 | messageLength = int.from_bytes(messageLength, 'little')
45 | message = conn.recv(messageLength)
46 | return json.loads(message.decode("utf-8"))
47 |
48 | def send(message_bytes, verify):
49 | global conn
50 | if message_bytes != None:
51 | conn.sendall(message_bytes)
52 | conn.sendall(len(verify).to_bytes(4, 'little'))
53 | conn.sendall(bytes(verify, 'ascii'))
54 |
55 | def receive():
56 | message = read()
57 |
58 | width = message["resolution_x"]
59 | height = message["resolution_y"]
60 |
61 | if width != 0 and height != 0:
62 | try:
63 | do_training = bool(message["train"])
64 | fovy = message["fov_y"]
65 | fovx = message["fov_x"]
66 | znear = message["z_near"]
67 | zfar = message["z_far"]
68 | do_shs_python = bool(message["shs_python"])
69 | do_rot_scale_python = bool(message["rot_scale_python"])
70 | keep_alive = bool(message["keep_alive"])
71 | scaling_modifier = message["scaling_modifier"]
72 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
73 | world_view_transform[:,1] = -world_view_transform[:,1]
74 | world_view_transform[:,2] = -world_view_transform[:,2]
75 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
76 | full_proj_transform[:,1] = -full_proj_transform[:,1]
77 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
78 | except Exception as e:
79 | print("")
80 | traceback.print_exc()
81 | raise e
82 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
83 | else:
84 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | from pathlib import Path
11 | import os
12 | from PIL import Image
13 | import torch
14 | import torchvision.transforms.functional as tf
15 | from utils.loss_utils import ssim
16 | from lpipsPyTorch import lpips
17 | import json
18 | from tqdm import tqdm
19 | from utils.image_utils import psnr
20 | from argparse import ArgumentParser
21 |
22 | def readImages(renders_dir, gt_dir):
23 | renders = []
24 | gts = []
25 | image_names = []
26 | for fname in os.listdir(renders_dir):
27 | render = Image.open(renders_dir / fname)
28 | gt = Image.open(gt_dir / fname)
29 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
30 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
31 | image_names.append(fname)
32 | return renders, gts, image_names
33 |
34 | def evaluate(model_paths):
35 |
36 | full_dict = {}
37 | per_view_dict = {}
38 | full_dict_polytopeonly = {}
39 | per_view_dict_polytopeonly = {}
40 | print("")
41 |
42 | for scene_dir in model_paths:
43 | try:
44 | print("Scene:", scene_dir)
45 | full_dict[scene_dir] = {}
46 | per_view_dict[scene_dir] = {}
47 | full_dict_polytopeonly[scene_dir] = {}
48 | per_view_dict_polytopeonly[scene_dir] = {}
49 |
50 | test_dir = Path(scene_dir) / "test"
51 |
52 | for method in os.listdir(test_dir):
53 | print("Method:", method)
54 |
55 | full_dict[scene_dir][method] = {}
56 | per_view_dict[scene_dir][method] = {}
57 | full_dict_polytopeonly[scene_dir][method] = {}
58 | per_view_dict_polytopeonly[scene_dir][method] = {}
59 |
60 | method_dir = test_dir / method
61 | gt_dir = method_dir/ "gt"
62 | renders_dir = method_dir / "renders"
63 | renders, gts, image_names = readImages(renders_dir, gt_dir)
64 |
65 | ssims = []
66 | psnrs = []
67 | lpipss = []
68 |
69 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
70 | ssims.append(ssim(renders[idx], gts[idx]))
71 | psnrs.append(psnr(renders[idx], gts[idx]))
72 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
73 |
74 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
75 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
76 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
77 | print("")
78 |
79 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
80 | "PSNR": torch.tensor(psnrs).mean().item(),
81 | "LPIPS": torch.tensor(lpipss).mean().item()})
82 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
83 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
84 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
85 |
86 | with open(scene_dir + "/results.json", 'w') as fp:
87 | json.dump(full_dict[scene_dir], fp, indent=True)
88 | with open(scene_dir + "/per_view.json", 'w') as fp:
89 | json.dump(per_view_dict[scene_dir], fp, indent=True)
90 | except:
91 | print("Unable to compute metrics for model", scene_dir)
92 |
93 | if __name__ == "__main__":
94 | device = torch.device("cuda:0")
95 | torch.cuda.set_device(device)
96 |
97 | # Set up command line argument parser
98 | parser = ArgumentParser(description="Training script parameters")
99 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
100 | args = parser.parse_args()
101 | evaluate(args.model_paths)
102 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import torch
11 | from scene import Scene
12 | import os
13 | from tqdm import tqdm
14 | from os import makedirs
15 | from gaussian_renderer import render
16 | import torchvision
17 | from utils.general_utils import safe_state
18 | from argparse import ArgumentParser
19 | from arguments import ModelParams, PipelineParams, get_combined_args
20 | from gaussian_renderer import GaussianModel
21 |
22 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
23 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
24 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
25 |
26 | makedirs(render_path, exist_ok=True)
27 | makedirs(gts_path, exist_ok=True)
28 |
29 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
30 | rendering = render(view, gaussians, pipeline, background)["render"]
31 | gt = view.original_image[0:3, :, :]
32 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
33 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
34 |
35 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
36 | with torch.no_grad():
37 | gaussians = GaussianModel(dataset.sh_degree)
38 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
39 |
40 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
41 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
42 |
43 | if not skip_train:
44 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
45 |
46 | if not skip_test:
47 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
48 |
49 | if __name__ == "__main__":
50 | # Set up command line argument parser
51 | parser = ArgumentParser(description="Testing script parameters")
52 | model = ModelParams(parser, sentinel=True)
53 | pipeline = PipelineParams(parser)
54 | parser.add_argument("--iteration", default=-1, type=int)
55 | parser.add_argument("--skip_train", action="store_true")
56 | parser.add_argument("--skip_test", action="store_true")
57 | parser.add_argument("--quiet", action="store_true")
58 | args = get_combined_args(parser)
59 | print("Rendering " + args.model_path)
60 |
61 | # Initialize system state (RNG)
62 | safe_state(args.quiet)
63 |
64 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.65.0
2 | Pillow
3 | opencv-python
4 | imageio
5 | matplotlib==3.7
6 | pyyaml
7 | plyfile
8 | kornia
9 | open3d
10 | timm==0.6.11
11 | submodules/diff-gaussian-rasterization
12 | submodules/simple-knn
13 | git+https://github.com/facebookresearch/pytorch3d.git@stable
14 | git+https://github.com/princeton-vl/lietorch.git
--------------------------------------------------------------------------------
/run_cf3dgs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import sys
11 | from argparse import ArgumentParser, Namespace
12 |
13 | from trainer.cf3dgs_trainer import CFGaussianTrainer
14 | from arguments import ModelParams, PipelineParams, OptimizationParams
15 |
16 | import torch
17 | import pdb
18 | from datetime import datetime
19 |
20 |
21 | def contruct_pose(poses):
22 | n_trgt = poses.shape[0]
23 | # for i in range(n_trgt-1):
24 | # poses[i+1] = poses[i+1] @ torch.inverse(poses[i])
25 | for i in range(n_trgt-1, 0, -1):
26 | poses = torch.cat(
27 | (poses[:i], poses[[i-1]]@poses[i:]), 0)
28 | return poses
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = ArgumentParser(description="Training script parameters")
33 | lp = ModelParams(parser)
34 | op = OptimizationParams(parser)
35 | pp = PipelineParams(parser)
36 | args = parser.parse_args(sys.argv[1:])
37 | model_cfg = lp.extract(args)
38 | pipe_cfg = pp.extract(args)
39 | optim_cfg = op.extract(args)
40 | # hydrant/615_99120_197713
41 | # hydrant/106_12648_23157
42 | # teddybear/34_1403_4393
43 | data_path = model_cfg.source_path
44 | trainer = CFGaussianTrainer(data_path, model_cfg, pipe_cfg, optim_cfg)
45 | start_time = datetime.now()
46 | if model_cfg.mode == "train":
47 | trainer.train_from_progressive()
48 | elif model_cfg.mode == "render":
49 | trainer.render_nvs(traj_opt=model_cfg.traj_opt)
50 | elif model_cfg.mode == "eval_nvs":
51 | trainer.eval_nvs()
52 | elif model_cfg.mode == "eval_pose":
53 | trainer.eval_pose()
54 | end_time = datetime.now()
55 | print('Duration: {}'.format(end_time - start_time))
56 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from arguments import ModelParams
19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20 |
21 | class Scene:
22 |
23 | gaussians : GaussianModel
24 |
25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
26 | """b
27 | :param path: Path to colmap scene main folder.
28 | """
29 | self.model_path = args.model_path
30 | self.loaded_iter = None
31 | self.gaussians = gaussians
32 |
33 | if load_iteration:
34 | if load_iteration == -1:
35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36 | else:
37 | self.loaded_iter = load_iteration
38 | print("Loading trained model at iteration {}".format(self.loaded_iter))
39 |
40 | self.train_cameras = {}
41 | self.test_cameras = {}
42 |
43 | if os.path.exists(os.path.join(args.source_path, "sparse")):
44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
46 | print("Found transforms_train.json file, assuming Blender data set!")
47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
48 | else:
49 | assert False, "Could not recognize scene type!"
50 |
51 | if not self.loaded_iter:
52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
53 | dest_file.write(src_file.read())
54 | json_cams = []
55 | camlist = []
56 | if scene_info.test_cameras:
57 | camlist.extend(scene_info.test_cameras)
58 | if scene_info.train_cameras:
59 | camlist.extend(scene_info.train_cameras)
60 | for id, cam in enumerate(camlist):
61 | json_cams.append(camera_to_JSON(id, cam))
62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
63 | json.dump(json_cams, file)
64 |
65 | if shuffle:
66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
68 |
69 | self.cameras_extent = scene_info.nerf_normalization["radius"]
70 |
71 | for resolution_scale in resolution_scales:
72 | print("Loading Training Cameras")
73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
74 | print("Loading Test Cameras")
75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
76 |
77 | if self.loaded_iter:
78 | self.gaussians.load_ply(os.path.join(self.model_path,
79 | "point_cloud",
80 | "iteration_" + str(self.loaded_iter),
81 | "point_cloud.ply"))
82 | else:
83 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
84 |
85 | def save(self, iteration):
86 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
87 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
88 |
89 | def getTrainCameras(self, scale=1.0):
90 | return self.train_cameras[scale]
91 |
92 | def getTestCameras(self, scale=1.0):
93 | return self.test_cameras[scale]
--------------------------------------------------------------------------------
/scene/camera_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | from lietorch import SO3, SE3
13 | class CameraModel(nn.Module):
14 | def __init__(self, num_cams, pose_dim=9, init_pose=None):
15 | """
16 | :param num_cams:
17 | :param learn_R: True/False
18 | :param learn_t: True/False
19 | :param cfg: config argument options
20 | :param init_c2w: (N, 4, 4) torch tensor
21 | """
22 | super(CameraModel, self).__init__()
23 | self.register_buffer('init_pose', init_pose)
24 |
25 | img_emb_dim = 1
26 | pose_dim = 9
27 | self.img_embedding = nn.Embedding(num_cams, img_emb_dim)
28 | self.cam_mlp = nn.Sequential(
29 | nn.Linear(img_emb_dim+pose_dim, 32),
30 | nn.ReLU(),
31 | nn.Linear(32, pose_dim))
32 | self.factor = 0.01
33 |
34 | def forward(self, cam_id):
35 | cam_id = torch.int(cam_id).cuda()
36 | img_emb = self.img_embedding(cam_id)
37 | init_cam = self.init_pose[cam_id]
38 | cam_residual = self.cam_mlp(torch.cat([img_emb, init_cam], dim=1))
39 | cam = init_cam + self.factor * cam_residual
40 | cam = SE3(cam)
41 |
42 | return cam
43 |
44 |
45 | def vec2skew(v):
46 | """
47 | :param v: (3, ) torch tensor
48 | :return: (3, 3)
49 | """
50 | zero = torch.zeros(1, dtype=torch.float32, device=v.device)
51 | skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1)
52 | skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]])
53 | skew_v2 = torch.cat([-v[1:2], v[0:1], zero])
54 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3)
55 | return skew_v # (3, 3)
56 |
57 | def Exp(r):
58 | """so(3) vector to SO(3) matrix
59 | :param r: (3, ) axis-angle, torch tensor
60 | :return: (3, 3)
61 | """
62 | skew_r = vec2skew(r) # (3, 3)
63 | norm_r = r.norm() + 1e-15
64 | eye = torch.eye(3, dtype=torch.float32, device=r.device)
65 | R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
66 | return R
67 |
68 | def make_c2w(r, t):
69 | """
70 | :param r: (3, ) axis-angle torch tensor
71 | :param t: (3, ) translation vector torch tensor
72 | :return: (4, 4)
73 | """
74 | R = Exp(r) # (3, 3)
75 | c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4)
76 | c2w = convert3x4_4x4(c2w) # (4, 4)
77 | return c2w
78 |
79 | def convert3x4_4x4(input):
80 | """
81 | :param input: (N, 3, 4) or (3, 4) torch or np
82 | :return: (N, 4, 4) or (4, 4) torch or np
83 | """
84 | if torch.is_tensor(input):
85 | if len(input.shape) == 3:
86 | output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4)
87 | output[:, 3, 3] = 1.0
88 | else:
89 | output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
90 | else:
91 | if len(input.shape) == 3:
92 | output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
93 | output[:, 3, 3] = 1.0
94 | else:
95 | output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4)
96 | output[3, 3] = 1.0
97 | return output
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 |
11 | import torch
12 | from torch import nn
13 | import numpy as np
14 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getWorld2View3
15 |
16 | class Camera(nn.Module):
17 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
18 | image_name, uid, intrinsics=None,
19 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda",
20 | do_grad=False, is_co3d=False,
21 | ):
22 | super(Camera, self).__init__()
23 |
24 | self.uid = uid
25 | self.colmap_id = colmap_id
26 | self.R = R
27 | self.T = T
28 | self.FoVx = FoVx
29 | self.FoVy = FoVy
30 | self.image_name = image_name
31 | self.intrinsics = intrinsics.astype(np.float32)
32 |
33 | try:
34 | self.data_device = torch.device(data_device)
35 | except Exception as e:
36 | print(e)
37 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
38 | self.data_device = torch.device("cuda")
39 |
40 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
41 | self.image_width = self.original_image.shape[2]
42 | self.image_height = self.original_image.shape[1]
43 |
44 | if gt_alpha_mask is not None:
45 | self.original_image *= gt_alpha_mask.to(self.data_device)
46 | else:
47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
48 |
49 | self.gt_alpha_mask = gt_alpha_mask
50 |
51 | self.zfar = 100.0
52 | self.znear = 0.01
53 |
54 | self.trans = trans
55 | self.scale = scale
56 |
57 | if is_co3d:
58 | self.world_view_transform = torch.tensor(getWorld2View3(R, T, trans, scale)).transpose(0, 1).cuda()
59 | w, h = self.image_width, self.image_height
60 | fx, fy, cx, cy = self.intrinsics[0, 0], self.intrinsics[1, 1], \
61 | self.intrinsics[0, 2], self.intrinsics[1, 2]
62 | far, near = self.zfar, self.znear
63 | opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],
64 | [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],
65 | [0.0, 0.0, far / (far - near), -(far * near) / (far - near)],
66 | [0.0, 0.0, 1.0, 0.0]]).cuda().float().transpose(0, 1)
67 | self.projection_matrix = opengl_proj
68 |
69 | else:
70 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
71 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
72 |
73 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
74 | self.camera_center = self.world_view_transform.inverse()[3, :3]
75 |
76 | class MiniCam:
77 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
78 | self.image_width = width
79 | self.image_height = height
80 | self.FoVy = fovy
81 | self.FoVx = fovx
82 | self.znear = znear
83 | self.zfar = zfar
84 | self.world_view_transform = world_view_transform
85 | self.full_proj_transform = full_proj_transform
86 | view_inv = torch.inverse(self.world_view_transform)
87 | self.camera_center = view_inv[3][:3]
88 |
89 |
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import numpy as np
11 | import collections
12 | import struct
13 |
14 | CameraModel = collections.namedtuple(
15 | "CameraModel", ["model_id", "model_name", "num_params"])
16 | Camera = collections.namedtuple(
17 | "Camera", ["id", "model", "width", "height", "params"])
18 | BaseImage = collections.namedtuple(
19 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
20 | Point3D = collections.namedtuple(
21 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
22 | CAMERA_MODELS = {
23 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
24 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
25 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
26 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
27 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
28 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
29 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
30 | CameraModel(model_id=7, model_name="FOV", num_params=5),
31 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
32 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
33 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
34 | }
35 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
36 | for camera_model in CAMERA_MODELS])
37 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 |
40 |
41 | def qvec2rotmat(qvec):
42 | return np.array([
43 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
44 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
45 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
46 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
47 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
48 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
49 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
50 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
51 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
52 |
53 | def rotmat2qvec(R):
54 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
55 | K = np.array([
56 | [Rxx - Ryy - Rzz, 0, 0, 0],
57 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
58 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
59 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
60 | eigvals, eigvecs = np.linalg.eigh(K)
61 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
62 | if qvec[0] < 0:
63 | qvec *= -1
64 | return qvec
65 |
66 | class Image(BaseImage):
67 | def qvec2rotmat(self):
68 | return qvec2rotmat(self.qvec)
69 |
70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
71 | """Read and unpack the next bytes from a binary file.
72 | :param fid:
73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
75 | :param endian_character: Any of {@, =, <, >, !}
76 | :return: Tuple of read and unpacked values.
77 | """
78 | data = fid.read(num_bytes)
79 | return struct.unpack(endian_character + format_char_sequence, data)
80 |
81 | def read_points3D_text(path):
82 | """
83 | see: src/base/reconstruction.cc
84 | void Reconstruction::ReadPoints3DText(const std::string& path)
85 | void Reconstruction::WritePoints3DText(const std::string& path)
86 | """
87 | xyzs = None
88 | rgbs = None
89 | errors = None
90 | with open(path, "r") as fid:
91 | while True:
92 | line = fid.readline()
93 | if not line:
94 | break
95 | line = line.strip()
96 | if len(line) > 0 and line[0] != "#":
97 | elems = line.split()
98 | xyz = np.array(tuple(map(float, elems[1:4])))
99 | rgb = np.array(tuple(map(int, elems[4:7])))
100 | error = np.array(float(elems[7]))
101 | if xyzs is None:
102 | xyzs = xyz[None, ...]
103 | rgbs = rgb[None, ...]
104 | errors = error[None, ...]
105 | else:
106 | xyzs = np.append(xyzs, xyz[None, ...], axis=0)
107 | rgbs = np.append(rgbs, rgb[None, ...], axis=0)
108 | errors = np.append(errors, error[None, ...], axis=0)
109 | return xyzs, rgbs, errors
110 |
111 | def read_points3D_binary(path_to_model_file):
112 | """
113 | see: src/base/reconstruction.cc
114 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
115 | void Reconstruction::WritePoints3DBinary(const std::string& path)
116 | """
117 |
118 |
119 | with open(path_to_model_file, "rb") as fid:
120 | num_points = read_next_bytes(fid, 8, "Q")[0]
121 |
122 | xyzs = np.empty((num_points, 3))
123 | rgbs = np.empty((num_points, 3))
124 | errors = np.empty((num_points, 1))
125 |
126 | for p_id in range(num_points):
127 | binary_point_line_properties = read_next_bytes(
128 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
129 | xyz = np.array(binary_point_line_properties[1:4])
130 | rgb = np.array(binary_point_line_properties[4:7])
131 | error = np.array(binary_point_line_properties[7])
132 | track_length = read_next_bytes(
133 | fid, num_bytes=8, format_char_sequence="Q")[0]
134 | track_elems = read_next_bytes(
135 | fid, num_bytes=8*track_length,
136 | format_char_sequence="ii"*track_length)
137 | xyzs[p_id] = xyz
138 | rgbs[p_id] = rgb
139 | errors[p_id] = error
140 | return xyzs, rgbs, errors
141 |
142 | def read_intrinsics_text(path):
143 | """
144 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
145 | """
146 | cameras = {}
147 | with open(path, "r") as fid:
148 | while True:
149 | line = fid.readline()
150 | if not line:
151 | break
152 | line = line.strip()
153 | if len(line) > 0 and line[0] != "#":
154 | elems = line.split()
155 | camera_id = int(elems[0])
156 | model = elems[1]
157 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
158 | width = int(elems[2])
159 | height = int(elems[3])
160 | params = np.array(tuple(map(float, elems[4:])))
161 | cameras[camera_id] = Camera(id=camera_id, model=model,
162 | width=width, height=height,
163 | params=params)
164 | return cameras
165 |
166 | def read_extrinsics_binary(path_to_model_file):
167 | """
168 | see: src/base/reconstruction.cc
169 | void Reconstruction::ReadImagesBinary(const std::string& path)
170 | void Reconstruction::WriteImagesBinary(const std::string& path)
171 | """
172 | images = {}
173 | with open(path_to_model_file, "rb") as fid:
174 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
175 | for _ in range(num_reg_images):
176 | binary_image_properties = read_next_bytes(
177 | fid, num_bytes=64, format_char_sequence="idddddddi")
178 | image_id = binary_image_properties[0]
179 | qvec = np.array(binary_image_properties[1:5])
180 | tvec = np.array(binary_image_properties[5:8])
181 | camera_id = binary_image_properties[8]
182 | image_name = ""
183 | current_char = read_next_bytes(fid, 1, "c")[0]
184 | while current_char != b"\x00": # look for the ASCII 0 entry
185 | image_name += current_char.decode("utf-8")
186 | current_char = read_next_bytes(fid, 1, "c")[0]
187 | num_points2D = read_next_bytes(fid, num_bytes=8,
188 | format_char_sequence="Q")[0]
189 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
190 | format_char_sequence="ddq"*num_points2D)
191 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
192 | tuple(map(float, x_y_id_s[1::3]))])
193 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
194 | images[image_id] = Image(
195 | id=image_id, qvec=qvec, tvec=tvec,
196 | camera_id=camera_id, name=image_name,
197 | xys=xys, point3D_ids=point3D_ids)
198 | return images
199 |
200 |
201 | def read_intrinsics_binary(path_to_model_file):
202 | """
203 | see: src/base/reconstruction.cc
204 | void Reconstruction::WriteCamerasBinary(const std::string& path)
205 | void Reconstruction::ReadCamerasBinary(const std::string& path)
206 | """
207 | cameras = {}
208 | with open(path_to_model_file, "rb") as fid:
209 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
210 | for _ in range(num_cameras):
211 | camera_properties = read_next_bytes(
212 | fid, num_bytes=24, format_char_sequence="iiQQ")
213 | camera_id = camera_properties[0]
214 | model_id = camera_properties[1]
215 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
216 | width = camera_properties[2]
217 | height = camera_properties[3]
218 | num_params = CAMERA_MODEL_IDS[model_id].num_params
219 | params = read_next_bytes(fid, num_bytes=8*num_params,
220 | format_char_sequence="d"*num_params)
221 | cameras[camera_id] = Camera(id=camera_id,
222 | model=model_name,
223 | width=width,
224 | height=height,
225 | params=np.array(params))
226 | assert len(cameras) == num_cameras
227 | return cameras
228 |
229 |
230 | def read_extrinsics_text(path):
231 | """
232 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
233 | """
234 | images = {}
235 | with open(path, "r") as fid:
236 | while True:
237 | line = fid.readline()
238 | if not line:
239 | break
240 | line = line.strip()
241 | if len(line) > 0 and line[0] != "#":
242 | elems = line.split()
243 | image_id = int(elems[0])
244 | qvec = np.array(tuple(map(float, elems[1:5])))
245 | tvec = np.array(tuple(map(float, elems[5:8])))
246 | camera_id = int(elems[8])
247 | image_name = elems[9]
248 | elems = fid.readline().split()
249 | xys = np.column_stack([tuple(map(float, elems[0::3])),
250 | tuple(map(float, elems[1::3]))])
251 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
252 | images[image_id] = Image(
253 | id=image_id, qvec=qvec, tvec=tvec,
254 | camera_id=camera_id, name=image_name,
255 | xys=xys, point3D_ids=point3D_ids)
256 | return images
257 |
258 |
259 | def read_colmap_bin_array(path):
260 | """
261 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
262 |
263 | :param path: path to the colmap binary file.
264 | :return: nd array with the floating point values in the value
265 | """
266 | with open(path, "rb") as fid:
267 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
268 | usecols=(0, 1, 2), dtype=int)
269 | fid.seek(0)
270 | num_delimiter = 0
271 | byte = fid.read(1)
272 | while True:
273 | if byte == b"&":
274 | num_delimiter += 1
275 | if num_delimiter >= 3:
276 | break
277 | byte = fid.read(1)
278 | array = np.fromfile(fid, np.float32)
279 | array = array.reshape((width, height, channels), order="F")
280 | return np.transpose(array, (1, 0, 2)).squeeze()
281 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import os
11 | import sys
12 | from PIL import Image
13 | from typing import NamedTuple
14 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
15 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
16 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
17 | import numpy as np
18 | import json
19 | from pathlib import Path
20 | from plyfile import PlyData, PlyElement
21 | from utils.sh_utils import SH2RGB
22 | from scene.gaussian_model import BasicPointCloud
23 |
24 |
25 | class CameraInfo(NamedTuple):
26 | uid: int
27 | R: np.array
28 | T: np.array
29 | FovY: np.array
30 | FovX: np.array
31 | image: np.array
32 | intrinsics: np.array
33 | image_path: str
34 | image_name: str
35 | width: int
36 | height: int
37 |
38 |
39 | class SceneInfo(NamedTuple):
40 | point_cloud: BasicPointCloud
41 | train_cameras: list
42 | test_cameras: list
43 | nerf_normalization: dict
44 | ply_path: str
45 |
46 |
47 | def getNerfppNorm(cam_info):
48 | def get_center_and_diag(cam_centers):
49 | cam_centers = np.hstack(cam_centers)
50 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
51 | center = avg_cam_center
52 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
53 | diagonal = np.max(dist)
54 | return center.flatten(), diagonal
55 |
56 | cam_centers = []
57 |
58 | for cam in cam_info:
59 | W2C = getWorld2View2(cam.R, cam.T)
60 | C2W = np.linalg.inv(W2C)
61 | cam_centers.append(C2W[:3, 3:4])
62 |
63 | center, diagonal = get_center_and_diag(cam_centers)
64 | radius = diagonal * 1.1
65 |
66 | translate = -center
67 |
68 | return {"translate": translate, "radius": radius}
69 |
70 |
71 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
72 | cam_infos = []
73 | for idx, key in enumerate(cam_extrinsics):
74 | sys.stdout.write('\r')
75 | # the exact output you're looking for:
76 | sys.stdout.write(
77 | "Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
78 | sys.stdout.flush()
79 |
80 | extr = cam_extrinsics[key]
81 | intr = cam_intrinsics[extr.camera_id]
82 | height = intr.height
83 | width = intr.width
84 |
85 | uid = intr.id
86 | R = np.transpose(qvec2rotmat(extr.qvec))
87 | T = np.array(extr.tvec)
88 | if intr.model == "SIMPLE_PINHOLE" or intr.model == "SIMPLE_RADIAL":
89 | focal_length_x = intr.params[0]
90 | FovY = focal2fov(focal_length_x, height)
91 | FovX = focal2fov(focal_length_x, width)
92 | intr_mat = np.array(
93 | [[focal_length_x, 0, width/2], [0, focal_length_x, height/2], [0, 0, 1]])
94 | elif intr.model == "PINHOLE":
95 | focal_length_x = intr.params[0]
96 | focal_length_y = intr.params[1]
97 | FovY = focal2fov(focal_length_y, height)
98 | FovX = focal2fov(focal_length_x, width)
99 | intr_mat = np.array(
100 | [[focal_length_x, 0, width/2], [0, focal_length_y, height/2], [0, 0, 1]])
101 | else:
102 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
103 |
104 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
105 | image_name = os.path.basename(image_path).split(".")[0]
106 | image = Image.open(image_path)
107 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, intrinsics=intr_mat,
108 | image_path=image_path, image_name=image_name, width=width, height=height)
109 | cam_infos.append(cam_info)
110 | sys.stdout.write('\n')
111 | return cam_infos
112 |
113 |
114 | def fetchPly(path):
115 | plydata = PlyData.read(path)
116 | vertices = plydata['vertex']
117 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
118 | colors = np.vstack([vertices['red'], vertices['green'],
119 | vertices['blue']]).T / 255.0
120 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
121 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
122 |
123 |
124 | def storePly(path, xyz, rgb):
125 | # Define the dtype for the structured array
126 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
127 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
128 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
129 |
130 | normals = np.zeros_like(xyz)
131 |
132 | elements = np.empty(xyz.shape[0], dtype=dtype)
133 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
134 | elements[:] = list(map(tuple, attributes))
135 |
136 | # Create the PlyData object and write to file
137 | vertex_element = PlyElement.describe(elements, 'vertex')
138 | ply_data = PlyData([vertex_element])
139 | ply_data.write(path)
140 |
141 |
142 | def readColmapSceneInfo(path, images, eval, llffhold=8):
143 | try:
144 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
145 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
146 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
147 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
148 | except:
149 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
150 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
151 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
152 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
153 |
154 | reading_dir = "images" if images == None else images
155 | cam_infos_unsorted = readColmapCameras(
156 | cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
157 | cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
158 |
159 | if eval:
160 | # train_cam_infos = [c for idx, c in enumerate(
161 | # cam_infos) if idx % llffhold != 0]
162 | # test_cam_infos = [c for idx, c in enumerate(
163 | # cam_infos) if idx % llffhold == 0]
164 | sample_rate = 2 if "Family" in path else 8
165 | # sample_rate = 8
166 | ids = np.arange(len(cam_infos))
167 | i_test = ids[int(sample_rate/2)::sample_rate]
168 | i_train = np.array([i for i in ids if i not in i_test])
169 | train_cam_infos = [cam_infos[i] for i in i_train]
170 | test_cam_infos = [cam_infos[i] for i in i_test]
171 | else:
172 | train_cam_infos = cam_infos
173 | test_cam_infos = []
174 |
175 | nerf_normalization = getNerfppNorm(train_cam_infos)
176 |
177 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
178 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
179 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
180 | if not os.path.exists(ply_path):
181 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
182 | try:
183 | xyz, rgb, _ = read_points3D_binary(bin_path)
184 | except:
185 | xyz, rgb, _ = read_points3D_text(txt_path)
186 | storePly(ply_path, xyz, rgb)
187 | try:
188 | pcd = fetchPly(ply_path)
189 | except:
190 | pcd = None
191 |
192 | scene_info = SceneInfo(point_cloud=pcd,
193 | train_cameras=train_cam_infos,
194 | test_cameras=test_cam_infos,
195 | nerf_normalization=nerf_normalization,
196 | ply_path=ply_path)
197 | return scene_info
198 |
199 |
200 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
201 | cam_infos = []
202 |
203 | with open(os.path.join(path, transformsfile)) as json_file:
204 | contents = json.load(json_file)
205 | fovx = contents["camera_angle_x"]
206 |
207 | frames = contents["frames"]
208 | for idx, frame in enumerate(frames):
209 | cam_name = os.path.join(path, frame["file_path"] + extension)
210 |
211 | # NeRF 'transform_matrix' is a camera-to-world transform
212 | c2w = np.array(frame["transform_matrix"])
213 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
214 | c2w[:3, 1:3] *= -1
215 |
216 | # get the world-to-camera transform and set R, T
217 | w2c = np.linalg.inv(c2w)
218 | # R is stored transposed due to 'glm' in CUDA code
219 | R = np.transpose(w2c[:3, :3])
220 | T = w2c[:3, 3]
221 |
222 | image_path = os.path.join(path, cam_name)
223 | image_name = Path(cam_name).stem
224 | image = Image.open(image_path)
225 |
226 | im_data = np.array(image.convert("RGBA"))
227 |
228 | bg = np.array(
229 | [1, 1, 1]) if white_background else np.array([0, 0, 0])
230 |
231 | norm_data = im_data / 255.0
232 | arr = norm_data[:, :, :3] * norm_data[:, :,
233 | 3:4] + bg * (1 - norm_data[:, :, 3:4])
234 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
235 |
236 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
237 | FovY = fovy
238 | FovX = fovx
239 |
240 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
241 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
242 |
243 | return cam_infos
244 |
245 |
246 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
247 | print("Reading Training Transforms")
248 | train_cam_infos = readCamerasFromTransforms(
249 | path, "transforms_train.json", white_background, extension)
250 | print("Reading Test Transforms")
251 | test_cam_infos = readCamerasFromTransforms(
252 | path, "transforms_test.json", white_background, extension)
253 |
254 | if not eval:
255 | train_cam_infos.extend(test_cam_infos)
256 | test_cam_infos = []
257 |
258 | nerf_normalization = getNerfppNorm(train_cam_infos)
259 |
260 | ply_path = os.path.join(path, "points3d.ply")
261 | if not os.path.exists(ply_path):
262 | # Since this data set has no colmap data, we start with random points
263 | num_pts = 100_000
264 | print(f"Generating random point cloud ({num_pts})...")
265 |
266 | # We create random points inside the bounds of the synthetic Blender scenes
267 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
268 | shs = np.random.random((num_pts, 3)) / 255.0
269 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(
270 | shs), normals=np.zeros((num_pts, 3)))
271 |
272 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
273 | try:
274 | pcd = fetchPly(ply_path)
275 | except:
276 | pcd = None
277 |
278 | scene_info = SceneInfo(point_cloud=pcd,
279 | train_cameras=train_cam_infos,
280 | test_cameras=test_cam_infos,
281 | nerf_normalization=nerf_normalization,
282 | ply_path=ply_path)
283 | return scene_info
284 |
285 |
286 | sceneLoadTypeCallbacks = {
287 | "Colmap": readColmapSceneInfo,
288 | "Blender": readNerfSyntheticInfo
289 | }
290 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import os
11 | import torch
12 | from random import randint
13 | from utils.loss_utils import l1_loss, ssim
14 | from gaussian_renderer import render, network_gui
15 | import sys
16 | from scene import Scene, GaussianModel
17 | from utils.general_utils import safe_state
18 | import uuid
19 | from tqdm import tqdm
20 | from utils.image_utils import psnr
21 | from argparse import ArgumentParser, Namespace
22 | from arguments import ModelParams, PipelineParams, OptimizationParams
23 | try:
24 | from torch.utils.tensorboard import SummaryWriter
25 | TENSORBOARD_FOUND = True
26 | except ImportError:
27 | TENSORBOARD_FOUND = False
28 |
29 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
30 | first_iter = 0
31 | tb_writer = prepare_output_and_logger(dataset)
32 | gaussians = GaussianModel(dataset.sh_degree)
33 | scene = Scene(dataset, gaussians)
34 | gaussians.training_setup(opt)
35 | if checkpoint:
36 | (model_params, first_iter) = torch.load(checkpoint)
37 | gaussians.restore(model_params, opt)
38 |
39 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
40 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
41 |
42 | iter_start = torch.cuda.Event(enable_timing = True)
43 | iter_end = torch.cuda.Event(enable_timing = True)
44 |
45 | viewpoint_stack = None
46 | ema_loss_for_log = 0.0
47 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
48 | first_iter += 1
49 | for iteration in range(first_iter, opt.iterations + 1):
50 | if network_gui.conn == None:
51 | network_gui.try_connect()
52 | while network_gui.conn != None:
53 | try:
54 | net_image_bytes = None
55 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
56 | if custom_cam != None:
57 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
58 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
59 | network_gui.send(net_image_bytes, dataset.source_path)
60 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
61 | break
62 | except Exception as e:
63 | network_gui.conn = None
64 |
65 | iter_start.record()
66 |
67 | gaussians.update_learning_rate(iteration)
68 |
69 | # Every 1000 its we increase the levels of SH up to a maximum degree
70 | if iteration % 1000 == 0:
71 | gaussians.oneupSHdegree()
72 |
73 | # Pick a random Camera
74 | if not viewpoint_stack:
75 | viewpoint_stack = scene.getTrainCameras().copy()
76 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
77 |
78 | # Render
79 | if (iteration - 1) == debug_from:
80 | pipe.debug = True
81 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
82 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
83 |
84 | # Loss
85 | gt_image = viewpoint_cam.original_image.cuda()
86 | Ll1 = l1_loss(image, gt_image)
87 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
88 | loss.backward()
89 |
90 | iter_end.record()
91 |
92 | with torch.no_grad():
93 | # Progress bar
94 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
95 | if iteration % 10 == 0:
96 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
97 | progress_bar.update(10)
98 | if iteration == opt.iterations:
99 | progress_bar.close()
100 |
101 | # Log and save
102 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
103 | if (iteration in saving_iterations):
104 | print("\n[ITER {}] Saving Gaussians".format(iteration))
105 | scene.save(iteration)
106 |
107 | # Densification
108 | if iteration < opt.densify_until_iter:
109 | # Keep track of max radii in image-space for pruning
110 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
111 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
112 |
113 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
114 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
115 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
116 |
117 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
118 | gaussians.reset_opacity()
119 |
120 | # Optimizer step
121 | if iteration < opt.iterations:
122 | gaussians.optimizer.step()
123 | gaussians.optimizer.zero_grad(set_to_none = True)
124 |
125 | if (iteration in checkpoint_iterations):
126 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
127 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
128 |
129 | def prepare_output_and_logger(args):
130 | if not args.model_path:
131 | if os.getenv('OAR_JOB_ID'):
132 | unique_str=os.getenv('OAR_JOB_ID')
133 | else:
134 | unique_str = str(uuid.uuid4())
135 | args.model_path = os.path.join("./output/", unique_str[0:10])
136 |
137 | # Set up output folder
138 | print("Output folder: {}".format(args.model_path))
139 | os.makedirs(args.model_path, exist_ok = True)
140 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
141 | cfg_log_f.write(str(Namespace(**vars(args))))
142 |
143 | # Create Tensorboard writer
144 | tb_writer = None
145 | if TENSORBOARD_FOUND:
146 | tb_writer = SummaryWriter(args.model_path)
147 | else:
148 | print("Tensorboard not available: not logging progress")
149 | return tb_writer
150 |
151 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
152 | if tb_writer:
153 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
154 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
155 | tb_writer.add_scalar('iter_time', elapsed, iteration)
156 |
157 | # Report test and samples of training set
158 | if iteration in testing_iterations:
159 | torch.cuda.empty_cache()
160 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
161 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
162 |
163 | for config in validation_configs:
164 | if config['cameras'] and len(config['cameras']) > 0:
165 | l1_test = 0.0
166 | psnr_test = 0.0
167 | for idx, viewpoint in enumerate(config['cameras']):
168 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
169 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
170 | if tb_writer and (idx < 5):
171 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
172 | if iteration == testing_iterations[0]:
173 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
174 | l1_test += l1_loss(image, gt_image).mean().double()
175 | psnr_test += psnr(image, gt_image).mean().double()
176 | psnr_test /= len(config['cameras'])
177 | l1_test /= len(config['cameras'])
178 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
179 | if tb_writer:
180 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
181 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
182 |
183 | if tb_writer:
184 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
185 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
186 | torch.cuda.empty_cache()
187 |
188 | if __name__ == "__main__":
189 | # Set up command line argument parser
190 | parser = ArgumentParser(description="Training script parameters")
191 | lp = ModelParams(parser)
192 | op = OptimizationParams(parser)
193 | pp = PipelineParams(parser)
194 | parser.add_argument('--ip', type=str, default="127.0.0.1")
195 | parser.add_argument('--port', type=int, default=6009)
196 | parser.add_argument('--debug_from', type=int, default=-1)
197 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
198 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
199 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
200 | parser.add_argument("--quiet", action="store_true")
201 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
202 | parser.add_argument("--start_checkpoint", type=str, default = None)
203 | args = parser.parse_args(sys.argv[1:])
204 | args.save_iterations.append(args.iterations)
205 |
206 | print("Optimizing " + args.model_path)
207 |
208 | # Initialize system state (RNG)
209 | safe_state(args.quiet)
210 |
211 | # Start GUI server, configure and run training
212 | network_gui.init(args.ip, args.port)
213 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
214 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
215 |
216 | # All done
217 | print("\nTraining complete.")
218 |
--------------------------------------------------------------------------------
/trainer/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from torch.autograd import Variable
13 |
14 | import numpy as np
15 | from math import exp
16 |
17 |
18 | import pdb
19 |
20 |
21 | class Loss_Eval(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 |
25 | def forward(self, rgb_pred, rgb_gt):
26 | loss = F.mse_loss(rgb_pred, rgb_gt)
27 | return_dict = {
28 | 'loss': loss
29 | }
30 | return return_dict
31 |
32 |
33 | class Loss(nn.Module):
34 | def __init__(self, cfg=None):
35 | super().__init__()
36 |
37 | self.depth_loss_type = cfg.depth_loss_type
38 |
39 | self.l1_loss = nn.L1Loss(reduction='sum')
40 | self.l2_loss = nn.MSELoss(reduction='sum')
41 | self.scale_inv_loss = ScaleAndShiftInvariantLoss(alpha=0.5,
42 | scales=1)
43 |
44 | # ssim_loss = ssim
45 | self.ssim_loss = SSIM_V2()
46 |
47 | self.cfg = cfg
48 |
49 | def get_rgb_full_loss(self, rgb_values, rgb_gt, rgb_loss_type='l2'):
50 | num_pixels = rgb_values.shape[1] * rgb_values.shape[2]
51 | if rgb_loss_type == 'l1':
52 | rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(num_pixels)
53 | elif rgb_loss_type == 'l2':
54 | rgb_loss = self.l2_loss(rgb_values, rgb_gt) / float(num_pixels)
55 | return rgb_loss
56 |
57 | def depth_loss_dpt(self, pred_depth, gt_depth, weight=None):
58 | """
59 | :param pred_depth: (H, W)
60 | :param gt_depth: (H, W)
61 | :param weight: (H, W)
62 | :return: scalar
63 | """
64 | t_pred = torch.median(pred_depth)
65 | s_pred = torch.mean(torch.abs(pred_depth - t_pred))
66 |
67 | t_gt = torch.median(gt_depth)
68 | s_gt = torch.mean(torch.abs(gt_depth - t_gt))
69 |
70 | pred_depth_n = (pred_depth - t_pred) / s_pred
71 | gt_depth_n = (gt_depth - t_gt) / s_gt
72 |
73 | if weight is not None:
74 | loss = F.mse_loss(pred_depth_n, gt_depth_n, reduction='none')
75 | loss = loss * weight
76 | loss = loss.sum() / (weight.sum() + 1e-8)
77 | else:
78 |
79 | depth_error = (pred_depth_n - gt_depth_n) ** 2
80 | depth_error[depth_error > torch.quantile(depth_error, 0.8)] = 0
81 | loss = depth_error.mean()
82 | # loss = F.mse_loss(pred_depth_n, gt_depth_n)
83 |
84 | return loss
85 |
86 | def get_depth_loss(self, depth_pred, depth_gt):
87 | num_pixels = depth_pred.shape[0] * depth_pred.shape[1]
88 | if self.depth_loss_type == 'l1':
89 | loss = self.l1_loss(depth_pred, depth_gt) / float(num_pixels)
90 | elif self.depth_loss_type == 'invariant':
91 | # loss = self.depth_loss_dpt(1.0/depth_pred, 1.0/depth_gt)
92 | mask = (depth_gt > 0.02).float()
93 | loss = self.scale_inv_loss(
94 | depth_pred[None], depth_gt[None], mask[None])
95 | return loss
96 |
97 |
98 | def forward(self, rgb_pred, rgb_gt, depth_pred=None, depth_gt=None,
99 | rgb_loss_type='l1', **kwargs):
100 |
101 | rgb_gt = rgb_gt.cuda()
102 |
103 | lambda_dssim = self.cfg.lambda_dssim
104 | lambda_depth = self.cfg.lambda_depth
105 |
106 | # rgb_full_loss = self.get_rgb_full_loss(rgb_pred, rgb_gt, rgb_loss_type)
107 | rgb_full_loss = (1 - lambda_dssim) * l1_loss(rgb_pred, rgb_gt)
108 | if lambda_dssim != 0.0:
109 | # dssim_loss = compute_ssim_loss(rgb_pred, rgb_gt)
110 | # pdb.set_trace()
111 | # dssim_loss = 1 - ssim(rgb_pred, rgb_gt)
112 | dssim_loss = 1 - self.ssim_loss(rgb_pred, rgb_gt)
113 |
114 | if lambda_depth != 0.0 and depth_pred is not None and depth_gt is not None:
115 | depth_gt = depth_gt.cuda()
116 | depth_pred[depth_pred < 0.02] = 0.02
117 | depth_pred[depth_pred > 20.0] = 20.0
118 | depth_loss = self.get_depth_loss(
119 | depth_pred.squeeze(), depth_gt.squeeze())
120 | else:
121 | depth_loss = torch.tensor(0.0).cuda().float()
122 |
123 | loss = rgb_full_loss + lambda_dssim * dssim_loss +\
124 | lambda_depth * depth_loss
125 |
126 | if torch.isnan(loss):
127 | breakpoint()
128 |
129 | return_dict = {
130 | 'loss': loss,
131 | 'loss_rgb': rgb_full_loss,
132 | 'loss_dssim': dssim_loss,
133 | 'loss_depth': depth_loss,
134 | }
135 |
136 | return return_dict
137 |
138 |
139 | def l1_loss(network_output, gt):
140 | return torch.abs((network_output - gt)).mean()
141 |
142 |
143 | def l2_loss(network_output, gt):
144 | return ((network_output - gt) ** 2).mean()
145 |
146 |
147 | def gaussian(window_size, sigma):
148 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /
149 | float(2 * sigma ** 2)) for x in range(window_size)])
150 | return gauss / gauss.sum()
151 |
152 |
153 | def create_window(window_size, channel):
154 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
155 | _2D_window = _1D_window.mm(
156 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
157 | window = Variable(_2D_window.expand(
158 | channel, 1, window_size, window_size).contiguous())
159 | return window
160 |
161 |
162 |
163 |
164 | def _ssim(
165 | img1, img2, window, window_size, channel, mask=None, size_average=True
166 | ):
167 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
168 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
169 |
170 | mu1_sq = mu1.pow(2)
171 | mu2_sq = mu2.pow(2)
172 | mu1_mu2 = mu1 * mu2
173 |
174 | sigma1_sq = (
175 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
176 | - mu1_sq
177 | )
178 | sigma2_sq = (
179 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
180 | - mu2_sq
181 | )
182 | sigma12 = (
183 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
184 | - mu1_mu2
185 | )
186 |
187 | C1 = (0.01) ** 2
188 | C2 = (0.03) ** 2
189 |
190 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
191 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
192 | )
193 |
194 | if not (mask is None):
195 | b = mask.size(0)
196 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
197 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
198 | dim=1
199 | ).clamp(min=1)
200 | return ssim_map
201 |
202 | # import pdb
203 |
204 | # pdb.set_trace
205 |
206 | if size_average:
207 | return ssim_map.mean()
208 | else:
209 | return ssim_map.mean(1).mean(1).mean(1)
210 |
211 |
212 | class SSIM_V2(torch.nn.Module):
213 | def __init__(self, window_size=11, size_average=True):
214 | super(SSIM_V2, self).__init__()
215 | self.window_size = window_size
216 | self.size_average = size_average
217 | self.channel = 1
218 | self.window = create_window(window_size, self.channel)
219 |
220 | def forward(self, img1, img2, mask=None):
221 | if img1.dim() == 3:
222 | img1 = img1.unsqueeze(0)
223 | if img2.dim() == 3:
224 | img2 = img2.unsqueeze(0)
225 |
226 | (_, channel, _, _) = img1.size()
227 |
228 | if (
229 | channel == self.channel
230 | and self.window.data.type() == img1.data.type()
231 | ):
232 | window = self.window
233 | else:
234 | window = create_window(self.window_size, channel)
235 |
236 | if img1.is_cuda:
237 | window = window.cuda(img1.get_device())
238 | window = window.type_as(img1)
239 |
240 | self.window = window
241 | self.channel = channel
242 |
243 | return _ssim(
244 | img1,
245 | img2,
246 | window,
247 | self.window_size,
248 | channel,
249 | mask,
250 | self.size_average,
251 | )
252 |
253 |
254 |
255 |
256 |
257 |
258 | # copy from MiDaS and MonoSDF
259 | def compute_scale_and_shift(prediction, target, mask):
260 | # system matrix: A = [[a_00, a_01], [a_10, a_11]]
261 | a_00 = torch.sum(mask * prediction * prediction, (1, 2))
262 | a_01 = torch.sum(mask * prediction, (1, 2))
263 | a_11 = torch.sum(mask, (1, 2))
264 |
265 | # right hand side: b = [b_0, b_1]
266 | b_0 = torch.sum(mask * prediction * target, (1, 2))
267 | b_1 = torch.sum(mask * target, (1, 2))
268 |
269 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
270 | x_0 = torch.zeros_like(b_0)
271 | x_1 = torch.zeros_like(b_1)
272 |
273 | det = a_00 * a_11 - a_01 * a_01
274 | valid = det.nonzero()
275 |
276 | x_0[valid] = (a_11[valid] * b_0[valid] -
277 | a_01[valid] * b_1[valid]) / det[valid]
278 | x_1[valid] = (-a_01[valid] * b_0[valid] +
279 | a_00[valid] * b_1[valid]) / det[valid]
280 |
281 | return x_0, x_1
282 |
283 |
284 | def reduction_batch_based(image_loss, M):
285 | # average of all valid pixels of the batch
286 |
287 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
288 | divisor = torch.sum(M)
289 |
290 | if divisor == 0:
291 | return 0
292 | else:
293 | return torch.sum(image_loss) / divisor
294 |
295 |
296 | def reduction_image_based(image_loss, M):
297 | # mean of average of valid pixels of an image
298 |
299 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
300 | valid = M.nonzero()
301 |
302 | image_loss[valid] = image_loss[valid] / M[valid]
303 |
304 | return torch.mean(image_loss)
305 |
306 |
307 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
308 |
309 | M = torch.sum(mask, (1, 2))
310 | res = prediction - target
311 | image_loss = torch.sum(mask * res * res, (1, 2))
312 |
313 | return reduction(image_loss, 2 * M)
314 |
315 |
316 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
317 |
318 | M = torch.sum(mask, (1, 2))
319 |
320 | diff = prediction - target
321 | diff = torch.mul(mask, diff)
322 |
323 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
324 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
325 | grad_x = torch.mul(mask_x, grad_x)
326 |
327 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
328 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
329 | grad_y = torch.mul(mask_y, grad_y)
330 |
331 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
332 |
333 | return reduction(image_loss, M)
334 |
335 |
336 | class MSELoss(nn.Module):
337 | def __init__(self, reduction='batch-based'):
338 | super().__init__()
339 |
340 | if reduction == 'batch-based':
341 | self.__reduction = reduction_batch_based
342 | else:
343 | self.__reduction = reduction_image_based
344 |
345 | def forward(self, prediction, target, mask):
346 | return mse_loss(prediction, target, mask, reduction=self.__reduction)
347 |
348 |
349 | class GradientLoss(nn.Module):
350 | def __init__(self, scales=4, reduction='batch-based'):
351 | super().__init__()
352 |
353 | if reduction == 'batch-based':
354 | self.__reduction = reduction_batch_based
355 | else:
356 | self.__reduction = reduction_image_based
357 |
358 | self.__scales = scales
359 |
360 | def forward(self, prediction, target, mask):
361 | total = 0
362 |
363 | for scale in range(self.__scales):
364 | step = pow(2, scale)
365 |
366 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
367 | mask[:, ::step, ::step], reduction=self.__reduction)
368 |
369 | return total
370 |
371 |
372 | class ScaleAndShiftInvariantLoss(nn.Module):
373 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'):
374 | super().__init__()
375 |
376 | self.__data_loss = MSELoss(reduction=reduction)
377 | self.__regularization_loss = GradientLoss(
378 | scales=scales, reduction=reduction)
379 | self.__alpha = alpha
380 |
381 | self.__prediction_ssi = None
382 |
383 | def forward(self, prediction, target, mask):
384 | scale, shift = compute_scale_and_shift(prediction, target, mask)
385 | self.__prediction_ssi = scale.view(-1, 1, 1) * \
386 | prediction + shift.view(-1, 1, 1)
387 |
388 | total = self.__data_loss(self.__prediction_ssi, target, mask)
389 | if self.__alpha > 0:
390 | total += self.__alpha * \
391 | self.__regularization_loss(self.__prediction_ssi, target, mask)
392 |
393 | return total
394 |
395 | def __get_prediction_ssi(self):
396 | return self.__prediction_ssi
397 |
398 | prediction_ssi = property(__get_prediction_ssi)
399 |
--------------------------------------------------------------------------------
/utils/camera_conversion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # License file: https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Union
8 | import numpy as np
9 |
10 |
11 | import torch
12 | import torch.nn.functional as F
13 |
14 | Device = Union[str, torch.device]
15 |
16 | """
17 | The transformation matrices returned from the functions in this file assume
18 | the points on which the transformation will be applied are column vectors.
19 | i.e. the R matrix is structured as
20 |
21 | R = [
22 | [Rxx, Rxy, Rxz],
23 | [Ryx, Ryy, Ryz],
24 | [Rzx, Rzy, Rzz],
25 | ] # (3, 3)
26 |
27 | This matrix can be applied to column vectors by post multiplication
28 | by the points e.g.
29 |
30 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
31 | transformed_points = R * points
32 |
33 | To apply the same matrix to points which are row vectors, the R matrix
34 | can be transposed and pre multiplied by the points:
35 |
36 | e.g.
37 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
38 | transformed_points = points * R.transpose(1, 0)
39 | """
40 |
41 |
42 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
43 | """
44 | Convert rotations given as quaternions to rotation matrices.
45 |
46 | Args:
47 | quaternions: quaternions with real part first,
48 | as tensor of shape (..., 4).
49 |
50 | Returns:
51 | Rotation matrices as tensor of shape (..., 3, 3).
52 | """
53 | r, i, j, k = torch.unbind(quaternions, -1)
54 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
55 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
56 |
57 | o = torch.stack(
58 | (
59 | 1 - two_s * (j * j + k * k),
60 | two_s * (i * j - k * r),
61 | two_s * (i * k + j * r),
62 | two_s * (i * j + k * r),
63 | 1 - two_s * (i * i + k * k),
64 | two_s * (j * k - i * r),
65 | two_s * (i * k - j * r),
66 | two_s * (j * k + i * r),
67 | 1 - two_s * (i * i + j * j),
68 | ),
69 | -1,
70 | )
71 | return o.reshape(quaternions.shape[:-1] + (3, 3))
72 |
73 |
74 | def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
75 | """
76 | Return a tensor where each element has the absolute value taken from the,
77 | corresponding element of a, with sign taken from the corresponding
78 | element of b. This is like the standard copysign floating-point operation,
79 | but is not careful about negative 0 and NaN.
80 |
81 | Args:
82 | a: source tensor.
83 | b: tensor whose signs will be used, of the same shape as a.
84 |
85 | Returns:
86 | Tensor of the same shape as a with the signs of b.
87 | """
88 | signs_differ = (a < 0) != (b < 0)
89 | return torch.where(signs_differ, -a, a)
90 |
91 |
92 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
93 | """
94 | Returns torch.sqrt(torch.max(0, x))
95 | but with a zero subgradient where x is 0.
96 | """
97 | ret = torch.zeros_like(x)
98 | positive_mask = x > 0
99 | ret[positive_mask] = torch.sqrt(x[positive_mask])
100 | return ret
101 |
102 |
103 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
104 | """
105 | Convert rotations given as rotation matrices to quaternions.
106 |
107 | Args:
108 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
109 |
110 | Returns:
111 | quaternions with real part first, as tensor of shape (..., 4).
112 | """
113 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
114 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
115 |
116 | batch_dim = matrix.shape[:-2]
117 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
118 | matrix.reshape(batch_dim + (9,)), dim=-1
119 | )
120 |
121 | q_abs = _sqrt_positive_part(
122 | torch.stack(
123 | [
124 | 1.0 + m00 + m11 + m22,
125 | 1.0 + m00 - m11 - m22,
126 | 1.0 - m00 + m11 - m22,
127 | 1.0 - m00 - m11 + m22,
128 | ],
129 | dim=-1,
130 | )
131 | )
132 |
133 | # we produce the desired quaternion multiplied by each of r, i, j, k
134 | quat_by_rijk = torch.stack(
135 | [
136 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
137 | # `int`.
138 | torch.stack([q_abs[..., 0] ** 2, m21 - m12,
139 | m02 - m20, m10 - m01], dim=-1),
140 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
141 | # `int`.
142 | torch.stack([m21 - m12, q_abs[..., 1] ** 2,
143 | m10 + m01, m02 + m20], dim=-1),
144 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
145 | # `int`.
146 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]
147 | ** 2, m12 + m21], dim=-1),
148 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
149 | # `int`.
150 | torch.stack([m10 - m01, m20 + m02, m21 + m12,
151 | q_abs[..., 3] ** 2], dim=-1),
152 | ],
153 | dim=-2,
154 | )
155 |
156 | # We floor here at 0.1 but the exact level is not important; if q_abs is small,
157 | # the candidate won't be picked.
158 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
159 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
160 |
161 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
162 | # forall i; we pick the best-conditioned one (with the largest denominator)
163 |
164 | out = quat_candidates[
165 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
166 | ].reshape(batch_dim + (4,))
167 | return standardize_quaternion(out)
168 |
169 |
170 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
171 | """
172 | Convert a unit quaternion to a standard form: one in which the real
173 | part is non negative.
174 |
175 | Args:
176 | quaternions: Quaternions with real part first,
177 | as tensor of shape (..., 4).
178 |
179 | Returns:
180 | Standardized quaternions as tensor of shape (..., 4).
181 | """
182 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
183 |
184 |
185 | def convert3x4_4x4(input):
186 | """
187 | :param input: (N, 3, 4) or (3, 4) torch or np
188 | :return: (N, 4, 4) or (4, 4) torch or np
189 | """
190 | if torch.is_tensor(input):
191 | if len(input.shape) == 3:
192 | output = torch.cat([input, torch.zeros_like(
193 | input[:, 0:1])], dim=1) # (N, 4, 4)
194 | output[:, 3, 3] = 1.0
195 | else:
196 | output = torch.cat([input, torch.tensor(
197 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
198 | else:
199 | if len(input.shape) == 3:
200 | output = np.concatenate(
201 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
202 | output[:, 3, 3] = 1.0
203 | else:
204 | output = np.concatenate(
205 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
206 | output[3, 3] = 1.0
207 | return output
208 |
209 |
210 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
211 | """
212 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
213 | using Gram--Schmidt orthogonalization per Section B of [1].
214 | Args:
215 | d6: 6D rotation representation, of size (*, 6)
216 |
217 | Returns:
218 | batch of rotation matrices of size (*, 3, 3)
219 |
220 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
221 | On the Continuity of Rotation Representations in Neural Networks.
222 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
223 | Retrieved from http://arxiv.org/abs/1812.07035
224 | """
225 |
226 | a1, a2 = d6[..., :3], d6[..., 3:]
227 | b1 = F.normalize(a1, dim=-1)
228 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
229 | b2 = F.normalize(b2, dim=-1)
230 | b3 = torch.cross(b1, b2, dim=-1)
231 | return torch.stack((b1, b2, b3), dim=-2)
232 |
233 |
234 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
235 | """
236 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
237 | by dropping the last row. Note that 6D representation is not unique.
238 | Args:
239 | matrix: batch of rotation matrices of size (*, 3, 3)
240 |
241 | Returns:
242 | 6D rotation representation, of size (*, 6)
243 |
244 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
245 | On the Continuity of Rotation Representations in Neural Networks.
246 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
247 | Retrieved from http://arxiv.org/abs/1812.07035
248 | """
249 | batch_dim = matrix.size()[:-2]
250 | return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
251 |
252 |
253 | class Pose():
254 | """
255 | A class of operations on camera poses (PyTorch tensors with shape [...,3,4])
256 | each [3,4] camera pose takes the form of [R|t]
257 | """
258 |
259 | def __call__(self, R=None, t=None):
260 | # construct a camera pose from the given R and/or t
261 | assert (R is not None or t is not None)
262 | if R is None:
263 | if not isinstance(t, torch.Tensor):
264 | t = torch.tensor(t)
265 | R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)
266 | elif t is None:
267 | if not isinstance(R, torch.Tensor):
268 | R = torch.tensor(R)
269 | t = torch.zeros(R.shape[:-1], device=R.device)
270 | else:
271 | if not isinstance(R, torch.Tensor):
272 | R = torch.tensor(R)
273 | if not isinstance(t, torch.Tensor):
274 | t = torch.tensor(t)
275 | assert (R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3))
276 | R = R.float()
277 | t = t.float()
278 | pose = torch.cat([R, t[..., None]], dim=-1) # [...,3,4]
279 | assert (pose.shape[-2:] == (3, 4))
280 | return pose
281 |
282 | def invert(self, pose, use_inverse=False):
283 | # invert a camera pose
284 | R, t = pose[..., :3], pose[..., 3:]
285 | R_inv = R.inverse() if use_inverse else R.transpose(-1, -2)
286 | t_inv = (-R_inv@t)[..., 0]
287 | pose_inv = self(R=R_inv, t=t_inv)
288 | return pose_inv
289 |
290 | def compose(self, pose_list):
291 | # compose a sequence of poses together
292 | # pose_new(x) = poseN o ... o pose2 o pose1(x)
293 | pose_new = pose_list[0]
294 | for pose in pose_list[1:]:
295 | pose_new = self.compose_pair(pose_new, pose)
296 | return pose_new
297 |
298 | def compose_pair(self, pose_a, pose_b):
299 | # pose_new(x) = pose_b o pose_a(x)
300 | R_a, t_a = pose_a[..., :3], pose_a[..., 3:]
301 | R_b, t_b = pose_b[..., :3], pose_b[..., 3:]
302 | R_new = R_b@R_a
303 | t_new = (R_b@t_a+t_b)[..., 0]
304 | pose_new = self(R=R_new, t=t_new)
305 | return pose_new
306 |
307 |
308 | class Lie():
309 | """
310 | Lie algebra for SO(3) and SE(3) operations in PyTorch
311 | """
312 |
313 | def so3_to_SO3(self, w): # [...,3]
314 | wx = self.skew_symmetric(w)
315 | theta = w.norm(dim=-1)[..., None, None]
316 | I = torch.eye(3, device=w.device, dtype=torch.float32)
317 | A = self.taylor_A(theta)
318 | B = self.taylor_B(theta)
319 | R = I+A*wx+B*wx@wx
320 | return R
321 |
322 | def SO3_to_so3(self, R, eps=1e-7): # [...,3,3]
323 | trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]
324 | # ln(R) will explode if theta==pi
325 | theta = ((trace-1)/2).clamp(-1+eps, 1 -
326 | eps).acos_()[..., None, None] % np.pi
327 | lnR = 1/(2*self.taylor_A(theta)+1e-8) * \
328 | (R-R.transpose(-2, -1)) # FIXME: wei-chiu finds it weird
329 | w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
330 | w = torch.stack([w0, w1, w2], dim=-1)
331 | return w
332 |
333 | def se3_to_SE3(self, wu): # [...,3]
334 | w, u = wu.split([3, 3], dim=-1)
335 | wx = self.skew_symmetric(w)
336 | theta = w.norm(dim=-1)[..., None, None]
337 | I = torch.eye(3, device=w.device, dtype=torch.float32)
338 | A = self.taylor_A(theta)
339 | B = self.taylor_B(theta)
340 | C = self.taylor_C(theta)
341 | R = I+A*wx+B*wx@wx
342 | V = I+B*wx+C*wx@wx
343 | Rt = torch.cat([R, (V@u[..., None])], dim=-1)
344 | return Rt
345 |
346 | def SE3_to_se3(self, Rt, eps=1e-8): # [...,3,4]
347 | R, t = Rt.split([3, 1], dim=-1)
348 | w = self.SO3_to_so3(R)
349 | wx = self.skew_symmetric(w)
350 | theta = w.norm(dim=-1)[..., None, None]
351 | I = torch.eye(3, device=w.device, dtype=torch.float32)
352 | A = self.taylor_A(theta)
353 | B = self.taylor_B(theta)
354 | invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
355 | u = (invV@t)[..., 0]
356 | wu = torch.cat([w, u], dim=-1)
357 | return wu
358 |
359 | def skew_symmetric(self, w):
360 | w0, w1, w2 = w.unbind(dim=-1)
361 | O = torch.zeros_like(w0)
362 | wx = torch.stack([torch.stack([O, -w2, w1], dim=-1),
363 | torch.stack([w2, O, -w0], dim=-1),
364 | torch.stack([-w1, w0, O], dim=-1)], dim=-2)
365 | return wx
366 |
367 | def taylor_A(self, x, nth=10):
368 | # Taylor expansion of sin(x)/x
369 | ans = torch.zeros_like(x)
370 | denom = 1.
371 | for i in range(nth+1):
372 | if i > 0:
373 | denom *= (2*i)*(2*i+1)
374 | ans = ans+(-1)**i*x**(2*i)/denom
375 | return ans
376 |
377 | def taylor_B(self, x, nth=10):
378 | # Taylor expansion of (1-cos(x))/x**2
379 | ans = torch.zeros_like(x)
380 | denom = 1.
381 | for i in range(nth+1):
382 | denom *= (2*i+1)*(2*i+2)
383 | ans = ans+(-1)**i*x**(2*i)/denom
384 | return ans
385 |
386 | def taylor_C(self, x, nth=10):
387 | # Taylor expansion of (x-sin(x))/x**3
388 | ans = torch.zeros_like(x)
389 | denom = 1.
390 | for i in range(nth+1):
391 | denom *= (2*i+2)*(2*i+3)
392 | ans = ans+(-1)**i*x**(2*i)/denom
393 | return ans
394 |
395 |
396 | class Quaternion():
397 |
398 | def q_to_R(self, q):
399 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
400 | qa, qb, qc, qd = q.unbind(dim=-1)
401 | R = torch.stack([torch.stack([1-2*(qc**2+qd**2), 2*(qb*qc-qa*qd), 2*(qa*qc+qb*qd)], dim=-1),
402 | torch.stack(
403 | [2*(qb*qc+qa*qd), 1-2*(qb**2+qd**2), 2*(qc*qd-qa*qb)], dim=-1),
404 | torch.stack([2*(qb*qd-qa*qc), 2*(qa*qb+qc*qd), 1-2*(qb**2+qc**2)], dim=-1)], dim=-2)
405 | return R
406 |
407 | def R_to_q(self, R, eps=1e-8): # [B,3,3]
408 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
409 | # FIXME: this function seems a bit problematic, need to double-check
410 | row0, row1, row2 = R.unbind(dim=-2)
411 | R00, R01, R02 = row0.unbind(dim=-1)
412 | R10, R11, R12 = row1.unbind(dim=-1)
413 | R20, R21, R22 = row2.unbind(dim=-1)
414 | t = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]
415 | r = (1+t+eps).sqrt()
416 | qa = 0.5*r
417 | qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()
418 | qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()
419 | qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()
420 | q = torch.stack([qa, qb, qc, qd], dim=-1)
421 | for i, qi in enumerate(q):
422 | if torch.isnan(qi).any():
423 | K = torch.stack([torch.stack([R00-R11-R22, R10+R01, R20+R02, R12-R21], dim=-1),
424 | torch.stack(
425 | [R10+R01, R11-R00-R22, R21+R12, R20-R02], dim=-1),
426 | torch.stack(
427 | [R20+R02, R21+R12, R22-R00-R11, R01-R10], dim=-1),
428 | torch.stack([R12-R21, R20-R02, R01-R10, R00+R11+R22], dim=-1)], dim=-2)/3.0
429 | K = K[i]
430 | eigval, eigvec = torch.linalg.eigh(K)
431 | V = eigvec[:, eigval.argmax()]
432 | q[i] = torch.stack([V[3], V[0], V[1], V[2]])
433 | return q
434 |
435 | def invert(self, q):
436 | qa, qb, qc, qd = q.unbind(dim=-1)
437 | norm = q.norm(dim=-1, keepdim=True)
438 | q_inv = torch.stack([qa, -qb, -qc, -qd], dim=-1)/norm**2
439 | return q_inv
440 |
441 | def product(self, q1, q2): # [B,4]
442 | q1a, q1b, q1c, q1d = q1.unbind(dim=-1)
443 | q2a, q2b, q2c, q2d = q2.unbind(dim=-1)
444 | hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,
445 | q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,
446 | q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,
447 | q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a], dim=-1)
448 | return hamil_prod
449 |
450 |
451 | lie = Lie()
452 | pose = Pose()
453 | quaternion = Quaternion()
454 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 |
11 | from scene.cameras import Camera
12 | import numpy as np
13 | from utils.general_utils import PILtoTorch
14 | from utils.graphics_utils import fov2focal
15 | import copy
16 | WARNED = False
17 |
18 |
19 | def loadCam(args, id, cam_info, resolution_scale):
20 | orig_w, orig_h = cam_info.image.size
21 |
22 | if args.resolution in [1, 2, 4, 8]:
23 | resolution = round(orig_w/(resolution_scale * args.resolution)
24 | ), round(orig_h/(resolution_scale * args.resolution))
25 | else: # should be a type that converts to float
26 | if args.resolution == -1:
27 | if orig_w > 1600:
28 | global WARNED
29 | if not WARNED:
30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
32 | WARNED = True
33 | global_down = orig_w / 1600
34 | else:
35 | global_down = 1
36 | else:
37 | global_down = orig_w / args.resolution
38 |
39 | scale = float(global_down) * float(resolution_scale)
40 | resolution = (int(orig_w / scale), int(orig_h / scale))
41 | # intrinsics = copy.copy(cam_info.intrinsics)
42 | # intrinsics[:2, :] /= scale
43 | # import pdb; pdb.set_trace()
44 | focal_length_x = fov2focal(cam_info.FovX, orig_w)
45 | focal_length_y = fov2focal(cam_info.FovY, orig_h)
46 | scale = int(orig_w / resolution[0])
47 | intrinsics = np.array(
48 | [[focal_length_x//scale, 0, resolution[0]/2],
49 | [0, focal_length_y//scale, resolution[1]/2], [0, 0, 1]]).astype(np.float32)
50 |
51 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
52 |
53 | gt_image = resized_image_rgb[:3, ...]
54 | loaded_mask = None
55 |
56 | if resized_image_rgb.shape[1] == 4:
57 | loaded_mask = resized_image_rgb[3:4, ...]
58 |
59 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
60 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
61 | image=gt_image, gt_alpha_mask=loaded_mask, intrinsics=intrinsics,
62 | image_name=cam_info.image_name, uid=id, data_device=args.data_device)
63 |
64 |
65 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
66 | camera_list = []
67 |
68 | for id, c in enumerate(cam_infos):
69 | camera_list.append(loadCam(args, id, c, resolution_scale))
70 |
71 | return camera_list
72 |
73 |
74 | def camera_to_JSON(id, camera: Camera):
75 | Rt = np.zeros((4, 4))
76 | Rt[:3, :3] = camera.R.transpose()
77 | Rt[:3, 3] = camera.T
78 | Rt[3, 3] = 1.0
79 |
80 | W2C = np.linalg.inv(Rt)
81 | pos = W2C[:3, 3]
82 | rot = W2C[:3, :3]
83 | serializable_array_2d = [x.tolist() for x in rot]
84 | camera_entry = {
85 | 'id': id,
86 | 'img_name': camera.image_name,
87 | 'width': camera.width,
88 | 'height': camera.height,
89 | 'position': pos.tolist(),
90 | 'rotation': serializable_array_2d,
91 | 'fy': fov2focal(camera.FovY, camera.height),
92 | 'fx': fov2focal(camera.FovX, camera.width)
93 | }
94 | return camera_entry
95 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import torch
11 | import sys
12 | from datetime import datetime
13 | import numpy as np
14 | import random
15 |
16 | def inverse_sigmoid(x):
17 | return torch.log(x/(1-x))
18 |
19 | def PILtoTorch(pil_image, resolution):
20 | resized_image_PIL = pil_image.resize(resolution)
21 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
22 | if len(resized_image.shape) == 3:
23 | return resized_image.permute(2, 0, 1)
24 | else:
25 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
26 |
27 | def get_expon_lr_func(
28 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
29 | ):
30 | """
31 | Copied from Plenoxels
32 |
33 | Continuous learning rate decay function. Adapted from JaxNeRF
34 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
35 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
36 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
37 | function of lr_delay_mult, such that the initial learning rate is
38 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
39 | to the normal learning rate when steps>lr_delay_steps.
40 | :param conf: config subtree 'lr' or similar
41 | :param max_steps: int, the number of steps during optimization.
42 | :return HoF which takes step as input
43 | """
44 |
45 | def helper(step):
46 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
47 | # Disable this parameter
48 | return 0.0
49 | if lr_delay_steps > 0:
50 | # A kind of reverse cosine decay.
51 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
52 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
53 | )
54 | else:
55 | delay_rate = 1.0
56 | t = np.clip(step / max_steps, 0, 1)
57 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
58 | return delay_rate * log_lerp
59 |
60 | return helper
61 |
62 | def strip_lowerdiag(L):
63 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
64 |
65 | uncertainty[:, 0] = L[:, 0, 0]
66 | uncertainty[:, 1] = L[:, 0, 1]
67 | uncertainty[:, 2] = L[:, 0, 2]
68 | uncertainty[:, 3] = L[:, 1, 1]
69 | uncertainty[:, 4] = L[:, 1, 2]
70 | uncertainty[:, 5] = L[:, 2, 2]
71 | return uncertainty
72 |
73 | def strip_symmetric(sym):
74 | return strip_lowerdiag(sym)
75 |
76 | def build_rotation(r):
77 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
78 |
79 | q = r / norm[:, None]
80 |
81 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
82 |
83 | r = q[:, 0]
84 | x = q[:, 1]
85 | y = q[:, 2]
86 | z = q[:, 3]
87 |
88 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
89 | R[:, 0, 1] = 2 * (x*y - r*z)
90 | R[:, 0, 2] = 2 * (x*z + r*y)
91 | R[:, 1, 0] = 2 * (x*y + r*z)
92 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
93 | R[:, 1, 2] = 2 * (y*z - r*x)
94 | R[:, 2, 0] = 2 * (x*z - r*y)
95 | R[:, 2, 1] = 2 * (y*z + r*x)
96 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
97 | return R
98 |
99 | def build_scaling_rotation(s, r):
100 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
101 | R = build_rotation(r)
102 |
103 | L[:,0,0] = s[:,0]
104 | L[:,1,1] = s[:,1]
105 | L[:,2,2] = s[:,2]
106 |
107 | L = R @ L
108 | return L
109 |
110 | def safe_state(silent):
111 | old_f = sys.stdout
112 | class F:
113 | def __init__(self, silent):
114 | self.silent = silent
115 |
116 | def write(self, x):
117 | if not self.silent:
118 | if x.endswith("\n"):
119 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
120 | else:
121 | old_f.write(x)
122 |
123 | def flush(self):
124 | old_f.flush()
125 |
126 | sys.stdout = F(silent)
127 |
128 | random.seed(0)
129 | np.random.seed(0)
130 | torch.manual_seed(0)
131 | torch.cuda.set_device(torch.device("cuda:0"))
132 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 |
11 | import torch
12 | import math
13 | import numpy as np
14 | from typing import NamedTuple
15 |
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points: np.array
19 | colors: np.array
20 | normals: np.array
21 |
22 | def update(self, points):
23 | self.points = points
24 |
25 |
26 | class SegmentedPointCloud(NamedTuple):
27 | points: np.array
28 | colors: np.array
29 | normals: np.array
30 | labels: np.array
31 |
32 |
33 | class SimilarityTransform(NamedTuple):
34 | R: torch.Tensor
35 | T: torch.Tensor
36 | s: torch.Tensor
37 |
38 |
39 | def geom_transform_points(points, transf_matrix):
40 | P, _ = points.shape
41 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
42 | points_hom = torch.cat([points, ones], dim=1)
43 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
44 |
45 | denom = points_out[..., 3:] + 0.0000001
46 | return (points_out[..., :3] / denom).squeeze(dim=0)
47 |
48 |
49 | def getWorld2View(R, t):
50 | Rt = np.zeros((4, 4))
51 | Rt[:3, :3] = R.transpose()
52 | Rt[:3, 3] = t
53 | Rt[3, 3] = 1.0
54 | return np.float32(Rt)
55 |
56 |
57 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
58 | Rt = np.zeros((4, 4))
59 | Rt[:3, :3] = R.transpose()
60 | Rt[:3, 3] = t
61 | Rt[3, 3] = 1.0
62 |
63 | C2W = np.linalg.inv(Rt)
64 | cam_center = C2W[:3, 3]
65 | cam_center = (cam_center + translate) * scale
66 | C2W[:3, 3] = cam_center
67 | Rt = np.linalg.inv(C2W)
68 | return np.float32(Rt)
69 |
70 |
71 | def getWorld2View3(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
72 | Rt = np.zeros((4, 4))
73 | Rt[:3, :3] = R
74 | Rt[:3, 3] = t
75 | Rt[3, 3] = 1.0
76 |
77 | C2W = np.linalg.inv(Rt)
78 | cam_center = C2W[:3, 3]
79 | cam_center = (cam_center + translate) * scale
80 | C2W[:3, 3] = cam_center
81 | Rt = np.linalg.inv(C2W)
82 | return np.float32(Rt)
83 |
84 |
85 | def getProjectionMatrix(znear, zfar, fovX, fovY):
86 | tanHalfFovY = math.tan((fovY / 2))
87 | tanHalfFovX = math.tan((fovX / 2))
88 |
89 | top = tanHalfFovY * znear
90 | bottom = -top
91 | right = tanHalfFovX * znear
92 | left = -right
93 |
94 | P = torch.zeros(4, 4)
95 |
96 | z_sign = 1.0
97 |
98 | P[0, 0] = 2.0 * znear / (right - left)
99 | P[1, 1] = 2.0 * znear / (top - bottom)
100 | P[0, 2] = (right + left) / (right - left)
101 | P[1, 2] = (top + bottom) / (top - bottom)
102 | P[3, 2] = z_sign
103 | P[2, 2] = z_sign * zfar / (zfar - znear)
104 | P[2, 3] = -(zfar * znear) / (zfar - znear)
105 | return P
106 |
107 |
108 | def fov2focal(fov, pixels):
109 | return pixels / (2 * math.tan(fov / 2))
110 |
111 |
112 | def focal2fov(focal, pixels):
113 | return 2*math.atan(pixels/(2*focal))
114 |
115 |
116 | def procrustes(S1, S2):
117 | '''
118 | Computes a similarity transform (sR, t) that takes
119 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
120 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
121 | i.e. solves the orthogonal Procrutes problem.
122 | '''
123 | transposed = False
124 | if S1.shape[0] != 3 and S1.shape[0] != 2:
125 | S1 = S1.T
126 | S2 = S2.T
127 | transposed = True
128 | assert (S2.shape[1] == S1.shape[1])
129 |
130 | # 1. Remove mean.
131 | mu1 = S1.mean(axis=1, keepdims=True)
132 | mu2 = S2.mean(axis=1, keepdims=True)
133 | X1 = S1 - mu1
134 | X2 = S2 - mu2
135 |
136 | # print('X1', X1.shape)
137 |
138 | # 2. Compute variance of X1 used for scale.
139 | var1 = torch.sum(X1 ** 2)
140 |
141 | # print('var', var1.shape)
142 |
143 | # 3. The outer product of X1 and X2.
144 | K = X1.mm(X2.T)
145 |
146 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
147 | # singular vectors of K.
148 | U, s, V = torch.svd(K)
149 | # V = Vh.T
150 | # Construct Z that fixes the orientation of R to get det(R)=1.
151 | Z = torch.eye(U.shape[0], device=S1.device)
152 | Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
153 | # Construct R.
154 | R = V.mm(Z.mm(U.T))
155 |
156 | # print('R', X1.shape)
157 |
158 | # 5. Recover scale.
159 | scale = torch.trace(R.mm(K)) / var1
160 | # print(R.shape, mu1.shape)
161 | # 6. Recover translation.
162 | t = mu2 - scale * (R.mm(mu1))
163 | # t = mu2 - (R.mm(mu1))
164 | # print(t.shape)
165 |
166 | # 7. Error:
167 | S1_hat = scale * R.mm(S1) + t
168 | # S1_hat = R.mm(S1) + t
169 |
170 | if transposed:
171 | S1_hat = S1_hat.T
172 |
173 | R_ = torch.eye(4).to(S1)
174 | R_[:3, :3] = R
175 | T_ = torch.eye(4).to(S1)
176 | T_[:3, -1] = t.squeeze(-1)
177 | S_ = torch.eye(4).to(S1)
178 | transf = T_@S_@R_
179 |
180 | return S1_hat, transf
181 |
182 |
183 | # def procrustes(S1, S2,weights=None):
184 | # '''
185 | # Computes a similarity transform (sR, t) that takes
186 | # a set of 3D points S1 (BxNx3) closest to a set of 3D points, S2,
187 | # where R is an 3x3 rotation matrix, t 3x1 translation, s scale. / mod : assuming scale is 1
188 | # i.e. solves the orthogonal Procrutes problem.
189 | # '''
190 | # transposed = False
191 | # if S1.shape[0] != 3 and S1.shape[0] != 2:
192 | # S1 = S1.T
193 | # S2 = S2.T
194 | # transposed = True
195 | # assert (S2.shape[1] == S1.shape[1])
196 |
197 | # if weights is None:
198 | # weights = torch.ones_like(S1[:1,:])
199 |
200 | # # 1. Remove mean.
201 | # weights_norm = weights/(weights.sum(-1, keepdim=True) + 1e-6)
202 | # mu1 = (S1*weights_norm).sum(1, keepdim=True)
203 | # mu2 = (S2*weights_norm).sum(1, keepdim=True)
204 |
205 | # X1 = S1 - mu1
206 | # X2 = S2 - mu2
207 |
208 | # # diags = torch.stack([torch.diag(w.squeeze(0)) for w in weights]) # does batched version exist?
209 | # diags = torch.diag(weights.squeeze())
210 |
211 | # # 3. The outer product of X1 and X2.
212 | # K = (X1@diags).mm(X2.T)
213 | # # K = (X1@diags).bmm(X2.permute(0,2,1))
214 |
215 | # # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
216 | # U, s, V = torch.svd(K)
217 |
218 | # # Construct Z that fixes the orientation of R to get det(R)=1.
219 | # Z = torch.eye(U.shape[0], device=S1.device)
220 | # Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
221 | # # Construct R.
222 | # R = V.mm(Z.mm(U.T))
223 |
224 | # # 6. Recover translation.
225 | # t = mu2 - ((R.mm(mu1)))
226 |
227 | # # 7. Error:
228 | # S1_hat = R.mm(S1) + t
229 | # if transposed:
230 | # S1_hat = S1_hat.T
231 |
232 | # # Combine recovered transformation as single matrix
233 | # R_=torch.eye(4).to(S1)
234 | # R_[:3, :3]=R
235 | # T_=torch.eye(4).to(S1)
236 | # T_[:3, -1]=t.squeeze(-1)
237 | # S_=torch.eye(4).to(S1)
238 | # transf = T_@S_@R_
239 |
240 | # return (S1_hat-S2).square().mean(),transf
241 |
242 |
243 | def convert3x4_4x4(input):
244 | """
245 | Make into homogeneous cordinates by adding [0, 0, 0, 1] to the bottom.
246 | :param input: (N, 3, 4) or (3, 4) torch or np
247 | :return: (N, 4, 4) or (4, 4) torch or np
248 | """
249 | if torch.is_tensor(input):
250 | if len(input.shape) == 3:
251 | output = torch.cat([input, torch.zeros_like(
252 | input[:, 0:1])], dim=1) # (N, 4, 4)
253 | output[:, 3, 3] = 1.0
254 | else:
255 | output = torch.cat([input, torch.tensor(
256 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
257 | else:
258 | if len(input.shape) == 3:
259 | output = np.concatenate(
260 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
261 | output[:, 3, 3] = 1.0
262 | else:
263 | output = np.concatenate(
264 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
265 | output[3, 3] = 1.0
266 | return output
267 |
268 |
269 | def align_umeyama(model, data, known_scale=False):
270 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation
271 | of Transformation Parameters Between Two Point Patterns,
272 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991.
273 |
274 | model = s * R * data + t
275 |
276 | Input:
277 | model -- first trajectory (nx3), numpy array type
278 | data -- second trajectory (nx3), numpy array type
279 |
280 | Output:
281 | s -- scale factor (scalar)
282 | R -- rotation matrix (3x3)
283 | t -- translation vector (3x1)
284 | t_error -- translational error per point (1xn)
285 |
286 | """
287 |
288 | # substract mean
289 | mu_M = model.mean(0)
290 | mu_D = data.mean(0)
291 | model_zerocentered = model - mu_M
292 | data_zerocentered = data - mu_D
293 | n = np.shape(model)[0]
294 |
295 | # correlation
296 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered)
297 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum()
298 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)
299 |
300 | D_svd = np.diag(D_svd)
301 | V_svd = np.transpose(V_svd)
302 |
303 | S = np.eye(3)
304 | if (np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0):
305 | S[2, 2] = -1
306 |
307 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))
308 |
309 | if known_scale:
310 | s = 1
311 | else:
312 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S))
313 |
314 | t = mu_M-s*np.dot(R, mu_D)
315 |
316 | return s, R, t
317 |
318 |
319 | def _getIndices(n_aligned, total_n):
320 | if n_aligned == -1:
321 | idxs = np.arange(0, total_n)
322 | else:
323 | assert n_aligned <= total_n and n_aligned >= 1
324 | idxs = np.arange(0, n_aligned)
325 | return idxs
326 |
327 |
328 | # align by similarity transformation
329 | def align_sim3(p_es, p_gt, n_aligned=-1):
330 | '''
331 | calculate s, R, t so that:
332 | gt = R * s * est + t
333 | '''
334 | idxs = _getIndices(n_aligned, p_es.shape[0])
335 | est_pos = p_es[idxs, 0:3]
336 | gt_pos = p_gt[idxs, 0:3]
337 | try:
338 | s, R, t = align_umeyama(gt_pos, est_pos) # note the order
339 | except:
340 | print('[WARNING] align_poses.py: SVD did not converge!')
341 | s, R, t = 1.0, np.eye(3), np.zeros(3)
342 | return s, R, t
343 |
344 |
345 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None):
346 | """Align c to b using the sim3 from a to b.
347 | :param traj_a: (N0, 3/4, 4) torch tensor
348 | :param traj_b: (N0, 3/4, 4) torch tensor
349 | :param traj_c: None or (N1, 3/4, 4) torch tensor
350 | :return: (N1, 4, 4) torch tensor
351 | """
352 | device = traj_a.device
353 | if traj_c is None:
354 | traj_c = traj_a.clone()
355 |
356 | traj_a = traj_a.float().cpu().numpy()
357 | traj_b = traj_b.float().cpu().numpy()
358 | traj_c = traj_c.float().cpu().numpy()
359 |
360 | R_a = traj_a[:, :3, :3] # (N0, 3, 3)
361 | t_a = traj_a[:, :3, 3] # (N0, 3)
362 |
363 | R_b = traj_b[:, :3, :3] # (N0, 3, 3)
364 | t_b = traj_b[:, :3, 3] # (N0, 3)
365 |
366 | # This function works in quaternion.
367 | # scalar, (3, 3), (3, ) gt = R * s * est + t.
368 | s, R, t = align_sim3(t_a, t_b)
369 |
370 | # reshape tensors
371 | R = R[None, :, :].astype(np.float32) # (1, 3, 3)
372 | t = t[None, :, None].astype(np.float32) # (1, 3, 1)
373 | s = float(s)
374 |
375 | R_c = traj_c[:, :3, :3] # (N1, 3, 3)
376 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1)
377 |
378 | R_c_aligned = R @ R_c # (N1, 3, 3)
379 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1)
380 | traj_c_aligned = np.concatenate(
381 | [R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4)
382 |
383 | # append the last row
384 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4)
385 |
386 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device)
387 |
388 | return traj_c_aligned # (N1, 4, 4)
389 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 |
11 | import torch
12 |
13 | def mse(img1, img2):
14 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
15 |
16 | def psnr(img1, img2):
17 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
18 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
19 |
20 | import numpy as np
21 | import matplotlib
22 | import matplotlib as mpl
23 | import matplotlib.pyplot as plt
24 | from matplotlib.patches import Patch
25 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
26 |
27 | class CameraPoseVisualizer:
28 | def __init__(self, xlim=None, ylim=None, zlim=None):
29 | self.fig = plt.figure(figsize=(18, 7))
30 | self.ax = self.fig.add_subplot(projection='3d')
31 | self.ax.set_aspect("auto")
32 | if xlim is not None and ylim is not None and zlim is not None:
33 | self.ax.set_xlim(xlim)
34 | self.ax.set_ylim(ylim)
35 | self.ax.set_zlim(zlim)
36 |
37 | self.ax.view_init(elev=10., azim=45)
38 | self.ax.xaxis.set_tick_params(labelbottom=False)
39 | self.ax.yaxis.set_tick_params(labelleft=False)
40 | self.ax.zaxis.set_tick_params(labelleft=False)
41 | self.ax.set_xlabel('x')
42 | self.ax.set_ylabel('y')
43 | self.ax.set_zlabel('z')
44 | # plt.tight_layout()
45 | print('initialize camera pose visualizer')
46 |
47 | def extrinsic2pyramid(self, extrinsic, color='r', focal_len_scaled=5, aspect_ratio=0.3):
48 | vertex_std = np.array([[0, 0, 0, 1],
49 | [focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1],
50 | [focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1],
51 | [-focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1],
52 | [-focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1]])
53 | vertex_transformed = vertex_std @ extrinsic.T
54 | meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
55 | [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
56 | [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
57 | [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
58 | [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]]
59 | self.ax.add_collection3d(
60 | Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
61 |
62 | def customize_legend(self, list_label):
63 | list_handle = []
64 | for idx, label in enumerate(list_label):
65 | color = plt.cm.rainbow(idx / len(list_label))
66 | patch = Patch(color=color, label=label)
67 | list_handle.append(patch)
68 | plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)
69 |
70 | def colorbar(self, max_frame_length):
71 | cmap = mpl.cm.rainbow
72 | norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
73 | self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='Frame Number')
74 |
75 | def show(self):
76 | plt.title('Extrinsic Parameters')
77 | plt.show()
78 |
79 | def add_traj(self, poses, c='r', alpha=0.5):
80 | x = [float(pose[0, 3]) for pose in poses]
81 | y = [float(pose[1, 3]) for pose in poses]
82 | z = [float(pose[2, 3]) for pose in poses]
83 | self.ax.plot(x,y,z, c=c, alpha=alpha)
84 |
85 | def save(self, path):
86 | plt.tight_layout()
87 | self.fig.savefig(path, bbox_inches='tight')
88 |
89 |
90 | def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
91 | """Converts a depth map to a color image.
92 |
93 | Args:
94 | value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
95 | vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
96 | vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
97 | cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
98 | invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
99 | invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
100 | background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
101 | gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
102 | value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
103 |
104 | Returns:
105 | numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
106 | """
107 | if isinstance(value, torch.Tensor):
108 | value = value.detach().cpu().numpy()
109 |
110 | value = value.squeeze()
111 | if invalid_mask is None:
112 | invalid_mask = value == invalid_val
113 | mask = np.logical_not(invalid_mask)
114 |
115 | # normalize
116 | vmin = np.percentile(value[mask],2) if vmin is None else vmin
117 | vmax = np.percentile(value[mask],85) if vmax is None else vmax
118 | if vmin != vmax:
119 | value = (value - vmin) / (vmax - vmin) # vmin..vmax
120 | else:
121 | # Avoid 0-division
122 | value = value * 0.
123 |
124 | # squeeze last dim if it exists
125 | # grey out the invalid values
126 |
127 | value[invalid_mask] = np.nan
128 | cmapper = mpl.cm.get_cmap(cmap)
129 | if value_transform:
130 | value = value_transform(value)
131 | # value = value / value.max()
132 | value = cmapper(value, bytes=True) # (nxmx4)
133 |
134 | # img = value[:, :, :]
135 | img = value[...]
136 | img[invalid_mask] = background_color
137 |
138 | # return img.transpose((2, 0, 1))
139 | if gamma_corrected:
140 | # gamma correction
141 | img = img / 255
142 | img = np.power(img, 2.2)
143 | img = img * 255
144 | img = img.astype(np.uint8)
145 | return img
146 |
147 |
148 |
149 | def cm_prune(x_):
150 | """Custom colormap to visualize pruning"""
151 | if isinstance(x_, torch.Tensor):
152 | x_ = x_.cpu().numpy()
153 | max_i = max(x_)
154 | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
155 | return cm_BlRdGn(norm_x)
156 |
157 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from math import exp
14 |
15 | def l1_loss(network_output, gt):
16 | return torch.abs((network_output - gt)).mean()
17 |
18 | def l2_loss(network_output, gt):
19 | return ((network_output - gt) ** 2).mean()
20 |
21 | def gaussian(window_size, sigma):
22 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
23 | return gauss / gauss.sum()
24 |
25 | def create_window(window_size, channel):
26 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
27 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
28 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
29 | return window
30 |
31 | def ssim(img1, img2, window_size=11, size_average=True):
32 | channel = img1.size(-3)
33 | window = create_window(window_size, channel)
34 |
35 | if img1.is_cuda:
36 | window = window.cuda(img1.get_device())
37 | window = window.type_as(img1)
38 |
39 | return _ssim(img1, img2, window, window_size, channel, size_average)
40 |
41 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
42 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
43 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
44 |
45 | mu1_sq = mu1.pow(2)
46 | mu2_sq = mu2.pow(2)
47 | mu1_mu2 = mu1 * mu2
48 |
49 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
50 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
51 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
52 |
53 | C1 = 0.01 ** 2
54 | C2 = 0.03 ** 2
55 |
56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
57 |
58 | if size_average:
59 | return ssim_map.mean()
60 | else:
61 | return ssim_map.mean(1).mean(1).mean(1)
62 |
63 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Inria
2 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
3 | # All rights reserved.
4 | #
5 | # This software is free for non-commercial, research and evaluation use
6 | # under the terms of the LICENSE_inria.md file.
7 | #
8 | # For inquiries contact george.drettakis@inria.fr
9 |
10 | from errno import EEXIST
11 | from os import makedirs, path
12 | import os
13 |
14 | def mkdir_p(folder_path):
15 | # Creates a directory. equivalent to using mkdir -p on the command line
16 | try:
17 | makedirs(folder_path)
18 | except OSError as exc: # Python >2.5
19 | if exc.errno == EEXIST and path.isdir(folder_path):
20 | pass
21 | else:
22 | raise
23 |
24 | def searchForMaxIteration(folder):
25 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
26 | return max(saved_iters)
27 |
--------------------------------------------------------------------------------
/utils/utils_poses/ATE/align_trajectory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | #!/usr/bin/env python2
11 | # -*- coding: utf-8 -*-
12 |
13 | import numpy as np
14 | import utils.utils_poses.ATE.transformations as tfs
15 |
16 |
17 | def get_best_yaw(C):
18 | '''
19 | maximize trace(Rz(theta) * C)
20 | '''
21 | assert C.shape == (3, 3)
22 |
23 | A = C[0, 1] - C[1, 0]
24 | B = C[0, 0] + C[1, 1]
25 | theta = np.pi / 2 - np.arctan2(B, A)
26 |
27 | return theta
28 |
29 |
30 | def rot_z(theta):
31 | R = tfs.rotation_matrix(theta, [0, 0, 1])
32 | R = R[0:3, 0:3]
33 |
34 | return R
35 |
36 |
37 | def align_umeyama(model, data, known_scale=False, yaw_only=False):
38 | """Implementation of the paper: S. Umeyama, Least-Squares Estimation
39 | of Transformation Parameters Between Two Point Patterns,
40 | IEEE Trans. Pattern Anal. Mach. Intell., vol. 13, no. 4, 1991.
41 |
42 | model = s * R * data + t
43 |
44 | Input:
45 | model -- first trajectory (nx3), numpy array type
46 | data -- second trajectory (nx3), numpy array type
47 |
48 | Output:
49 | s -- scale factor (scalar)
50 | R -- rotation matrix (3x3)
51 | t -- translation vector (3x1)
52 | t_error -- translational error per point (1xn)
53 |
54 | """
55 |
56 | # substract mean
57 | mu_M = model.mean(0)
58 | mu_D = data.mean(0)
59 | model_zerocentered = model - mu_M
60 | data_zerocentered = data - mu_D
61 | n = np.shape(model)[0]
62 |
63 | # correlation
64 | C = 1.0/n*np.dot(model_zerocentered.transpose(), data_zerocentered)
65 | sigma2 = 1.0/n*np.multiply(data_zerocentered, data_zerocentered).sum()
66 | U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)
67 |
68 | D_svd = np.diag(D_svd)
69 | V_svd = np.transpose(V_svd)
70 |
71 | S = np.eye(3)
72 | if(np.linalg.det(U_svd)*np.linalg.det(V_svd) < 0):
73 | S[2, 2] = -1
74 |
75 | if yaw_only:
76 | rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered)
77 | theta = get_best_yaw(rot_C)
78 | R = rot_z(theta)
79 | else:
80 | R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))
81 |
82 | if known_scale:
83 | s = 1
84 | else:
85 | s = 1.0/sigma2*np.trace(np.dot(D_svd, S))
86 |
87 | t = mu_M-s*np.dot(R, mu_D)
88 |
89 | return s, R, t
90 |
--------------------------------------------------------------------------------
/utils/utils_poses/ATE/align_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #!/usr/bin/env python2
10 | # -*- coding: utf-8 -*-
11 |
12 | import numpy as np
13 |
14 | import utils.utils_poses.ATE.transformations as tfs
15 | import utils.utils_poses.ATE.align_trajectory as align
16 |
17 |
18 | def _getIndices(n_aligned, total_n):
19 | if n_aligned == -1:
20 | idxs = np.arange(0, total_n)
21 | else:
22 | assert n_aligned <= total_n and n_aligned >= 1
23 | idxs = np.arange(0, n_aligned)
24 | return idxs
25 |
26 |
27 | def alignPositionYawSingle(p_es, p_gt, q_es, q_gt):
28 | '''
29 | calcualte the 4DOF transformation: yaw R and translation t so that:
30 | gt = R * est + t
31 | '''
32 |
33 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
34 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
35 | g_rot = tfs.quaternion_matrix(q_gt_0)
36 | g_rot = g_rot[0:3, 0:3]
37 | est_rot = tfs.quaternion_matrix(q_es_0)
38 | est_rot = est_rot[0:3, 0:3]
39 |
40 | C_R = np.dot(est_rot, g_rot.transpose())
41 | theta = align.get_best_yaw(C_R)
42 | R = align.rot_z(theta)
43 | t = p_gt_0 - np.dot(R, p_es_0)
44 |
45 | return R, t
46 |
47 |
48 | def alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned=1):
49 | if n_aligned == 1:
50 | R, t = alignPositionYawSingle(p_es, p_gt, q_es, q_gt)
51 | return R, t
52 | else:
53 | idxs = _getIndices(n_aligned, p_es.shape[0])
54 | est_pos = p_es[idxs, 0:3]
55 | gt_pos = p_gt[idxs, 0:3]
56 | _, R, t = align.align_umeyama(gt_pos, est_pos, known_scale=True,
57 | yaw_only=True) # note the order
58 | t = np.array(t)
59 | t = t.reshape((3, ))
60 | R = np.array(R)
61 | return R, t
62 |
63 |
64 | # align by a SE3 transformation
65 | def alignSE3Single(p_es, p_gt, q_es, q_gt):
66 | '''
67 | Calculate SE3 transformation R and t so that:
68 | gt = R * est + t
69 | Using only the first poses of est and gt
70 | '''
71 |
72 | p_es_0, q_es_0 = p_es[0, :], q_es[0, :]
73 | p_gt_0, q_gt_0 = p_gt[0, :], q_gt[0, :]
74 |
75 | g_rot = tfs.quaternion_matrix(q_gt_0)
76 | g_rot = g_rot[0:3, 0:3]
77 | est_rot = tfs.quaternion_matrix(q_es_0)
78 | est_rot = est_rot[0:3, 0:3]
79 |
80 | R = np.dot(g_rot, np.transpose(est_rot))
81 | t = p_gt_0 - np.dot(R, p_es_0)
82 |
83 | return R, t
84 |
85 |
86 | def alignSE3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
87 | '''
88 | Calculate SE3 transformation R and t so that:
89 | gt = R * est + t
90 | '''
91 | if n_aligned == 1:
92 | R, t = alignSE3Single(p_es, p_gt, q_es, q_gt)
93 | return R, t
94 | else:
95 | idxs = _getIndices(n_aligned, p_es.shape[0])
96 | est_pos = p_es[idxs, 0:3]
97 | gt_pos = p_gt[idxs, 0:3]
98 | s, R, t = align.align_umeyama(gt_pos, est_pos,
99 | known_scale=True) # note the order
100 | t = np.array(t)
101 | t = t.reshape((3, ))
102 | R = np.array(R)
103 | return R, t
104 |
105 |
106 | # align by similarity transformation
107 | def alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned=-1):
108 | '''
109 | calculate s, R, t so that:
110 | gt = R * s * est + t
111 | '''
112 | idxs = _getIndices(n_aligned, p_es.shape[0])
113 | est_pos = p_es[idxs, 0:3]
114 | gt_pos = p_gt[idxs, 0:3]
115 | s, R, t = align.align_umeyama(gt_pos, est_pos) # note the order
116 | return s, R, t
117 |
118 |
119 | # a general interface
120 | def alignTrajectory(p_es, p_gt, q_es, q_gt, method, n_aligned=-1):
121 | '''
122 | calculate s, R, t so that:
123 | gt = R * s * est + t
124 | method can be: sim3, se3, posyaw, none;
125 | n_aligned: -1 means using all the frames
126 | '''
127 | assert p_es.shape[1] == 3
128 | assert p_gt.shape[1] == 3
129 | assert q_es.shape[1] == 4
130 | assert q_gt.shape[1] == 4
131 |
132 | s = 1
133 | R = None
134 | t = None
135 | if method == 'sim3':
136 | assert n_aligned >= 2 or n_aligned == -1, "sim3 uses at least 2 frames"
137 | s, R, t = alignSIM3(p_es, p_gt, q_es, q_gt, n_aligned)
138 | elif method == 'se3':
139 | R, t = alignSE3(p_es, p_gt, q_es, q_gt, n_aligned)
140 | elif method == 'posyaw':
141 | R, t = alignPositionYaw(p_es, p_gt, q_es, q_gt, n_aligned)
142 | elif method == 'none':
143 | R = np.identity(3)
144 | t = np.zeros((3, ))
145 | else:
146 | assert False, 'unknown alignment method'
147 |
148 | return s, R, t
149 |
150 |
151 | if __name__ == '__main__':
152 | pass
153 |
--------------------------------------------------------------------------------
/utils/utils_poses/ATE/compute_trajectory_errors.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #!/usr/bin/env python2
10 |
11 | import os
12 | import numpy as np
13 |
14 | import utils.utils_poses.ATE.trajectory_utils as tu
15 | import utils.utils_poses.ATE.transformations as tf
16 |
17 |
18 | def compute_relative_error(p_es, q_es, p_gt, q_gt, T_cm, dist, max_dist_diff,
19 | accum_distances=[],
20 | scale=1.0):
21 |
22 | if len(accum_distances) == 0:
23 | accum_distances = tu.get_distance_from_start(p_gt)
24 | comparisons = tu.compute_comparison_indices_length(
25 | accum_distances, dist, max_dist_diff)
26 |
27 | n_samples = len(comparisons)
28 | print('number of samples = {0} '.format(n_samples))
29 | if n_samples < 2:
30 | print("Too few samples! Will not compute.")
31 | return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]),\
32 | np.array([]), np.array([])
33 |
34 | T_mc = np.linalg.inv(T_cm)
35 | errors = []
36 | for idx, c in enumerate(comparisons):
37 | if not c == -1:
38 | T_c1 = tu.get_rigid_body_trafo(q_es[idx, :], p_es[idx, :])
39 | T_c2 = tu.get_rigid_body_trafo(q_es[c, :], p_es[c, :])
40 | T_c1_c2 = np.dot(np.linalg.inv(T_c1), T_c2)
41 | T_c1_c2[:3, 3] *= scale
42 |
43 | T_m1 = tu.get_rigid_body_trafo(q_gt[idx, :], p_gt[idx, :])
44 | T_m2 = tu.get_rigid_body_trafo(q_gt[c, :], p_gt[c, :])
45 | T_m1_m2 = np.dot(np.linalg.inv(T_m1), T_m2)
46 |
47 | T_m1_m2_in_c1 = np.dot(T_cm, np.dot(T_m1_m2, T_mc))
48 | T_error_in_c2 = np.dot(np.linalg.inv(T_m1_m2_in_c1), T_c1_c2)
49 | T_c2_rot = np.eye(4)
50 | T_c2_rot[0:3, 0:3] = T_c2[0:3, 0:3]
51 | T_error_in_w = np.dot(T_c2_rot, np.dot(
52 | T_error_in_c2, np.linalg.inv(T_c2_rot)))
53 | errors.append(T_error_in_w)
54 |
55 | error_trans_norm = []
56 | error_trans_perc = []
57 | error_yaw = []
58 | error_gravity = []
59 | e_rot = []
60 | e_rot_deg_per_m = []
61 | for e in errors:
62 | tn = np.linalg.norm(e[0:3, 3])
63 | error_trans_norm.append(tn)
64 | error_trans_perc.append(tn / dist * 100)
65 | ypr_angles = tf.euler_from_matrix(e, 'rzyx')
66 | e_rot.append(tu.compute_angle(e))
67 | error_yaw.append(abs(ypr_angles[0])*180.0/np.pi)
68 | error_gravity.append(
69 | np.sqrt(ypr_angles[1]**2+ypr_angles[2]**2)*180.0/np.pi)
70 | e_rot_deg_per_m.append(e_rot[-1] / dist)
71 | return errors, np.array(error_trans_norm), np.array(error_trans_perc),\
72 | np.array(error_yaw), np.array(error_gravity), np.array(e_rot),\
73 | np.array(e_rot_deg_per_m)
74 |
75 |
76 | def compute_absolute_error(p_es_aligned, q_es_aligned, p_gt, q_gt):
77 | e_trans_vec = (p_gt-p_es_aligned)
78 | e_trans = np.sqrt(np.sum(e_trans_vec**2, 1))
79 |
80 |
81 | # orientation error
82 | e_rot = np.zeros((len(e_trans,)))
83 | e_ypr = np.zeros(np.shape(p_es_aligned))
84 | for i in range(np.shape(p_es_aligned)[0]):
85 | R_we = tf.matrix_from_quaternion(q_es_aligned[i, :])
86 | R_wg = tf.matrix_from_quaternion(q_gt[i, :])
87 | e_R = np.dot(R_we, np.linalg.inv(R_wg))
88 | e_ypr[i, :] = tf.euler_from_matrix(e_R, 'rzyx')
89 | e_rot[i] = np.rad2deg(np.linalg.norm(tf.logmap_so3(e_R[:3, :3])))
90 | # scale drift
91 | motion_gt = np.diff(p_gt, 0)
92 | motion_es = np.diff(p_es_aligned, 0)
93 | dist_gt = np.sqrt(np.sum(np.multiply(motion_gt, motion_gt), 1))
94 | dist_es = np.sqrt(np.sum(np.multiply(motion_es, motion_es), 1))
95 | e_scale_perc = np.abs((np.divide(dist_es, dist_gt)-1.0) * 100)
96 | # ate = np.sqrt(np.mean(np.asarray(e_trans) ** 2))
97 | return e_trans, e_trans_vec, e_rot, e_ypr, e_scale_perc
98 |
--------------------------------------------------------------------------------
/utils/utils_poses/ATE/results_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #!/usr/bin/env python2
10 | import os
11 | # import yaml
12 | import numpy as np
13 |
14 |
15 | def compute_statistics(data_vec):
16 | stats = dict()
17 | if len(data_vec) > 0:
18 | stats['rmse'] = float(
19 | np.sqrt(np.dot(data_vec, data_vec) / len(data_vec)))
20 | stats['mean'] = float(np.mean(data_vec))
21 | stats['median'] = float(np.median(data_vec))
22 | stats['std'] = float(np.std(data_vec))
23 | stats['min'] = float(np.min(data_vec))
24 | stats['max'] = float(np.max(data_vec))
25 | stats['num_samples'] = int(len(data_vec))
26 | else:
27 | stats['rmse'] = 0
28 | stats['mean'] = 0
29 | stats['median'] = 0
30 | stats['std'] = 0
31 | stats['min'] = 0
32 | stats['max'] = 0
33 | stats['num_samples'] = 0
34 |
35 | return stats
36 |
37 |
38 | # def update_and_save_stats(new_stats, label, yaml_filename):
39 | # stats = dict()
40 | # if os.path.exists(yaml_filename):
41 | # stats = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader)
42 | # stats[label] = new_stats
43 | #
44 | # with open(yaml_filename, 'w') as outfile:
45 | # outfile.write(yaml.dump(stats, default_flow_style=False))
46 | #
47 | # return
48 | #
49 | #
50 | # def compute_and_save_statistics(data_vec, label, yaml_filename):
51 | # new_stats = compute_statistics(data_vec)
52 | # update_and_save_stats(new_stats, label, yaml_filename)
53 | #
54 | # return new_stats
55 | #
56 | #
57 | # def write_tex_table(list_values, rows, cols, outfn):
58 | # '''
59 | # write list_values[row_idx][col_idx] to a table that is ready to be pasted
60 | # into latex source
61 | #
62 | # list_values is a list of row values
63 | #
64 | # The value should be string of desired format
65 | # '''
66 | #
67 | # assert len(rows) >= 1
68 | # assert len(cols) >= 1
69 | #
70 | # with open(outfn, 'w') as f:
71 | # # write header
72 | # f.write(' & ')
73 | # for col_i in cols[:-1]:
74 | # f.write(col_i + ' & ')
75 | # f.write(' ' + cols[-1]+'\n')
76 | #
77 | # # write each row
78 | # for row_idx, row_i in enumerate(list_values):
79 | # f.write(rows[row_idx] + ' & ')
80 | # row_values = list_values[row_idx]
81 | # for col_idx in range(len(row_values) - 1):
82 | # f.write(row_values[col_idx] + ' & ')
83 | # f.write(' ' + row_values[-1]+' \n')
84 |
--------------------------------------------------------------------------------
/utils/utils_poses/ATE/trajectory_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #!/usr/bin/env python2
10 | """
11 | @author: Christian Forster
12 | """
13 |
14 | import os
15 | import numpy as np
16 | import utils.utils_poses.ATE.transformations as tf
17 |
18 |
19 | def get_rigid_body_trafo(quat, trans):
20 | T = tf.quaternion_matrix(quat)
21 | T[0:3, 3] = trans
22 | return T
23 |
24 |
25 | def get_distance_from_start(gt_translation):
26 | distances = np.diff(gt_translation[:, 0:3], axis=0)
27 | distances = np.sqrt(np.sum(np.multiply(distances, distances), 1))
28 | distances = np.cumsum(distances)
29 | distances = np.concatenate(([0], distances))
30 | return distances
31 |
32 |
33 | def compute_comparison_indices_length(distances, dist, max_dist_diff):
34 | max_idx = len(distances)
35 | comparisons = []
36 | for idx, d in enumerate(distances):
37 | best_idx = -1
38 | error = max_dist_diff
39 | for i in range(idx, max_idx):
40 | if np.abs(distances[i]-(d+dist)) < error:
41 | best_idx = i
42 | error = np.abs(distances[i] - (d+dist))
43 | if best_idx != -1:
44 | comparisons.append(best_idx)
45 | return comparisons
46 |
47 |
48 | def compute_angle(transform):
49 | """
50 | Compute the rotation angle from a 4x4 homogeneous matrix.
51 | """
52 | # an invitation to 3-d vision, p 27
53 | return np.arccos(
54 | min(1, max(-1, (np.trace(transform[0:3, 0:3]) - 1)/2)))*180.0/np.pi
55 |
--------------------------------------------------------------------------------
/utils/utils_poses/align_traj.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from utils.utils_poses.ATE.align_utils import alignTrajectory
13 | from utils.utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4
14 |
15 |
16 | def pts_dist_max(pts):
17 | """
18 | :param pts: (N, 3) torch or np
19 | :return: scalar
20 | """
21 | if torch.is_tensor(pts):
22 | dist = pts.unsqueeze(0) - pts.unsqueeze(1) # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
23 | dist = dist[0] # (N, 3)
24 | dist = dist.norm(dim=1) # (N, )
25 | max_dist = dist.max()
26 | else:
27 | dist = pts[None, :, :] - pts[:, None, :] # (1, N, 3) - (N, 1, 3) -> (N, N, 3)
28 | dist = dist[0] # (N, 3)
29 | dist = np.linalg.norm(dist, axis=1) # (N, )
30 | max_dist = dist.max()
31 | return max_dist
32 |
33 |
34 | def align_ate_c2b_use_a2b(traj_a, traj_b, traj_c=None, method='sim3'):
35 | """Align c to b using the sim3 from a to b.
36 | :param traj_a: (N0, 3/4, 4) torch tensor
37 | :param traj_b: (N0, 3/4, 4) torch tensor
38 | :param traj_c: None or (N1, 3/4, 4) torch tensor
39 | :return: (N1, 4, 4) torch tensor
40 | """
41 | device = traj_a.device
42 | if traj_c is None:
43 | traj_c = traj_a.clone()
44 |
45 | traj_a = traj_a.float().cpu().numpy()
46 | traj_b = traj_b.float().cpu().numpy()
47 | traj_c = traj_c.float().cpu().numpy()
48 |
49 | R_a = traj_a[:, :3, :3] # (N0, 3, 3)
50 | t_a = traj_a[:, :3, 3] # (N0, 3)
51 | quat_a = SO3_to_quat(R_a) # (N0, 4)
52 |
53 | R_b = traj_b[:, :3, :3] # (N0, 3, 3)
54 | t_b = traj_b[:, :3, 3] # (N0, 3)
55 | quat_b = SO3_to_quat(R_b) # (N0, 4)
56 |
57 | # This function works in quaternion.
58 | # scalar, (3, 3), (3, ) gt = R * s * est + t.
59 | s, R, t = alignTrajectory(t_a, t_b, quat_a, quat_b, method=method)
60 |
61 | # reshape tensors
62 | R = R[None, :, :].astype(np.float32) # (1, 3, 3)
63 | t = t[None, :, None].astype(np.float32) # (1, 3, 1)
64 | s = float(s)
65 |
66 | R_c = traj_c[:, :3, :3] # (N1, 3, 3)
67 | t_c = traj_c[:, :3, 3:4] # (N1, 3, 1)
68 |
69 | R_c_aligned = R @ R_c # (N1, 3, 3)
70 | t_c_aligned = s * (R @ t_c) + t # (N1, 3, 1)
71 | traj_c_aligned = np.concatenate([R_c_aligned, t_c_aligned], axis=2) # (N1, 3, 4)
72 |
73 | # append the last row
74 | traj_c_aligned = convert3x4_4x4(traj_c_aligned) # (N1, 4, 4)
75 |
76 | traj_c_aligned = torch.from_numpy(traj_c_aligned).to(device)
77 | return traj_c_aligned # (N1, 4, 4)
78 |
79 |
80 |
81 | def align_scale_c2b_use_a2b(traj_a, traj_b, traj_c=None):
82 | '''Scale c to b using the scale from a to b.
83 | :param traj_a: (N0, 3/4, 4) torch tensor
84 | :param traj_b: (N0, 3/4, 4) torch tensor
85 | :param traj_c: None or (N1, 3/4, 4) torch tensor
86 | :return:
87 | scaled_traj_c (N1, 4, 4) torch tensor
88 | scale scalar
89 | '''
90 | if traj_c is None:
91 | traj_c = traj_a.clone()
92 |
93 | t_a = traj_a[:, :3, 3] # (N, 3)
94 | t_b = traj_b[:, :3, 3] # (N, 3)
95 |
96 | # scale estimated poses to colmap scale
97 | # s_a2b: a*s ~ b
98 | scale_a2b = pts_dist_max(t_b) / pts_dist_max(t_a)
99 |
100 | traj_c[:, :3, 3] *= scale_a2b
101 |
102 | if traj_c.shape[1] == 3:
103 | traj_c = convert3x4_4x4(traj_c) # (N, 4, 4)
104 |
105 | return traj_c, scale_a2b # (N, 4, 4)
106 |
--------------------------------------------------------------------------------
/utils/utils_poses/comp_ate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 |
11 | import utils.utils_poses.ATE.trajectory_utils as tu
12 | import utils.utils_poses.ATE.transformations as tf
13 | def rotation_error(pose_error):
14 | """Compute rotation error
15 | Args:
16 | pose_error (4x4 array): relative pose error
17 | Returns:
18 | rot_error (float): rotation error
19 | """
20 | a = pose_error[0, 0]
21 | b = pose_error[1, 1]
22 | c = pose_error[2, 2]
23 | d = 0.5*(a+b+c-1.0)
24 | rot_error = np.arccos(max(min(d, 1.0), -1.0))
25 | return rot_error
26 |
27 | def translation_error(pose_error):
28 | """Compute translation error
29 | Args:
30 | pose_error (4x4 array): relative pose error
31 | Returns:
32 | trans_error (float): translation error
33 | """
34 | dx = pose_error[0, 3]
35 | dy = pose_error[1, 3]
36 | dz = pose_error[2, 3]
37 | trans_error = np.sqrt(dx**2+dy**2+dz**2)
38 | return trans_error
39 |
40 | def compute_rpe(gt, pred):
41 | trans_errors = []
42 | rot_errors = []
43 | for i in range(len(gt)-1):
44 | gt1 = gt[i]
45 | gt2 = gt[i+1]
46 | gt_rel = np.linalg.inv(gt1) @ gt2
47 |
48 | pred1 = pred[i]
49 | pred2 = pred[i+1]
50 | pred_rel = np.linalg.inv(pred1) @ pred2
51 | rel_err = np.linalg.inv(gt_rel) @ pred_rel
52 |
53 | trans_errors.append(translation_error(rel_err))
54 | rot_errors.append(rotation_error(rel_err))
55 | rpe_trans = np.mean(np.asarray(trans_errors))
56 | rpe_rot = np.mean(np.asarray(rot_errors))
57 | return rpe_trans, rpe_rot
58 |
59 | def compute_ATE(gt, pred):
60 | """Compute RMSE of ATE
61 | Args:
62 | gt: ground-truth poses
63 | pred: predicted poses
64 | """
65 | errors = []
66 |
67 | for i in range(len(pred)):
68 | # cur_gt = np.linalg.inv(gt_0) @ gt[i]
69 | cur_gt = gt[i]
70 | gt_xyz = cur_gt[:3, 3]
71 |
72 | # cur_pred = np.linalg.inv(pred_0) @ pred[i]
73 | cur_pred = pred[i]
74 | pred_xyz = cur_pred[:3, 3]
75 |
76 | align_err = gt_xyz - pred_xyz
77 |
78 | errors.append(np.sqrt(np.sum(align_err ** 2)))
79 | ate = np.sqrt(np.mean(np.asarray(errors) ** 2))
80 | return ate
81 |
82 |
--------------------------------------------------------------------------------
/utils/utils_poses/lie_group_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | from scipy.spatial.transform import Rotation as RotLib
12 |
13 |
14 | def SO3_to_quat(R):
15 | """
16 | :param R: (N, 3, 3) or (3, 3) np
17 | :return: (N, 4, ) or (4, ) np
18 | """
19 | x = RotLib.from_matrix(R)
20 | quat = x.as_quat()
21 | return quat
22 |
23 |
24 | def quat_to_SO3(quat):
25 | """
26 | :param quat: (N, 4, ) or (4, ) np
27 | :return: (N, 3, 3) or (3, 3) np
28 | """
29 | x = RotLib.from_quat(quat)
30 | R = x.as_matrix()
31 | return R
32 |
33 |
34 | def convert3x4_4x4(input):
35 | """
36 | :param input: (N, 3, 4) or (3, 4) torch or np
37 | :return: (N, 4, 4) or (4, 4) torch or np
38 | """
39 | if torch.is_tensor(input):
40 | if len(input.shape) == 3:
41 | output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1) # (N, 4, 4)
42 | output[:, 3, 3] = 1.0
43 | else:
44 | output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
45 | else:
46 | if len(input.shape) == 3:
47 | output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
48 | output[:, 3, 3] = 1.0
49 | else:
50 | output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0) # (4, 4)
51 | output[3, 3] = 1.0
52 | return output
53 |
54 |
55 | def vec2skew(v):
56 | """
57 | :param v: (3, ) torch tensor
58 | :return: (3, 3)
59 | """
60 | zero = torch.zeros(1, dtype=torch.float32, device=v.device)
61 | skew_v0 = torch.cat([ zero, -v[2:3], v[1:2]]) # (3, 1)
62 | skew_v1 = torch.cat([ v[2:3], zero, -v[0:1]])
63 | skew_v2 = torch.cat([-v[1:2], v[0:1], zero])
64 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0) # (3, 3)
65 | return skew_v # (3, 3)
66 |
67 |
68 | def Exp(r):
69 | """so(3) vector to SO(3) matrix
70 | :param r: (3, ) axis-angle, torch tensor
71 | :return: (3, 3)
72 | """
73 | skew_r = vec2skew(r) # (3, 3)
74 | norm_r = r.norm() + 1e-15
75 | eye = torch.eye(3, dtype=torch.float32, device=r.device)
76 | R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
77 | return R
78 |
79 |
80 | def make_c2w(r, t):
81 | """
82 | :param r: (3, ) axis-angle torch tensor
83 | :param t: (3, ) translation vector torch tensor
84 | :return: (4, 4)
85 | """
86 | R = Exp(r) # (3, 3)
87 | c2w = torch.cat([R, t.unsqueeze(1)], dim=1) # (3, 4)
88 | c2w = convert3x4_4x4(c2w) # (4, 4)
89 | return c2w
90 |
--------------------------------------------------------------------------------
/utils/utils_poses/vis_cam_traj.py:
--------------------------------------------------------------------------------
1 | '''
2 | BSD 2-Clause License
3 |
4 | Copyright (c) 2020, the NeRF++ authors
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | 1. Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | 2. Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 | '''
28 |
29 | import numpy as np
30 |
31 | try:
32 | import open3d as o3d
33 | except ImportError:
34 | pass
35 |
36 |
37 | def frustums2lineset(frustums):
38 | N = len(frustums)
39 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum
40 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum
41 | merged_colors = np.zeros((N*8, 3)) # each line gets a color
42 |
43 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums):
44 | merged_points[i*5:(i+1)*5, :] = frustum_points
45 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5
46 | merged_colors[i*8:(i+1)*8, :] = frustum_colors
47 |
48 | lineset = o3d.geometry.LineSet()
49 | lineset.points = o3d.utility.Vector3dVector(merged_points)
50 | lineset.lines = o3d.utility.Vector2iVector(merged_lines)
51 | lineset.colors = o3d.utility.Vector3dVector(merged_colors)
52 |
53 | return lineset
54 |
55 |
56 | def get_camera_frustum_opengl_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
57 | '''X right, Y up, Z backward to the observer.
58 | :param H, W:
59 | :param fx, fy:
60 | :param W2C: (4, 4) matrix
61 | :param frustum_length: scalar: scale the frustum
62 | :param color: (3,) list, frustum line color
63 | :return:
64 | frustum_points: (5, 3) frustum points in world coordinate
65 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
66 | frustum_colors: (8, 3) colors for 8 lines.
67 | '''
68 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
69 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
70 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
71 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
72 |
73 | # build view frustum in camera space in homogenous coordinate (5, 4)
74 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
75 | [-half_w, half_h, -frustum_length, 1.0], # top-left image corner
76 | [half_w, half_h, -frustum_length, 1.0], # top-right image corner
77 | [half_w, -half_h, -frustum_length, 1.0], # bottom-right image corner
78 | [-half_w, -half_h, -frustum_length, 1.0]]) # bottom-left image corner
79 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
80 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
81 |
82 | # transform view frustum from camera space to world space
83 | C2W = np.linalg.inv(W2C)
84 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
85 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
86 | return frustum_points, frustum_lines, frustum_colors
87 |
88 | def get_camera_frustum_opencv_coord(H, W, fx, fy, W2C, frustum_length=0.5, color=np.array([0., 1., 0.])):
89 | '''X right, Y up, Z backward to the observer.
90 | :param H, W:
91 | :param fx, fy:
92 | :param W2C: (4, 4) matrix
93 | :param frustum_length: scalar: scale the frustum
94 | :param color: (3,) list, frustum line color
95 | :return:
96 | frustum_points: (5, 3) frustum points in world coordinate
97 | frustum_lines: (8, 2) 8 lines connect 5 frustum points, specified in line start/end index.
98 | frustum_colors: (8, 3) colors for 8 lines.
99 | '''
100 | hfov = np.rad2deg(np.arctan(W / 2. / fx) * 2.)
101 | vfov = np.rad2deg(np.arctan(H / 2. / fy) * 2.)
102 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.))
103 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.))
104 |
105 | # build view frustum in camera space in homogenous coordinate (5, 4)
106 | frustum_points = np.array([[0., 0., 0., 1.0], # frustum origin
107 | [-half_w, -half_h, frustum_length, 1.0], # top-left image corner
108 | [ half_w, -half_h, frustum_length, 1.0], # top-right image corner
109 | [ half_w, half_h, frustum_length, 1.0], # bottom-right image corner
110 | [-half_w, +half_h, frustum_length, 1.0]]) # bottom-left image corner
111 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) # (8, 2)
112 | frustum_colors = np.tile(color.reshape((1, 3)), (frustum_lines.shape[0], 1)) # (8, 3)
113 |
114 | # transform view frustum from camera space to world space
115 | C2W = np.linalg.inv(W2C)
116 | frustum_points = np.matmul(C2W, frustum_points.T).T # (5, 4)
117 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] # (5, 3) remove homogenous coordinate
118 | return frustum_points, frustum_lines, frustum_colors
119 |
120 |
121 |
122 | def draw_camera_frustum_geometry(c2ws, H, W, fx=600.0, fy=600.0, frustum_length=0.5,
123 | color=np.array([29.0, 53.0, 87.0])/255.0, draw_now=False, coord='opengl'):
124 | '''
125 | :param c2ws: (N, 4, 4) np.array
126 | :param H: scalar
127 | :param W: scalar
128 | :param fx: scalar
129 | :param fy: scalar
130 | :param frustum_length: scalar
131 | :param color: None or (N, 3) or (3, ) or (1, 3) or (3, 1) np array
132 | :param draw_now: True/False call o3d vis now
133 | :return:
134 | '''
135 | N = c2ws.shape[0]
136 |
137 | num_ele = color.flatten().shape[0]
138 | if num_ele == 3:
139 | color = color.reshape(1, 3)
140 | color = np.tile(color, (N, 1))
141 |
142 | frustum_list = []
143 | if coord == 'opengl':
144 | for i in range(N):
145 | frustum_list.append(get_camera_frustum_opengl_coord(H, W, fx, fy,
146 | W2C=np.linalg.inv(c2ws[i]),
147 | frustum_length=frustum_length,
148 | color=color[i]))
149 | elif coord == 'opencv':
150 | for i in range(N):
151 | frustum_list.append(get_camera_frustum_opencv_coord(H, W, fx, fy,
152 | W2C=np.linalg.inv(c2ws[i]),
153 | frustum_length=frustum_length,
154 | color=color[i]))
155 | else:
156 | print('Undefined coordinate system. Exit')
157 | exit()
158 |
159 | frustums_geometry = frustums2lineset(frustum_list)
160 |
161 | if draw_now:
162 | o3d.visualization.draw_geometries([frustums_geometry])
163 |
164 | return frustums_geometry # this is an o3d geometry object.
165 |
--------------------------------------------------------------------------------
/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 | from matplotlib import tight_layout
13 | import copy
14 | from evo.core.trajectory import PosePath3D, PoseTrajectory3D
15 | from evo.main_ape import ape
16 | from evo.tools import plot
17 | from evo.core import sync
18 | from evo.tools import file_interface
19 | from evo.core import metrics
20 | import evo
21 | import torch
22 | import numpy as np
23 | from scipy.spatial.transform import Slerp
24 | from scipy.spatial.transform import Rotation as R
25 | import scipy.interpolate as si
26 |
27 |
28 | def interp_poses(c2ws, N_views):
29 | N_inputs = c2ws.shape[0]
30 | trans = c2ws[:, :3, 3:].permute(2, 1, 0)
31 | rots = c2ws[:, :3, :3]
32 | render_poses = []
33 | rots = R.from_matrix(rots)
34 | slerp = Slerp(np.linspace(0, 1, N_inputs), rots)
35 | interp_rots = torch.tensor(
36 | slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32))
37 | interp_trans = torch.nn.functional.interpolate(
38 | trans, size=N_views, mode='linear').permute(2, 1, 0)
39 | render_poses = torch.cat([interp_rots, interp_trans], dim=2)
40 | render_poses = convert3x4_4x4(render_poses)
41 | return render_poses
42 |
43 |
44 | def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree):
45 | target_trans = torch.tensor(scipy_bspline(
46 | c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2)
47 | rots = R.from_matrix(c2ws[:, :3, :3])
48 | slerp = Slerp(input_times, rots)
49 | target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs)
50 | target_rots = torch.tensor(
51 | slerp(target_times).as_matrix().astype(np.float32))
52 | target_poses = torch.cat([target_rots, target_trans], dim=2)
53 | target_poses = convert3x4_4x4(target_poses)
54 | return target_poses
55 |
56 |
57 | def poses_avg(poses):
58 |
59 | hwf = poses[0, :3, -1:]
60 |
61 | center = poses[:, :3, 3].mean(0)
62 | vec2 = normalize(poses[:, :3, 2].sum(0))
63 | up = poses[:, :3, 1].sum(0)
64 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
65 |
66 | return c2w
67 |
68 |
69 | def normalize(v):
70 | """Normalize a vector."""
71 | return v / np.linalg.norm(v)
72 |
73 |
74 | def viewmatrix(z, up, pos):
75 | vec2 = normalize(z)
76 | vec1_avg = up
77 | vec0 = normalize(np.cross(vec1_avg, vec2))
78 | vec1 = normalize(np.cross(vec2, vec0))
79 | m = np.stack([vec0, vec1, vec2, pos], 1)
80 | return m
81 |
82 |
83 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
84 | render_poses = []
85 | rads = np.array(list(rads) + [1.])
86 | hwf = c2w[:, 4:5]
87 |
88 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
89 | # c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads)
90 | # c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads)
91 | c = np.dot(c2w[:3, :4], np.array(
92 | [0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads)
93 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
94 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
95 | return render_poses
96 |
97 |
98 | def scipy_bspline(cv, n=100, degree=3, periodic=False):
99 | """ Calculate n samples on a bspline
100 |
101 | cv : Array ov control vertices
102 | n : Number of samples to return
103 | degree: Curve degree
104 | periodic: True - Curve is closed
105 | """
106 | cv = np.asarray(cv)
107 | count = cv.shape[0]
108 |
109 | # Closed curve
110 | if periodic:
111 | kv = np.arange(-degree, count+degree+1)
112 | factor, fraction = divmod(count+degree+1, count)
113 | cv = np.roll(np.concatenate(
114 | (cv,) * factor + (cv[:fraction],)), -1, axis=0)
115 | degree = np.clip(degree, 1, degree)
116 |
117 | # Opened curve
118 | else:
119 | degree = np.clip(degree, 1, count-1)
120 | kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree)
121 |
122 | # Return samples
123 | max_param = count - (degree * (1-periodic))
124 | spl = si.BSpline(kv, cv, degree)
125 | return spl(np.linspace(0, max_param, n))
126 |
127 |
128 | def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf):
129 | learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach(
130 | ).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1)
131 | c2w = poses_avg(learned_poses_)
132 | print('recentered', c2w.shape)
133 | # Get spiral
134 | # Get average pose
135 | up = normalize(learned_poses_[:, :3, 1].sum(0))
136 | # Find a reasonable "focus depth" for this dataset
137 |
138 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
139 | dt = .75
140 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
141 | focal = mean_dz
142 |
143 | # Get radii for spiral path
144 | shrink_factor = .8
145 | zdelta = close_depth * .2
146 | tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
147 | rads = np.percentile(np.abs(tt), 90, 0)
148 | c2w_path = c2w
149 | N_rots = 2
150 | c2ws = render_path_spiral(
151 | c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views)
152 | c2ws = torch.tensor(np.stack(c2ws).astype(np.float32))
153 | c2ws = c2ws[:, :3, :4]
154 | c2ws = convert3x4_4x4(c2ws)
155 | return c2ws
156 |
157 |
158 | def convert3x4_4x4(input):
159 | """
160 | :param input: (N, 3, 4) or (3, 4) torch or np
161 | :return: (N, 4, 4) or (4, 4) torch or np
162 | """
163 | if torch.is_tensor(input):
164 | if len(input.shape) == 3:
165 | output = torch.cat([input, torch.zeros_like(
166 | input[:, 0:1])], dim=1) # (N, 4, 4)
167 | output[:, 3, 3] = 1.0
168 | else:
169 | output = torch.cat([input, torch.tensor(
170 | [[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
171 | else:
172 | if len(input.shape) == 3:
173 | output = np.concatenate(
174 | [input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
175 | output[:, 3, 3] = 1.0
176 | else:
177 | output = np.concatenate(
178 | [input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
179 | output[3, 3] = 1.0
180 | return output
181 |
182 |
183 | plt.rc('legend', fontsize=20) # using a named size
184 |
185 |
186 | def plot_pose(ref_poses, est_poses, output_path, vid=False):
187 | ref_poses = [pose for pose in ref_poses]
188 | if isinstance(est_poses, dict):
189 | est_poses = [pose for k, pose in est_poses.items()]
190 | else:
191 | est_poses = [pose for pose in est_poses]
192 | traj_ref = PosePath3D(poses_se3=ref_poses)
193 | traj_est = PosePath3D(poses_se3=est_poses)
194 | traj_est_aligned = copy.deepcopy(traj_est)
195 | traj_est_aligned.align(traj_ref, correct_scale=True,
196 | correct_only_scale=False)
197 | if vid:
198 | for p_idx in range(len(ref_poses)):
199 | fig = plt.figure()
200 | current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1]
201 | current_ref = traj_ref.poses_se3[:p_idx+1]
202 | current_est_aligned = PosePath3D(poses_se3=current_est_aligned)
203 | current_ref = PosePath3D(poses_se3=current_ref)
204 | traj_by_label = {
205 | # "estimate (not aligned)": traj_est,
206 | "Ours (aligned)": current_est_aligned,
207 | "Ground-truth": current_ref
208 | }
209 | plot_mode = plot.PlotMode.xyz
210 | # ax = plot.prepare_axis(fig, plot_mode, 111)
211 | ax = fig.add_subplot(111, projection="3d")
212 | ax.xaxis.set_tick_params(labelbottom=False)
213 | ax.yaxis.set_tick_params(labelleft=False)
214 | ax.zaxis.set_tick_params(labelleft=False)
215 | colors = ['r', 'b']
216 | styles = ['-', '--']
217 |
218 | for idx, (label, traj) in enumerate(traj_by_label.items()):
219 | plot.traj(ax, plot_mode, traj,
220 | styles[idx], colors[idx], label)
221 | # break
222 | # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
223 | ax.view_init(elev=10., azim=45)
224 | plt.tight_layout()
225 | os.makedirs(os.path.join(os.path.dirname(
226 | output_path), 'pose_vid'), exist_ok=True)
227 | pose_vis_path = os.path.join(os.path.dirname(
228 | output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx))
229 | print(pose_vis_path)
230 | fig.savefig(pose_vis_path)
231 |
232 | # else:
233 |
234 | fig = plt.figure()
235 | traj_by_label = {
236 | # "estimate (not aligned)": traj_est,
237 | "Ours (aligned)": traj_est_aligned,
238 | "Ground-truth": traj_ref
239 | }
240 | plot_mode = plot.PlotMode.xyz
241 | # ax = plot.prepare_axis(fig, plot_mode, 111)
242 | ax = fig.add_subplot(111, projection="3d")
243 | ax.xaxis.set_tick_params(labelbottom=False)
244 | ax.yaxis.set_tick_params(labelleft=False)
245 | ax.zaxis.set_tick_params(labelleft=False)
246 | colors = ['r', 'b']
247 | styles = ['-', '--']
248 |
249 | for idx, (label, traj) in enumerate(traj_by_label.items()):
250 | plot.traj(ax, plot_mode, traj,
251 | styles[idx], colors[idx], label)
252 | # break
253 | # plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
254 | ax.view_init(elev=10., azim=45)
255 | plt.tight_layout()
256 | pose_vis_path = os.path.join(os.path.dirname(output_path), 'pose_vis.png')
257 | fig.savefig(pose_vis_path)
258 |
--------------------------------------------------------------------------------