├── .gitignore ├── README.md ├── assets ├── combined_normal.gif ├── combined_render.gif ├── combined_video.gif ├── depth.gif ├── pipeline.png ├── qual.png └── teaser.png ├── configs ├── base.yaml ├── emer_reconstruction_stage1.yaml └── emer_reconstruction_stage2.yaml ├── docs └── dataset_prep.md ├── evaluate.py ├── gaussian_renderer ├── __init__.py ├── gs_render.py └── pvg_render.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── requirements.txt ├── scene ├── __init__.py ├── cameras.py ├── dinov2.py ├── dynamic_model.py ├── emer_waymo_loader.py ├── emernerf_loader.py ├── envlight.py ├── fit3d.py ├── gaussian_model.py ├── kittimot_loader.py ├── scene_utils.py └── waymo_loader.py ├── scripts ├── extract_mask_kitti.py ├── extract_mask_waymo.py ├── extract_mono_cues_kitti.py ├── extract_mono_cues_notr.py ├── extract_mono_cues_waymo.py ├── waymo_converter.py └── waymo_download.py ├── separate.py ├── train.py ├── utils ├── camera_utils.py ├── feature_extractor.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py └── visualize_gs.py /.gitignore: -------------------------------------------------------------------------------- 1 | work_dir*/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | data/ 11 | paper_code/ 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | __pycache__/ 32 | *.so 33 | build/ 34 | *.egg-info/ 35 | .vscode 36 | 37 | 38 | build 39 | data/ 40 | dataset/ 41 | output 42 | eval_output 43 | diff-gaussian-rasterization 44 | nvdiffrast 45 | simple-knn 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # poetry 115 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 116 | # This is especially recommended for binary packages to ensure reproducibility, and is more 117 | # commonly ignored for libraries. 118 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 119 | #poetry.lock 120 | 121 | # pdm 122 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 123 | #pdm.lock 124 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 125 | # in version control. 126 | # https://pdm.fming.dev/#use-with-ide 127 | .pdm.toml 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # Cython debug symbols 170 | cython_debug/ 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | 179 | 180 | # custom 181 | # logs 182 | data/ 183 | local_scripts/ 184 | package_hacks/ 185 | 186 | # caches 187 | *.pyc 188 | *.swp 189 | 190 | # media 191 | *.mp4 192 | *.png 193 | *.jpg 194 | 195 | 196 | # wandb 197 | wandb/ 198 | 199 | # work in progress 200 | *wip* 201 | 202 | *results* 203 | *debug* 204 | 205 | network/src/* 206 | 207 | network/threestudio_*/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

DeSiRe-GS

2 | 3 |
4 |

5 | teaser 6 |

