├── .gitignore
├── LICENSE
├── LICENSE_gaussian_splatting.md
├── README.md
├── assets
└── pipeline.png
├── configs
├── base.yaml
├── kitti_nvs.yaml
├── kitti_reconstruction.yaml
├── waymo_nvs.yaml
└── waymo_reconstruction.yaml
├── evaluate.py
├── gaussian_renderer
└── __init__.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── requirements-data.txt
├── requirements.txt
├── scene
├── __init__.py
├── cameras.py
├── envlight.py
├── gaussian_model.py
├── kittimot_loader.py
├── scene_utils.py
└── waymo_loader.py
├── scripts
├── extract_kitti_metric_nvs.py
├── extract_kitti_metric_reconstruction.py
├── extract_mask_kitti.py
├── extract_mask_waymo.py
├── extract_scenes_waymo.py
├── extract_waymo_metric_nvs.py
├── extract_waymo_metric_reconstruction.py
├── run_kitti_nvs_all.sh
├── run_kitti_reconstruction_all.sh
├── run_waymo_nvs_all.sh
├── run_waymo_reconstruction_all.sh
└── waymo_converter.py
├── separate.py
├── train.py
└── utils
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── loss_utils.py
├── sh_utils.py
└── system_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.so
3 | build/
4 | *.egg-info/
5 | .vscode
6 |
7 |
8 | build
9 | data
10 | output
11 | eval_output
12 | diff-gaussian-rasterization
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Fudan Zhang Vision Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE_gaussian_splatting.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
85 | ## 6. Files subject to permissive licenses
86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license.
87 |
88 | Title: pytorch-ssim\
89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\
90 | Copyright Evan Su, 2017\
91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering
2 | ### [[Project]](https://fudan-zvg.github.io/PVG) [[Paper]](https://arxiv.org/abs/2311.18561)
3 |
4 | > [**Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering**](https://arxiv.org/abs/2311.18561),
5 | > Yurui Chen, [Chun Gu](https://sulvxiangxin.github.io/), Junzhe Jiang, [Xiatian Zhu](https://surrey-uplab.github.io/), [Li Zhang](https://lzrobots.github.io)
6 | > **Arxiv preprint**
7 |
8 | **Official implementation of "Periodic Vibration Gaussian:
9 | Dynamic Urban Scene Reconstruction and Real-time Rendering".**
10 |
11 |
12 | ## 🛠️ Pipeline
13 |
14 |

15 |
16 |
17 | ## Get started
18 | ### Environment
19 | ```
20 | # Clone the repo.
21 | git clone https://github.com/fudan-zvg/PVG.git
22 | cd PVG
23 |
24 | # Make a conda environment.
25 | conda create --name pvg python=3.9
26 | conda activate pvg
27 |
28 | # Install requirements.
29 | pip install -r requirements.txt
30 |
31 | # Install simple-knn
32 | git clone https://gitlab.inria.fr/bkerbl/simple-knn.git
33 | pip install ./simple-knn
34 |
35 | # a modified gaussian splatting (for feature rendering)
36 | git clone --recursive https://github.com/SuLvXiangXin/diff-gaussian-rasterization
37 | pip install ./diff-gaussian-rasterization
38 |
39 | # Install nvdiffrast (for Envlight)
40 | git clone https://github.com/NVlabs/nvdiffrast
41 | pip install ./nvdiffrast
42 |
43 | ```
44 |
45 | ### Data preparation
46 | Create a directory for the data: `mkdir data`.
47 | #### Waymo dataset
48 |
49 | Preprocessed 4 waymo scenes for results in Table 1 of our paper can be downloaded [here](https://drive.google.com/file/d/1eTNJz7WeYrB3IctVlUmJIY0z8qhjR_qF/view?usp=sharing) (optional: [corresponding label](https://drive.google.com/file/d/1rkOzYqD1wdwILq_tUNvXBcXMe5YwtI2k/view?usp=drive_link)). Please unzip and put it into `data` directory.
50 |
51 | First prepare the kitti-format Waymo dataset:
52 | ```
53 | # Given the following dataset, we convert it to kitti-format
54 | # data
55 | # └── waymo
56 | # └── waymo_format
57 | # └── training
58 | # └── segment-xxxxxx
59 |
60 | # install some optional package
61 | pip install -r requirements-data.txt
62 |
63 | # Convert the waymo dataset to kitti-format
64 | python scripts/waymo_converter.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo
65 | ```
66 | Then use the example script `scripts/extract_scenes_waymo.py` to extract the scenes from the kitti-format Waymo dataset which we employ to extract the scenes listed in StreetSurf.
67 |
68 | Following [StreetSurf](https://github.com/PJLab-ADG/neuralsim), we use [Segformer](https://github.com/NVlabs/SegFormer) to extract the sky mask and put them as follows:
69 | ```
70 | data
71 | └── waymo_scenes
72 | └── sequence_id
73 | ├── calib
74 | │ └── frame_id.txt
75 | ├── image_0{0, 1, 2, 3, 4}
76 | │ └── frame_id.png
77 | ├── sky_0{0, 1, 2, 3, 4}
78 | │ └── frame_id.png
79 | |── pose
80 | | └── frame_id.txt
81 | └── velodyne
82 | └── frame_id.bin
83 | ```
84 | We provide an example script `scripts/extract_mask_waymo.py` to extract the sky mask from the extracted Waymo dataset, follow instructions [here](https://github.com/PJLab-ADG/neuralsim/blob/main/dataio/autonomous_driving/waymo/README.md#extract-mask-priors----for-sky-pedestrian-etc) to setup the Segformer environment.
85 |
86 | #### KITTI dataset
87 | Preprocessed 3 kitti scenes for results in Table 1 of our paper can be downloaded [here](https://drive.google.com/file/d/1y6elRlFdRXW02oUOHdS9inVHK3U4xBXZ/view?usp=sharinghttps://drive.google.com/file/d/1y6elRlFdRXW02oUOHdS9inVHK3U4xBXZ/view?usp=sharing). Please unzip and put it into `data` directory.
88 |
89 | Put the [KITTI-MOT](https://www.cvlibs.net/datasets/kitti/eval_tracking.php) dataset in `data` directory.
90 | Following [StreetSurf](https://github.com/PJLab-ADG/neuralsim), we use [Segformer](https://github.com/NVlabs/SegFormer) to extract the sky mask and put them as follows:
91 | ```
92 | data
93 | └── kitti_mot
94 | └── training
95 | ├── calib
96 | │ └── sequence_id.txt
97 | ├── image_0{2, 3}
98 | │ └── sequence_id
99 | │ └── frame_id.png
100 | ├── sky_0{2, 3}
101 | │ └── sequence_id
102 | │ └── frame_id.png
103 | |── oxts
104 | | └── sequence_id.txt
105 | └── velodyne
106 | └── sequence_id
107 | └── frame_id.bin
108 | ```
109 | We also provide an example script `scripts/extract_mask_kitti.py` to extract the sky mask from the KITTI dataset.
110 |
111 |
112 | ### Training
113 | ```
114 | # Waymo image reconstruction
115 | CUDA_VISIBLE_DEVICES=0 python train.py \
116 | --config configs/waymo_reconstruction.yaml \
117 | source_path=data/waymo_scenes/0145050 \
118 | model_path=eval_output/waymo_reconstruction/0145050
119 |
120 | # Waymo novel view synthesis
121 | CUDA_VISIBLE_DEVICES=0 python train.py \
122 | --config configs/waymo_nvs.yaml \
123 | source_path=data/waymo_scenes/0145050 \
124 | model_path=eval_output/waymo_nvs/0145050
125 |
126 | # KITTI image reconstruction
127 | CUDA_VISIBLE_DEVICES=0 python train.py \
128 | --config configs/kitti_reconstruction.yaml \
129 | source_path=data/kitti_mot/training/image_02/0001 \
130 | model_path=eval_output/kitti_reconstruction/0001 \
131 | start_frame=380 end_frame=431
132 |
133 | # KITTI novel view synthesis
134 | CUDA_VISIBLE_DEVICES=0 python train.py \
135 | --config configs/kitti_nvs.yaml \
136 | source_path=data/kitti_mot/training/image_02/0001 \
137 | model_path=eval_output/kitti_nvs/0001 \
138 | start_frame=380 end_frame=431
139 | ```
140 |
141 | After training, evaluation results can be found in `{EXPERIMENT_DIR}/eval` directory.
142 |
143 | ### Evaluating
144 | You can also use the following command to evaluate.
145 | ```
146 | CUDA_VISIBLE_DEVICES=0 python evaluate.py \
147 | --config configs/kitti_reconstruction.yaml \
148 | source_path=data/kitti_mot/training/image_02/0001 \
149 | model_path=eval_output/kitti_reconstruction/0001 \
150 | start_frame=380 end_frame=431
151 | ```
152 |
153 | ### Automatically removing the dynamics
154 | You can the following command to automatically remove the dynamics, the render results will be saved in `{EXPERIMENT_DIR}/separation` directory.
155 | ```
156 | CUDA_VISIBLE_DEVICES=1 python separate.py \
157 | --config configs/waymo_reconstruction.yaml \
158 | source_path=data/waymo_scenes/0158150 \
159 | model_path=eval_output/waymo_reconstruction/0158150
160 | ```
161 |
162 |
163 | ## 🎥 Videos
164 | ### 🎞️ Demo
165 | [](https://www.youtube.com/embed/jJCCkdpDkRQ)
166 |
167 |
168 | ### 🎞️ Rendered RGB, Depth and Semantic
169 |
170 | https://github.com/fudan-zvg/PVG/assets/83005605/60337a98-f92c-4465-ab45-2ee121413114
171 |
172 | https://github.com/fudan-zvg/PVG/assets/83005605/f45c0a91-26b6-46d9-895c-bf13786f94d2
173 |
174 | https://github.com/fudan-zvg/PVG/assets/83005605/0ed679d6-5e62-4923-b2cb-02c587ed468c
175 |
176 | https://github.com/fudan-zvg/PVG/assets/83005605/3ffda292-1b73-43d3-916a-b524f143f0c9
177 |
178 | ### 🎞️ Image Reconstruction on Waymo
179 | #### Comparison with static methods
180 |
181 | https://github.com/fudan-zvg/PVG/assets/83005605/93e32945-7e9a-454a-8c31-5563125de95b
182 |
183 | https://github.com/fudan-zvg/PVG/assets/83005605/f3c02e43-bb86-428d-b27b-73c4a7857bc7
184 |
185 | #### Comparison with dynamic methods
186 |
187 | https://github.com/fudan-zvg/PVG/assets/83005605/73a82171-9e78-416f-a770-f6f4239d80ca
188 |
189 | https://github.com/fudan-zvg/PVG/assets/83005605/e579f8b8-d31e-456b-a943-b39d56073b94
190 |
191 | ### 🎞️ Novel View Synthesis on Waymo
192 |
193 | https://github.com/fudan-zvg/PVG/assets/83005605/37393332-5d34-4bd0-8285-40bf938b849f
194 |
195 | ## 📜 BibTeX
196 | ```bibtex
197 | @article{chen2023periodic,
198 | title={Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering},
199 | author={Chen, Yurui and Gu, Chun and Jiang, Junzhe and Zhu, Xiatian and Zhang, Li},
200 | journal={arXiv:2311.18561},
201 | year={2023},
202 | }
203 | ```
204 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PVG/b4162a9135282e0f3c929054f16be1b3fbacd77a/assets/pipeline.png
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | test_iterations: [7000, 30000]
2 | save_iterations: [7000, 30000]
3 | checkpoint_iterations: [7000, 30000]
4 | exhaust_test: false
5 | test_interval: 5000
6 | render_static: false
7 | vis_step: 500
8 | start_checkpoint: null
9 | seed: 0
10 |
11 | # ModelParams
12 | sh_degree: 3
13 | scene_type: "Waymo"
14 | source_path: ???
15 | start_frame: 65 # for kitti
16 | end_frame: 120 # for kitti
17 | model_path: ???
18 | resolution_scales: [1]
19 | resolution: -1
20 | white_background: false
21 | data_device: "cuda"
22 | eval: false
23 | debug_cuda: false
24 | cam_num: 3 # for waymo
25 | t_init: 0.2
26 | cycle: 0.2
27 | velocity_decay: 1.0
28 | random_init_point: 200000
29 | fix_radius: 0.0
30 | time_duration: [-0.5, 0.5]
31 | num_pts: 100000
32 | frame_interval: 0.02
33 | testhold: 4 # NVS
34 | env_map_res: 1024
35 | separate_scaling_t: 0.1
36 | neg_fov: true
37 |
38 | # PipelineParams
39 | convert_SHs_python: false
40 | compute_cov3D_python: false
41 | debug: false
42 | depth_blend_mode: 0
43 | env_optimize_until: 1000000000
44 | env_optimize_from: 0
45 |
46 |
47 | # OptimizationParams
48 | iterations: 30000
49 | position_lr_init: 0.00016
50 | position_lr_final: 0.0000016
51 | position_lr_delay_mult: 0.01
52 | t_lr_init: 0.0008
53 | position_lr_max_steps: 30_000
54 | feature_lr: 0.0025
55 | opacity_lr: 0.05
56 | scaling_lr: 0.005
57 | scaling_t_lr: 0.002
58 | velocity_lr: 0.001
59 | rotation_lr: 0.001
60 | envmap_lr: 0.01
61 |
62 | time_split_frac: 0.5
63 | percent_dense: 0.01
64 | thresh_opa_prune: 0.005
65 | densification_interval: 100
66 | opacity_reset_interval: 3000
67 | densify_from_iter: 500
68 | densify_until_iter: 15_000
69 | densify_grad_threshold: 0.0002
70 | densify_grad_t_threshold: 0.002
71 | densify_until_num_points: 3000000
72 | sh_increase_interval: 1000
73 | scale_increase_interval: 5000
74 | prune_big_point: 1
75 | size_threshold: 20
76 | big_point_threshold: 0.1
77 | t_grad: true
78 | no_time_split: true
79 | contract: true
80 |
81 | lambda_dssim: 0.2
82 | lambda_opa: 0.0
83 | lambda_sky_opa: 0.05
84 | lambda_opacity_entropy: 0.05
85 | lambda_inv_depth: 0.001
86 | lambda_self_supervision: 0.5
87 | lambda_t_reg: 0.0
88 | lambda_v_reg: 0.0
89 | lambda_lidar: 0.1
90 | lidar_decay: 1.0
91 | lambda_v_smooth: 0.0
--------------------------------------------------------------------------------
/configs/kitti_nvs.yaml:
--------------------------------------------------------------------------------
1 | exhaust_test: false
2 |
3 |
4 | # ModelParams
5 | scene_type: "KittiMot"
6 | start_frame: 380
7 | end_frame: 431
8 | num_pts: 1000000
9 | cam_num: 3
10 | resolution_scales: [1, 2, 4, 8]
11 | eval: true
12 | t_init: 0.006
13 | fix_radius: 10.0
14 |
15 | # PipelineParams
16 |
17 | # OptimizationParams
18 | iterations: 40000
19 | densify_until_iter: 20000
20 | densify_until_num_points: 10000000
21 |
22 | opacity_lr: 0.007
23 |
24 | densification_interval: 200
25 | sh_increase_interval: 2000
26 | opacity_reset_interval: 5000
27 | densify_grad_threshold: 0.00015
28 | size_threshold: 100
29 | big_point_threshold: 0.4
30 |
31 | lidar_decay: 0.3
32 | lambda_v_reg: 0.001
33 | lambda_t_reg: 0.01
34 |
--------------------------------------------------------------------------------
/configs/kitti_reconstruction.yaml:
--------------------------------------------------------------------------------
1 | exhaust_test: false
2 |
3 |
4 | # ModelParams
5 | scene_type: "KittiMot"
6 | start_frame: 380
7 | end_frame: 431
8 | num_pts: 1000000
9 | cam_num: 3
10 | resolution_scales: [1, 2, 4, 8]
11 | eval: false
12 | t_init: 0.006
13 | fix_radius: 15.0
14 |
15 | # PipelineParams
16 |
17 | # OptimizationParams
18 | iterations: 40000
19 | densify_until_iter: 20000
20 | densify_until_num_points: 10000000
21 |
22 | opacity_lr: 0.007
23 |
24 | densification_interval: 200
25 | sh_increase_interval: 2000
26 | opacity_reset_interval: 5000
27 | densify_grad_threshold: 0.00015
28 | size_threshold: 100
29 | big_point_threshold: 0.4
30 |
31 | lidar_decay: 0.3
32 | lambda_v_reg: 0.001
33 |
--------------------------------------------------------------------------------
/configs/waymo_nvs.yaml:
--------------------------------------------------------------------------------
1 | exhaust_test: false
2 |
3 |
4 | # ModelParams
5 | scene_type: "Waymo"
6 | resolution_scales: [1, 2, 4, 8, 16]
7 | cam_num: 3
8 | eval: true
9 | num_pts: 600000
10 | t_init: 0.1
11 | separate_scaling_t: 0.2
12 |
13 | # PipelineParams
14 |
15 |
16 | # OptimizationParams
17 | iterations: 30000
18 |
19 | opacity_lr: 0.005
20 |
21 | densify_until_iter: 15000
22 | densify_grad_threshold: 0.00017
23 | sh_increase_interval: 2000
24 |
25 |
26 | lambda_v_reg: 0.01
27 |
28 |
29 |
--------------------------------------------------------------------------------
/configs/waymo_reconstruction.yaml:
--------------------------------------------------------------------------------
1 | exhaust_test: false
2 |
3 |
4 | # ModelParams
5 | scene_type: "Waymo"
6 | resolution_scales: [1, 2, 4, 8, 16]
7 | cam_num: 3
8 | eval: false
9 | num_pts: 600000
10 | t_init: 0.1
11 | separate_scaling_t: 0.2
12 |
13 | # PipelineParams
14 |
15 |
16 | # OptimizationParams
17 | iterations: 30000
18 |
19 | opacity_lr: 0.005
20 |
21 | densify_until_iter: 15000
22 | densify_grad_threshold: 0.00017
23 | sh_increase_interval: 2000
24 |
25 |
26 | lambda_v_reg: 0.01
27 |
28 |
29 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import glob
12 | import json
13 | import os
14 | import torch
15 | import torch.nn.functional as F
16 | from utils.loss_utils import psnr, ssim
17 | from gaussian_renderer import render
18 | from scene import Scene, GaussianModel, EnvLight
19 | from utils.general_utils import seed_everything, visualize_depth
20 | from tqdm import tqdm
21 | from argparse import ArgumentParser
22 | from torchvision.utils import make_grid, save_image
23 | from omegaconf import OmegaConf
24 |
25 | EPS = 1e-5
26 |
27 | @torch.no_grad()
28 | def evaluation(iteration, scene : Scene, renderFunc, renderArgs, env_map=None):
29 | from lpipsPyTorch import lpips
30 |
31 | scale = scene.resolution_scales[0]
32 | if "kitti" in args.model_path:
33 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766
34 | num = len(scene.getTrainCameras())//2
35 | eval_train_frame = num//5
36 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id)
37 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
38 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]})
39 | else:
40 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
41 | {'name': 'train', 'cameras': scene.getTrainCameras()})
42 |
43 | for config in validation_configs:
44 | if config['cameras'] and len(config['cameras']) > 0:
45 | l1_test = 0.0
46 | psnr_test = 0.0
47 | ssim_test = 0.0
48 | lpips_test = 0.0
49 | outdir = os.path.join(args.model_path, "eval", config['name'] + f"_{iteration}" + "_render")
50 | os.makedirs(outdir,exist_ok=True)
51 | for idx, viewpoint in enumerate(tqdm(config['cameras'])):
52 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map)
53 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
54 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
55 |
56 | depth = render_pkg['depth']
57 | alpha = render_pkg['alpha']
58 | sky_depth = 900
59 | depth = depth / alpha.clamp_min(EPS)
60 | if env_map is not None:
61 | if args.depth_blend_mode == 0: # harmonic mean
62 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS)
63 | elif args.depth_blend_mode == 1:
64 | depth = alpha * depth + (1 - alpha) * sky_depth
65 |
66 | depth = visualize_depth(depth)
67 | alpha = alpha.repeat(3, 1, 1)
68 |
69 | grid = [gt_image, image, alpha, depth]
70 | grid = make_grid(grid, nrow=2)
71 |
72 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png"))
73 |
74 | l1_test += F.l1_loss(image, gt_image).double()
75 | psnr_test += psnr(image, gt_image).double()
76 | ssim_test += ssim(image, gt_image).double()
77 | lpips_test += lpips(image, gt_image, net_type='vgg').double() # very slow
78 |
79 | psnr_test /= len(config['cameras'])
80 | l1_test /= len(config['cameras'])
81 | ssim_test /= len(config['cameras'])
82 | lpips_test /= len(config['cameras'])
83 |
84 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test))
85 | with open(os.path.join(outdir, "metrics.json"), "w") as f:
86 | json.dump({"split": config['name'], "iteration": iteration, "psnr": psnr_test.item(), "ssim": ssim_test.item(), "lpips": lpips_test.item()}, f)
87 |
88 |
89 | if __name__ == "__main__":
90 | # Set up command line argument parser
91 | parser = ArgumentParser(description="Training script parameters")
92 | parser.add_argument("--config", type=str, required=True)
93 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml")
94 | args, _ = parser.parse_known_args()
95 |
96 | base_conf = OmegaConf.load(args.base_config)
97 | second_conf = OmegaConf.load(args.config)
98 | cli_conf = OmegaConf.from_cli()
99 | args = OmegaConf.merge(base_conf, second_conf, cli_conf)
100 | args.resolution_scales = args.resolution_scales[:1]
101 | print(args)
102 |
103 | seed_everything(args.seed)
104 |
105 | sep_path = os.path.join(args.model_path, 'separation')
106 | os.makedirs(sep_path, exist_ok=True)
107 |
108 | gaussians = GaussianModel(args)
109 | scene = Scene(args, gaussians, shuffle=False)
110 |
111 | if args.env_map_res > 0:
112 | env_map = EnvLight(resolution=args.env_map_res).cuda()
113 | env_map.training_setup(args)
114 | else:
115 | env_map = None
116 |
117 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
118 | assert len(checkpoints) > 0, "No checkpoints found."
119 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
120 | (model_params, first_iter) = torch.load(checkpoint)
121 | gaussians.restore(model_params, args)
122 |
123 | if env_map is not None:
124 | env_checkpoint = os.path.join(os.path.dirname(checkpoint),
125 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt"))
126 | (light_params, _) = torch.load(env_checkpoint)
127 | env_map.restore(light_params)
128 |
129 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
130 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
131 | evaluation(first_iter, scene, render, (args, background), env_map=env_map)
132 |
133 | print("Evaluation complete.")
134 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | from scene.gaussian_model import GaussianModel
16 | from scene.cameras import Camera
17 | from utils.sh_utils import eval_sh
18 |
19 |
20 | def render(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0,
21 | override_color=None, env_map=None,
22 | time_shift=None, other=[], mask=None, is_training=False):
23 | """
24 | Render the scene.
25 |
26 | Background tensor (bg_color) must be on GPU!
27 | """
28 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
29 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
30 | try:
31 | screenspace_points.retain_grad()
32 | except:
33 | pass
34 |
35 | # Set up rasterization configuration
36 | if pipe.neg_fov:
37 | # we find that set fov as -1 slightly improves the results
38 | tanfovx = math.tan(-0.5)
39 | tanfovy = math.tan(-0.5)
40 | else:
41 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
42 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
43 |
44 | raster_settings = GaussianRasterizationSettings(
45 | image_height=int(viewpoint_camera.image_height),
46 | image_width=int(viewpoint_camera.image_width),
47 | tanfovx=tanfovx,
48 | tanfovy=tanfovy,
49 | bg=bg_color if env_map is not None else torch.zeros(3, device="cuda"),
50 | scale_modifier=scaling_modifier,
51 | viewmatrix=viewpoint_camera.world_view_transform,
52 | projmatrix=viewpoint_camera.full_proj_transform,
53 | sh_degree=pc.active_sh_degree,
54 | campos=viewpoint_camera.camera_center,
55 | prefiltered=False,
56 | debug=pipe.debug
57 | )
58 |
59 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
60 |
61 | means3D = pc.get_xyz
62 | means2D = screenspace_points
63 | opacity = pc.get_opacity
64 | scales = None
65 | rotations = None
66 | cov3D_precomp = None
67 |
68 | if time_shift is not None:
69 | means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp-time_shift)
70 | means3D = means3D + pc.get_inst_velocity * time_shift
71 | marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp-time_shift)
72 | else:
73 | means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp)
74 | marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp)
75 | opacity = opacity * marginal_t
76 |
77 | if pipe.compute_cov3D_python:
78 | cov3D_precomp = pc.get_covariance(scaling_modifier)
79 | else:
80 | scales = pc.get_scaling
81 | rotations = pc.get_rotation
82 |
83 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
84 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
85 | shs = None
86 | colors_precomp = None
87 | if override_color is None:
88 | if pipe.convert_SHs_python:
89 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, pc.get_max_sh_channels)
90 | dir_pp = (means3D.detach() - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)).detach()
91 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
92 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
93 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
94 | else:
95 | shs = pc.get_features
96 | else:
97 | colors_precomp = override_color
98 |
99 | feature_list = other
100 |
101 | if len(feature_list) > 0:
102 | features = torch.cat(feature_list, dim=1)
103 | S_other = features.shape[1]
104 | else:
105 | features = torch.zeros_like(means3D[:, :0])
106 | S_other = 0
107 |
108 | # Prefilter
109 | if mask is None:
110 | mask = marginal_t[:, 0] > 0.05
111 | else:
112 | mask = mask & (marginal_t[:, 0] > 0.05)
113 | masked_means3D = means3D[mask]
114 | masked_xyz_homo = torch.cat([masked_means3D, torch.ones_like(masked_means3D[:, :1])], dim=1)
115 | masked_depth = (masked_xyz_homo @ viewpoint_camera.world_view_transform[:, 2:3])
116 | depth_alpha = torch.zeros(means3D.shape[0], 2, dtype=torch.float32, device=means3D.device)
117 | depth_alpha[mask] = torch.cat([
118 | masked_depth,
119 | torch.ones_like(masked_depth)
120 | ], dim=1)
121 | features = torch.cat([features, depth_alpha], dim=1)
122 |
123 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
124 | contrib, rendered_image, rendered_feature, radii = rasterizer(
125 | means3D = means3D,
126 | means2D = means2D,
127 | shs = shs,
128 | colors_precomp = colors_precomp,
129 | features = features,
130 | opacities = opacity,
131 | scales = scales,
132 | rotations = rotations,
133 | cov3D_precomp = cov3D_precomp,
134 | mask = mask)
135 |
136 | rendered_other, rendered_depth, rendered_opacity = rendered_feature.split([S_other, 1, 1], dim=0)
137 | rendered_image_before = rendered_image
138 | if env_map is not None:
139 | bg_color_from_envmap = env_map(viewpoint_camera.get_world_directions(is_training).permute(1, 2, 0)).permute(2, 0, 1)
140 | rendered_image = rendered_image + (1 - rendered_opacity) * bg_color_from_envmap
141 |
142 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
143 | # They will be excluded from value updates used in the splitting criteria.
144 | return {"render": rendered_image,
145 | "render_nobg": rendered_image_before,
146 | "viewspace_points": screenspace_points,
147 | "visibility_filter": radii > 0,
148 | "radii": radii,
149 | "contrib": contrib,
150 | "depth": rendered_depth,
151 | "alpha": rendered_opacity,
152 | "feature": rendered_other}
153 |
154 |
155 |
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y).mean()
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/requirements-data.txt:
--------------------------------------------------------------------------------
1 | waymo-open-dataset-tf-2-4-0
2 | mmcv-full
3 |
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision==0.15.2
3 | tqdm
4 | imageio
5 | imageio[ffmpeg]
6 | kornia
7 | trimesh
8 | Pillow
9 | ninja
10 | omegaconf
11 | plyfile
12 | opencv_python
13 | opencv_contrib_python
14 | tensorboardX
15 | matplotlib
16 |
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | from utils.system_utils import searchForMaxIteration
16 | from scene.gaussian_model import GaussianModel
17 | from scene.envlight import EnvLight
18 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
19 | from scene.waymo_loader import readWaymoInfo
20 | from scene.kittimot_loader import readKittiMotInfo
21 |
22 | sceneLoadTypeCallbacks = {
23 | "Waymo": readWaymoInfo,
24 | "KittiMot": readKittiMotInfo,
25 | }
26 |
27 | class Scene:
28 |
29 | gaussians : GaussianModel
30 |
31 | def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True):
32 | self.model_path = args.model_path
33 | self.loaded_iter = None
34 | self.gaussians = gaussians
35 | self.white_background = args.white_background
36 |
37 | if load_iteration:
38 | if load_iteration == -1:
39 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
40 | else:
41 | self.loaded_iter = load_iteration
42 | print("Loading trained model at iteration {}".format(self.loaded_iter))
43 |
44 | self.train_cameras = {}
45 | self.test_cameras = {}
46 |
47 | scene_info = sceneLoadTypeCallbacks[args.scene_type](args)
48 |
49 | self.time_interval = args.frame_interval
50 | self.gaussians.time_duration = scene_info.time_duration
51 | print("time duration: ", scene_info.time_duration)
52 | print("frame interval: ", self.time_interval)
53 |
54 | if not self.loaded_iter:
55 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
56 | dest_file.write(src_file.read())
57 | json_cams = []
58 | camlist = []
59 | if scene_info.test_cameras:
60 | camlist.extend(scene_info.test_cameras)
61 | if scene_info.train_cameras:
62 | camlist.extend(scene_info.train_cameras)
63 | for id, cam in enumerate(camlist):
64 | json_cams.append(camera_to_JSON(id, cam))
65 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
66 | json.dump(json_cams, file)
67 |
68 | if shuffle:
69 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
70 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
71 |
72 | self.cameras_extent = scene_info.nerf_normalization["radius"]
73 | self.resolution_scales = args.resolution_scales
74 | self.scale_index = len(self.resolution_scales) - 1
75 | for resolution_scale in self.resolution_scales:
76 | print("Loading Training Cameras")
77 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
78 | print("Loading Test Cameras")
79 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
80 |
81 | if self.loaded_iter:
82 | self.gaussians.load_ply(os.path.join(self.model_path,
83 | "point_cloud",
84 | "iteration_" + str(self.loaded_iter),
85 | "point_cloud.ply"))
86 | else:
87 | self.gaussians.create_from_pcd(scene_info.point_cloud, 1)
88 |
89 | def upScale(self):
90 | self.scale_index = max(0, self.scale_index - 1)
91 |
92 | def getTrainCameras(self):
93 | return self.train_cameras[self.resolution_scales[self.scale_index]]
94 |
95 | def getTestCameras(self, scale=1.0):
96 | return self.test_cameras[scale]
97 |
98 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import math
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 | import numpy as np
17 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrixCenterShift
18 | import kornia
19 |
20 |
21 | class Camera(nn.Module):
22 | def __init__(self, colmap_id, R, T, FoVx=None, FoVy=None, cx=None, cy=None, fx=None, fy=None,
23 | image=None,
24 | image_name=None, uid=0,
25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", timestamp=0.0,
26 | resolution=None, image_path=None,
27 | pts_depth=None, sky_mask=None
28 | ):
29 | super(Camera, self).__init__()
30 |
31 | self.uid = uid
32 | self.colmap_id = colmap_id
33 | self.R = R
34 | self.T = T
35 | self.FoVx = FoVx
36 | self.FoVy = FoVy
37 | self.image_name = image_name
38 | self.image = image
39 | self.cx = cx
40 | self.cy = cy
41 | self.fx = fx
42 | self.fy = fy
43 | self.resolution = resolution
44 | self.image_path = image_path
45 |
46 | try:
47 | self.data_device = torch.device(data_device)
48 | except Exception as e:
49 | print(e)
50 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device")
51 | self.data_device = torch.device("cuda")
52 |
53 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
54 | self.sky_mask = sky_mask.to(self.data_device) > 0 if sky_mask is not None else sky_mask
55 | self.pts_depth = pts_depth.to(self.data_device) if pts_depth is not None else pts_depth
56 |
57 | self.image_width = resolution[0]
58 | self.image_height = resolution[1]
59 |
60 | self.zfar = 1000.0
61 | self.znear = 0.01
62 |
63 | self.trans = trans
64 | self.scale = scale
65 |
66 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
67 | if cx is not None:
68 | self.FoVx = 2 * math.atan(0.5*self.image_width / fx)
69 | self.FoVy = 2 * math.atan(0.5*self.image_height / fy)
70 | self.projection_matrix = getProjectionMatrixCenterShift(self.znear, self.zfar, cx, cy, fx, fy,
71 | self.image_width, self.image_height).transpose(0, 1).cuda()
72 | else:
73 | self.cx = self.image_width / 2
74 | self.cy = self.image_height / 2
75 | self.fx = self.image_width / (2 * np.tan(self.FoVx * 0.5))
76 | self.fy = self.image_height / (2 * np.tan(self.FoVy * 0.5))
77 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx,
78 | fovY=self.FoVy).transpose(0, 1).cuda()
79 | self.full_proj_transform = (
80 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
81 | self.camera_center = self.world_view_transform.inverse()[3, :3]
82 | self.c2w = self.world_view_transform.transpose(0, 1).inverse()
83 | self.timestamp = timestamp
84 | self.grid = kornia.utils.create_meshgrid(self.image_height, self.image_width, normalized_coordinates=False, device='cuda')[0]
85 |
86 | def get_world_directions(self, train=False):
87 | u, v = self.grid.unbind(-1)
88 | if train:
89 | directions = torch.stack([(u-self.cx+torch.rand_like(u))/self.fx,
90 | (v-self.cy+torch.rand_like(v))/self.fy,
91 | torch.ones_like(u)], dim=0)
92 | else:
93 | directions = torch.stack([(u-self.cx+0.5)/self.fx,
94 | (v-self.cy+0.5)/self.fy,
95 | torch.ones_like(u)], dim=0)
96 | directions = F.normalize(directions, dim=0)
97 | directions = (self.c2w[:3, :3] @ directions.reshape(3, -1)).reshape(3, self.image_height, self.image_width)
98 | return directions
99 |
100 |
--------------------------------------------------------------------------------
/scene/envlight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import nvdiffrast.torch as dr
3 |
4 |
5 | class EnvLight(torch.nn.Module):
6 |
7 | def __init__(self, resolution=1024):
8 | super().__init__()
9 | self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")
10 | self.base = torch.nn.Parameter(
11 | 0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),
12 | )
13 |
14 | def capture(self):
15 | return (
16 | self.base,
17 | self.optimizer.state_dict(),
18 | )
19 |
20 | def restore(self, model_args, training_args=None):
21 | self.base, opt_dict = model_args
22 | if training_args is not None:
23 | self.training_setup(training_args)
24 | self.optimizer.load_state_dict(opt_dict)
25 |
26 | def training_setup(self, training_args):
27 | self.optimizer = torch.optim.Adam(self.parameters(), lr=training_args.envmap_lr, eps=1e-15)
28 |
29 | def forward(self, l):
30 | l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape)
31 | l = l.contiguous()
32 | prefix = l.shape[:-1]
33 | if len(prefix) != 3: # reshape to [B, H, W, -1]
34 | l = l.reshape(1, 1, -1, l.shape[-1])
35 |
36 | light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube')
37 | light = light.view(*prefix, -1)
38 |
39 | return light
40 |
--------------------------------------------------------------------------------
/scene/gaussian_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import math
12 | import torch
13 | import numpy as np
14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation, get_step_lr_func
15 | from torch import nn
16 | import os
17 | from plyfile import PlyData, PlyElement
18 | from utils.sh_utils import RGB2SH
19 | from simple_knn._C import distCUDA2
20 | from utils.graphics_utils import BasicPointCloud
21 | from utils.general_utils import strip_symmetric, build_scaling_rotation
22 |
23 | class GaussianModel:
24 |
25 | def setup_functions(self):
26 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
27 | L = build_scaling_rotation(scaling_modifier * scaling, rotation)
28 | actual_covariance = L @ L.transpose(1, 2)
29 | symm = strip_symmetric(actual_covariance)
30 | return symm
31 |
32 | self.scaling_activation = torch.exp
33 | self.scaling_inverse_activation = torch.log
34 |
35 | self.scaling_t_activation = torch.exp
36 | self.scaling_t_inverse_activation = torch.log
37 |
38 | self.covariance_activation = build_covariance_from_scaling_rotation
39 |
40 | self.opacity_activation = torch.sigmoid
41 | self.inverse_opacity_activation = inverse_sigmoid
42 |
43 | self.rotation_activation = torch.nn.functional.normalize
44 |
45 | def __init__(self, args):
46 | self.active_sh_degree = 0
47 | self.max_sh_degree = args.sh_degree
48 | self._xyz = torch.empty(0)
49 | self._features_dc = torch.empty(0)
50 | self._features_rest = torch.empty(0)
51 | self._scaling = torch.empty(0)
52 | self._rotation = torch.empty(0)
53 | self._opacity = torch.empty(0)
54 | self._t = torch.empty(0)
55 | self._scaling_t = torch.empty(0)
56 | self._velocity = torch.empty(0)
57 |
58 | self.max_radii2D = torch.empty(0)
59 | self.xyz_gradient_accum = torch.empty(0)
60 | self.t_gradient_accum = torch.empty(0)
61 | self.denom = torch.empty(0)
62 |
63 | self.optimizer = None
64 | self.percent_dense = 0
65 | self.spatial_lr_scale = 0
66 |
67 | self.time_duration = args.time_duration
68 | self.no_time_split = args.no_time_split
69 | self.t_grad = args.t_grad
70 | self.contract = args.contract
71 | self.t_init = args.t_init
72 | self.big_point_threshold = args.big_point_threshold
73 |
74 | self.T = args.cycle
75 | self.velocity_decay = args.velocity_decay
76 | self.random_init_point = args.random_init_point
77 |
78 | self.setup_functions()
79 |
80 | def capture(self):
81 | return (
82 | self.active_sh_degree,
83 | self._xyz,
84 | self._features_dc,
85 | self._features_rest,
86 | self._scaling,
87 | self._rotation,
88 | self._opacity,
89 | self._t,
90 | self._scaling_t,
91 | self._velocity,
92 | self.max_radii2D,
93 | self.xyz_gradient_accum,
94 | self.t_gradient_accum,
95 | self.denom,
96 | self.optimizer.state_dict(),
97 | self.spatial_lr_scale,
98 | self.T,
99 | self.velocity_decay,
100 | )
101 |
102 | def restore(self, model_args, training_args=None):
103 | (self.active_sh_degree,
104 | self._xyz,
105 | self._features_dc,
106 | self._features_rest,
107 | self._scaling,
108 | self._rotation,
109 | self._opacity,
110 | self._t,
111 | self._scaling_t,
112 | self._velocity,
113 | self.max_radii2D,
114 | xyz_gradient_accum,
115 | t_gradient_accum,
116 | denom,
117 | opt_dict,
118 | self.spatial_lr_scale,
119 | self.T,
120 | self.velocity_decay,
121 | ) = model_args
122 | self.setup_functions()
123 | if training_args is not None:
124 | self.training_setup(training_args)
125 | self.xyz_gradient_accum = xyz_gradient_accum
126 | self.t_gradient_accum = t_gradient_accum
127 | self.denom = denom
128 | self.optimizer.load_state_dict(opt_dict)
129 |
130 | @property
131 | def get_scaling(self):
132 | return self.scaling_activation(self._scaling)
133 |
134 | @property
135 | def get_scaling_t(self):
136 | return self.scaling_t_activation(self._scaling_t)
137 |
138 | @property
139 | def get_rotation(self):
140 | return self.rotation_activation(self._rotation)
141 |
142 | def get_xyz_SHM(self, t):
143 | a = 1/self.T * np.pi * 2
144 | return self._xyz + self._velocity*torch.sin((t-self._t)*a)/a
145 |
146 | @property
147 | def get_inst_velocity(self):
148 | return self._velocity*torch.exp(-self.get_scaling_t/self.T/2*self.velocity_decay)
149 |
150 | @property
151 | def get_xyz(self):
152 | return self._xyz
153 |
154 | @property
155 | def get_t(self):
156 | return self._t
157 |
158 | @property
159 | def get_features(self):
160 | features_dc = self._features_dc
161 | features_rest = self._features_rest
162 | return torch.cat((features_dc, features_rest), dim=1)
163 |
164 | @property
165 | def get_opacity(self):
166 | return self.opacity_activation(self._opacity)
167 |
168 | @property
169 | def get_max_sh_channels(self):
170 | return (self.max_sh_degree + 1) ** 2
171 |
172 | def get_marginal_t(self, timestamp):
173 | return torch.exp(-0.5 * (self.get_t - timestamp) ** 2 / self.get_scaling_t ** 2)
174 |
175 | def get_covariance(self, scaling_modifier=1):
176 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
177 |
178 | def oneupSHdegree(self):
179 | if self.active_sh_degree < self.max_sh_degree:
180 | self.active_sh_degree += 1
181 |
182 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
183 | self.spatial_lr_scale = spatial_lr_scale
184 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
185 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
186 | features = torch.zeros((fused_color.shape[0], 3, self.get_max_sh_channels)).float().cuda()
187 | features[:, :3, 0] = fused_color
188 | features[:, 3:, 1:] = 0.0
189 |
190 | ## random up and far
191 | r_max = 100000
192 | r_min = 2
193 | num_sph = self.random_init_point
194 |
195 | theta = 2*torch.pi*torch.rand(num_sph)
196 | phi = (torch.pi/2*0.99*torch.rand(num_sph))**1.5 # x**a decay
197 | s = torch.rand(num_sph)
198 | r_1 = s*1/r_min+(1-s)*1/r_max
199 | r = 1/r_1
200 | pts_sph = torch.stack([r*torch.cos(theta)*torch.cos(phi), r*torch.sin(theta)*torch.cos(phi), r*torch.sin(phi)],dim=-1).cuda()
201 |
202 | r_rec = r_min
203 | num_rec = self.random_init_point
204 | pts_rec = torch.stack([r_rec*(torch.rand(num_rec)-0.5),r_rec*(torch.rand(num_rec)-0.5),
205 | r_rec*(torch.rand(num_rec))],dim=-1).cuda()
206 |
207 | pts_sph = torch.cat([pts_rec, pts_sph], dim=0)
208 | pts_sph[:,2] = -pts_sph[:,2]+1
209 |
210 | fused_point_cloud = torch.cat([fused_point_cloud, pts_sph], dim=0)
211 | features = torch.cat([features,
212 | torch.zeros([pts_sph.size(0), features.size(1), features.size(2)]).float().cuda()],
213 | dim=0)
214 |
215 | if pcd.time is None or pcd.time.shape[0] != fused_point_cloud.shape[0]:
216 | if pcd.time is None:
217 | time = (np.random.rand(pcd.points.shape[0], 1) * 1.2 - 0.1) * (
218 | self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
219 | else:
220 | time = pcd.time
221 |
222 | if self.t_init < 1:
223 | random_times = (torch.rand(fused_point_cloud.shape[0]-pcd.points.shape[0], 1, device="cuda") * 1.2 - 0.1) * (
224 | self.time_duration[1] - self.time_duration[0]) + self.time_duration[0]
225 | pts_times = torch.from_numpy(time.copy()).float().cuda()
226 | fused_times = torch.cat([pts_times, random_times], dim=0)
227 | else:
228 | fused_times = torch.full_like(fused_point_cloud[..., :1],
229 | 0.5 * (self.time_duration[1] + self.time_duration[0]))
230 | else:
231 | fused_times = torch.from_numpy(np.asarray(pcd.time.copy())).cuda().float()
232 | fused_times_sh = torch.full_like(pts_sph[..., :1], 0.5 * (self.time_duration[1] + self.time_duration[0]))
233 | fused_times = torch.cat([fused_times, fused_times_sh], dim=0)
234 |
235 | print("Number of points at initialization : ", fused_point_cloud.shape[0])
236 |
237 | dist2 = torch.clamp_min(distCUDA2(fused_point_cloud), 0.0000001)
238 | scales = self.scaling_inverse_activation(torch.sqrt(dist2))[..., None].repeat(1, 3)
239 |
240 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
241 | rots[:, 0] = 1
242 |
243 | dist_t = torch.full_like(fused_times, (self.time_duration[1] - self.time_duration[0])*self.t_init)
244 | scales_t = self.scaling_t_inverse_activation(torch.sqrt(dist_t))
245 | velocity = torch.full((fused_point_cloud.shape[0], 3), 0., device="cuda")
246 |
247 | opacities = inverse_sigmoid(0.01 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
248 |
249 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
250 | self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True))
251 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True))
252 | self._scaling = nn.Parameter(scales.requires_grad_(True))
253 | self._rotation = nn.Parameter(rots.requires_grad_(True))
254 | self._opacity = nn.Parameter(opacities.requires_grad_(True))
255 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
256 | self._t = nn.Parameter(fused_times.requires_grad_(True))
257 | self._scaling_t = nn.Parameter(scales_t.requires_grad_(True))
258 | self._velocity = nn.Parameter(velocity.requires_grad_(True))
259 |
260 | def training_setup(self, training_args):
261 | self.percent_dense = training_args.percent_dense
262 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
263 | self.t_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
264 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
265 |
266 | l = [
267 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
268 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
269 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
270 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
271 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
272 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
273 | {'params': [self._t], 'lr': training_args.t_lr_init, "name": "t"},
274 | {'params': [self._scaling_t], 'lr': training_args.scaling_t_lr, "name": "scaling_t"},
275 | {'params': [self._velocity], 'lr': training_args.velocity_lr * self.spatial_lr_scale, "name": "velocity"},
276 | ]
277 |
278 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
279 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
280 | lr_final=training_args.position_lr_final * self.spatial_lr_scale,
281 | lr_delay_mult=training_args.position_lr_delay_mult,
282 | max_steps=training_args.iterations)
283 |
284 | final_decay = training_args.position_lr_final / training_args.position_lr_init
285 |
286 | self.t_scheduler_args = get_expon_lr_func(lr_init=training_args.t_lr_init,
287 | lr_final=training_args.t_lr_init * final_decay,
288 | lr_delay_mult=training_args.position_lr_delay_mult,
289 | max_steps=training_args.iterations)
290 |
291 | def update_learning_rate(self, iteration):
292 | ''' Learning rate scheduling per step '''
293 | for param_group in self.optimizer.param_groups:
294 | if param_group["name"] == "xyz":
295 | lr = self.xyz_scheduler_args(iteration)
296 | param_group['lr'] = lr
297 | if param_group["name"] == "t":
298 | lr = self.t_scheduler_args(iteration)
299 | param_group['lr'] = lr
300 |
301 | def reset_opacity(self):
302 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01))
303 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
304 | self._opacity = optimizable_tensors["opacity"]
305 |
306 | def replace_tensor_to_optimizer(self, tensor, name):
307 | optimizable_tensors = {}
308 | for group in self.optimizer.param_groups:
309 | if group["name"] == name:
310 | stored_state = self.optimizer.state.get(group['params'][0], None)
311 | stored_state["exp_avg"] = torch.zeros_like(tensor)
312 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
313 |
314 | del self.optimizer.state[group['params'][0]]
315 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
316 | self.optimizer.state[group['params'][0]] = stored_state
317 |
318 | optimizable_tensors[group["name"]] = group["params"][0]
319 | return optimizable_tensors
320 |
321 | def _prune_optimizer(self, mask):
322 | optimizable_tensors = {}
323 | for group in self.optimizer.param_groups:
324 | stored_state = self.optimizer.state.get(group['params'][0], None)
325 | if stored_state is not None:
326 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
327 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
328 |
329 | del self.optimizer.state[group['params'][0]]
330 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
331 | self.optimizer.state[group['params'][0]] = stored_state
332 |
333 | optimizable_tensors[group["name"]] = group["params"][0]
334 | else:
335 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
336 | optimizable_tensors[group["name"]] = group["params"][0]
337 | return optimizable_tensors
338 |
339 | def prune_points(self, mask):
340 | valid_points_mask = ~mask
341 | optimizable_tensors = self._prune_optimizer(valid_points_mask)
342 |
343 | self._xyz = optimizable_tensors["xyz"]
344 | self._features_dc = optimizable_tensors["f_dc"]
345 | self._features_rest = optimizable_tensors["f_rest"]
346 | self._opacity = optimizable_tensors["opacity"]
347 | self._scaling = optimizable_tensors["scaling"]
348 | self._rotation = optimizable_tensors["rotation"]
349 |
350 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
351 |
352 | self.denom = self.denom[valid_points_mask]
353 | self.max_radii2D = self.max_radii2D[valid_points_mask]
354 |
355 | self._t = optimizable_tensors['t']
356 | self._scaling_t = optimizable_tensors['scaling_t']
357 | self._velocity = optimizable_tensors['velocity']
358 | self.t_gradient_accum = self.t_gradient_accum[valid_points_mask]
359 |
360 | def cat_tensors_to_optimizer(self, tensors_dict):
361 | optimizable_tensors = {}
362 | for group in self.optimizer.param_groups:
363 | assert len(group["params"]) == 1
364 | extension_tensor = tensors_dict[group["name"]]
365 | stored_state = self.optimizer.state.get(group['params'][0], None)
366 | if stored_state is not None:
367 |
368 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)),
369 | dim=0)
370 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)),
371 | dim=0)
372 |
373 | del self.optimizer.state[group['params'][0]]
374 | group["params"][0] = nn.Parameter(
375 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
376 | self.optimizer.state[group['params'][0]] = stored_state
377 |
378 | optimizable_tensors[group["name"]] = group["params"][0]
379 | else:
380 | group["params"][0] = nn.Parameter(
381 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
382 | optimizable_tensors[group["name"]] = group["params"][0]
383 |
384 | return optimizable_tensors
385 |
386 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
387 | new_rotation, new_t, new_scaling_t, new_velocity):
388 | d = {"xyz": new_xyz,
389 | "f_dc": new_features_dc,
390 | "f_rest": new_features_rest,
391 | "opacity": new_opacities,
392 | "scaling": new_scaling,
393 | "rotation": new_rotation,
394 | "t": new_t,
395 | "scaling_t": new_scaling_t,
396 | "velocity": new_velocity,
397 | }
398 |
399 | optimizable_tensors = self.cat_tensors_to_optimizer(d)
400 | self._xyz = optimizable_tensors["xyz"]
401 | self._features_dc = optimizable_tensors["f_dc"]
402 | self._features_rest = optimizable_tensors["f_rest"]
403 | self._opacity = optimizable_tensors["opacity"]
404 | self._scaling = optimizable_tensors["scaling"]
405 | self._rotation = optimizable_tensors["rotation"]
406 | self._t = optimizable_tensors['t']
407 | self._scaling_t = optimizable_tensors['scaling_t']
408 | self._velocity = optimizable_tensors['velocity']
409 | self.t_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
410 |
411 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
412 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
413 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
414 |
415 | def densify_and_split(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, N=2, time_split=False,
416 | joint_sample=True):
417 | n_init_points = self.get_xyz.shape[0]
418 | # Extract points that satisfy the gradient condition
419 | padded_grad = torch.zeros((n_init_points), device="cuda")
420 | padded_grad[:grads.shape[0]] = grads.squeeze()
421 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
422 |
423 | if self.contract:
424 | scale_factor = self._xyz.norm(dim=-1)*scene_extent-1 # -0
425 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent
426 | else:
427 | scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent
428 |
429 | selected_pts_mask = torch.logical_and(selected_pts_mask,
430 | torch.max(self.get_scaling,
431 | dim=1).values > self.percent_dense * scene_extent*scale_factor)
432 | decay_factor = N*0.8
433 | if not self.no_time_split:
434 | N = N+1
435 |
436 | if time_split:
437 | padded_grad_t = torch.zeros((n_init_points), device="cuda")
438 | padded_grad_t[:grads_t.shape[0]] = grads_t.squeeze()
439 | selected_time_mask = torch.where(padded_grad_t >= grad_t_threshold, True, False)
440 | extend_thresh = self.percent_dense
441 |
442 | selected_time_mask = torch.logical_and(selected_time_mask,
443 | torch.max(self.get_scaling_t, dim=1).values > extend_thresh)
444 | if joint_sample:
445 | selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)
446 |
447 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (decay_factor))
448 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
449 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
450 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
451 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
452 |
453 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
454 | means = torch.zeros((stds.size(0), 3), device="cuda")
455 | samples = torch.normal(mean=means, std=stds)
456 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
457 | xyz = self.get_xyz[selected_pts_mask]
458 |
459 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + xyz.repeat(N, 1)
460 |
461 | new_t = None
462 | new_scaling_t = None
463 | new_velocity = None
464 | stds_t = self.get_scaling_t[selected_pts_mask].repeat(N, 1)
465 | means_t = torch.zeros((stds_t.size(0), 1), device="cuda")
466 | samples_t = torch.normal(mean=means_t, std=stds_t)
467 | new_t = samples_t+self.get_t[selected_pts_mask].repeat(N, 1)
468 |
469 | new_scaling_t = self.scaling_t_inverse_activation(
470 | self.get_scaling_t[selected_pts_mask].repeat(N, 1)/ (decay_factor))
471 |
472 |
473 |
474 | new_velocity = self._velocity[selected_pts_mask].repeat(N, 1)
475 |
476 | new_xyz = new_xyz + self.get_inst_velocity[selected_pts_mask].repeat(N, 1) * (samples_t)
477 |
478 | not_split_xyz_mask = torch.max(self.get_scaling[selected_pts_mask], dim=1).values < \
479 | self.percent_dense * scene_extent*scale_factor[selected_pts_mask]
480 | new_scaling[not_split_xyz_mask.repeat(N)] = self.scaling_inverse_activation(
481 | self.get_scaling[selected_pts_mask].repeat(N, 1))[not_split_xyz_mask.repeat(N)]
482 |
483 | if time_split:
484 | not_split_t_mask = self.get_scaling_t[selected_pts_mask].squeeze() < extend_thresh
485 | new_scaling_t[not_split_t_mask.repeat(N)] = self.scaling_t_inverse_activation(
486 | self.get_scaling_t[selected_pts_mask].repeat(N, 1))[not_split_t_mask.repeat(N)]
487 |
488 | if self.no_time_split:
489 | new_scaling_t = self.scaling_t_inverse_activation(
490 | self.get_scaling_t[selected_pts_mask].repeat(N, 1))
491 |
492 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation,
493 | new_t, new_scaling_t, new_velocity)
494 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
495 | self.prune_points(prune_filter)
496 |
497 | def densify_and_clone(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, time_clone=False):
498 | t_scale_factor=self.get_scaling_t.clamp(0,self.T)
499 | t_scale_factor=torch.exp(-t_scale_factor/self.T).squeeze()
500 |
501 | if self.contract:
502 | scale_factor = self._xyz.norm(dim=-1)*scene_extent-1
503 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent
504 | else:
505 | scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent
506 |
507 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
508 | selected_pts_mask = torch.logical_and(selected_pts_mask,torch.max(self.get_scaling,dim=1).values <= self.percent_dense * scene_extent*scale_factor)
509 | if time_clone:
510 | selected_time_mask = torch.where(torch.norm(grads_t, dim=-1) >= grad_t_threshold, True, False)
511 | extend_thresh = self.percent_dense
512 | selected_time_mask = torch.logical_and(selected_time_mask,
513 | torch.max(self.get_scaling_t, dim=1).values <= extend_thresh)
514 | selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask)
515 |
516 | new_xyz = self._xyz[selected_pts_mask]
517 | new_features_dc = self._features_dc[selected_pts_mask]
518 | new_features_rest = self._features_rest[selected_pts_mask]
519 | new_opacities = self._opacity[selected_pts_mask]
520 | new_scaling = self._scaling[selected_pts_mask]
521 | new_rotation = self._rotation[selected_pts_mask]
522 | new_t = None
523 | new_scaling_t = None
524 | new_velocity = None
525 | new_t = self._t[selected_pts_mask]
526 | new_scaling_t = self._scaling_t[selected_pts_mask]
527 | new_velocity = self._velocity[selected_pts_mask]
528 |
529 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling,
530 | new_rotation, new_t, new_scaling_t, new_velocity)
531 |
532 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, max_grad_t=None, prune_only=False):
533 | if not prune_only:
534 | grads = self.xyz_gradient_accum / self.denom
535 | grads[grads.isnan()] = 0.0
536 | grads_t = self.t_gradient_accum / self.denom
537 | grads_t[grads_t.isnan()] = 0.0
538 |
539 | if self.t_grad:
540 | self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=True)
541 | self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=True)
542 | else:
543 | self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=False)
544 | self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=False)
545 |
546 | prune_mask = (self.get_opacity < min_opacity).squeeze()
547 |
548 | if self.contract:
549 | scale_factor = self._xyz.norm(dim=-1)*extent-1
550 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/extent
551 | else:
552 | scale_factor = torch.ones_like(self._xyz)[:,0]/extent
553 |
554 | if max_screen_size:
555 | big_points_vs = self.max_radii2D > max_screen_size
556 | big_points_ws = self.get_scaling.max(dim=1).values > self.big_point_threshold * extent * scale_factor ## ori 0.1
557 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
558 | self.prune_points(prune_mask)
559 |
560 | torch.cuda.empty_cache()
561 |
562 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
563 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1,
564 | keepdim=True)
565 | self.denom[update_filter] += 1
566 | self.t_gradient_accum[update_filter] += self._t.grad.clone()[update_filter]
567 |
--------------------------------------------------------------------------------
/scene/kittimot_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 |
4 | import imageio
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 | from PIL import Image
9 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly
10 | from pathlib import Path
11 | camera_ls = [2, 3]
12 |
13 | """
14 | Most function brought from MARS
15 | https://github.com/OPEN-AIR-SUN/mars/blob/69b9bf9d992e6b9f4027dfdc2a741c2a33eef174/mars/data/mars_kitti_dataparser.py
16 | """
17 |
18 | def pad_poses(p):
19 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
20 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
21 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
22 |
23 |
24 | def unpad_poses(p):
25 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
26 | return p[..., :3, :4]
27 |
28 |
29 | def transform_poses_pca(poses, fix_radius=0):
30 | """Transforms poses so principal components lie on XYZ axes.
31 |
32 | Args:
33 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
34 |
35 | Returns:
36 | A tuple (poses, transform), with the transformed poses and the applied
37 | camera_to_world transforms.
38 |
39 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161
40 | """
41 | t = poses[:, :3, 3]
42 | t_mean = t.mean(axis=0)
43 | t = t - t_mean
44 |
45 | eigval, eigvec = np.linalg.eig(t.T @ t)
46 | # Sort eigenvectors in order of largest to smallest eigenvalue.
47 | inds = np.argsort(eigval)[::-1]
48 | eigvec = eigvec[:, inds]
49 | rot = eigvec.T
50 | if np.linalg.det(rot) < 0:
51 | rot = np.diag(np.array([1, 1, -1])) @ rot
52 |
53 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
54 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
55 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
56 |
57 | # Flip coordinate system if z component of y-axis is negative
58 | if poses_recentered.mean(axis=0)[2, 1] < 0:
59 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
60 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
61 |
62 | # Just make sure it's it in the [-1, 1]^3 cube
63 | if fix_radius>0:
64 | scale_factor = 1./fix_radius
65 | else:
66 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5)
67 | scale_factor = min(1 / 10, scale_factor)
68 |
69 | poses_recentered[:, :3, 3] *= scale_factor
70 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
71 |
72 | return poses_recentered, transform, scale_factor
73 |
74 | def kitti_string_to_float(str):
75 | return float(str.split("e")[0]) * 10 ** int(str.split("e")[1])
76 |
77 |
78 | def get_rotation(roll, pitch, heading):
79 | s_heading = np.sin(heading)
80 | c_heading = np.cos(heading)
81 | rot_z = np.array([[c_heading, -s_heading, 0], [s_heading, c_heading, 0], [0, 0, 1]])
82 |
83 | s_pitch = np.sin(pitch)
84 | c_pitch = np.cos(pitch)
85 | rot_y = np.array([[c_pitch, 0, s_pitch], [0, 1, 0], [-s_pitch, 0, c_pitch]])
86 |
87 | s_roll = np.sin(roll)
88 | c_roll = np.cos(roll)
89 | rot_x = np.array([[1, 0, 0], [0, c_roll, -s_roll], [0, s_roll, c_roll]])
90 |
91 | rot = np.matmul(rot_z, np.matmul(rot_y, rot_x))
92 |
93 | return rot
94 |
95 |
96 | def tracking_calib_from_txt(calibration_path):
97 | """
98 | Extract tracking calibration information from a KITTI tracking calibration file.
99 |
100 | This function reads a KITTI tracking calibration file and extracts the relevant
101 | calibration information, including projection matrices and transformation matrices
102 | for camera, LiDAR, and IMU coordinate systems.
103 |
104 | Args:
105 | calibration_path (str): Path to the KITTI tracking calibration file.
106 |
107 | Returns:
108 | dict: A dictionary containing the following calibration information:
109 | P0, P1, P2, P3 (np.array): 3x4 projection matrices for the cameras.
110 | Tr_cam2camrect (np.array): 4x4 transformation matrix from camera to rectified camera coordinates.
111 | Tr_velo2cam (np.array): 4x4 transformation matrix from LiDAR to camera coordinates.
112 | Tr_imu2velo (np.array): 4x4 transformation matrix from IMU to LiDAR coordinates.
113 | """
114 | # Read the calibration file
115 | f = open(calibration_path)
116 | calib_str = f.read().splitlines()
117 |
118 | # Process the calibration data
119 | calibs = []
120 | for calibration in calib_str:
121 | calibs.append(np.array([kitti_string_to_float(val) for val in calibration.split()[1:]]))
122 |
123 | # Extract the projection matrices
124 | P0 = np.reshape(calibs[0], [3, 4])
125 | P1 = np.reshape(calibs[1], [3, 4])
126 | P2 = np.reshape(calibs[2], [3, 4])
127 | P3 = np.reshape(calibs[3], [3, 4])
128 |
129 | # Extract the transformation matrix for camera to rectified camera coordinates
130 | Tr_cam2camrect = np.eye(4)
131 | R_rect = np.reshape(calibs[4], [3, 3])
132 | Tr_cam2camrect[:3, :3] = R_rect
133 |
134 | # Extract the transformation matrices for LiDAR to camera and IMU to LiDAR coordinates
135 | Tr_velo2cam = np.concatenate([np.reshape(calibs[5], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)
136 | Tr_imu2velo = np.concatenate([np.reshape(calibs[6], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)
137 |
138 | return {
139 | "P0": P0,
140 | "P1": P1,
141 | "P2": P2,
142 | "P3": P3,
143 | "Tr_cam2camrect": Tr_cam2camrect,
144 | "Tr_velo2cam": Tr_velo2cam,
145 | "Tr_imu2velo": Tr_imu2velo,
146 | }
147 |
148 |
149 | def calib_from_txt(calibration_path):
150 | """
151 | Read the calibration files and extract the required transformation matrices and focal length.
152 |
153 | Args:
154 | calibration_path (str): The path to the directory containing the calibration files.
155 |
156 | Returns:
157 | tuple: A tuple containing the following elements:
158 | traimu2v (np.array): 4x4 transformation matrix from IMU to Velodyne coordinates.
159 | v2c (np.array): 4x4 transformation matrix from Velodyne to left camera coordinates.
160 | c2leftRGB (np.array): 4x4 transformation matrix from left camera to rectified left camera coordinates.
161 | c2rightRGB (np.array): 4x4 transformation matrix from right camera to rectified right camera coordinates.
162 | focal (float): Focal length of the left camera.
163 | """
164 | c2c = []
165 |
166 | # Read and parse the camera-to-camera calibration file
167 | f = open(os.path.join(calibration_path, "calib_cam_to_cam.txt"), "r")
168 | cam_to_cam_str = f.read()
169 | [left_cam, right_cam] = cam_to_cam_str.split("S_02: ")[1].split("S_03: ")
170 | cam_to_cam_ls = [left_cam, right_cam]
171 |
172 | # Extract the transformation matrices for left and right cameras
173 | for i, cam_str in enumerate(cam_to_cam_ls):
174 | r_str, t_str = cam_str.split("R_0" + str(i + 2) + ": ")[1].split("\nT_0" + str(i + 2) + ": ")
175 | t_str = t_str.split("\n")[0]
176 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
177 | R = np.reshape(R, [3, 3])
178 | t = np.array([kitti_string_to_float(t) for t in t_str.split(" ")])
179 | Tr = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])
180 |
181 | t_str_rect, s_rect_part = cam_str.split("\nT_0" + str(i + 2) + ": ")[1].split("\nS_rect_0" + str(i + 2) + ": ")
182 | s_rect_str, r_rect_part = s_rect_part.split("\nR_rect_0" + str(i + 2) + ": ")
183 | r_rect_str = r_rect_part.split("\nP_rect_0" + str(i + 2) + ": ")[0]
184 | R_rect = np.array([kitti_string_to_float(r) for r in r_rect_str.split(" ")])
185 | R_rect = np.reshape(R_rect, [3, 3])
186 | t_rect = np.array([kitti_string_to_float(t) for t in t_str_rect.split(" ")])
187 | Tr_rect = np.concatenate(
188 | [np.concatenate([R_rect, t_rect[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]
189 | )
190 |
191 | c2c.append(Tr_rect)
192 |
193 | c2leftRGB = c2c[0]
194 | c2rightRGB = c2c[1]
195 |
196 | # Read and parse the Velodyne-to-camera calibration file
197 | f = open(os.path.join(calibration_path, "calib_velo_to_cam.txt"), "r")
198 | velo_to_cam_str = f.read()
199 | r_str, t_str = velo_to_cam_str.split("R: ")[1].split("\nT: ")
200 | t_str = t_str.split("\n")[0]
201 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
202 | R = np.reshape(R, [3, 3])
203 | t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
204 | v2c = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])
205 |
206 | # Read and parse the IMU-to-Velodyne calibration file
207 | f = open(os.path.join(calibration_path, "calib_imu_to_velo.txt"), "r")
208 | imu_to_velo_str = f.read()
209 | r_str, t_str = imu_to_velo_str.split("R: ")[1].split("\nT: ")
210 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
211 | R = np.reshape(R, [3, 3])
212 | t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
213 | imu2v = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])
214 |
215 | # Extract the focal length of the left camera
216 | focal = kitti_string_to_float(left_cam.split("P_rect_02: ")[1].split()[0])
217 |
218 | return imu2v, v2c, c2leftRGB, c2rightRGB, focal
219 |
220 |
221 | def get_poses_calibration(basedir, oxts_path_tracking=None, selected_frames=None):
222 | """
223 | Extract poses and calibration information from the KITTI dataset.
224 |
225 | This function processes the OXTS data (GPS/IMU) and extracts the
226 | pose information (translation and rotation) for each frame. It also
227 | retrieves the calibration information (transformation matrices and focal length)
228 | required for further processing.
229 |
230 | Args:
231 | basedir (str): The base directory containing the KITTI dataset.
232 | oxts_path_tracking (str, optional): Path to the OXTS data file for tracking sequences.
233 | If not provided, the function will look for OXTS data in the basedir.
234 | selected_frames (list, optional): A list of frame indices to process.
235 | If not provided, all frames in the dataset will be processed.
236 |
237 | Returns:
238 | tuple: A tuple containing the following elements:
239 | poses (np.array): An array of 4x4 pose matrices representing the vehicle's
240 | position and orientation for each frame (IMU pose).
241 | calibrations (dict): A dictionary containing the transformation matrices
242 | and focal length obtained from the calibration files.
243 | focal (float): The focal length of the left camera.
244 | """
245 |
246 | def oxts_to_pose(oxts):
247 | """
248 | OXTS (Oxford Technical Solutions) data typically refers to the data generated by an Inertial and GPS Navigation System (INS/GPS) that is used to provide accurate position, orientation, and velocity information for a moving platform, such as a vehicle. In the context of the KITTI dataset, OXTS data is used to provide the ground truth for the vehicle's trajectory and 6 degrees of freedom (6-DoF) motion, which is essential for evaluating and benchmarking various computer vision and robotics algorithms, such as visual odometry, SLAM, and object detection.
249 |
250 | The OXTS data contains several important measurements:
251 |
252 | 1. Latitude, longitude, and altitude: These are the global coordinates of the moving platform.
253 | 2. Roll, pitch, and yaw (heading): These are the orientation angles of the platform, usually given in Euler angles.
254 | 3. Velocity (north, east, and down): These are the linear velocities of the platform in the local navigation frame.
255 | 4. Accelerations (ax, ay, az): These are the linear accelerations in the platform's body frame.
256 | 5. Angular rates (wx, wy, wz): These are the angular rates (also known as angular velocities) of the platform in its body frame.
257 |
258 | In the KITTI dataset, the OXTS data is stored as plain text files with each line corresponding to a timestamp. Each line in the file contains the aforementioned measurements, which are used to compute the ground truth trajectory and 6-DoF motion of the vehicle. This information can be further used for calibration, data synchronization, and performance evaluation of various algorithms.
259 | """
260 | poses = []
261 |
262 | def latlon_to_mercator(lat, lon, s):
263 | """
264 | Converts latitude and longitude coordinates to Mercator coordinates (x, y) using the given scale factor.
265 |
266 | The Mercator projection is a widely used cylindrical map projection that represents the Earth's surface
267 | as a flat, rectangular grid, distorting the size of geographical features in higher latitudes.
268 | This function uses the scale factor 's' to control the amount of distortion in the projection.
269 |
270 | Args:
271 | lat (float): Latitude in degrees, range: -90 to 90.
272 | lon (float): Longitude in degrees, range: -180 to 180.
273 | s (float): Scale factor, typically the cosine of the reference latitude.
274 |
275 | Returns:
276 | list: A list containing the Mercator coordinates [x, y] in meters.
277 | """
278 | r = 6378137.0 # the Earth's equatorial radius in meters
279 | x = s * r * ((np.pi * lon) / 180)
280 | y = s * r * np.log(np.tan((np.pi * (90 + lat)) / 360))
281 | return [x, y]
282 |
283 | # Compute the initial scale and pose based on the selected frames
284 | if selected_frames is None:
285 | lat0 = oxts[0][0]
286 | scale = np.cos(lat0 * np.pi / 180)
287 | pose_0_inv = None
288 | else:
289 | oxts0 = oxts[selected_frames[0][0]]
290 | lat0 = oxts0[0]
291 | scale = np.cos(lat0 * np.pi / 180)
292 |
293 | pose_i = np.eye(4)
294 |
295 | [x, y] = latlon_to_mercator(oxts0[0], oxts0[1], scale)
296 | z = oxts0[2]
297 | translation = np.array([x, y, z])
298 | rotation = get_rotation(oxts0[3], oxts0[4], oxts0[5])
299 | pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1)
300 | pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])
301 |
302 | # Iterate through the OXTS data and compute the corresponding pose matrices
303 | for oxts_val in oxts:
304 | pose_i = np.zeros([4, 4])
305 | pose_i[3, 3] = 1
306 |
307 | [x, y] = latlon_to_mercator(oxts_val[0], oxts_val[1], scale)
308 | z = oxts_val[2]
309 | translation = np.array([x, y, z])
310 |
311 | roll = oxts_val[3]
312 | pitch = oxts_val[4]
313 | heading = oxts_val[5]
314 | rotation = get_rotation(roll, pitch, heading) # (3,3)
315 |
316 | pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1) # (4, 4)
317 | if pose_0_inv is None:
318 | pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])
319 |
320 | pose_i = np.matmul(pose_0_inv, pose_i)
321 | poses.append(pose_i)
322 |
323 | return np.array(poses)
324 |
325 | # If there is no tracking path specified, use the default path
326 | if oxts_path_tracking is None:
327 | oxts_path = os.path.join(basedir, "oxts/data")
328 | oxts = np.array([np.loadtxt(os.path.join(oxts_path, file)) for file in sorted(os.listdir(oxts_path))])
329 | calibration_path = os.path.dirname(basedir)
330 |
331 | calibrations = calib_from_txt(calibration_path)
332 |
333 | focal = calibrations[4]
334 |
335 | poses = oxts_to_pose(oxts)
336 |
337 | # If a tracking path is specified, use it to load OXTS data and compute the poses
338 | else:
339 | oxts_tracking = np.loadtxt(oxts_path_tracking)
340 | poses = oxts_to_pose(oxts_tracking) # (n_frames, 4, 4)
341 | calibrations = None
342 | focal = None
343 | # Set velodyne close to z = 0
344 | # poses[:, 2, 3] -= 0.8
345 |
346 | # Return the poses, calibrations, and focal length
347 | return poses, calibrations, focal
348 |
349 |
350 | def invert_transformation(rot, t):
351 | t = np.matmul(-rot.T, t)
352 | inv_translation = np.concatenate([rot.T, t[:, None]], axis=1)
353 | return np.concatenate([inv_translation, np.array([[0.0, 0.0, 0.0, 1.0]])])
354 |
355 |
356 | def get_camera_poses_tracking(poses_velo_w_tracking, tracking_calibration, selected_frames, scene_no=None):
357 | exp = False
358 | camera_poses = []
359 |
360 | opengl2kitti = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
361 |
362 | start_frame = selected_frames[0]
363 | end_frame = selected_frames[1]
364 |
365 | #####################
366 | # Debug Camera offset
367 | if scene_no == 2:
368 | yaw = np.deg2rad(0.7) ## Affects camera rig roll: High --> counterclockwise
369 | pitch = np.deg2rad(-0.5) ## Affects camera rig yaw: High --> Turn Right
370 | # pitch = np.deg2rad(-0.97)
371 | roll = np.deg2rad(0.9) ## Affects camera rig pitch: High --> up
372 | # roll = np.deg2rad(1.2)
373 | elif scene_no == 1:
374 | if exp:
375 | yaw = np.deg2rad(0.3) ## Affects camera rig roll: High --> counterclockwise
376 | pitch = np.deg2rad(-0.6) ## Affects camera rig yaw: High --> Turn Right
377 | # pitch = np.deg2rad(-0.97)
378 | roll = np.deg2rad(0.75) ## Affects camera rig pitch: High --> up
379 | # roll = np.deg2rad(1.2)
380 | else:
381 | yaw = np.deg2rad(0.5) ## Affects camera rig roll: High --> counterclockwise
382 | pitch = np.deg2rad(-0.5) ## Affects camera rig yaw: High --> Turn Right
383 | roll = np.deg2rad(0.75) ## Affects camera rig pitch: High --> up
384 | else:
385 | yaw = np.deg2rad(0.05)
386 | pitch = np.deg2rad(-0.75)
387 | # pitch = np.deg2rad(-0.97)
388 | roll = np.deg2rad(1.05)
389 | # roll = np.deg2rad(1.2)
390 |
391 | cam_debug = np.eye(4)
392 | cam_debug[:3, :3] = get_rotation(roll, pitch, yaw)
393 |
394 | Tr_cam2camrect = tracking_calibration["Tr_cam2camrect"]
395 | Tr_cam2camrect = np.matmul(Tr_cam2camrect, cam_debug)
396 | Tr_camrect2cam = invert_transformation(Tr_cam2camrect[:3, :3], Tr_cam2camrect[:3, 3])
397 | Tr_velo2cam = tracking_calibration["Tr_velo2cam"]
398 | Tr_cam2velo = invert_transformation(Tr_velo2cam[:3, :3], Tr_velo2cam[:3, 3])
399 |
400 | camera_poses_imu = []
401 | for cam in camera_ls:
402 | Tr_camrect2cam_i = tracking_calibration["Tr_camrect2cam0" + str(cam)]
403 | Tr_cam_i2camrect = invert_transformation(Tr_camrect2cam_i[:3, :3], Tr_camrect2cam_i[:3, 3])
404 | # transform camera axis from kitti to opengl for nerf:
405 | cam_i_camrect = np.matmul(Tr_cam_i2camrect, opengl2kitti)
406 | cam_i_cam0 = np.matmul(Tr_camrect2cam, cam_i_camrect)
407 | cam_i_velo = np.matmul(Tr_cam2velo, cam_i_cam0)
408 |
409 | cam_i_w = np.matmul(poses_velo_w_tracking, cam_i_velo)
410 | camera_poses_imu.append(cam_i_w)
411 |
412 | for i, cam in enumerate(camera_ls):
413 | for frame_no in range(start_frame, end_frame + 1):
414 | camera_poses.append(camera_poses_imu[i][frame_no])
415 |
416 | return np.array(camera_poses)
417 |
418 |
419 | def get_scene_images_tracking(tracking_path, sequence, selected_frames):
420 | [start_frame, end_frame] = selected_frames
421 | img_name = []
422 | sky_name = []
423 |
424 | left_img_path = os.path.join(os.path.join(tracking_path, "image_02"), sequence)
425 | right_img_path = os.path.join(os.path.join(tracking_path, "image_03"), sequence)
426 |
427 | left_sky_path = os.path.join(os.path.join(tracking_path, "sky_02"), sequence)
428 | right_sky_path = os.path.join(os.path.join(tracking_path, "sky_03"), sequence)
429 |
430 | for frame_dir in [left_img_path, right_img_path]:
431 | for frame_no in range(len(os.listdir(left_img_path))):
432 | if start_frame <= frame_no <= end_frame:
433 | frame = sorted(os.listdir(frame_dir))[frame_no]
434 | fname = os.path.join(frame_dir, frame)
435 | img_name.append(fname)
436 |
437 | for frame_dir in [left_sky_path, right_sky_path]:
438 | for frame_no in range(len(os.listdir(left_sky_path))):
439 | if start_frame <= frame_no <= end_frame:
440 | frame = sorted(os.listdir(frame_dir))[frame_no]
441 | fname = os.path.join(frame_dir, frame)
442 | sky_name.append(fname)
443 |
444 | return img_name, sky_name
445 |
446 | def rotation_matrix(a, b):
447 | """Compute the rotation matrix that rotates vector a to vector b.
448 |
449 | Args:
450 | a: The vector to rotate.
451 | b: The vector to rotate to.
452 | Returns:
453 | The rotation matrix.
454 | """
455 | a = a / torch.linalg.norm(a)
456 | b = b / torch.linalg.norm(b)
457 | v = torch.cross(a, b)
458 | c = torch.dot(a, b)
459 | # If vectors are exactly opposite, we add a little noise to one of them
460 | if c < -1 + 1e-8:
461 | eps = (torch.rand(3) - 0.5) * 0.01
462 | return rotation_matrix(a + eps, b)
463 | s = torch.linalg.norm(v)
464 | skew_sym_mat = torch.Tensor(
465 | [
466 | [0, -v[2], v[1]],
467 | [v[2], 0, -v[0]],
468 | [-v[1], v[0], 0],
469 | ]
470 | )
471 | return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8))
472 |
473 | def auto_orient_and_center_poses(
474 | poses,
475 | ):
476 | """
477 | From nerfstudio
478 | https://github.com/nerfstudio-project/nerfstudio/blob/8e0c68754b2c440e2d83864fac586cddcac52dc4/nerfstudio/cameras/camera_utils.py#L515
479 | """
480 | origins = poses[..., :3, 3]
481 | mean_origin = torch.mean(origins, dim=0)
482 | translation = mean_origin
483 | up = torch.mean(poses[:, :3, 1], dim=0)
484 | up = up / torch.linalg.norm(up)
485 | rotation = rotation_matrix(up, torch.Tensor([0, 0, 1]))
486 | transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
487 | oriented_poses = transform @ poses
488 | return oriented_poses, transform
489 |
490 | def readKittiMotInfo(args):
491 | cam_infos = []
492 | points = []
493 | points_time = []
494 | scale_factor = 1.0
495 |
496 | basedir = args.source_path
497 | scene_id = basedir[-4:] # check
498 | kitti_scene_no = int(scene_id)
499 | tracking_path = basedir[:-13] # check
500 | calibration_path = os.path.join(os.path.join(tracking_path, "calib"), scene_id + ".txt")
501 | oxts_path_tracking = os.path.join(os.path.join(tracking_path, "oxts"), scene_id + ".txt")
502 |
503 | tracking_calibration = tracking_calib_from_txt(calibration_path)
504 | focal_X = tracking_calibration["P2"][0, 0]
505 | focal_Y = tracking_calibration["P2"][1, 1]
506 | poses_imu_w_tracking, _, _ = get_poses_calibration(basedir, oxts_path_tracking) # (n_frames, 4, 4) imu pose
507 |
508 | tr_imu2velo = tracking_calibration["Tr_imu2velo"]
509 | tr_velo2imu = invert_transformation(tr_imu2velo[:3, :3], tr_imu2velo[:3, 3])
510 | poses_velo_w_tracking = np.matmul(poses_imu_w_tracking, tr_velo2imu) # (n_frames, 4, 4) velodyne pose
511 |
512 | # Get camera Poses camare id: 02, 03
513 | for cam_i in range(2):
514 | transformation = np.eye(4)
515 | projection = tracking_calibration["P" + str(cam_i + 2)] # rectified camera coordinate system -> image
516 | K_inv = np.linalg.inv(projection[:3, :3])
517 | R_t = projection[:3, 3]
518 | t_crect2c = np.matmul(K_inv, R_t)
519 | transformation[:3, 3] = t_crect2c
520 | tracking_calibration["Tr_camrect2cam0" + str(cam_i + 2)] = transformation
521 |
522 | first_frame = args.start_frame
523 | last_frame = args.end_frame
524 |
525 | frame_num = last_frame-first_frame+1
526 | if args.frame_interval > 0:
527 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2]
528 | else:
529 | time_duration = args.time_duration
530 |
531 | selected_frames = [first_frame, last_frame]
532 | sequ_frames = selected_frames
533 |
534 | cam_poses_tracking = get_camera_poses_tracking(
535 | poses_velo_w_tracking, tracking_calibration, sequ_frames, kitti_scene_no
536 | )
537 | poses_velo_w_tracking = poses_velo_w_tracking[first_frame:last_frame + 1]
538 |
539 | # Orients and centers the poses
540 | oriented = torch.from_numpy(np.array(cam_poses_tracking).astype(np.float32)) # (n_frames, 3, 4)
541 | oriented, transform_matrix = auto_orient_and_center_poses(
542 | oriented
543 | ) # oriented (n_frames, 3, 4), transform_matrix (3, 4)
544 | row = torch.tensor([0, 0, 0, 1], dtype=torch.float32)
545 | zeros = torch.zeros(oriented.shape[0], 1, 4)
546 | oriented = torch.cat([oriented, zeros], dim=1)
547 | oriented[:, -1] = row # (n_frames, 4, 4)
548 | transform_matrix = torch.cat([transform_matrix, row[None, :]], dim=0) # (4, 4)
549 | cam_poses_tracking = oriented.numpy()
550 | transform_matrix = transform_matrix.numpy()
551 |
552 | image_filenames, sky_filenames = get_scene_images_tracking(
553 | tracking_path, scene_id, sequ_frames)
554 |
555 | # # Align Axis with vkitti axis
556 | poses = cam_poses_tracking.astype(np.float32)
557 | poses[:, :, 1:3] *= -1
558 |
559 | test_load_image = imageio.imread(image_filenames[0])
560 | image_height, image_width = test_load_image.shape[:2]
561 | cx, cy = image_width / 2.0, image_height / 2.0
562 | poses[..., :3, 3] *= scale_factor
563 |
564 | c2ws = poses
565 | for idx in tqdm(range(len(c2ws)), desc="Loading data"):
566 | c2w = c2ws[idx]
567 | w2c = np.linalg.inv(c2w)
568 | image_path = image_filenames[idx]
569 | image_name = os.path.basename(image_path)[:-4]
570 | sky_path = sky_filenames[idx]
571 | im_data = Image.open(image_path)
572 | W, H = im_data.size
573 | image = np.array(im_data) / 255.
574 |
575 | sky_mask = cv2.imread(sky_path)
576 |
577 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * (idx % (len(c2ws) // 2)) / (len(c2ws) // 2 - 1)
578 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
579 | T = w2c[:3, 3]
580 |
581 | if idx < len(c2ws) / 2:
582 | point = np.fromfile(os.path.join(tracking_path, "velodyne", scene_id, image_name + ".bin"), dtype=np.float32).reshape(-1, 4)
583 | point_xyz = point[:, :3]
584 | point_xyz_world = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ poses_velo_w_tracking[idx].T)[:, :3]
585 | points.append(point_xyz_world)
586 | point_time = np.full_like(point_xyz_world[:, :1], timestamp)
587 | points_time.append(point_time)
588 | frame_num = len(c2ws) // 2
589 | point_xyz = points[idx%frame_num]
590 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1)@ transform_matrix.T @ w2c.T)[:, :3]*scale_factor
591 |
592 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T,
593 | image=image,
594 | image_path=image_filenames[idx], image_name=image_filenames[idx],
595 | width=W, height=H, timestamp=timestamp,
596 | fx=focal_X, fy=focal_Y, cx=cx, cy=cy, sky_mask=sky_mask,
597 | pointcloud_camera=point_camera))
598 |
599 | if args.debug_cuda and idx > 5:
600 | break
601 | pointcloud = np.concatenate(points, axis=0)
602 | pointcloud = (np.concatenate([pointcloud, np.ones_like(pointcloud[:,:1])], axis=-1) @ transform_matrix.T)[:, :3]
603 |
604 | pointcloud_timestamp = np.concatenate(points_time, axis=0)
605 |
606 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True)
607 | pointcloud = pointcloud[indices]
608 | pointcloud_timestamp = pointcloud_timestamp[indices]
609 |
610 | # normalize poses
611 | w2cs = np.zeros((len(cam_infos), 4, 4))
612 | Rs = np.stack([c.R for c in cam_infos], axis=0)
613 | Ts = np.stack([c.T for c in cam_infos], axis=0)
614 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1))
615 | w2cs[:, :3, 3] = Ts
616 | w2cs[:, 3, 3] = 1
617 | c2ws = unpad_poses(np.linalg.inv(w2cs))
618 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius)
619 | c2ws = pad_poses(c2ws)
620 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")):
621 | c2w = c2ws[idx]
622 | w2c = np.linalg.inv(c2w)
623 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
624 | cam_info.T[:] = w2c[:3, 3]
625 | cam_info.pointcloud_camera[:] *= scale_factor
626 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3]
627 |
628 | if args.eval:
629 | num_frame = len(cam_infos)//2
630 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx % num_frame + 1) % args.testhold != 0]
631 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx % num_frame + 1) % args.testhold == 0]
632 | else:
633 | train_cam_infos = cam_infos
634 | test_cam_infos = []
635 |
636 | # for kitti have some static ego videos, we dont calculate radius here
637 | nerf_normalization = getNerfppNorm(train_cam_infos)
638 | nerf_normalization['radius'] = 1
639 |
640 | ply_path = os.path.join(args.source_path, "points3d.ply")
641 | if not os.path.exists(ply_path):
642 | rgbs = np.random.random((pointcloud.shape[0], 3))
643 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp)
644 | try:
645 | pcd = fetchPly(ply_path)
646 | except:
647 | pcd = None
648 |
649 | time_interval = (time_duration[1] - time_duration[0]) / (frame_num - 1)
650 |
651 |
652 | scene_info = SceneInfo(point_cloud=pcd,
653 | train_cameras=train_cam_infos,
654 | test_cameras=test_cam_infos,
655 | nerf_normalization=nerf_normalization,
656 | ply_path=ply_path,
657 | time_interval=time_interval)
658 |
659 | return scene_info
--------------------------------------------------------------------------------
/scene/scene_utils.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 | import numpy as np
3 | from utils.graphics_utils import getWorld2View2
4 | from scene.gaussian_model import BasicPointCloud
5 | from plyfile import PlyData, PlyElement
6 |
7 |
8 | class CameraInfo(NamedTuple):
9 | uid: int
10 | R: np.array
11 | T: np.array
12 | image: np.array
13 | image_path: str
14 | image_name: str
15 | width: int
16 | height: int
17 | sky_mask: np.array = None
18 | timestamp: float = 0.0
19 | FovY: float = None
20 | FovX: float = None
21 | fx: float = None
22 | fy: float = None
23 | cx: float = None
24 | cy: float = None
25 | pointcloud_camera: np.array = None
26 |
27 | class SceneInfo(NamedTuple):
28 | point_cloud: BasicPointCloud
29 | train_cameras: list
30 | test_cameras: list
31 | nerf_normalization: dict
32 | ply_path: str
33 | time_interval: float = 0.02
34 | time_duration: list = [-0.5, 0.5]
35 |
36 | def getNerfppNorm(cam_info):
37 | def get_center_and_diag(cam_centers):
38 | cam_centers = np.hstack(cam_centers)
39 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
40 | center = avg_cam_center
41 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
42 | diagonal = np.max(dist)
43 | return center.flatten(), diagonal
44 |
45 | cam_centers = []
46 |
47 | for cam in cam_info:
48 | W2C = getWorld2View2(cam.R, cam.T)
49 | C2W = np.linalg.inv(W2C)
50 | cam_centers.append(C2W[:3, 3:4])
51 |
52 | center, diagonal = get_center_and_diag(cam_centers)
53 | radius = diagonal * 1.1
54 |
55 | translate = -center
56 |
57 | return {"translate": translate, "radius": radius}
58 |
59 |
60 | def fetchPly(path):
61 | plydata = PlyData.read(path)
62 | vertices = plydata['vertex']
63 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
64 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
65 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
66 | if 'time' in vertices:
67 | timestamp = vertices['time'][:, None]
68 | else:
69 | timestamp = None
70 | return BasicPointCloud(points=positions, colors=colors, normals=normals, time=timestamp)
71 |
72 |
73 | def storePly(path, xyz, rgb, timestamp=None):
74 | # Define the dtype for the structured array
75 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
76 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
77 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
78 | ('time', 'f4')]
79 |
80 | normals = np.zeros_like(xyz)
81 | if timestamp is None:
82 | timestamp = np.zeros_like(xyz[:, :1])
83 |
84 | elements = np.empty(xyz.shape[0], dtype=dtype)
85 | attributes = np.concatenate((xyz, normals, rgb, timestamp), axis=1)
86 | elements[:] = list(map(tuple, attributes))
87 |
88 | # Create the PlyData object and write to file
89 | vertex_element = PlyElement.describe(elements, 'vertex')
90 | ply_data = PlyData([vertex_element])
91 | ply_data.write(path)
92 |
--------------------------------------------------------------------------------
/scene/waymo_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from PIL import Image
5 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly
6 | from utils.graphics_utils import BasicPointCloud
7 |
8 |
9 | def pad_poses(p):
10 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
11 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
12 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
13 |
14 |
15 | def unpad_poses(p):
16 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
17 | return p[..., :3, :4]
18 |
19 |
20 | def transform_poses_pca(poses, fix_radius=0):
21 | """Transforms poses so principal components lie on XYZ axes.
22 |
23 | Args:
24 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
25 |
26 | Returns:
27 | A tuple (poses, transform), with the transformed poses and the applied
28 | camera_to_world transforms.
29 |
30 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161
31 | """
32 | t = poses[:, :3, 3]
33 | t_mean = t.mean(axis=0)
34 | t = t - t_mean
35 |
36 | eigval, eigvec = np.linalg.eig(t.T @ t)
37 | # Sort eigenvectors in order of largest to smallest eigenvalue.
38 | inds = np.argsort(eigval)[::-1]
39 | eigvec = eigvec[:, inds]
40 | rot = eigvec.T
41 | if np.linalg.det(rot) < 0:
42 | rot = np.diag(np.array([1, 1, -1])) @ rot
43 |
44 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
45 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
46 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
47 |
48 | # Flip coordinate system if z component of y-axis is negative
49 | if poses_recentered.mean(axis=0)[2, 1] < 0:
50 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
51 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
52 |
53 | # Just make sure it's it in the [-1, 1]^3 cube
54 | if fix_radius>0:
55 | scale_factor = 1./fix_radius
56 | else:
57 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5)
58 | scale_factor = min(1 / 10, scale_factor)
59 |
60 | poses_recentered[:, :3, 3] *= scale_factor
61 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
62 |
63 | return poses_recentered, transform, scale_factor
64 |
65 |
66 | def readWaymoInfo(args):
67 | cam_infos = []
68 | car_list = [f[:-4] for f in sorted(os.listdir(os.path.join(args.source_path, "calib"))) if f.endswith('.txt')]
69 | points = []
70 | points_time = []
71 |
72 | frame_num = len(car_list)
73 | if args.frame_interval > 0:
74 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2]
75 | else:
76 | time_duration = args.time_duration
77 |
78 | for idx, car_id in tqdm(enumerate(car_list), desc="Loading data"):
79 | ego_pose = np.loadtxt(os.path.join(args.source_path, 'pose', car_id + '.txt'))
80 |
81 | # CAMERA DIRECTION: RIGHT DOWN FORWARDS
82 | with open(os.path.join(args.source_path, 'calib', car_id + '.txt')) as f:
83 | calib_data = f.readlines()
84 | L = [list(map(float, line.split()[1:])) for line in calib_data]
85 | Ks = np.array(L[:5]).reshape(-1, 3, 4)[:, :, :3]
86 | lidar2cam = np.array(L[-5:]).reshape(-1, 3, 4)
87 | lidar2cam = pad_poses(lidar2cam)
88 |
89 | cam2lidar = np.linalg.inv(lidar2cam)
90 | c2w = ego_pose @ cam2lidar
91 | w2c = np.linalg.inv(c2w)
92 | images = []
93 | image_paths = []
94 | HWs = []
95 | for subdir in ['image_0', 'image_1', 'image_2', 'image_3', 'image_4'][:args.cam_num]:
96 | image_path = os.path.join(args.source_path, subdir, car_id + '.png')
97 | im_data = Image.open(image_path)
98 | W, H = im_data.size
99 | image = np.array(im_data) / 255.
100 | HWs.append((H, W))
101 | images.append(image)
102 | image_paths.append(image_path)
103 |
104 | sky_masks = []
105 | for subdir in ['sky_0', 'sky_1', 'sky_2', 'sky_3', 'sky_4'][:args.cam_num]:
106 | sky_data = np.array(Image.open(os.path.join(args.source_path, subdir, car_id + '.png')))
107 | sky_mask = sky_data>0
108 | sky_masks.append(sky_mask.astype(np.float32))
109 |
110 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (len(car_list) - 1)
111 | point = np.fromfile(os.path.join(args.source_path, "velodyne", car_id + ".bin"),
112 | dtype=np.float32, count=-1).reshape(-1, 6)
113 | point_xyz, intensity, elongation, timestamp_pts = np.split(point, [3, 4, 5], axis=1)
114 | point_xyz_world = (np.pad(point_xyz, (0, 1), constant_values=1) @ ego_pose.T)[:, :3]
115 | points.append(point_xyz_world)
116 | point_time = np.full_like(point_xyz_world[:, :1], timestamp)
117 | points_time.append(point_time)
118 | for j in range(args.cam_num):
119 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ lidar2cam[j].T)[:, :3]
120 | R = np.transpose(w2c[j, :3, :3]) # R is stored transposed due to 'glm' in CUDA code
121 | T = w2c[j, :3, 3]
122 | K = Ks[j]
123 | fx = float(K[0, 0])
124 | fy = float(K[1, 1])
125 | cx = float(K[0, 2])
126 | cy = float(K[1, 2])
127 | FovX = FovY = -1.0
128 | cam_infos.append(CameraInfo(uid=idx * 5 + j, R=R, T=T, FovY=FovY, FovX=FovX,
129 | image=images[j],
130 | image_path=image_paths[j], image_name=car_id,
131 | width=HWs[j][1], height=HWs[j][0], timestamp=timestamp,
132 | pointcloud_camera = point_camera,
133 | fx=fx, fy=fy, cx=cx, cy=cy,
134 | sky_mask=sky_masks[j]))
135 |
136 | if args.debug_cuda:
137 | break
138 |
139 | pointcloud = np.concatenate(points, axis=0)
140 | pointcloud_timestamp = np.concatenate(points_time, axis=0)
141 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True)
142 | pointcloud = pointcloud[indices]
143 | pointcloud_timestamp = pointcloud_timestamp[indices]
144 |
145 | w2cs = np.zeros((len(cam_infos), 4, 4))
146 | Rs = np.stack([c.R for c in cam_infos], axis=0)
147 | Ts = np.stack([c.T for c in cam_infos], axis=0)
148 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1))
149 | w2cs[:, :3, 3] = Ts
150 | w2cs[:, 3, 3] = 1
151 | c2ws = unpad_poses(np.linalg.inv(w2cs))
152 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius)
153 |
154 | c2ws = pad_poses(c2ws)
155 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")):
156 | c2w = c2ws[idx]
157 | w2c = np.linalg.inv(c2w)
158 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
159 | cam_info.T[:] = w2c[:3, 3]
160 | cam_info.pointcloud_camera[:] *= scale_factor
161 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3]
162 | if args.eval:
163 | # ## for snerf scene
164 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0]
165 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0]
166 |
167 | # for dynamic scene
168 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0]
169 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0]
170 |
171 | # for emernerf comparison [testhold::testhold]
172 | if args.testhold == 10:
173 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0]
174 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0]
175 | else:
176 | train_cam_infos = cam_infos
177 | test_cam_infos = []
178 |
179 | nerf_normalization = getNerfppNorm(train_cam_infos)
180 | nerf_normalization['radius'] = 1/nerf_normalization['radius']
181 |
182 | ply_path = os.path.join(args.source_path, "points3d.ply")
183 | if not os.path.exists(ply_path):
184 | rgbs = np.random.random((pointcloud.shape[0], 3))
185 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp)
186 | try:
187 | pcd = fetchPly(ply_path)
188 | except:
189 | pcd = None
190 |
191 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp)
192 | time_interval = (time_duration[1] - time_duration[0]) / (len(car_list) - 1)
193 |
194 | scene_info = SceneInfo(point_cloud=pcd,
195 | train_cameras=train_cam_infos,
196 | test_cameras=test_cam_infos,
197 | nerf_normalization=nerf_normalization,
198 | ply_path=ply_path,
199 | time_interval=time_interval,
200 | time_duration=time_duration)
201 |
202 | return scene_info
203 |
--------------------------------------------------------------------------------
/scripts/extract_kitti_metric_nvs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 |
5 | root = "eval_output/kitti_nvs"
6 | scenes = ["0001", "0002", "0006"]
7 |
8 | eval_dict = {
9 | "TEST": {"psnr": [], "ssim": [], "lpips": []},
10 | }
11 | for scene in scenes:
12 | eval_dir = os.path.join(root, scene, "eval")
13 | dirs = os.listdir(eval_dir)
14 | test_path = sorted([d for d in dirs if d.startswith("test")], key=lambda x: int(x.split("_")[1]))[-1]
15 | for name, path in [("TEST", test_path)]:
16 | psnrs = []
17 | ssims = []
18 | lpipss = []
19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f:
20 | data = json.load(f)
21 | eval_dict[name]["psnr"].append(data["psnr"])
22 | eval_dict[name]["ssim"].append(data["ssim"])
23 | eval_dict[name]["lpips"].append(data["lpips"])
24 |
25 | print(f'TEST PSNR:{np.mean(eval_dict["TEST"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TEST"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TEST"]["lpips"]):.3f}')
26 |
--------------------------------------------------------------------------------
/scripts/extract_kitti_metric_reconstruction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 |
5 | root = "eval_output/kitti_reconstruction"
6 | scenes = ["0001", "0002", "0006"]
7 |
8 | eval_dict = {
9 | "TRAIN": {"psnr": [], "ssim": [], "lpips": []},
10 | }
11 | for scene in scenes:
12 | eval_dir = os.path.join(root, scene, "eval")
13 | dirs = os.listdir(eval_dir)
14 | test_path = sorted([d for d in dirs if d.startswith("train")], key=lambda x: int(x.split("_")[1]))[-1]
15 | for name, path in [("TRAIN", test_path)]:
16 | psnrs = []
17 | ssims = []
18 | lpipss = []
19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f:
20 | data = json.load(f)
21 | eval_dict[name]["psnr"].append(data["psnr"])
22 | eval_dict[name]["ssim"].append(data["ssim"])
23 | eval_dict[name]["lpips"].append(data["lpips"])
24 |
25 | print(f'TRAIN PSNR:{np.mean(eval_dict["TRAIN"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TRAIN"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TRAIN"]["lpips"]):.3f}')
26 |
--------------------------------------------------------------------------------
/scripts/extract_mask_kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_masks.py
3 | @author Jianfei Guo, Shanghai AI Lab
4 | @brief Extract semantic mask
5 |
6 | Using SegFormer, 2021. Cityscapes 83.2%
7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9)
8 |
9 | Installation:
10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env.
11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9;
12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num'
13 | Hence, a seperate conda env is needed.
14 |
15 | git clone https://github.com/NVlabs/SegFormer
16 |
17 | conda create -n segformer python=3.8
18 | conda activate segformer
19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
21 |
22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf
23 | pip install mmcv-full==1.2.7 --no-cache-dir
24 |
25 | cd SegFormer
26 | pip install .
27 |
28 | Usage:
29 | Direct run this script in the newly set conda env.
30 | """
31 | import os
32 | import numpy as np
33 | import cv2
34 | from tqdm import tqdm
35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
36 | from mmseg.core.evaluation import get_palette
37 |
38 | if __name__ == "__main__":
39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer'
40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 'segformer.b5.1024x1024.city.160k.py')
41 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth')
42 | model = init_segmentor(config, checkpoint, device='cuda')
43 |
44 | root = 'data/kitti_mot/training'
45 |
46 | for cam_id in ['2', '3']:
47 | image_dir = os.path.join(root, f'image_0{cam_id}')
48 | sky_dir = os.path.join(root, f'sky_0{cam_id}')
49 | for seq in sorted(os.listdir(image_dir)):
50 | seq_dir = os.path.join(image_dir, seq)
51 | mask_dir = os.path.join(sky_dir, seq)
52 | if not os.path.isdir(seq_dir):
53 | continue
54 |
55 | os.makedirs(image_dir, exist_ok=True)
56 | os.makedirs(mask_dir, exist_ok=True)
57 | for image_name in sorted(os.listdir(seq_dir)):
58 | image_path = os.path.join(seq_dir, image_name)
59 | print(image_path)
60 | mask_path = os.path.join(mask_dir, image_name)
61 | if not image_path.endswith(".png"):
62 | continue
63 | result = inference_segmentor(model, image_path)
64 | mask = result[0].astype(np.uint8)
65 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8)
66 | cv2.imwrite(os.path.join(mask_dir, image_name), mask)
67 |
--------------------------------------------------------------------------------
/scripts/extract_mask_waymo.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_masks.py
3 | @author Jianfei Guo, Shanghai AI Lab
4 | @brief Extract semantic mask
5 |
6 | Using SegFormer, 2021. Cityscapes 83.2%
7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9)
8 |
9 | Installation:
10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env.
11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9;
12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num'
13 | Hence, a seperate conda env is needed.
14 |
15 | git clone https://github.com/NVlabs/SegFormer
16 |
17 | conda create -n segformer python=3.8
18 | conda activate segformer
19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
21 |
22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf
23 | pip install mmcv-full==1.2.7 --no-cache-dir
24 |
25 | cd SegFormer
26 | pip install .
27 |
28 | Usage:
29 | Direct run this script in the newly set conda env.
30 | """
31 | import os
32 | import numpy as np
33 | import cv2
34 | from tqdm import tqdm
35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
36 | from mmseg.core.evaluation import get_palette
37 |
38 | if __name__ == "__main__":
39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer'
40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5',
41 | 'segformer.b5.1024x1024.city.160k.py')
42 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth')
43 | model = init_segmentor(config, checkpoint, device='cuda')
44 |
45 | root = 'data/waymo_scenes'
46 |
47 | scenes = sorted(os.listdir(root))
48 |
49 | for scene in scenes:
50 | for cam_id in range(5):
51 | image_dir = os.path.join(root, scene, f'image_{cam_id}')
52 | sky_dir = os.path.join(root, scene, f'sky_{cam_id}')
53 | os.makedirs(sky_dir, exist_ok=True)
54 | for image_name in tqdm(sorted(os.listdir(image_dir))):
55 | if not image_name.endswith(".png"):
56 | continue
57 | image_path = os.path.join(image_dir, image_name)
58 | mask_path = os.path.join(sky_dir, image_name)
59 | result = inference_segmentor(model, image_path)
60 | mask = result[0].astype(np.uint8)
61 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8)
62 | cv2.imwrite(mask_path, mask)
63 |
--------------------------------------------------------------------------------
/scripts/extract_scenes_waymo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | from tqdm import tqdm
4 |
5 | data_root = '/HDD_DISK/datasets/waymo/kitti_format/training'
6 |
7 | tags = ['image_0','image_1','image_2','image_3','image_4','calib','velodyne','pose']
8 | posts = ['.png','.png','.png','.png','.png','.txt', '.bin','.txt']
9 |
10 | out_dir = 'data/waymo_scenes_streetsurf'
11 |
12 | scene_ids = [3, 19, 36, 69, 81, 126, 139, 140, 146,
13 | 148, 157, 181, 200, 204, 226, 232, 237,
14 | 241, 245, 246, 271, 297, 302, 312, 314,
15 | 362, 482, 495, 524, 527]
16 | scene_nums = [
17 | [0, 163],
18 | [0, 198],
19 | [0, 198],
20 | [0, 198],
21 | [0, 198],
22 | [0, 198],
23 | [0, 198],
24 | [17, 198],
25 | [0, 198],
26 | [0, 198],
27 | [0, 140],
28 | [24, 198],
29 | [0, 198],
30 | [0, 198],
31 | [0, 198],
32 | [0, 198],
33 | [0, 198],
34 | [30, 198],
35 | [80, 198],
36 | [0, 170],
37 | [70, 198],
38 | [0, 198],
39 | [0, 198],
40 | [0, 120],
41 | [0, 198],
42 | [0, 198],
43 | [0, 198],
44 | [0, 198],
45 | [0, 198],
46 | [0, 90],
47 | ]
48 | os.makedirs(out_dir, exist_ok=True)
49 |
50 | for scene_idx, scene_id in enumerate(scene_ids):
51 | scene_dir = join(out_dir, f'{scene_id:04d}001')
52 | os.makedirs(scene_dir, exist_ok=True)
53 |
54 | for tag in tags:
55 | os.makedirs(join(scene_dir, tag), exist_ok=True)
56 | for post, tag in zip(posts,tags):
57 | for i in tqdm(range(scene_nums[scene_idx][0], scene_nums[scene_idx][1])):
58 | cmd = "cp {} {}".format(join(data_root,tag,f'{scene_id:04d}{i:03d}'+post),
59 | join(scene_dir, tag, f'{scene_id:04d}{i:03d}'+post))
60 | os.system(cmd)
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/scripts/extract_waymo_metric_nvs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 |
5 | root = "eval_output/waymo_nvs"
6 | scenes = ["0017085", "0145050", "0147030", "0158150"]
7 |
8 | eval_dict = {
9 | "TEST": {"psnr": [], "ssim": [], "lpips": []},
10 | }
11 | for scene in scenes:
12 | eval_dir = os.path.join(root, scene, "eval")
13 | dirs = os.listdir(eval_dir)
14 | test_path = sorted([d for d in dirs if d.startswith("test")], key=lambda x: int(x.split("_")[1]))[-1]
15 | for name, path in [("TEST", test_path)]:
16 | psnrs = []
17 | ssims = []
18 | lpipss = []
19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f:
20 | data = json.load(f)
21 | eval_dict[name]["psnr"].append(data["psnr"])
22 | eval_dict[name]["ssim"].append(data["ssim"])
23 | eval_dict[name]["lpips"].append(data["lpips"])
24 |
25 | print(f'TEST PSNR:{np.mean(eval_dict["TEST"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TEST"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TEST"]["lpips"]):.3f}')
26 |
--------------------------------------------------------------------------------
/scripts/extract_waymo_metric_reconstruction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 |
5 | root = "eval_output/waymo_reconstruction"
6 | scenes = ["0017085", "0145050", "0147030", "0158150"]
7 |
8 | eval_dict = {
9 | "TRAIN": {"psnr": [], "ssim": [], "lpips": []},
10 | }
11 | for scene in scenes:
12 | eval_dir = os.path.join(root, scene, "eval")
13 | dirs = os.listdir(eval_dir)
14 | test_path = sorted([d for d in dirs if d.startswith("train")], key=lambda x: int(x.split("_")[1]))[-1]
15 | for name, path in [("TRAIN", test_path)]:
16 | psnrs = []
17 | ssims = []
18 | lpipss = []
19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f:
20 | data = json.load(f)
21 | eval_dict[name]["psnr"].append(data["psnr"])
22 | eval_dict[name]["ssim"].append(data["ssim"])
23 | eval_dict[name]["lpips"].append(data["lpips"])
24 |
25 | print(f'TRAIN PSNR:{np.mean(eval_dict["TRAIN"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TRAIN"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TRAIN"]["lpips"]):.3f}')
26 |
--------------------------------------------------------------------------------
/scripts/run_kitti_nvs_all.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python train.py \
2 | --config configs/kitti_nvs.yaml \
3 | source_path=data/kitti_mot/training/image_02/0001 \
4 | model_path=eval_output/kitti_nvs/0001 \
5 | start_frame=380 end_frame=431
6 |
7 | CUDA_VISIBLE_DEVICES=2 python train.py \
8 | --config configs/kitti_nvs.yaml \
9 | source_path=data/kitti_mot/training/image_02/0002 \
10 | model_path=eval_output/kitti_nvs/0002 \
11 | start_frame=140 end_frame=224
12 |
13 | CUDA_VISIBLE_DEVICES=2 python train.py \
14 | --config configs/kitti_nvs.yaml \
15 | source_path=data/kitti_mot/training/image_02/0006 \
16 | model_path=eval_output/kitti_nvs/0006 \
17 | start_frame=65 end_frame=120
18 |
--------------------------------------------------------------------------------
/scripts/run_kitti_reconstruction_all.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=3 python train.py \
2 | --config configs/kitti_reconstruction.yaml \
3 | source_path=data/kitti_mot/training/image_02/0001 \
4 | model_path=eval_output/kitti_reconstruction/0001 \
5 | start_frame=380 end_frame=431
6 |
7 | CUDA_VISIBLE_DEVICES=3 python train.py \
8 | --config configs/kitti_reconstruction.yaml \
9 | source_path=data/kitti_mot/training/image_02/0002 \
10 | model_path=eval_output/kitti_reconstruction/0002 \
11 | start_frame=140 end_frame=224
12 |
13 | CUDA_VISIBLE_DEVICES=3 python train.py \
14 | --config configs/kitti_reconstruction.yaml \
15 | source_path=data/kitti_mot/training/image_02/0006 \
16 | model_path=eval_output/kitti_reconstruction/0006 \
17 | start_frame=65 end_frame=120
18 |
--------------------------------------------------------------------------------
/scripts/run_waymo_nvs_all.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python train.py \
2 | --config configs/waymo_nvs.yaml \
3 | source_path=data/waymo_scenes/0145050 \
4 | model_path=eval_output/waymo_nvs/0145050
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --config configs/waymo_nvs.yaml \
8 | source_path=data/waymo_scenes/0147030 \
9 | model_path=eval_output/waymo_nvs/0147030
10 |
11 | CUDA_VISIBLE_DEVICES=0 python train.py \
12 | --config configs/waymo_nvs.yaml \
13 | source_path=data/waymo_scenes/0158150 \
14 | model_path=eval_output/waymo_nvs/0158150
15 |
16 | CUDA_VISIBLE_DEVICES=0 python train.py \
17 | --config configs/waymo_nvs.yaml \
18 | source_path=data/waymo_scenes/0017085 \
19 | model_path=eval_output/waymo_nvs/0017085
20 |
--------------------------------------------------------------------------------
/scripts/run_waymo_reconstruction_all.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python train.py \
2 | --config configs/waymo_reconstruction.yaml \
3 | source_path=data/waymo_scenes/0145050 \
4 | model_path=eval_output/waymo_reconstruction/0145050
5 |
6 | CUDA_VISIBLE_DEVICES=1 python train.py \
7 | --config configs/waymo_reconstruction.yaml \
8 | source_path=data/waymo_scenes/0147030 \
9 | model_path=eval_output/waymo_reconstruction/0147030
10 |
11 | CUDA_VISIBLE_DEVICES=1 python train.py \
12 | --config configs/waymo_reconstruction.yaml \
13 | source_path=data/waymo_scenes/0158150 \
14 | model_path=eval_output/waymo_reconstruction/0158150
15 |
16 | CUDA_VISIBLE_DEVICES=1 python train.py \
17 | --config configs/waymo_reconstruction.yaml \
18 | source_path=data/waymo_scenes/0017085 \
19 | model_path=eval_output/waymo_reconstruction/0017085
20 |
--------------------------------------------------------------------------------
/scripts/waymo_converter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | r"""Adapted from `Waymo to KITTI converter
3 | `_.
4 | """
5 |
6 | try:
7 | from waymo_open_dataset import dataset_pb2
8 | except ImportError:
9 | raise ImportError(
10 | 'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
11 | 'to install the official devkit first.')
12 |
13 | from glob import glob
14 | from os.path import join
15 |
16 | import mmcv
17 | import numpy as np
18 | import tensorflow as tf
19 | from waymo_open_dataset.utils import range_image_utils, transform_utils
20 | from waymo_open_dataset.utils.frame_utils import \
21 | parse_range_image_and_camera_projection
22 |
23 |
24 | class Waymo2KITTI(object):
25 | """Waymo to KITTI converter.
26 |
27 | This class serves as the converter to change the waymo raw data to KITTI
28 | format.
29 |
30 | Args:
31 | load_dir (str): Directory to load waymo raw data.
32 | save_dir (str): Directory to save data in KITTI format.
33 | prefix (str): Prefix of filename. In general, 0 for training, 1 for
34 | validation and 2 for testing.
35 | workers (int, optional): Number of workers for the parallel process.
36 | test_mode (bool, optional): Whether in the test_mode. Default: False.
37 | """
38 |
39 | def __init__(self,
40 | load_dir,
41 | save_dir,
42 | prefix,
43 | workers=64,
44 | test_mode=False):
45 | self.filter_empty_3dboxes = True
46 | self.filter_no_label_zone_points = True
47 |
48 | self.selected_waymo_classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST']
49 |
50 | # Only data collected in specific locations will be converted
51 | # If set None, this filter is disabled
52 | # Available options: location_sf (main dataset)
53 | self.selected_waymo_locations = None
54 | self.save_track_id = False
55 |
56 | # turn on eager execution for older tensorflow versions
57 | if int(tf.__version__.split('.')[0]) < 2:
58 | tf.enable_eager_execution()
59 |
60 | self.lidar_list = [
61 | '_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT',
62 | '_SIDE_LEFT'
63 | ]
64 | self.type_list = [
65 | 'UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'
66 | ]
67 | self.waymo_to_kitti_class_map = {
68 | 'UNKNOWN': 'DontCare',
69 | 'PEDESTRIAN': 'Pedestrian',
70 | 'VEHICLE': 'Car',
71 | 'CYCLIST': 'Cyclist',
72 | 'SIGN': 'Sign' # not in kitti
73 | }
74 |
75 | self.load_dir = load_dir
76 | self.save_dir = save_dir
77 | self.prefix = prefix
78 | self.workers = int(workers)
79 | self.test_mode = test_mode
80 |
81 | self.tfrecord_pathnames = sorted(
82 | glob(join(self.load_dir, '*.tfrecord')))
83 |
84 | self.label_save_dir = f'{self.save_dir}/label_'
85 | self.label_all_save_dir = f'{self.save_dir}/label_all'
86 | self.image_save_dir = f'{self.save_dir}/image_'
87 | self.calib_save_dir = f'{self.save_dir}/calib'
88 | self.point_cloud_save_dir = f'{self.save_dir}/velodyne'
89 | self.pose_save_dir = f'{self.save_dir}/pose'
90 | self.timestamp_save_dir = f'{self.save_dir}/timestamp'
91 |
92 | self.create_folder()
93 |
94 | def convert(self):
95 | """Convert action."""
96 | print('Start converting ...')
97 | mmcv.track_parallel_progress(self.convert_one, range(len(self)),
98 | self.workers)
99 | print('\nFinished ...')
100 |
101 | def convert_one(self, file_idx):
102 | """Convert action for single file.
103 |
104 | Args:
105 | file_idx (int): Index of the file to be converted.
106 | """
107 | pathname = self.tfrecord_pathnames[file_idx]
108 | dataset = tf.data.TFRecordDataset(pathname, compression_type='')
109 |
110 | for frame_idx, data in enumerate(dataset):
111 |
112 | frame = dataset_pb2.Frame()
113 | frame.ParseFromString(bytearray(data.numpy()))
114 | if (self.selected_waymo_locations is not None
115 | and frame.context.stats.location
116 | not in self.selected_waymo_locations):
117 | continue
118 |
119 | self.save_image(frame, file_idx, frame_idx)
120 | self.save_calib(frame, file_idx, frame_idx)
121 | self.save_lidar(frame, file_idx, frame_idx)
122 | self.save_pose(frame, file_idx, frame_idx)
123 | self.save_timestamp(frame, file_idx, frame_idx)
124 |
125 | if not self.test_mode:
126 | self.save_label(frame, file_idx, frame_idx)
127 |
128 | def __len__(self):
129 | """Length of the filename list."""
130 | return len(self.tfrecord_pathnames)
131 |
132 | def save_image(self, frame, file_idx, frame_idx):
133 | """Parse and save the images in jpg format. Jpg is the original format
134 | used by Waymo Open dataset. Saving in png format will cause huge (~3x)
135 | unnesssary storage waste.
136 |
137 | Args:
138 | frame (:obj:`Frame`): Open dataset frame proto.
139 | file_idx (int): Current file index.
140 | frame_idx (int): Current frame index.
141 | """
142 | for img in frame.images:
143 | img_path = f'{self.image_save_dir}{str(img.name - 1)}/' + \
144 | f'{self.prefix}{str(file_idx).zfill(3)}' + \
145 | f'{str(frame_idx).zfill(3)}.jpg'
146 | with open(img_path, 'wb') as fp:
147 | fp.write(img.image)
148 |
149 | def save_calib(self, frame, file_idx, frame_idx):
150 | """Parse and save the calibration data.
151 |
152 | Args:
153 | frame (:obj:`Frame`): Open dataset frame proto.
154 | file_idx (int): Current file index.
155 | frame_idx (int): Current frame index.
156 | """
157 | # waymo front camera to kitti reference camera
158 | T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0],
159 | [1.0, 0.0, 0.0]])
160 | camera_calibs = []
161 | R0_rect = [f'{i:e}' for i in np.eye(3).flatten()]
162 | Tr_velo_to_cams = []
163 | calib_context = ''
164 |
165 | for camera in frame.context.camera_calibrations:
166 | # extrinsic parameters
167 | T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape(
168 | 4, 4)
169 | T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle)
170 | Tr_velo_to_cam = \
171 | self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam
172 | if camera.name == 1: # FRONT = 1, see dataset.proto for details
173 | self.T_velo_to_front_cam = Tr_velo_to_cam.copy()
174 | Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12, ))
175 | Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam])
176 |
177 | # intrinsic parameters
178 | camera_calib = np.zeros((3, 4))
179 | camera_calib[0, 0] = camera.intrinsic[0]
180 | camera_calib[1, 1] = camera.intrinsic[1]
181 | camera_calib[0, 2] = camera.intrinsic[2]
182 | camera_calib[1, 2] = camera.intrinsic[3]
183 | camera_calib[2, 2] = 1
184 | camera_calib = list(camera_calib.reshape(12))
185 | camera_calib = [f'{i:e}' for i in camera_calib]
186 | camera_calibs.append(camera_calib)
187 |
188 | # all camera ids are saved as id-1 in the result because
189 | # camera 0 is unknown in the proto
190 | for i in range(5):
191 | calib_context += 'P' + str(i) + ': ' + \
192 | ' '.join(camera_calibs[i]) + '\n'
193 | calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n'
194 | for i in range(5):
195 | calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \
196 | ' '.join(Tr_velo_to_cams[i]) + '\n'
197 |
198 | with open(
199 | f'{self.calib_save_dir}/{self.prefix}' +
200 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt',
201 | 'w+') as fp_calib:
202 | fp_calib.write(calib_context)
203 | fp_calib.close()
204 |
205 | def save_lidar(self, frame, file_idx, frame_idx):
206 | """Parse and save the lidar data in psd format.
207 |
208 | Args:
209 | frame (:obj:`Frame`): Open dataset frame proto.
210 | file_idx (int): Current file index.
211 | frame_idx (int): Current frame index.
212 | """
213 | range_images, camera_projections, range_image_top_pose = \
214 | parse_range_image_and_camera_projection(frame)
215 |
216 | # First return
217 | points_0, cp_points_0, intensity_0, elongation_0, mask_indices_0 = \
218 | self.convert_range_image_to_point_cloud(
219 | frame,
220 | range_images,
221 | camera_projections,
222 | range_image_top_pose,
223 | ri_index=0
224 | )
225 | points_0 = np.concatenate(points_0, axis=0)
226 | intensity_0 = np.concatenate(intensity_0, axis=0)
227 | elongation_0 = np.concatenate(elongation_0, axis=0)
228 | mask_indices_0 = np.concatenate(mask_indices_0, axis=0)
229 |
230 | # Second return
231 | points_1, cp_points_1, intensity_1, elongation_1, mask_indices_1 = \
232 | self.convert_range_image_to_point_cloud(
233 | frame,
234 | range_images,
235 | camera_projections,
236 | range_image_top_pose,
237 | ri_index=1
238 | )
239 | points_1 = np.concatenate(points_1, axis=0)
240 | intensity_1 = np.concatenate(intensity_1, axis=0)
241 | elongation_1 = np.concatenate(elongation_1, axis=0)
242 | mask_indices_1 = np.concatenate(mask_indices_1, axis=0)
243 |
244 | points = np.concatenate([points_0, points_1], axis=0)
245 | intensity = np.concatenate([intensity_0, intensity_1], axis=0)
246 | elongation = np.concatenate([elongation_0, elongation_1], axis=0)
247 | mask_indices = np.concatenate([mask_indices_0, mask_indices_1], axis=0)
248 |
249 | # timestamp = frame.timestamp_micros * np.ones_like(intensity)
250 |
251 | # concatenate x,y,z, intensity, elongation, timestamp (6-dim)
252 | point_cloud = np.column_stack(
253 | (points, intensity, elongation, mask_indices))
254 |
255 | pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \
256 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin'
257 | point_cloud.astype(np.float32).tofile(pc_path)
258 |
259 | def save_label(self, frame, file_idx, frame_idx):
260 | """Parse and save the label data in txt format.
261 | The relation between waymo and kitti coordinates is noteworthy:
262 | 1. x, y, z correspond to l, w, h (waymo) -> l, h, w (kitti)
263 | 2. x-y-z: front-left-up (waymo) -> right-down-front(kitti)
264 | 3. bbox origin at volumetric center (waymo) -> bottom center (kitti)
265 | 4. rotation: +x around y-axis (kitti) -> +x around z-axis (waymo)
266 |
267 | Args:
268 | frame (:obj:`Frame`): Open dataset frame proto.
269 | file_idx (int): Current file index.
270 | frame_idx (int): Current frame index.
271 | """
272 | fp_label_all = open(
273 | f'{self.label_all_save_dir}/{self.prefix}' +
274 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+')
275 | id_to_bbox = dict()
276 | id_to_name = dict()
277 | for labels in frame.projected_lidar_labels:
278 | name = labels.name
279 | for label in labels.labels:
280 | # TODO: need a workaround as bbox may not belong to front cam
281 | bbox = [
282 | label.box.center_x - label.box.length / 2,
283 | label.box.center_y - label.box.width / 2,
284 | label.box.center_x + label.box.length / 2,
285 | label.box.center_y + label.box.width / 2
286 | ]
287 | id_to_bbox[label.id] = bbox
288 | id_to_name[label.id] = name - 1
289 |
290 | for obj in frame.laser_labels:
291 | bounding_box = None
292 | name = None
293 | id = obj.id
294 | for lidar in self.lidar_list:
295 | if id + lidar in id_to_bbox:
296 | bounding_box = id_to_bbox.get(id + lidar)
297 | name = str(id_to_name.get(id + lidar))
298 | break
299 |
300 | if bounding_box is None or name is None:
301 | name = '0'
302 | bounding_box = (0, 0, 0, 0)
303 |
304 | my_type = self.type_list[obj.type]
305 |
306 | if my_type not in self.selected_waymo_classes:
307 | continue
308 |
309 | if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1:
310 | continue
311 |
312 | my_type = self.waymo_to_kitti_class_map[my_type]
313 |
314 | height = obj.box.height
315 | width = obj.box.width
316 | length = obj.box.length
317 |
318 | x = obj.box.center_x
319 | y = obj.box.center_y
320 | z = obj.box.center_z - height / 2
321 |
322 | # project bounding box to the virtual reference frame
323 | pt_ref = self.T_velo_to_front_cam @ \
324 | np.array([x, y, z, 1]).reshape((4, 1))
325 | x, y, z, _ = pt_ref.flatten().tolist()
326 |
327 | rotation_y = -obj.box.heading - np.pi / 2
328 | track_id = obj.id
329 |
330 | # not available
331 | truncated = 0
332 | occluded = 0
333 | alpha = -10
334 |
335 | line = my_type + \
336 | ' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format(
337 | round(truncated, 2), occluded, round(alpha, 2),
338 | round(bounding_box[0], 2), round(bounding_box[1], 2),
339 | round(bounding_box[2], 2), round(bounding_box[3], 2),
340 | round(height, 2), round(width, 2), round(length, 2),
341 | round(x, 2), round(y, 2), round(z, 2),
342 | round(rotation_y, 2))
343 |
344 | if self.save_track_id:
345 | line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n'
346 | else:
347 | line_all = line[:-1] + ' ' + name + '\n'
348 |
349 | fp_label = open(
350 | f'{self.label_save_dir}{name}/{self.prefix}' +
351 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'a')
352 | fp_label.write(line)
353 | fp_label.close()
354 |
355 | fp_label_all.write(line_all)
356 |
357 | fp_label_all.close()
358 |
359 | def save_pose(self, frame, file_idx, frame_idx):
360 | """Parse and save the pose data.
361 |
362 | Note that SDC's own pose is not included in the regular training
363 | of KITTI dataset. KITTI raw dataset contains ego motion files
364 | but are not often used. Pose is important for algorithms that
365 | take advantage of the temporal information.
366 |
367 | Args:
368 | frame (:obj:`Frame`): Open dataset frame proto.
369 | file_idx (int): Current file index.
370 | frame_idx (int): Current frame index.
371 | """
372 | pose = np.array(frame.pose.transform).reshape(4, 4)
373 | np.savetxt(
374 | join(f'{self.pose_save_dir}/{self.prefix}' +
375 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'),
376 | pose)
377 |
378 | def save_timestamp(self, frame, file_idx, frame_idx):
379 | """Save the timestamp data in a separate file instead of the
380 | pointcloud.
381 |
382 | Note that SDC's own pose is not included in the regular training
383 | of KITTI dataset. KITTI raw dataset contains ego motion files
384 | but are not often used. Pose is important for algorithms that
385 | take advantage of the temporal information.
386 |
387 | Args:
388 | frame (:obj:`Frame`): Open dataset frame proto.
389 | file_idx (int): Current file index.
390 | frame_idx (int): Current frame index.
391 | """
392 | with open(
393 | join(f'{self.timestamp_save_dir}/{self.prefix}' +
394 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'),
395 | 'w') as f:
396 | f.write(str(frame.timestamp_micros))
397 |
398 | def create_folder(self):
399 | """Create folder for data preprocessing."""
400 | if not self.test_mode:
401 | dir_list1 = [
402 | self.label_all_save_dir, self.calib_save_dir,
403 | self.point_cloud_save_dir, self.pose_save_dir,
404 | self.timestamp_save_dir
405 | ]
406 | dir_list2 = [self.label_save_dir, self.image_save_dir]
407 | else:
408 | dir_list1 = [
409 | self.calib_save_dir, self.point_cloud_save_dir,
410 | self.pose_save_dir, self.timestamp_save_dir
411 | ]
412 | dir_list2 = [self.image_save_dir]
413 | for d in dir_list1:
414 | mmcv.mkdir_or_exist(d)
415 | for d in dir_list2:
416 | for i in range(5):
417 | mmcv.mkdir_or_exist(f'{d}{str(i)}')
418 |
419 | def convert_range_image_to_point_cloud(self,
420 | frame,
421 | range_images,
422 | camera_projections,
423 | range_image_top_pose,
424 | ri_index=0):
425 | """Convert range images to point cloud.
426 |
427 | Args:
428 | frame (:obj:`Frame`): Open dataset frame.
429 | range_images (dict): Mapping from laser_name to list of two
430 | range images corresponding with two returns.
431 | camera_projections (dict): Mapping from laser_name to list of two
432 | camera projections corresponding with two returns.
433 | range_image_top_pose (:obj:`Transform`): Range image pixel pose for
434 | top lidar.
435 | ri_index (int, optional): 0 for the first return,
436 | 1 for the second return. Default: 0.
437 |
438 | Returns:
439 | tuple[list[np.ndarray]]: (List of points with shape [N, 3],
440 | camera projections of points with shape [N, 6], intensity
441 | with shape [N, 1], elongation with shape [N, 1], points'
442 | position in the depth map (element offset if points come from
443 | the main lidar otherwise -1) with shape[N, 1]). All the
444 | lists have the length of lidar numbers (5).
445 | """
446 | calibrations = sorted(
447 | frame.context.laser_calibrations, key=lambda c: c.name)
448 | points = []
449 | cp_points = []
450 | intensity = []
451 | elongation = []
452 | mask_indices = []
453 |
454 | frame_pose = tf.convert_to_tensor(
455 | value=np.reshape(np.array(frame.pose.transform), [4, 4]))
456 | # [H, W, 6]
457 | range_image_top_pose_tensor = tf.reshape(
458 | tf.convert_to_tensor(value=range_image_top_pose.data),
459 | range_image_top_pose.shape.dims)
460 | # [H, W, 3, 3]
461 | range_image_top_pose_tensor_rotation = \
462 | transform_utils.get_rotation_matrix(
463 | range_image_top_pose_tensor[..., 0],
464 | range_image_top_pose_tensor[..., 1],
465 | range_image_top_pose_tensor[..., 2])
466 | range_image_top_pose_tensor_translation = \
467 | range_image_top_pose_tensor[..., 3:]
468 | range_image_top_pose_tensor = transform_utils.get_transform(
469 | range_image_top_pose_tensor_rotation,
470 | range_image_top_pose_tensor_translation)
471 | for c in calibrations:
472 | range_image = range_images[c.name][ri_index]
473 | if len(c.beam_inclinations) == 0:
474 | beam_inclinations = range_image_utils.compute_inclination(
475 | tf.constant(
476 | [c.beam_inclination_min, c.beam_inclination_max]),
477 | height=range_image.shape.dims[0])
478 | else:
479 | beam_inclinations = tf.constant(c.beam_inclinations)
480 |
481 | beam_inclinations = tf.reverse(beam_inclinations, axis=[-1])
482 | extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4])
483 |
484 | range_image_tensor = tf.reshape(
485 | tf.convert_to_tensor(value=range_image.data),
486 | range_image.shape.dims)
487 | pixel_pose_local = None
488 | frame_pose_local = None
489 | if c.name == dataset_pb2.LaserName.TOP:
490 | pixel_pose_local = range_image_top_pose_tensor
491 | pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0)
492 | frame_pose_local = tf.expand_dims(frame_pose, axis=0)
493 | range_image_mask = range_image_tensor[..., 0] > 0
494 |
495 | if self.filter_no_label_zone_points:
496 | nlz_mask = range_image_tensor[..., 3] != 1.0 # 1.0: in NLZ
497 | range_image_mask = range_image_mask & nlz_mask
498 |
499 | range_image_cartesian = \
500 | range_image_utils.extract_point_cloud_from_range_image(
501 | tf.expand_dims(range_image_tensor[..., 0], axis=0),
502 | tf.expand_dims(extrinsic, axis=0),
503 | tf.expand_dims(tf.convert_to_tensor(
504 | value=beam_inclinations), axis=0),
505 | pixel_pose=pixel_pose_local,
506 | frame_pose=frame_pose_local)
507 |
508 | mask_index = tf.where(range_image_mask)
509 |
510 | range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0)
511 | points_tensor = tf.gather_nd(range_image_cartesian, mask_index)
512 |
513 | cp = camera_projections[c.name][ri_index]
514 | cp_tensor = tf.reshape(
515 | tf.convert_to_tensor(value=cp.data), cp.shape.dims)
516 | cp_points_tensor = tf.gather_nd(cp_tensor, mask_index)
517 | points.append(points_tensor.numpy())
518 | cp_points.append(cp_points_tensor.numpy())
519 |
520 | intensity_tensor = tf.gather_nd(range_image_tensor[..., 1],
521 | mask_index)
522 | intensity.append(intensity_tensor.numpy())
523 |
524 | elongation_tensor = tf.gather_nd(range_image_tensor[..., 2],
525 | mask_index)
526 | elongation.append(elongation_tensor.numpy())
527 | if c.name == 1:
528 | mask_index = (ri_index * range_image_mask.shape[0] +
529 | mask_index[:, 0]
530 | ) * range_image_mask.shape[1] + mask_index[:, 1]
531 | mask_index = mask_index.numpy().astype(elongation[-1].dtype)
532 | else:
533 | mask_index = np.full_like(elongation[-1], -1)
534 |
535 | mask_indices.append(mask_index)
536 |
537 | return points, cp_points, intensity, elongation, mask_indices
538 |
539 | def cart_to_homo(self, mat):
540 | """Convert transformation matrix in Cartesian coordinates to
541 | homogeneous format.
542 |
543 | Args:
544 | mat (np.ndarray): Transformation matrix in Cartesian.
545 | The input matrix shape is 3x3 or 3x4.
546 |
547 | Returns:
548 | np.ndarray: Transformation matrix in homogeneous format.
549 | The matrix shape is 4x4.
550 | """
551 | ret = np.eye(4)
552 | if mat.shape == (3, 3):
553 | ret[:3, :3] = mat
554 | elif mat.shape == (3, 4):
555 | ret[:3, :] = mat
556 | else:
557 | raise ValueError(mat.shape)
558 | return ret
559 |
560 | if __name__ == '__main__':
561 | import argparse
562 |
563 | parser = argparse.ArgumentParser(description='Data converter arg parser')
564 | parser.add_argument('dataset', metavar='kitti', help='name of the dataset')
565 | parser.add_argument(
566 | '--root-path',
567 | type=str,
568 | default='./data/kitti',
569 | help='specify the root path of dataset')
570 | parser.add_argument(
571 | '--version',
572 | type=str,
573 | default='v1.0',
574 | required=False,
575 | help='specify the dataset version, no need for kitti')
576 | parser.add_argument(
577 | '--max-sweeps',
578 | type=int,
579 | default=10,
580 | required=False,
581 | help='specify sweeps of lidar per example')
582 | parser.add_argument(
583 | '--with-plane',
584 | action='store_true',
585 | help='Whether to use plane information for kitti.')
586 | parser.add_argument(
587 | '--num-points',
588 | type=int,
589 | default=-1,
590 | help='Number of points to sample for indoor datasets.')
591 | parser.add_argument(
592 | '--out-dir',
593 | type=str,
594 | default='./data/kitti',
595 | required=False,
596 | help='name of info pkl')
597 | parser.add_argument('--extra-tag', type=str, default='kitti')
598 | parser.add_argument(
599 | '--workers', type=int, default=4, help='number of threads to be used')
600 | args = parser.parse_args()
601 |
602 | from os import path as osp
603 | splits = ['training', 'validation', 'testing']
604 | for i, split in enumerate(splits):
605 | load_dir = osp.join(args.root_path, 'waymo_format', split)
606 | if split == 'validation':
607 | save_dir = osp.join(args.out_dir, 'kitti_format', 'training')
608 | else:
609 | save_dir = osp.join(args.out_dir, 'kitti_format', split)
610 | converter = Waymo2KITTI(
611 | load_dir,
612 | save_dir,
613 | prefix=str(i),
614 | workers=args.workers,
615 | test_mode=(split == 'testing'))
616 | converter.convert()
617 | # Generate waymo infos
618 | # out_dir = osp.join(out_dir, 'kitti_format')
619 | # kitti.create_waymo_info_file(
620 | # out_dir, info_prefix, max_sweeps=max_sweeps, workers=workers)
621 | # GTDatabaseCreater(
622 | # 'WaymoDataset',
623 | # out_dir,
624 | # info_prefix,
625 | # f'{out_dir}/{info_prefix}_infos_train.pkl',
626 | # relative_path=False,
627 | # with_mask=False,
628 | # num_worker=workers).create()
--------------------------------------------------------------------------------
/separate.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import glob
12 | import os
13 | import torch
14 | from gaussian_renderer import render
15 | from scene import Scene, GaussianModel, EnvLight
16 | from utils.general_utils import seed_everything
17 | from tqdm import tqdm
18 | from argparse import ArgumentParser
19 | from torchvision.utils import save_image
20 | from omegaconf import OmegaConf
21 |
22 | EPS = 1e-5
23 |
24 | @torch.no_grad()
25 | def separation(scene : Scene, renderFunc, renderArgs, env_map=None):
26 | scale = scene.resolution_scales[0]
27 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
28 | {'name': 'train', 'cameras': scene.getTrainCameras()})
29 |
30 | # we supppose area with altitude>0.5 is static
31 | # here z axis is downward so is gaussians.get_xyz[:, 2] < -0.5
32 | high_mask = gaussians.get_xyz[:, 2] < -0.5
33 | # import pdb;pdb.set_trace()
34 | mask = (gaussians.get_scaling_t[:, 0] > args.separate_scaling_t) | high_mask
35 | for config in validation_configs:
36 | if config['cameras'] and len(config['cameras']) > 0:
37 | outdir = os.path.join(args.model_path, "separation", config['name'])
38 | os.makedirs(outdir,exist_ok=True)
39 | for idx, viewpoint in enumerate(tqdm(config['cameras'])):
40 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map)
41 | render_pkg_static = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=mask)
42 |
43 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
44 | image_static = torch.clamp(render_pkg_static["render"], 0.0, 1.0)
45 |
46 | save_image(image, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png"))
47 | save_image(image_static, os.path.join(outdir, f"{viewpoint.colmap_id:03d}_static.png"))
48 |
49 |
50 | if __name__ == "__main__":
51 | # Set up command line argument parser
52 | parser = ArgumentParser(description="Training script parameters")
53 | parser.add_argument("--config", type=str, required=True)
54 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml")
55 | args, _ = parser.parse_known_args()
56 |
57 | base_conf = OmegaConf.load(args.base_config)
58 | second_conf = OmegaConf.load(args.config)
59 | cli_conf = OmegaConf.from_cli()
60 | args = OmegaConf.merge(base_conf, second_conf, cli_conf)
61 | args.resolution_scales = args.resolution_scales[:1]
62 | print(args)
63 |
64 | seed_everything(args.seed)
65 |
66 | sep_path = os.path.join(args.model_path, 'separation')
67 | os.makedirs(sep_path, exist_ok=True)
68 |
69 | gaussians = GaussianModel(args)
70 | scene = Scene(args, gaussians, shuffle=False)
71 |
72 | if args.env_map_res > 0:
73 | env_map = EnvLight(resolution=args.env_map_res).cuda()
74 | env_map.training_setup(args)
75 | else:
76 | env_map = None
77 |
78 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
79 | assert len(checkpoints) > 0, "No checkpoints found."
80 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
81 | (model_params, first_iter) = torch.load(checkpoint)
82 | gaussians.restore(model_params, args)
83 |
84 | if env_map is not None:
85 | env_checkpoint = os.path.join(os.path.dirname(checkpoint),
86 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt"))
87 | (light_params, _) = torch.load(env_checkpoint)
88 | env_map.restore(light_params)
89 |
90 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
91 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
92 | separation(scene, render, (args, background), env_map=env_map)
93 |
94 | print("\Rendering statics and dynamics complete.")
95 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import json
12 | import os
13 | from collections import defaultdict
14 | import torch
15 | import torch.nn.functional as F
16 | from random import randint
17 | from utils.loss_utils import psnr, ssim
18 | from gaussian_renderer import render
19 | from scene import Scene, GaussianModel, EnvLight
20 | from utils.general_utils import seed_everything, visualize_depth
21 | from tqdm import tqdm
22 | from argparse import ArgumentParser
23 | from torchvision.utils import make_grid, save_image
24 | import numpy as np
25 | import kornia
26 | from omegaconf import OmegaConf
27 | try:
28 | from torch.utils.tensorboard import SummaryWriter
29 | TENSORBOARD_FOUND = True
30 | except ImportError:
31 | TENSORBOARD_FOUND = False
32 |
33 | EPS = 1e-5
34 | def training(args):
35 |
36 | if TENSORBOARD_FOUND:
37 | tb_writer = SummaryWriter(args.model_path)
38 | else:
39 | tb_writer = None
40 | print("Tensorboard not available: not logging progress")
41 | vis_path = os.path.join(args.model_path, 'visualization')
42 | os.makedirs(vis_path, exist_ok=True)
43 |
44 | gaussians = GaussianModel(args)
45 |
46 | scene = Scene(args, gaussians)
47 |
48 | gaussians.training_setup(args)
49 |
50 | if args.env_map_res > 0:
51 | env_map = EnvLight(resolution=args.env_map_res).cuda()
52 | env_map.training_setup(args)
53 | else:
54 | env_map = None
55 |
56 | first_iter = 0
57 | if args.start_checkpoint:
58 | (model_params, first_iter) = torch.load(args.start_checkpoint)
59 | gaussians.restore(model_params, args)
60 |
61 | if env_map is not None:
62 | env_checkpoint = os.path.join(os.path.dirname(args.checkpoint),
63 | os.path.basename(args.checkpoint).replace("chkpnt", "env_light_chkpnt"))
64 | (light_params, _) = torch.load(env_checkpoint)
65 | env_map.restore(light_params)
66 |
67 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
68 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
69 |
70 | iter_start = torch.cuda.Event(enable_timing = True)
71 | iter_end = torch.cuda.Event(enable_timing = True)
72 |
73 | viewpoint_stack = None
74 |
75 | ema_dict_for_log = defaultdict(int)
76 | progress_bar = tqdm(range(first_iter + 1, args.iterations + 1), desc="Training progress")
77 |
78 | for iteration in progress_bar:
79 | iter_start.record()
80 | gaussians.update_learning_rate(iteration)
81 |
82 | # Every 1000 its we increase the levels of SH up to a maximum degree
83 | if iteration % args.sh_increase_interval == 0:
84 | gaussians.oneupSHdegree()
85 |
86 | if not viewpoint_stack:
87 | viewpoint_stack = list(range(len(scene.getTrainCameras())))
88 | viewpoint_cam = scene.getTrainCameras()[viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))]
89 |
90 | # render v and t scale map
91 | v = gaussians.get_inst_velocity
92 | t_scale = gaussians.get_scaling_t.clamp_max(2)
93 | other = [t_scale, v]
94 |
95 | if np.random.random() < args.lambda_self_supervision:
96 | time_shift = 3*(np.random.random() - 0.5) * scene.time_interval
97 | else:
98 | time_shift = None
99 |
100 | render_pkg = render(viewpoint_cam, gaussians, args, background, env_map=env_map, other=other, time_shift=time_shift, is_training=True)
101 |
102 | image = render_pkg["render"]
103 | depth = render_pkg["depth"]
104 | alpha = render_pkg["alpha"]
105 | viewspace_point_tensor = render_pkg["viewspace_points"]
106 | visibility_filter = render_pkg["visibility_filter"]
107 | radii = render_pkg["radii"]
108 | log_dict = {}
109 |
110 | feature = render_pkg['feature'] / alpha.clamp_min(EPS)
111 | t_map = feature[0:1]
112 | v_map = feature[1:]
113 |
114 | sky_mask = viewpoint_cam.sky_mask.cuda() if viewpoint_cam.sky_mask is not None else torch.zeros_like(alpha, dtype=torch.bool)
115 |
116 | sky_depth = 900
117 | depth = depth / alpha.clamp_min(EPS)
118 | if env_map is not None:
119 | if args.depth_blend_mode == 0: # harmonic mean
120 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS)
121 | elif args.depth_blend_mode == 1:
122 | depth = alpha * depth + (1 - alpha) * sky_depth
123 |
124 | gt_image = viewpoint_cam.original_image.cuda()
125 |
126 | loss_l1 = F.l1_loss(image, gt_image)
127 | log_dict['loss_l1'] = loss_l1.item()
128 | loss_ssim = 1.0 - ssim(image, gt_image)
129 | log_dict['loss_ssim'] = loss_ssim.item()
130 | loss = (1.0 - args.lambda_dssim) * loss_l1 + args.lambda_dssim * loss_ssim
131 |
132 | if args.lambda_lidar > 0:
133 | assert viewpoint_cam.pts_depth is not None
134 | pts_depth = viewpoint_cam.pts_depth.cuda()
135 |
136 | mask = pts_depth > 0
137 | loss_lidar = torch.abs(1 / (pts_depth[mask] + 1e-5) - 1 / (depth[mask] + 1e-5)).mean()
138 | if args.lidar_decay > 0:
139 | iter_decay = np.exp(-iteration / 8000 * args.lidar_decay)
140 | else:
141 | iter_decay = 1
142 | log_dict['loss_lidar'] = loss_lidar.item()
143 | loss += iter_decay * args.lambda_lidar * loss_lidar
144 |
145 | if args.lambda_t_reg > 0:
146 | loss_t_reg = -torch.abs(t_map).mean()
147 | log_dict['loss_t_reg'] = loss_t_reg.item()
148 | loss += args.lambda_t_reg * loss_t_reg
149 |
150 | if args.lambda_v_reg > 0:
151 | loss_v_reg = torch.abs(v_map).mean()
152 | log_dict['loss_v_reg'] = loss_v_reg.item()
153 | loss += args.lambda_v_reg * loss_v_reg
154 |
155 | if args.lambda_inv_depth > 0:
156 | inverse_depth = 1 / (depth + 1e-5)
157 | loss_inv_depth = kornia.losses.inverse_depth_smoothness_loss(inverse_depth[None], gt_image[None])
158 | log_dict['loss_inv_depth'] = loss_inv_depth.item()
159 | loss = loss + args.lambda_inv_depth * loss_inv_depth
160 |
161 | if args.lambda_v_smooth > 0:
162 | loss_v_smooth = kornia.losses.inverse_depth_smoothness_loss(v_map[None], gt_image[None])
163 | log_dict['loss_v_smooth'] = loss_v_smooth.item()
164 | loss = loss + args.lambda_v_smooth * loss_v_smooth
165 |
166 | if args.lambda_sky_opa > 0:
167 | o = alpha.clamp(1e-6, 1-1e-6)
168 | sky = sky_mask.float()
169 | loss_sky_opa = (-sky * torch.log(1 - o)).mean()
170 | log_dict['loss_sky_opa'] = loss_sky_opa.item()
171 | loss = loss + args.lambda_sky_opa * loss_sky_opa
172 |
173 | if args.lambda_opacity_entropy > 0:
174 | o = alpha.clamp(1e-6, 1 - 1e-6)
175 | loss_opacity_entropy = -(o*torch.log(o)).mean()
176 | log_dict['loss_opacity_entropy'] = loss_opacity_entropy.item()
177 | loss = loss + args.lambda_opacity_entropy * loss_opacity_entropy
178 |
179 | loss.backward()
180 | log_dict['loss'] = loss.item()
181 |
182 | iter_end.record()
183 |
184 | with torch.no_grad():
185 | psnr_for_log = psnr(image, gt_image).double()
186 | log_dict["psnr"] = psnr_for_log
187 | for key in ['loss', "loss_l1", "psnr"]:
188 | ema_dict_for_log[key] = 0.4 * log_dict[key] + 0.6 * ema_dict_for_log[key]
189 |
190 | if iteration % 10 == 0:
191 | postfix = {k[5:] if k.startswith("loss_") else k:f"{ema_dict_for_log[k]:.{5}f}" for k, v in ema_dict_for_log.items()}
192 | postfix["scale"] = scene.resolution_scales[scene.scale_index]
193 | progress_bar.set_postfix(postfix)
194 |
195 | log_dict['iter_time'] = iter_start.elapsed_time(iter_end)
196 | log_dict['total_points'] = gaussians.get_xyz.shape[0]
197 | # Log and save
198 | complete_eval(tb_writer, iteration, args.test_iterations, scene, render, (args, background),
199 | log_dict, env_map=env_map)
200 |
201 | # Densification
202 | if iteration > args.densify_until_iter * args.time_split_frac:
203 | gaussians.no_time_split = False
204 |
205 | if iteration < args.densify_until_iter and (args.densify_until_num_points < 0 or gaussians.get_xyz.shape[0] < args.densify_until_num_points):
206 | # Keep track of max radii in image-space for pruning
207 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
208 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
209 | if iteration > args.densify_from_iter and iteration % args.densification_interval == 0:
210 | size_threshold = args.size_threshold if (iteration > args.opacity_reset_interval and args.prune_big_point > 0) else None
211 |
212 | if size_threshold is not None:
213 | size_threshold = size_threshold // scene.resolution_scales[0]
214 |
215 | gaussians.densify_and_prune(args.densify_grad_threshold, args.thresh_opa_prune, scene.cameras_extent, size_threshold, args.densify_grad_t_threshold)
216 |
217 | if iteration % args.opacity_reset_interval == 0 or (args.white_background and iteration == args.densify_from_iter):
218 | gaussians.reset_opacity()
219 |
220 | gaussians.optimizer.step()
221 | gaussians.optimizer.zero_grad(set_to_none = True)
222 | if env_map is not None and iteration < args.env_optimize_until:
223 | env_map.optimizer.step()
224 | env_map.optimizer.zero_grad(set_to_none = True)
225 | torch.cuda.empty_cache()
226 |
227 | if iteration % args.vis_step == 0 or iteration == 1:
228 | other_img = []
229 | feature = render_pkg['feature'] / alpha.clamp_min(1e-5)
230 | t_map = feature[0:1]
231 | v_map = feature[1:]
232 | v_norm_map = v_map.norm(dim=0, keepdim=True)
233 |
234 | et_color = visualize_depth(t_map, near=0.01, far=1)
235 | v_color = visualize_depth(v_norm_map, near=0.01, far=1)
236 | other_img.append(et_color)
237 | other_img.append(v_color)
238 |
239 | if viewpoint_cam.pts_depth is not None:
240 | pts_depth_vis = visualize_depth(viewpoint_cam.pts_depth)
241 | other_img.append(pts_depth_vis)
242 |
243 | grid = make_grid([
244 | image,
245 | gt_image,
246 | alpha.repeat(3, 1, 1),
247 | torch.logical_not(sky_mask[:1]).float().repeat(3, 1, 1),
248 | visualize_depth(depth),
249 | ] + other_img, nrow=4)
250 |
251 | save_image(grid, os.path.join(vis_path, f"{iteration:05d}_{viewpoint_cam.colmap_id:03d}.png"))
252 |
253 | if iteration % args.scale_increase_interval == 0:
254 | scene.upScale()
255 |
256 | if iteration in args.checkpoint_iterations:
257 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
258 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
259 | torch.save((env_map.capture(), iteration), scene.model_path + "/env_light_chkpnt" + str(iteration) + ".pth")
260 |
261 |
262 | def complete_eval(tb_writer, iteration, test_iterations, scene : Scene, renderFunc, renderArgs, log_dict, env_map=None):
263 | from lpipsPyTorch import lpips
264 |
265 | if tb_writer:
266 | for key, value in log_dict.items():
267 | tb_writer.add_scalar(f'train/{key}', value, iteration)
268 |
269 | if iteration in test_iterations:
270 | scale = scene.resolution_scales[scene.scale_index]
271 | if iteration < args.iterations:
272 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},)
273 | else:
274 | if "kitti" in args.model_path:
275 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766
276 | num = len(scene.getTrainCameras())//2
277 | eval_train_frame = num//5
278 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id)
279 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
280 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]})
281 | else:
282 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
283 | {'name': 'train', 'cameras': scene.getTrainCameras()})
284 |
285 |
286 |
287 | for config in validation_configs:
288 | if config['cameras'] and len(config['cameras']) > 0:
289 | l1_test = 0.0
290 | psnr_test = 0.0
291 | ssim_test = 0.0
292 | lpips_test = 0.0
293 | outdir = os.path.join(args.model_path, "eval", config['name'] + f"_{iteration}" + "_render")
294 | os.makedirs(outdir,exist_ok=True)
295 | for idx, viewpoint in enumerate(tqdm(config['cameras'])):
296 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map)
297 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
298 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
299 | depth = render_pkg['depth']
300 | alpha = render_pkg['alpha']
301 | sky_depth = 900
302 | depth = depth / alpha.clamp_min(EPS)
303 | if env_map is not None:
304 | if args.depth_blend_mode == 0: # harmonic mean
305 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS)
306 | elif args.depth_blend_mode == 1:
307 | depth = alpha * depth + (1 - alpha) * sky_depth
308 |
309 | depth = visualize_depth(depth)
310 | alpha = alpha.repeat(3, 1, 1)
311 |
312 | grid = [gt_image, image, alpha, depth]
313 | grid = make_grid(grid, nrow=2)
314 |
315 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png"))
316 |
317 | l1_test += F.l1_loss(image, gt_image).double()
318 | psnr_test += psnr(image, gt_image).double()
319 | ssim_test += ssim(image, gt_image).double()
320 | lpips_test += lpips(image, gt_image, net_type='vgg').double() # very slow
321 |
322 | psnr_test /= len(config['cameras'])
323 | l1_test /= len(config['cameras'])
324 | ssim_test /= len(config['cameras'])
325 | lpips_test /= len(config['cameras'])
326 |
327 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test))
328 | if tb_writer:
329 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
330 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
331 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration)
332 | with open(os.path.join(outdir, "metrics.json"), "w") as f:
333 | json.dump({"split": config['name'], "iteration": iteration, "psnr": psnr_test.item(), "ssim": ssim_test.item(), "lpips": lpips_test.item()}, f)
334 | torch.cuda.empty_cache()
335 |
336 |
337 | if __name__ == "__main__":
338 | # Set up command line argument parser
339 | parser = ArgumentParser(description="Training script parameters")
340 | parser.add_argument("--config", type=str, required=True)
341 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml")
342 | args, _ = parser.parse_known_args()
343 |
344 | base_conf = OmegaConf.load(args.base_config)
345 | second_conf = OmegaConf.load(args.config)
346 | cli_conf = OmegaConf.from_cli()
347 | args = OmegaConf.merge(base_conf, second_conf, cli_conf)
348 | print(args)
349 |
350 | args.save_iterations.append(args.iterations)
351 | args.checkpoint_iterations.append(args.iterations)
352 | args.test_iterations.append(args.iterations)
353 |
354 | if args.exhaust_test:
355 | args.test_iterations += [i for i in range(0,args.iterations, args.test_interval)]
356 |
357 | print("Optimizing " + args.model_path)
358 |
359 | seed_everything(args.seed)
360 |
361 | training(args)
362 |
363 | # All done
364 | print("\nTraining complete.")
365 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import cv2
14 | from scene.cameras import Camera
15 | import numpy as np
16 | from scene.scene_utils import CameraInfo
17 | from tqdm import tqdm
18 | from .graphics_utils import fov2focal
19 |
20 |
21 | def loadCam(args, id, cam_info: CameraInfo, resolution_scale):
22 | orig_w, orig_h = cam_info.width, cam_info.height # cam_info.image.size
23 |
24 | if args.resolution in [1, 2, 3, 4, 8, 16, 32]:
25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round(
26 | orig_h / (resolution_scale * args.resolution)
27 | )
28 | scale = resolution_scale * args.resolution
29 | else: # should be a type that converts to float
30 | if args.resolution == -1:
31 | global_down = 1
32 | else:
33 | global_down = orig_w / args.resolution
34 |
35 | scale = float(global_down) * float(resolution_scale)
36 | resolution = (int(orig_w / scale), int(orig_h / scale))
37 |
38 | if cam_info.cx:
39 | cx = cam_info.cx / scale
40 | cy = cam_info.cy / scale
41 | fy = cam_info.fy / scale
42 | fx = cam_info.fx / scale
43 | else:
44 | cx = None
45 | cy = None
46 | fy = None
47 | fx = None
48 |
49 | if cam_info.image.shape[:2] != resolution[::-1]:
50 | image_rgb = cv2.resize(cam_info.image, resolution)
51 | else:
52 | image_rgb = cam_info.image
53 | image_rgb = torch.from_numpy(image_rgb).float().permute(2, 0, 1)
54 | gt_image = image_rgb[:3, ...]
55 |
56 | if cam_info.sky_mask is not None:
57 | if cam_info.sky_mask.shape[:2] != resolution[::-1]:
58 | sky_mask = cv2.resize(cam_info.sky_mask, resolution)
59 | else:
60 | sky_mask = cam_info.sky_mask
61 | if len(sky_mask.shape) == 2:
62 | sky_mask = sky_mask[..., None]
63 | sky_mask = torch.from_numpy(sky_mask).float().permute(2, 0, 1)
64 | else:
65 | sky_mask = None
66 |
67 | if cam_info.pointcloud_camera is not None:
68 | h, w = gt_image.shape[1:]
69 | K = np.eye(3)
70 | if cam_info.cx:
71 | K[0, 0] = fx
72 | K[1, 1] = fy
73 | K[0, 2] = cx
74 | K[1, 2] = cy
75 | else:
76 | K[0, 0] = fov2focal(cam_info.FovX, w)
77 | K[1, 1] = fov2focal(cam_info.FovY, h)
78 | K[0, 2] = cam_info.width / 2
79 | K[1, 2] = cam_info.height / 2
80 | pts_depth = np.zeros([1, h, w])
81 | point_camera = cam_info.pointcloud_camera
82 | uvz = point_camera[point_camera[:, 2] > 0]
83 | uvz = uvz @ K.T
84 | uvz[:, :2] /= uvz[:, 2:]
85 | uvz = uvz[uvz[:, 1] >= 0]
86 | uvz = uvz[uvz[:, 1] < h]
87 | uvz = uvz[uvz[:, 0] >= 0]
88 | uvz = uvz[uvz[:, 0] < w]
89 | uv = uvz[:, :2]
90 | uv = uv.astype(int)
91 | # TODO: may need to consider overlap
92 | pts_depth[0, uv[:, 1], uv[:, 0]] = uvz[:, 2]
93 | pts_depth = torch.from_numpy(pts_depth).float()
94 | else:
95 | pts_depth = None
96 |
97 | return Camera(
98 | colmap_id=cam_info.uid,
99 | uid=id,
100 | R=cam_info.R,
101 | T=cam_info.T,
102 | FoVx=cam_info.FovX,
103 | FoVy=cam_info.FovY,
104 | cx=cx,
105 | cy=cy,
106 | fx=fx,
107 | fy=fy,
108 | image=gt_image,
109 | image_name=cam_info.image_name,
110 | data_device=args.data_device,
111 | timestamp=cam_info.timestamp,
112 | resolution=resolution,
113 | image_path=cam_info.image_path,
114 | pts_depth=pts_depth,
115 | sky_mask=sky_mask,
116 | )
117 |
118 |
119 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
120 | camera_list = []
121 |
122 | for id, c in enumerate(tqdm(cam_infos)):
123 | camera_list.append(loadCam(args, id, c, resolution_scale))
124 |
125 | return camera_list
126 |
127 |
128 | def camera_to_JSON(id, camera: Camera):
129 | Rt = np.zeros((4, 4))
130 | Rt[:3, :3] = camera.R.transpose()
131 | Rt[:3, 3] = camera.T
132 | Rt[3, 3] = 1.0
133 |
134 | W2C = np.linalg.inv(Rt)
135 | pos = W2C[:3, 3]
136 | rot = W2C[:3, :3]
137 | serializable_array_2d = [x.tolist() for x in rot]
138 |
139 | if camera.cx is None:
140 | camera_entry = {
141 | "id": id,
142 | "img_name": camera.image_name,
143 | "width": camera.width,
144 | "height": camera.height,
145 | "position": pos.tolist(),
146 | "rotation": serializable_array_2d,
147 | "FoVx": camera.FovX,
148 | "FoVy": camera.FovY,
149 | }
150 | else:
151 | camera_entry = {
152 | "id": id,
153 | "img_name": camera.image_name,
154 | "width": camera.width,
155 | "height": camera.height,
156 | "position": pos.tolist(),
157 | "rotation": serializable_array_2d,
158 | "fx": camera.fx,
159 | "fy": camera.fy,
160 | "cx": camera.cx,
161 | "cy": camera.cy,
162 | }
163 | return camera_entry
164 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | import numpy as np
15 | import random
16 | from matplotlib import cm
17 |
18 | def visualize_depth(depth, near=0.2, far=13, linear=False):
19 | depth = depth[0].clone().detach().cpu().numpy()
20 | colormap = cm.get_cmap('turbo')
21 | curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps)
22 | if linear:
23 | curve_fn = lambda x: -x
24 | eps = np.finfo(np.float32).eps
25 | near = near if near else depth.min()
26 | far = far if far else depth.max()
27 | near -= eps
28 | far += eps
29 | near, far, depth = [curve_fn(x) for x in [near, far, depth]]
30 | depth = np.nan_to_num(
31 | np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1))
32 | vis = colormap(depth)[:, :, :3]
33 | out_depth = np.clip(np.nan_to_num(vis), 0., 1.) * 255
34 | out_depth = torch.from_numpy(out_depth).permute(2, 0, 1).float().cuda() / 255
35 | return out_depth
36 |
37 |
38 | def inverse_sigmoid(x):
39 | return torch.log(x / (1 - x))
40 |
41 |
42 | def PILtoTorch(pil_image, resolution):
43 | resized_image_PIL = pil_image.resize(resolution)
44 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
45 | if len(resized_image.shape) == 3:
46 | return resized_image.permute(2, 0, 1)
47 | else:
48 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
49 |
50 |
51 | def get_step_lr_func(lr_init, lr_final, start_step):
52 | def helper(step):
53 | if step < start_step:
54 | return lr_init
55 | else:
56 | return lr_final
57 | return helper
58 |
59 | def get_expon_lr_func(
60 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
61 | ):
62 | """
63 | Copied from Plenoxels
64 |
65 | Continuous learning rate decay function. Adapted from JaxNeRF
66 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
67 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
68 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
69 | function of lr_delay_mult, such that the initial learning rate is
70 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
71 | to the normal learning rate when steps>lr_delay_steps.
72 | :param conf: config subtree 'lr' or similar
73 | :param max_steps: int, the number of steps during optimization.
74 | :return HoF which takes step as input
75 | """
76 |
77 | def helper(step):
78 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
79 | # Disable this parameter
80 | return 0.0
81 | if lr_delay_steps > 0:
82 | # A kind of reverse cosine decay.
83 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
84 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
85 | )
86 | else:
87 | delay_rate = 1.0
88 | t = np.clip(step / max_steps, 0, 1)
89 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
90 | return delay_rate * log_lerp
91 |
92 | return helper
93 |
94 |
95 | def strip_lowerdiag(L):
96 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
97 |
98 | uncertainty[:, 0] = L[:, 0, 0]
99 | uncertainty[:, 1] = L[:, 0, 1]
100 | uncertainty[:, 2] = L[:, 0, 2]
101 | uncertainty[:, 3] = L[:, 1, 1]
102 | uncertainty[:, 4] = L[:, 1, 2]
103 | uncertainty[:, 5] = L[:, 2, 2]
104 | return uncertainty
105 |
106 |
107 | def strip_symmetric(sym):
108 | return strip_lowerdiag(sym)
109 |
110 |
111 | def build_rotation(r):
112 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
113 |
114 | q = r / norm[:, None]
115 |
116 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
117 |
118 | r = q[:, 0]
119 | x = q[:, 1]
120 | y = q[:, 2]
121 | z = q[:, 3]
122 |
123 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
124 | R[:, 0, 1] = 2 * (x * y - r * z)
125 | R[:, 0, 2] = 2 * (x * z + r * y)
126 | R[:, 1, 0] = 2 * (x * y + r * z)
127 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
128 | R[:, 1, 2] = 2 * (y * z - r * x)
129 | R[:, 2, 0] = 2 * (x * z - r * y)
130 | R[:, 2, 1] = 2 * (y * z + r * x)
131 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
132 | return R
133 |
134 |
135 | def build_scaling_rotation(s, r):
136 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
137 | R = build_rotation(r)
138 |
139 | L[:, 0, 0] = s[:, 0]
140 | L[:, 1, 1] = s[:, 1]
141 | L[:, 2, 2] = s[:, 2]
142 |
143 | L = R @ L
144 | return L
145 |
146 |
147 | def seed_everything(seed):
148 | random.seed(seed)
149 | os.environ['PYTHONHASHSEED'] = str(seed)
150 | np.random.seed(seed)
151 | torch.manual_seed(seed)
152 | torch.cuda.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 |
155 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | class BasicPointCloud(NamedTuple):
18 | points : np.array
19 | colors : np.array
20 | normals : np.array
21 | time : np.array = None
22 |
23 | def getWorld2View(R, t):
24 | Rt = np.zeros((4, 4))
25 | Rt[:3, :3] = R.transpose()
26 | Rt[:3, 3] = t
27 | Rt[3, 3] = 1.0
28 | return np.float32(Rt)
29 |
30 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
31 | Rt = np.zeros((4, 4))
32 | Rt[:3, :3] = R.transpose()
33 | Rt[:3, 3] = t
34 | Rt[3, 3] = 1.0
35 |
36 | C2W = np.linalg.inv(Rt)
37 | cam_center = C2W[:3, 3]
38 | cam_center = (cam_center + translate) * scale
39 | C2W[:3, 3] = cam_center
40 | Rt = np.linalg.inv(C2W)
41 | return np.float32(Rt)
42 |
43 | def getProjectionMatrix(znear, zfar, fovX, fovY):
44 | tanHalfFovY = math.tan((fovY / 2))
45 | tanHalfFovX = math.tan((fovX / 2))
46 |
47 | top = tanHalfFovY * znear
48 | bottom = -top
49 | right = tanHalfFovX * znear
50 | left = -right
51 |
52 | P = torch.zeros(4, 4)
53 |
54 | z_sign = 1.0
55 |
56 | P[0, 0] = 2.0 * znear / (right - left)
57 | P[1, 1] = 2.0 * znear / (top - bottom)
58 | P[0, 2] = (right + left) / (right - left)
59 | P[1, 2] = (top + bottom) / (top - bottom)
60 | P[3, 2] = z_sign
61 | P[2, 2] = z_sign * zfar / (zfar - znear)
62 | P[2, 3] = -(zfar * znear) / (zfar - znear)
63 | return P
64 |
65 | def getProjectionMatrixCenterShift(znear, zfar, cx, cy, fx, fy, w, h):
66 | top = cy / fy * znear
67 | bottom = -(h-cy) / fy * znear
68 |
69 | left = -(w-cx) / fx * znear
70 | right = cx / fx * znear
71 |
72 | P = torch.zeros(4, 4)
73 |
74 | z_sign = 1.0
75 |
76 | P[0, 0] = 2.0 * znear / (right - left)
77 | P[1, 1] = 2.0 * znear / (top - bottom)
78 | P[0, 2] = (right + left) / (right - left)
79 | P[1, 2] = (top + bottom) / (top - bottom)
80 | P[3, 2] = z_sign
81 | P[2, 2] = z_sign * zfar / (zfar - znear)
82 | P[2, 3] = -(zfar * znear) / (zfar - znear)
83 | return P
84 |
85 | def fov2focal(fov, pixels):
86 | return pixels / (2 * math.tan(fov / 2))
87 |
88 | def focal2fov(focal, pixels):
89 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def psnr(img1, img2):
18 | mse = F.mse_loss(img1, img2)
19 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
20 |
21 | def gaussian(window_size, sigma):
22 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
23 | return gauss / gauss.sum()
24 |
25 | def create_window(window_size, channel):
26 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
27 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
28 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
29 | return window
30 |
31 | def ssim(img1, img2, window_size=11, size_average=True):
32 | channel = img1.size(-3)
33 | window = create_window(window_size, channel)
34 |
35 | if img1.is_cuda:
36 | window = window.cuda(img1.get_device())
37 | window = window.type_as(img1)
38 |
39 | return _ssim(img1, img2, window, window_size, channel, size_average)
40 |
41 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
42 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
43 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
44 |
45 | mu1_sq = mu1.pow(2)
46 | mu2_sq = mu2.pow(2)
47 | mu1_mu2 = mu1 * mu2
48 |
49 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
50 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
51 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
52 |
53 | C1 = 0.01 ** 2
54 | C2 = 0.03 ** 2
55 |
56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
57 |
58 | if size_average:
59 | return ssim_map.mean()
60 | else:
61 | return ssim_map.mean(1).mean(1).mean(1)
62 |
63 | def tv_loss(depth):
64 | c, h, w = depth.shape[0], depth.shape[1], depth.shape[2]
65 | count_h = c * (h - 1) * w
66 | count_w = c * h * (w - 1)
67 | h_tv = torch.square(depth[..., 1:, :] - depth[..., :h-1, :]).sum()
68 | w_tv = torch.square(depth[..., :, 1:] - depth[..., :, :w-1]).sum()
69 | return 2 * (h_tv / count_h + w_tv / count_w)
70 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 |
15 | def searchForMaxIteration(folder):
16 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
17 | return max(saved_iters)
18 |
19 |
20 | class Timing:
21 | """
22 | From https://github.com/sxyu/svox2/blob/ee80e2c4df8f29a407fda5729a494be94ccf9234/svox2/utils.py#L611
23 |
24 | Timing environment
25 | usage:
26 | with Timing("message"):
27 | your commands here
28 | will print CUDA runtime in ms
29 | """
30 |
31 | def __init__(self, name):
32 | self.name = name
33 |
34 | def __enter__(self):
35 | self.start = torch.cuda.Event(enable_timing=True)
36 | self.end = torch.cuda.Event(enable_timing=True)
37 | self.start.record()
38 |
39 | def __exit__(self, type, value, traceback):
40 | self.end.record()
41 | torch.cuda.synchronize()
42 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms")
--------------------------------------------------------------------------------