├── .gitignore ├── LICENSE ├── LICENSE_gaussian_splatting.md ├── README.md ├── assets └── pipeline.png ├── configs ├── base.yaml ├── kitti_nvs.yaml ├── kitti_reconstruction.yaml ├── waymo_nvs.yaml └── waymo_reconstruction.yaml ├── evaluate.py ├── gaussian_renderer └── __init__.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── requirements-data.txt ├── requirements.txt ├── scene ├── __init__.py ├── cameras.py ├── envlight.py ├── gaussian_model.py ├── kittimot_loader.py ├── scene_utils.py └── waymo_loader.py ├── scripts ├── extract_kitti_metric_nvs.py ├── extract_kitti_metric_reconstruction.py ├── extract_mask_kitti.py ├── extract_mask_waymo.py ├── extract_scenes_waymo.py ├── extract_waymo_metric_nvs.py ├── extract_waymo_metric_reconstruction.py ├── run_kitti_nvs_all.sh ├── run_kitti_reconstruction_all.sh ├── run_waymo_nvs_all.sh ├── run_waymo_reconstruction_all.sh └── waymo_converter.py ├── separate.py ├── train.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.so 3 | build/ 4 | *.egg-info/ 5 | .vscode 6 | 7 | 8 | build 9 | data 10 | output 11 | eval_output 12 | diff-gaussian-rasterization -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fudan Zhang Vision Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE_gaussian_splatting.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering 2 | ### [[Project]](https://fudan-zvg.github.io/PVG) [[Paper]](https://arxiv.org/abs/2311.18561) 3 | 4 | > [**Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering**](https://arxiv.org/abs/2311.18561), 5 | > Yurui Chen, [Chun Gu](https://sulvxiangxin.github.io/), Junzhe Jiang, [Xiatian Zhu](https://surrey-uplab.github.io/), [Li Zhang](https://lzrobots.github.io) 6 | > **Arxiv preprint** 7 | 8 | **Official implementation of "Periodic Vibration Gaussian: 9 | Dynamic Urban Scene Reconstruction and Real-time Rendering".** 10 | 11 | 12 | ## 🛠️ Pipeline 13 |
14 | 15 |