7 |
8 | 9 | 10 | > [**DeSiRe-GS: 4D Street Gaussians for Static-Dynamic Decomposition and Surface Reconstruction for Urban Driving Scenes**](https://arxiv.org/abs/2411.11921) 11 | > 12 | > [Chensheng Peng](https://pholypeng.github.io/), [Chengwei Zhang](https://chengweialan.github.io/), [Yixiao Wang](https://yixiaowang7.github.io/), [Chenfeng Xu](https://www.chenfengx.com/), [Yichen Xie](https://scholar.google.com/citations?user=SdX6DaEAAAAJ), [Wenzhao Zheng](https://wzzheng.net/), [Kurt Keutzer](https://people.eecs.berkeley.edu/~keutzer/), [Masayoshi Tomizuka](https://me.berkeley.edu/people/masayoshi-tomizuka/), [Wei Zhan](https://zhanwei.site/) 13 | > 14 | > **Arxiv preprint** 15 | 16 | 17 | 18 | ## 📖 Overview 19 | 20 |
21 |

22 | pipeline 23 |

24 |
25 | In the first stage, we extract 2D motion masks based on the observation that 3D Gaussian Splatting inherently can reconstruct only the static regions in dynamic environments. These extracted 2D motion priors are then mapped into the Gaussian space in a differentiable manner, leveraging an efficient formulation of dynamic Gaussians in the second stage. 26 | 27 | 28 | ## 🛠️ Installation 29 | 30 | We test our code on Ubuntu 20.04 using Python 3.10 and PyTorch 2.2.0. We recommend using conda to install all the independencies. 31 | 32 | 33 | 1. Create the conda environment and install requirements. 34 | 35 | ``` 36 | # Clone the repo. 37 | git clone https://github.com/chengweialan/DeSiRe-GS.git 38 | cd DeSiRe-GS 39 | 40 | # Create the conda environment. 41 | conda create -n DeSiReGS python==3.10 42 | conda activate DeSiReGS 43 | 44 | # Install torch. 45 | pip install torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu118 # replace with your own CUDA version 46 | 47 | # Install requirements. 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | 2. Install the submodules. The repository contains the same submodules as [PVG](https://github.com/fudan-zvg/PVG). 52 | 53 | ``` 54 | # Install simple-knn 55 | git clone https://gitlab.inria.fr/bkerbl/simple-knn.git 56 | pip install ./simple-knn 57 | 58 | # a modified gaussian splatting (for feature rendering) 59 | git clone --recursive https://github.com/SuLvXiangXin/diff-gaussian-rasterization 60 | pip install ./diff-gaussian-rasterization 61 | 62 | # Install nvdiffrast (for Envlight) 63 | git clone https://github.com/NVlabs/nvdiffrast 64 | pip install ./nvdiffrast 65 | ``` 66 | 67 | 68 | 69 | ## 💾 Data Preparation 70 | 71 | Create a directory to save the data. Run ```mkdir dataset```. 72 | 73 | We provide a sample sequence in [google drive](https://drive.google.com/drive/u/0/folders/1fHQJy0cq9ofADpCxtlfmpCh6nCpcw-mH), you may download it and unzip it to `dataset`. 74 |
75 | Waymo Dataset 76 | 77 | 78 | | Source | Number of Sequences | Scene Type | Description | 79 | | :--------------------------------------------: | :-----------------: | :---------------------: | ------------------------------------------------------------ | 80 | | [PVG](https://github.com/fudan-zvg/PVG) | 4 | Dynamic | • Refer to [this page](https://github.com/fudan-zvg/PVG?tab=readme-ov-file#data-preparation). | 81 | | [OmniRe](https://ziyc.github.io/omnire/) | 8 | Dynamic | • Described as highly complex dynamic
• Refer to [this page](https://github.com/ziyc/drivestudio/blob/main/docs/Waymo.md). | 82 | | [EmerNeRF](https://github.com/NVlabs/EmerNeRF) | 64 | 32 dynamic
32 static | • Contains 32 static, 32 dynamic and 56 diverse scenes.
• We test our code on the 32 static and 32 dynamic scenes.
• See [this page](https://github.com/NVlabs/EmerNeRF?tab=readme-ov-file#dataset-preparation) for detailed instructions. | 83 |
84 | 85 |
86 | KITTI Dataset 87 | 88 | | Source | Number of Sequences | Scene Type | Description | 89 | | :-------------------------------------: | :-----------------: | :--------: | ------------------------------------------------------------ | 90 | | [PVG](https://github.com/fudan-zvg/PVG) | 3 | Dynamic | • Refer to [this page](https://github.com/fudan-zvg/PVG?tab=readme-ov-file#kitti-dataset). | 91 |
92 | 93 | 94 | ## :memo: Training and Evaluation 95 | 96 | ### Training 97 | 98 | First, we use the following command to train for stage I, 99 | 100 | ```sh 101 | # Stage 1 102 | python train.py \ 103 | --config configs/emer_reconstruction_stage1.yaml \ 104 | source_path=dataset/084 \ 105 | model_path=eval_output/waymo_reconstruction/084_stage1 106 | ``` 107 | 108 | After running the command, the uncertainty model will be saved in ```${YOUR_MODEL_PATH}/uncertainty_model.pth``` by default. 109 | 110 | ```sh 111 | # Stage 2 112 | python train.py \ 113 | --config configs/emer_reconstruction_stage2.yaml \ 114 | source_path=dataset/084 \ 115 | model_path=eval_output/waymo_reconstruction/084_stage2 \ 116 | uncertainty_model_path=eval_output/waymo_reconstruction/084_stage1/uncertainty_model30000.pth 117 | ``` 118 | 119 | ### Evaluating 120 | 121 | We provide the checkpoints in [google drive](https://drive.google.com/drive/u/0/folders/1fHQJy0cq9ofADpCxtlfmpCh6nCpcw-mH), you may download it and unzip it under `${PROJECT_FOLDER}` 122 | 123 | ```sh 124 | python evaluate.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml 125 | ``` 126 | 127 | ### Static-Dynamic Decomposition 128 | 129 | We provide code ```separate.py``` for static-dynamic decomposition. Run 130 | 131 | ``` 132 | python separate.py --config_path ${YOUR_MODEL_PATH}/config.yaml 133 | ``` 134 | 135 | For instance, 136 | 137 | ``` 138 | # example 139 | python separate.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml 140 | ``` 141 | 142 | The decomposition results will be saved in `${MODEL_PATH}/separation` 143 | 144 | 145 | 146 | 147 | 152 | 155 | 156 | 157 | 162 | 165 | 166 | 167 | 172 | 175 | 176 | 177 | 182 | 185 | 186 |
148 |

149 | Rendered Image 150 |

151 |
153 | Rendered Image 154 |
158 |

159 | Decomposed Static 160 |

161 |
163 | Decomposition 164 |
168 |

169 | Rendered Depth 170 |

171 |
173 | Rendered Depth 174 |
178 |

179 | Rendered Normal 180 |

181 |
183 | Rendered Normal 184 |
187 | 188 | 189 | 190 | 191 | 192 | 193 | ## :clapper: Visualization 194 | 195 | ### 3D Gaussians Visualization 196 | 197 | Following [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), we use [SIBR](https://sibr.gitlabpages.inria.fr/) framework, which is developed by GRAPHDECO group, as an interactive viewer to visualize the gaussian ellipsoids. Refer to [this page](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#interactive-viewers) for more installation details. 198 | 199 | We provide code ```visualize_gs.py``` for gaussian ellipsoids visualization. For example, run 200 | 201 | ``` 202 | # Save gaussian point cloud. 203 | python visualize_gs.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml 204 | ``` 205 | 206 | The ```.ply``` file which contains visible gaussians will be saved in ```${YOUR_MODEL_PATH}/point_cloud/point_cloud.ply```. You can use SIBR Viewer to visualize the gaussians directly in your model path folder. For example, 207 | 208 | ``` 209 | # Enter your SIBR folder. 210 | cd ${SIBR_FOLDER}/SIBR_viewers/install/bin 211 | 212 | # Visualize the gaussians. 213 | ./SIBR_gaussianViewer_app -m ${YOUR_MODEL_PATH}/ 214 | 215 | # Example 216 | ./SIBR_gaussianViewer_app -m ${PROJECT_FOLDER}/eval_output/waymo_reconstruction/084_stage2/ 217 | ``` 218 |
219 |

220 | qual 221 |

222 |
223 | 224 | 225 | ## 📜 BibTeX 226 | 227 | ```bibtex 228 | @misc{peng2024desiregs4dstreetgaussians, 229 | title={DeSiRe-GS: 4D Street Gaussians for Static-Dynamic Decomposition and Surface Reconstruction for Urban Driving Scenes}, 230 | author={Chensheng Peng and Chengwei Zhang and Yixiao Wang and Chenfeng Xu and Yichen Xie and Wenzhao Zheng and Kurt Keutzer and Masayoshi Tomizuka and Wei Zhan}, 231 | year={2024}, 232 | eprint={2411.11921}, 233 | archivePrefix={arXiv}, 234 | primaryClass={cs.CV}, 235 | } 236 | ``` 237 | 238 | 239 | 240 | 241 | ## :pray: Acknowledgements 242 | 243 | We adapted some codes from some awesome repositories including [PVG](https://github.com/fudan-zvg/PVG) and [PGSR](https://github.com/zju3dv/PGSR/). 244 | 245 | -------------------------------------------------------------------------------- /assets/combined_normal.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_normal.gif -------------------------------------------------------------------------------- /assets/combined_render.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_render.gif -------------------------------------------------------------------------------- /assets/combined_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_video.gif -------------------------------------------------------------------------------- /assets/depth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/depth.gif -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/pipeline.png -------------------------------------------------------------------------------- /assets/qual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/qual.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/teaser.png -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | test_iterations: [7000, 30000] 2 | save_iterations: [7000, 30000] 3 | checkpoint_iterations: [7000, 30000] 4 | exhaust_test: false 5 | test_interval: 5000 6 | render_static: false 7 | vis_step: 500 8 | start_checkpoint: null 9 | seed: 0 10 | resume: false 11 | 12 | # ModelParams 13 | sh_degree: 3 14 | scene_type: "Waymo" 15 | source_path: ??? 16 | start_frame: 65 # for kitti 17 | end_frame: 120 # for kitti 18 | model_path: ??? 19 | resolution_scales: [1] 20 | resolution: -1 21 | white_background: false 22 | data_device: "cuda" 23 | eval: false 24 | debug_cuda: false 25 | cam_num: 3 # for waymo 26 | t_init: 0.2 27 | cycle: 0.2 28 | velocity_decay: 1.0 29 | random_init_point: 200000 30 | fix_radius: 0.0 31 | time_duration: [-0.5, 0.5] 32 | num_pts: 100000 33 | frame_interval: 0.02 34 | testhold: 4 # NVS 35 | env_map_res: 1024 36 | separate_scaling_t: 0.1 37 | neg_fov: true 38 | 39 | # PipelineParams 40 | convert_SHs_python: false 41 | compute_cov3D_python: false 42 | debug: false 43 | depth_blend_mode: 0 44 | env_optimize_until: 1000000000 45 | env_optimize_from: 0 46 | 47 | 48 | # OptimizationParams 49 | iterations: 30000 50 | position_lr_init: 0.00016 51 | position_lr_final: 0.0000016 52 | position_lr_delay_mult: 0.01 53 | t_lr_init: 0.0008 54 | position_lr_max_steps: 30_000 55 | feature_lr: 0.0025 56 | opacity_lr: 0.05 57 | scaling_lr: 0.005 58 | scaling_t_lr: 0.002 59 | velocity_lr: 0.001 60 | rotation_lr: 0.001 61 | envmap_lr: 0.01 62 | normal_lr: 0.001 63 | 64 | time_split_frac: 0.5 65 | percent_dense: 0.01 66 | thresh_opa_prune: 0.005 67 | densification_interval: 100 68 | opacity_reset_interval: 3000 69 | densify_from_iter: 500 70 | densify_until_iter: 15_000 71 | densify_grad_threshold: 0.0002 72 | densify_grad_t_threshold: 0.002 73 | densify_until_num_points: 3000000 74 | sh_increase_interval: 1000 75 | scale_increase_interval: 5000 76 | prune_big_point: 1 77 | size_threshold: 20 78 | big_point_threshold: 0.1 79 | t_grad: true 80 | no_time_split: true 81 | contract: true 82 | 83 | lambda_dssim: 0.2 84 | lambda_opa: 0.0 85 | lambda_sky_opa: 0.05 86 | lambda_opacity_entropy: 0.05 87 | lambda_inv_depth: 0.001 88 | lambda_self_supervision: 0.5 # for random sampling to self-supervise 89 | lambda_t_reg: 0.0 90 | lambda_v_reg: 0.0 91 | lambda_lidar: 0.1 92 | lidar_decay: 1.0 93 | lambda_v_smooth: 0.01 94 | lambda_t_smooth: 0.0 95 | lambda_normal: 0.0 -------------------------------------------------------------------------------- /configs/emer_reconstruction_stage1.yaml: -------------------------------------------------------------------------------- 1 | test_iterations: [30000] 2 | save_iterations: [30000] 3 | checkpoint_iterations: [30000] 4 | isotropic: false 5 | exhaust_test: false 6 | 7 | enable_dynamic: false 8 | scale_increase_interval: 1500 9 | uncertainty_warmup_iters: 1800 10 | uncertainty_warmup_start: 6000 11 | 12 | # ModelParams 13 | scene_type: "EmerWaymo" 14 | resolution_scales: [1, 2, 4, 8, 16] 15 | cam_num: 3 16 | eval: false 17 | num_pts: 600000 18 | t_init: 0.1 19 | time_duration: [0, 1] 20 | separate_scaling_t: 0.2 21 | start_time: 0 22 | end_time: 49 23 | stride: 0 24 | original_start_time: 0 25 | 26 | load_sky_mask: true 27 | load_panoptic_mask: false 28 | load_sam_mask: false 29 | load_dynamic_mask: true 30 | load_feat_map: false 31 | load_normal_map: false 32 | 33 | load_intrinsic: false 34 | load_c2w: false 35 | 36 | save_occ_grid: true 37 | occ_voxel_size: 0.4 38 | recompute_occ_grid: false 39 | use_bg_gs: false 40 | white_background: false 41 | # PipelineParams 42 | 43 | 44 | # OptimizationParams 45 | iterations: 30000 46 | 47 | dynamic_mask_epoch: 50000 48 | 49 | opacity_lr: 0.005 50 | 51 | densify_until_iter: 15000 52 | densify_grad_threshold: 0.00017 53 | sh_increase_interval: 2000 54 | 55 | 56 | lambda_v_reg: 0.01 57 | lambda_normal: 0.0 58 | lambda_lidar: 0.0 59 | lambda_scaling: 0.0 60 | lambda_min_scale: 0.0 61 | lambda_max_scale: 0.0 62 | 63 | 64 | # uncertainty model 65 | uncertainty_stage: stage1 66 | uncertainty_mode: "dino" # ["disabled", "l2reg", "l1reg", "dino", "dino+mssim"] 67 | uncertainty_backbone: "dinov2_vits14_reg" 68 | uncertainty_regularizer_weight: 0.5 69 | uncertainty_clip_min: 0.1 70 | uncertainty_mask_clip_max: null 71 | uncertainty_dssim_clip_max: 1.0 # 0.05 -> 0.005 72 | uncertainty_lr: 0.001 73 | uncertainty_dropout: 0.1 74 | uncertainty_dino_max_size: null 75 | uncertainty_scale_grad: false 76 | uncertainty_center_mult: false 77 | uncertainty_after_opacity_reset: 1000 78 | uncertainty_protected_iters: 500 79 | uncertainty_preserve_sky: false 80 | 81 | 82 | render_type: pvg 83 | 84 | 85 | multi_view_weight_from_iter: 50000 86 | multi_view_patch_size: 3 87 | multi_view_sample_num: 102400 88 | multi_view_ncc_weight: 0.15 89 | multi_view_geo_weight: 0.0 90 | multi_view_pixel_noise_th: 1.0 91 | -------------------------------------------------------------------------------- /configs/emer_reconstruction_stage2.yaml: -------------------------------------------------------------------------------- 1 | exhaust_test: false 2 | isotropic: false 3 | enable_dynamic: true 4 | 5 | scale_increase_interval: 5000 6 | # ModelParams 7 | scene_type: "EmerWaymo" 8 | resolution_scales: [1, 2, 4, 8, 16] 9 | cam_num: 3 10 | eval: false 11 | num_pts: 600000 12 | t_init: 0.1 13 | time_duration: [0, 1] 14 | separate_scaling_t: 0.2 15 | separate_velocity: 10 16 | start_time: 0 17 | end_time: 49 18 | stride: 0 19 | original_start_time: 0 20 | 21 | load_sky_mask: true 22 | load_panoptic_mask: false 23 | load_sam_mask: false 24 | load_dynamic_mask: true 25 | load_feat_map: false 26 | load_normal_map: true 27 | 28 | load_intrinsic: false 29 | load_c2w: false 30 | 31 | save_occ_grid: true 32 | occ_voxel_size: 0.4 33 | recompute_occ_grid: false 34 | use_bg_gs: false 35 | white_background: false 36 | # PipelineParams 37 | 38 | 39 | # OptimizationParams 40 | iterations: 50000 41 | 42 | 43 | 44 | opacity_lr: 0.005 45 | 46 | densify_until_iter: 15000 47 | densify_grad_threshold: 0.00017 48 | sh_increase_interval: 2000 49 | 50 | 51 | lambda_t_reg: 0.0 52 | lambda_v_reg: 0.01 53 | lambda_normal: 0.1 54 | lambda_lidar: 0.1 55 | lambda_scaling: 0.0 56 | lambda_depth_var: 0.0 57 | lambda_min_scale: 10.0 58 | min_scale: 0.001 59 | max_scale: 0.2 # corresponding to 4.0m 60 | lambda_max_scale: 1.0 61 | 62 | neg_fov: false 63 | # uncertainty model 64 | dynamic_mask_epoch: 30000 65 | uncertainty_stage: stage2 66 | uncertainty_mode: "dino" # ["disabled", "l2reg", "l1reg", "dino", "dino+mssim"] 67 | uncertainty_backbone: "dinov2_vits14_reg" 68 | uncertainty_model_path: null 69 | uncertainty_regularizer_weight: 0.5 70 | uncertainty_clip_min: 0.1 71 | uncertainty_mask_clip_max: null 72 | uncertainty_dssim_clip_max: 1.0 # 0.05 -> 0.005 73 | uncertainty_lr: 0.001 74 | uncertainty_dropout: 0.1 75 | uncertainty_dino_max_size: null 76 | uncertainty_scale_grad: false 77 | uncertainty_center_mult: false 78 | uncertainty_after_opacity_reset: 1000 79 | uncertainty_protected_iters: 500 80 | uncertainty_preserve_sky: false 81 | 82 | 83 | render_type: pvg 84 | 85 | 86 | multi_view_weight_from_iter: 20000 87 | multi_view_patch_size: 3 88 | multi_view_sample_num: 102400 89 | multi_view_ncc_weight: 0.0 90 | multi_view_geo_weight: 0.03 91 | multi_view_pixel_noise_th: 1.0 -------------------------------------------------------------------------------- /docs/dataset_prep.md: -------------------------------------------------------------------------------- 1 | # Preparing Normals 2 | 3 | We use [Omnidata](https://github.com/EPFL-VILAB/omnidata) as the foundation model to extract monocular normals. Feel free to tryother models, such as [Metric3D](https://github.com/YvanYin/Metric3D). 4 | 5 | ## Installation 6 | 7 | 8 | ```sh 9 | git clone https://github.com/EPFL-VILAB/omnidata.git 10 | cd omnidata/omnidata_tools/torch 11 | mkdir -p pretrained_models && cd pretrained_models 12 | wget 'https://zenodo.org/records/10447888/files/omnidata_dpt_depth_v2.ckpt' 13 | wget 'https://zenodo.org/records/10447888/files/omnidata_dpt_normal_v2.ckpt' 14 | ``` 15 | 16 | ## Processing 17 | 18 | ### Waymo dataset 19 | 20 | If you download the dataset from [PVG](https://drive.google.com/file/d/1eTNJz7WeYrB3IctVlUmJIY0z8qhjR_qF/view?usp=sharing) to `dataset`, run the following command to generate the normal maps 21 | 22 | ```sh 23 | python scripts/extract_mono_cues_waymo.py --data_root ./dataset/waymo_scenes --task normal 24 | ``` 25 | 26 | ### KITTI dataset 27 | 28 | ```sh 29 | python scripts/extract_mono_cues_kitti.py --data_root ./dataset/kitti_mot/training --task normal 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import glob 12 | import json 13 | import os 14 | import torch 15 | import torch.nn.functional as F 16 | from utils.loss_utils import psnr, ssim 17 | from gaussian_renderer import get_renderer 18 | from scene import Scene, GaussianModel, EnvLight 19 | from utils.general_utils import seed_everything, visualize_depth 20 | from tqdm import tqdm 21 | from argparse import ArgumentParser 22 | from torchvision.utils import make_grid, save_image 23 | from omegaconf import OmegaConf 24 | from pprint import pprint, pformat 25 | from omegaconf import OmegaConf 26 | from texttable import Texttable 27 | import cv2 28 | import numpy as np 29 | import warnings 30 | warnings.filterwarnings("ignore") 31 | EPS = 1e-5 32 | non_zero_mean = ( 33 | lambda x: sum(x) / len(x) if len(x) > 0 else -1 34 | ) 35 | @torch.no_grad() 36 | def evaluation(iteration, scene : Scene, renderFunc, renderArgs, env_map=None): 37 | from lpipsPyTorch import lpips 38 | 39 | scale = scene.resolution_scales[0] 40 | if "kitti" in args.model_path: 41 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766 42 | num = len(scene.getTrainCameras())//2 43 | eval_train_frame = num//5 44 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id) 45 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 46 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]}) 47 | else: 48 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 49 | {'name': 'train', 'cameras': scene.getTrainCameras()}) 50 | 51 | for config in validation_configs: 52 | if config['cameras'] and len(config['cameras']) > 0: 53 | l1_test = [] 54 | psnr_test = [] 55 | ssim_test = [] 56 | lpips_test = [] 57 | masked_psnr_test = [] 58 | masked_ssim_test = [] 59 | outdir = os.path.join(args.model_path, "eval", "eval_on_" + config['name'] + "data" + "_render") 60 | image_folder = os.path.join(args.model_path, "images") 61 | os.makedirs(outdir,exist_ok=True) 62 | os.makedirs(image_folder,exist_ok=True) 63 | # opaciity_mask = scene.gaussians.get_opacity[:, 0] > 0.01 64 | # print("number of valid gaussians: {:d}".format(opaciity_mask.sum().item())) 65 | for camera_id, viewpoint in enumerate(tqdm(config['cameras'], bar_format="{l_bar}{bar:50}{r_bar}")): 66 | # render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=opaciity_mask) 67 | _, _, render_pkg = renderFunc(args, viewpoint, gaussians, background, scene.time_interval, env_map, iteration, camera_id) 68 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 69 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 70 | cv2.imwrite(os.path.join(image_folder, f"{viewpoint.colmap_id:03d}_gt.png"), (gt_image[[2,1,0], :, :].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) 71 | cv2.imwrite(os.path.join(image_folder, f"{viewpoint.colmap_id:03d}_render.png"), (image[[2,1,0], :, :].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) 72 | 73 | depth = render_pkg['depth'] 74 | alpha = render_pkg['alpha'] 75 | sky_depth = 900 76 | depth = depth / alpha.clamp_min(EPS) 77 | if env_map is not None: 78 | if args.depth_blend_mode == 0: # harmonic mean 79 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS) 80 | elif args.depth_blend_mode == 1: 81 | depth = alpha * depth + (1 - alpha) * sky_depth 82 | sky_mask = viewpoint.sky_mask.to("cuda") 83 | # dynamic_mask = viewpoint.dynamic_mask.to("cuda") if viewpoint.dynamic_mask is not None else torch.zeros_like(alpha, dtype=torch.bool) 84 | dynamic_mask = render_pkg['dynamic_mask'] 85 | depth = visualize_depth(depth) 86 | alpha = alpha.repeat(3, 1, 1) 87 | 88 | grid = [image, alpha, depth, gt_image, torch.logical_not(sky_mask[:1]).float().repeat(3, 1, 1), dynamic_mask.float().repeat(3, 1, 1)] 89 | grid = make_grid(grid, nrow=3) 90 | 91 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png")) 92 | 93 | l1_test.append(F.l1_loss(image, gt_image).double().item()) 94 | psnr_test.append(psnr(image, gt_image).double().item()) 95 | ssim_test.append(ssim(image, gt_image).double().item()) 96 | lpips_test.append(lpips(image, gt_image, net_type='alex').double().item()) # very slow 97 | # get the binary dynamic mask 98 | 99 | if dynamic_mask.sum() > 0: 100 | dynamic_mask = dynamic_mask.repeat(3, 1, 1) > 0 # (C, H, W) 101 | masked_psnr_test.append(psnr(image[dynamic_mask], gt_image[dynamic_mask]).double().item()) 102 | unaveraged_ssim = ssim(image, gt_image, size_average=False) # (C, H, W) 103 | masked_ssim_test.append(unaveraged_ssim[dynamic_mask].mean().double().item()) 104 | 105 | 106 | psnr_test = non_zero_mean(psnr_test) 107 | l1_test = non_zero_mean(l1_test) 108 | ssim_test = non_zero_mean(ssim_test) 109 | lpips_test = non_zero_mean(lpips_test) 110 | masked_psnr_test = non_zero_mean(masked_psnr_test) 111 | masked_ssim_test = non_zero_mean(masked_ssim_test) 112 | 113 | 114 | t = Texttable() 115 | t.add_rows([["PSNR", "SSIM", "LPIPS", "L1", "PSNR (dynamic)", "SSIM (dynamic)"], 116 | [f"{psnr_test:.4f}", f"{ssim_test:.4f}", f"{lpips_test:.4f}", f"{l1_test:.4f}", f"{masked_psnr_test:.4f}", f"{masked_ssim_test:.4f}"]]) 117 | print(t.draw()) 118 | with open(os.path.join(outdir, "metrics.json"), "w") as f: 119 | json.dump({"split": config['name'], "iteration": iteration, 120 | "psnr": psnr_test, "ssim": ssim_test, "lpips": lpips_test, "masked_psnr": masked_psnr_test, "masked_ssim": masked_ssim_test, 121 | }, f) 122 | 123 | 124 | if __name__ == "__main__": 125 | # Set up command line argument parser 126 | parser = ArgumentParser(description="Training script parameters") 127 | parser.add_argument("--config_path", type=str, required=True) 128 | params, _ = parser.parse_known_args() 129 | 130 | args = OmegaConf.load(params.config_path) 131 | args.resolution_scales = args.resolution_scales[:1] 132 | print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True)))) 133 | 134 | seed_everything(args.seed) 135 | 136 | sep_path = os.path.join(args.model_path, 'separation') 137 | os.makedirs(sep_path, exist_ok=True) 138 | 139 | gaussians = GaussianModel(args) 140 | scene = Scene(args, gaussians, shuffle=False) 141 | 142 | if args.env_map_res > 0: 143 | env_map = EnvLight(resolution=args.env_map_res).cuda() 144 | env_map.training_setup(args) 145 | else: 146 | env_map = None 147 | 148 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth")) 149 | assert len(checkpoints) > 0, "No checkpoints found." 150 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1] 151 | print(f"Loading checkpoint {checkpoint}") 152 | (model_params, first_iter) = torch.load(checkpoint) 153 | gaussians.restore(model_params, args) 154 | 155 | if env_map is not None: 156 | env_checkpoint = os.path.join(os.path.dirname(checkpoint), 157 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt")) 158 | (light_params, _) = torch.load(env_checkpoint) 159 | env_map.restore(light_params) 160 | uncertainty_model_path = os.path.join(os.path.dirname(checkpoint), 161 | os.path.basename(checkpoint).replace("chkpnt", "uncertainty_model")) 162 | state_dict = torch.load(uncertainty_model_path) 163 | gaussians.uncertainty_model.load_state_dict(state_dict, strict=False) 164 | 165 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0] 166 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 167 | render_func, render_wrapper = get_renderer(args.render_type) 168 | evaluation(first_iter, scene, render_wrapper, (args, background), env_map=env_map) 169 | 170 | print("Evaluation complete.") 171 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from .gs_render import render_original_gs, render_gs_origin_wrapper 13 | from .pvg_render import render_pvg, render_pvg_wrapper 14 | 15 | EPS = 1e-5 16 | 17 | rendererTypeCallbacks = { 18 | "gs": render_original_gs, 19 | "pvg": render_pvg 20 | } 21 | 22 | renderWrapperTypeCallbacks = { 23 | "gs": render_gs_origin_wrapper, 24 | "pvg": render_pvg_wrapper, 25 | } 26 | 27 | 28 | def get_renderer(render_type: str): 29 | return rendererTypeCallbacks[render_type], renderWrapperTypeCallbacks[render_type] 30 | -------------------------------------------------------------------------------- /gaussian_renderer/gs_render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | 13 | """ 14 | conda activate gs 15 | 16 | """ 17 | 18 | import torch 19 | import math 20 | from diff_gaussian_rasterization import ( 21 | GaussianRasterizationSettings, 22 | GaussianRasterizer, 23 | ) 24 | from scene.gaussian_model import GaussianModel 25 | from utils.sh_utils import eval_sh 26 | from scene.cameras import Camera 27 | import torch.nn.functional as F 28 | import numpy as np 29 | import kornia 30 | from utils.loss_utils import psnr, ssim, tv_loss 31 | 32 | EPS = 1e-5 33 | 34 | 35 | def render_original_gs( 36 | viewpoint_camera: Camera, 37 | pc: GaussianModel, 38 | pipe, 39 | bg_color: torch.Tensor, 40 | scaling_modifier=1.0, 41 | override_color=None, 42 | env_map=None, 43 | time_shift=None, 44 | other=[], 45 | mask=None, 46 | is_training=False, 47 | return_opacity=True, 48 | return_depth=False, 49 | ): 50 | """ 51 | Render the scene. 52 | 53 | Background tensor (bg_color) must be on GPU! 54 | """ 55 | 56 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 57 | screenspace_points = ( 58 | torch.zeros_like( 59 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" 60 | ) 61 | + 0 62 | ) 63 | try: 64 | screenspace_points.retain_grad() 65 | except: 66 | pass 67 | 68 | # Set up rasterization configuration 69 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 70 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 71 | 72 | raster_settings = GaussianRasterizationSettings( 73 | image_height=int(viewpoint_camera.image_height), 74 | image_width=int(viewpoint_camera.image_width), 75 | tanfovx=tanfovx, 76 | tanfovy=tanfovy, 77 | bg=bg_color, 78 | scale_modifier=scaling_modifier, 79 | viewmatrix=viewpoint_camera.world_view_transform, 80 | projmatrix=viewpoint_camera.full_proj_transform, 81 | sh_degree=pc.active_sh_degree, 82 | campos=viewpoint_camera.camera_center, 83 | prefiltered=False, 84 | debug=pipe.debug, 85 | ) 86 | 87 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 88 | 89 | means3D = pc.get_xyz 90 | means2D = screenspace_points 91 | opacity = pc.get_opacity 92 | 93 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 94 | # scaling / rotation by the rasterizer. 95 | scales = None 96 | rotations = None 97 | cov3D_precomp = None 98 | 99 | # if time_shift is not None: 100 | # means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp-time_shift) 101 | # means3D = means3D + pc.get_inst_velocity * time_shift 102 | # marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp-time_shift) 103 | # else: 104 | # means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp) 105 | # marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp) 106 | # opacity = opacity * marginal_t 107 | 108 | if pipe.compute_cov3D_python: 109 | cov3D_precomp = pc.get_covariance(scaling_modifier) 110 | else: 111 | scales = pc.get_scaling 112 | rotations = pc.get_rotation 113 | 114 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 115 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 116 | shs = None 117 | colors_precomp = None 118 | if override_color is None: 119 | if pipe.convert_SHs_python: 120 | shs_view = pc.get_features.transpose(1, 2).view( 121 | -1, 3, pc.get_max_sh_channels 122 | ) 123 | dir_pp = ( 124 | means3D.detach() 125 | - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1) 126 | ).detach() 127 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 128 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 129 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 130 | else: 131 | shs = pc.get_features 132 | else: 133 | colors_precomp = override_color 134 | 135 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 136 | rendered_image, radii = rasterizer( 137 | means3D=means3D, 138 | means2D=means2D, 139 | shs=shs, 140 | colors_precomp=colors_precomp, 141 | opacities=opacity, 142 | scales=scales, 143 | rotations=rotations, 144 | cov3D_precomp=cov3D_precomp, 145 | ) 146 | H, W = rendered_image.shape[1], rendered_image.shape[2] 147 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 148 | # They will be excluded from value updates used in the splitting criteria. 149 | return_dict = { 150 | "render_nobg": rendered_image, 151 | "normal": torch.zeros_like(rendered_image), 152 | "feature": torch.zeros([2, H, W]).to(rendered_image.device), 153 | "viewspace_points": screenspace_points, 154 | "visibility_filter": radii > 0, 155 | "radii": radii, 156 | } 157 | 158 | if return_opacity: 159 | density = torch.ones_like(means3D) 160 | 161 | render_opacity, _ = rasterizer( 162 | means3D=means3D, 163 | means2D=means2D, 164 | shs=None, 165 | colors_precomp=density, 166 | opacities=opacity, 167 | scales=scales, 168 | rotations=rotations, 169 | cov3D_precomp=cov3D_precomp, 170 | ) 171 | render_opacity = render_opacity.mean(dim=0, keepdim=True) # (1, H, W) 172 | return_dict.update({"alpha": render_opacity}) 173 | 174 | if return_depth: 175 | projvect1 = viewpoint_camera.world_view_transform[:, 2][:3].detach() 176 | projvect2 = viewpoint_camera.world_view_transform[:, 2][-1].detach() 177 | means3D_depth = (means3D * projvect1.unsqueeze(0)).sum( 178 | dim=-1, keepdim=True 179 | ) + projvect2 180 | means3D_depth = means3D_depth.repeat(1, 3) 181 | render_depth, _ = rasterizer( 182 | means3D=means3D, 183 | means2D=means2D, 184 | shs=None, 185 | colors_precomp=means3D_depth, 186 | opacities=opacity, 187 | scales=scales, 188 | rotations=rotations, 189 | cov3D_precomp=cov3D_precomp, 190 | ) 191 | render_depth = render_depth.mean(dim=0, keepdim=True) 192 | return_dict.update({"depth": render_depth}) 193 | else: 194 | return_dict.update({"depth": torch.zeros_like(render_opacity)}) 195 | 196 | if env_map is not None: 197 | bg_color_from_envmap = env_map( 198 | viewpoint_camera.get_world_directions(is_training).permute(1, 2, 0) 199 | ).permute(2, 0, 1) 200 | rendered_image = rendered_image + (1 - render_opacity) * bg_color_from_envmap 201 | return_dict.update({"render": rendered_image}) 202 | 203 | return return_dict 204 | 205 | 206 | def calculate_loss( 207 | gaussians: GaussianModel, 208 | viewpoint_camera: Camera, 209 | args, 210 | render_pkg: dict, 211 | env_map, 212 | iteration, 213 | camera_id, 214 | ): 215 | log_dict = {} 216 | 217 | image = render_pkg["render"] 218 | depth = render_pkg["depth"] 219 | alpha = render_pkg["alpha"] 220 | 221 | sky_mask = ( 222 | viewpoint_camera.sky_mask.cuda() 223 | if viewpoint_camera.sky_mask is not None 224 | else torch.zeros_like(alpha, dtype=torch.bool) 225 | ) 226 | 227 | sky_depth = 900 228 | depth = depth / alpha.clamp_min(EPS) 229 | if env_map is not None: 230 | if args.depth_blend_mode == 0: # harmonic mean 231 | depth = 1 / ( 232 | alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth 233 | ).clamp_min(EPS) 234 | elif args.depth_blend_mode == 1: 235 | depth = alpha * depth + (1 - alpha) * sky_depth 236 | 237 | gt_image = viewpoint_camera.original_image.cuda() 238 | 239 | loss_l1 = F.l1_loss(image, gt_image, reduction="none") # [3, H, W] 240 | loss_ssim = 1.0 - ssim(image, gt_image, size_average=False) # [3, H, W] 241 | 242 | log_dict["loss_l1"] = loss_l1.mean().item() 243 | log_dict["loss_ssim"] = loss_ssim.mean().item() 244 | 245 | loss = ( 246 | 1.0 - args.lambda_dssim 247 | ) * loss_l1.mean() + args.lambda_dssim * loss_ssim.mean() 248 | 249 | psnr_for_log = psnr(image, gt_image).double() 250 | log_dict["psnr"] = psnr_for_log 251 | 252 | if args.lambda_lidar > 0: 253 | assert viewpoint_camera.pts_depth is not None 254 | pts_depth = viewpoint_camera.pts_depth.cuda() 255 | 256 | mask = pts_depth > 0 257 | loss_lidar = torch.abs( 258 | 1 / (pts_depth[mask] + 1e-5) - 1 / (depth[mask] + 1e-5) 259 | ).mean() 260 | if args.lidar_decay > 0: 261 | iter_decay = np.exp(-iteration / 8000 * args.lidar_decay) 262 | else: 263 | iter_decay = 1 264 | log_dict["loss_lidar"] = loss_lidar.item() 265 | loss += iter_decay * args.lambda_lidar * loss_lidar 266 | 267 | # if args.lambda_normal > 0 and args.load_normal_map: 268 | # alpha_mask = (alpha.data > EPS).repeat(3, 1, 1) # (3, H, W) detached 269 | # rendered_normal = render_pkg['normal'] # (3, H, W) 270 | # gt_normal = viewpoint_camera.normal_map.cuda() 271 | # loss_normal = F.l1_loss(rendered_normal[alpha_mask], gt_normal[alpha_mask]) 272 | # loss_normal += tv_loss(rendered_normal) 273 | # log_dict['loss_normal'] = loss_normal.item() 274 | # loss += args.lambda_normal * loss_normal 275 | 276 | # if args.lambda_v_reg > 0 and args.enable_dynamic: 277 | # loss_v_reg = (torch.abs(v_map) * loss_mult).mean() 278 | # log_dict['loss_v_reg'] = loss_v_reg.item() 279 | # loss += args.lambda_v_reg * loss_v_reg 280 | 281 | # loss_mult[alpha.data < EPS] = 0.0 282 | # if args.lambda_t_reg > 0 and args.enable_dynamic: 283 | # loss_t_reg = (-torch.abs(t_map) * loss_mult).mean() 284 | # log_dict['loss_t_reg'] = loss_t_reg.item() 285 | # loss += args.lambda_t_reg * loss_t_reg 286 | 287 | if args.lambda_inv_depth > 0: 288 | inverse_depth = 1 / (depth + 1e-5) 289 | loss_inv_depth = kornia.losses.inverse_depth_smoothness_loss( 290 | inverse_depth[None], gt_image[None] 291 | ) 292 | log_dict["loss_inv_depth"] = loss_inv_depth.item() 293 | loss = loss + args.lambda_inv_depth * loss_inv_depth 294 | 295 | if args.lambda_sky_opa > 0: 296 | o = alpha.clamp(1e-6, 1 - 1e-6) 297 | sky = sky_mask.float() 298 | loss_sky_opa = (-sky * torch.log(1 - o)).mean() 299 | log_dict["loss_sky_opa"] = loss_sky_opa.item() 300 | loss = loss + args.lambda_sky_opa * loss_sky_opa 301 | 302 | if args.lambda_opacity_entropy > 0: 303 | o = alpha.clamp(1e-6, 1 - 1e-6) 304 | loss_opacity_entropy = -(o * torch.log(o)).mean() 305 | log_dict["loss_opacity_entropy"] = loss_opacity_entropy.item() 306 | loss = loss + args.lambda_opacity_entropy * loss_opacity_entropy 307 | 308 | extra_render_pkg = {} 309 | extra_render_pkg["t_map"] = torch.zeros_like(alpha) 310 | extra_render_pkg["v_map"] = torch.zeros_like(alpha) 311 | # extra_render_pkg['depth'] = torch.zeros_like(alpha) 312 | extra_render_pkg["dynamic_mask"] = torch.zeros_like(alpha) 313 | extra_render_pkg["dino_cosine"] = torch.zeros_like(alpha) 314 | 315 | return loss, log_dict, extra_render_pkg 316 | 317 | 318 | def render_gs_origin_wrapper( 319 | args, 320 | viewpoint_camera: Camera, 321 | gaussians: GaussianModel, 322 | background: torch.Tensor, 323 | time_interval: float, 324 | env_map, 325 | iterations, 326 | camera_id, 327 | ): 328 | 329 | render_pkg = render_original_gs( 330 | viewpoint_camera, gaussians, args, background, env_map=env_map, is_training=True 331 | ) 332 | 333 | loss, log_dict, extra_render_pkg = calculate_loss( 334 | gaussians, viewpoint_camera, args, render_pkg, env_map, iterations, camera_id 335 | ) 336 | 337 | render_pkg.update(extra_render_pkg) 338 | 339 | return loss, log_dict, render_pkg 340 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y).mean() 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | imageio 3 | imageio[ffmpeg] 4 | kornia 5 | trimesh 6 | Pillow 7 | ninja 8 | omegaconf 9 | plyfile 10 | opencv_python 11 | opencv_contrib_python 12 | tensorboardX 13 | matplotlib 14 | texttable 15 | albumentations 16 | timm==0.9.10 17 | scikit-learn 18 | torch-kmeans 19 | albumentations 20 | albucore 21 | open3d 22 | tensorboard 23 | gdown 24 | h5py 25 | pytorch_lightning -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | import torch 16 | from tqdm import tqdm 17 | import numpy as np 18 | from utils.system_utils import searchForMaxIteration 19 | from scene.gaussian_model import GaussianModel 20 | from scene.envlight import EnvLight 21 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, calculate_mean_and_std 22 | from scene.waymo_loader import readWaymoInfo 23 | from scene.kittimot_loader import readKittiMotInfo 24 | from scene.emer_waymo_loader import readEmerWaymoInfo 25 | import logging 26 | sceneLoadTypeCallbacks = { 27 | "Waymo": readWaymoInfo, 28 | "KittiMot": readKittiMotInfo, 29 | 'EmerWaymo': readEmerWaymoInfo, 30 | } 31 | 32 | class Scene: 33 | 34 | gaussians : GaussianModel 35 | 36 | def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True): 37 | self.model_path = args.model_path 38 | self.loaded_iter = None 39 | self.gaussians = gaussians 40 | self.white_background = args.white_background 41 | 42 | if load_iteration: 43 | if load_iteration == -1: 44 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 45 | else: 46 | self.loaded_iter = load_iteration 47 | logging.info("Loading trained model at iteration {}".format(self.loaded_iter)) 48 | 49 | self.train_cameras = {} 50 | self.test_cameras = {} 51 | 52 | scene_info = sceneLoadTypeCallbacks[args.scene_type](args) 53 | 54 | self.time_interval = args.frame_interval 55 | self.gaussians.time_duration = scene_info.time_duration 56 | # print("time duration: ", scene_info.time_duration) 57 | # print("frame interval: ", self.time_interval) 58 | 59 | 60 | if not self.loaded_iter: 61 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 62 | dest_file.write(src_file.read()) 63 | json_cams = [] 64 | camlist = [] 65 | if scene_info.test_cameras: 66 | camlist.extend(scene_info.test_cameras) 67 | if scene_info.train_cameras: 68 | camlist.extend(scene_info.train_cameras) 69 | for id, cam in enumerate(camlist): 70 | json_cams.append(camera_to_JSON(id, cam)) 71 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 72 | json.dump(json_cams, file) 73 | 74 | if shuffle: 75 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 76 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 77 | 78 | self.cameras_extent = scene_info.nerf_normalization["radius"] 79 | self.resolution_scales = args.resolution_scales 80 | self.scale_index = len(self.resolution_scales) - 1 81 | for resolution_scale in self.resolution_scales: 82 | logging.info("Loading Training Cameras at resolution scale {}".format(resolution_scale)) 83 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 84 | logging.info("Loading Test Cameras") 85 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 86 | logging.info("Computing nearest_id") 87 | image_name_to_id = {"train": {cam.image_name: id for id, cam in enumerate(self.train_cameras[resolution_scale])}, 88 | "test": {cam.image_name: id for id, cam in enumerate(self.test_cameras[resolution_scale])}} 89 | with open(os.path.join(self.model_path, "multi_view.json"), 'w') as file: 90 | json_d = [] 91 | for id, cur_cam in enumerate(tqdm(self.train_cameras[resolution_scale], bar_format='{l_bar}{bar:50}{r_bar}')): 92 | image_name_to_id_map = image_name_to_id["train"] 93 | # cur_image_name = cur_cam.image_name 94 | cur_colmap_id = cur_cam.colmap_id 95 | nearest_colmap_id_candidate = [cur_colmap_id - 10, cur_colmap_id + 10, cur_colmap_id - 20, cur_colmap_id + 20] 96 | 97 | for colmap_id in nearest_colmap_id_candidate: 98 | near_image_name = "{:03d}_{:1d}".format(colmap_id // 10, colmap_id % 10) 99 | if near_image_name in image_name_to_id_map: 100 | cur_cam.nearest_id.append(image_name_to_id_map[near_image_name]) 101 | cur_cam.nearest_names.append(near_image_name) 102 | 103 | json_d.append({'ref_name' : cur_cam.image_name, 'nearest_name': cur_cam.nearest_names, "id": id, 'nearest_id': cur_cam.nearest_id}) 104 | json.dump(json_d, file) 105 | 106 | 107 | if resolution_scale == 1.0: 108 | logging.info("Computing mean and std of dataset") 109 | mean = [] 110 | std = [] 111 | all_cameras = self.train_cameras[resolution_scale] + self.test_cameras[resolution_scale] 112 | for idx, viewpoint in enumerate(tqdm(all_cameras, bar_format="{l_bar}{bar:50}{r_bar}")): 113 | gt_image = viewpoint.original_image # [3, H, W] 114 | mean.append(gt_image.mean(dim=[1, 2]).cpu().numpy()) 115 | std.append(gt_image.std(dim=[1, 2]).cpu().numpy()) 116 | mean = np.array(mean) 117 | std = np.array(std) 118 | # calculate mean and std of dataset 119 | mean_dataset, std_rgb_dataset = calculate_mean_and_std(mean, std) 120 | 121 | if gaussians.uncertainty_model is not None: 122 | gaussians.uncertainty_model.img_norm_mean = torch.from_numpy(mean_dataset) 123 | gaussians.uncertainty_model.img_norm_std = torch.from_numpy(std_rgb_dataset) 124 | 125 | if self.loaded_iter: 126 | self.gaussians.load_ply(os.path.join(self.model_path, 127 | "point_cloud", 128 | "iteration_" + str(self.loaded_iter), 129 | "point_cloud.ply")) 130 | else: 131 | self.gaussians.create_from_pcd(scene_info.point_cloud, 1) 132 | 133 | def upScale(self): 134 | self.scale_index = max(0, self.scale_index - 1) 135 | 136 | def getTrainCameras(self): 137 | return self.train_cameras[self.resolution_scales[self.scale_index]] 138 | 139 | def getTestCameras(self, scale=1.0): 140 | return self.test_cameras[scale] 141 | 142 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrixCenterShift 18 | import kornia 19 | 20 | 21 | class Camera(nn.Module): 22 | def __init__(self, colmap_id, R, T, FoVx=None, FoVy=None, cx=None, cy=None, fx=None, fy=None, 23 | image=None, 24 | image_name=None, uid=0, 25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", timestamp=0.0, 26 | resolution=None, image_path=None, 27 | pts_depth=None, sky_mask=None, dynamic_mask=None, normal_map=None, ncc_scale=1.0, 28 | ): 29 | super(Camera, self).__init__() 30 | 31 | self.uid = uid 32 | self.colmap_id = colmap_id 33 | self.R = R 34 | self.T = T 35 | self.FoVx = FoVx 36 | self.FoVy = FoVy 37 | self.image_name = image_name 38 | self.image = image 39 | self.cx = cx 40 | self.cy = cy 41 | self.fx = fx 42 | self.fy = fy 43 | self.resolution = resolution 44 | self.image_path = image_path 45 | self.ncc_scale = ncc_scale 46 | self.nearest_id = [] 47 | self.nearest_names = [] 48 | 49 | try: 50 | self.data_device = torch.device(data_device) 51 | except Exception as e: 52 | print(e) 53 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 54 | self.data_device = torch.device("cuda") 55 | 56 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 57 | self.sky_mask = sky_mask.to(self.data_device) > 0 if sky_mask is not None else sky_mask 58 | self.pts_depth = pts_depth.to(self.data_device) if pts_depth is not None else pts_depth 59 | self.dynamic_mask = dynamic_mask.to(self.data_device) > 0 if dynamic_mask is not None else dynamic_mask 60 | self.normal_map = normal_map.to(self.data_device) if normal_map is not None else normal_map 61 | self.image_width = resolution[0] 62 | self.image_height = resolution[1] 63 | 64 | self.zfar = 1000.0 65 | self.znear = 0.01 66 | 67 | self.trans = trans 68 | self.scale = scale 69 | 70 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 71 | if cx is not None: 72 | self.FoVx = 2 * math.atan(0.5*self.image_width / fx) if FoVx is None else FoVx 73 | self.FoVy = 2 * math.atan(0.5*self.image_height / fy) if FoVy is None else FoVy 74 | self.projection_matrix = getProjectionMatrixCenterShift(self.znear, self.zfar, cx, cy, fx, fy, 75 | self.image_width, self.image_height).transpose(0, 1).cuda() 76 | else: 77 | self.cx = self.image_width / 2 78 | self.cy = self.image_height / 2 79 | self.fx = self.image_width / (2 * np.tan(self.FoVx * 0.5)) 80 | self.fy = self.image_height / (2 * np.tan(self.FoVy * 0.5)) 81 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, 82 | fovY=self.FoVy).transpose(0, 1).cuda() 83 | self.full_proj_transform = ( 84 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 85 | self.camera_center = self.world_view_transform.inverse()[3, :3] 86 | self.c2w = self.world_view_transform.transpose(0, 1).inverse() 87 | self.timestamp = timestamp 88 | self.grid = kornia.utils.create_meshgrid(self.image_height, self.image_width, normalized_coordinates=False, device='cuda')[0] 89 | 90 | def get_world_directions(self, train=False): 91 | u, v = self.grid.unbind(-1) 92 | if train: 93 | directions = torch.stack([(u-self.cx+torch.rand_like(u))/self.fx, 94 | (v-self.cy+torch.rand_like(v))/self.fy, 95 | torch.ones_like(u)], dim=0) 96 | else: 97 | directions = torch.stack([(u-self.cx+0.5)/self.fx, 98 | (v-self.cy+0.5)/self.fy, 99 | torch.ones_like(u)], dim=0) 100 | directions = F.normalize(directions, dim=0) 101 | directions = (self.c2w[:3, :3] @ directions.reshape(3, -1)).reshape(3, self.image_height, self.image_width) 102 | return directions 103 | 104 | def get_image(self): 105 | original_image = self.original_image # [3, H, W] 106 | 107 | # convert the image tensor to grayscale L = R * 299/1000 + G * 587/1000 + B * 114/1000 108 | # image_gray = (0.299 * original_image[0] + 0.587 * original_image[1] + 0.114 * original_image[2])[None] 109 | image_gray = original_image[0:1] * 0.299 + original_image[1:2] * 0.587 + original_image[2:3] * 0.114 110 | 111 | return original_image.cuda(), image_gray.cuda() 112 | 113 | 114 | def get_calib_matrix_nerf(self, scale=1.0): 115 | intrinsic_matrix = torch.tensor([[self.fx/scale, 0, self.cx/scale], [0, self.fy/scale, self.cy/scale], [0, 0, 1]]).float() 116 | extrinsic_matrix = self.world_view_transform.transpose(0,1).contiguous() # cam2world 117 | return intrinsic_matrix, extrinsic_matrix 118 | 119 | def get_rays(self, scale=1.0): 120 | W, H = int(self.image_width/scale), int(self.image_height/scale) 121 | ix, iy = torch.meshgrid( 122 | torch.arange(W), torch.arange(H), indexing='xy') 123 | rays_d = torch.stack( 124 | [(ix-self.cx/scale) / self.fx * scale, 125 | (iy-self.cy/scale) / self.fy * scale, 126 | torch.ones_like(ix)], -1).float().cuda() 127 | return rays_d 128 | 129 | def get_k(self, scale=1.0): 130 | K = torch.tensor([[self.fx / scale, 0, self.cx / scale], 131 | [0, self.fy / scale, self.cy / scale], 132 | [0, 0, 1]]).cuda() 133 | return K 134 | 135 | def get_inv_k(self, scale=1.0): 136 | K_T = torch.tensor([[scale/self.fx, 0, -self.cx/self.fx], 137 | [0, scale/self.fy, -self.cy/self.fy], 138 | [0, 0, 1]]).cuda() 139 | return K_T -------------------------------------------------------------------------------- /scene/emer_waymo_loader.py: -------------------------------------------------------------------------------- 1 | # Description: Load the EmerWaymo dataset for training and testing 2 | # adapted from the PVG datareader for the data from EmerNeRF 3 | 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly 9 | from utils.graphics_utils import BasicPointCloud, focal2fov 10 | 11 | def pad_poses(p): 12 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 13 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 14 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 15 | 16 | 17 | def unpad_poses(p): 18 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 19 | return p[..., :3, :4] 20 | 21 | 22 | def transform_poses_pca(poses, fix_radius=0): 23 | """Transforms poses so principal components lie on XYZ axes. 24 | 25 | Args: 26 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 27 | 28 | Returns: 29 | A tuple (poses, transform), with the transformed poses and the applied 30 | camera_to_world transforms. 31 | 32 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161 33 | """ 34 | t = poses[:, :3, 3] 35 | t_mean = t.mean(axis=0) 36 | t = t - t_mean 37 | 38 | eigval, eigvec = np.linalg.eig(t.T @ t) 39 | # Sort eigenvectors in order of largest to smallest eigenvalue. 40 | inds = np.argsort(eigval)[::-1] 41 | eigvec = eigvec[:, inds] 42 | rot = eigvec.T 43 | if np.linalg.det(rot) < 0: 44 | rot = np.diag(np.array([1, 1, -1])) @ rot 45 | 46 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 47 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 48 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 49 | 50 | # Flip coordinate system if z component of y-axis is negative 51 | if poses_recentered.mean(axis=0)[2, 1] < 0: 52 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 53 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 54 | 55 | # Just make sure it's it in the [-1, 1]^3 cube 56 | if fix_radius>0: 57 | scale_factor = 1./fix_radius 58 | else: 59 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5) 60 | scale_factor = min(1 / 10, scale_factor) 61 | 62 | poses_recentered[:, :3, 3] *= scale_factor 63 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 64 | 65 | return poses_recentered, transform, scale_factor 66 | 67 | 68 | def readEmerWaymoInfo(args): 69 | 70 | eval = args.eval 71 | load_sky_mask = args.load_sky_mask 72 | load_panoptic_mask = args.load_panoptic_mask 73 | load_sam_mask = args.load_sam_mask 74 | load_dynamic_mask = args.load_dynamic_mask 75 | load_normal_map = args.load_normal_map 76 | load_feat_map = args.load_feat_map 77 | load_intrinsic = args.load_intrinsic 78 | load_c2w = args.load_c2w 79 | save_occ_grid = args.save_occ_grid 80 | occ_voxel_size = args.occ_voxel_size 81 | recompute_occ_grid = args.recompute_occ_grid 82 | use_bg_gs = args.use_bg_gs 83 | white_background = args.white_background 84 | neg_fov = args.neg_fov 85 | 86 | 87 | num_pts = args.num_pts 88 | start_time = args.start_time 89 | end_time = args.end_time 90 | stride = args.stride 91 | original_start_time = args.original_start_time 92 | 93 | ORIGINAL_SIZE = [[1280, 1920], [1280, 1920], [1280, 1920], [884, 1920], [884, 1920]] 94 | OPENCV2DATASET = np.array( 95 | [[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]] 96 | ) 97 | load_size = [640, 960] 98 | 99 | cam_infos = [] 100 | points = [] 101 | points_time = [] 102 | 103 | data_root = args.source_path 104 | image_folder = os.path.join(data_root, "images") 105 | num_seqs = len(os.listdir(image_folder))/5 106 | if end_time == -1: 107 | end_time = int(num_seqs) 108 | else: 109 | end_time += 1 110 | 111 | frame_num = end_time - start_time 112 | # assert frame_num == 50, "frame_num should be 50" 113 | time_duration = args.time_duration 114 | time_interval = (time_duration[1] - time_duration[0]) / (end_time - start_time) 115 | 116 | camera_list = [0, 1, 2] 117 | truncated_min_range, truncated_max_range = -2, 80 118 | 119 | 120 | # --------------------------------------------- 121 | # load poses: intrinsic, c2w, l2w per camera 122 | # --------------------------------------------- 123 | _intrinsics = [] 124 | cam_to_egos = [] 125 | for i in camera_list: 126 | # load intrinsics 127 | intrinsic = np.loadtxt(os.path.join(data_root, "intrinsics", f"{i}.txt")) 128 | fx, fy, cx, cy = intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3] 129 | # scale intrinsics w.r.t. load size 130 | fx, fy = ( 131 | fx * load_size[1] / ORIGINAL_SIZE[i][1], 132 | fy * load_size[0] / ORIGINAL_SIZE[i][0], 133 | ) 134 | cx, cy = ( 135 | cx * load_size[1] / ORIGINAL_SIZE[i][1], 136 | cy * load_size[0] / ORIGINAL_SIZE[i][0], 137 | ) 138 | intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 139 | _intrinsics.append(intrinsic) 140 | # load extrinsics 141 | cam_to_ego = np.loadtxt(os.path.join(data_root, "extrinsics", f"{i}.txt")) 142 | # opencv coordinate system: x right, y down, z front 143 | # waymo coordinate system: x front, y left, z up 144 | cam_to_egos.append(cam_to_ego @ OPENCV2DATASET) # opencv_cam -> waymo_cam -> waymo_ego 145 | 146 | 147 | # --------------------------------------------- 148 | # get c2w and w2c transformation per frame and camera 149 | # --------------------------------------------- 150 | # compute per-image poses and intrinsics 151 | cam_to_worlds, ego_to_worlds = [], [] 152 | intrinsics, cam_ids = [], [] 153 | lidar_to_worlds = [] 154 | # ===! for waymo, we simplify timestamps as the time indices 155 | timestamps, timesteps = [], [] 156 | # we tranform the camera poses w.r.t. the first timestep to make the translation vector of 157 | # the first ego pose as the origin of the world coordinate system. 158 | ego_to_world_start = np.loadtxt(os.path.join(data_root, "ego_pose", f"{start_time:03d}.txt")) 159 | for t in range(start_time, end_time): 160 | ego_to_world_current = np.loadtxt(os.path.join(data_root, "ego_pose", f"{t:03d}.txt")) 161 | # ego to world transformation: cur_ego -> world -> start_ego(world) 162 | ego_to_world = np.linalg.inv(ego_to_world_start) @ ego_to_world_current 163 | ego_to_worlds.append(ego_to_world) 164 | for cam_id in camera_list: 165 | cam_ids.append(cam_id) 166 | # transformation: 167 | # opencv_cam -> waymo_cam -> waymo_cur_ego -> world -> start_ego(world) 168 | cam2world = ego_to_world @ cam_to_egos[cam_id] 169 | cam_to_worlds.append(cam2world) 170 | intrinsics.append(_intrinsics[cam_id]) 171 | # ===! we use time indices as the timestamp for waymo dataset for simplicity 172 | # ===! we can use the actual timestamps if needed 173 | # to be improved 174 | timestamps.append(t - start_time) 175 | timesteps.append(t - start_time) 176 | # lidar to world : lidar = ego in waymo 177 | lidar_to_worlds.append(ego_to_world) 178 | # convert to numpy arrays 179 | intrinsics = np.stack(intrinsics, axis=0) 180 | cam_to_worlds = np.stack(cam_to_worlds, axis=0) 181 | ego_to_worlds = np.stack(ego_to_worlds, axis=0) 182 | lidar_to_worlds = np.stack(lidar_to_worlds, axis=0) 183 | cam_ids = np.array(cam_ids) 184 | timestamps = np.array(timestamps) 185 | timesteps = np.array(timesteps) 186 | 187 | 188 | # --------------------------------------------- 189 | # get image, sky_mask, lidar per frame and camera 190 | # --------------------------------------------- 191 | accumulated_num_original_rays = 0 192 | accumulated_num_rays = 0 193 | 194 | for idx, t in enumerate(tqdm(range(start_time, end_time), desc="Loading data", bar_format='{l_bar}{bar:50}{r_bar}')): 195 | 196 | images = [] 197 | image_paths = [] 198 | HWs = [] 199 | sky_masks = [] 200 | dynamic_masks = [] 201 | normal_maps = [] 202 | 203 | for cam_idx in camera_list: 204 | image_path = os.path.join(args.source_path, "images", f"{t:03d}_{cam_idx}.jpg") 205 | im_data = Image.open(image_path) 206 | im_data = im_data.resize((load_size[1], load_size[0]), Image.BILINEAR) # PIL resize: (W, H) 207 | W, H = im_data.size 208 | image = np.array(im_data) / 255. 209 | HWs.append((H, W)) 210 | images.append(image) 211 | image_paths.append(image_path) 212 | 213 | sky_path = os.path.join(args.source_path, "sky_masks", f"{t:03d}_{cam_idx}.png") 214 | sky_data = Image.open(sky_path) 215 | sky_data = sky_data.resize((load_size[1], load_size[0]), Image.NEAREST) # PIL resize: (W, H) 216 | sky_mask = np.array(sky_data)>0 217 | sky_masks.append(sky_mask.astype(np.float32)) 218 | 219 | if load_normal_map: 220 | normal_path = os.path.join(args.source_path, "normals", f"{t:03d}_{cam_idx}.jpg") 221 | normal_data = Image.open(normal_path) 222 | normal_data = normal_data.resize((load_size[1], load_size[0]), Image.BILINEAR) 223 | normal_map = (np.array(normal_data)) / 255. # [0, 1] 224 | normal_maps.append(normal_map) 225 | 226 | if load_dynamic_mask: 227 | dynamic_path = os.path.join(args.source_path, "dynamic_masks", f"{t:03d}_{cam_idx}.png") 228 | dynamic_data = Image.open(dynamic_path) 229 | dynamic_data = dynamic_data.resize((load_size[1], load_size[0]), Image.BILINEAR) 230 | dynamic_mask = np.array(dynamic_data)>0 231 | dynamic_masks.append(dynamic_mask.astype(np.float32)) 232 | 233 | 234 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (frame_num - 1) 235 | lidar_info = np.memmap( 236 | os.path.join(data_root, "lidar", f"{t:03d}.bin"), 237 | dtype=np.float32, 238 | mode="r", 239 | ).reshape(-1, 10) 240 | original_length = len(lidar_info) 241 | accumulated_num_original_rays += original_length 242 | lidar_origins = lidar_info[:, :3] 243 | lidar_points = lidar_info[:, 3:6] 244 | lidar_ids = lidar_info[:, -1] 245 | # select lidar points based on a truncated ego-forward-directional range 246 | # make sure most of lidar points are within the range of the camera 247 | valid_mask = lidar_points[:, 0] < truncated_max_range 248 | valid_mask = valid_mask & (lidar_points[:, 0] > truncated_min_range) 249 | lidar_origins = lidar_origins[valid_mask] 250 | lidar_points = lidar_points[valid_mask] 251 | lidar_ids = lidar_ids[valid_mask] 252 | # transform lidar points to world coordinate system 253 | lidar_origins = ( 254 | lidar_to_worlds[idx][:3, :3] @ lidar_origins.T 255 | + lidar_to_worlds[idx][:3, 3:4] 256 | ).T 257 | lidar_points = ( 258 | lidar_to_worlds[idx][:3, :3] @ lidar_points.T 259 | + lidar_to_worlds[idx][:3, 3:4] 260 | ).T # point_xyz_world 261 | 262 | points.append(lidar_points) 263 | point_time = np.full_like(lidar_points[:, :1], timestamp) 264 | points_time.append(point_time) 265 | 266 | for cam_idx in camera_list: 267 | # world-lidar-pts --> camera-pts : w2c 268 | c2w = cam_to_worlds[int(len(camera_list))*idx + cam_idx] 269 | w2c = np.linalg.inv(c2w) 270 | point_camera = ( 271 | w2c[:3, :3] @ lidar_points.T 272 | + w2c[:3, 3:4] 273 | ).T 274 | 275 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 276 | T = w2c[:3, 3] 277 | K = _intrinsics[cam_idx] 278 | fx = float(K[0, 0]) 279 | fy = float(K[1, 1]) 280 | cx = float(K[0, 2]) 281 | cy = float(K[1, 2]) 282 | height, width = HWs[cam_idx] 283 | if neg_fov: 284 | FovY = -1.0 285 | FovX = -1.0 286 | else: 287 | FovY = focal2fov(fy, height) 288 | FovX = focal2fov(fx, width) 289 | cam_infos.append(CameraInfo(uid=idx * 10 + cam_idx, R=R, T=T, FovY=FovY, FovX=FovX, 290 | image=images[cam_idx], 291 | image_path=image_paths[cam_idx], image_name=f"{t:03d}_{cam_idx}", 292 | width=width, height=height, timestamp=timestamp, 293 | pointcloud_camera = point_camera, 294 | fx=fx, fy=fy, cx=cx, cy=cy, 295 | sky_mask=sky_masks[cam_idx], 296 | dynamic_mask=dynamic_masks[cam_idx] if load_dynamic_mask else None, 297 | normal_map=normal_maps[cam_idx] if load_normal_map else None,)) 298 | 299 | if args.debug_cuda: 300 | break 301 | 302 | pointcloud = np.concatenate(points, axis=0) 303 | pointcloud_timestamp = np.concatenate(points_time, axis=0) 304 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True) 305 | pointcloud = pointcloud[indices] 306 | pointcloud_timestamp = pointcloud_timestamp[indices] 307 | 308 | w2cs = np.zeros((len(cam_infos), 4, 4)) 309 | Rs = np.stack([c.R for c in cam_infos], axis=0) 310 | Ts = np.stack([c.T for c in cam_infos], axis=0) 311 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1)) 312 | w2cs[:, :3, 3] = Ts 313 | w2cs[:, 3, 3] = 1 314 | c2ws = unpad_poses(np.linalg.inv(w2cs)) 315 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius) 316 | 317 | c2ws = pad_poses(c2ws) 318 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data", bar_format='{l_bar}{bar:50}{r_bar}')): 319 | c2w = c2ws[idx] 320 | w2c = np.linalg.inv(c2w) 321 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 322 | cam_info.T[:] = w2c[:3, 3] 323 | cam_info.pointcloud_camera[:] *= scale_factor 324 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3] 325 | if args.eval: 326 | # ## for snerf scene 327 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0] 328 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0] 329 | 330 | # for dynamic scene 331 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0] 332 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0] 333 | 334 | # for emernerf comparison [testhold::testhold] 335 | if args.testhold == 10: 336 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0] 337 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0] 338 | else: 339 | train_cam_infos = cam_infos 340 | test_cam_infos = [] 341 | 342 | nerf_normalization = getNerfppNorm(train_cam_infos) 343 | nerf_normalization['radius'] = 1/nerf_normalization['radius'] 344 | 345 | ply_path = os.path.join(args.source_path, "points3d.ply") 346 | if not os.path.exists(ply_path): 347 | rgbs = np.random.random((pointcloud.shape[0], 3)) 348 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp) 349 | try: 350 | pcd = fetchPly(ply_path) 351 | except: 352 | pcd = None 353 | 354 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp) 355 | 356 | scene_info = SceneInfo(point_cloud=pcd, 357 | train_cameras=train_cam_infos, 358 | test_cameras=test_cam_infos, 359 | nerf_normalization=nerf_normalization, 360 | ply_path=ply_path, 361 | time_interval=time_interval, 362 | time_duration=time_duration) 363 | 364 | return scene_info 365 | -------------------------------------------------------------------------------- /scene/envlight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nvdiffrast.torch as dr 3 | 4 | 5 | class EnvLight(torch.nn.Module): 6 | 7 | def __init__(self, resolution=1024): 8 | super().__init__() 9 | self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda") 10 | self.base = torch.nn.Parameter( 11 | 0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True), 12 | ) 13 | 14 | def capture(self): 15 | return ( 16 | self.base, 17 | self.optimizer.state_dict(), 18 | ) 19 | 20 | def restore(self, model_args, training_args=None): 21 | self.base, opt_dict = model_args 22 | if training_args is not None: 23 | self.training_setup(training_args) 24 | self.optimizer.load_state_dict(opt_dict) 25 | 26 | def training_setup(self, training_args): 27 | self.optimizer = torch.optim.Adam(self.parameters(), lr=training_args.envmap_lr, eps=1e-15) 28 | 29 | def forward(self, l): 30 | l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape) 31 | l = l.contiguous() 32 | prefix = l.shape[:-1] 33 | if len(prefix) != 3: # reshape to [B, H, W, -1] 34 | l = l.reshape(1, 1, -1, l.shape[-1]) 35 | 36 | light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube') 37 | light = light.view(*prefix, -1) 38 | 39 | return light 40 | -------------------------------------------------------------------------------- /scene/fit3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | timm 0.9.10 must be installed to use the get_intermediate_layers method. 3 | pip install timm==0.9.10 4 | pip install torch_kmeans 5 | 6 | """ 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import os 11 | import requests 12 | import itertools 13 | import timm 14 | import torch 15 | import types 16 | import albumentations as A 17 | from torch.nn import functional as F 18 | 19 | from PIL import Image 20 | from sklearn.decomposition import PCA 21 | from torch_kmeans import KMeans, CosineSimilarity 22 | from scene.dinov2 import dinov2_vits14_reg 23 | import cv2 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | cmap = plt.get_cmap("tab20") 27 | MEAN = np.array([0.3609696, 0.38405442, 0.4348492]) 28 | STD = np.array([0.19669543, 0.20297967, 0.22123419]) 29 | 30 | options = ['DINOv2-reg'] 31 | 32 | timm_model_card = { 33 | "DINOv2": "vit_small_patch14_dinov2.lvd142m", 34 | "DINOv2-reg": "vit_small_patch14_reg4_dinov2.lvd142m", 35 | "CLIP": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k", 36 | "MAE": "vit_base_patch16_224.mae", 37 | "DeiT-III": "deit3_base_patch16_224.fb_in1k" 38 | } 39 | 40 | our_model_card = { 41 | "DINOv2": "dinov2_small_fine", 42 | "DINOv2-reg": "dinov2_reg_small_fine", 43 | "CLIP": "clip_base_fine", 44 | "MAE": "mae_base_fine", 45 | "DeiT-III": "deit3_base_fine" 46 | } 47 | 48 | 49 | 50 | # transforms = A.Compose([ 51 | # A.Normalize(mean=list(MEAN), std=list(STD)), 52 | # ]) 53 | 54 | 55 | def get_intermediate_layers( 56 | self, 57 | x: torch.Tensor, 58 | n=1, 59 | reshape: bool = False, 60 | return_prefix_tokens: bool = False, 61 | return_class_token: bool = False, 62 | norm: bool = True, 63 | ): 64 | 65 | outputs = self._intermediate_layers(x, n) 66 | if norm: 67 | outputs = [self.norm(out) for out in outputs] 68 | if return_class_token: 69 | prefix_tokens = [out[:, 0] for out in outputs] 70 | else: 71 | prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] 72 | outputs = [out[:, self.num_prefix_tokens :] for out in outputs] 73 | 74 | if reshape: 75 | B, C, H, W = x.shape 76 | grid_size = ( 77 | (H - self.patch_embed.patch_size[0]) 78 | // self.patch_embed.proj.stride[0] 79 | + 1, 80 | (W - self.patch_embed.patch_size[1]) 81 | // self.patch_embed.proj.stride[1] 82 | + 1, 83 | ) 84 | outputs = [ 85 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) 86 | .permute(0, 3, 1, 2) 87 | .contiguous() 88 | for out in outputs 89 | ] 90 | 91 | if return_prefix_tokens or return_class_token: 92 | return tuple(zip(outputs, prefix_tokens)) 93 | return tuple(outputs) 94 | 95 | 96 | def viz_feat(feat): 97 | 98 | _, _, h, w = feat.shape 99 | feat = feat.squeeze(0).permute((1,2,0)) 100 | projected_featmap = feat.reshape(-1, feat.shape[-1]).cpu() 101 | 102 | pca = PCA(n_components=3) 103 | pca.fit(projected_featmap) 104 | pca_features = pca.transform(projected_featmap) 105 | pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) 106 | pca_features = pca_features * 255 107 | res_pred = Image.fromarray(pca_features.reshape(h, w, 3).astype(np.uint8)) 108 | 109 | return res_pred 110 | 111 | 112 | def plot_feats(image, model_option, ori_feats, fine_feats, ori_labels=None, fine_labels=None): 113 | 114 | ori_feats_map = viz_feat(ori_feats) 115 | fine_feats_map = viz_feat(fine_feats) 116 | 117 | if ori_labels is not None: 118 | fig, ax = plt.subplots(2, 3, figsize=(10, 5)) 119 | ax[0][0].imshow(image) 120 | ax[0][0].set_title("Input image", fontsize=15) 121 | ax[0][1].imshow(ori_feats_map) 122 | ax[0][1].set_title("Original " + model_option, fontsize=15) 123 | ax[0][2].imshow(fine_feats_map) 124 | ax[0][2].set_title("Ours", fontsize=15) 125 | ax[1][1].imshow(ori_labels) 126 | ax[1][2].imshow(fine_labels) 127 | for xx in ax: 128 | for x in xx: 129 | x.xaxis.set_major_formatter(plt.NullFormatter()) 130 | x.yaxis.set_major_formatter(plt.NullFormatter()) 131 | x.set_xticks([]) 132 | x.set_yticks([]) 133 | x.axis('off') 134 | 135 | else: 136 | fig, ax = plt.subplots(1, 3, figsize=(30, 8)) 137 | ax[0].imshow(image) 138 | ax[0].set_title("Input image", fontsize=15) 139 | ax[1].imshow(ori_feats_map) 140 | ax[1].set_title("Original " + model_option, fontsize=15) 141 | ax[2].imshow(fine_feats_map) 142 | ax[2].set_title("FiT3D", fontsize=15) 143 | 144 | for x in ax: 145 | x.xaxis.set_major_formatter(plt.NullFormatter()) 146 | x.yaxis.set_major_formatter(plt.NullFormatter()) 147 | x.set_xticks([]) 148 | x.set_yticks([]) 149 | x.axis('off') 150 | 151 | plt.tight_layout() 152 | plt.savefig("output2.png") 153 | # plt.close(fig) 154 | return fig 155 | 156 | 157 | def download_image(url, save_path): 158 | response = requests.get(url) 159 | with open(save_path, 'wb') as file: 160 | file.write(response.content) 161 | 162 | 163 | def process_image(image, stride, transforms): 164 | transformed = transforms(image=np.array(image)) 165 | image_tensor = torch.tensor(transformed['image']) 166 | image_tensor = image_tensor.permute(2,0,1) 167 | image_tensor = image_tensor.unsqueeze(0).to(device) 168 | 169 | h, w = image_tensor.shape[2:] 170 | 171 | height_int = ((h + stride-1) // stride)*stride 172 | width_int = ((w+stride-1) // stride)*stride 173 | 174 | image_resized = torch.nn.functional.interpolate(image_tensor, size=(height_int, width_int), mode='bilinear') 175 | 176 | return image_resized 177 | 178 | 179 | def kmeans_clustering(feats_map, n_clusters=20): 180 | 181 | B, D, h, w = feats_map.shape 182 | feats_map_flattened = feats_map.permute((0, 2, 3, 1)).reshape(B, -1, D) 183 | 184 | kmeans_engine = KMeans(n_clusters=n_clusters, distance=CosineSimilarity) 185 | kmeans_engine.fit(feats_map_flattened) 186 | labels = kmeans_engine.predict( 187 | feats_map_flattened 188 | ) 189 | labels = labels.reshape( 190 | B, h, w 191 | ).float() 192 | labels = labels[0].cpu().numpy() 193 | 194 | label_map = cmap(labels / n_clusters)[..., :3] 195 | label_map = np.uint8(label_map * 255) 196 | label_map = Image.fromarray(label_map) 197 | 198 | return label_map 199 | 200 | 201 | def run_demo(original_model, fine_model, image_path, kmeans=20): 202 | """ 203 | Run the demo for a given model option and image 204 | model_option: ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III'] 205 | image_path: path to the image 206 | kmeans: number of clusters for kmeans. Default is 20. -1 means no kmeans. 207 | """ 208 | p = original_model.patch_embed.patch_size 209 | stride = p if isinstance(p, int) else p[0] 210 | image = Image.open(image_path) 211 | transforms = A.Compose([ 212 | A.Normalize(mean=list(MEAN), std=list(STD)), 213 | ]) 214 | image_resized = process_image(image, stride, transforms) 215 | with torch.no_grad(): 216 | ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, 217 | return_class_token=False, norm=True) 218 | fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, 219 | return_class_token=False, norm=True) 220 | 221 | ori_feats = ori_feats[-1] 222 | fine_feats = fine_feats[-1] 223 | if kmeans != -1: 224 | ori_labels = kmeans_clustering(ori_feats, kmeans) 225 | fine_labels = kmeans_clustering(fine_feats, kmeans) 226 | else: 227 | ori_labels = None 228 | fine_labels = None 229 | print("image shape: ", image.size) 230 | print("image_resized shape: ", image_resized.shape) 231 | print("ori_feats shape: ", ori_feats.shape) 232 | print("fine_feats shape: ", fine_feats.shape) 233 | 234 | 235 | return plot_feats(image, "DINOv2-reg", ori_feats, fine_feats, ori_labels, fine_labels),ori_feats,fine_feats 236 | 237 | def Dinov2RegExtractor(original_model, fine_model, image,transforms=None, kmeans=20,only_fine_feats:bool=False): 238 | """ 239 | Run the demo for a given model option and image 240 | model_option: ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III'] 241 | image_path: path to the image 242 | kmeans: number of clusters for kmeans. Default is 20. -1 means no kmeans. 243 | """ 244 | p = original_model.patch_embed.patch_size 245 | stride = p if isinstance(p, int) else p[0] 246 | image=image.cpu().numpy() 247 | image_array = (image * 255).astype(np.uint8) 248 | image_array=image_array.squeeze(0).transpose(1,2,0) 249 | image = Image.fromarray(image_array) 250 | fine_feats=None 251 | ori_feats=None 252 | if transforms is not None: 253 | image_resized = process_image(image, stride, transforms) 254 | else: 255 | image_resized=image 256 | with torch.no_grad(): 257 | 258 | fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, 259 | return_class_token=False, norm=True) 260 | if not only_fine_feats: 261 | ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True, 262 | return_class_token=False, norm=True) 263 | 264 | 265 | fine_feats = fine_feats[-1] 266 | if not only_fine_feats: 267 | ori_feats = ori_feats[-1] 268 | # For semantic segmentation 269 | # if kmeans != -1: 270 | # ori_labels = kmeans_clustering(ori_feats, kmeans) 271 | # fine_labels = kmeans_clustering(fine_feats, kmeans) 272 | # else: 273 | # ori_labels = None 274 | # fine_labels = None 275 | # print("image shape: ", image.size) 276 | # print("image_resized shape: ", image_resized.shape) 277 | # print("ori_feats shape: ", ori_feats.shape) 278 | # print("fine_feats shape: ", fine_feats.shape) 279 | 280 | 281 | return ori_feats,fine_feats 282 | 283 | 284 | def LoadDinov2Model(): 285 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 286 | original_model = timm.create_model( 287 | "vit_small_patch14_reg4_dinov2.lvd142m", 288 | pretrained=True, 289 | num_classes=0, 290 | dynamic_img_size=True, 291 | dynamic_img_pad=False, 292 | ).to(device) 293 | original_model.get_intermediate_layers = types.MethodType( 294 | get_intermediate_layers, 295 | original_model 296 | ) 297 | fine_model = torch.hub.load("ywyue/FiT3D", "dinov2_reg_small_fine").to(device) 298 | fine_model.get_intermediate_layers = types.MethodType( 299 | get_intermediate_layers, 300 | fine_model 301 | ) 302 | return original_model,fine_model 303 | 304 | 305 | def GetDinov2RegFeats(original_model,fine_model,image,transforms=None,kmeans=20): 306 | ori_feats,fine_feats=Dinov2RegExtractor(original_model,fine_model,image,transforms) 307 | return fine_feats 308 | -------------------------------------------------------------------------------- /scene/scene_utils.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | import numpy as np 3 | from utils.graphics_utils import getWorld2View2 4 | from scene.gaussian_model import BasicPointCloud 5 | from plyfile import PlyData, PlyElement 6 | 7 | 8 | class CameraInfo(NamedTuple): 9 | uid: int 10 | R: np.array 11 | T: np.array 12 | image: np.array 13 | image_path: str 14 | image_name: str 15 | width: int 16 | height: int 17 | sky_mask: np.array = None 18 | timestamp: float = 0.0 19 | FovY: float = None 20 | FovX: float = None 21 | fx: float = None 22 | fy: float = None 23 | cx: float = None 24 | cy: float = None 25 | pointcloud_camera: np.array = None 26 | 27 | 28 | depth_map: np.array = None 29 | semantic_mash: np.array = None 30 | instance_mask: np.array = None 31 | sam_mask: np.array = None 32 | dynamic_mask: np.array = None 33 | feat_map: np.array = None 34 | normal_map: np.array = None 35 | 36 | objects: np.array = None 37 | intrinsic: np.array = None 38 | c2w: np.array = None 39 | 40 | class SceneInfo(NamedTuple): 41 | point_cloud: BasicPointCloud 42 | train_cameras: list 43 | test_cameras: list 44 | nerf_normalization: dict 45 | ply_path: str 46 | time_interval: float = 0.02 47 | time_duration: list = [0, 1] 48 | 49 | 50 | full_cameras : list = None 51 | bg_point_cloud: BasicPointCloud = None 52 | bg_ply_path: str = None 53 | cam_frustum_aabb: np.array = None 54 | num_panoptic_objects: int = 0 55 | panoptic_id_to_idx: dict = None 56 | panoptic_object_ids: list = None 57 | occ_grid: np.array = None 58 | 59 | def getNerfppNorm(cam_info): 60 | def get_center_and_diag(cam_centers): 61 | cam_centers = np.hstack(cam_centers) 62 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 63 | center = avg_cam_center 64 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 65 | diagonal = np.max(dist) 66 | return center.flatten(), diagonal 67 | 68 | cam_centers = [] 69 | 70 | for cam in cam_info: 71 | W2C = getWorld2View2(cam.R, cam.T) 72 | C2W = np.linalg.inv(W2C) 73 | cam_centers.append(C2W[:3, 3:4]) 74 | 75 | center, diagonal = get_center_and_diag(cam_centers) 76 | radius = diagonal * 1.1 77 | 78 | translate = -center 79 | 80 | return {"translate": translate, "radius": radius} 81 | 82 | 83 | def fetchPly(path): 84 | plydata = PlyData.read(path) 85 | vertices = plydata['vertex'] 86 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 87 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 88 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 89 | if 'time' in vertices: 90 | timestamp = vertices['time'][:, None] 91 | else: 92 | timestamp = None 93 | return BasicPointCloud(points=positions, colors=colors, normals=normals, time=timestamp) 94 | 95 | 96 | def storePly(path, xyz, rgb, timestamp=None): 97 | # Define the dtype for the structured array 98 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 99 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 100 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'), 101 | ('time', 'f4')] 102 | 103 | normals = np.zeros_like(xyz) 104 | if timestamp is None: 105 | timestamp = np.zeros_like(xyz[:, :1]) 106 | 107 | elements = np.empty(xyz.shape[0], dtype=dtype) 108 | attributes = np.concatenate((xyz, normals, rgb, timestamp), axis=1) 109 | elements[:] = list(map(tuple, attributes)) 110 | 111 | # Create the PlyData object and write to file 112 | vertex_element = PlyElement.describe(elements, 'vertex') 113 | ply_data = PlyData([vertex_element]) 114 | ply_data.write(path) 115 | -------------------------------------------------------------------------------- /scene/waymo_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly 6 | from utils.graphics_utils import BasicPointCloud, focal2fov 7 | 8 | 9 | def pad_poses(p): 10 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" 11 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) 12 | return np.concatenate([p[..., :3, :4], bottom], axis=-2) 13 | 14 | 15 | def unpad_poses(p): 16 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" 17 | return p[..., :3, :4] 18 | 19 | 20 | def transform_poses_pca(poses, fix_radius=0): 21 | """Transforms poses so principal components lie on XYZ axes. 22 | 23 | Args: 24 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 25 | 26 | Returns: 27 | A tuple (poses, transform), with the transformed poses and the applied 28 | camera_to_world transforms. 29 | 30 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161 31 | """ 32 | t = poses[:, :3, 3] 33 | t_mean = t.mean(axis=0) 34 | t = t - t_mean 35 | 36 | eigval, eigvec = np.linalg.eig(t.T @ t) 37 | # Sort eigenvectors in order of largest to smallest eigenvalue. 38 | inds = np.argsort(eigval)[::-1] 39 | eigvec = eigvec[:, inds] 40 | rot = eigvec.T 41 | if np.linalg.det(rot) < 0: 42 | rot = np.diag(np.array([1, 1, -1])) @ rot 43 | 44 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 45 | poses_recentered = unpad_poses(transform @ pad_poses(poses)) 46 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) 47 | 48 | # Flip coordinate system if z component of y-axis is negative 49 | if poses_recentered.mean(axis=0)[2, 1] < 0: 50 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 51 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 52 | 53 | # Just make sure it's it in the [-1, 1]^3 cube 54 | if fix_radius>0: 55 | scale_factor = 1./fix_radius 56 | else: 57 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5) 58 | scale_factor = min(1 / 10, scale_factor) 59 | 60 | poses_recentered[:, :3, 3] *= scale_factor 61 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform 62 | 63 | return poses_recentered, transform, scale_factor 64 | 65 | 66 | def readWaymoInfo(args): 67 | neg_fov = args.neg_fov 68 | cam_infos = [] 69 | car_list = [f[:-4] for f in sorted(os.listdir(os.path.join(args.source_path, "calib"))) if f.endswith('.txt')] 70 | points = [] 71 | points_time = [] 72 | 73 | load_size = [640, 960] 74 | ORIGINAL_SIZE = [[1280, 1920], [1280, 1920], [1280, 1920], [884, 1920], [884, 1920]] 75 | frame_num = len(car_list) 76 | if args.frame_interval > 0: 77 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2] 78 | else: 79 | time_duration = args.time_duration 80 | 81 | for idx, car_id in tqdm(enumerate(car_list), desc="Loading data"): 82 | ego_pose = np.loadtxt(os.path.join(args.source_path, 'pose', car_id + '.txt')) 83 | 84 | # CAMERA DIRECTION: RIGHT DOWN FORWARDS 85 | with open(os.path.join(args.source_path, 'calib', car_id + '.txt')) as f: 86 | calib_data = f.readlines() 87 | L = [list(map(float, line.split()[1:])) for line in calib_data] 88 | Ks = np.array(L[:5]).reshape(-1, 3, 4)[:, :, :3] 89 | lidar2cam = np.array(L[-5:]).reshape(-1, 3, 4) 90 | lidar2cam = pad_poses(lidar2cam) 91 | 92 | cam2lidar = np.linalg.inv(lidar2cam) 93 | c2w = ego_pose @ cam2lidar 94 | w2c = np.linalg.inv(c2w) 95 | images = [] 96 | image_paths = [] 97 | HWs = [] 98 | for subdir in ['image_0', 'image_1', 'image_2', 'image_3', 'image_4'][:args.cam_num]: 99 | image_path = os.path.join(args.source_path, subdir, car_id + '.png') 100 | im_data = Image.open(image_path) 101 | im_data = im_data.resize((load_size[1], load_size[0]), Image.BILINEAR) # PIL resize: (W, H) 102 | W, H = im_data.size 103 | image = np.array(im_data) / 255. 104 | HWs.append((H, W)) 105 | images.append(image) 106 | image_paths.append(image_path) 107 | 108 | sky_masks = [] 109 | for subdir in ['sky_0', 'sky_1', 'sky_2', 'sky_3', 'sky_4'][:args.cam_num]: 110 | sky_data = Image.open(os.path.join(args.source_path, subdir, car_id + '.png')) 111 | sky_data = sky_data.resize((load_size[1], load_size[0]), Image.NEAREST) # PIL resize: (W, H) 112 | sky_mask = np.array(sky_data)>0 113 | sky_masks.append(sky_mask.astype(np.float32)) 114 | 115 | normal_maps = [] 116 | if args.load_normal_map: 117 | for subdir in ['normal_0', 'normal_1', 'normal_2', 'normal_3', 'normal_4'][:args.cam_num]: 118 | normal_data = Image.open(os.path.join(args.source_path, subdir, car_id + '.jpg')) 119 | normal_data = normal_data.resize((load_size[1], load_size[0]), Image.BILINEAR) 120 | normal_map = (np.array(normal_data)) / 255. 121 | normal_maps.append(normal_map) 122 | 123 | 124 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (len(car_list) - 1) 125 | point = np.fromfile(os.path.join(args.source_path, "velodyne", car_id + ".bin"), 126 | dtype=np.float32, count=-1).reshape(-1, 6) 127 | point_xyz, intensity, elongation, timestamp_pts = np.split(point, [3, 4, 5], axis=1) 128 | point_xyz_world = (np.pad(point_xyz, (0, 1), constant_values=1) @ ego_pose.T)[:, :3] 129 | points.append(point_xyz_world) 130 | point_time = np.full_like(point_xyz_world[:, :1], timestamp) 131 | points_time.append(point_time) 132 | for j in range(args.cam_num): 133 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ lidar2cam[j].T)[:, :3] 134 | R = np.transpose(w2c[j, :3, :3]) # R is stored transposed due to 'glm' in CUDA code 135 | T = w2c[j, :3, 3] 136 | K = Ks[j] 137 | fx = float(K[0, 0]) * load_size[1] / ORIGINAL_SIZE[j][1] 138 | fy = float(K[1, 1]) * load_size[0] / ORIGINAL_SIZE[j][0] 139 | cx = float(K[0, 2]) * load_size[1] / ORIGINAL_SIZE[j][1] 140 | cy = float(K[1, 2]) * load_size[0] / ORIGINAL_SIZE[j][0] 141 | width=HWs[j][1] 142 | height=HWs[j][0] 143 | if neg_fov: 144 | FovY = -1.0 145 | FovX = -1.0 146 | else: 147 | FovY = focal2fov(fy, height) 148 | FovX = focal2fov(fx, width) 149 | cam_infos.append(CameraInfo(uid=idx * 5 + j, R=R, T=T, FovY=FovY, FovX=FovX, 150 | image=images[j], 151 | image_path=image_paths[j], image_name=car_id, 152 | width=HWs[j][1], height=HWs[j][0], timestamp=timestamp, 153 | pointcloud_camera = point_camera, 154 | fx=fx, fy=fy, cx=cx, cy=cy, 155 | sky_mask=sky_masks[j], 156 | normal_map=normal_maps[j] if args.load_normal_map else None)) 157 | 158 | if args.debug_cuda: 159 | break 160 | 161 | pointcloud = np.concatenate(points, axis=0) 162 | pointcloud_timestamp = np.concatenate(points_time, axis=0) 163 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True) 164 | pointcloud = pointcloud[indices] 165 | pointcloud_timestamp = pointcloud_timestamp[indices] 166 | 167 | w2cs = np.zeros((len(cam_infos), 4, 4)) 168 | Rs = np.stack([c.R for c in cam_infos], axis=0) 169 | Ts = np.stack([c.T for c in cam_infos], axis=0) 170 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1)) 171 | w2cs[:, :3, 3] = Ts 172 | w2cs[:, 3, 3] = 1 173 | c2ws = unpad_poses(np.linalg.inv(w2cs)) 174 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius) 175 | 176 | c2ws = pad_poses(c2ws) 177 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")): 178 | c2w = c2ws[idx] 179 | w2c = np.linalg.inv(c2w) 180 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code 181 | cam_info.T[:] = w2c[:3, 3] 182 | cam_info.pointcloud_camera[:] *= scale_factor 183 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3] 184 | if args.eval: 185 | # ## for snerf scene 186 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0] 187 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0] 188 | 189 | # for dynamic scene 190 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0] 191 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0] 192 | 193 | # for emernerf comparison [testhold::testhold] 194 | if args.testhold == 10: 195 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0] 196 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0] 197 | else: 198 | train_cam_infos = cam_infos 199 | test_cam_infos = [] 200 | 201 | nerf_normalization = getNerfppNorm(train_cam_infos) 202 | nerf_normalization['radius'] = 1/nerf_normalization['radius'] 203 | 204 | ply_path = os.path.join(args.source_path, "points3d.ply") 205 | if not os.path.exists(ply_path): 206 | rgbs = np.random.random((pointcloud.shape[0], 3)) 207 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp) 208 | try: 209 | pcd = fetchPly(ply_path) 210 | except: 211 | pcd = None 212 | 213 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp) 214 | time_interval = (time_duration[1] - time_duration[0]) / (len(car_list) - 1) 215 | 216 | scene_info = SceneInfo(point_cloud=pcd, 217 | train_cameras=train_cam_infos, 218 | test_cameras=test_cam_infos, 219 | nerf_normalization=nerf_normalization, 220 | ply_path=ply_path, 221 | time_interval=time_interval, 222 | time_duration=time_duration) 223 | 224 | return scene_info 225 | -------------------------------------------------------------------------------- /scripts/extract_mask_kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_masks.py 3 | @author Jianfei Guo, Shanghai AI Lab 4 | @brief Extract semantic mask 5 | 6 | Using SegFormer, 2021. Cityscapes 83.2% 7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9) 8 | 9 | Installation: 10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env. 11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9; 12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num' 13 | Hence, a seperate conda env is needed. 14 | 15 | git clone https://github.com/NVlabs/SegFormer 16 | 17 | conda create -n segformer python=3.8 18 | conda activate segformer 19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf 23 | pip install mmcv-full==1.2.7 --no-cache-dir 24 | 25 | cd SegFormer 26 | pip install . 27 | 28 | Usage: 29 | Direct run this script in the newly set conda env. 30 | """ 31 | import os 32 | import numpy as np 33 | import cv2 34 | from tqdm import tqdm 35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot 36 | from mmseg.core.evaluation import get_palette 37 | 38 | if __name__ == "__main__": 39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer' 40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 'segformer.b5.1024x1024.city.160k.py') 41 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth') 42 | model = init_segmentor(config, checkpoint, device='cuda') 43 | 44 | root = 'data/kitti_mot/training' 45 | 46 | for cam_id in ['2', '3']: 47 | image_dir = os.path.join(root, f'image_0{cam_id}') 48 | sky_dir = os.path.join(root, f'sky_0{cam_id}') 49 | for seq in sorted(os.listdir(image_dir)): 50 | seq_dir = os.path.join(image_dir, seq) 51 | mask_dir = os.path.join(sky_dir, seq) 52 | if not os.path.isdir(seq_dir): 53 | continue 54 | 55 | os.makedirs(image_dir, exist_ok=True) 56 | os.makedirs(mask_dir, exist_ok=True) 57 | for image_name in sorted(os.listdir(seq_dir)): 58 | image_path = os.path.join(seq_dir, image_name) 59 | print(image_path) 60 | mask_path = os.path.join(mask_dir, image_name) 61 | if not image_path.endswith(".png"): 62 | continue 63 | result = inference_segmentor(model, image_path) 64 | mask = result[0].astype(np.uint8) 65 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8) 66 | cv2.imwrite(os.path.join(mask_dir, image_name), mask) 67 | -------------------------------------------------------------------------------- /scripts/extract_mask_waymo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_masks.py 3 | @author Jianfei Guo, Shanghai AI Lab 4 | @brief Extract semantic mask 5 | 6 | Using SegFormer, 2021. Cityscapes 83.2% 7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9) 8 | 9 | Installation: 10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env. 11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9; 12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num' 13 | Hence, a seperate conda env is needed. 14 | 15 | git clone https://github.com/NVlabs/SegFormer 16 | 17 | conda create -n segformer python=3.8 18 | conda activate segformer 19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf 23 | pip install mmcv-full==1.2.7 --no-cache-dir 24 | 25 | cd SegFormer 26 | pip install . 27 | 28 | Usage: 29 | Direct run this script in the newly set conda env. 30 | """ 31 | import os 32 | import numpy as np 33 | import cv2 34 | from tqdm import tqdm 35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot 36 | from mmseg.core.evaluation import get_palette 37 | 38 | if __name__ == "__main__": 39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer' 40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 41 | 'segformer.b5.1024x1024.city.160k.py') 42 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth') 43 | model = init_segmentor(config, checkpoint, device='cuda') 44 | 45 | root = 'data/waymo_scenes' 46 | 47 | scenes = sorted(os.listdir(root)) 48 | 49 | for scene in scenes: 50 | for cam_id in range(5): 51 | image_dir = os.path.join(root, scene, f'image_{cam_id}') 52 | sky_dir = os.path.join(root, scene, f'sky_{cam_id}') 53 | os.makedirs(sky_dir, exist_ok=True) 54 | for image_name in tqdm(sorted(os.listdir(image_dir))): 55 | if not image_name.endswith(".png"): 56 | continue 57 | image_path = os.path.join(image_dir, image_name) 58 | mask_path = os.path.join(sky_dir, image_name) 59 | result = inference_segmentor(model, image_path) 60 | mask = result[0].astype(np.uint8) 61 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8) 62 | cv2.imwrite(mask_path, mask) 63 | -------------------------------------------------------------------------------- /scripts/extract_mono_cues_kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_mono_cues.py 3 | @brief extract monocular cues (normal & depth) 4 | Adapted from https://github.com/EPFL-VILAB/omnidata 5 | 6 | Installation: 7 | git clone https://github.com/EPFL-VILAB/omnidata 8 | 9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning 10 | """ 11 | 12 | import os 13 | import sys 14 | import argparse 15 | import imageio 16 | import numpy as np 17 | from glob import glob 18 | from tqdm import tqdm 19 | from typing import Literal 20 | 21 | import PIL 22 | import skimage 23 | from PIL import Image 24 | import matplotlib.pyplot as plt 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | from torchvision import transforms 29 | import torchvision.transforms.functional as transF 30 | 31 | def list_contains(l: list, v): 32 | """ 33 | Whether any item in `l` contains `v` 34 | """ 35 | for item in l: 36 | if v in item: 37 | return True 38 | else: 39 | return False 40 | 41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): 42 | if mask_valid is not None: 43 | img[~mask_valid] = torch.nan 44 | sorted_img = torch.sort(torch.flatten(img))[0] 45 | # Remove nan, nan at the end of sort 46 | num_nan = sorted_img.isnan().sum() 47 | if num_nan > 0: 48 | sorted_img = sorted_img[:-num_nan] 49 | # Remove outliers 50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] 51 | trunc_mean = trunc_img.mean() 52 | trunc_var = trunc_img.var() 53 | eps = 1e-6 54 | # Replace nan by mean 55 | img = torch.nan_to_num(img, nan=trunc_mean) 56 | # Standardize 57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) 58 | return img 59 | 60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True): 61 | with torch.no_grad(): 62 | # img = Image.open(img_path) 63 | img = imageio.imread(img_path) 64 | img = skimage.img_as_float32(img) 65 | H, W, _ = img.shape 66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device) 67 | if H < W: 68 | H_ = ref_img_size 69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32 70 | else: 71 | W_ = ref_img_size 72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32 73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True) 74 | 75 | if img_tensor.shape[1] == 1: 76 | img_tensor = img_tensor.repeat_interleave(3,1) 77 | 78 | output = model(img_tensor).clamp(min=0, max=1) 79 | 80 | if args.task == 'depth': 81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0) 82 | output = output.clamp(0,1).data / output.max() 83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 86 | 87 | # np.savez_compressed(f"{output_path_base}.npz", output) 88 | #output = 1 - output 89 | # output = standardize_depth_map(output) 90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis') 91 | if verbose: 92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size 93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here. 94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16)) 95 | 96 | else: 97 | output = output.data.clamp(0,1).squeeze(0) 98 | # Resize to original shape 99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here. 100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 101 | 102 | # np.savez_compressed(f"{output_path_base}.npz", output) 103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5) 104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size 105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow 106 | # np.save(f"{output_path_base}.npy", output) 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals') 110 | 111 | # Dataset specific configs 112 | parser.add_argument('--data_root', type=str, default='./data/kitti_mot/training') 113 | parser.add_argument('--seq_list', type=str, default=['0001', '0002', '0006'], help='specify --seq_list if you want to limit the list of seqs') 114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization") 115 | parser.add_argument('--ignore_existing', action='store_true') 116 | parser.add_argument('--rgb_dirname', type=str, default="image") 117 | parser.add_argument('--depth_dirname', type=str, default="depth") 118 | parser.add_argument('--normals_dirname', type=str, default="normal") 119 | 120 | # Algorithm configs 121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth") 122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model", 123 | default='omnidata/omnidata_tools/torch') 124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models", 125 | default='omnidata/omnidata_tools/torch/pretrained_models') 126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512, 127 | help="image size when inference (will still save full-scale output)") 128 | args = parser.parse_args() 129 | 130 | #----------------------------------------------- 131 | #-- Original preparation 132 | #----------------------------------------------- 133 | if args.pretrained_models is None: 134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/' 135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models") 136 | 137 | sys.path.append(args.omnidata_path) 138 | print(sys.path) 139 | from modules.unet import UNet 140 | from modules.midas.dpt_depth import DPTDepthModel 141 | from data.transforms import get_transform 142 | 143 | trans_topil = transforms.ToPILImage() 144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') 145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 146 | 147 | # get target task and model 148 | if args.task == 'normal' or args.task == 'normals': 149 | image_size = 384 150 | 151 | #---- Version 1 model 152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth') 153 | # model = UNet(in_channels=3, out_channels=3) 154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 155 | 156 | # if 'state_dict' in checkpoint: 157 | # state_dict = {} 158 | # for k, v in checkpoint['state_dict'].items(): 159 | # state_dict[k.replace('model.', '')] = v 160 | # else: 161 | # state_dict = checkpoint 162 | 163 | 164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt') 165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid 166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 167 | if 'state_dict' in checkpoint: 168 | state_dict = {} 169 | for k, v in checkpoint['state_dict'].items(): 170 | state_dict[k[6:]] = v 171 | else: 172 | state_dict = checkpoint 173 | 174 | model.load_state_dict(state_dict) 175 | model.to(device) 176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 177 | # transforms.CenterCrop(image_size), 178 | # get_transform('rgb', image_size=None)]) 179 | 180 | trans_totensor = transforms.Compose([ 181 | get_transform('rgb', image_size=None)]) 182 | 183 | elif args.task == 'depth': 184 | image_size = 384 185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt' 186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large 187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid 188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 189 | if 'state_dict' in checkpoint: 190 | state_dict = {} 191 | for k, v in checkpoint['state_dict'].items(): 192 | state_dict[k[6:]] = v 193 | else: 194 | state_dict = checkpoint 195 | model.load_state_dict(state_dict) 196 | model.to(device) 197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR), 198 | # transforms.CenterCrop(image_size), 199 | # transforms.ToTensor(), 200 | # transforms.Normalize(mean=0.5, std=0.5)]) 201 | trans_totensor = transforms.Compose([ 202 | transforms.ToTensor(), 203 | transforms.Normalize(mean=0.5, std=0.5)]) 204 | 205 | else: 206 | print("task should be one of the following: normal, depth") 207 | sys.exit() 208 | 209 | #----------------------------------------------- 210 | #--- Dataset Specific processing 211 | #----------------------------------------------- 212 | 213 | select_scene_ids = [f'{args.rgb_dirname}_02', f'{args.rgb_dirname}_03'] 214 | 215 | 216 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')): 217 | for seq in args.seq_list: 218 | image_dir = os.path.join(args.data_root, scene_id, seq) 219 | obs_id_list = sorted(os.listdir(image_dir)) 220 | 221 | if args.task == 'depth': 222 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname) 223 | elif args.task == 'normal': 224 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname) 225 | else: 226 | raise RuntimeError(f"Invalid task={args.task}") 227 | 228 | os.makedirs(output_dir, exist_ok=True) 229 | 230 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')): 231 | img_path = os.path.join(image_dir, obs_id) 232 | fbase = os.path.splitext(os.path.basename(img_path))[0] 233 | 234 | if args.task == 'depth': 235 | output_base = os.path.join(output_dir, fbase) 236 | elif args.task == 'normal': 237 | output_base = os.path.join(output_dir, fbase) 238 | else: 239 | raise RuntimeError(f"Invalid task={args.task}") 240 | 241 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase): 242 | continue 243 | 244 | #---- Inference and save outputs 245 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose) -------------------------------------------------------------------------------- /scripts/extract_mono_cues_notr.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_mono_cues.py 3 | @brief extract monocular cues (normal & depth) 4 | Adapted from https://github.com/EPFL-VILAB/omnidata 5 | 6 | Installation: 7 | git clone https://github.com/EPFL-VILAB/omnidata 8 | 9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning 10 | """ 11 | 12 | import os 13 | import sys 14 | import argparse 15 | import imageio 16 | import numpy as np 17 | from glob import glob 18 | from tqdm import tqdm 19 | from typing import Literal 20 | 21 | import PIL 22 | import skimage 23 | from PIL import Image 24 | import matplotlib.pyplot as plt 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | from torchvision import transforms 29 | import torchvision.transforms.functional as transF 30 | 31 | def list_contains(l: list, v): 32 | """ 33 | Whether any item in `l` contains `v` 34 | """ 35 | for item in l: 36 | if v in item: 37 | return True 38 | else: 39 | return False 40 | 41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): 42 | if mask_valid is not None: 43 | img[~mask_valid] = torch.nan 44 | sorted_img = torch.sort(torch.flatten(img))[0] 45 | # Remove nan, nan at the end of sort 46 | num_nan = sorted_img.isnan().sum() 47 | if num_nan > 0: 48 | sorted_img = sorted_img[:-num_nan] 49 | # Remove outliers 50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] 51 | trunc_mean = trunc_img.mean() 52 | trunc_var = trunc_img.var() 53 | eps = 1e-6 54 | # Replace nan by mean 55 | img = torch.nan_to_num(img, nan=trunc_mean) 56 | # Standardize 57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) 58 | return img 59 | 60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True): 61 | with torch.no_grad(): 62 | # img = Image.open(img_path) 63 | img = imageio.imread(img_path) 64 | img = skimage.img_as_float32(img) 65 | H, W, _ = img.shape 66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device) 67 | if H < W: 68 | H_ = ref_img_size 69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32 70 | else: 71 | W_ = ref_img_size 72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32 73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True) 74 | 75 | if img_tensor.shape[1] == 1: 76 | img_tensor = img_tensor.repeat_interleave(3,1) 77 | 78 | output = model(img_tensor).clamp(min=0, max=1) 79 | 80 | if args.task == 'depth': 81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0) 82 | output = output.clamp(0,1).data / output.max() 83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 86 | 87 | # np.savez_compressed(f"{output_path_base}.npz", output) 88 | #output = 1 - output 89 | # output = standardize_depth_map(output) 90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis') 91 | if verbose: 92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size 93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here. 94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16)) 95 | 96 | else: 97 | output = output.data.clamp(0,1).squeeze(0) 98 | # Resize to original shape 99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here. 100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 101 | 102 | # np.savez_compressed(f"{output_path_base}.npz", output) 103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5) 104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size 105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow 106 | # np.save(f"{output_path_base}.npy", output) 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals') 110 | 111 | # Dataset specific configs 112 | parser.add_argument('--data_root', type=str, default='./data/notr/') 113 | parser.add_argument('--seq_list', type=str, default=['084'], help='specify --seq_list if you want to limit the list of seqs') 114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization") 115 | parser.add_argument('--ignore_existing', action='store_true') 116 | parser.add_argument('--rgb_dirname', type=str, default="images") 117 | parser.add_argument('--depth_dirname', type=str, default="depths") 118 | parser.add_argument('--normals_dirname', type=str, default="normals") 119 | 120 | # Algorithm configs 121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth") 122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model", 123 | default='omnidata/omnidata_tools/torch') 124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models", 125 | default='omnidata/omnidata_tools/torch/pretrained_models') 126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512, 127 | help="image size when inference (will still save full-scale output)") 128 | args = parser.parse_args() 129 | 130 | #----------------------------------------------- 131 | #-- Original preparation 132 | #----------------------------------------------- 133 | if args.pretrained_models is None: 134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/' 135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models") 136 | 137 | sys.path.append(args.omnidata_path) 138 | print(sys.path) 139 | from modules.unet import UNet 140 | from modules.midas.dpt_depth import DPTDepthModel 141 | from data.transforms import get_transform 142 | 143 | trans_topil = transforms.ToPILImage() 144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') 145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 146 | 147 | # get target task and model 148 | if args.task == 'normal' or args.task == 'normals': 149 | image_size = 384 150 | 151 | #---- Version 1 model 152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth') 153 | # model = UNet(in_channels=3, out_channels=3) 154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 155 | 156 | # if 'state_dict' in checkpoint: 157 | # state_dict = {} 158 | # for k, v in checkpoint['state_dict'].items(): 159 | # state_dict[k.replace('model.', '')] = v 160 | # else: 161 | # state_dict = checkpoint 162 | 163 | 164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt') 165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid 166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 167 | if 'state_dict' in checkpoint: 168 | state_dict = {} 169 | for k, v in checkpoint['state_dict'].items(): 170 | state_dict[k[6:]] = v 171 | else: 172 | state_dict = checkpoint 173 | 174 | model.load_state_dict(state_dict) 175 | model.to(device) 176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 177 | # transforms.CenterCrop(image_size), 178 | # get_transform('rgb', image_size=None)]) 179 | 180 | trans_totensor = transforms.Compose([ 181 | get_transform('rgb', image_size=None)]) 182 | 183 | elif args.task == 'depth': 184 | image_size = 384 185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt' 186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large 187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid 188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 189 | if 'state_dict' in checkpoint: 190 | state_dict = {} 191 | for k, v in checkpoint['state_dict'].items(): 192 | state_dict[k[6:]] = v 193 | else: 194 | state_dict = checkpoint 195 | model.load_state_dict(state_dict) 196 | model.to(device) 197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR), 198 | # transforms.CenterCrop(image_size), 199 | # transforms.ToTensor(), 200 | # transforms.Normalize(mean=0.5, std=0.5)]) 201 | trans_totensor = transforms.Compose([ 202 | transforms.ToTensor(), 203 | transforms.Normalize(mean=0.5, std=0.5)]) 204 | 205 | else: 206 | print("task should be one of the following: normal, depth") 207 | sys.exit() 208 | 209 | #----------------------------------------------- 210 | #--- Dataset Specific processing 211 | #----------------------------------------------- 212 | if isinstance(args.seq_list, list): 213 | select_scene_ids = args.seq_list 214 | else: 215 | select_scene_ids = [args.seq_list] 216 | 217 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')): 218 | 219 | image_dir = os.path.join(args.data_root, scene_id, f"{args.rgb_dirname}") 220 | obs_id_list = sorted(os.listdir(image_dir)) 221 | 222 | if args.task == 'depth': 223 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname) 224 | elif args.task == 'normal': 225 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname) 226 | else: 227 | raise RuntimeError(f"Invalid task={args.task}") 228 | 229 | os.makedirs(output_dir, exist_ok=True) 230 | 231 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')): 232 | img_path = os.path.join(image_dir, obs_id) 233 | fbase = os.path.splitext(os.path.basename(img_path))[0] 234 | 235 | if args.task == 'depth': 236 | output_base = os.path.join(output_dir, fbase) 237 | elif args.task == 'normal': 238 | output_base = os.path.join(output_dir, fbase) 239 | else: 240 | raise RuntimeError(f"Invalid task={args.task}") 241 | 242 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase): 243 | continue 244 | 245 | #---- Inference and save outputs 246 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose) -------------------------------------------------------------------------------- /scripts/extract_mono_cues_waymo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file extract_mono_cues.py 3 | @brief extract monocular cues (normal & depth) 4 | Adapted from https://github.com/EPFL-VILAB/omnidata 5 | 6 | Installation: 7 | git clone https://github.com/EPFL-VILAB/omnidata 8 | 9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning 10 | """ 11 | 12 | import os 13 | import sys 14 | import argparse 15 | import imageio 16 | import numpy as np 17 | from glob import glob 18 | from tqdm import tqdm 19 | from typing import Literal 20 | 21 | import PIL 22 | import skimage 23 | from PIL import Image 24 | import matplotlib.pyplot as plt 25 | 26 | import torch 27 | import torch.nn.functional as F 28 | from torchvision import transforms 29 | import torchvision.transforms.functional as transF 30 | 31 | def list_contains(l: list, v): 32 | """ 33 | Whether any item in `l` contains `v` 34 | """ 35 | for item in l: 36 | if v in item: 37 | return True 38 | else: 39 | return False 40 | 41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): 42 | if mask_valid is not None: 43 | img[~mask_valid] = torch.nan 44 | sorted_img = torch.sort(torch.flatten(img))[0] 45 | # Remove nan, nan at the end of sort 46 | num_nan = sorted_img.isnan().sum() 47 | if num_nan > 0: 48 | sorted_img = sorted_img[:-num_nan] 49 | # Remove outliers 50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] 51 | trunc_mean = trunc_img.mean() 52 | trunc_var = trunc_img.var() 53 | eps = 1e-6 54 | # Replace nan by mean 55 | img = torch.nan_to_num(img, nan=trunc_mean) 56 | # Standardize 57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) 58 | return img 59 | 60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True): 61 | with torch.no_grad(): 62 | # img = Image.open(img_path) 63 | img = imageio.imread(img_path) 64 | img = skimage.img_as_float32(img) 65 | H, W, _ = img.shape 66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device) 67 | if H < W: 68 | H_ = ref_img_size 69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32 70 | else: 71 | W_ = ref_img_size 72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32 73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True) 74 | 75 | if img_tensor.shape[1] == 1: 76 | img_tensor = img_tensor.repeat_interleave(3,1) 77 | 78 | output = model(img_tensor).clamp(min=0, max=1) 79 | 80 | if args.task == 'depth': 81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0) 82 | output = output.clamp(0,1).data / output.max() 83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy() 86 | 87 | # np.savez_compressed(f"{output_path_base}.npz", output) 88 | #output = 1 - output 89 | # output = standardize_depth_map(output) 90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis') 91 | if verbose: 92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size 93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here. 94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16)) 95 | 96 | else: 97 | output = output.data.clamp(0,1).squeeze(0) 98 | # Resize to original shape 99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here. 100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy() 101 | 102 | # np.savez_compressed(f"{output_path_base}.npz", output) 103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5) 104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size 105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow 106 | # np.save(f"{output_path_base}.npy", output) 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals') 110 | 111 | # Dataset specific configs 112 | parser.add_argument('--data_root', type=str, default='./data/waymo_scenes') 113 | parser.add_argument('--seq_list', type=str, default=['0017085', '0145050', '0147030', '0158150'], help='specify --seq_list if you want to limit the list of seqs') 114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization") 115 | parser.add_argument('--ignore_existing', action='store_true') 116 | parser.add_argument('--rgb_dirname', type=str, default="image") 117 | parser.add_argument('--depth_dirname', type=str, default="depth") 118 | parser.add_argument('--normals_dirname', type=str, default="normal") 119 | 120 | # Algorithm configs 121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth") 122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model", 123 | default='omnidata/omnidata_tools/torch') 124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models", 125 | default='omnidata/omnidata_tools/torch/pretrained_models') 126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512, 127 | help="image size when inference (will still save full-scale output)") 128 | args = parser.parse_args() 129 | 130 | #----------------------------------------------- 131 | #-- Original preparation 132 | #----------------------------------------------- 133 | if args.pretrained_models is None: 134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/' 135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models") 136 | 137 | sys.path.append(args.omnidata_path) 138 | print(sys.path) 139 | from modules.unet import UNet 140 | from modules.midas.dpt_depth import DPTDepthModel 141 | from data.transforms import get_transform 142 | 143 | trans_topil = transforms.ToPILImage() 144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') 145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 146 | 147 | # get target task and model 148 | if args.task == 'normal' or args.task == 'normals': 149 | image_size = 384 150 | 151 | #---- Version 1 model 152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth') 153 | # model = UNet(in_channels=3, out_channels=3) 154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 155 | 156 | # if 'state_dict' in checkpoint: 157 | # state_dict = {} 158 | # for k, v in checkpoint['state_dict'].items(): 159 | # state_dict[k.replace('model.', '')] = v 160 | # else: 161 | # state_dict = checkpoint 162 | 163 | 164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt') 165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid 166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 167 | if 'state_dict' in checkpoint: 168 | state_dict = {} 169 | for k, v in checkpoint['state_dict'].items(): 170 | state_dict[k[6:]] = v 171 | else: 172 | state_dict = checkpoint 173 | 174 | model.load_state_dict(state_dict) 175 | model.to(device) 176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 177 | # transforms.CenterCrop(image_size), 178 | # get_transform('rgb', image_size=None)]) 179 | 180 | trans_totensor = transforms.Compose([ 181 | get_transform('rgb', image_size=None)]) 182 | 183 | elif args.task == 'depth': 184 | image_size = 384 185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt' 186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large 187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid 188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 189 | if 'state_dict' in checkpoint: 190 | state_dict = {} 191 | for k, v in checkpoint['state_dict'].items(): 192 | state_dict[k[6:]] = v 193 | else: 194 | state_dict = checkpoint 195 | model.load_state_dict(state_dict) 196 | model.to(device) 197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR), 198 | # transforms.CenterCrop(image_size), 199 | # transforms.ToTensor(), 200 | # transforms.Normalize(mean=0.5, std=0.5)]) 201 | trans_totensor = transforms.Compose([ 202 | transforms.ToTensor(), 203 | transforms.Normalize(mean=0.5, std=0.5)]) 204 | 205 | else: 206 | print("task should be one of the following: normal, depth") 207 | sys.exit() 208 | 209 | #----------------------------------------------- 210 | #--- Dataset Specific processing 211 | #----------------------------------------------- 212 | if isinstance(args.seq_list, list): 213 | select_scene_ids = args.seq_list 214 | else: 215 | select_scene_ids = [args.seq_list] 216 | 217 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')): 218 | for i in range(5): 219 | image_dir = os.path.join(args.data_root, scene_id, f"{args.rgb_dirname}_{i}") 220 | obs_id_list = sorted(os.listdir(image_dir)) 221 | 222 | if args.task == 'depth': 223 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname) 224 | elif args.task == 'normal': 225 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname) 226 | else: 227 | raise RuntimeError(f"Invalid task={args.task}") 228 | 229 | os.makedirs(output_dir, exist_ok=True) 230 | 231 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')): 232 | img_path = os.path.join(image_dir, obs_id) 233 | fbase = os.path.splitext(os.path.basename(img_path))[0] 234 | 235 | if args.task == 'depth': 236 | output_base = os.path.join(output_dir, fbase) 237 | elif args.task == 'normal': 238 | output_base = os.path.join(output_dir, fbase) 239 | else: 240 | raise RuntimeError(f"Invalid task={args.task}") 241 | 242 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase): 243 | continue 244 | 245 | #---- Inference and save outputs 246 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose) -------------------------------------------------------------------------------- /scripts/waymo_download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from concurrent.futures import ThreadPoolExecutor 5 | from typing import List 6 | 7 | 8 | def download_file(filename, target_dir, source): 9 | result = subprocess.run( 10 | [ 11 | "gsutil", 12 | "cp", 13 | "-n", 14 | f"{source}/{filename}.tfrecord", 15 | target_dir, 16 | ], 17 | capture_output=True, # To capture stderr and stdout for detailed error information 18 | text=True, 19 | ) 20 | 21 | # Check the return code of the gsutil command 22 | if result.returncode != 0: 23 | raise Exception( 24 | result.stderr 25 | ) # Raise an exception with the error message from the gsutil command 26 | 27 | 28 | def download_files( 29 | file_names: List[str], 30 | target_dir: str, 31 | source: str = "gs://waymo_open_dataset_scene_flow/train", 32 | ) -> None: 33 | """ 34 | Downloads a list of files from a given source to a target directory using multiple threads. 35 | 36 | Args: 37 | file_names (List[str]): A list of file names to download. 38 | target_dir (str): The target directory to save the downloaded files. 39 | source (str, optional): The source directory to download the files from. Defaults to "gs://waymo_open_dataset_scene_flow/train". 40 | """ 41 | # Get the total number of file_names 42 | total_files = len(file_names) 43 | 44 | # Use ThreadPoolExecutor to manage concurrent downloads 45 | with ThreadPoolExecutor(max_workers=10) as executor: 46 | futures = [ 47 | executor.submit(download_file, filename, target_dir, source) 48 | for filename in file_names 49 | ] 50 | 51 | for counter, future in enumerate(futures, start=1): 52 | # Wait for the download to complete and handle any exceptions 53 | try: 54 | # inspects the result of the future and raises an exception if one occurred during execution 55 | future.result() 56 | print(f"[{counter}/{total_files}] Downloaded successfully!") 57 | except Exception as e: 58 | print(f"[{counter}/{total_files}] Failed to download. Error: {e}") 59 | 60 | 61 | if __name__ == "__main__": 62 | print("note: `gcloud auth login` is required before running this script") 63 | print("Downloading Waymo dataset from Google Cloud Storage...") 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "--target_dir", 67 | type=str, 68 | default="dataset/waymo/raw", 69 | help="Path to the target directory", 70 | ) 71 | parser.add_argument( 72 | "--scene_ids", type=int, nargs="+", help="scene ids to download" 73 | ) 74 | parser.add_argument( 75 | "--split_file", type=str, default=None, help="split file in data/waymo_splits" 76 | ) 77 | args = parser.parse_args() 78 | os.makedirs(args.target_dir, exist_ok=True) 79 | total_list = open("data/waymo_train_list.txt", "r").readlines() 80 | if args.split_file is None: 81 | file_names = [total_list[i].strip() for i in args.scene_ids] 82 | else: 83 | # parse the split file 84 | split_file = open(args.split_file, "r").readlines()[1:] 85 | scene_ids = [int(line.strip().split(",")[0]) for line in split_file] 86 | file_names = [total_list[i].strip() for i in scene_ids] 87 | download_files(file_names, args.target_dir) 88 | 89 | 90 | # python scripts/waymo_download.py --scene_ids 23 114 327 621 703 172 552 788 -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import glob 12 | import os 13 | import torch 14 | from gaussian_renderer import get_renderer 15 | from scene import Scene, GaussianModel, EnvLight 16 | from utils.general_utils import seed_everything 17 | from tqdm import tqdm 18 | from argparse import ArgumentParser 19 | from torchvision.utils import save_image 20 | from omegaconf import OmegaConf 21 | from pprint import pformat 22 | EPS = 1e-5 23 | 24 | @torch.no_grad() 25 | def separation(scene : Scene, renderFunc, renderArgs, env_map=None): 26 | scale = scene.resolution_scales[0] 27 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)}, 28 | {'name': 'train', 'cameras': scene.getTrainCameras()}) 29 | 30 | # we supppose area with altitude>0.5 is static 31 | # here z axis is downward so is gaussians.get_xyz[:, 2] < -0.5 32 | high_mask = gaussians.get_xyz[:, 2] < 0 33 | # import pdb;pdb.set_trace() 34 | mask = (gaussians.get_scaling_t[:, 0] > args.separate_scaling_t) | high_mask 35 | for config in validation_configs: 36 | if config['cameras'] and len(config['cameras']) > 0: 37 | outdir = os.path.join(args.model_path, "separation", config['name']) 38 | os.makedirs(outdir,exist_ok=True) 39 | for idx, viewpoint in enumerate(tqdm(config['cameras'])): 40 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map) 41 | render_pkg_static = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=mask) 42 | 43 | image = torch.clamp(render_pkg["render"], 0.0, 1.0) 44 | image_static = torch.clamp(render_pkg_static["render"], 0.0, 1.0) 45 | 46 | save_image(image, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png")) 47 | save_image(image_static, os.path.join(outdir, f"{viewpoint.colmap_id:03d}_static.png")) 48 | 49 | 50 | if __name__ == "__main__": 51 | # Set up command line argument parser 52 | parser = ArgumentParser(description="Training script parameters") 53 | parser.add_argument("--config_path", type=str, required=True) 54 | params, _ = parser.parse_known_args() 55 | 56 | args = OmegaConf.load(params.config_path) 57 | args.resolution_scales = args.resolution_scales[:1] 58 | print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True)))) 59 | 60 | 61 | seed_everything(args.seed) 62 | 63 | sep_path = os.path.join(args.model_path, 'separation') 64 | os.makedirs(sep_path, exist_ok=True) 65 | 66 | gaussians = GaussianModel(args) 67 | scene = Scene(args, gaussians, shuffle=False) 68 | 69 | if args.env_map_res > 0: 70 | env_map = EnvLight(resolution=args.env_map_res).cuda() 71 | env_map.training_setup(args) 72 | else: 73 | env_map = None 74 | 75 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth")) 76 | assert len(checkpoints) > 0, "No checkpoints found." 77 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1] 78 | print(f"Loading checkpoint {checkpoint}") 79 | (model_params, first_iter) = torch.load(checkpoint) 80 | gaussians.restore(model_params, args) 81 | 82 | if env_map is not None: 83 | env_checkpoint = os.path.join(os.path.dirname(checkpoint), 84 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt")) 85 | (light_params, _) = torch.load(env_checkpoint) 86 | env_map.restore(light_params) 87 | 88 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0] 89 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 90 | render_func, _ = get_renderer(args.render_type) 91 | separation(scene, render_func, (args, background), env_map=env_map) 92 | 93 | print("Rendering statics and dynamics complete.") 94 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import cv2 14 | from scene.cameras import Camera 15 | import numpy as np 16 | from scene.scene_utils import CameraInfo 17 | from tqdm import tqdm 18 | from .graphics_utils import fov2focal 19 | 20 | 21 | def loadCam(args, id, cam_info: CameraInfo, resolution_scale): 22 | orig_w, orig_h = cam_info.width, cam_info.height # cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 3, 4, 8, 16, 32]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution) 27 | ) 28 | scale = resolution_scale * args.resolution 29 | else: # should be a type that converts to float 30 | if args.resolution == -1: 31 | global_down = 1 32 | else: 33 | global_down = orig_w / args.resolution 34 | 35 | scale = float(global_down) * float(resolution_scale) 36 | resolution = (int(orig_w / scale), int(orig_h / scale)) 37 | 38 | if cam_info.cx: 39 | cx = cam_info.cx / scale 40 | cy = cam_info.cy / scale 41 | fy = cam_info.fy / scale 42 | fx = cam_info.fx / scale 43 | else: 44 | cx = None 45 | cy = None 46 | fy = None 47 | fx = None 48 | 49 | if cam_info.image.shape[:2] != resolution[::-1]: 50 | image_rgb = cv2.resize(cam_info.image, resolution) 51 | else: 52 | image_rgb = cam_info.image 53 | image_rgb = torch.from_numpy(image_rgb).float().permute(2, 0, 1) 54 | gt_image = image_rgb[:3, ...] 55 | 56 | if cam_info.sky_mask is not None: 57 | if cam_info.sky_mask.shape[:2] != resolution[::-1]: 58 | sky_mask = cv2.resize(cam_info.sky_mask, resolution) 59 | else: 60 | sky_mask = cam_info.sky_mask 61 | if len(sky_mask.shape) == 2: 62 | sky_mask = sky_mask[..., None] 63 | sky_mask = torch.from_numpy(sky_mask).float().permute(2, 0, 1) 64 | else: 65 | sky_mask = None 66 | 67 | if cam_info.pointcloud_camera is not None: 68 | h, w = gt_image.shape[1:] 69 | K = np.eye(3) 70 | if cam_info.cx: 71 | K[0, 0] = fx 72 | K[1, 1] = fy 73 | K[0, 2] = cx 74 | K[1, 2] = cy 75 | else: 76 | K[0, 0] = fov2focal(cam_info.FovX, w) 77 | K[1, 1] = fov2focal(cam_info.FovY, h) 78 | K[0, 2] = cam_info.width / 2 79 | K[1, 2] = cam_info.height / 2 80 | pts_depth = np.zeros([1, h, w]) 81 | point_camera = cam_info.pointcloud_camera 82 | uvz = point_camera[point_camera[:, 2] > 0] 83 | uvz = uvz @ K.T 84 | uvz[:, :2] /= uvz[:, 2:] 85 | uvz = uvz[uvz[:, 1] >= 0] 86 | uvz = uvz[uvz[:, 1] < h] 87 | uvz = uvz[uvz[:, 0] >= 0] 88 | uvz = uvz[uvz[:, 0] < w] 89 | uv = uvz[:, :2] 90 | uv = uv.astype(int) 91 | # TODO: may need to consider overlap 92 | pts_depth[0, uv[:, 1], uv[:, 0]] = uvz[:, 2] 93 | pts_depth = torch.from_numpy(pts_depth).float() 94 | 95 | elif cam_info.depth_map is not None: 96 | if cam_info.depth_map.shape[:2] != resolution[::-1]: 97 | depth_map = cv2.resize(cam_info.depth_map, resolution) 98 | else: 99 | depth_map = cam_info.depth_map 100 | depth_map = torch.from_numpy(depth_map).float() 101 | pts_depth = depth_map.unsqueeze(0) 102 | 103 | else: 104 | pts_depth = None 105 | 106 | if cam_info.dynamic_mask is not None: 107 | if cam_info.dynamic_mask.shape[:2] != resolution[::-1]: 108 | dynamic_mask = cv2.resize(cam_info.dynamic_mask, resolution) 109 | else: 110 | dynamic_mask = cam_info.dynamic_mask 111 | if len(dynamic_mask.shape) == 2: 112 | dynamic_mask = dynamic_mask[..., None] 113 | dynamic_mask = torch.from_numpy(dynamic_mask).float().permute(2, 0, 1) 114 | else: 115 | dynamic_mask = None 116 | 117 | if cam_info.normal_map is not None: 118 | if cam_info.normal_map.shape[:2] != resolution[::-1]: 119 | normal_map = cv2.resize(cam_info.normal_map, resolution) 120 | else: 121 | normal_map = cam_info.normal_map 122 | normal_map = torch.from_numpy(normal_map).float().permute(2, 0, 1) # H, W, 3-> 3, H, W 123 | else: 124 | normal_map = None 125 | 126 | return Camera( 127 | colmap_id=cam_info.uid, 128 | uid=id, 129 | R=cam_info.R, 130 | T=cam_info.T, 131 | FoVx=cam_info.FovX, 132 | FoVy=cam_info.FovY, 133 | cx=cx, 134 | cy=cy, 135 | fx=fx, 136 | fy=fy, 137 | image=gt_image, 138 | image_name=cam_info.image_name, 139 | data_device=args.data_device, 140 | timestamp=cam_info.timestamp, 141 | resolution=resolution, 142 | image_path=cam_info.image_path, 143 | pts_depth=pts_depth, 144 | sky_mask=sky_mask, 145 | dynamic_mask = dynamic_mask, 146 | normal_map = normal_map, 147 | ) 148 | 149 | 150 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 151 | camera_list = [] 152 | 153 | for id, c in enumerate(tqdm(cam_infos, bar_format='{l_bar}{bar:50}{r_bar}')): 154 | camera_list.append(loadCam(args, id, c, resolution_scale)) 155 | 156 | return camera_list 157 | 158 | 159 | def camera_to_JSON(id, camera: Camera): 160 | Rt = np.zeros((4, 4)) 161 | Rt[:3, :3] = camera.R.transpose() 162 | Rt[:3, 3] = camera.T 163 | Rt[3, 3] = 1.0 164 | 165 | W2C = np.linalg.inv(Rt) 166 | pos = W2C[:3, 3] 167 | rot = W2C[:3, :3] 168 | serializable_array_2d = [x.tolist() for x in rot] 169 | 170 | if camera.cx is None: 171 | camera_entry = { 172 | "id": id, 173 | "img_name": camera.image_name, 174 | "width": camera.width, 175 | "height": camera.height, 176 | "position": pos.tolist(), 177 | "rotation": serializable_array_2d, 178 | "FoVx": camera.FovX, 179 | "FoVy": camera.FovY, 180 | } 181 | else: 182 | camera_entry = { 183 | "id": id, 184 | "img_name": camera.image_name, 185 | "width": camera.width, 186 | "height": camera.height, 187 | "position": pos.tolist(), 188 | "rotation": serializable_array_2d, 189 | "fx": camera.fx, 190 | "fy": camera.fy, 191 | "cx": camera.cx, 192 | "cy": camera.cy, 193 | } 194 | return camera_entry 195 | 196 | 197 | def calculate_mean_and_std(mean_per_image, std_per_image): 198 | # Calculate mean RGB across dataset 199 | mean_dataset = np.mean(mean_per_image, axis=0) 200 | 201 | # Calculate variance of each image 202 | variances = std_per_image**2 203 | 204 | # Calculate overall variance across the dataset 205 | overall_variance = np.mean(variances, axis=0) + np.mean((mean_per_image - mean_dataset)**2, axis=0) # (C,) 206 | 207 | # Calculate std RGB across dataset 208 | std_rgb_dataset = np.sqrt(overall_variance) 209 | 210 | return mean_dataset, std_rgb_dataset -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import numpy as np 15 | import random 16 | from matplotlib import cm 17 | 18 | def GridSample3D(in_pc,in_shs, voxel_size=0.013): 19 | in_pc_ = in_pc[:,:3].copy() 20 | quantized_pc = np.around(in_pc_ / voxel_size) 21 | quantized_pc -= np.min(quantized_pc, axis=0) 22 | pc_boundary = np.max(quantized_pc, axis=0) - np.min(quantized_pc, axis=0) 23 | 24 | voxel_index = quantized_pc[:,0] * pc_boundary[1] * pc_boundary[2] + quantized_pc[:,1] * pc_boundary[2] + quantized_pc[:,2] 25 | 26 | split_point, index = get_split_point(voxel_index) 27 | 28 | in_points = in_pc[index,:] 29 | out_points = in_points[split_point[:-1],:] 30 | 31 | in_colors = in_shs[index] 32 | out_colors = in_colors[split_point[:-1]] 33 | 34 | return out_points,out_colors 35 | 36 | def get_split_point(labels): 37 | index = np.argsort(labels) 38 | label = labels[index] 39 | label_shift = label.copy() 40 | 41 | label_shift[1:] = label[:-1] 42 | remain = label - label_shift 43 | step_index = np.where(remain > 0)[0].tolist() 44 | step_index.insert(0,0) 45 | step_index.append(labels.shape[0]) 46 | return step_index,index 47 | 48 | def sample_on_aabb_surface(aabb_center, aabb_size, n_pts=1000, above_half=False): 49 | """ 50 | 0:立方体的左面(x轴负方向) 51 | 1:立方体的右面(x轴正方向) 52 | 2:立方体的下面(y轴负方向) 53 | 3:立方体的上面(y轴正方向) 54 | 4:立方体的后面(z轴负方向) 55 | 5:立方体的前面(z轴正方向) 56 | """ 57 | # Choose a face randomly 58 | faces = np.random.randint(0, 6, size=n_pts) 59 | 60 | # Generate two random numbers 61 | r_ = np.random.random((n_pts, 2)) 62 | 63 | # Create an array to store the points 64 | points = np.zeros((n_pts, 3)) 65 | 66 | # Define the offsets for each face 67 | offsets = np.array([ 68 | [-aabb_size[0]/2, 0, 0], 69 | [aabb_size[0]/2, 0, 0], 70 | [0, -aabb_size[1]/2, 0], 71 | [0, aabb_size[1]/2, 0], 72 | [0, 0, -aabb_size[2]/2], 73 | [0, 0, aabb_size[2]/2] 74 | ]) 75 | 76 | # Define the scales for each face 77 | scales = np.array([ 78 | [aabb_size[1], aabb_size[2]], 79 | [aabb_size[1], aabb_size[2]], 80 | [aabb_size[0], aabb_size[2]], 81 | [aabb_size[0], aabb_size[2]], 82 | [aabb_size[0], aabb_size[1]], 83 | [aabb_size[0], aabb_size[1]] 84 | ]) 85 | 86 | # Define the positions of the zero column for each face 87 | zero_column_positions = [0, 0, 1, 1, 2, 2] 88 | # Define the indices of the aabb_size components for each face 89 | aabb_size_indices = [[1, 2], [1, 2], [0, 2], [0, 2], [0, 1], [0, 1]] 90 | # Calculate the coordinates of the points for each face 91 | for i in range(6): 92 | mask = faces == i 93 | r_scaled = r_[mask] * scales[i] 94 | r_scaled = np.insert(r_scaled, zero_column_positions[i], 0, axis=1) 95 | aabb_size_adjusted = np.insert(aabb_size[aabb_size_indices[i]] / 2, zero_column_positions[i], 0) 96 | points[mask] = aabb_center + offsets[i] + r_scaled - aabb_size_adjusted 97 | #visualize_points(points[mask], aabb_center, aabb_size) 98 | #visualize_points(points, aabb_center, aabb_size) 99 | 100 | # 提取上半部分的点 101 | if above_half: 102 | points = points[points[:, -1] > aabb_center[-1]] 103 | return points 104 | 105 | def get_OccGrid(pts, aabb, occ_voxel_size): 106 | # 计算网格的大小 107 | grid_size = np.ceil((aabb[1] - aabb[0]) / occ_voxel_size).astype(int) 108 | assert pts.min() >= aabb[0].min() and pts.max() <= aabb[1].max(), "Points are outside the AABB" 109 | 110 | # 创建一个空的网格 111 | voxel_grid = np.zeros(grid_size, dtype=np.uint8) 112 | 113 | # 将点云转换为网格坐标 114 | grid_pts = ((pts - aabb[0]) / occ_voxel_size).astype(int) 115 | 116 | # 将网格中的点设置为1 117 | voxel_grid[grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2]] = 1 118 | 119 | # check 120 | #voxel_coords = np.floor((pts - aabb[0]) / occ_voxel_size).astype(int) 121 | #occ = voxel_grid[voxel_coords[:, 0], voxel_coords[:, 1], voxel_coords[:, 2]] 122 | 123 | return voxel_grid 124 | 125 | def visualize_depth(depth, near=0.2, far=13, linear=False): 126 | depth = depth[0].clone().detach().cpu().numpy() 127 | colormap = cm.get_cmap('turbo') 128 | curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) 129 | if linear: 130 | curve_fn = lambda x: -x 131 | eps = np.finfo(np.float32).eps 132 | near = near if near else depth.min() 133 | far = far if far else depth.max() 134 | near -= eps 135 | far += eps 136 | near, far, depth = [curve_fn(x) for x in [near, far, depth]] 137 | depth = np.nan_to_num( 138 | np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1)) 139 | vis = colormap(depth)[:, :, :3] 140 | out_depth = np.clip(np.nan_to_num(vis), 0., 1.) * 255 141 | out_depth = torch.from_numpy(out_depth).permute(2, 0, 1).float().cuda() / 255 142 | return out_depth 143 | 144 | 145 | def inverse_sigmoid(x): 146 | return torch.log(x / (1 - x)) 147 | 148 | 149 | def PILtoTorch(pil_image, resolution): 150 | resized_image_PIL = pil_image.resize(resolution) 151 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 152 | if len(resized_image.shape) == 3: 153 | return resized_image.permute(2, 0, 1) 154 | else: 155 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 156 | 157 | 158 | def get_step_lr_func(lr_init, lr_final, start_step): 159 | def helper(step): 160 | if step < start_step: 161 | return lr_init 162 | else: 163 | return lr_final 164 | return helper 165 | 166 | def get_expon_lr_func( 167 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 168 | ): 169 | """ 170 | Copied from Plenoxels 171 | 172 | Continuous learning rate decay function. Adapted from JaxNeRF 173 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 174 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 175 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 176 | function of lr_delay_mult, such that the initial learning rate is 177 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 178 | to the normal learning rate when steps>lr_delay_steps. 179 | :param conf: config subtree 'lr' or similar 180 | :param max_steps: int, the number of steps during optimization. 181 | :return HoF which takes step as input 182 | """ 183 | 184 | def helper(step): 185 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 186 | # Disable this parameter 187 | return 0.0 188 | if lr_delay_steps > 0: 189 | # A kind of reverse cosine decay. 190 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 191 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 192 | ) 193 | else: 194 | delay_rate = 1.0 195 | t = np.clip(step / max_steps, 0, 1) 196 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 197 | return delay_rate * log_lerp 198 | 199 | return helper 200 | 201 | 202 | def strip_lowerdiag(L): 203 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 204 | 205 | uncertainty[:, 0] = L[:, 0, 0] 206 | uncertainty[:, 1] = L[:, 0, 1] 207 | uncertainty[:, 2] = L[:, 0, 2] 208 | uncertainty[:, 3] = L[:, 1, 1] 209 | uncertainty[:, 4] = L[:, 1, 2] 210 | uncertainty[:, 5] = L[:, 2, 2] 211 | return uncertainty 212 | 213 | 214 | def strip_symmetric(sym): 215 | return strip_lowerdiag(sym) 216 | 217 | 218 | def build_rotation(r): 219 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 220 | 221 | q = r / norm[:, None] 222 | 223 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 224 | 225 | r = q[:, 0] 226 | x = q[:, 1] 227 | y = q[:, 2] 228 | z = q[:, 3] 229 | 230 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 231 | R[:, 0, 1] = 2 * (x * y - r * z) 232 | R[:, 0, 2] = 2 * (x * z + r * y) 233 | R[:, 1, 0] = 2 * (x * y + r * z) 234 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 235 | R[:, 1, 2] = 2 * (y * z - r * x) 236 | R[:, 2, 0] = 2 * (x * z - r * y) 237 | R[:, 2, 1] = 2 * (y * z + r * x) 238 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 239 | return R 240 | 241 | def rotation_to_quaternion(R): 242 | r11, r12, r13 = R[:, 0, 0], R[:, 0, 1], R[:, 0, 2] 243 | r21, r22, r23 = R[:, 1, 0], R[:, 1, 1], R[:, 1, 2] 244 | r31, r32, r33 = R[:, 2, 0], R[:, 2, 1], R[:, 2, 2] 245 | 246 | qw = torch.sqrt((1 + r11 + r22 + r33).clamp_min(1e-7)) / 2 247 | qx = (r32 - r23) / (4 * qw) 248 | qy = (r13 - r31) / (4 * qw) 249 | qz = (r21 - r12) / (4 * qw) 250 | 251 | quaternion = torch.stack((qw, qx, qy, qz), dim=-1) 252 | quaternion = torch.nn.functional.normalize(quaternion, dim=-1) 253 | return quaternion 254 | 255 | def quaternion_to_rotation_matrix(q): 256 | w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] 257 | 258 | r11 = 1 - 2 * y * y - 2 * z * z 259 | r12 = 2 * x * y - 2 * w * z 260 | r13 = 2 * x * z + 2 * w * y 261 | 262 | r21 = 2 * x * y + 2 * w * z 263 | r22 = 1 - 2 * x * x - 2 * z * z 264 | r23 = 2 * y * z - 2 * w * x 265 | 266 | r31 = 2 * x * z - 2 * w * y 267 | r32 = 2 * y * z + 2 * w * x 268 | r33 = 1 - 2 * x * x - 2 * y * y 269 | 270 | rotation_matrix = torch.stack((torch.stack((r11, r12, r13), dim=1), 271 | torch.stack((r21, r22, r23), dim=1), 272 | torch.stack((r31, r32, r33), dim=1)), dim=1) 273 | return rotation_matrix 274 | 275 | def quaternion_multiply(q1, q2): 276 | w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] 277 | w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] 278 | 279 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 280 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 281 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 282 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 283 | 284 | result_quaternion = torch.stack((w, x, y, z), dim=1) 285 | return result_quaternion 286 | 287 | def build_scaling_rotation(s, r): 288 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 289 | R = build_rotation(r) 290 | 291 | L[:, 0, 0] = s[:, 0] 292 | L[:, 1, 1] = s[:, 1] 293 | L[:, 2, 2] = s[:, 2] 294 | 295 | L = R @ L 296 | return L 297 | 298 | 299 | def seed_everything(seed): 300 | random.seed(seed) 301 | os.environ['PYTHONHASHSEED'] = str(seed) 302 | np.random.seed(seed) 303 | torch.manual_seed(seed) 304 | torch.cuda.manual_seed(seed) 305 | torch.cuda.manual_seed_all(seed) 306 | 307 | import logging 308 | 309 | import sys 310 | 311 | def init_logging(filename=None, debug=False): 312 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO') 313 | formatter = logging.Formatter('[%(asctime)s][%(filename)s][%(levelname)s] - %(message)s') 314 | stream_handler = logging.StreamHandler(sys.stdout) 315 | 316 | stream_handler.setFormatter(formatter) 317 | logging.root.addHandler(stream_handler) 318 | 319 | if filename is not None: 320 | file_handler = logging.FileHandler(filename) 321 | file_handler.setFormatter(formatter) 322 | logging.root.addHandler(file_handler) 323 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | def ndc_2_cam(ndc_xyz, intrinsic, W, H): 18 | inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device) 19 | cam_z = ndc_xyz[..., 2:3] 20 | cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z 21 | cam_xyz = torch.cat([cam_xy, cam_z], dim=-1) 22 | cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t()) 23 | return cam_xyz 24 | 25 | def depth2point_cam(sampled_depth, ref_intrinsic): 26 | B, N, C, H, W = sampled_depth.shape 27 | valid_z = sampled_depth 28 | valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device) / (W - 1) 29 | valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device) / (H - 1) 30 | valid_x, valid_y = torch.meshgrid(valid_x, valid_y, indexing='xy') 31 | # B,N,H,W 32 | valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1) 33 | valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1) 34 | ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view(B, N, C, H, W, 3) # 1, 1, 5, 512, 640, 3 35 | cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H) # 1, 1, 5, 512, 640, 3 36 | return ndc_xyz, cam_xyz 37 | 38 | def depth2point_world(depth_image, intrinsic_matrix, extrinsic_matrix): 39 | # depth_image: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4) 40 | _, xyz_cam = depth2point_cam(depth_image[None,None,None,...], intrinsic_matrix[None,...]) 41 | xyz_cam = xyz_cam.reshape(-1,3) 42 | # xyz_world = torch.cat([xyz_cam, torch.ones_like(xyz_cam[...,0:1])], axis=-1) @ torch.inverse(extrinsic_matrix).transpose(0,1) 43 | # xyz_world = xyz_world[...,:3] 44 | 45 | return xyz_cam 46 | 47 | def depth_pcd2normal(xyz, offset=None, gt_image=None): 48 | hd, wd, _ = xyz.shape 49 | if offset is not None: 50 | ix, iy = torch.meshgrid( 51 | torch.arange(wd), torch.arange(hd), indexing='xy') 52 | xy = (torch.stack((ix, iy), dim=-1)[1:-1,1:-1]).to(xyz.device) 53 | p_offset = torch.tensor([[0,1],[0,-1],[1,0],[-1,0]]).float().to(xyz.device) 54 | new_offset = p_offset[None,None] + offset.reshape(hd, wd, 4, 2)[1:-1,1:-1] 55 | xys = xy[:,:,None] + new_offset 56 | xys[..., 0] = 2 * xys[..., 0] / (wd - 1) - 1.0 57 | xys[..., 1] = 2 * xys[..., 1] / (hd - 1) - 1.0 58 | sampled_xyzs = torch.nn.functional.grid_sample(xyz.permute(2,0,1)[None], xys.reshape(1, -1, 1, 2)) 59 | sampled_xyzs = sampled_xyzs.permute(0,2,3,1).reshape(hd-2,wd-2,4,3) 60 | bottom_point = sampled_xyzs[:,:,0] 61 | top_point = sampled_xyzs[:,:,1] 62 | right_point = sampled_xyzs[:,:,2] 63 | left_point = sampled_xyzs[:,:,3] 64 | else: 65 | bottom_point = xyz[..., 2:hd, 1:wd-1, :] 66 | top_point = xyz[..., 0:hd-2, 1:wd-1, :] 67 | right_point = xyz[..., 1:hd-1, 2:wd, :] 68 | left_point = xyz[..., 1:hd-1, 0:wd-2, :] 69 | left_to_right = right_point - left_point 70 | bottom_to_top = top_point - bottom_point 71 | xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1) 72 | xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1) 73 | xyz_normal = torch.nn.functional.pad(xyz_normal.permute(2,0,1), (1,1,1,1), mode='constant').permute(1,2,0) 74 | return xyz_normal 75 | 76 | def normal_from_depth_image(depth, intrinsic_matrix, extrinsic_matrix, offset=None, gt_image=None): 77 | # depth: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4) 78 | # xyz_normal: (H, W, 3) 79 | xyz_world = depth2point_world(depth, intrinsic_matrix, extrinsic_matrix) # (HxW, 3) 80 | xyz_world = xyz_world.reshape(*depth.shape, 3) 81 | xyz_normal = depth_pcd2normal(xyz_world, offset, gt_image) 82 | 83 | return xyz_normal 84 | 85 | def render_normal(viewpoint_camera, depth, offset=None, normal=None, scale=1): 86 | # depth: (H, W), bg_color: (3), alpha: (H, W) 87 | # normal_ref: (3, H, W) 88 | intrinsic_matrix, extrinsic_matrix = viewpoint_camera.get_calib_matrix_nerf( 89 | scale=scale 90 | ) 91 | st = max(int(scale / 2) - 1, 0) 92 | if offset is not None: 93 | offset = offset[st::scale, st::scale] 94 | normal_ref = normal_from_depth_image( 95 | depth[st::scale, st::scale], 96 | intrinsic_matrix.to(depth.device), 97 | extrinsic_matrix.to(depth.device), 98 | offset, 99 | ) 100 | 101 | normal_ref = normal_ref.permute(2, 0, 1) 102 | return normal_ref 103 | 104 | def normal_from_neareast(normal, offset): 105 | _, hd, wd = normal.shape 106 | left_top_point = normal[..., 0:hd-2, 0:wd-2] 107 | top_point = normal[..., 0:hd-2, 1:wd-1] 108 | right_top_point= normal[..., 0:hd-2, 2:wd] 109 | left_point = normal[..., 1:hd-1, 0:wd-2] 110 | right_point = normal[..., 1:hd-1, 2:wd] 111 | left_bottom_point = normal[..., 2:hd, 0:wd-2] 112 | bottom_point = normal[..., 2:hd, 1:wd-1] 113 | right_bottom_point = normal[..., 2:hd, 2:wd] 114 | normals = torch.stack((left_top_point,top_point,right_top_point,left_point,right_point,left_bottom_point,bottom_point,right_bottom_point),dim=0) 115 | new_normal = (normals * offset[:,None,1:-1,1:-1]).sum(0) 116 | new_normal = torch.nn.functional.normalize(new_normal, p=2, dim=0) 117 | new_normal = torch.nn.functional.pad(new_normal, (1,1,1,1), mode='constant').permute(1,2,0) 118 | return new_normal 119 | 120 | def patch_offsets(h_patch_size, device): 121 | offsets = torch.arange(-h_patch_size, h_patch_size + 1, device=device) 122 | return torch.stack(torch.meshgrid(offsets, offsets, indexing='xy')[::-1], dim=-1).view(1, -1, 2) 123 | 124 | def patch_warp(H, uv): 125 | B, P = uv.shape[:2] 126 | H = H.view(B, 3, 3) 127 | ones = torch.ones((B,P,1), device=uv.device) 128 | homo_uv = torch.cat((uv, ones), dim=-1) 129 | 130 | grid_tmp = torch.einsum("bik,bpk->bpi", H, homo_uv) 131 | grid_tmp = grid_tmp.reshape(B, P, 3) 132 | grid = grid_tmp[..., :2] / (grid_tmp[..., 2:] + 1e-10) 133 | return grid 134 | 135 | class BasicPointCloud(NamedTuple): 136 | points : np.array 137 | colors : np.array 138 | normals : np.array 139 | time : np.array = None 140 | 141 | def getWorld2View(R, t): 142 | Rt = np.zeros((4, 4)) 143 | Rt[:3, :3] = R.transpose() 144 | Rt[:3, 3] = t 145 | Rt[3, 3] = 1.0 146 | return np.float32(Rt) 147 | 148 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 149 | Rt = np.zeros((4, 4)) 150 | Rt[:3, :3] = R.transpose() 151 | Rt[:3, 3] = t 152 | Rt[3, 3] = 1.0 153 | 154 | C2W = np.linalg.inv(Rt) 155 | cam_center = C2W[:3, 3] 156 | cam_center = (cam_center + translate) * scale 157 | C2W[:3, 3] = cam_center 158 | Rt = np.linalg.inv(C2W) 159 | return np.float32(Rt) 160 | 161 | def getProjectionMatrix(znear, zfar, fovX, fovY): 162 | tanHalfFovY = math.tan((fovY / 2)) 163 | tanHalfFovX = math.tan((fovX / 2)) 164 | 165 | top = tanHalfFovY * znear 166 | bottom = -top 167 | right = tanHalfFovX * znear 168 | left = -right 169 | 170 | P = torch.zeros(4, 4) 171 | 172 | z_sign = 1.0 173 | 174 | P[0, 0] = 2.0 * znear / (right - left) 175 | P[1, 1] = 2.0 * znear / (top - bottom) 176 | P[0, 2] = (right + left) / (right - left) 177 | P[1, 2] = (top + bottom) / (top - bottom) 178 | P[3, 2] = z_sign 179 | P[2, 2] = z_sign * zfar / (zfar - znear) 180 | P[2, 3] = -(zfar * znear) / (zfar - znear) 181 | return P 182 | 183 | def getProjectionMatrixCenterShift(znear, zfar, cx, cy, fx, fy, w, h): 184 | top = cy / fy * znear 185 | bottom = -(h-cy) / fy * znear 186 | 187 | left = -(w-cx) / fx * znear 188 | right = cx / fx * znear 189 | 190 | P = torch.zeros(4, 4) 191 | 192 | z_sign = 1.0 193 | 194 | P[0, 0] = 2.0 * znear / (right - left) 195 | P[1, 1] = 2.0 * znear / (top - bottom) 196 | P[0, 2] = (right + left) / (right - left) 197 | P[1, 2] = (top + bottom) / (top - bottom) 198 | P[3, 2] = z_sign 199 | P[2, 2] = z_sign * zfar / (zfar - znear) 200 | P[2, 3] = -(zfar * znear) / (zfar - znear) 201 | return P 202 | 203 | def fov2focal(fov, pixels): 204 | return pixels / (2 * math.tan(fov / 2)) 205 | 206 | def focal2fov(focal, pixels): 207 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | import PIL 15 | import torch.nn.functional as F 16 | import torch.nn as nn 17 | from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union 18 | 19 | # RGB colors used to visualize each semantic segmentation class. 20 | SEGMENTATION_COLOR_MAP = dict( 21 | TYPE_UNDEFINED=[0, 0, 0], 22 | TYPE_EGO_VEHICLE=[102, 102, 102], 23 | TYPE_CAR=[0, 0, 142], 24 | TYPE_TRUCK=[0, 0, 70], 25 | TYPE_BUS=[0, 60, 100], 26 | TYPE_OTHER_LARGE_VEHICLE=[61, 133, 198], 27 | TYPE_BICYCLE=[119, 11, 32], 28 | TYPE_MOTORCYCLE=[0, 0, 230], 29 | TYPE_TRAILER=[111, 168, 220], 30 | TYPE_PEDESTRIAN=[220, 20, 60], 31 | TYPE_CYCLIST=[255, 0, 0], 32 | TYPE_MOTORCYCLIST=[180, 0, 0], 33 | TYPE_BIRD=[127, 96, 0], 34 | TYPE_GROUND_ANIMAL=[91, 15, 0], 35 | TYPE_CONSTRUCTION_CONE_POLE=[230, 145, 56], 36 | TYPE_POLE=[153, 153, 153], 37 | TYPE_PEDESTRIAN_OBJECT=[234, 153, 153], 38 | TYPE_SIGN=[246, 178, 107], 39 | TYPE_TRAFFIC_LIGHT=[250, 170, 30], 40 | TYPE_BUILDING=[70, 70, 70], 41 | TYPE_ROAD=[128, 64, 128], 42 | TYPE_LANE_MARKER=[234, 209, 220], 43 | TYPE_ROAD_MARKER=[217, 210, 233], 44 | TYPE_SIDEWALK=[244, 35, 232], 45 | TYPE_VEGETATION=[107, 142, 35], 46 | TYPE_SKY=[70, 130, 180], 47 | TYPE_GROUND=[102, 102, 102], 48 | TYPE_DYNAMIC=[102, 102, 102], 49 | TYPE_STATIC=[102, 102, 102], 50 | ) 51 | 52 | def _generate_color_map( 53 | color_map_dict: Optional[ 54 | Mapping[int, Sequence[int]]] = None 55 | ) -> np.ndarray: 56 | """Generates a mapping from segmentation classes (rows) to colors (cols). 57 | 58 | Args: 59 | color_map_dict: An optional dict mapping from semantic classes to colors. If 60 | None, the default colors in SEGMENTATION_COLOR_MAP will be used. 61 | Returns: 62 | A np array of shape [max_class_id + 1, 3], where each row encodes the color 63 | for the corresponding class id. 64 | """ 65 | if color_map_dict is None: 66 | color_map_dict = SEGMENTATION_COLOR_MAP 67 | classes = list(color_map_dict.keys()) 68 | colors = list(color_map_dict.values()) 69 | color_map = np.zeros([#np.amax(classes) + 1 70 | len(classes) 71 | , 3], dtype=np.uint8) 72 | for idx, color in enumerate(colors): 73 | color_map[idx] = color 74 | #color_map[classes] = colors 75 | return color_map 76 | 77 | DEFAULT_COLOR_MAP = _generate_color_map() 78 | 79 | def get_panoptic_id(semantic_id, instance_id, semantic_interval=1000): 80 | if isinstance(semantic_id, np.ndarray): 81 | semantic_id = torch.from_numpy(semantic_id) 82 | instance_id = torch.from_numpy(instance_id) 83 | elif isinstance(semantic_id, PIL.Image.Image): 84 | semantic_id = torch.from_numpy(np.array(semantic_id)) 85 | instance_id = torch.from_numpy(np.array(instance_id)) 86 | elif isinstance(semantic_id, torch.Tensor): 87 | pass 88 | else: 89 | raise ValueError("semantic_id type is not supported!") 90 | 91 | return semantic_id * semantic_interval + instance_id 92 | 93 | def get_panoptic_encoding(semantic_id, instance_id, ): 94 | # 将 semantic-id 和 instance-id 编码成 panoptic one-hot编码 95 | panoptic_id = get_panoptic_id(semantic_id, instance_id) 96 | unique_panoptic_classes = panoptic_id.unique() 97 | num_panoptic_classes = unique_panoptic_classes.shape[0] 98 | # construct id map dict: panoptic_id -> num_class_idx 99 | id_to_idx_dict = {} 100 | for i in range(num_panoptic_classes): 101 | id_to_idx_dict[unique_panoptic_classes[i]] = i 102 | 103 | # convert to one-hot encoding 104 | panoptic_encoding = torch.zeros((num_panoptic_classes, ), dtype=torch.float32) 105 | 106 | 107 | def feat_encode(obj_id, id_to_idx, gt_label_embedding: nn.Embedding = None, output_both=False, only_idx=False): 108 | """ 根据 obj_id 和 id_to_idx_dict 编码成 one-hot """ 109 | 110 | map_ids = torch.zeros_like(obj_id) #obj_id.clone() 111 | # 将 gt-obj-id 替换成 global-obj-idx ,然后转成 one-hot 112 | for key, value in id_to_idx.items(): 113 | map_ids[obj_id == key] = value 114 | 115 | # query embedding 116 | if gt_label_embedding is not None: 117 | gt_label = gt_label_embedding(map_ids.flatten().long()) 118 | else: 119 | gt_label = None 120 | 121 | if output_both: 122 | return map_ids, gt_label 123 | else: 124 | if only_idx: 125 | return map_ids.long().flatten() 126 | else: 127 | return gt_label 128 | 129 | 130 | 131 | 132 | 133 | def mse(img1, img2): 134 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 135 | 136 | def psnr(img1, img2): 137 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 138 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 139 | 140 | def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): 141 | # features: (N, C) 142 | # m: a hyperparam controlling how many std dev outside for outliers 143 | assert len(features.shape) == 2, "features should be (N, C)" 144 | reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] 145 | colors = features @ reduction_mat 146 | if remove_first_component: 147 | colors_min = colors.min(dim=0).values 148 | colors_max = colors.max(dim=0).values 149 | tmp_colors = (colors - colors_min) / (colors_max - colors_min) 150 | fg_mask = tmp_colors[..., 0] < 0.2 151 | reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] 152 | colors = features @ reduction_mat 153 | else: 154 | fg_mask = torch.ones_like(colors[:, 0]).bool() 155 | d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) 156 | mdev = torch.median(d, dim=0).values 157 | s = d / mdev 158 | rins = colors[fg_mask][s[:, 0] < m, 0] 159 | gins = colors[fg_mask][s[:, 1] < m, 1] 160 | bins = colors[fg_mask][s[:, 2] < m, 2] 161 | 162 | rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) 163 | rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) 164 | return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) 165 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | import torch.nn as nn 17 | from kornia.filters import laplacian, spatial_gradient 18 | import numpy as np 19 | def psnr(img1, img2): 20 | mse = F.mse_loss(img1, img2) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map 64 | 65 | def tv_loss(depth): 66 | c, h, w = depth.shape[0], depth.shape[1], depth.shape[2] 67 | count_h = c * (h - 1) * w 68 | count_w = c * h * (w - 1) 69 | h_tv = torch.square(depth[..., 1:, :] - depth[..., :h-1, :]).sum() 70 | w_tv = torch.square(depth[..., :, 1:] - depth[..., :, :w-1]).sum() 71 | return 2 * (h_tv / count_h + w_tv / count_w) 72 | 73 | def get_img_grad_weight(img, beta=2.0): 74 | _, hd, wd = img.shape 75 | bottom_point = img[..., 2:hd, 1:wd-1] 76 | top_point = img[..., 0:hd-2, 1:wd-1] 77 | right_point = img[..., 1:hd-1, 2:wd] 78 | left_point = img[..., 1:hd-1, 0:wd-2] 79 | grad_img_x = torch.mean(torch.abs(right_point - left_point), 0, keepdim=True) 80 | grad_img_y = torch.mean(torch.abs(top_point - bottom_point), 0, keepdim=True) 81 | grad_img = torch.cat((grad_img_x, grad_img_y), dim=0) 82 | grad_img, _ = torch.max(grad_img, dim=0) 83 | grad_img = (grad_img - grad_img.min()) / (grad_img.max() - grad_img.min()) 84 | grad_img = torch.nn.functional.pad(grad_img[None,None], (1,1,1,1), mode='constant', value=1.0).squeeze() 85 | return grad_img 86 | 87 | def lncc(ref, nea): 88 | # ref_gray: [batch_size, total_patch_size] 89 | # nea_grays: [batch_size, total_patch_size] 90 | bs, tps = nea.shape 91 | patch_size = int(np.sqrt(tps)) 92 | 93 | ref_nea = ref * nea 94 | ref_nea = ref_nea.view(bs, 1, patch_size, patch_size) 95 | ref = ref.view(bs, 1, patch_size, patch_size) 96 | nea = nea.view(bs, 1, patch_size, patch_size) 97 | ref2 = ref.pow(2) 98 | nea2 = nea.pow(2) 99 | 100 | # sum over kernel 101 | filters = torch.ones(1, 1, patch_size, patch_size, device=ref.device) 102 | padding = patch_size // 2 103 | ref_sum = F.conv2d(ref, filters, stride=1, padding=padding)[:, :, padding, padding] 104 | nea_sum = F.conv2d(nea, filters, stride=1, padding=padding)[:, :, padding, padding] 105 | ref2_sum = F.conv2d(ref2, filters, stride=1, padding=padding)[:, :, padding, padding] 106 | nea2_sum = F.conv2d(nea2, filters, stride=1, padding=padding)[:, :, padding, padding] 107 | ref_nea_sum = F.conv2d(ref_nea, filters, stride=1, padding=padding)[:, :, padding, padding] 108 | 109 | # average over kernel 110 | ref_avg = ref_sum / tps 111 | nea_avg = nea_sum / tps 112 | 113 | cross = ref_nea_sum - nea_avg * ref_sum 114 | ref_var = ref2_sum - ref_avg * ref_sum 115 | nea_var = nea2_sum - nea_avg * nea_sum 116 | 117 | cc = cross * cross / (ref_var * nea_var + 1e-8) 118 | ncc = 1 - cc 119 | ncc = torch.clamp(ncc, 0.0, 2.0) 120 | ncc = torch.mean(ncc, dim=1, keepdim=True) 121 | mask = (ncc < 0.9) 122 | return ncc, mask 123 | 124 | 125 | def dilate(bin_img, ksize=5): 126 | pad = (ksize - 1) // 2 127 | bin_img = F.pad(bin_img, pad=[pad, pad, pad, pad], mode='reflect') 128 | out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0) 129 | return out 130 | 131 | def erode(bin_img, ksize=5): 132 | out = 1 - dilate(1 - bin_img, ksize) 133 | return out 134 | 135 | def cal_gradient(data): 136 | """ 137 | data: [1, C, H, W] 138 | """ 139 | kernel_x = [[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]] 140 | kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0).to(data.device) 141 | 142 | kernel_y = [[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]] 143 | kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0).to(data.device) 144 | 145 | weight_x = nn.Parameter(data=kernel_x, requires_grad=False) 146 | weight_y = nn.Parameter(data=kernel_y, requires_grad=False) 147 | 148 | grad_x = F.conv2d(data, weight_x, padding='same') 149 | grad_y = F.conv2d(data, weight_y, padding='same') 150 | gradient = torch.abs(grad_x) + torch.abs(grad_y) 151 | 152 | return gradient 153 | 154 | 155 | def bilateral_smooth_loss(data, image, mask): 156 | """ 157 | image: [C, H, W] 158 | data: [C, H, W] 159 | mask: [C, H, W] 160 | """ 161 | rgb_grad = cal_gradient(image.mean(0, keepdim=True).unsqueeze(0)).squeeze(0) # [1, H, W] 162 | data_grad = cal_gradient(data.mean(0, keepdim=True).unsqueeze(0)).squeeze(0) # [1, H, W] 163 | 164 | smooth_loss = (data_grad * (-rgb_grad).exp() * mask).mean() 165 | 166 | return smooth_loss 167 | 168 | 169 | def second_order_edge_aware_loss(data, img): 170 | return (spatial_gradient(data[None], order=2)[0, :, [0, 2]].abs() * torch.exp(-10*spatial_gradient(img[None], order=1)[0].abs())).sum(1).mean() 171 | 172 | 173 | def first_order_edge_aware_loss(data, img): 174 | return (spatial_gradient(data[None], order=1)[0].abs() * torch.exp(-spatial_gradient(img[None], order=1)[0].abs())).sum(1).mean() 175 | 176 | def first_order_edge_aware_norm_loss(data, img): 177 | return (spatial_gradient(data[None], order=1)[0].abs() * torch.exp(-spatial_gradient(img[None], order=1)[0].norm(dim=1, keepdim=True))).sum(1).mean() 178 | 179 | def first_order_loss(data): 180 | return spatial_gradient(data[None], order=1)[0].abs().sum(1).mean() 181 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from errno import EEXIST 15 | from os import makedirs, path 16 | 17 | def searchForMaxIteration(folder): 18 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 19 | return max(saved_iters) 20 | 21 | def mkdir_p(folder_path): 22 | # Creates a directory. equivalent to using mkdir -p on the command line 23 | try: 24 | makedirs(folder_path) 25 | except OSError as exc: # Python >2.5 26 | if exc.errno == EEXIST and path.isdir(folder_path): 27 | pass 28 | else: 29 | raise 30 | 31 | class Timing: 32 | """ 33 | From https://github.com/sxyu/svox2/blob/ee80e2c4df8f29a407fda5729a494be94ccf9234/svox2/utils.py#L611 34 | 35 | Timing environment 36 | usage: 37 | with Timing("message"): 38 | your commands here 39 | will print CUDA runtime in ms 40 | """ 41 | 42 | def __init__(self, name): 43 | self.name = name 44 | 45 | def __enter__(self): 46 | self.start = torch.cuda.Event(enable_timing=True) 47 | self.end = torch.cuda.Event(enable_timing=True) 48 | self.start.record() 49 | 50 | def __exit__(self, type, value, traceback): 51 | self.end.record() 52 | torch.cuda.synchronize() 53 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") -------------------------------------------------------------------------------- /visualize_gs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import glob 12 | import os 13 | import torch 14 | 15 | from scene import GaussianModel, EnvLight 16 | from utils.general_utils import seed_everything 17 | from tqdm import tqdm 18 | from argparse import ArgumentParser, Namespace 19 | from omegaconf import OmegaConf 20 | 21 | EPS = 1e-5 22 | 23 | 24 | 25 | if __name__ == "__main__": 26 | # Set up command line argument parser 27 | parser = ArgumentParser(description="Training script parameters") 28 | parser.add_argument("--config_path", type=str, required=True) 29 | params, _ = parser.parse_known_args() 30 | 31 | args = OmegaConf.load(params.config_path) 32 | args.resolution_scales = args.resolution_scales[:1] 33 | # print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True)))) 34 | # convert to DictConfig 35 | conf = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) 36 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 37 | cfg_log_f.write(str(Namespace(**conf))) 38 | 39 | seed_everything(args.seed) 40 | 41 | sep_path = os.path.join(args.model_path, 'point_cloud') 42 | os.makedirs(sep_path, exist_ok=True) 43 | 44 | gaussians = GaussianModel(args) 45 | 46 | if args.env_map_res > 0: 47 | env_map = EnvLight(resolution=args.env_map_res).cuda() 48 | env_map.training_setup(args) 49 | else: 50 | env_map = None 51 | 52 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth")) 53 | assert len(checkpoints) > 0, "No checkpoints found." 54 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1] 55 | print(f"Loading checkpoint {checkpoint}") 56 | (model_params, first_iter) = torch.load(checkpoint) 57 | gaussians.restore(model_params, args) 58 | 59 | if env_map is not None: 60 | env_checkpoint = os.path.join(os.path.dirname(checkpoint), 61 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt")) 62 | (light_params, _) = torch.load(env_checkpoint) 63 | env_map.restore(light_params) 64 | print("Number of Gaussians: ", gaussians.get_xyz.shape[0]) 65 | 66 | 67 | save_timestamp = 0.0 68 | gaussians.save_ply_at_t(os.path.join(args.model_path, "point_cloud", "iteration_{}".format(first_iter), "point_cloud.ply"), save_timestamp) 69 | 70 | 71 | --------------------------------------------------------------------------------