16 | 17 | ## Get started 18 | ### Environment 19 | ``` 20 | # Clone the repo. 21 | git clone https://github.com/fudan-zvg/PVG.git 22 | cd PVG 23 | 24 | # Make a conda environment. 25 | conda create --name pvg python=3.9 26 | conda activate pvg 27 | 28 | # Install requirements. 29 | pip install -r requirements.txt 30 | 31 | # Install simple-knn 32 | git clone https://gitlab.inria.fr/bkerbl/simple-knn.git 33 | pip install ./simple-knn 34 | 35 | # a modified gaussian splatting (for feature rendering) 36 | git clone --recursive https://github.com/SuLvXiangXin/diff-gaussian-rasterization 37 | pip install ./diff-gaussian-rasterization 38 | 39 | # Install nvdiffrast (for Envlight) 40 | git clone https://github.com/NVlabs/nvdiffrast 41 | pip install ./nvdiffrast 42 | 43 | ``` 44 | 45 | ### Data preparation 46 | Create a directory for the data: `mkdir data`. 47 | #### Waymo dataset 48 | 49 | Preprocessed 4 waymo scenes for results in Table 1 of our paper can be downloaded [here](https://drive.google.com/file/d/1eTNJz7WeYrB3IctVlUmJIY0z8qhjR_qF/view?usp=sharing) (optional: [corresponding label](https://drive.google.com/file/d/1rkOzYqD1wdwILq_tUNvXBcXMe5YwtI2k/view?usp=drive_link)). Please unzip and put it into `data` directory. 50 | 51 | First prepare the kitti-format Waymo dataset: 52 | ``` 53 | # Given the following dataset, we convert it to kitti-format 54 | # data 55 | # └── waymo 56 | # └── waymo_format 57 | # └── training 58 | # └── segment-xxxxxx 59 | 60 | # install some optional package 61 | pip install -r requirements-data.txt 62 | 63 | # Convert the waymo dataset to kitti-format 64 | python scripts/waymo_converter.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo 65 | ``` 66 | Then use the example script `scripts/extract_scenes_waymo.py` to extract the scenes from the kitti-format Waymo dataset which we employ to extract the scenes listed in StreetSurf. 67 | 68 | Following [StreetSurf](https://github.com/PJLab-ADG/neuralsim), we use [Segformer](https://github.com/NVlabs/SegFormer) to extract the sky mask and put them as follows: 69 | ``` 70 | data 71 | └── waymo_scenes 72 | └── sequence_id 73 | ├── calib 74 | │ └── frame_id.txt 75 | ├── image_0{0, 1, 2, 3, 4} 76 | │ └── frame_id.png 77 | ├── sky_0{0, 1, 2, 3, 4} 78 | │ └── frame_id.png 79 | |── pose 80 | | └── frame_id.txt 81 | └── velodyne 82 | └── frame_id.bin 83 | ``` 84 | We provide an example script `scripts/extract_mask_waymo.py` to extract the sky mask from the extracted Waymo dataset, follow instructions [here](https://github.com/PJLab-ADG/neuralsim/blob/main/dataio/autonomous_driving/waymo/README.md#extract-mask-priors----for-sky-pedestrian-etc) to setup the Segformer environment. 85 | 86 | #### KITTI dataset 87 | Preprocessed 3 kitti scenes for results in Table 1 of our paper can be downloaded [here](https://drive.google.com/file/d/1y6elRlFdRXW02oUOHdS9inVHK3U4xBXZ/view?usp=sharinghttps://drive.google.com/file/d/1y6elRlFdRXW02oUOHdS9inVHK3U4xBXZ/view?usp=sharing). Please unzip and put it into `data` directory. 88 | 89 | Put the [KITTI-MOT](https://www.cvlibs.net/datasets/kitti/eval_tracking.php) dataset in `data` directory. 90 | Following [StreetSurf](https://github.com/PJLab-ADG/neuralsim), we use [Segformer](https://github.com/NVlabs/SegFormer) to extract the sky mask and put them as follows: 91 | ``` 92 | data 93 | └── kitti_mot 94 | └── training 95 | ├── calib 96 | │ └── sequence_id.txt 97 | ├── image_0{2, 3} 98 | │ └── sequence_id 99 | │ └── frame_id.png 100 | ├── sky_0{2, 3} 101 | │ └── sequence_id 102 | │ └── frame_id.png 103 | |── oxts 104 | | └── sequence_id.txt 105 | └── velodyne 106 | └── sequence_id 107 | └── frame_id.bin 108 | ``` 109 | We also provide an example script `scripts/extract_mask_kitti.py` to extract the sky mask from the KITTI dataset. 110 | 111 | 112 | ### Training 113 | ``` 114 | # Waymo image reconstruction 115 | CUDA_VISIBLE_DEVICES=0 python train.py \ 116 | --config configs/waymo_reconstruction.yaml \ 117 | source_path=data/waymo_scenes/0145050 \ 118 | model_path=eval_output/waymo_reconstruction/0145050 119 | 120 | # Waymo novel view synthesis 121 | CUDA_VISIBLE_DEVICES=0 python train.py \ 122 | --config configs/waymo_nvs.yaml \ 123 | source_path=data/waymo_scenes/0145050 \ 124 | model_path=eval_output/waymo_nvs/0145050 125 | 126 | # KITTI image reconstruction 127 | CUDA_VISIBLE_DEVICES=0 python train.py \ 128 | --config configs/kitti_reconstruction.yaml \ 129 | source_path=data/kitti_mot/training/image_02/0001 \ 130 | model_path=eval_output/kitti_reconstruction/0001 \ 131 | start_frame=380 end_frame=431 132 | 133 | # KITTI novel view synthesis 134 | CUDA_VISIBLE_DEVICES=0 python train.py \ 135 | --config configs/kitti_nvs.yaml \ 136 | source_path=data/kitti_mot/training/image_02/0001 \ 137 | model_path=eval_output/kitti_nvs/0001 \ 138 | start_frame=380 end_frame=431 139 | ``` 140 | 141 | After training, evaluation results can be found in `{EXPERIMENT_DIR}/eval` directory. 142 | 143 | ### Evaluating 144 | You can also use the following command to evaluate. 145 | ``` 146 | CUDA_VISIBLE_DEVICES=0 python evaluate.py \ 147 | --config configs/kitti_reconstruction.yaml \ 148 | source_path=data/kitti_mot/training/image_02/0001 \ 149 | model_path=eval_output/kitti_reconstruction/0001 \ 150 | start_frame=380 end_frame=431 151 | ``` 152 | 153 | ### Automatically removing the dynamics 154 | You can the following command to automatically remove the dynamics, the render results will be saved in `{EXPERIMENT_DIR}/separation` directory. 155 | ``` 156 | CUDA_VISIBLE_DEVICES=1 python separate.py \ 157 | --config configs/waymo_reconstruction.yaml \ 158 | source_path=data/waymo_scenes/0158150 \ 159 | model_path=eval_output/waymo_reconstruction/0158150 160 | ``` 161 | 162 | 163 | ## 🎥 Videos 164 | ### 🎞️ Demo 165 | [![Demo Video](https://i3.ytimg.com/vi/jJCCkdpDkRQ/maxresdefault.jpg)](https://www.youtube.com/embed/jJCCkdpDkRQ) 166 | 167 | 168 | ### 🎞️ Rendered RGB, Depth and Semantic 169 | 170 | https://github.com/fudan-zvg/PVG/assets/83005605/60337a98-f92c-4465-ab45-2ee121413114 171 | 172 | https://github.com/fudan-zvg/PVG/assets/83005605/f45c0a91-26b6-46d9-895c-bf13786f94d2 173 | 174 | https://github.com/fudan-zvg/PVG/assets/83005605/0ed679d6-5e62-4923-b2cb-02c587ed468c 175 | 176 | https://github.com/fudan-zvg/PVG/assets/83005605/3ffda292-1b73-43d3-916a-b524f143f0c9 177 | 178 | ### 🎞️ Image Reconstruction on Waymo 179 | #### Comparison with static methods 180 | 181 | https://github.com/fudan-zvg/PVG/assets/83005605/93e32945-7e9a-454a-8c31-5563125de95b 182 | 183 | https://github.com/fudan-zvg/PVG/assets/83005605/f3c02e43-bb86-428d-b27b-73c4a7857bc7 184 | 185 | #### Comparison with dynamic methods 186 | 187 | https://github.com/fudan-zvg/PVG/assets/83005605/73a82171-9e78-416f-a770-f6f4239d80ca 188 | 189 | https://github.com/fudan-zvg/PVG/assets/83005605/e579f8b8-d31e-456b-a943-b39d56073b94 190 | 191 | ### 🎞️ Novel View Synthesis on Waymo 192 | 193 | https://github.com/fudan-zvg/PVG/assets/83005605/37393332-5d34-4bd0-8285-40bf938b849f 194 | 195 | ## 📜 BibTeX 196 | ```bibtex 197 | @article{chen2023periodic, 198 | title={Periodic Vibration Gaussian: Dynamic Urban Scene Reconstruction and Real-time Rendering}, 199 | author={Chen, Yurui and Gu, Chun and Jiang, Junzhe and Zhu, Xiatian and Zhang, Li}, 200 | journal={arXiv:2311.18561}, 201 | year={2023}, 202 | } 203 | ``` 204 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PVG/b4162a9135282e0f3c929054f16be1b3fbacd77a/assets/pipeline.png -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | test_iterations: [7000, 30000] 2 | save_iterations: [7000, 30000] 3 | checkpoint_iterations: [7000, 30000] 4 | exhaust_test: false 5 | test_interval: 5000 6 | render_static: false 7 | vis_step: 500 8 | start_checkpoint: null 9 | seed: 0 10 | 11 | # ModelParams 12 | sh_degree: 3 13 | scene_type: "Waymo" 14 | source_path: ??? 15 | start_frame: 65 # for kitti 16 | end_frame: 120 # for kitti 17 | model_path: ??? 18 | resolution_scales: [1] 19 | resolution: -1 20 | white_background: false 21 | data_device: "cuda" 22 | eval: false 23 | debug_cuda: false 24 | cam_num: 3 # for waymo 25 | t_init: 0.2 26 | cycle: 0.2 27 | velocity_decay: 1.0 28 | random_init_point: 200000 29 | fix_radius: 0.0 30 | time_duration: [-0.5, 0.5] 31 | num_pts: 100000 32 | frame_interval: 0.02 33 | testhold: 4 # NVS 34 | env_map_res: 1024 35 | separate_scaling_t: 0.1 36 | neg_fov: true 37 | 38 | # PipelineParams 39 | convert_SHs_python: false 40 | compute_cov3D_python: false 41 | debug: false 42 | depth_blend_mode: 0 43 | env_optimize_until: 1000000000 44 | env_optimize_from: 0 45 | 46 | 47 | # OptimizationParams 48 | iterations: 30000 49 | position_lr_init: 0.00016 50 | position_lr_final: 0.0000016 51 | position_lr_delay_mult: 0.01 52 | t_lr_init: 0.0008 53 | position_lr_max_steps: 30_000 54 | feature_lr: 0.0025 55 | opacity_lr: 0.05 56 | scaling_lr: 0.005 57 | scaling_t_lr: 0.002 58 | velocity_lr: 0.001 59 | rotation_lr: 0.001 60 | envmap_lr: 0.01 61 | 62 | time_split_frac: 0.5 63 | percent_dense: 0.01 64 | thresh_opa_prune: 0.005 65 | densification_interval: 100 66 | opacity_reset_interval: 3000 67 | densify_from_iter: 500 68 | densify_until_iter: 15_000 69 | densify_grad_threshold: 0.0002 70 | densify_grad_t_threshold: 0.002 71 | densify_until_num_points: 3000000 72 | sh_increase_interval: 1000 73 | scale_increase_interval: 5000 74 | prune_big_point: 1 75 | size_threshold: 20 76 | big_point_threshold: 0.1 77 | t_grad: true 78 | no_time_split: true 79 | contract: true 80 | 81 | lambda_dssim: 0.2 82 | lambda_opa: 0.0 83 | lambda_sky_opa: 0.05 84 | lambda_opacity_entropy: 0.05 85 | lambda_inv_depth: 0.001 86 | lambda_self_supervision: 0.5 87 | lambda_t_reg: 0.0 88 | lambda_v_reg: 0.0 89 | lambda_lidar: 0.1 90 | lidar_decay: 1.0 91 | lambda_v_smooth: 0.0 -------------------------------------------------------------------------------- /configs/kitti_nvs.yaml: -------------------------------------------------------------------------------- 1 | exhaust_test: false 2 | 3 | 4 | # ModelParams 5 | scene_type: "KittiMot" 6 | start_frame: 380 7 | end_frame: 431 8 | num_pts: 1000000 9 | cam_num: 3 10 | resolution_scales: [1, 2, 4, 8] 11 | eval: true 12 | t_init: 0.006 13 | fix_radius: 10.0 14 | 15 | # PipelineParams 16 | 17 | # OptimizationParams 18 | iterations: 40000 19 | densify_until_iter: 20000 20 | densify_until_num_points: 10000000 21 | 22 | opacity_lr: 0.007 23 | 24 | densification_interval: 200 25 | sh_increase_interval: 2000 26 | opacity_reset_interval: 5000 27 | densify_grad_threshold: 0.00015 28 | size_threshold: 100 29 | big_point_threshold: 0.4 30 | 31 | lidar_decay: 0.3 32 | lambda_v_reg: 0.001 33 | lambda_t_reg: 0.01 34 | -------------------------------------------------------------------------------- /configs/kitti_reconstruction.yaml: -------------------------------------------------------------------------------- 1 | exhaust_test: false 2 | 3 | 4 | # ModelParams 5 | scene_type: "KittiMot" 6 | start_frame: 380 7 | end_frame: 431 8 | num_pts: 1000000 9 | cam_num: 3 10 | resolution_scales: [1, 2, 4, 8] 11 | eval: false 12 | t_init: 0.006 13 | fix_radius: 15.0 14 | 15 | # PipelineParams 16 | 17 | # OptimizationParams 18 | iterations: 40000 19 | densify_until_iter: 20000 20 | densify_until_num_points: 10000000 21 | 22 | opacity_lr: 0.007 23 | 24 | densification_interval: 200 25 | sh_increase_interval: 2000 26 | opacity_reset_interval: 5000 27 | densify_grad_threshold: 0.00015 28 | size_threshold: 100 29 | big_point_threshold: 0.4 30 | 31 | lidar_decay: 0.3 32 | lambda_v_reg: 0.001 33 | -------------------------------------------------------------------------------- /configs/waymo_nvs.yaml: -------------------------------------------------------------------------------- 1 | exhaust_test: false 2 | 3 | 4 | # ModelParams 5 | scene_type: "Waymo" 6 | resolution_scales: [1, 2, 4, 8, 16] 7 | cam_num: 3 8 | eval: true 9 | num_pts: 600000 10 | t_init: 0.1 11 | separate_scaling_t: 0.2 12 | 13 | # PipelineParams 14 | 15 | 16 | # OptimizationParams 17 | iterations: 30000 18 | 19 | opacity_lr: 0.005 20 | 21 | densify_until_iter: 15000 22 | densify_grad_threshold: 0.00017 23 | sh_increase_interval: 2000 24 | 25 | 26 | lambda_v_reg: 0.01 27 | 28 | 29 | -------------------------------------------------------------------------------- /configs/waymo_reconstruction.yaml: -------------------------------------------------------------------------------- 1 | exhaust_test: false 2 | 3 | 4 | # ModelParams 5 | scene_type: "Waymo" 6 | resolution_scales: [1, 2, 4, 8, 16] 7 | cam_num: 3 8 | eval: false 9 | num_pts: 600000 10 | t_init: 0.1 11 | separate_scaling_t: 0.2 12 | 13 | # PipelineParams 14 | 15 | 16 | # OptimizationParams 17 | iterations: 30000 18 | 19 | opacity_lr: 0.005 20 | 21 | densify_until_iter: 15000 22 | densify_grad_threshold: 0.00017 23 | sh_increase_interval: 2000 24 | 25 | 26 | lambda_v_reg: 0.01 27 | 28 | 29 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import glob 12 | import json 13 | import os 14 | import torch 15 | import torch.nn.functional as F 16 | from utils.loss_utils import psnr, ssim 17 | from gaussian_renderer import render 18 | from scene import Scene, GaussianModel, EnvLight 19 | from utils.general_utils import seed_everything, visualize_depth 20 | from tqdm import tqdm 21 | from argparse import ArgumentParser 22 | from torchvision.utils import make_grid, save_image 23 | from omegaconf import OmegaConf 24 | 25 | EPS = 1e-5 26 | 27 | @torch.no_grad() 28 | def evaluation(iteration, scene : Scene, renderFunc, renderArgs, env_map=None): 29 | from lpipsPyTorch import lpips 30 | 31 | scale = scene.resolution_scales[0] 32 | if "kitti" in args.model_path: 33 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766 34 | num = len(scene.getTrainCameras())//2 35 | eval_train_frame = num//5 36 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id) 37 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 38 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]}) 39 | else: 40 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 41 | {'name': 'train', 'cameras': scene.getTrainCameras()}) 42 | 43 | for config in validation_configs: 44 | if config['cameras'] and len(config['cameras']) > 0: 45 | l1_test = 0.0 46 | psnr_test = 0.0 47 | ssim_test = 0.0 48 | lpips_test = 0.0 49 | outdir = os.path.join(args.model_path, "eval", config['name'] + f"_{iteration}" + "_render") 50 | os.makedirs(outdir,exist_ok=True) 51 | for idx, viewpoint in enumerate(tqdm(config['cameras'])): 52 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map) 53 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 54 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 55 | 56 | depth = render_pkg['depth'] 57 | alpha = render_pkg['alpha'] 58 | sky_depth = 900 59 | depth = depth / alpha.clamp_min(EPS) 60 | if env_map is not None: 61 | if args.depth_blend_mode == 0: # harmonic mean 62 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS) 63 | elif args.depth_blend_mode == 1: 64 | depth = alpha * depth + (1 - alpha) * sky_depth 65 | 66 | depth = visualize_depth(depth) 67 | alpha = alpha.repeat(3, 1, 1) 68 | 69 | grid = [gt_image, image, alpha, depth] 70 | grid = make_grid(grid, nrow=2) 71 | 72 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png")) 73 | 74 | l1_test += F.l1_loss(image, gt_image).double() 75 | psnr_test += psnr(image, gt_image).double() 76 | ssim_test += ssim(image, gt_image).double() 77 | lpips_test += lpips(image, gt_image, net_type='vgg').double() # very slow 78 | 79 | psnr_test /= len(config['cameras']) 80 | l1_test /= len(config['cameras']) 81 | ssim_test /= len(config['cameras']) 82 | lpips_test /= len(config['cameras']) 83 | 84 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test)) 85 | with open(os.path.join(outdir, "metrics.json"), "w") as f: 86 | json.dump({"split": config['name'], "iteration": iteration, "psnr": psnr_test.item(), "ssim": ssim_test.item(), "lpips": lpips_test.item()}, f) 87 | 88 | 89 | if __name__ == "__main__": 90 | # Set up command line argument parser 91 | parser = ArgumentParser(description="Training script parameters") 92 | parser.add_argument("--config", type=str, required=True) 93 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml") 94 | args, _ = parser.parse_known_args() 95 | 96 | base_conf = OmegaConf.load(args.base_config) 97 | second_conf = OmegaConf.load(args.config) 98 | cli_conf = OmegaConf.from_cli() 99 | args = OmegaConf.merge(base_conf, second_conf, cli_conf) 100 | args.resolution_scales = args.resolution_scales[:1] 101 | print(args) 102 | 103 | seed_everything(args.seed) 104 | 105 | sep_path = os.path.join(args.model_path, 'separation') 106 | os.makedirs(sep_path, exist_ok=True) 107 | 108 | gaussians = GaussianModel(args) 109 | scene = Scene(args, gaussians, shuffle=False) 110 | 111 | if args.env_map_res > 0: 112 | env_map = EnvLight(resolution=args.env_map_res).cuda() 113 | env_map.training_setup(args) 114 | else: 115 | env_map = None 116 | 117 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth")) 118 | assert len(checkpoints) > 0, "No checkpoints found." 119 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1] 120 | (model_params, first_iter) = torch.load(checkpoint) 121 | gaussians.restore(model_params, args) 122 | 123 | if env_map is not None: 124 | env_checkpoint = os.path.join(os.path.dirname(checkpoint), 125 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt")) 126 | (light_params, _) = torch.load(env_checkpoint) 127 | env_map.restore(light_params) 128 | 129 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0] 130 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 131 | evaluation(first_iter, scene, render, (args, background), env_map=env_map) 132 | 133 | print("Evaluation complete.") 134 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from scene.cameras import Camera 17 | from utils.sh_utils import eval_sh 18 | 19 | 20 | def render(viewpoint_camera: Camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, scaling_modifier=1.0, 21 | override_color=None, env_map=None, 22 | time_shift=None, other=[], mask=None, is_training=False): 23 | """ 24 | Render the scene. 25 | 26 | Background tensor (bg_color) must be on GPU! 27 | """ 28 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 29 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 30 | try: 31 | screenspace_points.retain_grad() 32 | except: 33 | pass 34 | 35 | # Set up rasterization configuration 36 | if pipe.neg_fov: 37 | # we find that set fov as -1 slightly improves the results 38 | tanfovx = math.tan(-0.5) 39 | tanfovy = math.tan(-0.5) 40 | else: 41 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 42 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 43 | 44 | raster_settings = GaussianRasterizationSettings( 45 | image_height=int(viewpoint_camera.image_height), 46 | image_width=int(viewpoint_camera.image_width), 47 | tanfovx=tanfovx, 48 | tanfovy=tanfovy, 49 | bg=bg_color if env_map is not None else torch.zeros(3, device="cuda"), 50 | scale_modifier=scaling_modifier, 51 | viewmatrix=viewpoint_camera.world_view_transform, 52 | projmatrix=viewpoint_camera.full_proj_transform, 53 | sh_degree=pc.active_sh_degree, 54 | campos=viewpoint_camera.camera_center, 55 | prefiltered=False, 56 | debug=pipe.debug 57 | ) 58 | 59 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 60 | 61 | means3D = pc.get_xyz 62 | means2D = screenspace_points 63 | opacity = pc.get_opacity 64 | scales = None 65 | rotations = None 66 | cov3D_precomp = None 67 | 68 | if time_shift is not None: 69 | means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp-time_shift) 70 | means3D = means3D + pc.get_inst_velocity * time_shift 71 | marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp-time_shift) 72 | else: 73 | means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp) 74 | marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp) 75 | opacity = opacity * marginal_t 76 | 77 | if pipe.compute_cov3D_python: 78 | cov3D_precomp = pc.get_covariance(scaling_modifier) 79 | else: 80 | scales = pc.get_scaling 81 | rotations = pc.get_rotation 82 | 83 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 84 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 85 | shs = None 86 | colors_precomp = None 87 | if override_color is None: 88 | if pipe.convert_SHs_python: 89 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, pc.get_max_sh_channels) 90 | dir_pp = (means3D.detach() - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)).detach() 91 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 92 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 93 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 94 | else: 95 | shs = pc.get_features 96 | else: 97 | colors_precomp = override_color 98 | 99 | feature_list = other 100 | 101 | if len(feature_list) > 0: 102 | features = torch.cat(feature_list, dim=1) 103 | S_other = features.shape[1] 104 | else: 105 | features = torch.zeros_like(means3D[:, :0]) 106 | S_other = 0 107 | 108 | # Prefilter 109 | if mask is None: 110 | mask = marginal_t[:, 0] > 0.05 111 | else: 112 | mask = mask & (marginal_t[:, 0] > 0.05) 113 | masked_means3D = means3D[mask] 114 | masked_xyz_homo = torch.cat([masked_means3D, torch.ones_like(masked_means3D[:, :1])], dim=1) 115 | masked_depth = (masked_xyz_homo @ viewpoint_camera.world_view_transform[:, 2:3]) 116 | depth_alpha = torch.zeros(means3D.shape[0], 2, dtype=torch.float32, device=means3D.device) 117 | depth_alpha[mask] = torch.cat([ 118 | masked_depth, 119 | torch.ones_like(masked_depth) 120 | ], dim=1) 121 | features = torch.cat([features, depth_alpha], dim=1) 122 | 123 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 124 | contrib, rendered_image, rendered_feature, radii = rasterizer( 125 | means3D = means3D, 126 | means2D = means2D, 127 | shs = shs, 128 | colors_precomp = colors_precomp, 129 | features = features, 130 | opacities = opacity, 131 | scales = scales, 132 | rotations = rotations, 133 | cov3D_precomp = cov3D_precomp, 134 | mask = mask) 135 | 136 | rendered_other, rendered_depth, rendered_opacity = rendered_feature.split([S_other, 1, 1], dim=0) 137 | rendered_image_before = rendered_image 138 | if env_map is not None: 139 | bg_color_from_envmap = env_map(viewpoint_camera.get_world_directions(is_training).permute(1, 2, 0)).permute(2, 0, 1) 140 | rendered_image = rendered_image + (1 - rendered_opacity) * bg_color_from_envmap 141 | 142 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 143 | # They will be excluded from value updates used in the splitting criteria. 144 | return {"render": rendered_image, 145 | "render_nobg": rendered_image_before, 146 | "viewspace_points": screenspace_points, 147 | "visibility_filter": radii > 0, 148 | "radii": radii, 149 | "contrib": contrib, 150 | "depth": rendered_depth, 151 | "alpha": rendered_opacity, 152 | "feature": rendered_other} 153 | 154 | 155 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y).mean() 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /requirements-data.txt: -------------------------------------------------------------------------------- 1 | waymo-open-dataset-tf-2-4-0 2 | mmcv-full 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | tqdm 4 | imageio 5 | imageio[ffmpeg] 6 | kornia 7 | trimesh 8 | Pillow 9 | ninja 10 | omegaconf 11 | plyfile 12 | opencv_python 13 | opencv_contrib_python 14 | tensorboardX 15 | matplotlib 16 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.gaussian_model import GaussianModel 17 | from scene.envlight import EnvLight 18 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 19 | from scene.waymo_loader import readWaymoInfo 20 | from scene.kittimot_loader import readKittiMotInfo 21 | 22 | sceneLoadTypeCallbacks = { 23 | "Waymo": readWaymoInfo, 24 | "KittiMot": readKittiMotInfo, 25 | } 26 | 27 | class Scene: 28 | 29 | gaussians : GaussianModel 30 | 31 | def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True): 32 | self.model_path = args.model_path 33 | self.loaded_iter = None 34 | self.gaussians = gaussians 35 | self.white_background = args.white_background 36 | 37 | if load_iteration: 38 | if load_iteration == -1: 39 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 40 | else: 41 | self.loaded_iter = load_iteration 42 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 43 | 44 | self.train_cameras = {} 45 | self.test_cameras = {} 46 | 47 | scene_info = sceneLoadTypeCallbacks[args.scene_type](args) 48 | 49 | self.time_interval = args.frame_interval 50 | self.gaussians.time_duration = scene_info.time_duration 51 | print("time duration: ", scene_info.time_duration) 52 | print("frame interval: ", self.time_interval) 53 | 54 | if not self.loaded_iter: 55 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 56 | dest_file.write(src_file.read()) 57 | json_cams = [] 58 | camlist = [] 59 | if scene_info.test_cameras: 60 | camlist.extend(scene_info.test_cameras) 61 | if scene_info.train_cameras: 62 | camlist.extend(scene_info.train_cameras) 63 | for id, cam in enumerate(camlist): 64 | json_cams.append(camera_to_JSON(id, cam)) 65 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 66 | json.dump(json_cams, file) 67 | 68 | if shuffle: 69 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 70 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 71 | 72 | self.cameras_extent = scene_info.nerf_normalization["radius"] 73 | self.resolution_scales = args.resolution_scales 74 | self.scale_index = len(self.resolution_scales) - 1 75 | for resolution_scale in self.resolution_scales: 76 | print("Loading Training Cameras") 77 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 78 | print("Loading Test Cameras") 79 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 80 | 81 | if self.loaded_iter: 82 | self.gaussians.load_ply(os.path.join(self.model_path, 83 | "point_cloud", 84 | "iteration_" + str(self.loaded_iter), 85 | "point_cloud.ply")) 86 | else: 87 | self.gaussians.create_from_pcd(scene_info.point_cloud, 1) 88 | 89 | def upScale(self): 90 | self.scale_index = max(0, self.scale_index - 1) 91 | 92 | def getTrainCameras(self): 93 | return self.train_cameras[self.resolution_scales[self.scale_index]] 94 | 95 | def getTestCameras(self, scale=1.0): 96 | return self.test_cameras[scale] 97 | 98 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrixCenterShift 18 | import kornia 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__(self, colmap_id, R, T, FoVx=None, FoVy=None, cx=None, cy=None, fx=None, fy=None, 23 | image=None, 24 | image_name=None, uid=0, 25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", timestamp=0.0, 26 | resolution=None, image_path=None, 27 | pts_depth=None, sky_mask=None 28 | ): 29 | super(Camera, self).__init__() 30 | 31 | self.uid = uid 32 | self.colmap_id = colmap_id 33 | self.R = R 34 | self.T = T 35 | self.FoVx = FoVx 36 | self.FoVy = FoVy 37 | self.image_name = image_name 38 | self.image = image 39 | self.cx = cx 40 | self.cy = cy 41 | self.fx = fx 42 | self.fy = fy 43 | self.resolution = resolution 44 | self.image_path = image_path 45 | 46 | try: 47 | self.data_device = torch.device(data_device) 48 | except Exception as e: 49 | print(e) 50 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 51 | self.data_device = torch.device("cuda") 52 | 53 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 54 | self.sky_mask = sky_mask.to(self.data_device) > 0 if sky_mask is not None else sky_mask 55 | self.pts_depth = pts_depth.to(self.data_device) if pts_depth is not None else pts_depth 56 | 57 | self.image_width = resolution[0] 58 | self.image_height = resolution[1] 59 | 60 | self.zfar = 1000.0 61 | self.znear = 0.01 62 | 63 | self.trans = trans 64 | self.scale = scale 65 | 66 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 67 | if cx is not None: 68 | self.FoVx = 2 * math.atan(0.5*self.image_width / fx) 69 | self.FoVy = 2 * math.atan(0.5*self.image_height / fy) 70 | self.projection_matrix = getProjectionMatrixCenterShift(self.znear, self.zfar, cx, cy, fx, fy, 71 | self.image_width, self.image_height).transpose(0, 1).cuda() 72 | else: 73 | self.cx = self.image_width / 2 74 | self.cy = self.image_height / 2 75 | self.fx = self.image_width / (2 * np.tan(self.FoVx * 0.5)) 76 | self.fy = self.image_height / (2 * np.tan(self.FoVy * 0.5)) 77 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, 78 | fovY=self.FoVy).transpose(0, 1).cuda() 79 | self.full_proj_transform = ( 80 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 81 | self.camera_center = self.world_view_transform.inverse()[3, :3] 82 | self.c2w = self.world_view_transform.transpose(0, 1).inverse() 83 | self.timestamp = timestamp 84 | self.grid = kornia.utils.create_meshgrid(self.image_height, self.image_width, normalized_coordinates=False, device='cuda')[0] 85 | 86 | def get_world_directions(self, train=False): 87 | u, v = self.grid.unbind(-1) 88 | if train: 89 | directions = torch.stack([(u-self.cx+torch.rand_like(u))/self.fx, 90 | (v-self.cy+torch.rand_like(v))/self.fy, 91 | torch.ones_like(u)], dim=0) 92 | else: 93 | directions = torch.stack([(u-self.cx+0.5)/self.fx, 94 | (v-self.cy+0.5)/self.fy, 95 | torch.ones_like(u)], dim=0) 96 | directions = F.normalize(directions, dim=0) 97 | directions = (self.c2w[:3, :3] @ directions.reshape(3, -1)).reshape(3, self.image_height, self.image_width) 98 | return directions 99 | 100 | -------------------------------------------------------------------------------- /scene/envlight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nvdiffrast.torch as dr 3 | 4 | 5 | class EnvLight(torch.nn.Module): 6 | 7 | def __init__(self, resolution=1024): 8 | super().__init__() 9 | self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda") 10 | self.base = torch.nn.Parameter( 11 | 0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True), 12 | ) 13 | 14 | def capture(self): 15 | return ( 16 | self.base, 17 | self.optimizer.state_dict(), 18 | ) 19 | 20 | def restore(self, model_args, training_args=None): 21 | self.base, opt_dict = model_args 22 | if training_args is not None: 23 | self.training_setup(training_args) 24 | self.optimizer.load_state_dict(opt_dict) 25 | 26 | def training_setup(self, training_args): 27 | self.optimizer = torch.optim.Adam(self.parameters(), lr=training_args.envmap_lr, eps=1e-15) 28 | 29 | def forward(self, l): 30 | l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape) 31 | l = l.contiguous() 32 | prefix = l.shape[:-1] 33 | if len(prefix) != 3: # reshape to [B, H, W, -1] 34 | l = l.reshape(1, 1, -1, l.shape[-1]) 35 | 36 | light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube') 37 | light = light.view(*prefix, -1) 38 | 39 | return light 40 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import math 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation, get_step_lr_func 15 | from torch import nn 16 | import os 17 | from plyfile import PlyData, PlyElement 18 | from utils.sh_utils import RGB2SH 19 | from simple_knn._C import distCUDA2 20 | from utils.graphics_utils import BasicPointCloud 21 | from utils.general_utils import strip_symmetric, build_scaling_rotation 22 | 23 | class GaussianModel: 24 | 25 | def setup_functions(self): 26 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 27 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 28 | actual_covariance = L @ L.transpose(1, 2) 29 | symm = strip_symmetric(actual_covariance) 30 | return symm 31 | 32 | self.scaling_activation = torch.exp 33 | self.scaling_inverse_activation = torch.log 34 | 35 | self.scaling_t_activation = torch.exp 36 | self.scaling_t_inverse_activation = torch.log 37 | 38 | self.covariance_activation = build_covariance_from_scaling_rotation 39 | 40 | self.opacity_activation = torch.sigmoid 41 | self.inverse_opacity_activation = inverse_sigmoid 42 | 43 | self.rotation_activation = torch.nn.functional.normalize 44 | 45 | def __init__(self, args): 46 | self.active_sh_degree = 0 47 | self.max_sh_degree = args.sh_degree 48 | self._xyz = torch.empty(0) 49 | self._features_dc = torch.empty(0) 50 | self._features_rest = torch.empty(0) 51 | self._scaling = torch.empty(0) 52 | self._rotation = torch.empty(0) 53 | self._opacity = torch.empty(0) 54 | self._t = torch.empty(0) 55 | self._scaling_t = torch.empty(0) 56 | self._velocity = torch.empty(0) 57 | 58 | self.max_radii2D = torch.empty(0) 59 | self.xyz_gradient_accum = torch.empty(0) 60 | self.t_gradient_accum = torch.empty(0) 61 | self.denom = torch.empty(0) 62 | 63 | self.optimizer = None 64 | self.percent_dense = 0 65 | self.spatial_lr_scale = 0 66 | 67 | self.time_duration = args.time_duration 68 | self.no_time_split = args.no_time_split 69 | self.t_grad = args.t_grad 70 | self.contract = args.contract 71 | self.t_init = args.t_init 72 | self.big_point_threshold = args.big_point_threshold 73 | 74 | self.T = args.cycle 75 | self.velocity_decay = args.velocity_decay 76 | self.random_init_point = args.random_init_point 77 | 78 | self.setup_functions() 79 | 80 | def capture(self): 81 | return ( 82 | self.active_sh_degree, 83 | self._xyz, 84 | self._features_dc, 85 | self._features_rest, 86 | self._scaling, 87 | self._rotation, 88 | self._opacity, 89 | self._t, 90 | self._scaling_t, 91 | self._velocity, 92 | self.max_radii2D, 93 | self.xyz_gradient_accum, 94 | self.t_gradient_accum, 95 | self.denom, 96 | self.optimizer.state_dict(), 97 | self.spatial_lr_scale, 98 | self.T, 99 | self.velocity_decay, 100 | ) 101 | 102 | def restore(self, model_args, training_args=None): 103 | (self.active_sh_degree, 104 | self._xyz, 105 | self._features_dc, 106 | self._features_rest, 107 | self._scaling, 108 | self._rotation, 109 | self._opacity, 110 | self._t, 111 | self._scaling_t, 112 | self._velocity, 113 | self.max_radii2D, 114 | xyz_gradient_accum, 115 | t_gradient_accum, 116 | denom, 117 | opt_dict, 118 | self.spatial_lr_scale, 119 | self.T, 120 | self.velocity_decay, 121 | ) = model_args 122 | self.setup_functions() 123 | if training_args is not None: 124 | self.training_setup(training_args) 125 | self.xyz_gradient_accum = xyz_gradient_accum 126 | self.t_gradient_accum = t_gradient_accum 127 | self.denom = denom 128 | self.optimizer.load_state_dict(opt_dict) 129 | 130 | @property 131 | def get_scaling(self): 132 | return self.scaling_activation(self._scaling) 133 | 134 | @property 135 | def get_scaling_t(self): 136 | return self.scaling_t_activation(self._scaling_t) 137 | 138 | @property 139 | def get_rotation(self): 140 | return self.rotation_activation(self._rotation) 141 | 142 | def get_xyz_SHM(self, t): 143 | a = 1/self.T * np.pi * 2 144 | return self._xyz + self._velocity*torch.sin((t-self._t)*a)/a 145 | 146 | @property 147 | def get_inst_velocity(self): 148 | return self._velocity*torch.exp(-self.get_scaling_t/self.T/2*self.velocity_decay) 149 | 150 | @property 151 | def get_xyz(self): 152 | return self._xyz 153 | 154 | @property 155 | def get_t(self): 156 | return self._t 157 | 158 | @property 159 | def get_features(self): 160 | features_dc = self._features_dc 161 | features_rest = self._features_rest 162 | return torch.cat((features_dc, features_rest), dim=1) 163 | 164 | @property 165 | def get_opacity(self): 166 | return self.opacity_activation(self._opacity) 167 | 168 | @property 169 | def get_max_sh_channels(self): 170 | return (self.max_sh_degree + 1) ** 2 171 | 172 | def get_marginal_t(self, timestamp): 173 | return torch.exp(-0.5 * (self.get_t - timestamp) ** 2 / self.get_scaling_t ** 2) 174 | 175 | def get_covariance(self, scaling_modifier=1): 176 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 177 | 178 | def oneupSHdegree(self): 179 | if self.active_sh_degree < self.max_sh_degree: 180 | self.active_sh_degree += 1 181 | 182 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): 183 | self.spatial_lr_scale = spatial_lr_scale 184 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 185 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 186 | features = torch.zeros((fused_color.shape[0], 3, self.get_max_sh_channels)).float().cuda() 187 | features[:, :3, 0] = fused_color 188 | features[:, 3:, 1:] = 0.0 189 | 190 | ## random up and far 191 | r_max = 100000 192 | r_min = 2 193 | num_sph = self.random_init_point 194 | 195 | theta = 2*torch.pi*torch.rand(num_sph) 196 | phi = (torch.pi/2*0.99*torch.rand(num_sph))**1.5 # x**a decay 197 | s = torch.rand(num_sph) 198 | r_1 = s*1/r_min+(1-s)*1/r_max 199 | r = 1/r_1 200 | pts_sph = torch.stack([r*torch.cos(theta)*torch.cos(phi), r*torch.sin(theta)*torch.cos(phi), r*torch.sin(phi)],dim=-1).cuda() 201 | 202 | r_rec = r_min 203 | num_rec = self.random_init_point 204 | pts_rec = torch.stack([r_rec*(torch.rand(num_rec)-0.5),r_rec*(torch.rand(num_rec)-0.5), 205 | r_rec*(torch.rand(num_rec))],dim=-1).cuda() 206 | 207 | pts_sph = torch.cat([pts_rec, pts_sph], dim=0) 208 | pts_sph[:,2] = -pts_sph[:,2]+1 209 | 210 | fused_point_cloud = torch.cat([fused_point_cloud, pts_sph], dim=0) 211 | features = torch.cat([features, 212 | torch.zeros([pts_sph.size(0), features.size(1), features.size(2)]).float().cuda()], 213 | dim=0) 214 | 215 | if pcd.time is None or pcd.time.shape[0] != fused_point_cloud.shape[0]: 216 | if pcd.time is None: 217 | time = (np.random.rand(pcd.points.shape[0], 1) * 1.2 - 0.1) * ( 218 | self.time_duration[1] - self.time_duration[0]) + self.time_duration[0] 219 | else: 220 | time = pcd.time 221 | 222 | if self.t_init < 1: 223 | random_times = (torch.rand(fused_point_cloud.shape[0]-pcd.points.shape[0], 1, device="cuda") * 1.2 - 0.1) * ( 224 | self.time_duration[1] - self.time_duration[0]) + self.time_duration[0] 225 | pts_times = torch.from_numpy(time.copy()).float().cuda() 226 | fused_times = torch.cat([pts_times, random_times], dim=0) 227 | else: 228 | fused_times = torch.full_like(fused_point_cloud[..., :1], 229 | 0.5 * (self.time_duration[1] + self.time_duration[0])) 230 | else: 231 | fused_times = torch.from_numpy(np.asarray(pcd.time.copy())).cuda().float() 232 | fused_times_sh = torch.full_like(pts_sph[..., :1], 0.5 * (self.time_duration[1] + self.time_duration[0])) 233 | fused_times = torch.cat([fused_times, fused_times_sh], dim=0) 234 | 235 | print("Number of points at initialization : ", fused_point_cloud.shape[0]) 236 | 237 | dist2 = torch.clamp_min(distCUDA2(fused_point_cloud), 0.0000001) 238 | scales = self.scaling_inverse_activation(torch.sqrt(dist2))[..., None].repeat(1, 3) 239 | 240 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 241 | rots[:, 0] = 1 242 | 243 | dist_t = torch.full_like(fused_times, (self.time_duration[1] - self.time_duration[0])*self.t_init) 244 | scales_t = self.scaling_t_inverse_activation(torch.sqrt(dist_t)) 245 | velocity = torch.full((fused_point_cloud.shape[0], 3), 0., device="cuda") 246 | 247 | opacities = inverse_sigmoid(0.01 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 248 | 249 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 250 | self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) 251 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) 252 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 253 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 254 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 255 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 256 | self._t = nn.Parameter(fused_times.requires_grad_(True)) 257 | self._scaling_t = nn.Parameter(scales_t.requires_grad_(True)) 258 | self._velocity = nn.Parameter(velocity.requires_grad_(True)) 259 | 260 | def training_setup(self, training_args): 261 | self.percent_dense = training_args.percent_dense 262 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 263 | self.t_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 264 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 265 | 266 | l = [ 267 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 268 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 269 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 270 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 271 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 272 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}, 273 | {'params': [self._t], 'lr': training_args.t_lr_init, "name": "t"}, 274 | {'params': [self._scaling_t], 'lr': training_args.scaling_t_lr, "name": "scaling_t"}, 275 | {'params': [self._velocity], 'lr': training_args.velocity_lr * self.spatial_lr_scale, "name": "velocity"}, 276 | ] 277 | 278 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 279 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 280 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 281 | lr_delay_mult=training_args.position_lr_delay_mult, 282 | max_steps=training_args.iterations) 283 | 284 | final_decay = training_args.position_lr_final / training_args.position_lr_init 285 | 286 | self.t_scheduler_args = get_expon_lr_func(lr_init=training_args.t_lr_init, 287 | lr_final=training_args.t_lr_init * final_decay, 288 | lr_delay_mult=training_args.position_lr_delay_mult, 289 | max_steps=training_args.iterations) 290 | 291 | def update_learning_rate(self, iteration): 292 | ''' Learning rate scheduling per step ''' 293 | for param_group in self.optimizer.param_groups: 294 | if param_group["name"] == "xyz": 295 | lr = self.xyz_scheduler_args(iteration) 296 | param_group['lr'] = lr 297 | if param_group["name"] == "t": 298 | lr = self.t_scheduler_args(iteration) 299 | param_group['lr'] = lr 300 | 301 | def reset_opacity(self): 302 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)) 303 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 304 | self._opacity = optimizable_tensors["opacity"] 305 | 306 | def replace_tensor_to_optimizer(self, tensor, name): 307 | optimizable_tensors = {} 308 | for group in self.optimizer.param_groups: 309 | if group["name"] == name: 310 | stored_state = self.optimizer.state.get(group['params'][0], None) 311 | stored_state["exp_avg"] = torch.zeros_like(tensor) 312 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 313 | 314 | del self.optimizer.state[group['params'][0]] 315 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 316 | self.optimizer.state[group['params'][0]] = stored_state 317 | 318 | optimizable_tensors[group["name"]] = group["params"][0] 319 | return optimizable_tensors 320 | 321 | def _prune_optimizer(self, mask): 322 | optimizable_tensors = {} 323 | for group in self.optimizer.param_groups: 324 | stored_state = self.optimizer.state.get(group['params'][0], None) 325 | if stored_state is not None: 326 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 327 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 328 | 329 | del self.optimizer.state[group['params'][0]] 330 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 331 | self.optimizer.state[group['params'][0]] = stored_state 332 | 333 | optimizable_tensors[group["name"]] = group["params"][0] 334 | else: 335 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 336 | optimizable_tensors[group["name"]] = group["params"][0] 337 | return optimizable_tensors 338 | 339 | def prune_points(self, mask): 340 | valid_points_mask = ~mask 341 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 342 | 343 | self._xyz = optimizable_tensors["xyz"] 344 | self._features_dc = optimizable_tensors["f_dc"] 345 | self._features_rest = optimizable_tensors["f_rest"] 346 | self._opacity = optimizable_tensors["opacity"] 347 | self._scaling = optimizable_tensors["scaling"] 348 | self._rotation = optimizable_tensors["rotation"] 349 | 350 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 351 | 352 | self.denom = self.denom[valid_points_mask] 353 | self.max_radii2D = self.max_radii2D[valid_points_mask] 354 | 355 | self._t = optimizable_tensors['t'] 356 | self._scaling_t = optimizable_tensors['scaling_t'] 357 | self._velocity = optimizable_tensors['velocity'] 358 | self.t_gradient_accum = self.t_gradient_accum[valid_points_mask] 359 | 360 | def cat_tensors_to_optimizer(self, tensors_dict): 361 | optimizable_tensors = {} 362 | for group in self.optimizer.param_groups: 363 | assert len(group["params"]) == 1 364 | extension_tensor = tensors_dict[group["name"]] 365 | stored_state = self.optimizer.state.get(group['params'][0], None) 366 | if stored_state is not None: 367 | 368 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), 369 | dim=0) 370 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), 371 | dim=0) 372 | 373 | del self.optimizer.state[group['params'][0]] 374 | group["params"][0] = nn.Parameter( 375 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 376 | self.optimizer.state[group['params'][0]] = stored_state 377 | 378 | optimizable_tensors[group["name"]] = group["params"][0] 379 | else: 380 | group["params"][0] = nn.Parameter( 381 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 382 | optimizable_tensors[group["name"]] = group["params"][0] 383 | 384 | return optimizable_tensors 385 | 386 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 387 | new_rotation, new_t, new_scaling_t, new_velocity): 388 | d = {"xyz": new_xyz, 389 | "f_dc": new_features_dc, 390 | "f_rest": new_features_rest, 391 | "opacity": new_opacities, 392 | "scaling": new_scaling, 393 | "rotation": new_rotation, 394 | "t": new_t, 395 | "scaling_t": new_scaling_t, 396 | "velocity": new_velocity, 397 | } 398 | 399 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 400 | self._xyz = optimizable_tensors["xyz"] 401 | self._features_dc = optimizable_tensors["f_dc"] 402 | self._features_rest = optimizable_tensors["f_rest"] 403 | self._opacity = optimizable_tensors["opacity"] 404 | self._scaling = optimizable_tensors["scaling"] 405 | self._rotation = optimizable_tensors["rotation"] 406 | self._t = optimizable_tensors['t'] 407 | self._scaling_t = optimizable_tensors['scaling_t'] 408 | self._velocity = optimizable_tensors['velocity'] 409 | self.t_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 410 | 411 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 412 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 413 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 414 | 415 | def densify_and_split(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, N=2, time_split=False, 416 | joint_sample=True): 417 | n_init_points = self.get_xyz.shape[0] 418 | # Extract points that satisfy the gradient condition 419 | padded_grad = torch.zeros((n_init_points), device="cuda") 420 | padded_grad[:grads.shape[0]] = grads.squeeze() 421 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 422 | 423 | if self.contract: 424 | scale_factor = self._xyz.norm(dim=-1)*scene_extent-1 # -0 425 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent 426 | else: 427 | scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent 428 | 429 | selected_pts_mask = torch.logical_and(selected_pts_mask, 430 | torch.max(self.get_scaling, 431 | dim=1).values > self.percent_dense * scene_extent*scale_factor) 432 | decay_factor = N*0.8 433 | if not self.no_time_split: 434 | N = N+1 435 | 436 | if time_split: 437 | padded_grad_t = torch.zeros((n_init_points), device="cuda") 438 | padded_grad_t[:grads_t.shape[0]] = grads_t.squeeze() 439 | selected_time_mask = torch.where(padded_grad_t >= grad_t_threshold, True, False) 440 | extend_thresh = self.percent_dense 441 | 442 | selected_time_mask = torch.logical_and(selected_time_mask, 443 | torch.max(self.get_scaling_t, dim=1).values > extend_thresh) 444 | if joint_sample: 445 | selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask) 446 | 447 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (decay_factor)) 448 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) 449 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) 450 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) 451 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) 452 | 453 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1) 454 | means = torch.zeros((stds.size(0), 3), device="cuda") 455 | samples = torch.normal(mean=means, std=stds) 456 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) 457 | xyz = self.get_xyz[selected_pts_mask] 458 | 459 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + xyz.repeat(N, 1) 460 | 461 | new_t = None 462 | new_scaling_t = None 463 | new_velocity = None 464 | stds_t = self.get_scaling_t[selected_pts_mask].repeat(N, 1) 465 | means_t = torch.zeros((stds_t.size(0), 1), device="cuda") 466 | samples_t = torch.normal(mean=means_t, std=stds_t) 467 | new_t = samples_t+self.get_t[selected_pts_mask].repeat(N, 1) 468 | 469 | new_scaling_t = self.scaling_t_inverse_activation( 470 | self.get_scaling_t[selected_pts_mask].repeat(N, 1)/ (decay_factor)) 471 | 472 | 473 | 474 | new_velocity = self._velocity[selected_pts_mask].repeat(N, 1) 475 | 476 | new_xyz = new_xyz + self.get_inst_velocity[selected_pts_mask].repeat(N, 1) * (samples_t) 477 | 478 | not_split_xyz_mask = torch.max(self.get_scaling[selected_pts_mask], dim=1).values < \ 479 | self.percent_dense * scene_extent*scale_factor[selected_pts_mask] 480 | new_scaling[not_split_xyz_mask.repeat(N)] = self.scaling_inverse_activation( 481 | self.get_scaling[selected_pts_mask].repeat(N, 1))[not_split_xyz_mask.repeat(N)] 482 | 483 | if time_split: 484 | not_split_t_mask = self.get_scaling_t[selected_pts_mask].squeeze() < extend_thresh 485 | new_scaling_t[not_split_t_mask.repeat(N)] = self.scaling_t_inverse_activation( 486 | self.get_scaling_t[selected_pts_mask].repeat(N, 1))[not_split_t_mask.repeat(N)] 487 | 488 | if self.no_time_split: 489 | new_scaling_t = self.scaling_t_inverse_activation( 490 | self.get_scaling_t[selected_pts_mask].repeat(N, 1)) 491 | 492 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, 493 | new_t, new_scaling_t, new_velocity) 494 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 495 | self.prune_points(prune_filter) 496 | 497 | def densify_and_clone(self, grads, grad_threshold, scene_extent, grads_t, grad_t_threshold, time_clone=False): 498 | t_scale_factor=self.get_scaling_t.clamp(0,self.T) 499 | t_scale_factor=torch.exp(-t_scale_factor/self.T).squeeze() 500 | 501 | if self.contract: 502 | scale_factor = self._xyz.norm(dim=-1)*scene_extent-1 503 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/scene_extent 504 | else: 505 | scale_factor = torch.ones_like(self._xyz)[:,0]/scene_extent 506 | 507 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 508 | selected_pts_mask = torch.logical_and(selected_pts_mask,torch.max(self.get_scaling,dim=1).values <= self.percent_dense * scene_extent*scale_factor) 509 | if time_clone: 510 | selected_time_mask = torch.where(torch.norm(grads_t, dim=-1) >= grad_t_threshold, True, False) 511 | extend_thresh = self.percent_dense 512 | selected_time_mask = torch.logical_and(selected_time_mask, 513 | torch.max(self.get_scaling_t, dim=1).values <= extend_thresh) 514 | selected_pts_mask = torch.logical_or(selected_pts_mask, selected_time_mask) 515 | 516 | new_xyz = self._xyz[selected_pts_mask] 517 | new_features_dc = self._features_dc[selected_pts_mask] 518 | new_features_rest = self._features_rest[selected_pts_mask] 519 | new_opacities = self._opacity[selected_pts_mask] 520 | new_scaling = self._scaling[selected_pts_mask] 521 | new_rotation = self._rotation[selected_pts_mask] 522 | new_t = None 523 | new_scaling_t = None 524 | new_velocity = None 525 | new_t = self._t[selected_pts_mask] 526 | new_scaling_t = self._scaling_t[selected_pts_mask] 527 | new_velocity = self._velocity[selected_pts_mask] 528 | 529 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 530 | new_rotation, new_t, new_scaling_t, new_velocity) 531 | 532 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, max_grad_t=None, prune_only=False): 533 | if not prune_only: 534 | grads = self.xyz_gradient_accum / self.denom 535 | grads[grads.isnan()] = 0.0 536 | grads_t = self.t_gradient_accum / self.denom 537 | grads_t[grads_t.isnan()] = 0.0 538 | 539 | if self.t_grad: 540 | self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=True) 541 | self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=True) 542 | else: 543 | self.densify_and_clone(grads, max_grad, extent, grads_t, max_grad_t, time_clone=False) 544 | self.densify_and_split(grads, max_grad, extent, grads_t, max_grad_t, time_split=False) 545 | 546 | prune_mask = (self.get_opacity < min_opacity).squeeze() 547 | 548 | if self.contract: 549 | scale_factor = self._xyz.norm(dim=-1)*extent-1 550 | scale_factor = torch.where(scale_factor<=1, 1, scale_factor)/extent 551 | else: 552 | scale_factor = torch.ones_like(self._xyz)[:,0]/extent 553 | 554 | if max_screen_size: 555 | big_points_vs = self.max_radii2D > max_screen_size 556 | big_points_ws = self.get_scaling.max(dim=1).values > self.big_point_threshold * extent * scale_factor ## ori 0.1 557 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 558 | self.prune_points(prune_mask) 559 | 560 | torch.cuda.empty_cache() 561 | 562 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 563 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, 564 | keepdim=True) 565 | self.denom[update_filter] += 1 566 | self.t_gradient_accum[update_filter] += self._t.grad.clone()[update_filter] 567 | -------------------------------------------------------------------------------- /scene/kittimot_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | import imageio 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly 10 | from pathlib import Path 11 | camera_ls = [2, 3] 12 | 13 | """ 14 | Most function brought from MARS 15 | https://github.com/OPEN-AIR-SUN/mars/blob/69b9bf9d992e6b9f4027dfdc2a741c2a33eef174/mars/data/mars_kitti_dataparser.py 16 | """ 17 | 18 | def pad_poses(p): 19 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 20 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 21 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 22 | 23 | 24 | def unpad_poses(p): 25 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 26 | return p[..., :3, :4] 27 | 28 | 29 | def transform_poses_pca(poses, fix_radius=0): 30 | """Transforms poses so principal components lie on XYZ axes. 31 | 32 | Args: 33 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 34 | 35 | Returns: 36 | A tuple (poses, transform), with the transformed poses and the applied 37 | camera_to_world transforms. 38 | 39 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161 40 | """ 41 | t = poses[:, :3, 3] 42 | t_mean = t.mean(axis=0) 43 | t = t - t_mean 44 | 45 | eigval, eigvec = np.linalg.eig(t.T @ t) 46 | # Sort eigenvectors in order of largest to smallest eigenvalue. 47 | inds = np.argsort(eigval)[::-1] 48 | eigvec = eigvec[:, inds] 49 | rot = eigvec.T 50 | if np.linalg.det(rot) < 0: 51 | rot = np.diag(np.array([1, 1, -1])) @ rot 52 | 53 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 54 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 55 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 56 | 57 | # Flip coordinate system if z component of y-axis is negative 58 | if poses_recentered.mean(axis=0)[2, 1] < 0: 59 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 60 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 61 | 62 | # Just make sure it's it in the [-1, 1]^3 cube 63 | if fix_radius>0: 64 | scale_factor = 1./fix_radius 65 | else: 66 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5) 67 | scale_factor = min(1 / 10, scale_factor) 68 | 69 | poses_recentered[:, :3, 3] *= scale_factor 70 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 71 | 72 | return poses_recentered, transform, scale_factor 73 | 74 | def kitti_string_to_float(str): 75 | return float(str.split("e")[0]) * 10 ** int(str.split("e")[1]) 76 | 77 | 78 | def get_rotation(roll, pitch, heading): 79 | s_heading = np.sin(heading) 80 | c_heading = np.cos(heading) 81 | rot_z = np.array([[c_heading, -s_heading, 0], [s_heading, c_heading, 0], [0, 0, 1]]) 82 | 83 | s_pitch = np.sin(pitch) 84 | c_pitch = np.cos(pitch) 85 | rot_y = np.array([[c_pitch, 0, s_pitch], [0, 1, 0], [-s_pitch, 0, c_pitch]]) 86 | 87 | s_roll = np.sin(roll) 88 | c_roll = np.cos(roll) 89 | rot_x = np.array([[1, 0, 0], [0, c_roll, -s_roll], [0, s_roll, c_roll]]) 90 | 91 | rot = np.matmul(rot_z, np.matmul(rot_y, rot_x)) 92 | 93 | return rot 94 | 95 | 96 | def tracking_calib_from_txt(calibration_path): 97 | """ 98 | Extract tracking calibration information from a KITTI tracking calibration file. 99 | 100 | This function reads a KITTI tracking calibration file and extracts the relevant 101 | calibration information, including projection matrices and transformation matrices 102 | for camera, LiDAR, and IMU coordinate systems. 103 | 104 | Args: 105 | calibration_path (str): Path to the KITTI tracking calibration file. 106 | 107 | Returns: 108 | dict: A dictionary containing the following calibration information: 109 | P0, P1, P2, P3 (np.array): 3x4 projection matrices for the cameras. 110 | Tr_cam2camrect (np.array): 4x4 transformation matrix from camera to rectified camera coordinates. 111 | Tr_velo2cam (np.array): 4x4 transformation matrix from LiDAR to camera coordinates. 112 | Tr_imu2velo (np.array): 4x4 transformation matrix from IMU to LiDAR coordinates. 113 | """ 114 | # Read the calibration file 115 | f = open(calibration_path) 116 | calib_str = f.read().splitlines() 117 | 118 | # Process the calibration data 119 | calibs = [] 120 | for calibration in calib_str: 121 | calibs.append(np.array([kitti_string_to_float(val) for val in calibration.split()[1:]])) 122 | 123 | # Extract the projection matrices 124 | P0 = np.reshape(calibs[0], [3, 4]) 125 | P1 = np.reshape(calibs[1], [3, 4]) 126 | P2 = np.reshape(calibs[2], [3, 4]) 127 | P3 = np.reshape(calibs[3], [3, 4]) 128 | 129 | # Extract the transformation matrix for camera to rectified camera coordinates 130 | Tr_cam2camrect = np.eye(4) 131 | R_rect = np.reshape(calibs[4], [3, 3]) 132 | Tr_cam2camrect[:3, :3] = R_rect 133 | 134 | # Extract the transformation matrices for LiDAR to camera and IMU to LiDAR coordinates 135 | Tr_velo2cam = np.concatenate([np.reshape(calibs[5], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0) 136 | Tr_imu2velo = np.concatenate([np.reshape(calibs[6], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0) 137 | 138 | return { 139 | "P0": P0, 140 | "P1": P1, 141 | "P2": P2, 142 | "P3": P3, 143 | "Tr_cam2camrect": Tr_cam2camrect, 144 | "Tr_velo2cam": Tr_velo2cam, 145 | "Tr_imu2velo": Tr_imu2velo, 146 | } 147 | 148 | 149 | def calib_from_txt(calibration_path): 150 | """ 151 | Read the calibration files and extract the required transformation matrices and focal length. 152 | 153 | Args: 154 | calibration_path (str): The path to the directory containing the calibration files. 155 | 156 | Returns: 157 | tuple: A tuple containing the following elements: 158 | traimu2v (np.array): 4x4 transformation matrix from IMU to Velodyne coordinates. 159 | v2c (np.array): 4x4 transformation matrix from Velodyne to left camera coordinates. 160 | c2leftRGB (np.array): 4x4 transformation matrix from left camera to rectified left camera coordinates. 161 | c2rightRGB (np.array): 4x4 transformation matrix from right camera to rectified right camera coordinates. 162 | focal (float): Focal length of the left camera. 163 | """ 164 | c2c = [] 165 | 166 | # Read and parse the camera-to-camera calibration file 167 | f = open(os.path.join(calibration_path, "calib_cam_to_cam.txt"), "r") 168 | cam_to_cam_str = f.read() 169 | [left_cam, right_cam] = cam_to_cam_str.split("S_02: ")[1].split("S_03: ") 170 | cam_to_cam_ls = [left_cam, right_cam] 171 | 172 | # Extract the transformation matrices for left and right cameras 173 | for i, cam_str in enumerate(cam_to_cam_ls): 174 | r_str, t_str = cam_str.split("R_0" + str(i + 2) + ": ")[1].split("\nT_0" + str(i + 2) + ": ") 175 | t_str = t_str.split("\n")[0] 176 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")]) 177 | R = np.reshape(R, [3, 3]) 178 | t = np.array([kitti_string_to_float(t) for t in t_str.split(" ")]) 179 | Tr = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]) 180 | 181 | t_str_rect, s_rect_part = cam_str.split("\nT_0" + str(i + 2) + ": ")[1].split("\nS_rect_0" + str(i + 2) + ": ") 182 | s_rect_str, r_rect_part = s_rect_part.split("\nR_rect_0" + str(i + 2) + ": ") 183 | r_rect_str = r_rect_part.split("\nP_rect_0" + str(i + 2) + ": ")[0] 184 | R_rect = np.array([kitti_string_to_float(r) for r in r_rect_str.split(" ")]) 185 | R_rect = np.reshape(R_rect, [3, 3]) 186 | t_rect = np.array([kitti_string_to_float(t) for t in t_str_rect.split(" ")]) 187 | Tr_rect = np.concatenate( 188 | [np.concatenate([R_rect, t_rect[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]] 189 | ) 190 | 191 | c2c.append(Tr_rect) 192 | 193 | c2leftRGB = c2c[0] 194 | c2rightRGB = c2c[1] 195 | 196 | # Read and parse the Velodyne-to-camera calibration file 197 | f = open(os.path.join(calibration_path, "calib_velo_to_cam.txt"), "r") 198 | velo_to_cam_str = f.read() 199 | r_str, t_str = velo_to_cam_str.split("R: ")[1].split("\nT: ") 200 | t_str = t_str.split("\n")[0] 201 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")]) 202 | R = np.reshape(R, [3, 3]) 203 | t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")]) 204 | v2c = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]) 205 | 206 | # Read and parse the IMU-to-Velodyne calibration file 207 | f = open(os.path.join(calibration_path, "calib_imu_to_velo.txt"), "r") 208 | imu_to_velo_str = f.read() 209 | r_str, t_str = imu_to_velo_str.split("R: ")[1].split("\nT: ") 210 | R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")]) 211 | R = np.reshape(R, [3, 3]) 212 | t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")]) 213 | imu2v = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]) 214 | 215 | # Extract the focal length of the left camera 216 | focal = kitti_string_to_float(left_cam.split("P_rect_02: ")[1].split()[0]) 217 | 218 | return imu2v, v2c, c2leftRGB, c2rightRGB, focal 219 | 220 | 221 | def get_poses_calibration(basedir, oxts_path_tracking=None, selected_frames=None): 222 | """ 223 | Extract poses and calibration information from the KITTI dataset. 224 | 225 | This function processes the OXTS data (GPS/IMU) and extracts the 226 | pose information (translation and rotation) for each frame. It also 227 | retrieves the calibration information (transformation matrices and focal length) 228 | required for further processing. 229 | 230 | Args: 231 | basedir (str): The base directory containing the KITTI dataset. 232 | oxts_path_tracking (str, optional): Path to the OXTS data file for tracking sequences. 233 | If not provided, the function will look for OXTS data in the basedir. 234 | selected_frames (list, optional): A list of frame indices to process. 235 | If not provided, all frames in the dataset will be processed. 236 | 237 | Returns: 238 | tuple: A tuple containing the following elements: 239 | poses (np.array): An array of 4x4 pose matrices representing the vehicle's 240 | position and orientation for each frame (IMU pose). 241 | calibrations (dict): A dictionary containing the transformation matrices 242 | and focal length obtained from the calibration files. 243 | focal (float): The focal length of the left camera. 244 | """ 245 | 246 | def oxts_to_pose(oxts): 247 | """ 248 | OXTS (Oxford Technical Solutions) data typically refers to the data generated by an Inertial and GPS Navigation System (INS/GPS) that is used to provide accurate position, orientation, and velocity information for a moving platform, such as a vehicle. In the context of the KITTI dataset, OXTS data is used to provide the ground truth for the vehicle's trajectory and 6 degrees of freedom (6-DoF) motion, which is essential for evaluating and benchmarking various computer vision and robotics algorithms, such as visual odometry, SLAM, and object detection. 249 | 250 | The OXTS data contains several important measurements: 251 | 252 | 1. Latitude, longitude, and altitude: These are the global coordinates of the moving platform. 253 | 2. Roll, pitch, and yaw (heading): These are the orientation angles of the platform, usually given in Euler angles. 254 | 3. Velocity (north, east, and down): These are the linear velocities of the platform in the local navigation frame. 255 | 4. Accelerations (ax, ay, az): These are the linear accelerations in the platform's body frame. 256 | 5. Angular rates (wx, wy, wz): These are the angular rates (also known as angular velocities) of the platform in its body frame. 257 | 258 | In the KITTI dataset, the OXTS data is stored as plain text files with each line corresponding to a timestamp. Each line in the file contains the aforementioned measurements, which are used to compute the ground truth trajectory and 6-DoF motion of the vehicle. This information can be further used for calibration, data synchronization, and performance evaluation of various algorithms. 259 | """ 260 | poses = [] 261 | 262 | def latlon_to_mercator(lat, lon, s): 263 | """ 264 | Converts latitude and longitude coordinates to Mercator coordinates (x, y) using the given scale factor. 265 | 266 | The Mercator projection is a widely used cylindrical map projection that represents the Earth's surface 267 | as a flat, rectangular grid, distorting the size of geographical features in higher latitudes. 268 | This function uses the scale factor 's' to control the amount of distortion in the projection. 269 | 270 | Args: 271 | lat (float): Latitude in degrees, range: -90 to 90. 272 | lon (float): Longitude in degrees, range: -180 to 180. 273 | s (float): Scale factor, typically the cosine of the reference latitude. 274 | 275 | Returns: 276 | list: A list containing the Mercator coordinates [x, y] in meters. 277 | """ 278 | r = 6378137.0 # the Earth's equatorial radius in meters 279 | x = s * r * ((np.pi * lon) / 180) 280 | y = s * r * np.log(np.tan((np.pi * (90 + lat)) / 360)) 281 | return [x, y] 282 | 283 | # Compute the initial scale and pose based on the selected frames 284 | if selected_frames is None: 285 | lat0 = oxts[0][0] 286 | scale = np.cos(lat0 * np.pi / 180) 287 | pose_0_inv = None 288 | else: 289 | oxts0 = oxts[selected_frames[0][0]] 290 | lat0 = oxts0[0] 291 | scale = np.cos(lat0 * np.pi / 180) 292 | 293 | pose_i = np.eye(4) 294 | 295 | [x, y] = latlon_to_mercator(oxts0[0], oxts0[1], scale) 296 | z = oxts0[2] 297 | translation = np.array([x, y, z]) 298 | rotation = get_rotation(oxts0[3], oxts0[4], oxts0[5]) 299 | pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1) 300 | pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3]) 301 | 302 | # Iterate through the OXTS data and compute the corresponding pose matrices 303 | for oxts_val in oxts: 304 | pose_i = np.zeros([4, 4]) 305 | pose_i[3, 3] = 1 306 | 307 | [x, y] = latlon_to_mercator(oxts_val[0], oxts_val[1], scale) 308 | z = oxts_val[2] 309 | translation = np.array([x, y, z]) 310 | 311 | roll = oxts_val[3] 312 | pitch = oxts_val[4] 313 | heading = oxts_val[5] 314 | rotation = get_rotation(roll, pitch, heading) # (3,3) 315 | 316 | pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1) # (4, 4) 317 | if pose_0_inv is None: 318 | pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3]) 319 | 320 | pose_i = np.matmul(pose_0_inv, pose_i) 321 | poses.append(pose_i) 322 | 323 | return np.array(poses) 324 | 325 | # If there is no tracking path specified, use the default path 326 | if oxts_path_tracking is None: 327 | oxts_path = os.path.join(basedir, "oxts/data") 328 | oxts = np.array([np.loadtxt(os.path.join(oxts_path, file)) for file in sorted(os.listdir(oxts_path))]) 329 | calibration_path = os.path.dirname(basedir) 330 | 331 | calibrations = calib_from_txt(calibration_path) 332 | 333 | focal = calibrations[4] 334 | 335 | poses = oxts_to_pose(oxts) 336 | 337 | # If a tracking path is specified, use it to load OXTS data and compute the poses 338 | else: 339 | oxts_tracking = np.loadtxt(oxts_path_tracking) 340 | poses = oxts_to_pose(oxts_tracking) # (n_frames, 4, 4) 341 | calibrations = None 342 | focal = None 343 | # Set velodyne close to z = 0 344 | # poses[:, 2, 3] -= 0.8 345 | 346 | # Return the poses, calibrations, and focal length 347 | return poses, calibrations, focal 348 | 349 | 350 | def invert_transformation(rot, t): 351 | t = np.matmul(-rot.T, t) 352 | inv_translation = np.concatenate([rot.T, t[:, None]], axis=1) 353 | return np.concatenate([inv_translation, np.array([[0.0, 0.0, 0.0, 1.0]])]) 354 | 355 | 356 | def get_camera_poses_tracking(poses_velo_w_tracking, tracking_calibration, selected_frames, scene_no=None): 357 | exp = False 358 | camera_poses = [] 359 | 360 | opengl2kitti = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 361 | 362 | start_frame = selected_frames[0] 363 | end_frame = selected_frames[1] 364 | 365 | ##################### 366 | # Debug Camera offset 367 | if scene_no == 2: 368 | yaw = np.deg2rad(0.7) ## Affects camera rig roll: High --> counterclockwise 369 | pitch = np.deg2rad(-0.5) ## Affects camera rig yaw: High --> Turn Right 370 | # pitch = np.deg2rad(-0.97) 371 | roll = np.deg2rad(0.9) ## Affects camera rig pitch: High --> up 372 | # roll = np.deg2rad(1.2) 373 | elif scene_no == 1: 374 | if exp: 375 | yaw = np.deg2rad(0.3) ## Affects camera rig roll: High --> counterclockwise 376 | pitch = np.deg2rad(-0.6) ## Affects camera rig yaw: High --> Turn Right 377 | # pitch = np.deg2rad(-0.97) 378 | roll = np.deg2rad(0.75) ## Affects camera rig pitch: High --> up 379 | # roll = np.deg2rad(1.2) 380 | else: 381 | yaw = np.deg2rad(0.5) ## Affects camera rig roll: High --> counterclockwise 382 | pitch = np.deg2rad(-0.5) ## Affects camera rig yaw: High --> Turn Right 383 | roll = np.deg2rad(0.75) ## Affects camera rig pitch: High --> up 384 | else: 385 | yaw = np.deg2rad(0.05) 386 | pitch = np.deg2rad(-0.75) 387 | # pitch = np.deg2rad(-0.97) 388 | roll = np.deg2rad(1.05) 389 | # roll = np.deg2rad(1.2) 390 | 391 | cam_debug = np.eye(4) 392 | cam_debug[:3, :3] = get_rotation(roll, pitch, yaw) 393 | 394 | Tr_cam2camrect = tracking_calibration["Tr_cam2camrect"] 395 | Tr_cam2camrect = np.matmul(Tr_cam2camrect, cam_debug) 396 | Tr_camrect2cam = invert_transformation(Tr_cam2camrect[:3, :3], Tr_cam2camrect[:3, 3]) 397 | Tr_velo2cam = tracking_calibration["Tr_velo2cam"] 398 | Tr_cam2velo = invert_transformation(Tr_velo2cam[:3, :3], Tr_velo2cam[:3, 3]) 399 | 400 | camera_poses_imu = [] 401 | for cam in camera_ls: 402 | Tr_camrect2cam_i = tracking_calibration["Tr_camrect2cam0" + str(cam)] 403 | Tr_cam_i2camrect = invert_transformation(Tr_camrect2cam_i[:3, :3], Tr_camrect2cam_i[:3, 3]) 404 | # transform camera axis from kitti to opengl for nerf: 405 | cam_i_camrect = np.matmul(Tr_cam_i2camrect, opengl2kitti) 406 | cam_i_cam0 = np.matmul(Tr_camrect2cam, cam_i_camrect) 407 | cam_i_velo = np.matmul(Tr_cam2velo, cam_i_cam0) 408 | 409 | cam_i_w = np.matmul(poses_velo_w_tracking, cam_i_velo) 410 | camera_poses_imu.append(cam_i_w) 411 | 412 | for i, cam in enumerate(camera_ls): 413 | for frame_no in range(start_frame, end_frame + 1): 414 | camera_poses.append(camera_poses_imu[i][frame_no]) 415 | 416 | return np.array(camera_poses) 417 | 418 | 419 | def get_scene_images_tracking(tracking_path, sequence, selected_frames): 420 | [start_frame, end_frame] = selected_frames 421 | img_name = [] 422 | sky_name = [] 423 | 424 | left_img_path = os.path.join(os.path.join(tracking_path, "image_02"), sequence) 425 | right_img_path = os.path.join(os.path.join(tracking_path, "image_03"), sequence) 426 | 427 | left_sky_path = os.path.join(os.path.join(tracking_path, "sky_02"), sequence) 428 | right_sky_path = os.path.join(os.path.join(tracking_path, "sky_03"), sequence) 429 | 430 | for frame_dir in [left_img_path, right_img_path]: 431 | for frame_no in range(len(os.listdir(left_img_path))): 432 | if start_frame <= frame_no <= end_frame: 433 | frame = sorted(os.listdir(frame_dir))[frame_no] 434 | fname = os.path.join(frame_dir, frame) 435 | img_name.append(fname) 436 | 437 | for frame_dir in [left_sky_path, right_sky_path]: 438 | for frame_no in range(len(os.listdir(left_sky_path))): 439 | if start_frame <= frame_no <= end_frame: 440 | frame = sorted(os.listdir(frame_dir))[frame_no] 441 | fname = os.path.join(frame_dir, frame) 442 | sky_name.append(fname) 443 | 444 | return img_name, sky_name 445 | 446 | def rotation_matrix(a, b): 447 | """Compute the rotation matrix that rotates vector a to vector b. 448 | 449 | Args: 450 | a: The vector to rotate. 451 | b: The vector to rotate to. 452 | Returns: 453 | The rotation matrix. 454 | """ 455 | a = a / torch.linalg.norm(a) 456 | b = b / torch.linalg.norm(b) 457 | v = torch.cross(a, b) 458 | c = torch.dot(a, b) 459 | # If vectors are exactly opposite, we add a little noise to one of them 460 | if c < -1 + 1e-8: 461 | eps = (torch.rand(3) - 0.5) * 0.01 462 | return rotation_matrix(a + eps, b) 463 | s = torch.linalg.norm(v) 464 | skew_sym_mat = torch.Tensor( 465 | [ 466 | [0, -v[2], v[1]], 467 | [v[2], 0, -v[0]], 468 | [-v[1], v[0], 0], 469 | ] 470 | ) 471 | return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8)) 472 | 473 | def auto_orient_and_center_poses( 474 | poses, 475 | ): 476 | """ 477 | From nerfstudio 478 | https://github.com/nerfstudio-project/nerfstudio/blob/8e0c68754b2c440e2d83864fac586cddcac52dc4/nerfstudio/cameras/camera_utils.py#L515 479 | """ 480 | origins = poses[..., :3, 3] 481 | mean_origin = torch.mean(origins, dim=0) 482 | translation = mean_origin 483 | up = torch.mean(poses[:, :3, 1], dim=0) 484 | up = up / torch.linalg.norm(up) 485 | rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) 486 | transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1) 487 | oriented_poses = transform @ poses 488 | return oriented_poses, transform 489 | 490 | def readKittiMotInfo(args): 491 | cam_infos = [] 492 | points = [] 493 | points_time = [] 494 | scale_factor = 1.0 495 | 496 | basedir = args.source_path 497 | scene_id = basedir[-4:] # check 498 | kitti_scene_no = int(scene_id) 499 | tracking_path = basedir[:-13] # check 500 | calibration_path = os.path.join(os.path.join(tracking_path, "calib"), scene_id + ".txt") 501 | oxts_path_tracking = os.path.join(os.path.join(tracking_path, "oxts"), scene_id + ".txt") 502 | 503 | tracking_calibration = tracking_calib_from_txt(calibration_path) 504 | focal_X = tracking_calibration["P2"][0, 0] 505 | focal_Y = tracking_calibration["P2"][1, 1] 506 | poses_imu_w_tracking, _, _ = get_poses_calibration(basedir, oxts_path_tracking) # (n_frames, 4, 4) imu pose 507 | 508 | tr_imu2velo = tracking_calibration["Tr_imu2velo"] 509 | tr_velo2imu = invert_transformation(tr_imu2velo[:3, :3], tr_imu2velo[:3, 3]) 510 | poses_velo_w_tracking = np.matmul(poses_imu_w_tracking, tr_velo2imu) # (n_frames, 4, 4) velodyne pose 511 | 512 | # Get camera Poses camare id: 02, 03 513 | for cam_i in range(2): 514 | transformation = np.eye(4) 515 | projection = tracking_calibration["P" + str(cam_i + 2)] # rectified camera coordinate system -> image 516 | K_inv = np.linalg.inv(projection[:3, :3]) 517 | R_t = projection[:3, 3] 518 | t_crect2c = np.matmul(K_inv, R_t) 519 | transformation[:3, 3] = t_crect2c 520 | tracking_calibration["Tr_camrect2cam0" + str(cam_i + 2)] = transformation 521 | 522 | first_frame = args.start_frame 523 | last_frame = args.end_frame 524 | 525 | frame_num = last_frame-first_frame+1 526 | if args.frame_interval > 0: 527 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2] 528 | else: 529 | time_duration = args.time_duration 530 | 531 | selected_frames = [first_frame, last_frame] 532 | sequ_frames = selected_frames 533 | 534 | cam_poses_tracking = get_camera_poses_tracking( 535 | poses_velo_w_tracking, tracking_calibration, sequ_frames, kitti_scene_no 536 | ) 537 | poses_velo_w_tracking = poses_velo_w_tracking[first_frame:last_frame + 1] 538 | 539 | # Orients and centers the poses 540 | oriented = torch.from_numpy(np.array(cam_poses_tracking).astype(np.float32)) # (n_frames, 3, 4) 541 | oriented, transform_matrix = auto_orient_and_center_poses( 542 | oriented 543 | ) # oriented (n_frames, 3, 4), transform_matrix (3, 4) 544 | row = torch.tensor([0, 0, 0, 1], dtype=torch.float32) 545 | zeros = torch.zeros(oriented.shape[0], 1, 4) 546 | oriented = torch.cat([oriented, zeros], dim=1) 547 | oriented[:, -1] = row # (n_frames, 4, 4) 548 | transform_matrix = torch.cat([transform_matrix, row[None, :]], dim=0) # (4, 4) 549 | cam_poses_tracking = oriented.numpy() 550 | transform_matrix = transform_matrix.numpy() 551 | 552 | image_filenames, sky_filenames = get_scene_images_tracking( 553 | tracking_path, scene_id, sequ_frames) 554 | 555 | # # Align Axis with vkitti axis 556 | poses = cam_poses_tracking.astype(np.float32) 557 | poses[:, :, 1:3] *= -1 558 | 559 | test_load_image = imageio.imread(image_filenames[0]) 560 | image_height, image_width = test_load_image.shape[:2] 561 | cx, cy = image_width / 2.0, image_height / 2.0 562 | poses[..., :3, 3] *= scale_factor 563 | 564 | c2ws = poses 565 | for idx in tqdm(range(len(c2ws)), desc="Loading data"): 566 | c2w = c2ws[idx] 567 | w2c = np.linalg.inv(c2w) 568 | image_path = image_filenames[idx] 569 | image_name = os.path.basename(image_path)[:-4] 570 | sky_path = sky_filenames[idx] 571 | im_data = Image.open(image_path) 572 | W, H = im_data.size 573 | image = np.array(im_data) / 255. 574 | 575 | sky_mask = cv2.imread(sky_path) 576 | 577 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * (idx % (len(c2ws) // 2)) / (len(c2ws) // 2 - 1) 578 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 579 | T = w2c[:3, 3] 580 | 581 | if idx < len(c2ws) / 2: 582 | point = np.fromfile(os.path.join(tracking_path, "velodyne", scene_id, image_name + ".bin"), dtype=np.float32).reshape(-1, 4) 583 | point_xyz = point[:, :3] 584 | point_xyz_world = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ poses_velo_w_tracking[idx].T)[:, :3] 585 | points.append(point_xyz_world) 586 | point_time = np.full_like(point_xyz_world[:, :1], timestamp) 587 | points_time.append(point_time) 588 | frame_num = len(c2ws) // 2 589 | point_xyz = points[idx%frame_num] 590 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1)@ transform_matrix.T @ w2c.T)[:, :3]*scale_factor 591 | 592 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, 593 | image=image, 594 | image_path=image_filenames[idx], image_name=image_filenames[idx], 595 | width=W, height=H, timestamp=timestamp, 596 | fx=focal_X, fy=focal_Y, cx=cx, cy=cy, sky_mask=sky_mask, 597 | pointcloud_camera=point_camera)) 598 | 599 | if args.debug_cuda and idx > 5: 600 | break 601 | pointcloud = np.concatenate(points, axis=0) 602 | pointcloud = (np.concatenate([pointcloud, np.ones_like(pointcloud[:,:1])], axis=-1) @ transform_matrix.T)[:, :3] 603 | 604 | pointcloud_timestamp = np.concatenate(points_time, axis=0) 605 | 606 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True) 607 | pointcloud = pointcloud[indices] 608 | pointcloud_timestamp = pointcloud_timestamp[indices] 609 | 610 | # normalize poses 611 | w2cs = np.zeros((len(cam_infos), 4, 4)) 612 | Rs = np.stack([c.R for c in cam_infos], axis=0) 613 | Ts = np.stack([c.T for c in cam_infos], axis=0) 614 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1)) 615 | w2cs[:, :3, 3] = Ts 616 | w2cs[:, 3, 3] = 1 617 | c2ws = unpad_poses(np.linalg.inv(w2cs)) 618 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius) 619 | c2ws = pad_poses(c2ws) 620 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")): 621 | c2w = c2ws[idx] 622 | w2c = np.linalg.inv(c2w) 623 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 624 | cam_info.T[:] = w2c[:3, 3] 625 | cam_info.pointcloud_camera[:] *= scale_factor 626 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3] 627 | 628 | if args.eval: 629 | num_frame = len(cam_infos)//2 630 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx % num_frame + 1) % args.testhold != 0] 631 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx % num_frame + 1) % args.testhold == 0] 632 | else: 633 | train_cam_infos = cam_infos 634 | test_cam_infos = [] 635 | 636 | # for kitti have some static ego videos, we dont calculate radius here 637 | nerf_normalization = getNerfppNorm(train_cam_infos) 638 | nerf_normalization['radius'] = 1 639 | 640 | ply_path = os.path.join(args.source_path, "points3d.ply") 641 | if not os.path.exists(ply_path): 642 | rgbs = np.random.random((pointcloud.shape[0], 3)) 643 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp) 644 | try: 645 | pcd = fetchPly(ply_path) 646 | except: 647 | pcd = None 648 | 649 | time_interval = (time_duration[1] - time_duration[0]) / (frame_num - 1) 650 | 651 | 652 | scene_info = SceneInfo(point_cloud=pcd, 653 | train_cameras=train_cam_infos, 654 | test_cameras=test_cam_infos, 655 | nerf_normalization=nerf_normalization, 656 | ply_path=ply_path, 657 | time_interval=time_interval) 658 | 659 | return scene_info -------------------------------------------------------------------------------- /scene/scene_utils.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | import numpy as np 3 | from utils.graphics_utils import getWorld2View2 4 | from scene.gaussian_model import BasicPointCloud 5 | from plyfile import PlyData, PlyElement 6 | 7 | 8 | class CameraInfo(NamedTuple): 9 | uid: int 10 | R: np.array 11 | T: np.array 12 | image: np.array 13 | image_path: str 14 | image_name: str 15 | width: int 16 | height: int 17 | sky_mask: np.array = None 18 | timestamp: float = 0.0 19 | FovY: float = None 20 | FovX: float = None 21 | fx: float = None 22 | fy: float = None 23 | cx: float = None 24 | cy: float = None 25 | pointcloud_camera: np.array = None 26 | 27 | class SceneInfo(NamedTuple): 28 | point_cloud: BasicPointCloud 29 | train_cameras: list 30 | test_cameras: list 31 | nerf_normalization: dict 32 | ply_path: str 33 | time_interval: float = 0.02 34 | time_duration: list = [-0.5, 0.5] 35 | 36 | def getNerfppNorm(cam_info): 37 | def get_center_and_diag(cam_centers): 38 | cam_centers = np.hstack(cam_centers) 39 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 40 | center = avg_cam_center 41 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 42 | diagonal = np.max(dist) 43 | return center.flatten(), diagonal 44 | 45 | cam_centers = [] 46 | 47 | for cam in cam_info: 48 | W2C = getWorld2View2(cam.R, cam.T) 49 | C2W = np.linalg.inv(W2C) 50 | cam_centers.append(C2W[:3, 3:4]) 51 | 52 | center, diagonal = get_center_and_diag(cam_centers) 53 | radius = diagonal * 1.1 54 | 55 | translate = -center 56 | 57 | return {"translate": translate, "radius": radius} 58 | 59 | 60 | def fetchPly(path): 61 | plydata = PlyData.read(path) 62 | vertices = plydata['vertex'] 63 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 64 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 65 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 66 | if 'time' in vertices: 67 | timestamp = vertices['time'][:, None] 68 | else: 69 | timestamp = None 70 | return BasicPointCloud(points=positions, colors=colors, normals=normals, time=timestamp) 71 | 72 | 73 | def storePly(path, xyz, rgb, timestamp=None): 74 | # Define the dtype for the structured array 75 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 76 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 77 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'), 78 | ('time', 'f4')] 79 | 80 | normals = np.zeros_like(xyz) 81 | if timestamp is None: 82 | timestamp = np.zeros_like(xyz[:, :1]) 83 | 84 | elements = np.empty(xyz.shape[0], dtype=dtype) 85 | attributes = np.concatenate((xyz, normals, rgb, timestamp), axis=1) 86 | elements[:] = list(map(tuple, attributes)) 87 | 88 | # Create the PlyData object and write to file 89 | vertex_element = PlyElement.describe(elements, 'vertex') 90 | ply_data = PlyData([vertex_element]) 91 | ply_data.write(path) 92 | -------------------------------------------------------------------------------- /scene/waymo_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly 6 | from utils.graphics_utils import BasicPointCloud 7 | 8 | 9 | def pad_poses(p): 10 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 11 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 12 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 13 | 14 | 15 | def unpad_poses(p): 16 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 17 | return p[..., :3, :4] 18 | 19 | 20 | def transform_poses_pca(poses, fix_radius=0): 21 | """Transforms poses so principal components lie on XYZ axes. 22 | 23 | Args: 24 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 25 | 26 | Returns: 27 | A tuple (poses, transform), with the transformed poses and the applied 28 | camera_to_world transforms. 29 | 30 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161 31 | """ 32 | t = poses[:, :3, 3] 33 | t_mean = t.mean(axis=0) 34 | t = t - t_mean 35 | 36 | eigval, eigvec = np.linalg.eig(t.T @ t) 37 | # Sort eigenvectors in order of largest to smallest eigenvalue. 38 | inds = np.argsort(eigval)[::-1] 39 | eigvec = eigvec[:, inds] 40 | rot = eigvec.T 41 | if np.linalg.det(rot) < 0: 42 | rot = np.diag(np.array([1, 1, -1])) @ rot 43 | 44 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 45 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 46 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 47 | 48 | # Flip coordinate system if z component of y-axis is negative 49 | if poses_recentered.mean(axis=0)[2, 1] < 0: 50 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 51 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 52 | 53 | # Just make sure it's it in the [-1, 1]^3 cube 54 | if fix_radius>0: 55 | scale_factor = 1./fix_radius 56 | else: 57 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5) 58 | scale_factor = min(1 / 10, scale_factor) 59 | 60 | poses_recentered[:, :3, 3] *= scale_factor 61 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 62 | 63 | return poses_recentered, transform, scale_factor 64 | 65 | 66 | def readWaymoInfo(args): 67 | cam_infos = [] 68 | car_list = [f[:-4] for f in sorted(os.listdir(os.path.join(args.source_path, "calib"))) if f.endswith('.txt')] 69 | points = [] 70 | points_time = [] 71 | 72 | frame_num = len(car_list) 73 | if args.frame_interval > 0: 74 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2] 75 | else: 76 | time_duration = args.time_duration 77 | 78 | for idx, car_id in tqdm(enumerate(car_list), desc="Loading data"): 79 | ego_pose = np.loadtxt(os.path.join(args.source_path, 'pose', car_id + '.txt')) 80 | 81 | # CAMERA DIRECTION: RIGHT DOWN FORWARDS 82 | with open(os.path.join(args.source_path, 'calib', car_id + '.txt')) as f: 83 | calib_data = f.readlines() 84 | L = [list(map(float, line.split()[1:])) for line in calib_data] 85 | Ks = np.array(L[:5]).reshape(-1, 3, 4)[:, :, :3] 86 | lidar2cam = np.array(L[-5:]).reshape(-1, 3, 4) 87 | lidar2cam = pad_poses(lidar2cam) 88 | 89 | cam2lidar = np.linalg.inv(lidar2cam) 90 | c2w = ego_pose @ cam2lidar 91 | w2c = np.linalg.inv(c2w) 92 | images = [] 93 | image_paths = [] 94 | HWs = [] 95 | for subdir in ['image_0', 'image_1', 'image_2', 'image_3', 'image_4'][:args.cam_num]: 96 | image_path = os.path.join(args.source_path, subdir, car_id + '.png') 97 | im_data = Image.open(image_path) 98 | W, H = im_data.size 99 | image = np.array(im_data) / 255. 100 | HWs.append((H, W)) 101 | images.append(image) 102 | image_paths.append(image_path) 103 | 104 | sky_masks = [] 105 | for subdir in ['sky_0', 'sky_1', 'sky_2', 'sky_3', 'sky_4'][:args.cam_num]: 106 | sky_data = np.array(Image.open(os.path.join(args.source_path, subdir, car_id + '.png'))) 107 | sky_mask = sky_data>0 108 | sky_masks.append(sky_mask.astype(np.float32)) 109 | 110 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (len(car_list) - 1) 111 | point = np.fromfile(os.path.join(args.source_path, "velodyne", car_id + ".bin"), 112 | dtype=np.float32, count=-1).reshape(-1, 6) 113 | point_xyz, intensity, elongation, timestamp_pts = np.split(point, [3, 4, 5], axis=1) 114 | point_xyz_world = (np.pad(point_xyz, (0, 1), constant_values=1) @ ego_pose.T)[:, :3] 115 | points.append(point_xyz_world) 116 | point_time = np.full_like(point_xyz_world[:, :1], timestamp) 117 | points_time.append(point_time) 118 | for j in range(args.cam_num): 119 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ lidar2cam[j].T)[:, :3] 120 | R = np.transpose(w2c[j, :3, :3]) # R is stored transposed due to 'glm' in CUDA code 121 | T = w2c[j, :3, 3] 122 | K = Ks[j] 123 | fx = float(K[0, 0]) 124 | fy = float(K[1, 1]) 125 | cx = float(K[0, 2]) 126 | cy = float(K[1, 2]) 127 | FovX = FovY = -1.0 128 | cam_infos.append(CameraInfo(uid=idx * 5 + j, R=R, T=T, FovY=FovY, FovX=FovX, 129 | image=images[j], 130 | image_path=image_paths[j], image_name=car_id, 131 | width=HWs[j][1], height=HWs[j][0], timestamp=timestamp, 132 | pointcloud_camera = point_camera, 133 | fx=fx, fy=fy, cx=cx, cy=cy, 134 | sky_mask=sky_masks[j])) 135 | 136 | if args.debug_cuda: 137 | break 138 | 139 | pointcloud = np.concatenate(points, axis=0) 140 | pointcloud_timestamp = np.concatenate(points_time, axis=0) 141 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True) 142 | pointcloud = pointcloud[indices] 143 | pointcloud_timestamp = pointcloud_timestamp[indices] 144 | 145 | w2cs = np.zeros((len(cam_infos), 4, 4)) 146 | Rs = np.stack([c.R for c in cam_infos], axis=0) 147 | Ts = np.stack([c.T for c in cam_infos], axis=0) 148 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1)) 149 | w2cs[:, :3, 3] = Ts 150 | w2cs[:, 3, 3] = 1 151 | c2ws = unpad_poses(np.linalg.inv(w2cs)) 152 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius) 153 | 154 | c2ws = pad_poses(c2ws) 155 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")): 156 | c2w = c2ws[idx] 157 | w2c = np.linalg.inv(c2w) 158 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 159 | cam_info.T[:] = w2c[:3, 3] 160 | cam_info.pointcloud_camera[:] *= scale_factor 161 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3] 162 | if args.eval: 163 | # ## for snerf scene 164 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0] 165 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0] 166 | 167 | # for dynamic scene 168 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0] 169 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0] 170 | 171 | # for emernerf comparison [testhold::testhold] 172 | if args.testhold == 10: 173 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0] 174 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0] 175 | else: 176 | train_cam_infos = cam_infos 177 | test_cam_infos = [] 178 | 179 | nerf_normalization = getNerfppNorm(train_cam_infos) 180 | nerf_normalization['radius'] = 1/nerf_normalization['radius'] 181 | 182 | ply_path = os.path.join(args.source_path, "points3d.ply") 183 | if not os.path.exists(ply_path): 184 | rgbs = np.random.random((pointcloud.shape[0], 3)) 185 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp) 186 | try: 187 | pcd = fetchPly(ply_path) 188 | except: 189 | pcd = None 190 | 191 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp) 192 | time_interval = (time_duration[1] - time_duration[0]) / (len(car_list) - 1) 193 | 194 | scene_info = SceneInfo(point_cloud=pcd, 195 | train_cameras=train_cam_infos, 196 | test_cameras=test_cam_infos, 197 | nerf_normalization=nerf_normalization, 198 | ply_path=ply_path, 199 | time_interval=time_interval, 200 | time_duration=time_duration) 201 | 202 | return scene_info 203 | -------------------------------------------------------------------------------- /scripts/extract_kitti_metric_nvs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | root = "eval_output/kitti_nvs" 6 | scenes = ["0001", "0002", "0006"] 7 | 8 | eval_dict = { 9 | "TEST": {"psnr": [], "ssim": [], "lpips": []}, 10 | } 11 | for scene in scenes: 12 | eval_dir = os.path.join(root, scene, "eval") 13 | dirs = os.listdir(eval_dir) 14 | test_path = sorted([d for d in dirs if d.startswith("test")], key=lambda x: int(x.split("_")[1]))[-1] 15 | for name, path in [("TEST", test_path)]: 16 | psnrs = [] 17 | ssims = [] 18 | lpipss = [] 19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f: 20 | data = json.load(f) 21 | eval_dict[name]["psnr"].append(data["psnr"]) 22 | eval_dict[name]["ssim"].append(data["ssim"]) 23 | eval_dict[name]["lpips"].append(data["lpips"]) 24 | 25 | print(f'TEST PSNR:{np.mean(eval_dict["TEST"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TEST"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TEST"]["lpips"]):.3f}') 26 | -------------------------------------------------------------------------------- /scripts/extract_kitti_metric_reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | root = "eval_output/kitti_reconstruction" 6 | scenes = ["0001", "0002", "0006"] 7 | 8 | eval_dict = { 9 | "TRAIN": {"psnr": [], "ssim": [], "lpips": []}, 10 | } 11 | for scene in scenes: 12 | eval_dir = os.path.join(root, scene, "eval") 13 | dirs = os.listdir(eval_dir) 14 | test_path = sorted([d for d in dirs if d.startswith("train")], key=lambda x: int(x.split("_")[1]))[-1] 15 | for name, path in [("TRAIN", test_path)]: 16 | psnrs = [] 17 | ssims = [] 18 | lpipss = [] 19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f: 20 | data = json.load(f) 21 | eval_dict[name]["psnr"].append(data["psnr"]) 22 | eval_dict[name]["ssim"].append(data["ssim"]) 23 | eval_dict[name]["lpips"].append(data["lpips"]) 24 | 25 | print(f'TRAIN PSNR:{np.mean(eval_dict["TRAIN"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TRAIN"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TRAIN"]["lpips"]):.3f}') 26 | -------------------------------------------------------------------------------- /scripts/extract_mask_kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_masks.py 3 | @author Jianfei Guo, Shanghai AI Lab 4 | @brief Extract semantic mask 5 | 6 | Using SegFormer, 2021. Cityscapes 83.2% 7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9) 8 | 9 | Installation: 10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env. 11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9; 12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num' 13 | Hence, a seperate conda env is needed. 14 | 15 | git clone https://github.com/NVlabs/SegFormer 16 | 17 | conda create -n segformer python=3.8 18 | conda activate segformer 19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf 23 | pip install mmcv-full==1.2.7 --no-cache-dir 24 | 25 | cd SegFormer 26 | pip install . 27 | 28 | Usage: 29 | Direct run this script in the newly set conda env. 30 | """ 31 | import os 32 | import numpy as np 33 | import cv2 34 | from tqdm import tqdm 35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot 36 | from mmseg.core.evaluation import get_palette 37 | 38 | if __name__ == "__main__": 39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer' 40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 'segformer.b5.1024x1024.city.160k.py') 41 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth') 42 | model = init_segmentor(config, checkpoint, device='cuda') 43 | 44 | root = 'data/kitti_mot/training' 45 | 46 | for cam_id in ['2', '3']: 47 | image_dir = os.path.join(root, f'image_0{cam_id}') 48 | sky_dir = os.path.join(root, f'sky_0{cam_id}') 49 | for seq in sorted(os.listdir(image_dir)): 50 | seq_dir = os.path.join(image_dir, seq) 51 | mask_dir = os.path.join(sky_dir, seq) 52 | if not os.path.isdir(seq_dir): 53 | continue 54 | 55 | os.makedirs(image_dir, exist_ok=True) 56 | os.makedirs(mask_dir, exist_ok=True) 57 | for image_name in sorted(os.listdir(seq_dir)): 58 | image_path = os.path.join(seq_dir, image_name) 59 | print(image_path) 60 | mask_path = os.path.join(mask_dir, image_name) 61 | if not image_path.endswith(".png"): 62 | continue 63 | result = inference_segmentor(model, image_path) 64 | mask = result[0].astype(np.uint8) 65 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8) 66 | cv2.imwrite(os.path.join(mask_dir, image_name), mask) 67 | -------------------------------------------------------------------------------- /scripts/extract_mask_waymo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_masks.py 3 | @author Jianfei Guo, Shanghai AI Lab 4 | @brief Extract semantic mask 5 | 6 | Using SegFormer, 2021. Cityscapes 83.2% 7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9) 8 | 9 | Installation: 10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env. 11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9; 12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num' 13 | Hence, a seperate conda env is needed. 14 | 15 | git clone https://github.com/NVlabs/SegFormer 16 | 17 | conda create -n segformer python=3.8 18 | conda activate segformer 19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf 23 | pip install mmcv-full==1.2.7 --no-cache-dir 24 | 25 | cd SegFormer 26 | pip install . 27 | 28 | Usage: 29 | Direct run this script in the newly set conda env. 30 | """ 31 | import os 32 | import numpy as np 33 | import cv2 34 | from tqdm import tqdm 35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot 36 | from mmseg.core.evaluation import get_palette 37 | 38 | if __name__ == "__main__": 39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer' 40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 41 | 'segformer.b5.1024x1024.city.160k.py') 42 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth') 43 | model = init_segmentor(config, checkpoint, device='cuda') 44 | 45 | root = 'data/waymo_scenes' 46 | 47 | scenes = sorted(os.listdir(root)) 48 | 49 | for scene in scenes: 50 | for cam_id in range(5): 51 | image_dir = os.path.join(root, scene, f'image_{cam_id}') 52 | sky_dir = os.path.join(root, scene, f'sky_{cam_id}') 53 | os.makedirs(sky_dir, exist_ok=True) 54 | for image_name in tqdm(sorted(os.listdir(image_dir))): 55 | if not image_name.endswith(".png"): 56 | continue 57 | image_path = os.path.join(image_dir, image_name) 58 | mask_path = os.path.join(sky_dir, image_name) 59 | result = inference_segmentor(model, image_path) 60 | mask = result[0].astype(np.uint8) 61 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8) 62 | cv2.imwrite(mask_path, mask) 63 | -------------------------------------------------------------------------------- /scripts/extract_scenes_waymo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from tqdm import tqdm 4 | 5 | data_root = '/HDD_DISK/datasets/waymo/kitti_format/training' 6 | 7 | tags = ['image_0','image_1','image_2','image_3','image_4','calib','velodyne','pose'] 8 | posts = ['.png','.png','.png','.png','.png','.txt', '.bin','.txt'] 9 | 10 | out_dir = 'data/waymo_scenes_streetsurf' 11 | 12 | scene_ids = [3, 19, 36, 69, 81, 126, 139, 140, 146, 13 | 148, 157, 181, 200, 204, 226, 232, 237, 14 | 241, 245, 246, 271, 297, 302, 312, 314, 15 | 362, 482, 495, 524, 527] 16 | scene_nums = [ 17 | [0, 163], 18 | [0, 198], 19 | [0, 198], 20 | [0, 198], 21 | [0, 198], 22 | [0, 198], 23 | [0, 198], 24 | [17, 198], 25 | [0, 198], 26 | [0, 198], 27 | [0, 140], 28 | [24, 198], 29 | [0, 198], 30 | [0, 198], 31 | [0, 198], 32 | [0, 198], 33 | [0, 198], 34 | [30, 198], 35 | [80, 198], 36 | [0, 170], 37 | [70, 198], 38 | [0, 198], 39 | [0, 198], 40 | [0, 120], 41 | [0, 198], 42 | [0, 198], 43 | [0, 198], 44 | [0, 198], 45 | [0, 198], 46 | [0, 90], 47 | ] 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | for scene_idx, scene_id in enumerate(scene_ids): 51 | scene_dir = join(out_dir, f'{scene_id:04d}001') 52 | os.makedirs(scene_dir, exist_ok=True) 53 | 54 | for tag in tags: 55 | os.makedirs(join(scene_dir, tag), exist_ok=True) 56 | for post, tag in zip(posts,tags): 57 | for i in tqdm(range(scene_nums[scene_idx][0], scene_nums[scene_idx][1])): 58 | cmd = "cp {} {}".format(join(data_root,tag,f'{scene_id:04d}{i:03d}'+post), 59 | join(scene_dir, tag, f'{scene_id:04d}{i:03d}'+post)) 60 | os.system(cmd) 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /scripts/extract_waymo_metric_nvs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | root = "eval_output/waymo_nvs" 6 | scenes = ["0017085", "0145050", "0147030", "0158150"] 7 | 8 | eval_dict = { 9 | "TEST": {"psnr": [], "ssim": [], "lpips": []}, 10 | } 11 | for scene in scenes: 12 | eval_dir = os.path.join(root, scene, "eval") 13 | dirs = os.listdir(eval_dir) 14 | test_path = sorted([d for d in dirs if d.startswith("test")], key=lambda x: int(x.split("_")[1]))[-1] 15 | for name, path in [("TEST", test_path)]: 16 | psnrs = [] 17 | ssims = [] 18 | lpipss = [] 19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f: 20 | data = json.load(f) 21 | eval_dict[name]["psnr"].append(data["psnr"]) 22 | eval_dict[name]["ssim"].append(data["ssim"]) 23 | eval_dict[name]["lpips"].append(data["lpips"]) 24 | 25 | print(f'TEST PSNR:{np.mean(eval_dict["TEST"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TEST"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TEST"]["lpips"]):.3f}') 26 | -------------------------------------------------------------------------------- /scripts/extract_waymo_metric_reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | root = "eval_output/waymo_reconstruction" 6 | scenes = ["0017085", "0145050", "0147030", "0158150"] 7 | 8 | eval_dict = { 9 | "TRAIN": {"psnr": [], "ssim": [], "lpips": []}, 10 | } 11 | for scene in scenes: 12 | eval_dir = os.path.join(root, scene, "eval") 13 | dirs = os.listdir(eval_dir) 14 | test_path = sorted([d for d in dirs if d.startswith("train")], key=lambda x: int(x.split("_")[1]))[-1] 15 | for name, path in [("TRAIN", test_path)]: 16 | psnrs = [] 17 | ssims = [] 18 | lpipss = [] 19 | with open(os.path.join(eval_dir, path, "metrics.json"), "r") as f: 20 | data = json.load(f) 21 | eval_dict[name]["psnr"].append(data["psnr"]) 22 | eval_dict[name]["ssim"].append(data["ssim"]) 23 | eval_dict[name]["lpips"].append(data["lpips"]) 24 | 25 | print(f'TRAIN PSNR:{np.mean(eval_dict["TRAIN"]["psnr"]):.3f} SSIM:{np.mean(eval_dict["TRAIN"]["ssim"]):.3f} LPIPS:{np.mean(eval_dict["TRAIN"]["lpips"]):.3f}') 26 | -------------------------------------------------------------------------------- /scripts/run_kitti_nvs_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python train.py \ 2 | --config configs/kitti_nvs.yaml \ 3 | source_path=data/kitti_mot/training/image_02/0001 \ 4 | model_path=eval_output/kitti_nvs/0001 \ 5 | start_frame=380 end_frame=431 6 | 7 | CUDA_VISIBLE_DEVICES=2 python train.py \ 8 | --config configs/kitti_nvs.yaml \ 9 | source_path=data/kitti_mot/training/image_02/0002 \ 10 | model_path=eval_output/kitti_nvs/0002 \ 11 | start_frame=140 end_frame=224 12 | 13 | CUDA_VISIBLE_DEVICES=2 python train.py \ 14 | --config configs/kitti_nvs.yaml \ 15 | source_path=data/kitti_mot/training/image_02/0006 \ 16 | model_path=eval_output/kitti_nvs/0006 \ 17 | start_frame=65 end_frame=120 18 | -------------------------------------------------------------------------------- /scripts/run_kitti_reconstruction_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=3 python train.py \ 2 | --config configs/kitti_reconstruction.yaml \ 3 | source_path=data/kitti_mot/training/image_02/0001 \ 4 | model_path=eval_output/kitti_reconstruction/0001 \ 5 | start_frame=380 end_frame=431 6 | 7 | CUDA_VISIBLE_DEVICES=3 python train.py \ 8 | --config configs/kitti_reconstruction.yaml \ 9 | source_path=data/kitti_mot/training/image_02/0002 \ 10 | model_path=eval_output/kitti_reconstruction/0002 \ 11 | start_frame=140 end_frame=224 12 | 13 | CUDA_VISIBLE_DEVICES=3 python train.py \ 14 | --config configs/kitti_reconstruction.yaml \ 15 | source_path=data/kitti_mot/training/image_02/0006 \ 16 | model_path=eval_output/kitti_reconstruction/0006 \ 17 | start_frame=65 end_frame=120 18 | -------------------------------------------------------------------------------- /scripts/run_waymo_nvs_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --config configs/waymo_nvs.yaml \ 3 | source_path=data/waymo_scenes/0145050 \ 4 | model_path=eval_output/waymo_nvs/0145050 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --config configs/waymo_nvs.yaml \ 8 | source_path=data/waymo_scenes/0147030 \ 9 | model_path=eval_output/waymo_nvs/0147030 10 | 11 | CUDA_VISIBLE_DEVICES=0 python train.py \ 12 | --config configs/waymo_nvs.yaml \ 13 | source_path=data/waymo_scenes/0158150 \ 14 | model_path=eval_output/waymo_nvs/0158150 15 | 16 | CUDA_VISIBLE_DEVICES=0 python train.py \ 17 | --config configs/waymo_nvs.yaml \ 18 | source_path=data/waymo_scenes/0017085 \ 19 | model_path=eval_output/waymo_nvs/0017085 20 | -------------------------------------------------------------------------------- /scripts/run_waymo_reconstruction_all.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python train.py \ 2 | --config configs/waymo_reconstruction.yaml \ 3 | source_path=data/waymo_scenes/0145050 \ 4 | model_path=eval_output/waymo_reconstruction/0145050 5 | 6 | CUDA_VISIBLE_DEVICES=1 python train.py \ 7 | --config configs/waymo_reconstruction.yaml \ 8 | source_path=data/waymo_scenes/0147030 \ 9 | model_path=eval_output/waymo_reconstruction/0147030 10 | 11 | CUDA_VISIBLE_DEVICES=1 python train.py \ 12 | --config configs/waymo_reconstruction.yaml \ 13 | source_path=data/waymo_scenes/0158150 \ 14 | model_path=eval_output/waymo_reconstruction/0158150 15 | 16 | CUDA_VISIBLE_DEVICES=1 python train.py \ 17 | --config configs/waymo_reconstruction.yaml \ 18 | source_path=data/waymo_scenes/0017085 \ 19 | model_path=eval_output/waymo_reconstruction/0017085 20 | -------------------------------------------------------------------------------- /scripts/waymo_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | r"""Adapted from `Waymo to KITTI converter 3 | `_. 4 | """ 5 | 6 | try: 7 | from waymo_open_dataset import dataset_pb2 8 | except ImportError: 9 | raise ImportError( 10 | 'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" ' 11 | 'to install the official devkit first.') 12 | 13 | from glob import glob 14 | from os.path import join 15 | 16 | import mmcv 17 | import numpy as np 18 | import tensorflow as tf 19 | from waymo_open_dataset.utils import range_image_utils, transform_utils 20 | from waymo_open_dataset.utils.frame_utils import \ 21 | parse_range_image_and_camera_projection 22 | 23 | 24 | class Waymo2KITTI(object): 25 | """Waymo to KITTI converter. 26 | 27 | This class serves as the converter to change the waymo raw data to KITTI 28 | format. 29 | 30 | Args: 31 | load_dir (str): Directory to load waymo raw data. 32 | save_dir (str): Directory to save data in KITTI format. 33 | prefix (str): Prefix of filename. In general, 0 for training, 1 for 34 | validation and 2 for testing. 35 | workers (int, optional): Number of workers for the parallel process. 36 | test_mode (bool, optional): Whether in the test_mode. Default: False. 37 | """ 38 | 39 | def __init__(self, 40 | load_dir, 41 | save_dir, 42 | prefix, 43 | workers=64, 44 | test_mode=False): 45 | self.filter_empty_3dboxes = True 46 | self.filter_no_label_zone_points = True 47 | 48 | self.selected_waymo_classes = ['VEHICLE', 'PEDESTRIAN', 'CYCLIST'] 49 | 50 | # Only data collected in specific locations will be converted 51 | # If set None, this filter is disabled 52 | # Available options: location_sf (main dataset) 53 | self.selected_waymo_locations = None 54 | self.save_track_id = False 55 | 56 | # turn on eager execution for older tensorflow versions 57 | if int(tf.__version__.split('.')[0]) < 2: 58 | tf.enable_eager_execution() 59 | 60 | self.lidar_list = [ 61 | '_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT', 62 | '_SIDE_LEFT' 63 | ] 64 | self.type_list = [ 65 | 'UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST' 66 | ] 67 | self.waymo_to_kitti_class_map = { 68 | 'UNKNOWN': 'DontCare', 69 | 'PEDESTRIAN': 'Pedestrian', 70 | 'VEHICLE': 'Car', 71 | 'CYCLIST': 'Cyclist', 72 | 'SIGN': 'Sign' # not in kitti 73 | } 74 | 75 | self.load_dir = load_dir 76 | self.save_dir = save_dir 77 | self.prefix = prefix 78 | self.workers = int(workers) 79 | self.test_mode = test_mode 80 | 81 | self.tfrecord_pathnames = sorted( 82 | glob(join(self.load_dir, '*.tfrecord'))) 83 | 84 | self.label_save_dir = f'{self.save_dir}/label_' 85 | self.label_all_save_dir = f'{self.save_dir}/label_all' 86 | self.image_save_dir = f'{self.save_dir}/image_' 87 | self.calib_save_dir = f'{self.save_dir}/calib' 88 | self.point_cloud_save_dir = f'{self.save_dir}/velodyne' 89 | self.pose_save_dir = f'{self.save_dir}/pose' 90 | self.timestamp_save_dir = f'{self.save_dir}/timestamp' 91 | 92 | self.create_folder() 93 | 94 | def convert(self): 95 | """Convert action.""" 96 | print('Start converting ...') 97 | mmcv.track_parallel_progress(self.convert_one, range(len(self)), 98 | self.workers) 99 | print('\nFinished ...') 100 | 101 | def convert_one(self, file_idx): 102 | """Convert action for single file. 103 | 104 | Args: 105 | file_idx (int): Index of the file to be converted. 106 | """ 107 | pathname = self.tfrecord_pathnames[file_idx] 108 | dataset = tf.data.TFRecordDataset(pathname, compression_type='') 109 | 110 | for frame_idx, data in enumerate(dataset): 111 | 112 | frame = dataset_pb2.Frame() 113 | frame.ParseFromString(bytearray(data.numpy())) 114 | if (self.selected_waymo_locations is not None 115 | and frame.context.stats.location 116 | not in self.selected_waymo_locations): 117 | continue 118 | 119 | self.save_image(frame, file_idx, frame_idx) 120 | self.save_calib(frame, file_idx, frame_idx) 121 | self.save_lidar(frame, file_idx, frame_idx) 122 | self.save_pose(frame, file_idx, frame_idx) 123 | self.save_timestamp(frame, file_idx, frame_idx) 124 | 125 | if not self.test_mode: 126 | self.save_label(frame, file_idx, frame_idx) 127 | 128 | def __len__(self): 129 | """Length of the filename list.""" 130 | return len(self.tfrecord_pathnames) 131 | 132 | def save_image(self, frame, file_idx, frame_idx): 133 | """Parse and save the images in jpg format. Jpg is the original format 134 | used by Waymo Open dataset. Saving in png format will cause huge (~3x) 135 | unnesssary storage waste. 136 | 137 | Args: 138 | frame (:obj:`Frame`): Open dataset frame proto. 139 | file_idx (int): Current file index. 140 | frame_idx (int): Current frame index. 141 | """ 142 | for img in frame.images: 143 | img_path = f'{self.image_save_dir}{str(img.name - 1)}/' + \ 144 | f'{self.prefix}{str(file_idx).zfill(3)}' + \ 145 | f'{str(frame_idx).zfill(3)}.jpg' 146 | with open(img_path, 'wb') as fp: 147 | fp.write(img.image) 148 | 149 | def save_calib(self, frame, file_idx, frame_idx): 150 | """Parse and save the calibration data. 151 | 152 | Args: 153 | frame (:obj:`Frame`): Open dataset frame proto. 154 | file_idx (int): Current file index. 155 | frame_idx (int): Current frame index. 156 | """ 157 | # waymo front camera to kitti reference camera 158 | T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], 159 | [1.0, 0.0, 0.0]]) 160 | camera_calibs = [] 161 | R0_rect = [f'{i:e}' for i in np.eye(3).flatten()] 162 | Tr_velo_to_cams = [] 163 | calib_context = '' 164 | 165 | for camera in frame.context.camera_calibrations: 166 | # extrinsic parameters 167 | T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape( 168 | 4, 4) 169 | T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle) 170 | Tr_velo_to_cam = \ 171 | self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam 172 | if camera.name == 1: # FRONT = 1, see dataset.proto for details 173 | self.T_velo_to_front_cam = Tr_velo_to_cam.copy() 174 | Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12, )) 175 | Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam]) 176 | 177 | # intrinsic parameters 178 | camera_calib = np.zeros((3, 4)) 179 | camera_calib[0, 0] = camera.intrinsic[0] 180 | camera_calib[1, 1] = camera.intrinsic[1] 181 | camera_calib[0, 2] = camera.intrinsic[2] 182 | camera_calib[1, 2] = camera.intrinsic[3] 183 | camera_calib[2, 2] = 1 184 | camera_calib = list(camera_calib.reshape(12)) 185 | camera_calib = [f'{i:e}' for i in camera_calib] 186 | camera_calibs.append(camera_calib) 187 | 188 | # all camera ids are saved as id-1 in the result because 189 | # camera 0 is unknown in the proto 190 | for i in range(5): 191 | calib_context += 'P' + str(i) + ': ' + \ 192 | ' '.join(camera_calibs[i]) + '\n' 193 | calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n' 194 | for i in range(5): 195 | calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \ 196 | ' '.join(Tr_velo_to_cams[i]) + '\n' 197 | 198 | with open( 199 | f'{self.calib_save_dir}/{self.prefix}' + 200 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 201 | 'w+') as fp_calib: 202 | fp_calib.write(calib_context) 203 | fp_calib.close() 204 | 205 | def save_lidar(self, frame, file_idx, frame_idx): 206 | """Parse and save the lidar data in psd format. 207 | 208 | Args: 209 | frame (:obj:`Frame`): Open dataset frame proto. 210 | file_idx (int): Current file index. 211 | frame_idx (int): Current frame index. 212 | """ 213 | range_images, camera_projections, range_image_top_pose = \ 214 | parse_range_image_and_camera_projection(frame) 215 | 216 | # First return 217 | points_0, cp_points_0, intensity_0, elongation_0, mask_indices_0 = \ 218 | self.convert_range_image_to_point_cloud( 219 | frame, 220 | range_images, 221 | camera_projections, 222 | range_image_top_pose, 223 | ri_index=0 224 | ) 225 | points_0 = np.concatenate(points_0, axis=0) 226 | intensity_0 = np.concatenate(intensity_0, axis=0) 227 | elongation_0 = np.concatenate(elongation_0, axis=0) 228 | mask_indices_0 = np.concatenate(mask_indices_0, axis=0) 229 | 230 | # Second return 231 | points_1, cp_points_1, intensity_1, elongation_1, mask_indices_1 = \ 232 | self.convert_range_image_to_point_cloud( 233 | frame, 234 | range_images, 235 | camera_projections, 236 | range_image_top_pose, 237 | ri_index=1 238 | ) 239 | points_1 = np.concatenate(points_1, axis=0) 240 | intensity_1 = np.concatenate(intensity_1, axis=0) 241 | elongation_1 = np.concatenate(elongation_1, axis=0) 242 | mask_indices_1 = np.concatenate(mask_indices_1, axis=0) 243 | 244 | points = np.concatenate([points_0, points_1], axis=0) 245 | intensity = np.concatenate([intensity_0, intensity_1], axis=0) 246 | elongation = np.concatenate([elongation_0, elongation_1], axis=0) 247 | mask_indices = np.concatenate([mask_indices_0, mask_indices_1], axis=0) 248 | 249 | # timestamp = frame.timestamp_micros * np.ones_like(intensity) 250 | 251 | # concatenate x,y,z, intensity, elongation, timestamp (6-dim) 252 | point_cloud = np.column_stack( 253 | (points, intensity, elongation, mask_indices)) 254 | 255 | pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ 256 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin' 257 | point_cloud.astype(np.float32).tofile(pc_path) 258 | 259 | def save_label(self, frame, file_idx, frame_idx): 260 | """Parse and save the label data in txt format. 261 | The relation between waymo and kitti coordinates is noteworthy: 262 | 1. x, y, z correspond to l, w, h (waymo) -> l, h, w (kitti) 263 | 2. x-y-z: front-left-up (waymo) -> right-down-front(kitti) 264 | 3. bbox origin at volumetric center (waymo) -> bottom center (kitti) 265 | 4. rotation: +x around y-axis (kitti) -> +x around z-axis (waymo) 266 | 267 | Args: 268 | frame (:obj:`Frame`): Open dataset frame proto. 269 | file_idx (int): Current file index. 270 | frame_idx (int): Current frame index. 271 | """ 272 | fp_label_all = open( 273 | f'{self.label_all_save_dir}/{self.prefix}' + 274 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') 275 | id_to_bbox = dict() 276 | id_to_name = dict() 277 | for labels in frame.projected_lidar_labels: 278 | name = labels.name 279 | for label in labels.labels: 280 | # TODO: need a workaround as bbox may not belong to front cam 281 | bbox = [ 282 | label.box.center_x - label.box.length / 2, 283 | label.box.center_y - label.box.width / 2, 284 | label.box.center_x + label.box.length / 2, 285 | label.box.center_y + label.box.width / 2 286 | ] 287 | id_to_bbox[label.id] = bbox 288 | id_to_name[label.id] = name - 1 289 | 290 | for obj in frame.laser_labels: 291 | bounding_box = None 292 | name = None 293 | id = obj.id 294 | for lidar in self.lidar_list: 295 | if id + lidar in id_to_bbox: 296 | bounding_box = id_to_bbox.get(id + lidar) 297 | name = str(id_to_name.get(id + lidar)) 298 | break 299 | 300 | if bounding_box is None or name is None: 301 | name = '0' 302 | bounding_box = (0, 0, 0, 0) 303 | 304 | my_type = self.type_list[obj.type] 305 | 306 | if my_type not in self.selected_waymo_classes: 307 | continue 308 | 309 | if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: 310 | continue 311 | 312 | my_type = self.waymo_to_kitti_class_map[my_type] 313 | 314 | height = obj.box.height 315 | width = obj.box.width 316 | length = obj.box.length 317 | 318 | x = obj.box.center_x 319 | y = obj.box.center_y 320 | z = obj.box.center_z - height / 2 321 | 322 | # project bounding box to the virtual reference frame 323 | pt_ref = self.T_velo_to_front_cam @ \ 324 | np.array([x, y, z, 1]).reshape((4, 1)) 325 | x, y, z, _ = pt_ref.flatten().tolist() 326 | 327 | rotation_y = -obj.box.heading - np.pi / 2 328 | track_id = obj.id 329 | 330 | # not available 331 | truncated = 0 332 | occluded = 0 333 | alpha = -10 334 | 335 | line = my_type + \ 336 | ' {} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format( 337 | round(truncated, 2), occluded, round(alpha, 2), 338 | round(bounding_box[0], 2), round(bounding_box[1], 2), 339 | round(bounding_box[2], 2), round(bounding_box[3], 2), 340 | round(height, 2), round(width, 2), round(length, 2), 341 | round(x, 2), round(y, 2), round(z, 2), 342 | round(rotation_y, 2)) 343 | 344 | if self.save_track_id: 345 | line_all = line[:-1] + ' ' + name + ' ' + track_id + '\n' 346 | else: 347 | line_all = line[:-1] + ' ' + name + '\n' 348 | 349 | fp_label = open( 350 | f'{self.label_save_dir}{name}/{self.prefix}' + 351 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'a') 352 | fp_label.write(line) 353 | fp_label.close() 354 | 355 | fp_label_all.write(line_all) 356 | 357 | fp_label_all.close() 358 | 359 | def save_pose(self, frame, file_idx, frame_idx): 360 | """Parse and save the pose data. 361 | 362 | Note that SDC's own pose is not included in the regular training 363 | of KITTI dataset. KITTI raw dataset contains ego motion files 364 | but are not often used. Pose is important for algorithms that 365 | take advantage of the temporal information. 366 | 367 | Args: 368 | frame (:obj:`Frame`): Open dataset frame proto. 369 | file_idx (int): Current file index. 370 | frame_idx (int): Current frame index. 371 | """ 372 | pose = np.array(frame.pose.transform).reshape(4, 4) 373 | np.savetxt( 374 | join(f'{self.pose_save_dir}/{self.prefix}' + 375 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'), 376 | pose) 377 | 378 | def save_timestamp(self, frame, file_idx, frame_idx): 379 | """Save the timestamp data in a separate file instead of the 380 | pointcloud. 381 | 382 | Note that SDC's own pose is not included in the regular training 383 | of KITTI dataset. KITTI raw dataset contains ego motion files 384 | but are not often used. Pose is important for algorithms that 385 | take advantage of the temporal information. 386 | 387 | Args: 388 | frame (:obj:`Frame`): Open dataset frame proto. 389 | file_idx (int): Current file index. 390 | frame_idx (int): Current frame index. 391 | """ 392 | with open( 393 | join(f'{self.timestamp_save_dir}/{self.prefix}' + 394 | f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'), 395 | 'w') as f: 396 | f.write(str(frame.timestamp_micros)) 397 | 398 | def create_folder(self): 399 | """Create folder for data preprocessing.""" 400 | if not self.test_mode: 401 | dir_list1 = [ 402 | self.label_all_save_dir, self.calib_save_dir, 403 | self.point_cloud_save_dir, self.pose_save_dir, 404 | self.timestamp_save_dir 405 | ] 406 | dir_list2 = [self.label_save_dir, self.image_save_dir] 407 | else: 408 | dir_list1 = [ 409 | self.calib_save_dir, self.point_cloud_save_dir, 410 | self.pose_save_dir, self.timestamp_save_dir 411 | ] 412 | dir_list2 = [self.image_save_dir] 413 | for d in dir_list1: 414 | mmcv.mkdir_or_exist(d) 415 | for d in dir_list2: 416 | for i in range(5): 417 | mmcv.mkdir_or_exist(f'{d}{str(i)}') 418 | 419 | def convert_range_image_to_point_cloud(self, 420 | frame, 421 | range_images, 422 | camera_projections, 423 | range_image_top_pose, 424 | ri_index=0): 425 | """Convert range images to point cloud. 426 | 427 | Args: 428 | frame (:obj:`Frame`): Open dataset frame. 429 | range_images (dict): Mapping from laser_name to list of two 430 | range images corresponding with two returns. 431 | camera_projections (dict): Mapping from laser_name to list of two 432 | camera projections corresponding with two returns. 433 | range_image_top_pose (:obj:`Transform`): Range image pixel pose for 434 | top lidar. 435 | ri_index (int, optional): 0 for the first return, 436 | 1 for the second return. Default: 0. 437 | 438 | Returns: 439 | tuple[list[np.ndarray]]: (List of points with shape [N, 3], 440 | camera projections of points with shape [N, 6], intensity 441 | with shape [N, 1], elongation with shape [N, 1], points' 442 | position in the depth map (element offset if points come from 443 | the main lidar otherwise -1) with shape[N, 1]). All the 444 | lists have the length of lidar numbers (5). 445 | """ 446 | calibrations = sorted( 447 | frame.context.laser_calibrations, key=lambda c: c.name) 448 | points = [] 449 | cp_points = [] 450 | intensity = [] 451 | elongation = [] 452 | mask_indices = [] 453 | 454 | frame_pose = tf.convert_to_tensor( 455 | value=np.reshape(np.array(frame.pose.transform), [4, 4])) 456 | # [H, W, 6] 457 | range_image_top_pose_tensor = tf.reshape( 458 | tf.convert_to_tensor(value=range_image_top_pose.data), 459 | range_image_top_pose.shape.dims) 460 | # [H, W, 3, 3] 461 | range_image_top_pose_tensor_rotation = \ 462 | transform_utils.get_rotation_matrix( 463 | range_image_top_pose_tensor[..., 0], 464 | range_image_top_pose_tensor[..., 1], 465 | range_image_top_pose_tensor[..., 2]) 466 | range_image_top_pose_tensor_translation = \ 467 | range_image_top_pose_tensor[..., 3:] 468 | range_image_top_pose_tensor = transform_utils.get_transform( 469 | range_image_top_pose_tensor_rotation, 470 | range_image_top_pose_tensor_translation) 471 | for c in calibrations: 472 | range_image = range_images[c.name][ri_index] 473 | if len(c.beam_inclinations) == 0: 474 | beam_inclinations = range_image_utils.compute_inclination( 475 | tf.constant( 476 | [c.beam_inclination_min, c.beam_inclination_max]), 477 | height=range_image.shape.dims[0]) 478 | else: 479 | beam_inclinations = tf.constant(c.beam_inclinations) 480 | 481 | beam_inclinations = tf.reverse(beam_inclinations, axis=[-1]) 482 | extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4]) 483 | 484 | range_image_tensor = tf.reshape( 485 | tf.convert_to_tensor(value=range_image.data), 486 | range_image.shape.dims) 487 | pixel_pose_local = None 488 | frame_pose_local = None 489 | if c.name == dataset_pb2.LaserName.TOP: 490 | pixel_pose_local = range_image_top_pose_tensor 491 | pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0) 492 | frame_pose_local = tf.expand_dims(frame_pose, axis=0) 493 | range_image_mask = range_image_tensor[..., 0] > 0 494 | 495 | if self.filter_no_label_zone_points: 496 | nlz_mask = range_image_tensor[..., 3] != 1.0 # 1.0: in NLZ 497 | range_image_mask = range_image_mask & nlz_mask 498 | 499 | range_image_cartesian = \ 500 | range_image_utils.extract_point_cloud_from_range_image( 501 | tf.expand_dims(range_image_tensor[..., 0], axis=0), 502 | tf.expand_dims(extrinsic, axis=0), 503 | tf.expand_dims(tf.convert_to_tensor( 504 | value=beam_inclinations), axis=0), 505 | pixel_pose=pixel_pose_local, 506 | frame_pose=frame_pose_local) 507 | 508 | mask_index = tf.where(range_image_mask) 509 | 510 | range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) 511 | points_tensor = tf.gather_nd(range_image_cartesian, mask_index) 512 | 513 | cp = camera_projections[c.name][ri_index] 514 | cp_tensor = tf.reshape( 515 | tf.convert_to_tensor(value=cp.data), cp.shape.dims) 516 | cp_points_tensor = tf.gather_nd(cp_tensor, mask_index) 517 | points.append(points_tensor.numpy()) 518 | cp_points.append(cp_points_tensor.numpy()) 519 | 520 | intensity_tensor = tf.gather_nd(range_image_tensor[..., 1], 521 | mask_index) 522 | intensity.append(intensity_tensor.numpy()) 523 | 524 | elongation_tensor = tf.gather_nd(range_image_tensor[..., 2], 525 | mask_index) 526 | elongation.append(elongation_tensor.numpy()) 527 | if c.name == 1: 528 | mask_index = (ri_index * range_image_mask.shape[0] + 529 | mask_index[:, 0] 530 | ) * range_image_mask.shape[1] + mask_index[:, 1] 531 | mask_index = mask_index.numpy().astype(elongation[-1].dtype) 532 | else: 533 | mask_index = np.full_like(elongation[-1], -1) 534 | 535 | mask_indices.append(mask_index) 536 | 537 | return points, cp_points, intensity, elongation, mask_indices 538 | 539 | def cart_to_homo(self, mat): 540 | """Convert transformation matrix in Cartesian coordinates to 541 | homogeneous format. 542 | 543 | Args: 544 | mat (np.ndarray): Transformation matrix in Cartesian. 545 | The input matrix shape is 3x3 or 3x4. 546 | 547 | Returns: 548 | np.ndarray: Transformation matrix in homogeneous format. 549 | The matrix shape is 4x4. 550 | """ 551 | ret = np.eye(4) 552 | if mat.shape == (3, 3): 553 | ret[:3, :3] = mat 554 | elif mat.shape == (3, 4): 555 | ret[:3, :] = mat 556 | else: 557 | raise ValueError(mat.shape) 558 | return ret 559 | 560 | if __name__ == '__main__': 561 | import argparse 562 | 563 | parser = argparse.ArgumentParser(description='Data converter arg parser') 564 | parser.add_argument('dataset', metavar='kitti', help='name of the dataset') 565 | parser.add_argument( 566 | '--root-path', 567 | type=str, 568 | default='./data/kitti', 569 | help='specify the root path of dataset') 570 | parser.add_argument( 571 | '--version', 572 | type=str, 573 | default='v1.0', 574 | required=False, 575 | help='specify the dataset version, no need for kitti') 576 | parser.add_argument( 577 | '--max-sweeps', 578 | type=int, 579 | default=10, 580 | required=False, 581 | help='specify sweeps of lidar per example') 582 | parser.add_argument( 583 | '--with-plane', 584 | action='store_true', 585 | help='Whether to use plane information for kitti.') 586 | parser.add_argument( 587 | '--num-points', 588 | type=int, 589 | default=-1, 590 | help='Number of points to sample for indoor datasets.') 591 | parser.add_argument( 592 | '--out-dir', 593 | type=str, 594 | default='./data/kitti', 595 | required=False, 596 | help='name of info pkl') 597 | parser.add_argument('--extra-tag', type=str, default='kitti') 598 | parser.add_argument( 599 | '--workers', type=int, default=4, help='number of threads to be used') 600 | args = parser.parse_args() 601 | 602 | from os import path as osp 603 | splits = ['training', 'validation', 'testing'] 604 | for i, split in enumerate(splits): 605 | load_dir = osp.join(args.root_path, 'waymo_format', split) 606 | if split == 'validation': 607 | save_dir = osp.join(args.out_dir, 'kitti_format', 'training') 608 | else: 609 | save_dir = osp.join(args.out_dir, 'kitti_format', split) 610 | converter = Waymo2KITTI( 611 | load_dir, 612 | save_dir, 613 | prefix=str(i), 614 | workers=args.workers, 615 | test_mode=(split == 'testing')) 616 | converter.convert() 617 | # Generate waymo infos 618 | # out_dir = osp.join(out_dir, 'kitti_format') 619 | # kitti.create_waymo_info_file( 620 | # out_dir, info_prefix, max_sweeps=max_sweeps, workers=workers) 621 | # GTDatabaseCreater( 622 | # 'WaymoDataset', 623 | # out_dir, 624 | # info_prefix, 625 | # f'{out_dir}/{info_prefix}_infos_train.pkl', 626 | # relative_path=False, 627 | # with_mask=False, 628 | # num_worker=workers).create() -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import glob 12 | import os 13 | import torch 14 | from gaussian_renderer import render 15 | from scene import Scene, GaussianModel, EnvLight 16 | from utils.general_utils import seed_everything 17 | from tqdm import tqdm 18 | from argparse import ArgumentParser 19 | from torchvision.utils import save_image 20 | from omegaconf import OmegaConf 21 | 22 | EPS = 1e-5 23 | 24 | @torch.no_grad() 25 | def separation(scene : Scene, renderFunc, renderArgs, env_map=None): 26 | scale = scene.resolution_scales[0] 27 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 28 | {'name': 'train', 'cameras': scene.getTrainCameras()}) 29 | 30 | # we supppose area with altitude>0.5 is static 31 | # here z axis is downward so is gaussians.get_xyz[:, 2] < -0.5 32 | high_mask = gaussians.get_xyz[:, 2] < -0.5 33 | # import pdb;pdb.set_trace() 34 | mask = (gaussians.get_scaling_t[:, 0] > args.separate_scaling_t) | high_mask 35 | for config in validation_configs: 36 | if config['cameras'] and len(config['cameras']) > 0: 37 | outdir = os.path.join(args.model_path, "separation", config['name']) 38 | os.makedirs(outdir,exist_ok=True) 39 | for idx, viewpoint in enumerate(tqdm(config['cameras'])): 40 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map) 41 | render_pkg_static = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=mask) 42 | 43 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 44 | image_static = torch.clamp(render_pkg_static["render"], 0.0, 1.0) 45 | 46 | save_image(image, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png")) 47 | save_image(image_static, os.path.join(outdir, f"{viewpoint.colmap_id:03d}_static.png")) 48 | 49 | 50 | if __name__ == "__main__": 51 | # Set up command line argument parser 52 | parser = ArgumentParser(description="Training script parameters") 53 | parser.add_argument("--config", type=str, required=True) 54 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml") 55 | args, _ = parser.parse_known_args() 56 | 57 | base_conf = OmegaConf.load(args.base_config) 58 | second_conf = OmegaConf.load(args.config) 59 | cli_conf = OmegaConf.from_cli() 60 | args = OmegaConf.merge(base_conf, second_conf, cli_conf) 61 | args.resolution_scales = args.resolution_scales[:1] 62 | print(args) 63 | 64 | seed_everything(args.seed) 65 | 66 | sep_path = os.path.join(args.model_path, 'separation') 67 | os.makedirs(sep_path, exist_ok=True) 68 | 69 | gaussians = GaussianModel(args) 70 | scene = Scene(args, gaussians, shuffle=False) 71 | 72 | if args.env_map_res > 0: 73 | env_map = EnvLight(resolution=args.env_map_res).cuda() 74 | env_map.training_setup(args) 75 | else: 76 | env_map = None 77 | 78 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth")) 79 | assert len(checkpoints) > 0, "No checkpoints found." 80 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1] 81 | (model_params, first_iter) = torch.load(checkpoint) 82 | gaussians.restore(model_params, args) 83 | 84 | if env_map is not None: 85 | env_checkpoint = os.path.join(os.path.dirname(checkpoint), 86 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt")) 87 | (light_params, _) = torch.load(env_checkpoint) 88 | env_map.restore(light_params) 89 | 90 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0] 91 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 92 | separation(scene, render, (args, background), env_map=env_map) 93 | 94 | print("\Rendering statics and dynamics complete.") 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import json 12 | import os 13 | from collections import defaultdict 14 | import torch 15 | import torch.nn.functional as F 16 | from random import randint 17 | from utils.loss_utils import psnr, ssim 18 | from gaussian_renderer import render 19 | from scene import Scene, GaussianModel, EnvLight 20 | from utils.general_utils import seed_everything, visualize_depth 21 | from tqdm import tqdm 22 | from argparse import ArgumentParser 23 | from torchvision.utils import make_grid, save_image 24 | import numpy as np 25 | import kornia 26 | from omegaconf import OmegaConf 27 | try: 28 | from torch.utils.tensorboard import SummaryWriter 29 | TENSORBOARD_FOUND = True 30 | except ImportError: 31 | TENSORBOARD_FOUND = False 32 | 33 | EPS = 1e-5 34 | def training(args): 35 | 36 | if TENSORBOARD_FOUND: 37 | tb_writer = SummaryWriter(args.model_path) 38 | else: 39 | tb_writer = None 40 | print("Tensorboard not available: not logging progress") 41 | vis_path = os.path.join(args.model_path, 'visualization') 42 | os.makedirs(vis_path, exist_ok=True) 43 | 44 | gaussians = GaussianModel(args) 45 | 46 | scene = Scene(args, gaussians) 47 | 48 | gaussians.training_setup(args) 49 | 50 | if args.env_map_res > 0: 51 | env_map = EnvLight(resolution=args.env_map_res).cuda() 52 | env_map.training_setup(args) 53 | else: 54 | env_map = None 55 | 56 | first_iter = 0 57 | if args.start_checkpoint: 58 | (model_params, first_iter) = torch.load(args.start_checkpoint) 59 | gaussians.restore(model_params, args) 60 | 61 | if env_map is not None: 62 | env_checkpoint = os.path.join(os.path.dirname(args.checkpoint), 63 | os.path.basename(args.checkpoint).replace("chkpnt", "env_light_chkpnt")) 64 | (light_params, _) = torch.load(env_checkpoint) 65 | env_map.restore(light_params) 66 | 67 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0] 68 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 69 | 70 | iter_start = torch.cuda.Event(enable_timing = True) 71 | iter_end = torch.cuda.Event(enable_timing = True) 72 | 73 | viewpoint_stack = None 74 | 75 | ema_dict_for_log = defaultdict(int) 76 | progress_bar = tqdm(range(first_iter + 1, args.iterations + 1), desc="Training progress") 77 | 78 | for iteration in progress_bar: 79 | iter_start.record() 80 | gaussians.update_learning_rate(iteration) 81 | 82 | # Every 1000 its we increase the levels of SH up to a maximum degree 83 | if iteration % args.sh_increase_interval == 0: 84 | gaussians.oneupSHdegree() 85 | 86 | if not viewpoint_stack: 87 | viewpoint_stack = list(range(len(scene.getTrainCameras()))) 88 | viewpoint_cam = scene.getTrainCameras()[viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))] 89 | 90 | # render v and t scale map 91 | v = gaussians.get_inst_velocity 92 | t_scale = gaussians.get_scaling_t.clamp_max(2) 93 | other = [t_scale, v] 94 | 95 | if np.random.random() < args.lambda_self_supervision: 96 | time_shift = 3*(np.random.random() - 0.5) * scene.time_interval 97 | else: 98 | time_shift = None 99 | 100 | render_pkg = render(viewpoint_cam, gaussians, args, background, env_map=env_map, other=other, time_shift=time_shift, is_training=True) 101 | 102 | image = render_pkg["render"] 103 | depth = render_pkg["depth"] 104 | alpha = render_pkg["alpha"] 105 | viewspace_point_tensor = render_pkg["viewspace_points"] 106 | visibility_filter = render_pkg["visibility_filter"] 107 | radii = render_pkg["radii"] 108 | log_dict = {} 109 | 110 | feature = render_pkg['feature'] / alpha.clamp_min(EPS) 111 | t_map = feature[0:1] 112 | v_map = feature[1:] 113 | 114 | sky_mask = viewpoint_cam.sky_mask.cuda() if viewpoint_cam.sky_mask is not None else torch.zeros_like(alpha, dtype=torch.bool) 115 | 116 | sky_depth = 900 117 | depth = depth / alpha.clamp_min(EPS) 118 | if env_map is not None: 119 | if args.depth_blend_mode == 0: # harmonic mean 120 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS) 121 | elif args.depth_blend_mode == 1: 122 | depth = alpha * depth + (1 - alpha) * sky_depth 123 | 124 | gt_image = viewpoint_cam.original_image.cuda() 125 | 126 | loss_l1 = F.l1_loss(image, gt_image) 127 | log_dict['loss_l1'] = loss_l1.item() 128 | loss_ssim = 1.0 - ssim(image, gt_image) 129 | log_dict['loss_ssim'] = loss_ssim.item() 130 | loss = (1.0 - args.lambda_dssim) * loss_l1 + args.lambda_dssim * loss_ssim 131 | 132 | if args.lambda_lidar > 0: 133 | assert viewpoint_cam.pts_depth is not None 134 | pts_depth = viewpoint_cam.pts_depth.cuda() 135 | 136 | mask = pts_depth > 0 137 | loss_lidar = torch.abs(1 / (pts_depth[mask] + 1e-5) - 1 / (depth[mask] + 1e-5)).mean() 138 | if args.lidar_decay > 0: 139 | iter_decay = np.exp(-iteration / 8000 * args.lidar_decay) 140 | else: 141 | iter_decay = 1 142 | log_dict['loss_lidar'] = loss_lidar.item() 143 | loss += iter_decay * args.lambda_lidar * loss_lidar 144 | 145 | if args.lambda_t_reg > 0: 146 | loss_t_reg = -torch.abs(t_map).mean() 147 | log_dict['loss_t_reg'] = loss_t_reg.item() 148 | loss += args.lambda_t_reg * loss_t_reg 149 | 150 | if args.lambda_v_reg > 0: 151 | loss_v_reg = torch.abs(v_map).mean() 152 | log_dict['loss_v_reg'] = loss_v_reg.item() 153 | loss += args.lambda_v_reg * loss_v_reg 154 | 155 | if args.lambda_inv_depth > 0: 156 | inverse_depth = 1 / (depth + 1e-5) 157 | loss_inv_depth = kornia.losses.inverse_depth_smoothness_loss(inverse_depth[None], gt_image[None]) 158 | log_dict['loss_inv_depth'] = loss_inv_depth.item() 159 | loss = loss + args.lambda_inv_depth * loss_inv_depth 160 | 161 | if args.lambda_v_smooth > 0: 162 | loss_v_smooth = kornia.losses.inverse_depth_smoothness_loss(v_map[None], gt_image[None]) 163 | log_dict['loss_v_smooth'] = loss_v_smooth.item() 164 | loss = loss + args.lambda_v_smooth * loss_v_smooth 165 | 166 | if args.lambda_sky_opa > 0: 167 | o = alpha.clamp(1e-6, 1-1e-6) 168 | sky = sky_mask.float() 169 | loss_sky_opa = (-sky * torch.log(1 - o)).mean() 170 | log_dict['loss_sky_opa'] = loss_sky_opa.item() 171 | loss = loss + args.lambda_sky_opa * loss_sky_opa 172 | 173 | if args.lambda_opacity_entropy > 0: 174 | o = alpha.clamp(1e-6, 1 - 1e-6) 175 | loss_opacity_entropy = -(o*torch.log(o)).mean() 176 | log_dict['loss_opacity_entropy'] = loss_opacity_entropy.item() 177 | loss = loss + args.lambda_opacity_entropy * loss_opacity_entropy 178 | 179 | loss.backward() 180 | log_dict['loss'] = loss.item() 181 | 182 | iter_end.record() 183 | 184 | with torch.no_grad(): 185 | psnr_for_log = psnr(image, gt_image).double() 186 | log_dict["psnr"] = psnr_for_log 187 | for key in ['loss', "loss_l1", "psnr"]: 188 | ema_dict_for_log[key] = 0.4 * log_dict[key] + 0.6 * ema_dict_for_log[key] 189 | 190 | if iteration % 10 == 0: 191 | postfix = {k[5:] if k.startswith("loss_") else k:f"{ema_dict_for_log[k]:.{5}f}" for k, v in ema_dict_for_log.items()} 192 | postfix["scale"] = scene.resolution_scales[scene.scale_index] 193 | progress_bar.set_postfix(postfix) 194 | 195 | log_dict['iter_time'] = iter_start.elapsed_time(iter_end) 196 | log_dict['total_points'] = gaussians.get_xyz.shape[0] 197 | # Log and save 198 | complete_eval(tb_writer, iteration, args.test_iterations, scene, render, (args, background), 199 | log_dict, env_map=env_map) 200 | 201 | # Densification 202 | if iteration > args.densify_until_iter * args.time_split_frac: 203 | gaussians.no_time_split = False 204 | 205 | if iteration < args.densify_until_iter and (args.densify_until_num_points < 0 or gaussians.get_xyz.shape[0] < args.densify_until_num_points): 206 | # Keep track of max radii in image-space for pruning 207 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 208 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 209 | if iteration > args.densify_from_iter and iteration % args.densification_interval == 0: 210 | size_threshold = args.size_threshold if (iteration > args.opacity_reset_interval and args.prune_big_point > 0) else None 211 | 212 | if size_threshold is not None: 213 | size_threshold = size_threshold // scene.resolution_scales[0] 214 | 215 | gaussians.densify_and_prune(args.densify_grad_threshold, args.thresh_opa_prune, scene.cameras_extent, size_threshold, args.densify_grad_t_threshold) 216 | 217 | if iteration % args.opacity_reset_interval == 0 or (args.white_background and iteration == args.densify_from_iter): 218 | gaussians.reset_opacity() 219 | 220 | gaussians.optimizer.step() 221 | gaussians.optimizer.zero_grad(set_to_none = True) 222 | if env_map is not None and iteration < args.env_optimize_until: 223 | env_map.optimizer.step() 224 | env_map.optimizer.zero_grad(set_to_none = True) 225 | torch.cuda.empty_cache() 226 | 227 | if iteration % args.vis_step == 0 or iteration == 1: 228 | other_img = [] 229 | feature = render_pkg['feature'] / alpha.clamp_min(1e-5) 230 | t_map = feature[0:1] 231 | v_map = feature[1:] 232 | v_norm_map = v_map.norm(dim=0, keepdim=True) 233 | 234 | et_color = visualize_depth(t_map, near=0.01, far=1) 235 | v_color = visualize_depth(v_norm_map, near=0.01, far=1) 236 | other_img.append(et_color) 237 | other_img.append(v_color) 238 | 239 | if viewpoint_cam.pts_depth is not None: 240 | pts_depth_vis = visualize_depth(viewpoint_cam.pts_depth) 241 | other_img.append(pts_depth_vis) 242 | 243 | grid = make_grid([ 244 | image, 245 | gt_image, 246 | alpha.repeat(3, 1, 1), 247 | torch.logical_not(sky_mask[:1]).float().repeat(3, 1, 1), 248 | visualize_depth(depth), 249 | ] + other_img, nrow=4) 250 | 251 | save_image(grid, os.path.join(vis_path, f"{iteration:05d}_{viewpoint_cam.colmap_id:03d}.png")) 252 | 253 | if iteration % args.scale_increase_interval == 0: 254 | scene.upScale() 255 | 256 | if iteration in args.checkpoint_iterations: 257 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 258 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 259 | torch.save((env_map.capture(), iteration), scene.model_path + "/env_light_chkpnt" + str(iteration) + ".pth") 260 | 261 | 262 | def complete_eval(tb_writer, iteration, test_iterations, scene : Scene, renderFunc, renderArgs, log_dict, env_map=None): 263 | from lpipsPyTorch import lpips 264 | 265 | if tb_writer: 266 | for key, value in log_dict.items(): 267 | tb_writer.add_scalar(f'train/{key}', value, iteration) 268 | 269 | if iteration in test_iterations: 270 | scale = scene.resolution_scales[scene.scale_index] 271 | if iteration < args.iterations: 272 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},) 273 | else: 274 | if "kitti" in args.model_path: 275 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766 276 | num = len(scene.getTrainCameras())//2 277 | eval_train_frame = num//5 278 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id) 279 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 280 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]}) 281 | else: 282 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 283 | {'name': 'train', 'cameras': scene.getTrainCameras()}) 284 | 285 | 286 | 287 | for config in validation_configs: 288 | if config['cameras'] and len(config['cameras']) > 0: 289 | l1_test = 0.0 290 | psnr_test = 0.0 291 | ssim_test = 0.0 292 | lpips_test = 0.0 293 | outdir = os.path.join(args.model_path, "eval", config['name'] + f"_{iteration}" + "_render") 294 | os.makedirs(outdir,exist_ok=True) 295 | for idx, viewpoint in enumerate(tqdm(config['cameras'])): 296 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map) 297 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 298 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 299 | depth = render_pkg['depth'] 300 | alpha = render_pkg['alpha'] 301 | sky_depth = 900 302 | depth = depth / alpha.clamp_min(EPS) 303 | if env_map is not None: 304 | if args.depth_blend_mode == 0: # harmonic mean 305 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS) 306 | elif args.depth_blend_mode == 1: 307 | depth = alpha * depth + (1 - alpha) * sky_depth 308 | 309 | depth = visualize_depth(depth) 310 | alpha = alpha.repeat(3, 1, 1) 311 | 312 | grid = [gt_image, image, alpha, depth] 313 | grid = make_grid(grid, nrow=2) 314 | 315 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png")) 316 | 317 | l1_test += F.l1_loss(image, gt_image).double() 318 | psnr_test += psnr(image, gt_image).double() 319 | ssim_test += ssim(image, gt_image).double() 320 | lpips_test += lpips(image, gt_image, net_type='vgg').double() # very slow 321 | 322 | psnr_test /= len(config['cameras']) 323 | l1_test /= len(config['cameras']) 324 | ssim_test /= len(config['cameras']) 325 | lpips_test /= len(config['cameras']) 326 | 327 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test)) 328 | if tb_writer: 329 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 330 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 331 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', ssim_test, iteration) 332 | with open(os.path.join(outdir, "metrics.json"), "w") as f: 333 | json.dump({"split": config['name'], "iteration": iteration, "psnr": psnr_test.item(), "ssim": ssim_test.item(), "lpips": lpips_test.item()}, f) 334 | torch.cuda.empty_cache() 335 | 336 | 337 | if __name__ == "__main__": 338 | # Set up command line argument parser 339 | parser = ArgumentParser(description="Training script parameters") 340 | parser.add_argument("--config", type=str, required=True) 341 | parser.add_argument("--base_config", type=str, default = "configs/base.yaml") 342 | args, _ = parser.parse_known_args() 343 | 344 | base_conf = OmegaConf.load(args.base_config) 345 | second_conf = OmegaConf.load(args.config) 346 | cli_conf = OmegaConf.from_cli() 347 | args = OmegaConf.merge(base_conf, second_conf, cli_conf) 348 | print(args) 349 | 350 | args.save_iterations.append(args.iterations) 351 | args.checkpoint_iterations.append(args.iterations) 352 | args.test_iterations.append(args.iterations) 353 | 354 | if args.exhaust_test: 355 | args.test_iterations += [i for i in range(0,args.iterations, args.test_interval)] 356 | 357 | print("Optimizing " + args.model_path) 358 | 359 | seed_everything(args.seed) 360 | 361 | training(args) 362 | 363 | # All done 364 | print("\nTraining complete.") 365 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import cv2 14 | from scene.cameras import Camera 15 | import numpy as np 16 | from scene.scene_utils import CameraInfo 17 | from tqdm import tqdm 18 | from .graphics_utils import fov2focal 19 | 20 | 21 | def loadCam(args, id, cam_info: CameraInfo, resolution_scale): 22 | orig_w, orig_h = cam_info.width, cam_info.height # cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 3, 4, 8, 16, 32]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution) 27 | ) 28 | scale = resolution_scale * args.resolution 29 | else: # should be a type that converts to float 30 | if args.resolution == -1: 31 | global_down = 1 32 | else: 33 | global_down = orig_w / args.resolution 34 | 35 | scale = float(global_down) * float(resolution_scale) 36 | resolution = (int(orig_w / scale), int(orig_h / scale)) 37 | 38 | if cam_info.cx: 39 | cx = cam_info.cx / scale 40 | cy = cam_info.cy / scale 41 | fy = cam_info.fy / scale 42 | fx = cam_info.fx / scale 43 | else: 44 | cx = None 45 | cy = None 46 | fy = None 47 | fx = None 48 | 49 | if cam_info.image.shape[:2] != resolution[::-1]: 50 | image_rgb = cv2.resize(cam_info.image, resolution) 51 | else: 52 | image_rgb = cam_info.image 53 | image_rgb = torch.from_numpy(image_rgb).float().permute(2, 0, 1) 54 | gt_image = image_rgb[:3, ...] 55 | 56 | if cam_info.sky_mask is not None: 57 | if cam_info.sky_mask.shape[:2] != resolution[::-1]: 58 | sky_mask = cv2.resize(cam_info.sky_mask, resolution) 59 | else: 60 | sky_mask = cam_info.sky_mask 61 | if len(sky_mask.shape) == 2: 62 | sky_mask = sky_mask[..., None] 63 | sky_mask = torch.from_numpy(sky_mask).float().permute(2, 0, 1) 64 | else: 65 | sky_mask = None 66 | 67 | if cam_info.pointcloud_camera is not None: 68 | h, w = gt_image.shape[1:] 69 | K = np.eye(3) 70 | if cam_info.cx: 71 | K[0, 0] = fx 72 | K[1, 1] = fy 73 | K[0, 2] = cx 74 | K[1, 2] = cy 75 | else: 76 | K[0, 0] = fov2focal(cam_info.FovX, w) 77 | K[1, 1] = fov2focal(cam_info.FovY, h) 78 | K[0, 2] = cam_info.width / 2 79 | K[1, 2] = cam_info.height / 2 80 | pts_depth = np.zeros([1, h, w]) 81 | point_camera = cam_info.pointcloud_camera 82 | uvz = point_camera[point_camera[:, 2] > 0] 83 | uvz = uvz @ K.T 84 | uvz[:, :2] /= uvz[:, 2:] 85 | uvz = uvz[uvz[:, 1] >= 0] 86 | uvz = uvz[uvz[:, 1] < h] 87 | uvz = uvz[uvz[:, 0] >= 0] 88 | uvz = uvz[uvz[:, 0] < w] 89 | uv = uvz[:, :2] 90 | uv = uv.astype(int) 91 | # TODO: may need to consider overlap 92 | pts_depth[0, uv[:, 1], uv[:, 0]] = uvz[:, 2] 93 | pts_depth = torch.from_numpy(pts_depth).float() 94 | else: 95 | pts_depth = None 96 | 97 | return Camera( 98 | colmap_id=cam_info.uid, 99 | uid=id, 100 | R=cam_info.R, 101 | T=cam_info.T, 102 | FoVx=cam_info.FovX, 103 | FoVy=cam_info.FovY, 104 | cx=cx, 105 | cy=cy, 106 | fx=fx, 107 | fy=fy, 108 | image=gt_image, 109 | image_name=cam_info.image_name, 110 | data_device=args.data_device, 111 | timestamp=cam_info.timestamp, 112 | resolution=resolution, 113 | image_path=cam_info.image_path, 114 | pts_depth=pts_depth, 115 | sky_mask=sky_mask, 116 | ) 117 | 118 | 119 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 120 | camera_list = [] 121 | 122 | for id, c in enumerate(tqdm(cam_infos)): 123 | camera_list.append(loadCam(args, id, c, resolution_scale)) 124 | 125 | return camera_list 126 | 127 | 128 | def camera_to_JSON(id, camera: Camera): 129 | Rt = np.zeros((4, 4)) 130 | Rt[:3, :3] = camera.R.transpose() 131 | Rt[:3, 3] = camera.T 132 | Rt[3, 3] = 1.0 133 | 134 | W2C = np.linalg.inv(Rt) 135 | pos = W2C[:3, 3] 136 | rot = W2C[:3, :3] 137 | serializable_array_2d = [x.tolist() for x in rot] 138 | 139 | if camera.cx is None: 140 | camera_entry = { 141 | "id": id, 142 | "img_name": camera.image_name, 143 | "width": camera.width, 144 | "height": camera.height, 145 | "position": pos.tolist(), 146 | "rotation": serializable_array_2d, 147 | "FoVx": camera.FovX, 148 | "FoVy": camera.FovY, 149 | } 150 | else: 151 | camera_entry = { 152 | "id": id, 153 | "img_name": camera.image_name, 154 | "width": camera.width, 155 | "height": camera.height, 156 | "position": pos.tolist(), 157 | "rotation": serializable_array_2d, 158 | "fx": camera.fx, 159 | "fy": camera.fy, 160 | "cx": camera.cx, 161 | "cy": camera.cy, 162 | } 163 | return camera_entry 164 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import numpy as np 15 | import random 16 | from matplotlib import cm 17 | 18 | def visualize_depth(depth, near=0.2, far=13, linear=False): 19 | depth = depth[0].clone().detach().cpu().numpy() 20 | colormap = cm.get_cmap('turbo') 21 | curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) 22 | if linear: 23 | curve_fn = lambda x: -x 24 | eps = np.finfo(np.float32).eps 25 | near = near if near else depth.min() 26 | far = far if far else depth.max() 27 | near -= eps 28 | far += eps 29 | near, far, depth = [curve_fn(x) for x in [near, far, depth]] 30 | depth = np.nan_to_num( 31 | np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1)) 32 | vis = colormap(depth)[:, :, :3] 33 | out_depth = np.clip(np.nan_to_num(vis), 0., 1.) * 255 34 | out_depth = torch.from_numpy(out_depth).permute(2, 0, 1).float().cuda() / 255 35 | return out_depth 36 | 37 | 38 | def inverse_sigmoid(x): 39 | return torch.log(x / (1 - x)) 40 | 41 | 42 | def PILtoTorch(pil_image, resolution): 43 | resized_image_PIL = pil_image.resize(resolution) 44 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 45 | if len(resized_image.shape) == 3: 46 | return resized_image.permute(2, 0, 1) 47 | else: 48 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 49 | 50 | 51 | def get_step_lr_func(lr_init, lr_final, start_step): 52 | def helper(step): 53 | if step < start_step: 54 | return lr_init 55 | else: 56 | return lr_final 57 | return helper 58 | 59 | def get_expon_lr_func( 60 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 61 | ): 62 | """ 63 | Copied from Plenoxels 64 | 65 | Continuous learning rate decay function. Adapted from JaxNeRF 66 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 67 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 68 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 69 | function of lr_delay_mult, such that the initial learning rate is 70 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 71 | to the normal learning rate when steps>lr_delay_steps. 72 | :param conf: config subtree 'lr' or similar 73 | :param max_steps: int, the number of steps during optimization. 74 | :return HoF which takes step as input 75 | """ 76 | 77 | def helper(step): 78 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 79 | # Disable this parameter 80 | return 0.0 81 | if lr_delay_steps > 0: 82 | # A kind of reverse cosine decay. 83 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 84 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 85 | ) 86 | else: 87 | delay_rate = 1.0 88 | t = np.clip(step / max_steps, 0, 1) 89 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 90 | return delay_rate * log_lerp 91 | 92 | return helper 93 | 94 | 95 | def strip_lowerdiag(L): 96 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 97 | 98 | uncertainty[:, 0] = L[:, 0, 0] 99 | uncertainty[:, 1] = L[:, 0, 1] 100 | uncertainty[:, 2] = L[:, 0, 2] 101 | uncertainty[:, 3] = L[:, 1, 1] 102 | uncertainty[:, 4] = L[:, 1, 2] 103 | uncertainty[:, 5] = L[:, 2, 2] 104 | return uncertainty 105 | 106 | 107 | def strip_symmetric(sym): 108 | return strip_lowerdiag(sym) 109 | 110 | 111 | def build_rotation(r): 112 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 113 | 114 | q = r / norm[:, None] 115 | 116 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 117 | 118 | r = q[:, 0] 119 | x = q[:, 1] 120 | y = q[:, 2] 121 | z = q[:, 3] 122 | 123 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 124 | R[:, 0, 1] = 2 * (x * y - r * z) 125 | R[:, 0, 2] = 2 * (x * z + r * y) 126 | R[:, 1, 0] = 2 * (x * y + r * z) 127 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 128 | R[:, 1, 2] = 2 * (y * z - r * x) 129 | R[:, 2, 0] = 2 * (x * z - r * y) 130 | R[:, 2, 1] = 2 * (y * z + r * x) 131 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 132 | return R 133 | 134 | 135 | def build_scaling_rotation(s, r): 136 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 137 | R = build_rotation(r) 138 | 139 | L[:, 0, 0] = s[:, 0] 140 | L[:, 1, 1] = s[:, 1] 141 | L[:, 2, 2] = s[:, 2] 142 | 143 | L = R @ L 144 | return L 145 | 146 | 147 | def seed_everything(seed): 148 | random.seed(seed) 149 | os.environ['PYTHONHASHSEED'] = str(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | 155 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | time : np.array = None 22 | 23 | def getWorld2View(R, t): 24 | Rt = np.zeros((4, 4)) 25 | Rt[:3, :3] = R.transpose() 26 | Rt[:3, 3] = t 27 | Rt[3, 3] = 1.0 28 | return np.float32(Rt) 29 | 30 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 31 | Rt = np.zeros((4, 4)) 32 | Rt[:3, :3] = R.transpose() 33 | Rt[:3, 3] = t 34 | Rt[3, 3] = 1.0 35 | 36 | C2W = np.linalg.inv(Rt) 37 | cam_center = C2W[:3, 3] 38 | cam_center = (cam_center + translate) * scale 39 | C2W[:3, 3] = cam_center 40 | Rt = np.linalg.inv(C2W) 41 | return np.float32(Rt) 42 | 43 | def getProjectionMatrix(znear, zfar, fovX, fovY): 44 | tanHalfFovY = math.tan((fovY / 2)) 45 | tanHalfFovX = math.tan((fovX / 2)) 46 | 47 | top = tanHalfFovY * znear 48 | bottom = -top 49 | right = tanHalfFovX * znear 50 | left = -right 51 | 52 | P = torch.zeros(4, 4) 53 | 54 | z_sign = 1.0 55 | 56 | P[0, 0] = 2.0 * znear / (right - left) 57 | P[1, 1] = 2.0 * znear / (top - bottom) 58 | P[0, 2] = (right + left) / (right - left) 59 | P[1, 2] = (top + bottom) / (top - bottom) 60 | P[3, 2] = z_sign 61 | P[2, 2] = z_sign * zfar / (zfar - znear) 62 | P[2, 3] = -(zfar * znear) / (zfar - znear) 63 | return P 64 | 65 | def getProjectionMatrixCenterShift(znear, zfar, cx, cy, fx, fy, w, h): 66 | top = cy / fy * znear 67 | bottom = -(h-cy) / fy * znear 68 | 69 | left = -(w-cx) / fx * znear 70 | right = cx / fx * znear 71 | 72 | P = torch.zeros(4, 4) 73 | 74 | z_sign = 1.0 75 | 76 | P[0, 0] = 2.0 * znear / (right - left) 77 | P[1, 1] = 2.0 * znear / (top - bottom) 78 | P[0, 2] = (right + left) / (right - left) 79 | P[1, 2] = (top + bottom) / (top - bottom) 80 | P[3, 2] = z_sign 81 | P[2, 2] = z_sign * zfar / (zfar - znear) 82 | P[2, 3] = -(zfar * znear) / (zfar - znear) 83 | return P 84 | 85 | def fov2focal(fov, pixels): 86 | return pixels / (2 * math.tan(fov / 2)) 87 | 88 | def focal2fov(focal, pixels): 89 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def psnr(img1, img2): 18 | mse = F.mse_loss(img1, img2) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | 21 | def gaussian(window_size, sigma): 22 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 23 | return gauss / gauss.sum() 24 | 25 | def create_window(window_size, channel): 26 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 27 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 28 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 29 | return window 30 | 31 | def ssim(img1, img2, window_size=11, size_average=True): 32 | channel = img1.size(-3) 33 | window = create_window(window_size, channel) 34 | 35 | if img1.is_cuda: 36 | window = window.cuda(img1.get_device()) 37 | window = window.type_as(img1) 38 | 39 | return _ssim(img1, img2, window, window_size, channel, size_average) 40 | 41 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 42 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 43 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 44 | 45 | mu1_sq = mu1.pow(2) 46 | mu2_sq = mu2.pow(2) 47 | mu1_mu2 = mu1 * mu2 48 | 49 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 50 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 51 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 52 | 53 | C1 = 0.01 ** 2 54 | C2 = 0.03 ** 2 55 | 56 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 57 | 58 | if size_average: 59 | return ssim_map.mean() 60 | else: 61 | return ssim_map.mean(1).mean(1).mean(1) 62 | 63 | def tv_loss(depth): 64 | c, h, w = depth.shape[0], depth.shape[1], depth.shape[2] 65 | count_h = c * (h - 1) * w 66 | count_w = c * h * (w - 1) 67 | h_tv = torch.square(depth[..., 1:, :] - depth[..., :h-1, :]).sum() 68 | w_tv = torch.square(depth[..., :, 1:] - depth[..., :, :w-1]).sum() 69 | return 2 * (h_tv / count_h + w_tv / count_w) 70 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | 15 | def searchForMaxIteration(folder): 16 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 17 | return max(saved_iters) 18 | 19 | 20 | class Timing: 21 | """ 22 | From https://github.com/sxyu/svox2/blob/ee80e2c4df8f29a407fda5729a494be94ccf9234/svox2/utils.py#L611 23 | 24 | Timing environment 25 | usage: 26 | with Timing("message"): 27 | your commands here 28 | will print CUDA runtime in ms 29 | """ 30 | 31 | def __init__(self, name): 32 | self.name = name 33 | 34 | def __enter__(self): 35 | self.start = torch.cuda.Event(enable_timing=True) 36 | self.end = torch.cuda.Event(enable_timing=True) 37 | self.start.record() 38 | 39 | def __exit__(self, type, value, traceback): 40 | self.end.record() 41 | torch.cuda.synchronize() 42 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") --------------------------------------------------------------------------------