├── .gitignore
├── README.md
├── assets
├── combined_normal.gif
├── combined_render.gif
├── combined_video.gif
├── depth.gif
├── pipeline.png
├── qual.png
└── teaser.png
├── configs
├── base.yaml
├── emer_reconstruction_stage1.yaml
└── emer_reconstruction_stage2.yaml
├── docs
└── dataset_prep.md
├── evaluate.py
├── gaussian_renderer
├── __init__.py
├── gs_render.py
└── pvg_render.py
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── requirements.txt
├── scene
├── __init__.py
├── cameras.py
├── dinov2.py
├── dynamic_model.py
├── emer_waymo_loader.py
├── emernerf_loader.py
├── envlight.py
├── fit3d.py
├── gaussian_model.py
├── kittimot_loader.py
├── scene_utils.py
└── waymo_loader.py
├── scripts
├── extract_mask_kitti.py
├── extract_mask_waymo.py
├── extract_mono_cues_kitti.py
├── extract_mono_cues_notr.py
├── extract_mono_cues_waymo.py
├── waymo_converter.py
└── waymo_download.py
├── separate.py
├── train.py
├── utils
├── camera_utils.py
├── feature_extractor.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── sh_utils.py
└── system_utils.py
└── visualize_gs.py
/.gitignore:
--------------------------------------------------------------------------------
1 | work_dir*/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 | data/
11 | paper_code/
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 | __pycache__/
32 | *.so
33 | build/
34 | *.egg-info/
35 | .vscode
36 |
37 |
38 | build
39 | data/
40 | dataset/
41 | output
42 | eval_output
43 | diff-gaussian-rasterization
44 | nvdiffrast
45 | simple-knn
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .nox/
60 | .coverage
61 | .coverage.*
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 | *.cover
66 | *.py,cover
67 | .hypothesis/
68 | .pytest_cache/
69 | cover/
70 |
71 | # Translations
72 | *.mo
73 | *.pot
74 |
75 | # Django stuff:
76 | *.log
77 | local_settings.py
78 | db.sqlite3
79 | db.sqlite3-journal
80 |
81 | # Flask stuff:
82 | instance/
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | .pybuilder/
93 | target/
94 |
95 | # Jupyter Notebook
96 | .ipynb_checkpoints
97 |
98 | # IPython
99 | profile_default/
100 | ipython_config.py
101 |
102 | # pyenv
103 | # For a library or package, you might want to ignore these files since the code is
104 | # intended to run in multiple environments; otherwise, check them in:
105 | # .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # poetry
115 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116 | # This is especially recommended for binary packages to ensure reproducibility, and is more
117 | # commonly ignored for libraries.
118 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119 | #poetry.lock
120 |
121 | # pdm
122 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123 | #pdm.lock
124 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125 | # in version control.
126 | # https://pdm.fming.dev/#use-with-ide
127 | .pdm.toml
128 |
129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130 | __pypackages__/
131 |
132 | # Celery stuff
133 | celerybeat-schedule
134 | celerybeat.pid
135 |
136 | # SageMath parsed files
137 | *.sage.py
138 |
139 | # Environments
140 | .env
141 | .venv
142 | env/
143 | venv/
144 | ENV/
145 | env.bak/
146 | venv.bak/
147 |
148 | # Spyder project settings
149 | .spyderproject
150 | .spyproject
151 |
152 | # Rope project settings
153 | .ropeproject
154 |
155 | # mkdocs documentation
156 | /site
157 |
158 | # mypy
159 | .mypy_cache/
160 | .dmypy.json
161 | dmypy.json
162 |
163 | # Pyre type checker
164 | .pyre/
165 |
166 | # pytype static type analyzer
167 | .pytype/
168 |
169 | # Cython debug symbols
170 | cython_debug/
171 |
172 | # PyCharm
173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
175 | # and can be added to the global gitignore or merged into this file. For a more nuclear
176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
177 | #.idea/
178 |
179 |
180 | # custom
181 | # logs
182 | data/
183 | local_scripts/
184 | package_hacks/
185 |
186 | # caches
187 | *.pyc
188 | *.swp
189 |
190 | # media
191 | *.mp4
192 | *.png
193 | *.jpg
194 |
195 |
196 | # wandb
197 | wandb/
198 |
199 | # work in progress
200 | *wip*
201 |
202 | *results*
203 | *debug*
204 |
205 | network/src/*
206 |
207 | network/threestudio_*/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
DeSiRe-GS
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | > [**DeSiRe-GS: 4D Street Gaussians for Static-Dynamic Decomposition and Surface Reconstruction for Urban Driving Scenes**](https://arxiv.org/abs/2411.11921)
11 | >
12 | > [Chensheng Peng](https://pholypeng.github.io/), [Chengwei Zhang](https://chengweialan.github.io/), [Yixiao Wang](https://yixiaowang7.github.io/), [Chenfeng Xu](https://www.chenfengx.com/), [Yichen Xie](https://scholar.google.com/citations?user=SdX6DaEAAAAJ), [Wenzhao Zheng](https://wzzheng.net/), [Kurt Keutzer](https://people.eecs.berkeley.edu/~keutzer/), [Masayoshi Tomizuka](https://me.berkeley.edu/people/masayoshi-tomizuka/), [Wei Zhan](https://zhanwei.site/)
13 | >
14 | > **Arxiv preprint**
15 |
16 |
17 |
18 | ## 📖 Overview
19 |
20 |
21 |
22 |
23 |
24 |
25 | In the first stage, we extract 2D motion masks based on the observation that 3D Gaussian Splatting inherently can reconstruct only the static regions in dynamic environments. These extracted 2D motion priors are then mapped into the Gaussian space in a differentiable manner, leveraging an efficient formulation of dynamic Gaussians in the second stage.
26 |
27 |
28 | ## 🛠️ Installation
29 |
30 | We test our code on Ubuntu 20.04 using Python 3.10 and PyTorch 2.2.0. We recommend using conda to install all the independencies.
31 |
32 |
33 | 1. Create the conda environment and install requirements.
34 |
35 | ```
36 | # Clone the repo.
37 | git clone https://github.com/chengweialan/DeSiRe-GS.git
38 | cd DeSiRe-GS
39 |
40 | # Create the conda environment.
41 | conda create -n DeSiReGS python==3.10
42 | conda activate DeSiReGS
43 |
44 | # Install torch.
45 | pip install torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu118 # replace with your own CUDA version
46 |
47 | # Install requirements.
48 | pip install -r requirements.txt
49 | ```
50 |
51 | 2. Install the submodules. The repository contains the same submodules as [PVG](https://github.com/fudan-zvg/PVG).
52 |
53 | ```
54 | # Install simple-knn
55 | git clone https://gitlab.inria.fr/bkerbl/simple-knn.git
56 | pip install ./simple-knn
57 |
58 | # a modified gaussian splatting (for feature rendering)
59 | git clone --recursive https://github.com/SuLvXiangXin/diff-gaussian-rasterization
60 | pip install ./diff-gaussian-rasterization
61 |
62 | # Install nvdiffrast (for Envlight)
63 | git clone https://github.com/NVlabs/nvdiffrast
64 | pip install ./nvdiffrast
65 | ```
66 |
67 |
68 |
69 | ## 💾 Data Preparation
70 |
71 | Create a directory to save the data. Run ```mkdir dataset```.
72 |
73 | We provide a sample sequence in [google drive](https://drive.google.com/drive/u/0/folders/1fHQJy0cq9ofADpCxtlfmpCh6nCpcw-mH), you may download it and unzip it to `dataset`.
74 |
75 | Waymo Dataset
76 |
77 |
78 | | Source | Number of Sequences | Scene Type | Description |
79 | | :--------------------------------------------: | :-----------------: | :---------------------: | ------------------------------------------------------------ |
80 | | [PVG](https://github.com/fudan-zvg/PVG) | 4 | Dynamic | • Refer to [this page](https://github.com/fudan-zvg/PVG?tab=readme-ov-file#data-preparation). |
81 | | [OmniRe](https://ziyc.github.io/omnire/) | 8 | Dynamic | • Described as highly complex dynamic
• Refer to [this page](https://github.com/ziyc/drivestudio/blob/main/docs/Waymo.md). |
82 | | [EmerNeRF](https://github.com/NVlabs/EmerNeRF) | 64 | 32 dynamic
32 static | • Contains 32 static, 32 dynamic and 56 diverse scenes.
• We test our code on the 32 static and 32 dynamic scenes.
• See [this page](https://github.com/NVlabs/EmerNeRF?tab=readme-ov-file#dataset-preparation) for detailed instructions. |
83 |
84 |
85 |
86 | KITTI Dataset
87 |
88 | | Source | Number of Sequences | Scene Type | Description |
89 | | :-------------------------------------: | :-----------------: | :--------: | ------------------------------------------------------------ |
90 | | [PVG](https://github.com/fudan-zvg/PVG) | 3 | Dynamic | • Refer to [this page](https://github.com/fudan-zvg/PVG?tab=readme-ov-file#kitti-dataset). |
91 |
92 |
93 |
94 | ## :memo: Training and Evaluation
95 |
96 | ### Training
97 |
98 | First, we use the following command to train for stage I,
99 |
100 | ```sh
101 | # Stage 1
102 | python train.py \
103 | --config configs/emer_reconstruction_stage1.yaml \
104 | source_path=dataset/084 \
105 | model_path=eval_output/waymo_reconstruction/084_stage1
106 | ```
107 |
108 | After running the command, the uncertainty model will be saved in ```${YOUR_MODEL_PATH}/uncertainty_model.pth``` by default.
109 |
110 | ```sh
111 | # Stage 2
112 | python train.py \
113 | --config configs/emer_reconstruction_stage2.yaml \
114 | source_path=dataset/084 \
115 | model_path=eval_output/waymo_reconstruction/084_stage2 \
116 | uncertainty_model_path=eval_output/waymo_reconstruction/084_stage1/uncertainty_model30000.pth
117 | ```
118 |
119 | ### Evaluating
120 |
121 | We provide the checkpoints in [google drive](https://drive.google.com/drive/u/0/folders/1fHQJy0cq9ofADpCxtlfmpCh6nCpcw-mH), you may download it and unzip it under `${PROJECT_FOLDER}`
122 |
123 | ```sh
124 | python evaluate.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml
125 | ```
126 |
127 | ### Static-Dynamic Decomposition
128 |
129 | We provide code ```separate.py``` for static-dynamic decomposition. Run
130 |
131 | ```
132 | python separate.py --config_path ${YOUR_MODEL_PATH}/config.yaml
133 | ```
134 |
135 | For instance,
136 |
137 | ```
138 | # example
139 | python separate.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml
140 | ```
141 |
142 | The decomposition results will be saved in `${MODEL_PATH}/separation`
143 |
144 |
145 |
146 |
147 |
148 |
149 | Rendered Image
150 |
151 | |
152 |
153 |
154 | |
155 |
156 |
157 |
158 |
159 | Decomposed Static
160 |
161 | |
162 |
163 |
164 | |
165 |
166 |
167 |
168 |
169 | Rendered Depth
170 |
171 | |
172 |
173 |
174 | |
175 |
176 |
177 |
178 |
179 | Rendered Normal
180 |
181 | |
182 |
183 |
184 | |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 | ## :clapper: Visualization
194 |
195 | ### 3D Gaussians Visualization
196 |
197 | Following [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), we use [SIBR](https://sibr.gitlabpages.inria.fr/) framework, which is developed by GRAPHDECO group, as an interactive viewer to visualize the gaussian ellipsoids. Refer to [this page](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#interactive-viewers) for more installation details.
198 |
199 | We provide code ```visualize_gs.py``` for gaussian ellipsoids visualization. For example, run
200 |
201 | ```
202 | # Save gaussian point cloud.
203 | python visualize_gs.py --config_path eval_output/waymo_reconstruction/084_stage2/config.yaml
204 | ```
205 |
206 | The ```.ply``` file which contains visible gaussians will be saved in ```${YOUR_MODEL_PATH}/point_cloud/point_cloud.ply```. You can use SIBR Viewer to visualize the gaussians directly in your model path folder. For example,
207 |
208 | ```
209 | # Enter your SIBR folder.
210 | cd ${SIBR_FOLDER}/SIBR_viewers/install/bin
211 |
212 | # Visualize the gaussians.
213 | ./SIBR_gaussianViewer_app -m ${YOUR_MODEL_PATH}/
214 |
215 | # Example
216 | ./SIBR_gaussianViewer_app -m ${PROJECT_FOLDER}/eval_output/waymo_reconstruction/084_stage2/
217 | ```
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 | ## 📜 BibTeX
226 |
227 | ```bibtex
228 | @misc{peng2024desiregs4dstreetgaussians,
229 | title={DeSiRe-GS: 4D Street Gaussians for Static-Dynamic Decomposition and Surface Reconstruction for Urban Driving Scenes},
230 | author={Chensheng Peng and Chengwei Zhang and Yixiao Wang and Chenfeng Xu and Yichen Xie and Wenzhao Zheng and Kurt Keutzer and Masayoshi Tomizuka and Wei Zhan},
231 | year={2024},
232 | eprint={2411.11921},
233 | archivePrefix={arXiv},
234 | primaryClass={cs.CV},
235 | }
236 | ```
237 |
238 |
239 |
240 |
241 | ## :pray: Acknowledgements
242 |
243 | We adapted some codes from some awesome repositories including [PVG](https://github.com/fudan-zvg/PVG) and [PGSR](https://github.com/zju3dv/PGSR/).
244 |
245 |
--------------------------------------------------------------------------------
/assets/combined_normal.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_normal.gif
--------------------------------------------------------------------------------
/assets/combined_render.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_render.gif
--------------------------------------------------------------------------------
/assets/combined_video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/combined_video.gif
--------------------------------------------------------------------------------
/assets/depth.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/depth.gif
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/qual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/qual.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengweialan/DeSiRe-GS/c38a56767a9c72591e684db47754004787210cb2/assets/teaser.png
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | test_iterations: [7000, 30000]
2 | save_iterations: [7000, 30000]
3 | checkpoint_iterations: [7000, 30000]
4 | exhaust_test: false
5 | test_interval: 5000
6 | render_static: false
7 | vis_step: 500
8 | start_checkpoint: null
9 | seed: 0
10 | resume: false
11 |
12 | # ModelParams
13 | sh_degree: 3
14 | scene_type: "Waymo"
15 | source_path: ???
16 | start_frame: 65 # for kitti
17 | end_frame: 120 # for kitti
18 | model_path: ???
19 | resolution_scales: [1]
20 | resolution: -1
21 | white_background: false
22 | data_device: "cuda"
23 | eval: false
24 | debug_cuda: false
25 | cam_num: 3 # for waymo
26 | t_init: 0.2
27 | cycle: 0.2
28 | velocity_decay: 1.0
29 | random_init_point: 200000
30 | fix_radius: 0.0
31 | time_duration: [-0.5, 0.5]
32 | num_pts: 100000
33 | frame_interval: 0.02
34 | testhold: 4 # NVS
35 | env_map_res: 1024
36 | separate_scaling_t: 0.1
37 | neg_fov: true
38 |
39 | # PipelineParams
40 | convert_SHs_python: false
41 | compute_cov3D_python: false
42 | debug: false
43 | depth_blend_mode: 0
44 | env_optimize_until: 1000000000
45 | env_optimize_from: 0
46 |
47 |
48 | # OptimizationParams
49 | iterations: 30000
50 | position_lr_init: 0.00016
51 | position_lr_final: 0.0000016
52 | position_lr_delay_mult: 0.01
53 | t_lr_init: 0.0008
54 | position_lr_max_steps: 30_000
55 | feature_lr: 0.0025
56 | opacity_lr: 0.05
57 | scaling_lr: 0.005
58 | scaling_t_lr: 0.002
59 | velocity_lr: 0.001
60 | rotation_lr: 0.001
61 | envmap_lr: 0.01
62 | normal_lr: 0.001
63 |
64 | time_split_frac: 0.5
65 | percent_dense: 0.01
66 | thresh_opa_prune: 0.005
67 | densification_interval: 100
68 | opacity_reset_interval: 3000
69 | densify_from_iter: 500
70 | densify_until_iter: 15_000
71 | densify_grad_threshold: 0.0002
72 | densify_grad_t_threshold: 0.002
73 | densify_until_num_points: 3000000
74 | sh_increase_interval: 1000
75 | scale_increase_interval: 5000
76 | prune_big_point: 1
77 | size_threshold: 20
78 | big_point_threshold: 0.1
79 | t_grad: true
80 | no_time_split: true
81 | contract: true
82 |
83 | lambda_dssim: 0.2
84 | lambda_opa: 0.0
85 | lambda_sky_opa: 0.05
86 | lambda_opacity_entropy: 0.05
87 | lambda_inv_depth: 0.001
88 | lambda_self_supervision: 0.5 # for random sampling to self-supervise
89 | lambda_t_reg: 0.0
90 | lambda_v_reg: 0.0
91 | lambda_lidar: 0.1
92 | lidar_decay: 1.0
93 | lambda_v_smooth: 0.01
94 | lambda_t_smooth: 0.0
95 | lambda_normal: 0.0
--------------------------------------------------------------------------------
/configs/emer_reconstruction_stage1.yaml:
--------------------------------------------------------------------------------
1 | test_iterations: [30000]
2 | save_iterations: [30000]
3 | checkpoint_iterations: [30000]
4 | isotropic: false
5 | exhaust_test: false
6 |
7 | enable_dynamic: false
8 | scale_increase_interval: 1500
9 | uncertainty_warmup_iters: 1800
10 | uncertainty_warmup_start: 6000
11 |
12 | # ModelParams
13 | scene_type: "EmerWaymo"
14 | resolution_scales: [1, 2, 4, 8, 16]
15 | cam_num: 3
16 | eval: false
17 | num_pts: 600000
18 | t_init: 0.1
19 | time_duration: [0, 1]
20 | separate_scaling_t: 0.2
21 | start_time: 0
22 | end_time: 49
23 | stride: 0
24 | original_start_time: 0
25 |
26 | load_sky_mask: true
27 | load_panoptic_mask: false
28 | load_sam_mask: false
29 | load_dynamic_mask: true
30 | load_feat_map: false
31 | load_normal_map: false
32 |
33 | load_intrinsic: false
34 | load_c2w: false
35 |
36 | save_occ_grid: true
37 | occ_voxel_size: 0.4
38 | recompute_occ_grid: false
39 | use_bg_gs: false
40 | white_background: false
41 | # PipelineParams
42 |
43 |
44 | # OptimizationParams
45 | iterations: 30000
46 |
47 | dynamic_mask_epoch: 50000
48 |
49 | opacity_lr: 0.005
50 |
51 | densify_until_iter: 15000
52 | densify_grad_threshold: 0.00017
53 | sh_increase_interval: 2000
54 |
55 |
56 | lambda_v_reg: 0.01
57 | lambda_normal: 0.0
58 | lambda_lidar: 0.0
59 | lambda_scaling: 0.0
60 | lambda_min_scale: 0.0
61 | lambda_max_scale: 0.0
62 |
63 |
64 | # uncertainty model
65 | uncertainty_stage: stage1
66 | uncertainty_mode: "dino" # ["disabled", "l2reg", "l1reg", "dino", "dino+mssim"]
67 | uncertainty_backbone: "dinov2_vits14_reg"
68 | uncertainty_regularizer_weight: 0.5
69 | uncertainty_clip_min: 0.1
70 | uncertainty_mask_clip_max: null
71 | uncertainty_dssim_clip_max: 1.0 # 0.05 -> 0.005
72 | uncertainty_lr: 0.001
73 | uncertainty_dropout: 0.1
74 | uncertainty_dino_max_size: null
75 | uncertainty_scale_grad: false
76 | uncertainty_center_mult: false
77 | uncertainty_after_opacity_reset: 1000
78 | uncertainty_protected_iters: 500
79 | uncertainty_preserve_sky: false
80 |
81 |
82 | render_type: pvg
83 |
84 |
85 | multi_view_weight_from_iter: 50000
86 | multi_view_patch_size: 3
87 | multi_view_sample_num: 102400
88 | multi_view_ncc_weight: 0.15
89 | multi_view_geo_weight: 0.0
90 | multi_view_pixel_noise_th: 1.0
91 |
--------------------------------------------------------------------------------
/configs/emer_reconstruction_stage2.yaml:
--------------------------------------------------------------------------------
1 | exhaust_test: false
2 | isotropic: false
3 | enable_dynamic: true
4 |
5 | scale_increase_interval: 5000
6 | # ModelParams
7 | scene_type: "EmerWaymo"
8 | resolution_scales: [1, 2, 4, 8, 16]
9 | cam_num: 3
10 | eval: false
11 | num_pts: 600000
12 | t_init: 0.1
13 | time_duration: [0, 1]
14 | separate_scaling_t: 0.2
15 | separate_velocity: 10
16 | start_time: 0
17 | end_time: 49
18 | stride: 0
19 | original_start_time: 0
20 |
21 | load_sky_mask: true
22 | load_panoptic_mask: false
23 | load_sam_mask: false
24 | load_dynamic_mask: true
25 | load_feat_map: false
26 | load_normal_map: true
27 |
28 | load_intrinsic: false
29 | load_c2w: false
30 |
31 | save_occ_grid: true
32 | occ_voxel_size: 0.4
33 | recompute_occ_grid: false
34 | use_bg_gs: false
35 | white_background: false
36 | # PipelineParams
37 |
38 |
39 | # OptimizationParams
40 | iterations: 50000
41 |
42 |
43 |
44 | opacity_lr: 0.005
45 |
46 | densify_until_iter: 15000
47 | densify_grad_threshold: 0.00017
48 | sh_increase_interval: 2000
49 |
50 |
51 | lambda_t_reg: 0.0
52 | lambda_v_reg: 0.01
53 | lambda_normal: 0.1
54 | lambda_lidar: 0.1
55 | lambda_scaling: 0.0
56 | lambda_depth_var: 0.0
57 | lambda_min_scale: 10.0
58 | min_scale: 0.001
59 | max_scale: 0.2 # corresponding to 4.0m
60 | lambda_max_scale: 1.0
61 |
62 | neg_fov: false
63 | # uncertainty model
64 | dynamic_mask_epoch: 30000
65 | uncertainty_stage: stage2
66 | uncertainty_mode: "dino" # ["disabled", "l2reg", "l1reg", "dino", "dino+mssim"]
67 | uncertainty_backbone: "dinov2_vits14_reg"
68 | uncertainty_model_path: null
69 | uncertainty_regularizer_weight: 0.5
70 | uncertainty_clip_min: 0.1
71 | uncertainty_mask_clip_max: null
72 | uncertainty_dssim_clip_max: 1.0 # 0.05 -> 0.005
73 | uncertainty_lr: 0.001
74 | uncertainty_dropout: 0.1
75 | uncertainty_dino_max_size: null
76 | uncertainty_scale_grad: false
77 | uncertainty_center_mult: false
78 | uncertainty_after_opacity_reset: 1000
79 | uncertainty_protected_iters: 500
80 | uncertainty_preserve_sky: false
81 |
82 |
83 | render_type: pvg
84 |
85 |
86 | multi_view_weight_from_iter: 20000
87 | multi_view_patch_size: 3
88 | multi_view_sample_num: 102400
89 | multi_view_ncc_weight: 0.0
90 | multi_view_geo_weight: 0.03
91 | multi_view_pixel_noise_th: 1.0
--------------------------------------------------------------------------------
/docs/dataset_prep.md:
--------------------------------------------------------------------------------
1 | # Preparing Normals
2 |
3 | We use [Omnidata](https://github.com/EPFL-VILAB/omnidata) as the foundation model to extract monocular normals. Feel free to tryother models, such as [Metric3D](https://github.com/YvanYin/Metric3D).
4 |
5 | ## Installation
6 |
7 |
8 | ```sh
9 | git clone https://github.com/EPFL-VILAB/omnidata.git
10 | cd omnidata/omnidata_tools/torch
11 | mkdir -p pretrained_models && cd pretrained_models
12 | wget 'https://zenodo.org/records/10447888/files/omnidata_dpt_depth_v2.ckpt'
13 | wget 'https://zenodo.org/records/10447888/files/omnidata_dpt_normal_v2.ckpt'
14 | ```
15 |
16 | ## Processing
17 |
18 | ### Waymo dataset
19 |
20 | If you download the dataset from [PVG](https://drive.google.com/file/d/1eTNJz7WeYrB3IctVlUmJIY0z8qhjR_qF/view?usp=sharing) to `dataset`, run the following command to generate the normal maps
21 |
22 | ```sh
23 | python scripts/extract_mono_cues_waymo.py --data_root ./dataset/waymo_scenes --task normal
24 | ```
25 |
26 | ### KITTI dataset
27 |
28 | ```sh
29 | python scripts/extract_mono_cues_kitti.py --data_root ./dataset/kitti_mot/training --task normal
30 | ```
31 |
32 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import glob
12 | import json
13 | import os
14 | import torch
15 | import torch.nn.functional as F
16 | from utils.loss_utils import psnr, ssim
17 | from gaussian_renderer import get_renderer
18 | from scene import Scene, GaussianModel, EnvLight
19 | from utils.general_utils import seed_everything, visualize_depth
20 | from tqdm import tqdm
21 | from argparse import ArgumentParser
22 | from torchvision.utils import make_grid, save_image
23 | from omegaconf import OmegaConf
24 | from pprint import pprint, pformat
25 | from omegaconf import OmegaConf
26 | from texttable import Texttable
27 | import cv2
28 | import numpy as np
29 | import warnings
30 | warnings.filterwarnings("ignore")
31 | EPS = 1e-5
32 | non_zero_mean = (
33 | lambda x: sum(x) / len(x) if len(x) > 0 else -1
34 | )
35 | @torch.no_grad()
36 | def evaluation(iteration, scene : Scene, renderFunc, renderArgs, env_map=None):
37 | from lpipsPyTorch import lpips
38 |
39 | scale = scene.resolution_scales[0]
40 | if "kitti" in args.model_path:
41 | # follow NSG: https://github.com/princeton-computational-imaging/neural-scene-graphs/blob/8d3d9ce9064ded8231a1374c3866f004a4a281f8/data_loader/load_kitti.py#L766
42 | num = len(scene.getTrainCameras())//2
43 | eval_train_frame = num//5
44 | traincamera = sorted(scene.getTrainCameras(), key =lambda x: x.colmap_id)
45 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
46 | {'name': 'train', 'cameras': traincamera[:num][-eval_train_frame:]+traincamera[num:][-eval_train_frame:]})
47 | else:
48 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
49 | {'name': 'train', 'cameras': scene.getTrainCameras()})
50 |
51 | for config in validation_configs:
52 | if config['cameras'] and len(config['cameras']) > 0:
53 | l1_test = []
54 | psnr_test = []
55 | ssim_test = []
56 | lpips_test = []
57 | masked_psnr_test = []
58 | masked_ssim_test = []
59 | outdir = os.path.join(args.model_path, "eval", "eval_on_" + config['name'] + "data" + "_render")
60 | image_folder = os.path.join(args.model_path, "images")
61 | os.makedirs(outdir,exist_ok=True)
62 | os.makedirs(image_folder,exist_ok=True)
63 | # opaciity_mask = scene.gaussians.get_opacity[:, 0] > 0.01
64 | # print("number of valid gaussians: {:d}".format(opaciity_mask.sum().item()))
65 | for camera_id, viewpoint in enumerate(tqdm(config['cameras'], bar_format="{l_bar}{bar:50}{r_bar}")):
66 | # render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=opaciity_mask)
67 | _, _, render_pkg = renderFunc(args, viewpoint, gaussians, background, scene.time_interval, env_map, iteration, camera_id)
68 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
69 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
70 | cv2.imwrite(os.path.join(image_folder, f"{viewpoint.colmap_id:03d}_gt.png"), (gt_image[[2,1,0], :, :].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
71 | cv2.imwrite(os.path.join(image_folder, f"{viewpoint.colmap_id:03d}_render.png"), (image[[2,1,0], :, :].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
72 |
73 | depth = render_pkg['depth']
74 | alpha = render_pkg['alpha']
75 | sky_depth = 900
76 | depth = depth / alpha.clamp_min(EPS)
77 | if env_map is not None:
78 | if args.depth_blend_mode == 0: # harmonic mean
79 | depth = 1 / (alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth).clamp_min(EPS)
80 | elif args.depth_blend_mode == 1:
81 | depth = alpha * depth + (1 - alpha) * sky_depth
82 | sky_mask = viewpoint.sky_mask.to("cuda")
83 | # dynamic_mask = viewpoint.dynamic_mask.to("cuda") if viewpoint.dynamic_mask is not None else torch.zeros_like(alpha, dtype=torch.bool)
84 | dynamic_mask = render_pkg['dynamic_mask']
85 | depth = visualize_depth(depth)
86 | alpha = alpha.repeat(3, 1, 1)
87 |
88 | grid = [image, alpha, depth, gt_image, torch.logical_not(sky_mask[:1]).float().repeat(3, 1, 1), dynamic_mask.float().repeat(3, 1, 1)]
89 | grid = make_grid(grid, nrow=3)
90 |
91 | save_image(grid, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png"))
92 |
93 | l1_test.append(F.l1_loss(image, gt_image).double().item())
94 | psnr_test.append(psnr(image, gt_image).double().item())
95 | ssim_test.append(ssim(image, gt_image).double().item())
96 | lpips_test.append(lpips(image, gt_image, net_type='alex').double().item()) # very slow
97 | # get the binary dynamic mask
98 |
99 | if dynamic_mask.sum() > 0:
100 | dynamic_mask = dynamic_mask.repeat(3, 1, 1) > 0 # (C, H, W)
101 | masked_psnr_test.append(psnr(image[dynamic_mask], gt_image[dynamic_mask]).double().item())
102 | unaveraged_ssim = ssim(image, gt_image, size_average=False) # (C, H, W)
103 | masked_ssim_test.append(unaveraged_ssim[dynamic_mask].mean().double().item())
104 |
105 |
106 | psnr_test = non_zero_mean(psnr_test)
107 | l1_test = non_zero_mean(l1_test)
108 | ssim_test = non_zero_mean(ssim_test)
109 | lpips_test = non_zero_mean(lpips_test)
110 | masked_psnr_test = non_zero_mean(masked_psnr_test)
111 | masked_ssim_test = non_zero_mean(masked_ssim_test)
112 |
113 |
114 | t = Texttable()
115 | t.add_rows([["PSNR", "SSIM", "LPIPS", "L1", "PSNR (dynamic)", "SSIM (dynamic)"],
116 | [f"{psnr_test:.4f}", f"{ssim_test:.4f}", f"{lpips_test:.4f}", f"{l1_test:.4f}", f"{masked_psnr_test:.4f}", f"{masked_ssim_test:.4f}"]])
117 | print(t.draw())
118 | with open(os.path.join(outdir, "metrics.json"), "w") as f:
119 | json.dump({"split": config['name'], "iteration": iteration,
120 | "psnr": psnr_test, "ssim": ssim_test, "lpips": lpips_test, "masked_psnr": masked_psnr_test, "masked_ssim": masked_ssim_test,
121 | }, f)
122 |
123 |
124 | if __name__ == "__main__":
125 | # Set up command line argument parser
126 | parser = ArgumentParser(description="Training script parameters")
127 | parser.add_argument("--config_path", type=str, required=True)
128 | params, _ = parser.parse_known_args()
129 |
130 | args = OmegaConf.load(params.config_path)
131 | args.resolution_scales = args.resolution_scales[:1]
132 | print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True))))
133 |
134 | seed_everything(args.seed)
135 |
136 | sep_path = os.path.join(args.model_path, 'separation')
137 | os.makedirs(sep_path, exist_ok=True)
138 |
139 | gaussians = GaussianModel(args)
140 | scene = Scene(args, gaussians, shuffle=False)
141 |
142 | if args.env_map_res > 0:
143 | env_map = EnvLight(resolution=args.env_map_res).cuda()
144 | env_map.training_setup(args)
145 | else:
146 | env_map = None
147 |
148 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
149 | assert len(checkpoints) > 0, "No checkpoints found."
150 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
151 | print(f"Loading checkpoint {checkpoint}")
152 | (model_params, first_iter) = torch.load(checkpoint)
153 | gaussians.restore(model_params, args)
154 |
155 | if env_map is not None:
156 | env_checkpoint = os.path.join(os.path.dirname(checkpoint),
157 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt"))
158 | (light_params, _) = torch.load(env_checkpoint)
159 | env_map.restore(light_params)
160 | uncertainty_model_path = os.path.join(os.path.dirname(checkpoint),
161 | os.path.basename(checkpoint).replace("chkpnt", "uncertainty_model"))
162 | state_dict = torch.load(uncertainty_model_path)
163 | gaussians.uncertainty_model.load_state_dict(state_dict, strict=False)
164 |
165 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
166 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
167 | render_func, render_wrapper = get_renderer(args.render_type)
168 | evaluation(first_iter, scene, render_wrapper, (args, background), env_map=env_map)
169 |
170 | print("Evaluation complete.")
171 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from .gs_render import render_original_gs, render_gs_origin_wrapper
13 | from .pvg_render import render_pvg, render_pvg_wrapper
14 |
15 | EPS = 1e-5
16 |
17 | rendererTypeCallbacks = {
18 | "gs": render_original_gs,
19 | "pvg": render_pvg
20 | }
21 |
22 | renderWrapperTypeCallbacks = {
23 | "gs": render_gs_origin_wrapper,
24 | "pvg": render_pvg_wrapper,
25 | }
26 |
27 |
28 | def get_renderer(render_type: str):
29 | return rendererTypeCallbacks[render_type], renderWrapperTypeCallbacks[render_type]
30 |
--------------------------------------------------------------------------------
/gaussian_renderer/gs_render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 |
13 | """
14 | conda activate gs
15 |
16 | """
17 |
18 | import torch
19 | import math
20 | from diff_gaussian_rasterization import (
21 | GaussianRasterizationSettings,
22 | GaussianRasterizer,
23 | )
24 | from scene.gaussian_model import GaussianModel
25 | from utils.sh_utils import eval_sh
26 | from scene.cameras import Camera
27 | import torch.nn.functional as F
28 | import numpy as np
29 | import kornia
30 | from utils.loss_utils import psnr, ssim, tv_loss
31 |
32 | EPS = 1e-5
33 |
34 |
35 | def render_original_gs(
36 | viewpoint_camera: Camera,
37 | pc: GaussianModel,
38 | pipe,
39 | bg_color: torch.Tensor,
40 | scaling_modifier=1.0,
41 | override_color=None,
42 | env_map=None,
43 | time_shift=None,
44 | other=[],
45 | mask=None,
46 | is_training=False,
47 | return_opacity=True,
48 | return_depth=False,
49 | ):
50 | """
51 | Render the scene.
52 |
53 | Background tensor (bg_color) must be on GPU!
54 | """
55 |
56 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
57 | screenspace_points = (
58 | torch.zeros_like(
59 | pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
60 | )
61 | + 0
62 | )
63 | try:
64 | screenspace_points.retain_grad()
65 | except:
66 | pass
67 |
68 | # Set up rasterization configuration
69 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
70 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
71 |
72 | raster_settings = GaussianRasterizationSettings(
73 | image_height=int(viewpoint_camera.image_height),
74 | image_width=int(viewpoint_camera.image_width),
75 | tanfovx=tanfovx,
76 | tanfovy=tanfovy,
77 | bg=bg_color,
78 | scale_modifier=scaling_modifier,
79 | viewmatrix=viewpoint_camera.world_view_transform,
80 | projmatrix=viewpoint_camera.full_proj_transform,
81 | sh_degree=pc.active_sh_degree,
82 | campos=viewpoint_camera.camera_center,
83 | prefiltered=False,
84 | debug=pipe.debug,
85 | )
86 |
87 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
88 |
89 | means3D = pc.get_xyz
90 | means2D = screenspace_points
91 | opacity = pc.get_opacity
92 |
93 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
94 | # scaling / rotation by the rasterizer.
95 | scales = None
96 | rotations = None
97 | cov3D_precomp = None
98 |
99 | # if time_shift is not None:
100 | # means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp-time_shift)
101 | # means3D = means3D + pc.get_inst_velocity * time_shift
102 | # marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp-time_shift)
103 | # else:
104 | # means3D = pc.get_xyz_SHM(viewpoint_camera.timestamp)
105 | # marginal_t = pc.get_marginal_t(viewpoint_camera.timestamp)
106 | # opacity = opacity * marginal_t
107 |
108 | if pipe.compute_cov3D_python:
109 | cov3D_precomp = pc.get_covariance(scaling_modifier)
110 | else:
111 | scales = pc.get_scaling
112 | rotations = pc.get_rotation
113 |
114 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
115 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
116 | shs = None
117 | colors_precomp = None
118 | if override_color is None:
119 | if pipe.convert_SHs_python:
120 | shs_view = pc.get_features.transpose(1, 2).view(
121 | -1, 3, pc.get_max_sh_channels
122 | )
123 | dir_pp = (
124 | means3D.detach()
125 | - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)
126 | ).detach()
127 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
128 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
129 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
130 | else:
131 | shs = pc.get_features
132 | else:
133 | colors_precomp = override_color
134 |
135 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
136 | rendered_image, radii = rasterizer(
137 | means3D=means3D,
138 | means2D=means2D,
139 | shs=shs,
140 | colors_precomp=colors_precomp,
141 | opacities=opacity,
142 | scales=scales,
143 | rotations=rotations,
144 | cov3D_precomp=cov3D_precomp,
145 | )
146 | H, W = rendered_image.shape[1], rendered_image.shape[2]
147 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
148 | # They will be excluded from value updates used in the splitting criteria.
149 | return_dict = {
150 | "render_nobg": rendered_image,
151 | "normal": torch.zeros_like(rendered_image),
152 | "feature": torch.zeros([2, H, W]).to(rendered_image.device),
153 | "viewspace_points": screenspace_points,
154 | "visibility_filter": radii > 0,
155 | "radii": radii,
156 | }
157 |
158 | if return_opacity:
159 | density = torch.ones_like(means3D)
160 |
161 | render_opacity, _ = rasterizer(
162 | means3D=means3D,
163 | means2D=means2D,
164 | shs=None,
165 | colors_precomp=density,
166 | opacities=opacity,
167 | scales=scales,
168 | rotations=rotations,
169 | cov3D_precomp=cov3D_precomp,
170 | )
171 | render_opacity = render_opacity.mean(dim=0, keepdim=True) # (1, H, W)
172 | return_dict.update({"alpha": render_opacity})
173 |
174 | if return_depth:
175 | projvect1 = viewpoint_camera.world_view_transform[:, 2][:3].detach()
176 | projvect2 = viewpoint_camera.world_view_transform[:, 2][-1].detach()
177 | means3D_depth = (means3D * projvect1.unsqueeze(0)).sum(
178 | dim=-1, keepdim=True
179 | ) + projvect2
180 | means3D_depth = means3D_depth.repeat(1, 3)
181 | render_depth, _ = rasterizer(
182 | means3D=means3D,
183 | means2D=means2D,
184 | shs=None,
185 | colors_precomp=means3D_depth,
186 | opacities=opacity,
187 | scales=scales,
188 | rotations=rotations,
189 | cov3D_precomp=cov3D_precomp,
190 | )
191 | render_depth = render_depth.mean(dim=0, keepdim=True)
192 | return_dict.update({"depth": render_depth})
193 | else:
194 | return_dict.update({"depth": torch.zeros_like(render_opacity)})
195 |
196 | if env_map is not None:
197 | bg_color_from_envmap = env_map(
198 | viewpoint_camera.get_world_directions(is_training).permute(1, 2, 0)
199 | ).permute(2, 0, 1)
200 | rendered_image = rendered_image + (1 - render_opacity) * bg_color_from_envmap
201 | return_dict.update({"render": rendered_image})
202 |
203 | return return_dict
204 |
205 |
206 | def calculate_loss(
207 | gaussians: GaussianModel,
208 | viewpoint_camera: Camera,
209 | args,
210 | render_pkg: dict,
211 | env_map,
212 | iteration,
213 | camera_id,
214 | ):
215 | log_dict = {}
216 |
217 | image = render_pkg["render"]
218 | depth = render_pkg["depth"]
219 | alpha = render_pkg["alpha"]
220 |
221 | sky_mask = (
222 | viewpoint_camera.sky_mask.cuda()
223 | if viewpoint_camera.sky_mask is not None
224 | else torch.zeros_like(alpha, dtype=torch.bool)
225 | )
226 |
227 | sky_depth = 900
228 | depth = depth / alpha.clamp_min(EPS)
229 | if env_map is not None:
230 | if args.depth_blend_mode == 0: # harmonic mean
231 | depth = 1 / (
232 | alpha / depth.clamp_min(EPS) + (1 - alpha) / sky_depth
233 | ).clamp_min(EPS)
234 | elif args.depth_blend_mode == 1:
235 | depth = alpha * depth + (1 - alpha) * sky_depth
236 |
237 | gt_image = viewpoint_camera.original_image.cuda()
238 |
239 | loss_l1 = F.l1_loss(image, gt_image, reduction="none") # [3, H, W]
240 | loss_ssim = 1.0 - ssim(image, gt_image, size_average=False) # [3, H, W]
241 |
242 | log_dict["loss_l1"] = loss_l1.mean().item()
243 | log_dict["loss_ssim"] = loss_ssim.mean().item()
244 |
245 | loss = (
246 | 1.0 - args.lambda_dssim
247 | ) * loss_l1.mean() + args.lambda_dssim * loss_ssim.mean()
248 |
249 | psnr_for_log = psnr(image, gt_image).double()
250 | log_dict["psnr"] = psnr_for_log
251 |
252 | if args.lambda_lidar > 0:
253 | assert viewpoint_camera.pts_depth is not None
254 | pts_depth = viewpoint_camera.pts_depth.cuda()
255 |
256 | mask = pts_depth > 0
257 | loss_lidar = torch.abs(
258 | 1 / (pts_depth[mask] + 1e-5) - 1 / (depth[mask] + 1e-5)
259 | ).mean()
260 | if args.lidar_decay > 0:
261 | iter_decay = np.exp(-iteration / 8000 * args.lidar_decay)
262 | else:
263 | iter_decay = 1
264 | log_dict["loss_lidar"] = loss_lidar.item()
265 | loss += iter_decay * args.lambda_lidar * loss_lidar
266 |
267 | # if args.lambda_normal > 0 and args.load_normal_map:
268 | # alpha_mask = (alpha.data > EPS).repeat(3, 1, 1) # (3, H, W) detached
269 | # rendered_normal = render_pkg['normal'] # (3, H, W)
270 | # gt_normal = viewpoint_camera.normal_map.cuda()
271 | # loss_normal = F.l1_loss(rendered_normal[alpha_mask], gt_normal[alpha_mask])
272 | # loss_normal += tv_loss(rendered_normal)
273 | # log_dict['loss_normal'] = loss_normal.item()
274 | # loss += args.lambda_normal * loss_normal
275 |
276 | # if args.lambda_v_reg > 0 and args.enable_dynamic:
277 | # loss_v_reg = (torch.abs(v_map) * loss_mult).mean()
278 | # log_dict['loss_v_reg'] = loss_v_reg.item()
279 | # loss += args.lambda_v_reg * loss_v_reg
280 |
281 | # loss_mult[alpha.data < EPS] = 0.0
282 | # if args.lambda_t_reg > 0 and args.enable_dynamic:
283 | # loss_t_reg = (-torch.abs(t_map) * loss_mult).mean()
284 | # log_dict['loss_t_reg'] = loss_t_reg.item()
285 | # loss += args.lambda_t_reg * loss_t_reg
286 |
287 | if args.lambda_inv_depth > 0:
288 | inverse_depth = 1 / (depth + 1e-5)
289 | loss_inv_depth = kornia.losses.inverse_depth_smoothness_loss(
290 | inverse_depth[None], gt_image[None]
291 | )
292 | log_dict["loss_inv_depth"] = loss_inv_depth.item()
293 | loss = loss + args.lambda_inv_depth * loss_inv_depth
294 |
295 | if args.lambda_sky_opa > 0:
296 | o = alpha.clamp(1e-6, 1 - 1e-6)
297 | sky = sky_mask.float()
298 | loss_sky_opa = (-sky * torch.log(1 - o)).mean()
299 | log_dict["loss_sky_opa"] = loss_sky_opa.item()
300 | loss = loss + args.lambda_sky_opa * loss_sky_opa
301 |
302 | if args.lambda_opacity_entropy > 0:
303 | o = alpha.clamp(1e-6, 1 - 1e-6)
304 | loss_opacity_entropy = -(o * torch.log(o)).mean()
305 | log_dict["loss_opacity_entropy"] = loss_opacity_entropy.item()
306 | loss = loss + args.lambda_opacity_entropy * loss_opacity_entropy
307 |
308 | extra_render_pkg = {}
309 | extra_render_pkg["t_map"] = torch.zeros_like(alpha)
310 | extra_render_pkg["v_map"] = torch.zeros_like(alpha)
311 | # extra_render_pkg['depth'] = torch.zeros_like(alpha)
312 | extra_render_pkg["dynamic_mask"] = torch.zeros_like(alpha)
313 | extra_render_pkg["dino_cosine"] = torch.zeros_like(alpha)
314 |
315 | return loss, log_dict, extra_render_pkg
316 |
317 |
318 | def render_gs_origin_wrapper(
319 | args,
320 | viewpoint_camera: Camera,
321 | gaussians: GaussianModel,
322 | background: torch.Tensor,
323 | time_interval: float,
324 | env_map,
325 | iterations,
326 | camera_id,
327 | ):
328 |
329 | render_pkg = render_original_gs(
330 | viewpoint_camera, gaussians, args, background, env_map=env_map, is_training=True
331 | )
332 |
333 | loss, log_dict, extra_render_pkg = calculate_loss(
334 | gaussians, viewpoint_camera, args, render_pkg, env_map, iterations, camera_id
335 | )
336 |
337 | render_pkg.update(extra_render_pkg)
338 |
339 | return loss, log_dict, render_pkg
340 |
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y).mean()
22 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | imageio
3 | imageio[ffmpeg]
4 | kornia
5 | trimesh
6 | Pillow
7 | ninja
8 | omegaconf
9 | plyfile
10 | opencv_python
11 | opencv_contrib_python
12 | tensorboardX
13 | matplotlib
14 | texttable
15 | albumentations
16 | timm==0.9.10
17 | scikit-learn
18 | torch-kmeans
19 | albumentations
20 | albucore
21 | open3d
22 | tensorboard
23 | gdown
24 | h5py
25 | pytorch_lightning
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | import torch
16 | from tqdm import tqdm
17 | import numpy as np
18 | from utils.system_utils import searchForMaxIteration
19 | from scene.gaussian_model import GaussianModel
20 | from scene.envlight import EnvLight
21 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, calculate_mean_and_std
22 | from scene.waymo_loader import readWaymoInfo
23 | from scene.kittimot_loader import readKittiMotInfo
24 | from scene.emer_waymo_loader import readEmerWaymoInfo
25 | import logging
26 | sceneLoadTypeCallbacks = {
27 | "Waymo": readWaymoInfo,
28 | "KittiMot": readKittiMotInfo,
29 | 'EmerWaymo': readEmerWaymoInfo,
30 | }
31 |
32 | class Scene:
33 |
34 | gaussians : GaussianModel
35 |
36 | def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True):
37 | self.model_path = args.model_path
38 | self.loaded_iter = None
39 | self.gaussians = gaussians
40 | self.white_background = args.white_background
41 |
42 | if load_iteration:
43 | if load_iteration == -1:
44 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
45 | else:
46 | self.loaded_iter = load_iteration
47 | logging.info("Loading trained model at iteration {}".format(self.loaded_iter))
48 |
49 | self.train_cameras = {}
50 | self.test_cameras = {}
51 |
52 | scene_info = sceneLoadTypeCallbacks[args.scene_type](args)
53 |
54 | self.time_interval = args.frame_interval
55 | self.gaussians.time_duration = scene_info.time_duration
56 | # print("time duration: ", scene_info.time_duration)
57 | # print("frame interval: ", self.time_interval)
58 |
59 |
60 | if not self.loaded_iter:
61 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
62 | dest_file.write(src_file.read())
63 | json_cams = []
64 | camlist = []
65 | if scene_info.test_cameras:
66 | camlist.extend(scene_info.test_cameras)
67 | if scene_info.train_cameras:
68 | camlist.extend(scene_info.train_cameras)
69 | for id, cam in enumerate(camlist):
70 | json_cams.append(camera_to_JSON(id, cam))
71 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
72 | json.dump(json_cams, file)
73 |
74 | if shuffle:
75 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
76 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
77 |
78 | self.cameras_extent = scene_info.nerf_normalization["radius"]
79 | self.resolution_scales = args.resolution_scales
80 | self.scale_index = len(self.resolution_scales) - 1
81 | for resolution_scale in self.resolution_scales:
82 | logging.info("Loading Training Cameras at resolution scale {}".format(resolution_scale))
83 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
84 | logging.info("Loading Test Cameras")
85 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
86 | logging.info("Computing nearest_id")
87 | image_name_to_id = {"train": {cam.image_name: id for id, cam in enumerate(self.train_cameras[resolution_scale])},
88 | "test": {cam.image_name: id for id, cam in enumerate(self.test_cameras[resolution_scale])}}
89 | with open(os.path.join(self.model_path, "multi_view.json"), 'w') as file:
90 | json_d = []
91 | for id, cur_cam in enumerate(tqdm(self.train_cameras[resolution_scale], bar_format='{l_bar}{bar:50}{r_bar}')):
92 | image_name_to_id_map = image_name_to_id["train"]
93 | # cur_image_name = cur_cam.image_name
94 | cur_colmap_id = cur_cam.colmap_id
95 | nearest_colmap_id_candidate = [cur_colmap_id - 10, cur_colmap_id + 10, cur_colmap_id - 20, cur_colmap_id + 20]
96 |
97 | for colmap_id in nearest_colmap_id_candidate:
98 | near_image_name = "{:03d}_{:1d}".format(colmap_id // 10, colmap_id % 10)
99 | if near_image_name in image_name_to_id_map:
100 | cur_cam.nearest_id.append(image_name_to_id_map[near_image_name])
101 | cur_cam.nearest_names.append(near_image_name)
102 |
103 | json_d.append({'ref_name' : cur_cam.image_name, 'nearest_name': cur_cam.nearest_names, "id": id, 'nearest_id': cur_cam.nearest_id})
104 | json.dump(json_d, file)
105 |
106 |
107 | if resolution_scale == 1.0:
108 | logging.info("Computing mean and std of dataset")
109 | mean = []
110 | std = []
111 | all_cameras = self.train_cameras[resolution_scale] + self.test_cameras[resolution_scale]
112 | for idx, viewpoint in enumerate(tqdm(all_cameras, bar_format="{l_bar}{bar:50}{r_bar}")):
113 | gt_image = viewpoint.original_image # [3, H, W]
114 | mean.append(gt_image.mean(dim=[1, 2]).cpu().numpy())
115 | std.append(gt_image.std(dim=[1, 2]).cpu().numpy())
116 | mean = np.array(mean)
117 | std = np.array(std)
118 | # calculate mean and std of dataset
119 | mean_dataset, std_rgb_dataset = calculate_mean_and_std(mean, std)
120 |
121 | if gaussians.uncertainty_model is not None:
122 | gaussians.uncertainty_model.img_norm_mean = torch.from_numpy(mean_dataset)
123 | gaussians.uncertainty_model.img_norm_std = torch.from_numpy(std_rgb_dataset)
124 |
125 | if self.loaded_iter:
126 | self.gaussians.load_ply(os.path.join(self.model_path,
127 | "point_cloud",
128 | "iteration_" + str(self.loaded_iter),
129 | "point_cloud.ply"))
130 | else:
131 | self.gaussians.create_from_pcd(scene_info.point_cloud, 1)
132 |
133 | def upScale(self):
134 | self.scale_index = max(0, self.scale_index - 1)
135 |
136 | def getTrainCameras(self):
137 | return self.train_cameras[self.resolution_scales[self.scale_index]]
138 |
139 | def getTestCameras(self, scale=1.0):
140 | return self.test_cameras[scale]
141 |
142 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import math
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 | import numpy as np
17 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getProjectionMatrixCenterShift
18 | import kornia
19 |
20 |
21 | class Camera(nn.Module):
22 | def __init__(self, colmap_id, R, T, FoVx=None, FoVy=None, cx=None, cy=None, fx=None, fy=None,
23 | image=None,
24 | image_name=None, uid=0,
25 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", timestamp=0.0,
26 | resolution=None, image_path=None,
27 | pts_depth=None, sky_mask=None, dynamic_mask=None, normal_map=None, ncc_scale=1.0,
28 | ):
29 | super(Camera, self).__init__()
30 |
31 | self.uid = uid
32 | self.colmap_id = colmap_id
33 | self.R = R
34 | self.T = T
35 | self.FoVx = FoVx
36 | self.FoVy = FoVy
37 | self.image_name = image_name
38 | self.image = image
39 | self.cx = cx
40 | self.cy = cy
41 | self.fx = fx
42 | self.fy = fy
43 | self.resolution = resolution
44 | self.image_path = image_path
45 | self.ncc_scale = ncc_scale
46 | self.nearest_id = []
47 | self.nearest_names = []
48 |
49 | try:
50 | self.data_device = torch.device(data_device)
51 | except Exception as e:
52 | print(e)
53 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device")
54 | self.data_device = torch.device("cuda")
55 |
56 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
57 | self.sky_mask = sky_mask.to(self.data_device) > 0 if sky_mask is not None else sky_mask
58 | self.pts_depth = pts_depth.to(self.data_device) if pts_depth is not None else pts_depth
59 | self.dynamic_mask = dynamic_mask.to(self.data_device) > 0 if dynamic_mask is not None else dynamic_mask
60 | self.normal_map = normal_map.to(self.data_device) if normal_map is not None else normal_map
61 | self.image_width = resolution[0]
62 | self.image_height = resolution[1]
63 |
64 | self.zfar = 1000.0
65 | self.znear = 0.01
66 |
67 | self.trans = trans
68 | self.scale = scale
69 |
70 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
71 | if cx is not None:
72 | self.FoVx = 2 * math.atan(0.5*self.image_width / fx) if FoVx is None else FoVx
73 | self.FoVy = 2 * math.atan(0.5*self.image_height / fy) if FoVy is None else FoVy
74 | self.projection_matrix = getProjectionMatrixCenterShift(self.znear, self.zfar, cx, cy, fx, fy,
75 | self.image_width, self.image_height).transpose(0, 1).cuda()
76 | else:
77 | self.cx = self.image_width / 2
78 | self.cy = self.image_height / 2
79 | self.fx = self.image_width / (2 * np.tan(self.FoVx * 0.5))
80 | self.fy = self.image_height / (2 * np.tan(self.FoVy * 0.5))
81 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx,
82 | fovY=self.FoVy).transpose(0, 1).cuda()
83 | self.full_proj_transform = (
84 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
85 | self.camera_center = self.world_view_transform.inverse()[3, :3]
86 | self.c2w = self.world_view_transform.transpose(0, 1).inverse()
87 | self.timestamp = timestamp
88 | self.grid = kornia.utils.create_meshgrid(self.image_height, self.image_width, normalized_coordinates=False, device='cuda')[0]
89 |
90 | def get_world_directions(self, train=False):
91 | u, v = self.grid.unbind(-1)
92 | if train:
93 | directions = torch.stack([(u-self.cx+torch.rand_like(u))/self.fx,
94 | (v-self.cy+torch.rand_like(v))/self.fy,
95 | torch.ones_like(u)], dim=0)
96 | else:
97 | directions = torch.stack([(u-self.cx+0.5)/self.fx,
98 | (v-self.cy+0.5)/self.fy,
99 | torch.ones_like(u)], dim=0)
100 | directions = F.normalize(directions, dim=0)
101 | directions = (self.c2w[:3, :3] @ directions.reshape(3, -1)).reshape(3, self.image_height, self.image_width)
102 | return directions
103 |
104 | def get_image(self):
105 | original_image = self.original_image # [3, H, W]
106 |
107 | # convert the image tensor to grayscale L = R * 299/1000 + G * 587/1000 + B * 114/1000
108 | # image_gray = (0.299 * original_image[0] + 0.587 * original_image[1] + 0.114 * original_image[2])[None]
109 | image_gray = original_image[0:1] * 0.299 + original_image[1:2] * 0.587 + original_image[2:3] * 0.114
110 |
111 | return original_image.cuda(), image_gray.cuda()
112 |
113 |
114 | def get_calib_matrix_nerf(self, scale=1.0):
115 | intrinsic_matrix = torch.tensor([[self.fx/scale, 0, self.cx/scale], [0, self.fy/scale, self.cy/scale], [0, 0, 1]]).float()
116 | extrinsic_matrix = self.world_view_transform.transpose(0,1).contiguous() # cam2world
117 | return intrinsic_matrix, extrinsic_matrix
118 |
119 | def get_rays(self, scale=1.0):
120 | W, H = int(self.image_width/scale), int(self.image_height/scale)
121 | ix, iy = torch.meshgrid(
122 | torch.arange(W), torch.arange(H), indexing='xy')
123 | rays_d = torch.stack(
124 | [(ix-self.cx/scale) / self.fx * scale,
125 | (iy-self.cy/scale) / self.fy * scale,
126 | torch.ones_like(ix)], -1).float().cuda()
127 | return rays_d
128 |
129 | def get_k(self, scale=1.0):
130 | K = torch.tensor([[self.fx / scale, 0, self.cx / scale],
131 | [0, self.fy / scale, self.cy / scale],
132 | [0, 0, 1]]).cuda()
133 | return K
134 |
135 | def get_inv_k(self, scale=1.0):
136 | K_T = torch.tensor([[scale/self.fx, 0, -self.cx/self.fx],
137 | [0, scale/self.fy, -self.cy/self.fy],
138 | [0, 0, 1]]).cuda()
139 | return K_T
--------------------------------------------------------------------------------
/scene/emer_waymo_loader.py:
--------------------------------------------------------------------------------
1 | # Description: Load the EmerWaymo dataset for training and testing
2 | # adapted from the PVG datareader for the data from EmerNeRF
3 |
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly
9 | from utils.graphics_utils import BasicPointCloud, focal2fov
10 |
11 | def pad_poses(p):
12 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
13 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
14 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
15 |
16 |
17 | def unpad_poses(p):
18 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
19 | return p[..., :3, :4]
20 |
21 |
22 | def transform_poses_pca(poses, fix_radius=0):
23 | """Transforms poses so principal components lie on XYZ axes.
24 |
25 | Args:
26 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
27 |
28 | Returns:
29 | A tuple (poses, transform), with the transformed poses and the applied
30 | camera_to_world transforms.
31 |
32 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161
33 | """
34 | t = poses[:, :3, 3]
35 | t_mean = t.mean(axis=0)
36 | t = t - t_mean
37 |
38 | eigval, eigvec = np.linalg.eig(t.T @ t)
39 | # Sort eigenvectors in order of largest to smallest eigenvalue.
40 | inds = np.argsort(eigval)[::-1]
41 | eigvec = eigvec[:, inds]
42 | rot = eigvec.T
43 | if np.linalg.det(rot) < 0:
44 | rot = np.diag(np.array([1, 1, -1])) @ rot
45 |
46 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
47 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
48 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
49 |
50 | # Flip coordinate system if z component of y-axis is negative
51 | if poses_recentered.mean(axis=0)[2, 1] < 0:
52 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
53 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
54 |
55 | # Just make sure it's it in the [-1, 1]^3 cube
56 | if fix_radius>0:
57 | scale_factor = 1./fix_radius
58 | else:
59 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5)
60 | scale_factor = min(1 / 10, scale_factor)
61 |
62 | poses_recentered[:, :3, 3] *= scale_factor
63 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
64 |
65 | return poses_recentered, transform, scale_factor
66 |
67 |
68 | def readEmerWaymoInfo(args):
69 |
70 | eval = args.eval
71 | load_sky_mask = args.load_sky_mask
72 | load_panoptic_mask = args.load_panoptic_mask
73 | load_sam_mask = args.load_sam_mask
74 | load_dynamic_mask = args.load_dynamic_mask
75 | load_normal_map = args.load_normal_map
76 | load_feat_map = args.load_feat_map
77 | load_intrinsic = args.load_intrinsic
78 | load_c2w = args.load_c2w
79 | save_occ_grid = args.save_occ_grid
80 | occ_voxel_size = args.occ_voxel_size
81 | recompute_occ_grid = args.recompute_occ_grid
82 | use_bg_gs = args.use_bg_gs
83 | white_background = args.white_background
84 | neg_fov = args.neg_fov
85 |
86 |
87 | num_pts = args.num_pts
88 | start_time = args.start_time
89 | end_time = args.end_time
90 | stride = args.stride
91 | original_start_time = args.original_start_time
92 |
93 | ORIGINAL_SIZE = [[1280, 1920], [1280, 1920], [1280, 1920], [884, 1920], [884, 1920]]
94 | OPENCV2DATASET = np.array(
95 | [[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]]
96 | )
97 | load_size = [640, 960]
98 |
99 | cam_infos = []
100 | points = []
101 | points_time = []
102 |
103 | data_root = args.source_path
104 | image_folder = os.path.join(data_root, "images")
105 | num_seqs = len(os.listdir(image_folder))/5
106 | if end_time == -1:
107 | end_time = int(num_seqs)
108 | else:
109 | end_time += 1
110 |
111 | frame_num = end_time - start_time
112 | # assert frame_num == 50, "frame_num should be 50"
113 | time_duration = args.time_duration
114 | time_interval = (time_duration[1] - time_duration[0]) / (end_time - start_time)
115 |
116 | camera_list = [0, 1, 2]
117 | truncated_min_range, truncated_max_range = -2, 80
118 |
119 |
120 | # ---------------------------------------------
121 | # load poses: intrinsic, c2w, l2w per camera
122 | # ---------------------------------------------
123 | _intrinsics = []
124 | cam_to_egos = []
125 | for i in camera_list:
126 | # load intrinsics
127 | intrinsic = np.loadtxt(os.path.join(data_root, "intrinsics", f"{i}.txt"))
128 | fx, fy, cx, cy = intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]
129 | # scale intrinsics w.r.t. load size
130 | fx, fy = (
131 | fx * load_size[1] / ORIGINAL_SIZE[i][1],
132 | fy * load_size[0] / ORIGINAL_SIZE[i][0],
133 | )
134 | cx, cy = (
135 | cx * load_size[1] / ORIGINAL_SIZE[i][1],
136 | cy * load_size[0] / ORIGINAL_SIZE[i][0],
137 | )
138 | intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
139 | _intrinsics.append(intrinsic)
140 | # load extrinsics
141 | cam_to_ego = np.loadtxt(os.path.join(data_root, "extrinsics", f"{i}.txt"))
142 | # opencv coordinate system: x right, y down, z front
143 | # waymo coordinate system: x front, y left, z up
144 | cam_to_egos.append(cam_to_ego @ OPENCV2DATASET) # opencv_cam -> waymo_cam -> waymo_ego
145 |
146 |
147 | # ---------------------------------------------
148 | # get c2w and w2c transformation per frame and camera
149 | # ---------------------------------------------
150 | # compute per-image poses and intrinsics
151 | cam_to_worlds, ego_to_worlds = [], []
152 | intrinsics, cam_ids = [], []
153 | lidar_to_worlds = []
154 | # ===! for waymo, we simplify timestamps as the time indices
155 | timestamps, timesteps = [], []
156 | # we tranform the camera poses w.r.t. the first timestep to make the translation vector of
157 | # the first ego pose as the origin of the world coordinate system.
158 | ego_to_world_start = np.loadtxt(os.path.join(data_root, "ego_pose", f"{start_time:03d}.txt"))
159 | for t in range(start_time, end_time):
160 | ego_to_world_current = np.loadtxt(os.path.join(data_root, "ego_pose", f"{t:03d}.txt"))
161 | # ego to world transformation: cur_ego -> world -> start_ego(world)
162 | ego_to_world = np.linalg.inv(ego_to_world_start) @ ego_to_world_current
163 | ego_to_worlds.append(ego_to_world)
164 | for cam_id in camera_list:
165 | cam_ids.append(cam_id)
166 | # transformation:
167 | # opencv_cam -> waymo_cam -> waymo_cur_ego -> world -> start_ego(world)
168 | cam2world = ego_to_world @ cam_to_egos[cam_id]
169 | cam_to_worlds.append(cam2world)
170 | intrinsics.append(_intrinsics[cam_id])
171 | # ===! we use time indices as the timestamp for waymo dataset for simplicity
172 | # ===! we can use the actual timestamps if needed
173 | # to be improved
174 | timestamps.append(t - start_time)
175 | timesteps.append(t - start_time)
176 | # lidar to world : lidar = ego in waymo
177 | lidar_to_worlds.append(ego_to_world)
178 | # convert to numpy arrays
179 | intrinsics = np.stack(intrinsics, axis=0)
180 | cam_to_worlds = np.stack(cam_to_worlds, axis=0)
181 | ego_to_worlds = np.stack(ego_to_worlds, axis=0)
182 | lidar_to_worlds = np.stack(lidar_to_worlds, axis=0)
183 | cam_ids = np.array(cam_ids)
184 | timestamps = np.array(timestamps)
185 | timesteps = np.array(timesteps)
186 |
187 |
188 | # ---------------------------------------------
189 | # get image, sky_mask, lidar per frame and camera
190 | # ---------------------------------------------
191 | accumulated_num_original_rays = 0
192 | accumulated_num_rays = 0
193 |
194 | for idx, t in enumerate(tqdm(range(start_time, end_time), desc="Loading data", bar_format='{l_bar}{bar:50}{r_bar}')):
195 |
196 | images = []
197 | image_paths = []
198 | HWs = []
199 | sky_masks = []
200 | dynamic_masks = []
201 | normal_maps = []
202 |
203 | for cam_idx in camera_list:
204 | image_path = os.path.join(args.source_path, "images", f"{t:03d}_{cam_idx}.jpg")
205 | im_data = Image.open(image_path)
206 | im_data = im_data.resize((load_size[1], load_size[0]), Image.BILINEAR) # PIL resize: (W, H)
207 | W, H = im_data.size
208 | image = np.array(im_data) / 255.
209 | HWs.append((H, W))
210 | images.append(image)
211 | image_paths.append(image_path)
212 |
213 | sky_path = os.path.join(args.source_path, "sky_masks", f"{t:03d}_{cam_idx}.png")
214 | sky_data = Image.open(sky_path)
215 | sky_data = sky_data.resize((load_size[1], load_size[0]), Image.NEAREST) # PIL resize: (W, H)
216 | sky_mask = np.array(sky_data)>0
217 | sky_masks.append(sky_mask.astype(np.float32))
218 |
219 | if load_normal_map:
220 | normal_path = os.path.join(args.source_path, "normals", f"{t:03d}_{cam_idx}.jpg")
221 | normal_data = Image.open(normal_path)
222 | normal_data = normal_data.resize((load_size[1], load_size[0]), Image.BILINEAR)
223 | normal_map = (np.array(normal_data)) / 255. # [0, 1]
224 | normal_maps.append(normal_map)
225 |
226 | if load_dynamic_mask:
227 | dynamic_path = os.path.join(args.source_path, "dynamic_masks", f"{t:03d}_{cam_idx}.png")
228 | dynamic_data = Image.open(dynamic_path)
229 | dynamic_data = dynamic_data.resize((load_size[1], load_size[0]), Image.BILINEAR)
230 | dynamic_mask = np.array(dynamic_data)>0
231 | dynamic_masks.append(dynamic_mask.astype(np.float32))
232 |
233 |
234 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (frame_num - 1)
235 | lidar_info = np.memmap(
236 | os.path.join(data_root, "lidar", f"{t:03d}.bin"),
237 | dtype=np.float32,
238 | mode="r",
239 | ).reshape(-1, 10)
240 | original_length = len(lidar_info)
241 | accumulated_num_original_rays += original_length
242 | lidar_origins = lidar_info[:, :3]
243 | lidar_points = lidar_info[:, 3:6]
244 | lidar_ids = lidar_info[:, -1]
245 | # select lidar points based on a truncated ego-forward-directional range
246 | # make sure most of lidar points are within the range of the camera
247 | valid_mask = lidar_points[:, 0] < truncated_max_range
248 | valid_mask = valid_mask & (lidar_points[:, 0] > truncated_min_range)
249 | lidar_origins = lidar_origins[valid_mask]
250 | lidar_points = lidar_points[valid_mask]
251 | lidar_ids = lidar_ids[valid_mask]
252 | # transform lidar points to world coordinate system
253 | lidar_origins = (
254 | lidar_to_worlds[idx][:3, :3] @ lidar_origins.T
255 | + lidar_to_worlds[idx][:3, 3:4]
256 | ).T
257 | lidar_points = (
258 | lidar_to_worlds[idx][:3, :3] @ lidar_points.T
259 | + lidar_to_worlds[idx][:3, 3:4]
260 | ).T # point_xyz_world
261 |
262 | points.append(lidar_points)
263 | point_time = np.full_like(lidar_points[:, :1], timestamp)
264 | points_time.append(point_time)
265 |
266 | for cam_idx in camera_list:
267 | # world-lidar-pts --> camera-pts : w2c
268 | c2w = cam_to_worlds[int(len(camera_list))*idx + cam_idx]
269 | w2c = np.linalg.inv(c2w)
270 | point_camera = (
271 | w2c[:3, :3] @ lidar_points.T
272 | + w2c[:3, 3:4]
273 | ).T
274 |
275 | R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
276 | T = w2c[:3, 3]
277 | K = _intrinsics[cam_idx]
278 | fx = float(K[0, 0])
279 | fy = float(K[1, 1])
280 | cx = float(K[0, 2])
281 | cy = float(K[1, 2])
282 | height, width = HWs[cam_idx]
283 | if neg_fov:
284 | FovY = -1.0
285 | FovX = -1.0
286 | else:
287 | FovY = focal2fov(fy, height)
288 | FovX = focal2fov(fx, width)
289 | cam_infos.append(CameraInfo(uid=idx * 10 + cam_idx, R=R, T=T, FovY=FovY, FovX=FovX,
290 | image=images[cam_idx],
291 | image_path=image_paths[cam_idx], image_name=f"{t:03d}_{cam_idx}",
292 | width=width, height=height, timestamp=timestamp,
293 | pointcloud_camera = point_camera,
294 | fx=fx, fy=fy, cx=cx, cy=cy,
295 | sky_mask=sky_masks[cam_idx],
296 | dynamic_mask=dynamic_masks[cam_idx] if load_dynamic_mask else None,
297 | normal_map=normal_maps[cam_idx] if load_normal_map else None,))
298 |
299 | if args.debug_cuda:
300 | break
301 |
302 | pointcloud = np.concatenate(points, axis=0)
303 | pointcloud_timestamp = np.concatenate(points_time, axis=0)
304 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True)
305 | pointcloud = pointcloud[indices]
306 | pointcloud_timestamp = pointcloud_timestamp[indices]
307 |
308 | w2cs = np.zeros((len(cam_infos), 4, 4))
309 | Rs = np.stack([c.R for c in cam_infos], axis=0)
310 | Ts = np.stack([c.T for c in cam_infos], axis=0)
311 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1))
312 | w2cs[:, :3, 3] = Ts
313 | w2cs[:, 3, 3] = 1
314 | c2ws = unpad_poses(np.linalg.inv(w2cs))
315 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius)
316 |
317 | c2ws = pad_poses(c2ws)
318 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data", bar_format='{l_bar}{bar:50}{r_bar}')):
319 | c2w = c2ws[idx]
320 | w2c = np.linalg.inv(c2w)
321 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
322 | cam_info.T[:] = w2c[:3, 3]
323 | cam_info.pointcloud_camera[:] *= scale_factor
324 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3]
325 | if args.eval:
326 | # ## for snerf scene
327 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0]
328 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0]
329 |
330 | # for dynamic scene
331 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0]
332 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0]
333 |
334 | # for emernerf comparison [testhold::testhold]
335 | if args.testhold == 10:
336 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0]
337 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0]
338 | else:
339 | train_cam_infos = cam_infos
340 | test_cam_infos = []
341 |
342 | nerf_normalization = getNerfppNorm(train_cam_infos)
343 | nerf_normalization['radius'] = 1/nerf_normalization['radius']
344 |
345 | ply_path = os.path.join(args.source_path, "points3d.ply")
346 | if not os.path.exists(ply_path):
347 | rgbs = np.random.random((pointcloud.shape[0], 3))
348 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp)
349 | try:
350 | pcd = fetchPly(ply_path)
351 | except:
352 | pcd = None
353 |
354 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp)
355 |
356 | scene_info = SceneInfo(point_cloud=pcd,
357 | train_cameras=train_cam_infos,
358 | test_cameras=test_cam_infos,
359 | nerf_normalization=nerf_normalization,
360 | ply_path=ply_path,
361 | time_interval=time_interval,
362 | time_duration=time_duration)
363 |
364 | return scene_info
365 |
--------------------------------------------------------------------------------
/scene/envlight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import nvdiffrast.torch as dr
3 |
4 |
5 | class EnvLight(torch.nn.Module):
6 |
7 | def __init__(self, resolution=1024):
8 | super().__init__()
9 | self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")
10 | self.base = torch.nn.Parameter(
11 | 0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),
12 | )
13 |
14 | def capture(self):
15 | return (
16 | self.base,
17 | self.optimizer.state_dict(),
18 | )
19 |
20 | def restore(self, model_args, training_args=None):
21 | self.base, opt_dict = model_args
22 | if training_args is not None:
23 | self.training_setup(training_args)
24 | self.optimizer.load_state_dict(opt_dict)
25 |
26 | def training_setup(self, training_args):
27 | self.optimizer = torch.optim.Adam(self.parameters(), lr=training_args.envmap_lr, eps=1e-15)
28 |
29 | def forward(self, l):
30 | l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape)
31 | l = l.contiguous()
32 | prefix = l.shape[:-1]
33 | if len(prefix) != 3: # reshape to [B, H, W, -1]
34 | l = l.reshape(1, 1, -1, l.shape[-1])
35 |
36 | light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube')
37 | light = light.view(*prefix, -1)
38 |
39 | return light
40 |
--------------------------------------------------------------------------------
/scene/fit3d.py:
--------------------------------------------------------------------------------
1 | """
2 | timm 0.9.10 must be installed to use the get_intermediate_layers method.
3 | pip install timm==0.9.10
4 | pip install torch_kmeans
5 |
6 | """
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import os
11 | import requests
12 | import itertools
13 | import timm
14 | import torch
15 | import types
16 | import albumentations as A
17 | from torch.nn import functional as F
18 |
19 | from PIL import Image
20 | from sklearn.decomposition import PCA
21 | from torch_kmeans import KMeans, CosineSimilarity
22 | from scene.dinov2 import dinov2_vits14_reg
23 | import cv2
24 |
25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 | cmap = plt.get_cmap("tab20")
27 | MEAN = np.array([0.3609696, 0.38405442, 0.4348492])
28 | STD = np.array([0.19669543, 0.20297967, 0.22123419])
29 |
30 | options = ['DINOv2-reg']
31 |
32 | timm_model_card = {
33 | "DINOv2": "vit_small_patch14_dinov2.lvd142m",
34 | "DINOv2-reg": "vit_small_patch14_reg4_dinov2.lvd142m",
35 | "CLIP": "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k",
36 | "MAE": "vit_base_patch16_224.mae",
37 | "DeiT-III": "deit3_base_patch16_224.fb_in1k"
38 | }
39 |
40 | our_model_card = {
41 | "DINOv2": "dinov2_small_fine",
42 | "DINOv2-reg": "dinov2_reg_small_fine",
43 | "CLIP": "clip_base_fine",
44 | "MAE": "mae_base_fine",
45 | "DeiT-III": "deit3_base_fine"
46 | }
47 |
48 |
49 |
50 | # transforms = A.Compose([
51 | # A.Normalize(mean=list(MEAN), std=list(STD)),
52 | # ])
53 |
54 |
55 | def get_intermediate_layers(
56 | self,
57 | x: torch.Tensor,
58 | n=1,
59 | reshape: bool = False,
60 | return_prefix_tokens: bool = False,
61 | return_class_token: bool = False,
62 | norm: bool = True,
63 | ):
64 |
65 | outputs = self._intermediate_layers(x, n)
66 | if norm:
67 | outputs = [self.norm(out) for out in outputs]
68 | if return_class_token:
69 | prefix_tokens = [out[:, 0] for out in outputs]
70 | else:
71 | prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
72 | outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
73 |
74 | if reshape:
75 | B, C, H, W = x.shape
76 | grid_size = (
77 | (H - self.patch_embed.patch_size[0])
78 | // self.patch_embed.proj.stride[0]
79 | + 1,
80 | (W - self.patch_embed.patch_size[1])
81 | // self.patch_embed.proj.stride[1]
82 | + 1,
83 | )
84 | outputs = [
85 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
86 | .permute(0, 3, 1, 2)
87 | .contiguous()
88 | for out in outputs
89 | ]
90 |
91 | if return_prefix_tokens or return_class_token:
92 | return tuple(zip(outputs, prefix_tokens))
93 | return tuple(outputs)
94 |
95 |
96 | def viz_feat(feat):
97 |
98 | _, _, h, w = feat.shape
99 | feat = feat.squeeze(0).permute((1,2,0))
100 | projected_featmap = feat.reshape(-1, feat.shape[-1]).cpu()
101 |
102 | pca = PCA(n_components=3)
103 | pca.fit(projected_featmap)
104 | pca_features = pca.transform(projected_featmap)
105 | pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min())
106 | pca_features = pca_features * 255
107 | res_pred = Image.fromarray(pca_features.reshape(h, w, 3).astype(np.uint8))
108 |
109 | return res_pred
110 |
111 |
112 | def plot_feats(image, model_option, ori_feats, fine_feats, ori_labels=None, fine_labels=None):
113 |
114 | ori_feats_map = viz_feat(ori_feats)
115 | fine_feats_map = viz_feat(fine_feats)
116 |
117 | if ori_labels is not None:
118 | fig, ax = plt.subplots(2, 3, figsize=(10, 5))
119 | ax[0][0].imshow(image)
120 | ax[0][0].set_title("Input image", fontsize=15)
121 | ax[0][1].imshow(ori_feats_map)
122 | ax[0][1].set_title("Original " + model_option, fontsize=15)
123 | ax[0][2].imshow(fine_feats_map)
124 | ax[0][2].set_title("Ours", fontsize=15)
125 | ax[1][1].imshow(ori_labels)
126 | ax[1][2].imshow(fine_labels)
127 | for xx in ax:
128 | for x in xx:
129 | x.xaxis.set_major_formatter(plt.NullFormatter())
130 | x.yaxis.set_major_formatter(plt.NullFormatter())
131 | x.set_xticks([])
132 | x.set_yticks([])
133 | x.axis('off')
134 |
135 | else:
136 | fig, ax = plt.subplots(1, 3, figsize=(30, 8))
137 | ax[0].imshow(image)
138 | ax[0].set_title("Input image", fontsize=15)
139 | ax[1].imshow(ori_feats_map)
140 | ax[1].set_title("Original " + model_option, fontsize=15)
141 | ax[2].imshow(fine_feats_map)
142 | ax[2].set_title("FiT3D", fontsize=15)
143 |
144 | for x in ax:
145 | x.xaxis.set_major_formatter(plt.NullFormatter())
146 | x.yaxis.set_major_formatter(plt.NullFormatter())
147 | x.set_xticks([])
148 | x.set_yticks([])
149 | x.axis('off')
150 |
151 | plt.tight_layout()
152 | plt.savefig("output2.png")
153 | # plt.close(fig)
154 | return fig
155 |
156 |
157 | def download_image(url, save_path):
158 | response = requests.get(url)
159 | with open(save_path, 'wb') as file:
160 | file.write(response.content)
161 |
162 |
163 | def process_image(image, stride, transforms):
164 | transformed = transforms(image=np.array(image))
165 | image_tensor = torch.tensor(transformed['image'])
166 | image_tensor = image_tensor.permute(2,0,1)
167 | image_tensor = image_tensor.unsqueeze(0).to(device)
168 |
169 | h, w = image_tensor.shape[2:]
170 |
171 | height_int = ((h + stride-1) // stride)*stride
172 | width_int = ((w+stride-1) // stride)*stride
173 |
174 | image_resized = torch.nn.functional.interpolate(image_tensor, size=(height_int, width_int), mode='bilinear')
175 |
176 | return image_resized
177 |
178 |
179 | def kmeans_clustering(feats_map, n_clusters=20):
180 |
181 | B, D, h, w = feats_map.shape
182 | feats_map_flattened = feats_map.permute((0, 2, 3, 1)).reshape(B, -1, D)
183 |
184 | kmeans_engine = KMeans(n_clusters=n_clusters, distance=CosineSimilarity)
185 | kmeans_engine.fit(feats_map_flattened)
186 | labels = kmeans_engine.predict(
187 | feats_map_flattened
188 | )
189 | labels = labels.reshape(
190 | B, h, w
191 | ).float()
192 | labels = labels[0].cpu().numpy()
193 |
194 | label_map = cmap(labels / n_clusters)[..., :3]
195 | label_map = np.uint8(label_map * 255)
196 | label_map = Image.fromarray(label_map)
197 |
198 | return label_map
199 |
200 |
201 | def run_demo(original_model, fine_model, image_path, kmeans=20):
202 | """
203 | Run the demo for a given model option and image
204 | model_option: ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III']
205 | image_path: path to the image
206 | kmeans: number of clusters for kmeans. Default is 20. -1 means no kmeans.
207 | """
208 | p = original_model.patch_embed.patch_size
209 | stride = p if isinstance(p, int) else p[0]
210 | image = Image.open(image_path)
211 | transforms = A.Compose([
212 | A.Normalize(mean=list(MEAN), std=list(STD)),
213 | ])
214 | image_resized = process_image(image, stride, transforms)
215 | with torch.no_grad():
216 | ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True,
217 | return_class_token=False, norm=True)
218 | fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True,
219 | return_class_token=False, norm=True)
220 |
221 | ori_feats = ori_feats[-1]
222 | fine_feats = fine_feats[-1]
223 | if kmeans != -1:
224 | ori_labels = kmeans_clustering(ori_feats, kmeans)
225 | fine_labels = kmeans_clustering(fine_feats, kmeans)
226 | else:
227 | ori_labels = None
228 | fine_labels = None
229 | print("image shape: ", image.size)
230 | print("image_resized shape: ", image_resized.shape)
231 | print("ori_feats shape: ", ori_feats.shape)
232 | print("fine_feats shape: ", fine_feats.shape)
233 |
234 |
235 | return plot_feats(image, "DINOv2-reg", ori_feats, fine_feats, ori_labels, fine_labels),ori_feats,fine_feats
236 |
237 | def Dinov2RegExtractor(original_model, fine_model, image,transforms=None, kmeans=20,only_fine_feats:bool=False):
238 | """
239 | Run the demo for a given model option and image
240 | model_option: ['DINOv2', 'DINOv2-reg', 'CLIP', 'MAE', 'DeiT-III']
241 | image_path: path to the image
242 | kmeans: number of clusters for kmeans. Default is 20. -1 means no kmeans.
243 | """
244 | p = original_model.patch_embed.patch_size
245 | stride = p if isinstance(p, int) else p[0]
246 | image=image.cpu().numpy()
247 | image_array = (image * 255).astype(np.uint8)
248 | image_array=image_array.squeeze(0).transpose(1,2,0)
249 | image = Image.fromarray(image_array)
250 | fine_feats=None
251 | ori_feats=None
252 | if transforms is not None:
253 | image_resized = process_image(image, stride, transforms)
254 | else:
255 | image_resized=image
256 | with torch.no_grad():
257 |
258 | fine_feats = fine_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True,
259 | return_class_token=False, norm=True)
260 | if not only_fine_feats:
261 | ori_feats = original_model.get_intermediate_layers(image_resized, n=[8,9,10,11], reshape=True,
262 | return_class_token=False, norm=True)
263 |
264 |
265 | fine_feats = fine_feats[-1]
266 | if not only_fine_feats:
267 | ori_feats = ori_feats[-1]
268 | # For semantic segmentation
269 | # if kmeans != -1:
270 | # ori_labels = kmeans_clustering(ori_feats, kmeans)
271 | # fine_labels = kmeans_clustering(fine_feats, kmeans)
272 | # else:
273 | # ori_labels = None
274 | # fine_labels = None
275 | # print("image shape: ", image.size)
276 | # print("image_resized shape: ", image_resized.shape)
277 | # print("ori_feats shape: ", ori_feats.shape)
278 | # print("fine_feats shape: ", fine_feats.shape)
279 |
280 |
281 | return ori_feats,fine_feats
282 |
283 |
284 | def LoadDinov2Model():
285 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
286 | original_model = timm.create_model(
287 | "vit_small_patch14_reg4_dinov2.lvd142m",
288 | pretrained=True,
289 | num_classes=0,
290 | dynamic_img_size=True,
291 | dynamic_img_pad=False,
292 | ).to(device)
293 | original_model.get_intermediate_layers = types.MethodType(
294 | get_intermediate_layers,
295 | original_model
296 | )
297 | fine_model = torch.hub.load("ywyue/FiT3D", "dinov2_reg_small_fine").to(device)
298 | fine_model.get_intermediate_layers = types.MethodType(
299 | get_intermediate_layers,
300 | fine_model
301 | )
302 | return original_model,fine_model
303 |
304 |
305 | def GetDinov2RegFeats(original_model,fine_model,image,transforms=None,kmeans=20):
306 | ori_feats,fine_feats=Dinov2RegExtractor(original_model,fine_model,image,transforms)
307 | return fine_feats
308 |
--------------------------------------------------------------------------------
/scene/scene_utils.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 | import numpy as np
3 | from utils.graphics_utils import getWorld2View2
4 | from scene.gaussian_model import BasicPointCloud
5 | from plyfile import PlyData, PlyElement
6 |
7 |
8 | class CameraInfo(NamedTuple):
9 | uid: int
10 | R: np.array
11 | T: np.array
12 | image: np.array
13 | image_path: str
14 | image_name: str
15 | width: int
16 | height: int
17 | sky_mask: np.array = None
18 | timestamp: float = 0.0
19 | FovY: float = None
20 | FovX: float = None
21 | fx: float = None
22 | fy: float = None
23 | cx: float = None
24 | cy: float = None
25 | pointcloud_camera: np.array = None
26 |
27 |
28 | depth_map: np.array = None
29 | semantic_mash: np.array = None
30 | instance_mask: np.array = None
31 | sam_mask: np.array = None
32 | dynamic_mask: np.array = None
33 | feat_map: np.array = None
34 | normal_map: np.array = None
35 |
36 | objects: np.array = None
37 | intrinsic: np.array = None
38 | c2w: np.array = None
39 |
40 | class SceneInfo(NamedTuple):
41 | point_cloud: BasicPointCloud
42 | train_cameras: list
43 | test_cameras: list
44 | nerf_normalization: dict
45 | ply_path: str
46 | time_interval: float = 0.02
47 | time_duration: list = [0, 1]
48 |
49 |
50 | full_cameras : list = None
51 | bg_point_cloud: BasicPointCloud = None
52 | bg_ply_path: str = None
53 | cam_frustum_aabb: np.array = None
54 | num_panoptic_objects: int = 0
55 | panoptic_id_to_idx: dict = None
56 | panoptic_object_ids: list = None
57 | occ_grid: np.array = None
58 |
59 | def getNerfppNorm(cam_info):
60 | def get_center_and_diag(cam_centers):
61 | cam_centers = np.hstack(cam_centers)
62 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
63 | center = avg_cam_center
64 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
65 | diagonal = np.max(dist)
66 | return center.flatten(), diagonal
67 |
68 | cam_centers = []
69 |
70 | for cam in cam_info:
71 | W2C = getWorld2View2(cam.R, cam.T)
72 | C2W = np.linalg.inv(W2C)
73 | cam_centers.append(C2W[:3, 3:4])
74 |
75 | center, diagonal = get_center_and_diag(cam_centers)
76 | radius = diagonal * 1.1
77 |
78 | translate = -center
79 |
80 | return {"translate": translate, "radius": radius}
81 |
82 |
83 | def fetchPly(path):
84 | plydata = PlyData.read(path)
85 | vertices = plydata['vertex']
86 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
87 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
88 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
89 | if 'time' in vertices:
90 | timestamp = vertices['time'][:, None]
91 | else:
92 | timestamp = None
93 | return BasicPointCloud(points=positions, colors=colors, normals=normals, time=timestamp)
94 |
95 |
96 | def storePly(path, xyz, rgb, timestamp=None):
97 | # Define the dtype for the structured array
98 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
99 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
100 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
101 | ('time', 'f4')]
102 |
103 | normals = np.zeros_like(xyz)
104 | if timestamp is None:
105 | timestamp = np.zeros_like(xyz[:, :1])
106 |
107 | elements = np.empty(xyz.shape[0], dtype=dtype)
108 | attributes = np.concatenate((xyz, normals, rgb, timestamp), axis=1)
109 | elements[:] = list(map(tuple, attributes))
110 |
111 | # Create the PlyData object and write to file
112 | vertex_element = PlyElement.describe(elements, 'vertex')
113 | ply_data = PlyData([vertex_element])
114 | ply_data.write(path)
115 |
--------------------------------------------------------------------------------
/scene/waymo_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from PIL import Image
5 | from scene.scene_utils import CameraInfo, SceneInfo, getNerfppNorm, fetchPly, storePly
6 | from utils.graphics_utils import BasicPointCloud, focal2fov
7 |
8 |
9 | def pad_poses(p):
10 | """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
11 | bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
12 | return np.concatenate([p[..., :3, :4], bottom], axis=-2)
13 |
14 |
15 | def unpad_poses(p):
16 | """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
17 | return p[..., :3, :4]
18 |
19 |
20 | def transform_poses_pca(poses, fix_radius=0):
21 | """Transforms poses so principal components lie on XYZ axes.
22 |
23 | Args:
24 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
25 |
26 | Returns:
27 | A tuple (poses, transform), with the transformed poses and the applied
28 | camera_to_world transforms.
29 |
30 | From https://github.com/SuLvXiangXin/zipnerf-pytorch/blob/af86ea6340b9be6b90ea40f66c0c02484dfc7302/internal/camera_utils.py#L161
31 | """
32 | t = poses[:, :3, 3]
33 | t_mean = t.mean(axis=0)
34 | t = t - t_mean
35 |
36 | eigval, eigvec = np.linalg.eig(t.T @ t)
37 | # Sort eigenvectors in order of largest to smallest eigenvalue.
38 | inds = np.argsort(eigval)[::-1]
39 | eigvec = eigvec[:, inds]
40 | rot = eigvec.T
41 | if np.linalg.det(rot) < 0:
42 | rot = np.diag(np.array([1, 1, -1])) @ rot
43 |
44 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
45 | poses_recentered = unpad_poses(transform @ pad_poses(poses))
46 | transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
47 |
48 | # Flip coordinate system if z component of y-axis is negative
49 | if poses_recentered.mean(axis=0)[2, 1] < 0:
50 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
51 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
52 |
53 | # Just make sure it's it in the [-1, 1]^3 cube
54 | if fix_radius>0:
55 | scale_factor = 1./fix_radius
56 | else:
57 | scale_factor = 1. / (np.max(np.abs(poses_recentered[:, :3, 3])) + 1e-5)
58 | scale_factor = min(1 / 10, scale_factor)
59 |
60 | poses_recentered[:, :3, 3] *= scale_factor
61 | transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
62 |
63 | return poses_recentered, transform, scale_factor
64 |
65 |
66 | def readWaymoInfo(args):
67 | neg_fov = args.neg_fov
68 | cam_infos = []
69 | car_list = [f[:-4] for f in sorted(os.listdir(os.path.join(args.source_path, "calib"))) if f.endswith('.txt')]
70 | points = []
71 | points_time = []
72 |
73 | load_size = [640, 960]
74 | ORIGINAL_SIZE = [[1280, 1920], [1280, 1920], [1280, 1920], [884, 1920], [884, 1920]]
75 | frame_num = len(car_list)
76 | if args.frame_interval > 0:
77 | time_duration = [-args.frame_interval*(frame_num-1)/2,args.frame_interval*(frame_num-1)/2]
78 | else:
79 | time_duration = args.time_duration
80 |
81 | for idx, car_id in tqdm(enumerate(car_list), desc="Loading data"):
82 | ego_pose = np.loadtxt(os.path.join(args.source_path, 'pose', car_id + '.txt'))
83 |
84 | # CAMERA DIRECTION: RIGHT DOWN FORWARDS
85 | with open(os.path.join(args.source_path, 'calib', car_id + '.txt')) as f:
86 | calib_data = f.readlines()
87 | L = [list(map(float, line.split()[1:])) for line in calib_data]
88 | Ks = np.array(L[:5]).reshape(-1, 3, 4)[:, :, :3]
89 | lidar2cam = np.array(L[-5:]).reshape(-1, 3, 4)
90 | lidar2cam = pad_poses(lidar2cam)
91 |
92 | cam2lidar = np.linalg.inv(lidar2cam)
93 | c2w = ego_pose @ cam2lidar
94 | w2c = np.linalg.inv(c2w)
95 | images = []
96 | image_paths = []
97 | HWs = []
98 | for subdir in ['image_0', 'image_1', 'image_2', 'image_3', 'image_4'][:args.cam_num]:
99 | image_path = os.path.join(args.source_path, subdir, car_id + '.png')
100 | im_data = Image.open(image_path)
101 | im_data = im_data.resize((load_size[1], load_size[0]), Image.BILINEAR) # PIL resize: (W, H)
102 | W, H = im_data.size
103 | image = np.array(im_data) / 255.
104 | HWs.append((H, W))
105 | images.append(image)
106 | image_paths.append(image_path)
107 |
108 | sky_masks = []
109 | for subdir in ['sky_0', 'sky_1', 'sky_2', 'sky_3', 'sky_4'][:args.cam_num]:
110 | sky_data = Image.open(os.path.join(args.source_path, subdir, car_id + '.png'))
111 | sky_data = sky_data.resize((load_size[1], load_size[0]), Image.NEAREST) # PIL resize: (W, H)
112 | sky_mask = np.array(sky_data)>0
113 | sky_masks.append(sky_mask.astype(np.float32))
114 |
115 | normal_maps = []
116 | if args.load_normal_map:
117 | for subdir in ['normal_0', 'normal_1', 'normal_2', 'normal_3', 'normal_4'][:args.cam_num]:
118 | normal_data = Image.open(os.path.join(args.source_path, subdir, car_id + '.jpg'))
119 | normal_data = normal_data.resize((load_size[1], load_size[0]), Image.BILINEAR)
120 | normal_map = (np.array(normal_data)) / 255.
121 | normal_maps.append(normal_map)
122 |
123 |
124 | timestamp = time_duration[0] + (time_duration[1] - time_duration[0]) * idx / (len(car_list) - 1)
125 | point = np.fromfile(os.path.join(args.source_path, "velodyne", car_id + ".bin"),
126 | dtype=np.float32, count=-1).reshape(-1, 6)
127 | point_xyz, intensity, elongation, timestamp_pts = np.split(point, [3, 4, 5], axis=1)
128 | point_xyz_world = (np.pad(point_xyz, (0, 1), constant_values=1) @ ego_pose.T)[:, :3]
129 | points.append(point_xyz_world)
130 | point_time = np.full_like(point_xyz_world[:, :1], timestamp)
131 | points_time.append(point_time)
132 | for j in range(args.cam_num):
133 | point_camera = (np.pad(point_xyz, ((0, 0), (0, 1)), constant_values=1) @ lidar2cam[j].T)[:, :3]
134 | R = np.transpose(w2c[j, :3, :3]) # R is stored transposed due to 'glm' in CUDA code
135 | T = w2c[j, :3, 3]
136 | K = Ks[j]
137 | fx = float(K[0, 0]) * load_size[1] / ORIGINAL_SIZE[j][1]
138 | fy = float(K[1, 1]) * load_size[0] / ORIGINAL_SIZE[j][0]
139 | cx = float(K[0, 2]) * load_size[1] / ORIGINAL_SIZE[j][1]
140 | cy = float(K[1, 2]) * load_size[0] / ORIGINAL_SIZE[j][0]
141 | width=HWs[j][1]
142 | height=HWs[j][0]
143 | if neg_fov:
144 | FovY = -1.0
145 | FovX = -1.0
146 | else:
147 | FovY = focal2fov(fy, height)
148 | FovX = focal2fov(fx, width)
149 | cam_infos.append(CameraInfo(uid=idx * 5 + j, R=R, T=T, FovY=FovY, FovX=FovX,
150 | image=images[j],
151 | image_path=image_paths[j], image_name=car_id,
152 | width=HWs[j][1], height=HWs[j][0], timestamp=timestamp,
153 | pointcloud_camera = point_camera,
154 | fx=fx, fy=fy, cx=cx, cy=cy,
155 | sky_mask=sky_masks[j],
156 | normal_map=normal_maps[j] if args.load_normal_map else None))
157 |
158 | if args.debug_cuda:
159 | break
160 |
161 | pointcloud = np.concatenate(points, axis=0)
162 | pointcloud_timestamp = np.concatenate(points_time, axis=0)
163 | indices = np.random.choice(pointcloud.shape[0], args.num_pts, replace=True)
164 | pointcloud = pointcloud[indices]
165 | pointcloud_timestamp = pointcloud_timestamp[indices]
166 |
167 | w2cs = np.zeros((len(cam_infos), 4, 4))
168 | Rs = np.stack([c.R for c in cam_infos], axis=0)
169 | Ts = np.stack([c.T for c in cam_infos], axis=0)
170 | w2cs[:, :3, :3] = Rs.transpose((0, 2, 1))
171 | w2cs[:, :3, 3] = Ts
172 | w2cs[:, 3, 3] = 1
173 | c2ws = unpad_poses(np.linalg.inv(w2cs))
174 | c2ws, transform, scale_factor = transform_poses_pca(c2ws, fix_radius=args.fix_radius)
175 |
176 | c2ws = pad_poses(c2ws)
177 | for idx, cam_info in enumerate(tqdm(cam_infos, desc="Transform data")):
178 | c2w = c2ws[idx]
179 | w2c = np.linalg.inv(c2w)
180 | cam_info.R[:] = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
181 | cam_info.T[:] = w2c[:3, 3]
182 | cam_info.pointcloud_camera[:] *= scale_factor
183 | pointcloud = (np.pad(pointcloud, ((0, 0), (0, 1)), constant_values=1) @ transform.T)[:, :3]
184 | if args.eval:
185 | # ## for snerf scene
186 | # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold != 0]
187 | # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // cam_num) % testhold == 0]
188 |
189 | # for dynamic scene
190 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold != 0]
191 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num + 1) % args.testhold == 0]
192 |
193 | # for emernerf comparison [testhold::testhold]
194 | if args.testhold == 10:
195 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold != 0 or (idx // args.cam_num) == 0]
196 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx // args.cam_num) % args.testhold == 0 and (idx // args.cam_num)>0]
197 | else:
198 | train_cam_infos = cam_infos
199 | test_cam_infos = []
200 |
201 | nerf_normalization = getNerfppNorm(train_cam_infos)
202 | nerf_normalization['radius'] = 1/nerf_normalization['radius']
203 |
204 | ply_path = os.path.join(args.source_path, "points3d.ply")
205 | if not os.path.exists(ply_path):
206 | rgbs = np.random.random((pointcloud.shape[0], 3))
207 | storePly(ply_path, pointcloud, rgbs, pointcloud_timestamp)
208 | try:
209 | pcd = fetchPly(ply_path)
210 | except:
211 | pcd = None
212 |
213 | pcd = BasicPointCloud(pointcloud, colors=np.zeros([pointcloud.shape[0],3]), normals=None, time=pointcloud_timestamp)
214 | time_interval = (time_duration[1] - time_duration[0]) / (len(car_list) - 1)
215 |
216 | scene_info = SceneInfo(point_cloud=pcd,
217 | train_cameras=train_cam_infos,
218 | test_cameras=test_cam_infos,
219 | nerf_normalization=nerf_normalization,
220 | ply_path=ply_path,
221 | time_interval=time_interval,
222 | time_duration=time_duration)
223 |
224 | return scene_info
225 |
--------------------------------------------------------------------------------
/scripts/extract_mask_kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_masks.py
3 | @author Jianfei Guo, Shanghai AI Lab
4 | @brief Extract semantic mask
5 |
6 | Using SegFormer, 2021. Cityscapes 83.2%
7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9)
8 |
9 | Installation:
10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env.
11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9;
12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num'
13 | Hence, a seperate conda env is needed.
14 |
15 | git clone https://github.com/NVlabs/SegFormer
16 |
17 | conda create -n segformer python=3.8
18 | conda activate segformer
19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
21 |
22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf
23 | pip install mmcv-full==1.2.7 --no-cache-dir
24 |
25 | cd SegFormer
26 | pip install .
27 |
28 | Usage:
29 | Direct run this script in the newly set conda env.
30 | """
31 | import os
32 | import numpy as np
33 | import cv2
34 | from tqdm import tqdm
35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
36 | from mmseg.core.evaluation import get_palette
37 |
38 | if __name__ == "__main__":
39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer'
40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5', 'segformer.b5.1024x1024.city.160k.py')
41 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth')
42 | model = init_segmentor(config, checkpoint, device='cuda')
43 |
44 | root = 'data/kitti_mot/training'
45 |
46 | for cam_id in ['2', '3']:
47 | image_dir = os.path.join(root, f'image_0{cam_id}')
48 | sky_dir = os.path.join(root, f'sky_0{cam_id}')
49 | for seq in sorted(os.listdir(image_dir)):
50 | seq_dir = os.path.join(image_dir, seq)
51 | mask_dir = os.path.join(sky_dir, seq)
52 | if not os.path.isdir(seq_dir):
53 | continue
54 |
55 | os.makedirs(image_dir, exist_ok=True)
56 | os.makedirs(mask_dir, exist_ok=True)
57 | for image_name in sorted(os.listdir(seq_dir)):
58 | image_path = os.path.join(seq_dir, image_name)
59 | print(image_path)
60 | mask_path = os.path.join(mask_dir, image_name)
61 | if not image_path.endswith(".png"):
62 | continue
63 | result = inference_segmentor(model, image_path)
64 | mask = result[0].astype(np.uint8)
65 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8)
66 | cv2.imwrite(os.path.join(mask_dir, image_name), mask)
67 |
--------------------------------------------------------------------------------
/scripts/extract_mask_waymo.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_masks.py
3 | @author Jianfei Guo, Shanghai AI Lab
4 | @brief Extract semantic mask
5 |
6 | Using SegFormer, 2021. Cityscapes 83.2%
7 | Relies on timm==0.3.2 & pytorch 1.8.1 (buggy on pytorch >= 1.9)
8 |
9 | Installation:
10 | NOTE: mmcv-full==1.2.7 requires another pytorch version & conda env.
11 | Currently mmcv-full==1.2.7 does not support pytorch>=1.9;
12 | will raise AttributeError: 'super' object has no attribute '_specify_ddp_gpu_num'
13 | Hence, a seperate conda env is needed.
14 |
15 | git clone https://github.com/NVlabs/SegFormer
16 |
17 | conda create -n segformer python=3.8
18 | conda activate segformer
19 | # conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 cudatoolkit=11.3 -c pytorch -c conda-forge
20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
21 |
22 | pip install timm==0.3.2 pylint debugpy opencv-python attrs ipython tqdm imageio scikit-image omegaconf
23 | pip install mmcv-full==1.2.7 --no-cache-dir
24 |
25 | cd SegFormer
26 | pip install .
27 |
28 | Usage:
29 | Direct run this script in the newly set conda env.
30 | """
31 | import os
32 | import numpy as np
33 | import cv2
34 | from tqdm import tqdm
35 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
36 | from mmseg.core.evaluation import get_palette
37 |
38 | if __name__ == "__main__":
39 | segformer_path = '/SSD_DISK/users/guchun/reconstruction/neuralsim/SegFormer'
40 | config = os.path.join(segformer_path, 'local_configs', 'segformer', 'B5',
41 | 'segformer.b5.1024x1024.city.160k.py')
42 | checkpoint = os.path.join(segformer_path, 'segformer.b5.1024x1024.city.160k.pth')
43 | model = init_segmentor(config, checkpoint, device='cuda')
44 |
45 | root = 'data/waymo_scenes'
46 |
47 | scenes = sorted(os.listdir(root))
48 |
49 | for scene in scenes:
50 | for cam_id in range(5):
51 | image_dir = os.path.join(root, scene, f'image_{cam_id}')
52 | sky_dir = os.path.join(root, scene, f'sky_{cam_id}')
53 | os.makedirs(sky_dir, exist_ok=True)
54 | for image_name in tqdm(sorted(os.listdir(image_dir))):
55 | if not image_name.endswith(".png"):
56 | continue
57 | image_path = os.path.join(image_dir, image_name)
58 | mask_path = os.path.join(sky_dir, image_name)
59 | result = inference_segmentor(model, image_path)
60 | mask = result[0].astype(np.uint8)
61 | mask = ((mask == 10).astype(np.float32) * 255).astype(np.uint8)
62 | cv2.imwrite(mask_path, mask)
63 |
--------------------------------------------------------------------------------
/scripts/extract_mono_cues_kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_mono_cues.py
3 | @brief extract monocular cues (normal & depth)
4 | Adapted from https://github.com/EPFL-VILAB/omnidata
5 |
6 | Installation:
7 | git clone https://github.com/EPFL-VILAB/omnidata
8 |
9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning
10 | """
11 |
12 | import os
13 | import sys
14 | import argparse
15 | import imageio
16 | import numpy as np
17 | from glob import glob
18 | from tqdm import tqdm
19 | from typing import Literal
20 |
21 | import PIL
22 | import skimage
23 | from PIL import Image
24 | import matplotlib.pyplot as plt
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | from torchvision import transforms
29 | import torchvision.transforms.functional as transF
30 |
31 | def list_contains(l: list, v):
32 | """
33 | Whether any item in `l` contains `v`
34 | """
35 | for item in l:
36 | if v in item:
37 | return True
38 | else:
39 | return False
40 |
41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1):
42 | if mask_valid is not None:
43 | img[~mask_valid] = torch.nan
44 | sorted_img = torch.sort(torch.flatten(img))[0]
45 | # Remove nan, nan at the end of sort
46 | num_nan = sorted_img.isnan().sum()
47 | if num_nan > 0:
48 | sorted_img = sorted_img[:-num_nan]
49 | # Remove outliers
50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))]
51 | trunc_mean = trunc_img.mean()
52 | trunc_var = trunc_img.var()
53 | eps = 1e-6
54 | # Replace nan by mean
55 | img = torch.nan_to_num(img, nan=trunc_mean)
56 | # Standardize
57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps)
58 | return img
59 |
60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True):
61 | with torch.no_grad():
62 | # img = Image.open(img_path)
63 | img = imageio.imread(img_path)
64 | img = skimage.img_as_float32(img)
65 | H, W, _ = img.shape
66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device)
67 | if H < W:
68 | H_ = ref_img_size
69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32
70 | else:
71 | W_ = ref_img_size
72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32
73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True)
74 |
75 | if img_tensor.shape[1] == 1:
76 | img_tensor = img_tensor.repeat_interleave(3,1)
77 |
78 | output = model(img_tensor).clamp(min=0, max=1)
79 |
80 | if args.task == 'depth':
81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
82 | output = output.clamp(0,1).data / output.max()
83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
86 |
87 | # np.savez_compressed(f"{output_path_base}.npz", output)
88 | #output = 1 - output
89 | # output = standardize_depth_map(output)
90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis')
91 | if verbose:
92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size
93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here.
94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16))
95 |
96 | else:
97 | output = output.data.clamp(0,1).squeeze(0)
98 | # Resize to original shape
99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here.
100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
101 |
102 | # np.savez_compressed(f"{output_path_base}.npz", output)
103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5)
104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size
105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow
106 | # np.save(f"{output_path_base}.npy", output)
107 |
108 | if __name__ == "__main__":
109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals')
110 |
111 | # Dataset specific configs
112 | parser.add_argument('--data_root', type=str, default='./data/kitti_mot/training')
113 | parser.add_argument('--seq_list', type=str, default=['0001', '0002', '0006'], help='specify --seq_list if you want to limit the list of seqs')
114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization")
115 | parser.add_argument('--ignore_existing', action='store_true')
116 | parser.add_argument('--rgb_dirname', type=str, default="image")
117 | parser.add_argument('--depth_dirname', type=str, default="depth")
118 | parser.add_argument('--normals_dirname', type=str, default="normal")
119 |
120 | # Algorithm configs
121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth")
122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model",
123 | default='omnidata/omnidata_tools/torch')
124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models",
125 | default='omnidata/omnidata_tools/torch/pretrained_models')
126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512,
127 | help="image size when inference (will still save full-scale output)")
128 | args = parser.parse_args()
129 |
130 | #-----------------------------------------------
131 | #-- Original preparation
132 | #-----------------------------------------------
133 | if args.pretrained_models is None:
134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/'
135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models")
136 |
137 | sys.path.append(args.omnidata_path)
138 | print(sys.path)
139 | from modules.unet import UNet
140 | from modules.midas.dpt_depth import DPTDepthModel
141 | from data.transforms import get_transform
142 |
143 | trans_topil = transforms.ToPILImage()
144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')
145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
146 |
147 | # get target task and model
148 | if args.task == 'normal' or args.task == 'normals':
149 | image_size = 384
150 |
151 | #---- Version 1 model
152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth')
153 | # model = UNet(in_channels=3, out_channels=3)
154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
155 |
156 | # if 'state_dict' in checkpoint:
157 | # state_dict = {}
158 | # for k, v in checkpoint['state_dict'].items():
159 | # state_dict[k.replace('model.', '')] = v
160 | # else:
161 | # state_dict = checkpoint
162 |
163 |
164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt')
165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid
166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
167 | if 'state_dict' in checkpoint:
168 | state_dict = {}
169 | for k, v in checkpoint['state_dict'].items():
170 | state_dict[k[6:]] = v
171 | else:
172 | state_dict = checkpoint
173 |
174 | model.load_state_dict(state_dict)
175 | model.to(device)
176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
177 | # transforms.CenterCrop(image_size),
178 | # get_transform('rgb', image_size=None)])
179 |
180 | trans_totensor = transforms.Compose([
181 | get_transform('rgb', image_size=None)])
182 |
183 | elif args.task == 'depth':
184 | image_size = 384
185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt'
186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large
187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid
188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
189 | if 'state_dict' in checkpoint:
190 | state_dict = {}
191 | for k, v in checkpoint['state_dict'].items():
192 | state_dict[k[6:]] = v
193 | else:
194 | state_dict = checkpoint
195 | model.load_state_dict(state_dict)
196 | model.to(device)
197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR),
198 | # transforms.CenterCrop(image_size),
199 | # transforms.ToTensor(),
200 | # transforms.Normalize(mean=0.5, std=0.5)])
201 | trans_totensor = transforms.Compose([
202 | transforms.ToTensor(),
203 | transforms.Normalize(mean=0.5, std=0.5)])
204 |
205 | else:
206 | print("task should be one of the following: normal, depth")
207 | sys.exit()
208 |
209 | #-----------------------------------------------
210 | #--- Dataset Specific processing
211 | #-----------------------------------------------
212 |
213 | select_scene_ids = [f'{args.rgb_dirname}_02', f'{args.rgb_dirname}_03']
214 |
215 |
216 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')):
217 | for seq in args.seq_list:
218 | image_dir = os.path.join(args.data_root, scene_id, seq)
219 | obs_id_list = sorted(os.listdir(image_dir))
220 |
221 | if args.task == 'depth':
222 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname)
223 | elif args.task == 'normal':
224 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname)
225 | else:
226 | raise RuntimeError(f"Invalid task={args.task}")
227 |
228 | os.makedirs(output_dir, exist_ok=True)
229 |
230 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')):
231 | img_path = os.path.join(image_dir, obs_id)
232 | fbase = os.path.splitext(os.path.basename(img_path))[0]
233 |
234 | if args.task == 'depth':
235 | output_base = os.path.join(output_dir, fbase)
236 | elif args.task == 'normal':
237 | output_base = os.path.join(output_dir, fbase)
238 | else:
239 | raise RuntimeError(f"Invalid task={args.task}")
240 |
241 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase):
242 | continue
243 |
244 | #---- Inference and save outputs
245 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose)
--------------------------------------------------------------------------------
/scripts/extract_mono_cues_notr.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_mono_cues.py
3 | @brief extract monocular cues (normal & depth)
4 | Adapted from https://github.com/EPFL-VILAB/omnidata
5 |
6 | Installation:
7 | git clone https://github.com/EPFL-VILAB/omnidata
8 |
9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning
10 | """
11 |
12 | import os
13 | import sys
14 | import argparse
15 | import imageio
16 | import numpy as np
17 | from glob import glob
18 | from tqdm import tqdm
19 | from typing import Literal
20 |
21 | import PIL
22 | import skimage
23 | from PIL import Image
24 | import matplotlib.pyplot as plt
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | from torchvision import transforms
29 | import torchvision.transforms.functional as transF
30 |
31 | def list_contains(l: list, v):
32 | """
33 | Whether any item in `l` contains `v`
34 | """
35 | for item in l:
36 | if v in item:
37 | return True
38 | else:
39 | return False
40 |
41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1):
42 | if mask_valid is not None:
43 | img[~mask_valid] = torch.nan
44 | sorted_img = torch.sort(torch.flatten(img))[0]
45 | # Remove nan, nan at the end of sort
46 | num_nan = sorted_img.isnan().sum()
47 | if num_nan > 0:
48 | sorted_img = sorted_img[:-num_nan]
49 | # Remove outliers
50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))]
51 | trunc_mean = trunc_img.mean()
52 | trunc_var = trunc_img.var()
53 | eps = 1e-6
54 | # Replace nan by mean
55 | img = torch.nan_to_num(img, nan=trunc_mean)
56 | # Standardize
57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps)
58 | return img
59 |
60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True):
61 | with torch.no_grad():
62 | # img = Image.open(img_path)
63 | img = imageio.imread(img_path)
64 | img = skimage.img_as_float32(img)
65 | H, W, _ = img.shape
66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device)
67 | if H < W:
68 | H_ = ref_img_size
69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32
70 | else:
71 | W_ = ref_img_size
72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32
73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True)
74 |
75 | if img_tensor.shape[1] == 1:
76 | img_tensor = img_tensor.repeat_interleave(3,1)
77 |
78 | output = model(img_tensor).clamp(min=0, max=1)
79 |
80 | if args.task == 'depth':
81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
82 | output = output.clamp(0,1).data / output.max()
83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
86 |
87 | # np.savez_compressed(f"{output_path_base}.npz", output)
88 | #output = 1 - output
89 | # output = standardize_depth_map(output)
90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis')
91 | if verbose:
92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size
93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here.
94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16))
95 |
96 | else:
97 | output = output.data.clamp(0,1).squeeze(0)
98 | # Resize to original shape
99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here.
100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
101 |
102 | # np.savez_compressed(f"{output_path_base}.npz", output)
103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5)
104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size
105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow
106 | # np.save(f"{output_path_base}.npy", output)
107 |
108 | if __name__ == "__main__":
109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals')
110 |
111 | # Dataset specific configs
112 | parser.add_argument('--data_root', type=str, default='./data/notr/')
113 | parser.add_argument('--seq_list', type=str, default=['084'], help='specify --seq_list if you want to limit the list of seqs')
114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization")
115 | parser.add_argument('--ignore_existing', action='store_true')
116 | parser.add_argument('--rgb_dirname', type=str, default="images")
117 | parser.add_argument('--depth_dirname', type=str, default="depths")
118 | parser.add_argument('--normals_dirname', type=str, default="normals")
119 |
120 | # Algorithm configs
121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth")
122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model",
123 | default='omnidata/omnidata_tools/torch')
124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models",
125 | default='omnidata/omnidata_tools/torch/pretrained_models')
126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512,
127 | help="image size when inference (will still save full-scale output)")
128 | args = parser.parse_args()
129 |
130 | #-----------------------------------------------
131 | #-- Original preparation
132 | #-----------------------------------------------
133 | if args.pretrained_models is None:
134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/'
135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models")
136 |
137 | sys.path.append(args.omnidata_path)
138 | print(sys.path)
139 | from modules.unet import UNet
140 | from modules.midas.dpt_depth import DPTDepthModel
141 | from data.transforms import get_transform
142 |
143 | trans_topil = transforms.ToPILImage()
144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')
145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
146 |
147 | # get target task and model
148 | if args.task == 'normal' or args.task == 'normals':
149 | image_size = 384
150 |
151 | #---- Version 1 model
152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth')
153 | # model = UNet(in_channels=3, out_channels=3)
154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
155 |
156 | # if 'state_dict' in checkpoint:
157 | # state_dict = {}
158 | # for k, v in checkpoint['state_dict'].items():
159 | # state_dict[k.replace('model.', '')] = v
160 | # else:
161 | # state_dict = checkpoint
162 |
163 |
164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt')
165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid
166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
167 | if 'state_dict' in checkpoint:
168 | state_dict = {}
169 | for k, v in checkpoint['state_dict'].items():
170 | state_dict[k[6:]] = v
171 | else:
172 | state_dict = checkpoint
173 |
174 | model.load_state_dict(state_dict)
175 | model.to(device)
176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
177 | # transforms.CenterCrop(image_size),
178 | # get_transform('rgb', image_size=None)])
179 |
180 | trans_totensor = transforms.Compose([
181 | get_transform('rgb', image_size=None)])
182 |
183 | elif args.task == 'depth':
184 | image_size = 384
185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt'
186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large
187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid
188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
189 | if 'state_dict' in checkpoint:
190 | state_dict = {}
191 | for k, v in checkpoint['state_dict'].items():
192 | state_dict[k[6:]] = v
193 | else:
194 | state_dict = checkpoint
195 | model.load_state_dict(state_dict)
196 | model.to(device)
197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR),
198 | # transforms.CenterCrop(image_size),
199 | # transforms.ToTensor(),
200 | # transforms.Normalize(mean=0.5, std=0.5)])
201 | trans_totensor = transforms.Compose([
202 | transforms.ToTensor(),
203 | transforms.Normalize(mean=0.5, std=0.5)])
204 |
205 | else:
206 | print("task should be one of the following: normal, depth")
207 | sys.exit()
208 |
209 | #-----------------------------------------------
210 | #--- Dataset Specific processing
211 | #-----------------------------------------------
212 | if isinstance(args.seq_list, list):
213 | select_scene_ids = args.seq_list
214 | else:
215 | select_scene_ids = [args.seq_list]
216 |
217 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')):
218 |
219 | image_dir = os.path.join(args.data_root, scene_id, f"{args.rgb_dirname}")
220 | obs_id_list = sorted(os.listdir(image_dir))
221 |
222 | if args.task == 'depth':
223 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname)
224 | elif args.task == 'normal':
225 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname)
226 | else:
227 | raise RuntimeError(f"Invalid task={args.task}")
228 |
229 | os.makedirs(output_dir, exist_ok=True)
230 |
231 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')):
232 | img_path = os.path.join(image_dir, obs_id)
233 | fbase = os.path.splitext(os.path.basename(img_path))[0]
234 |
235 | if args.task == 'depth':
236 | output_base = os.path.join(output_dir, fbase)
237 | elif args.task == 'normal':
238 | output_base = os.path.join(output_dir, fbase)
239 | else:
240 | raise RuntimeError(f"Invalid task={args.task}")
241 |
242 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase):
243 | continue
244 |
245 | #---- Inference and save outputs
246 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose)
--------------------------------------------------------------------------------
/scripts/extract_mono_cues_waymo.py:
--------------------------------------------------------------------------------
1 | """
2 | @file extract_mono_cues.py
3 | @brief extract monocular cues (normal & depth)
4 | Adapted from https://github.com/EPFL-VILAB/omnidata
5 |
6 | Installation:
7 | git clone https://github.com/EPFL-VILAB/omnidata
8 |
9 | pip install einops joblib pandas h5py scipy seaborn kornia timm pytorch-lightning
10 | """
11 |
12 | import os
13 | import sys
14 | import argparse
15 | import imageio
16 | import numpy as np
17 | from glob import glob
18 | from tqdm import tqdm
19 | from typing import Literal
20 |
21 | import PIL
22 | import skimage
23 | from PIL import Image
24 | import matplotlib.pyplot as plt
25 |
26 | import torch
27 | import torch.nn.functional as F
28 | from torchvision import transforms
29 | import torchvision.transforms.functional as transF
30 |
31 | def list_contains(l: list, v):
32 | """
33 | Whether any item in `l` contains `v`
34 | """
35 | for item in l:
36 | if v in item:
37 | return True
38 | else:
39 | return False
40 |
41 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1):
42 | if mask_valid is not None:
43 | img[~mask_valid] = torch.nan
44 | sorted_img = torch.sort(torch.flatten(img))[0]
45 | # Remove nan, nan at the end of sort
46 | num_nan = sorted_img.isnan().sum()
47 | if num_nan > 0:
48 | sorted_img = sorted_img[:-num_nan]
49 | # Remove outliers
50 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))]
51 | trunc_mean = trunc_img.mean()
52 | trunc_var = trunc_img.var()
53 | eps = 1e-6
54 | # Replace nan by mean
55 | img = torch.nan_to_num(img, nan=trunc_mean)
56 | # Standardize
57 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps)
58 | return img
59 |
60 | def extract_cues(img_path: str, output_path_base: str, ref_img_size: int=384, verbose=True):
61 | with torch.no_grad():
62 | # img = Image.open(img_path)
63 | img = imageio.imread(img_path)
64 | img = skimage.img_as_float32(img)
65 | H, W, _ = img.shape
66 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device)
67 | if H < W:
68 | H_ = ref_img_size
69 | W_ = int(((W / (H/H_)) // 32)) * 32 # Force to be a multiple of 32
70 | else:
71 | W_ = ref_img_size
72 | H_ = int(((H / (W/W_)) // 32)) * 32 # Force to be a multiple of 32
73 | img_tensor = transF.resize(img_tensor, (H_, W_), antialias=True)
74 |
75 | if img_tensor.shape[1] == 1:
76 | img_tensor = img_tensor.repeat_interleave(3,1)
77 |
78 | output = model(img_tensor).clamp(min=0, max=1)
79 |
80 | if args.task == 'depth':
81 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
82 | output = output.clamp(0,1).data / output.max()
83 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
84 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
85 | # output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).cpu().numpy()
86 |
87 | # np.savez_compressed(f"{output_path_base}.npz", output)
88 | #output = 1 - output
89 | # output = standardize_depth_map(output)
90 | # plt.imsave(f"{output_path_base}.png", output, cmap='viridis')
91 | if verbose:
92 | imageio.imwrite(f"{output_path_base}.jpg", (output*255).clip(0,255).astype(np.uint8)[..., 0], format="jpg") # Fastest and smallest file size
93 | # NOTE: jianfei: Although saving to float16 is lossy, we are allowing it since it's just extracting some weak hint here.
94 | np.savez_compressed(f"{output_path_base}.npz", output.astype(np.float16))
95 |
96 | else:
97 | output = output.data.clamp(0,1).squeeze(0)
98 | # Resize to original shape
99 | # NOTE: jianfei: Although saving to uint8 is lossy, we are allowing it since it's just extracting some weak hint here.
100 | output = transF.resize(output, (H,W), antialias=True).movedim(0,-1).mul_(255.).to(torch.uint8).clamp_(0,255).cpu().numpy()
101 |
102 | # np.savez_compressed(f"{output_path_base}.npz", output)
103 | # plt.imsave(f"{output_path_base}.png", output/2+0.5)
104 | imageio.imwrite(f"{output_path_base}.jpg", output) # Fastest and smallest file size
105 | # imageio.imwrite(f"{output_path_base}.png", output) # Very slow
106 | # np.save(f"{output_path_base}.npy", output)
107 |
108 | if __name__ == "__main__":
109 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals')
110 |
111 | # Dataset specific configs
112 | parser.add_argument('--data_root', type=str, default='./data/waymo_scenes')
113 | parser.add_argument('--seq_list', type=str, default=['0017085', '0145050', '0147030', '0158150'], help='specify --seq_list if you want to limit the list of seqs')
114 | parser.add_argument('--verbose', action='store_true', help="Additionally generate .jpg files for visualization")
115 | parser.add_argument('--ignore_existing', action='store_true')
116 | parser.add_argument('--rgb_dirname', type=str, default="image")
117 | parser.add_argument('--depth_dirname', type=str, default="depth")
118 | parser.add_argument('--normals_dirname', type=str, default="normal")
119 |
120 | # Algorithm configs
121 | parser.add_argument('--task', dest='task', required=True, default=None, help="normal or depth")
122 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model",
123 | default='omnidata/omnidata_tools/torch')
124 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models",
125 | default='omnidata/omnidata_tools/torch/pretrained_models')
126 | parser.add_argument('--ref_img_size', dest='ref_img_size', type=int, default=512,
127 | help="image size when inference (will still save full-scale output)")
128 | args = parser.parse_args()
129 |
130 | #-----------------------------------------------
131 | #-- Original preparation
132 | #-----------------------------------------------
133 | if args.pretrained_models is None:
134 | # '/home/guojianfei/ai_ws/omnidata/omnidata_tools/torch/pretrained_models/'
135 | args.pretrained_models = os.path.join(args.omnidata_path, "pretrained_models")
136 |
137 | sys.path.append(args.omnidata_path)
138 | print(sys.path)
139 | from modules.unet import UNet
140 | from modules.midas.dpt_depth import DPTDepthModel
141 | from data.transforms import get_transform
142 |
143 | trans_topil = transforms.ToPILImage()
144 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')
145 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
146 |
147 | # get target task and model
148 | if args.task == 'normal' or args.task == 'normals':
149 | image_size = 384
150 |
151 | #---- Version 1 model
152 | # pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_unet_normal_v1.pth')
153 | # model = UNet(in_channels=3, out_channels=3)
154 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
155 |
156 | # if 'state_dict' in checkpoint:
157 | # state_dict = {}
158 | # for k, v in checkpoint['state_dict'].items():
159 | # state_dict[k.replace('model.', '')] = v
160 | # else:
161 | # state_dict = checkpoint
162 |
163 |
164 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_normal_v2.ckpt')
165 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid
166 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
167 | if 'state_dict' in checkpoint:
168 | state_dict = {}
169 | for k, v in checkpoint['state_dict'].items():
170 | state_dict[k[6:]] = v
171 | else:
172 | state_dict = checkpoint
173 |
174 | model.load_state_dict(state_dict)
175 | model.to(device)
176 | # trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
177 | # transforms.CenterCrop(image_size),
178 | # get_transform('rgb', image_size=None)])
179 |
180 | trans_totensor = transforms.Compose([
181 | get_transform('rgb', image_size=None)])
182 |
183 | elif args.task == 'depth':
184 | image_size = 384
185 | pretrained_weights_path = os.path.join(args.pretrained_models, 'omnidata_dpt_depth_v2.ckpt') # 'omnidata_dpt_depth_v1.ckpt'
186 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large
187 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid
188 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
189 | if 'state_dict' in checkpoint:
190 | state_dict = {}
191 | for k, v in checkpoint['state_dict'].items():
192 | state_dict[k[6:]] = v
193 | else:
194 | state_dict = checkpoint
195 | model.load_state_dict(state_dict)
196 | model.to(device)
197 | # trans_totensor = transforms.Compose([transforms.Resize(args.ref_img_size, interpolation=PIL.Image.BILINEAR),
198 | # transforms.CenterCrop(image_size),
199 | # transforms.ToTensor(),
200 | # transforms.Normalize(mean=0.5, std=0.5)])
201 | trans_totensor = transforms.Compose([
202 | transforms.ToTensor(),
203 | transforms.Normalize(mean=0.5, std=0.5)])
204 |
205 | else:
206 | print("task should be one of the following: normal, depth")
207 | sys.exit()
208 |
209 | #-----------------------------------------------
210 | #--- Dataset Specific processing
211 | #-----------------------------------------------
212 | if isinstance(args.seq_list, list):
213 | select_scene_ids = args.seq_list
214 | else:
215 | select_scene_ids = [args.seq_list]
216 |
217 | for scene_i, scene_id in enumerate(tqdm(select_scene_ids, f'Extracting {args.task} ...')):
218 | for i in range(5):
219 | image_dir = os.path.join(args.data_root, scene_id, f"{args.rgb_dirname}_{i}")
220 | obs_id_list = sorted(os.listdir(image_dir))
221 |
222 | if args.task == 'depth':
223 | output_dir = image_dir.replace(args.rgb_dirname, args.depth_dirname)
224 | elif args.task == 'normal':
225 | output_dir = image_dir.replace(args.rgb_dirname, args.normals_dirname)
226 | else:
227 | raise RuntimeError(f"Invalid task={args.task}")
228 |
229 | os.makedirs(output_dir, exist_ok=True)
230 |
231 | for obs_i, obs_id in enumerate(tqdm(obs_id_list, f'scene [{scene_i}/{len(select_scene_ids)}]')):
232 | img_path = os.path.join(image_dir, obs_id)
233 | fbase = os.path.splitext(os.path.basename(img_path))[0]
234 |
235 | if args.task == 'depth':
236 | output_base = os.path.join(output_dir, fbase)
237 | elif args.task == 'normal':
238 | output_base = os.path.join(output_dir, fbase)
239 | else:
240 | raise RuntimeError(f"Invalid task={args.task}")
241 |
242 | if args.ignore_existing and list_contains(os.listdir(output_dir), fbase):
243 | continue
244 |
245 | #---- Inference and save outputs
246 | extract_cues(img_path, output_base, args.ref_img_size, verbose=args.verbose)
--------------------------------------------------------------------------------
/scripts/waymo_download.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 | from concurrent.futures import ThreadPoolExecutor
5 | from typing import List
6 |
7 |
8 | def download_file(filename, target_dir, source):
9 | result = subprocess.run(
10 | [
11 | "gsutil",
12 | "cp",
13 | "-n",
14 | f"{source}/{filename}.tfrecord",
15 | target_dir,
16 | ],
17 | capture_output=True, # To capture stderr and stdout for detailed error information
18 | text=True,
19 | )
20 |
21 | # Check the return code of the gsutil command
22 | if result.returncode != 0:
23 | raise Exception(
24 | result.stderr
25 | ) # Raise an exception with the error message from the gsutil command
26 |
27 |
28 | def download_files(
29 | file_names: List[str],
30 | target_dir: str,
31 | source: str = "gs://waymo_open_dataset_scene_flow/train",
32 | ) -> None:
33 | """
34 | Downloads a list of files from a given source to a target directory using multiple threads.
35 |
36 | Args:
37 | file_names (List[str]): A list of file names to download.
38 | target_dir (str): The target directory to save the downloaded files.
39 | source (str, optional): The source directory to download the files from. Defaults to "gs://waymo_open_dataset_scene_flow/train".
40 | """
41 | # Get the total number of file_names
42 | total_files = len(file_names)
43 |
44 | # Use ThreadPoolExecutor to manage concurrent downloads
45 | with ThreadPoolExecutor(max_workers=10) as executor:
46 | futures = [
47 | executor.submit(download_file, filename, target_dir, source)
48 | for filename in file_names
49 | ]
50 |
51 | for counter, future in enumerate(futures, start=1):
52 | # Wait for the download to complete and handle any exceptions
53 | try:
54 | # inspects the result of the future and raises an exception if one occurred during execution
55 | future.result()
56 | print(f"[{counter}/{total_files}] Downloaded successfully!")
57 | except Exception as e:
58 | print(f"[{counter}/{total_files}] Failed to download. Error: {e}")
59 |
60 |
61 | if __name__ == "__main__":
62 | print("note: `gcloud auth login` is required before running this script")
63 | print("Downloading Waymo dataset from Google Cloud Storage...")
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument(
66 | "--target_dir",
67 | type=str,
68 | default="dataset/waymo/raw",
69 | help="Path to the target directory",
70 | )
71 | parser.add_argument(
72 | "--scene_ids", type=int, nargs="+", help="scene ids to download"
73 | )
74 | parser.add_argument(
75 | "--split_file", type=str, default=None, help="split file in data/waymo_splits"
76 | )
77 | args = parser.parse_args()
78 | os.makedirs(args.target_dir, exist_ok=True)
79 | total_list = open("data/waymo_train_list.txt", "r").readlines()
80 | if args.split_file is None:
81 | file_names = [total_list[i].strip() for i in args.scene_ids]
82 | else:
83 | # parse the split file
84 | split_file = open(args.split_file, "r").readlines()[1:]
85 | scene_ids = [int(line.strip().split(",")[0]) for line in split_file]
86 | file_names = [total_list[i].strip() for i in scene_ids]
87 | download_files(file_names, args.target_dir)
88 |
89 |
90 | # python scripts/waymo_download.py --scene_ids 23 114 327 621 703 172 552 788
--------------------------------------------------------------------------------
/separate.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import glob
12 | import os
13 | import torch
14 | from gaussian_renderer import get_renderer
15 | from scene import Scene, GaussianModel, EnvLight
16 | from utils.general_utils import seed_everything
17 | from tqdm import tqdm
18 | from argparse import ArgumentParser
19 | from torchvision.utils import save_image
20 | from omegaconf import OmegaConf
21 | from pprint import pformat
22 | EPS = 1e-5
23 |
24 | @torch.no_grad()
25 | def separation(scene : Scene, renderFunc, renderArgs, env_map=None):
26 | scale = scene.resolution_scales[0]
27 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras(scale=scale)},
28 | {'name': 'train', 'cameras': scene.getTrainCameras()})
29 |
30 | # we supppose area with altitude>0.5 is static
31 | # here z axis is downward so is gaussians.get_xyz[:, 2] < -0.5
32 | high_mask = gaussians.get_xyz[:, 2] < 0
33 | # import pdb;pdb.set_trace()
34 | mask = (gaussians.get_scaling_t[:, 0] > args.separate_scaling_t) | high_mask
35 | for config in validation_configs:
36 | if config['cameras'] and len(config['cameras']) > 0:
37 | outdir = os.path.join(args.model_path, "separation", config['name'])
38 | os.makedirs(outdir,exist_ok=True)
39 | for idx, viewpoint in enumerate(tqdm(config['cameras'])):
40 | render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map)
41 | render_pkg_static = renderFunc(viewpoint, scene.gaussians, *renderArgs, env_map=env_map, mask=mask)
42 |
43 | image = torch.clamp(render_pkg["render"], 0.0, 1.0)
44 | image_static = torch.clamp(render_pkg_static["render"], 0.0, 1.0)
45 |
46 | save_image(image, os.path.join(outdir, f"{viewpoint.colmap_id:03d}.png"))
47 | save_image(image_static, os.path.join(outdir, f"{viewpoint.colmap_id:03d}_static.png"))
48 |
49 |
50 | if __name__ == "__main__":
51 | # Set up command line argument parser
52 | parser = ArgumentParser(description="Training script parameters")
53 | parser.add_argument("--config_path", type=str, required=True)
54 | params, _ = parser.parse_known_args()
55 |
56 | args = OmegaConf.load(params.config_path)
57 | args.resolution_scales = args.resolution_scales[:1]
58 | print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True))))
59 |
60 |
61 | seed_everything(args.seed)
62 |
63 | sep_path = os.path.join(args.model_path, 'separation')
64 | os.makedirs(sep_path, exist_ok=True)
65 |
66 | gaussians = GaussianModel(args)
67 | scene = Scene(args, gaussians, shuffle=False)
68 |
69 | if args.env_map_res > 0:
70 | env_map = EnvLight(resolution=args.env_map_res).cuda()
71 | env_map.training_setup(args)
72 | else:
73 | env_map = None
74 |
75 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
76 | assert len(checkpoints) > 0, "No checkpoints found."
77 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
78 | print(f"Loading checkpoint {checkpoint}")
79 | (model_params, first_iter) = torch.load(checkpoint)
80 | gaussians.restore(model_params, args)
81 |
82 | if env_map is not None:
83 | env_checkpoint = os.path.join(os.path.dirname(checkpoint),
84 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt"))
85 | (light_params, _) = torch.load(env_checkpoint)
86 | env_map.restore(light_params)
87 |
88 | bg_color = [1, 1, 1] if args.white_background else [0, 0, 0]
89 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
90 | render_func, _ = get_renderer(args.render_type)
91 | separation(scene, render_func, (args, background), env_map=env_map)
92 |
93 | print("Rendering statics and dynamics complete.")
94 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import cv2
14 | from scene.cameras import Camera
15 | import numpy as np
16 | from scene.scene_utils import CameraInfo
17 | from tqdm import tqdm
18 | from .graphics_utils import fov2focal
19 |
20 |
21 | def loadCam(args, id, cam_info: CameraInfo, resolution_scale):
22 | orig_w, orig_h = cam_info.width, cam_info.height # cam_info.image.size
23 |
24 | if args.resolution in [1, 2, 3, 4, 8, 16, 32]:
25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round(
26 | orig_h / (resolution_scale * args.resolution)
27 | )
28 | scale = resolution_scale * args.resolution
29 | else: # should be a type that converts to float
30 | if args.resolution == -1:
31 | global_down = 1
32 | else:
33 | global_down = orig_w / args.resolution
34 |
35 | scale = float(global_down) * float(resolution_scale)
36 | resolution = (int(orig_w / scale), int(orig_h / scale))
37 |
38 | if cam_info.cx:
39 | cx = cam_info.cx / scale
40 | cy = cam_info.cy / scale
41 | fy = cam_info.fy / scale
42 | fx = cam_info.fx / scale
43 | else:
44 | cx = None
45 | cy = None
46 | fy = None
47 | fx = None
48 |
49 | if cam_info.image.shape[:2] != resolution[::-1]:
50 | image_rgb = cv2.resize(cam_info.image, resolution)
51 | else:
52 | image_rgb = cam_info.image
53 | image_rgb = torch.from_numpy(image_rgb).float().permute(2, 0, 1)
54 | gt_image = image_rgb[:3, ...]
55 |
56 | if cam_info.sky_mask is not None:
57 | if cam_info.sky_mask.shape[:2] != resolution[::-1]:
58 | sky_mask = cv2.resize(cam_info.sky_mask, resolution)
59 | else:
60 | sky_mask = cam_info.sky_mask
61 | if len(sky_mask.shape) == 2:
62 | sky_mask = sky_mask[..., None]
63 | sky_mask = torch.from_numpy(sky_mask).float().permute(2, 0, 1)
64 | else:
65 | sky_mask = None
66 |
67 | if cam_info.pointcloud_camera is not None:
68 | h, w = gt_image.shape[1:]
69 | K = np.eye(3)
70 | if cam_info.cx:
71 | K[0, 0] = fx
72 | K[1, 1] = fy
73 | K[0, 2] = cx
74 | K[1, 2] = cy
75 | else:
76 | K[0, 0] = fov2focal(cam_info.FovX, w)
77 | K[1, 1] = fov2focal(cam_info.FovY, h)
78 | K[0, 2] = cam_info.width / 2
79 | K[1, 2] = cam_info.height / 2
80 | pts_depth = np.zeros([1, h, w])
81 | point_camera = cam_info.pointcloud_camera
82 | uvz = point_camera[point_camera[:, 2] > 0]
83 | uvz = uvz @ K.T
84 | uvz[:, :2] /= uvz[:, 2:]
85 | uvz = uvz[uvz[:, 1] >= 0]
86 | uvz = uvz[uvz[:, 1] < h]
87 | uvz = uvz[uvz[:, 0] >= 0]
88 | uvz = uvz[uvz[:, 0] < w]
89 | uv = uvz[:, :2]
90 | uv = uv.astype(int)
91 | # TODO: may need to consider overlap
92 | pts_depth[0, uv[:, 1], uv[:, 0]] = uvz[:, 2]
93 | pts_depth = torch.from_numpy(pts_depth).float()
94 |
95 | elif cam_info.depth_map is not None:
96 | if cam_info.depth_map.shape[:2] != resolution[::-1]:
97 | depth_map = cv2.resize(cam_info.depth_map, resolution)
98 | else:
99 | depth_map = cam_info.depth_map
100 | depth_map = torch.from_numpy(depth_map).float()
101 | pts_depth = depth_map.unsqueeze(0)
102 |
103 | else:
104 | pts_depth = None
105 |
106 | if cam_info.dynamic_mask is not None:
107 | if cam_info.dynamic_mask.shape[:2] != resolution[::-1]:
108 | dynamic_mask = cv2.resize(cam_info.dynamic_mask, resolution)
109 | else:
110 | dynamic_mask = cam_info.dynamic_mask
111 | if len(dynamic_mask.shape) == 2:
112 | dynamic_mask = dynamic_mask[..., None]
113 | dynamic_mask = torch.from_numpy(dynamic_mask).float().permute(2, 0, 1)
114 | else:
115 | dynamic_mask = None
116 |
117 | if cam_info.normal_map is not None:
118 | if cam_info.normal_map.shape[:2] != resolution[::-1]:
119 | normal_map = cv2.resize(cam_info.normal_map, resolution)
120 | else:
121 | normal_map = cam_info.normal_map
122 | normal_map = torch.from_numpy(normal_map).float().permute(2, 0, 1) # H, W, 3-> 3, H, W
123 | else:
124 | normal_map = None
125 |
126 | return Camera(
127 | colmap_id=cam_info.uid,
128 | uid=id,
129 | R=cam_info.R,
130 | T=cam_info.T,
131 | FoVx=cam_info.FovX,
132 | FoVy=cam_info.FovY,
133 | cx=cx,
134 | cy=cy,
135 | fx=fx,
136 | fy=fy,
137 | image=gt_image,
138 | image_name=cam_info.image_name,
139 | data_device=args.data_device,
140 | timestamp=cam_info.timestamp,
141 | resolution=resolution,
142 | image_path=cam_info.image_path,
143 | pts_depth=pts_depth,
144 | sky_mask=sky_mask,
145 | dynamic_mask = dynamic_mask,
146 | normal_map = normal_map,
147 | )
148 |
149 |
150 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
151 | camera_list = []
152 |
153 | for id, c in enumerate(tqdm(cam_infos, bar_format='{l_bar}{bar:50}{r_bar}')):
154 | camera_list.append(loadCam(args, id, c, resolution_scale))
155 |
156 | return camera_list
157 |
158 |
159 | def camera_to_JSON(id, camera: Camera):
160 | Rt = np.zeros((4, 4))
161 | Rt[:3, :3] = camera.R.transpose()
162 | Rt[:3, 3] = camera.T
163 | Rt[3, 3] = 1.0
164 |
165 | W2C = np.linalg.inv(Rt)
166 | pos = W2C[:3, 3]
167 | rot = W2C[:3, :3]
168 | serializable_array_2d = [x.tolist() for x in rot]
169 |
170 | if camera.cx is None:
171 | camera_entry = {
172 | "id": id,
173 | "img_name": camera.image_name,
174 | "width": camera.width,
175 | "height": camera.height,
176 | "position": pos.tolist(),
177 | "rotation": serializable_array_2d,
178 | "FoVx": camera.FovX,
179 | "FoVy": camera.FovY,
180 | }
181 | else:
182 | camera_entry = {
183 | "id": id,
184 | "img_name": camera.image_name,
185 | "width": camera.width,
186 | "height": camera.height,
187 | "position": pos.tolist(),
188 | "rotation": serializable_array_2d,
189 | "fx": camera.fx,
190 | "fy": camera.fy,
191 | "cx": camera.cx,
192 | "cy": camera.cy,
193 | }
194 | return camera_entry
195 |
196 |
197 | def calculate_mean_and_std(mean_per_image, std_per_image):
198 | # Calculate mean RGB across dataset
199 | mean_dataset = np.mean(mean_per_image, axis=0)
200 |
201 | # Calculate variance of each image
202 | variances = std_per_image**2
203 |
204 | # Calculate overall variance across the dataset
205 | overall_variance = np.mean(variances, axis=0) + np.mean((mean_per_image - mean_dataset)**2, axis=0) # (C,)
206 |
207 | # Calculate std RGB across dataset
208 | std_rgb_dataset = np.sqrt(overall_variance)
209 |
210 | return mean_dataset, std_rgb_dataset
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | import numpy as np
15 | import random
16 | from matplotlib import cm
17 |
18 | def GridSample3D(in_pc,in_shs, voxel_size=0.013):
19 | in_pc_ = in_pc[:,:3].copy()
20 | quantized_pc = np.around(in_pc_ / voxel_size)
21 | quantized_pc -= np.min(quantized_pc, axis=0)
22 | pc_boundary = np.max(quantized_pc, axis=0) - np.min(quantized_pc, axis=0)
23 |
24 | voxel_index = quantized_pc[:,0] * pc_boundary[1] * pc_boundary[2] + quantized_pc[:,1] * pc_boundary[2] + quantized_pc[:,2]
25 |
26 | split_point, index = get_split_point(voxel_index)
27 |
28 | in_points = in_pc[index,:]
29 | out_points = in_points[split_point[:-1],:]
30 |
31 | in_colors = in_shs[index]
32 | out_colors = in_colors[split_point[:-1]]
33 |
34 | return out_points,out_colors
35 |
36 | def get_split_point(labels):
37 | index = np.argsort(labels)
38 | label = labels[index]
39 | label_shift = label.copy()
40 |
41 | label_shift[1:] = label[:-1]
42 | remain = label - label_shift
43 | step_index = np.where(remain > 0)[0].tolist()
44 | step_index.insert(0,0)
45 | step_index.append(labels.shape[0])
46 | return step_index,index
47 |
48 | def sample_on_aabb_surface(aabb_center, aabb_size, n_pts=1000, above_half=False):
49 | """
50 | 0:立方体的左面(x轴负方向)
51 | 1:立方体的右面(x轴正方向)
52 | 2:立方体的下面(y轴负方向)
53 | 3:立方体的上面(y轴正方向)
54 | 4:立方体的后面(z轴负方向)
55 | 5:立方体的前面(z轴正方向)
56 | """
57 | # Choose a face randomly
58 | faces = np.random.randint(0, 6, size=n_pts)
59 |
60 | # Generate two random numbers
61 | r_ = np.random.random((n_pts, 2))
62 |
63 | # Create an array to store the points
64 | points = np.zeros((n_pts, 3))
65 |
66 | # Define the offsets for each face
67 | offsets = np.array([
68 | [-aabb_size[0]/2, 0, 0],
69 | [aabb_size[0]/2, 0, 0],
70 | [0, -aabb_size[1]/2, 0],
71 | [0, aabb_size[1]/2, 0],
72 | [0, 0, -aabb_size[2]/2],
73 | [0, 0, aabb_size[2]/2]
74 | ])
75 |
76 | # Define the scales for each face
77 | scales = np.array([
78 | [aabb_size[1], aabb_size[2]],
79 | [aabb_size[1], aabb_size[2]],
80 | [aabb_size[0], aabb_size[2]],
81 | [aabb_size[0], aabb_size[2]],
82 | [aabb_size[0], aabb_size[1]],
83 | [aabb_size[0], aabb_size[1]]
84 | ])
85 |
86 | # Define the positions of the zero column for each face
87 | zero_column_positions = [0, 0, 1, 1, 2, 2]
88 | # Define the indices of the aabb_size components for each face
89 | aabb_size_indices = [[1, 2], [1, 2], [0, 2], [0, 2], [0, 1], [0, 1]]
90 | # Calculate the coordinates of the points for each face
91 | for i in range(6):
92 | mask = faces == i
93 | r_scaled = r_[mask] * scales[i]
94 | r_scaled = np.insert(r_scaled, zero_column_positions[i], 0, axis=1)
95 | aabb_size_adjusted = np.insert(aabb_size[aabb_size_indices[i]] / 2, zero_column_positions[i], 0)
96 | points[mask] = aabb_center + offsets[i] + r_scaled - aabb_size_adjusted
97 | #visualize_points(points[mask], aabb_center, aabb_size)
98 | #visualize_points(points, aabb_center, aabb_size)
99 |
100 | # 提取上半部分的点
101 | if above_half:
102 | points = points[points[:, -1] > aabb_center[-1]]
103 | return points
104 |
105 | def get_OccGrid(pts, aabb, occ_voxel_size):
106 | # 计算网格的大小
107 | grid_size = np.ceil((aabb[1] - aabb[0]) / occ_voxel_size).astype(int)
108 | assert pts.min() >= aabb[0].min() and pts.max() <= aabb[1].max(), "Points are outside the AABB"
109 |
110 | # 创建一个空的网格
111 | voxel_grid = np.zeros(grid_size, dtype=np.uint8)
112 |
113 | # 将点云转换为网格坐标
114 | grid_pts = ((pts - aabb[0]) / occ_voxel_size).astype(int)
115 |
116 | # 将网格中的点设置为1
117 | voxel_grid[grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2]] = 1
118 |
119 | # check
120 | #voxel_coords = np.floor((pts - aabb[0]) / occ_voxel_size).astype(int)
121 | #occ = voxel_grid[voxel_coords[:, 0], voxel_coords[:, 1], voxel_coords[:, 2]]
122 |
123 | return voxel_grid
124 |
125 | def visualize_depth(depth, near=0.2, far=13, linear=False):
126 | depth = depth[0].clone().detach().cpu().numpy()
127 | colormap = cm.get_cmap('turbo')
128 | curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps)
129 | if linear:
130 | curve_fn = lambda x: -x
131 | eps = np.finfo(np.float32).eps
132 | near = near if near else depth.min()
133 | far = far if far else depth.max()
134 | near -= eps
135 | far += eps
136 | near, far, depth = [curve_fn(x) for x in [near, far, depth]]
137 | depth = np.nan_to_num(
138 | np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1))
139 | vis = colormap(depth)[:, :, :3]
140 | out_depth = np.clip(np.nan_to_num(vis), 0., 1.) * 255
141 | out_depth = torch.from_numpy(out_depth).permute(2, 0, 1).float().cuda() / 255
142 | return out_depth
143 |
144 |
145 | def inverse_sigmoid(x):
146 | return torch.log(x / (1 - x))
147 |
148 |
149 | def PILtoTorch(pil_image, resolution):
150 | resized_image_PIL = pil_image.resize(resolution)
151 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
152 | if len(resized_image.shape) == 3:
153 | return resized_image.permute(2, 0, 1)
154 | else:
155 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
156 |
157 |
158 | def get_step_lr_func(lr_init, lr_final, start_step):
159 | def helper(step):
160 | if step < start_step:
161 | return lr_init
162 | else:
163 | return lr_final
164 | return helper
165 |
166 | def get_expon_lr_func(
167 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
168 | ):
169 | """
170 | Copied from Plenoxels
171 |
172 | Continuous learning rate decay function. Adapted from JaxNeRF
173 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
174 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
175 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
176 | function of lr_delay_mult, such that the initial learning rate is
177 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
178 | to the normal learning rate when steps>lr_delay_steps.
179 | :param conf: config subtree 'lr' or similar
180 | :param max_steps: int, the number of steps during optimization.
181 | :return HoF which takes step as input
182 | """
183 |
184 | def helper(step):
185 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
186 | # Disable this parameter
187 | return 0.0
188 | if lr_delay_steps > 0:
189 | # A kind of reverse cosine decay.
190 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
191 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
192 | )
193 | else:
194 | delay_rate = 1.0
195 | t = np.clip(step / max_steps, 0, 1)
196 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
197 | return delay_rate * log_lerp
198 |
199 | return helper
200 |
201 |
202 | def strip_lowerdiag(L):
203 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
204 |
205 | uncertainty[:, 0] = L[:, 0, 0]
206 | uncertainty[:, 1] = L[:, 0, 1]
207 | uncertainty[:, 2] = L[:, 0, 2]
208 | uncertainty[:, 3] = L[:, 1, 1]
209 | uncertainty[:, 4] = L[:, 1, 2]
210 | uncertainty[:, 5] = L[:, 2, 2]
211 | return uncertainty
212 |
213 |
214 | def strip_symmetric(sym):
215 | return strip_lowerdiag(sym)
216 |
217 |
218 | def build_rotation(r):
219 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
220 |
221 | q = r / norm[:, None]
222 |
223 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
224 |
225 | r = q[:, 0]
226 | x = q[:, 1]
227 | y = q[:, 2]
228 | z = q[:, 3]
229 |
230 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
231 | R[:, 0, 1] = 2 * (x * y - r * z)
232 | R[:, 0, 2] = 2 * (x * z + r * y)
233 | R[:, 1, 0] = 2 * (x * y + r * z)
234 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
235 | R[:, 1, 2] = 2 * (y * z - r * x)
236 | R[:, 2, 0] = 2 * (x * z - r * y)
237 | R[:, 2, 1] = 2 * (y * z + r * x)
238 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
239 | return R
240 |
241 | def rotation_to_quaternion(R):
242 | r11, r12, r13 = R[:, 0, 0], R[:, 0, 1], R[:, 0, 2]
243 | r21, r22, r23 = R[:, 1, 0], R[:, 1, 1], R[:, 1, 2]
244 | r31, r32, r33 = R[:, 2, 0], R[:, 2, 1], R[:, 2, 2]
245 |
246 | qw = torch.sqrt((1 + r11 + r22 + r33).clamp_min(1e-7)) / 2
247 | qx = (r32 - r23) / (4 * qw)
248 | qy = (r13 - r31) / (4 * qw)
249 | qz = (r21 - r12) / (4 * qw)
250 |
251 | quaternion = torch.stack((qw, qx, qy, qz), dim=-1)
252 | quaternion = torch.nn.functional.normalize(quaternion, dim=-1)
253 | return quaternion
254 |
255 | def quaternion_to_rotation_matrix(q):
256 | w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
257 |
258 | r11 = 1 - 2 * y * y - 2 * z * z
259 | r12 = 2 * x * y - 2 * w * z
260 | r13 = 2 * x * z + 2 * w * y
261 |
262 | r21 = 2 * x * y + 2 * w * z
263 | r22 = 1 - 2 * x * x - 2 * z * z
264 | r23 = 2 * y * z - 2 * w * x
265 |
266 | r31 = 2 * x * z - 2 * w * y
267 | r32 = 2 * y * z + 2 * w * x
268 | r33 = 1 - 2 * x * x - 2 * y * y
269 |
270 | rotation_matrix = torch.stack((torch.stack((r11, r12, r13), dim=1),
271 | torch.stack((r21, r22, r23), dim=1),
272 | torch.stack((r31, r32, r33), dim=1)), dim=1)
273 | return rotation_matrix
274 |
275 | def quaternion_multiply(q1, q2):
276 | w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
277 | w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
278 |
279 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
280 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
281 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
282 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
283 |
284 | result_quaternion = torch.stack((w, x, y, z), dim=1)
285 | return result_quaternion
286 |
287 | def build_scaling_rotation(s, r):
288 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
289 | R = build_rotation(r)
290 |
291 | L[:, 0, 0] = s[:, 0]
292 | L[:, 1, 1] = s[:, 1]
293 | L[:, 2, 2] = s[:, 2]
294 |
295 | L = R @ L
296 | return L
297 |
298 |
299 | def seed_everything(seed):
300 | random.seed(seed)
301 | os.environ['PYTHONHASHSEED'] = str(seed)
302 | np.random.seed(seed)
303 | torch.manual_seed(seed)
304 | torch.cuda.manual_seed(seed)
305 | torch.cuda.manual_seed_all(seed)
306 |
307 | import logging
308 |
309 | import sys
310 |
311 | def init_logging(filename=None, debug=False):
312 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
313 | formatter = logging.Formatter('[%(asctime)s][%(filename)s][%(levelname)s] - %(message)s')
314 | stream_handler = logging.StreamHandler(sys.stdout)
315 |
316 | stream_handler.setFormatter(formatter)
317 | logging.root.addHandler(stream_handler)
318 |
319 | if filename is not None:
320 | file_handler = logging.FileHandler(filename)
321 | file_handler.setFormatter(formatter)
322 | logging.root.addHandler(file_handler)
323 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 | def ndc_2_cam(ndc_xyz, intrinsic, W, H):
18 | inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device)
19 | cam_z = ndc_xyz[..., 2:3]
20 | cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z
21 | cam_xyz = torch.cat([cam_xy, cam_z], dim=-1)
22 | cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t())
23 | return cam_xyz
24 |
25 | def depth2point_cam(sampled_depth, ref_intrinsic):
26 | B, N, C, H, W = sampled_depth.shape
27 | valid_z = sampled_depth
28 | valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device) / (W - 1)
29 | valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device) / (H - 1)
30 | valid_x, valid_y = torch.meshgrid(valid_x, valid_y, indexing='xy')
31 | # B,N,H,W
32 | valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1)
33 | valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1)
34 | ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view(B, N, C, H, W, 3) # 1, 1, 5, 512, 640, 3
35 | cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H) # 1, 1, 5, 512, 640, 3
36 | return ndc_xyz, cam_xyz
37 |
38 | def depth2point_world(depth_image, intrinsic_matrix, extrinsic_matrix):
39 | # depth_image: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)
40 | _, xyz_cam = depth2point_cam(depth_image[None,None,None,...], intrinsic_matrix[None,...])
41 | xyz_cam = xyz_cam.reshape(-1,3)
42 | # xyz_world = torch.cat([xyz_cam, torch.ones_like(xyz_cam[...,0:1])], axis=-1) @ torch.inverse(extrinsic_matrix).transpose(0,1)
43 | # xyz_world = xyz_world[...,:3]
44 |
45 | return xyz_cam
46 |
47 | def depth_pcd2normal(xyz, offset=None, gt_image=None):
48 | hd, wd, _ = xyz.shape
49 | if offset is not None:
50 | ix, iy = torch.meshgrid(
51 | torch.arange(wd), torch.arange(hd), indexing='xy')
52 | xy = (torch.stack((ix, iy), dim=-1)[1:-1,1:-1]).to(xyz.device)
53 | p_offset = torch.tensor([[0,1],[0,-1],[1,0],[-1,0]]).float().to(xyz.device)
54 | new_offset = p_offset[None,None] + offset.reshape(hd, wd, 4, 2)[1:-1,1:-1]
55 | xys = xy[:,:,None] + new_offset
56 | xys[..., 0] = 2 * xys[..., 0] / (wd - 1) - 1.0
57 | xys[..., 1] = 2 * xys[..., 1] / (hd - 1) - 1.0
58 | sampled_xyzs = torch.nn.functional.grid_sample(xyz.permute(2,0,1)[None], xys.reshape(1, -1, 1, 2))
59 | sampled_xyzs = sampled_xyzs.permute(0,2,3,1).reshape(hd-2,wd-2,4,3)
60 | bottom_point = sampled_xyzs[:,:,0]
61 | top_point = sampled_xyzs[:,:,1]
62 | right_point = sampled_xyzs[:,:,2]
63 | left_point = sampled_xyzs[:,:,3]
64 | else:
65 | bottom_point = xyz[..., 2:hd, 1:wd-1, :]
66 | top_point = xyz[..., 0:hd-2, 1:wd-1, :]
67 | right_point = xyz[..., 1:hd-1, 2:wd, :]
68 | left_point = xyz[..., 1:hd-1, 0:wd-2, :]
69 | left_to_right = right_point - left_point
70 | bottom_to_top = top_point - bottom_point
71 | xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)
72 | xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)
73 | xyz_normal = torch.nn.functional.pad(xyz_normal.permute(2,0,1), (1,1,1,1), mode='constant').permute(1,2,0)
74 | return xyz_normal
75 |
76 | def normal_from_depth_image(depth, intrinsic_matrix, extrinsic_matrix, offset=None, gt_image=None):
77 | # depth: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)
78 | # xyz_normal: (H, W, 3)
79 | xyz_world = depth2point_world(depth, intrinsic_matrix, extrinsic_matrix) # (HxW, 3)
80 | xyz_world = xyz_world.reshape(*depth.shape, 3)
81 | xyz_normal = depth_pcd2normal(xyz_world, offset, gt_image)
82 |
83 | return xyz_normal
84 |
85 | def render_normal(viewpoint_camera, depth, offset=None, normal=None, scale=1):
86 | # depth: (H, W), bg_color: (3), alpha: (H, W)
87 | # normal_ref: (3, H, W)
88 | intrinsic_matrix, extrinsic_matrix = viewpoint_camera.get_calib_matrix_nerf(
89 | scale=scale
90 | )
91 | st = max(int(scale / 2) - 1, 0)
92 | if offset is not None:
93 | offset = offset[st::scale, st::scale]
94 | normal_ref = normal_from_depth_image(
95 | depth[st::scale, st::scale],
96 | intrinsic_matrix.to(depth.device),
97 | extrinsic_matrix.to(depth.device),
98 | offset,
99 | )
100 |
101 | normal_ref = normal_ref.permute(2, 0, 1)
102 | return normal_ref
103 |
104 | def normal_from_neareast(normal, offset):
105 | _, hd, wd = normal.shape
106 | left_top_point = normal[..., 0:hd-2, 0:wd-2]
107 | top_point = normal[..., 0:hd-2, 1:wd-1]
108 | right_top_point= normal[..., 0:hd-2, 2:wd]
109 | left_point = normal[..., 1:hd-1, 0:wd-2]
110 | right_point = normal[..., 1:hd-1, 2:wd]
111 | left_bottom_point = normal[..., 2:hd, 0:wd-2]
112 | bottom_point = normal[..., 2:hd, 1:wd-1]
113 | right_bottom_point = normal[..., 2:hd, 2:wd]
114 | normals = torch.stack((left_top_point,top_point,right_top_point,left_point,right_point,left_bottom_point,bottom_point,right_bottom_point),dim=0)
115 | new_normal = (normals * offset[:,None,1:-1,1:-1]).sum(0)
116 | new_normal = torch.nn.functional.normalize(new_normal, p=2, dim=0)
117 | new_normal = torch.nn.functional.pad(new_normal, (1,1,1,1), mode='constant').permute(1,2,0)
118 | return new_normal
119 |
120 | def patch_offsets(h_patch_size, device):
121 | offsets = torch.arange(-h_patch_size, h_patch_size + 1, device=device)
122 | return torch.stack(torch.meshgrid(offsets, offsets, indexing='xy')[::-1], dim=-1).view(1, -1, 2)
123 |
124 | def patch_warp(H, uv):
125 | B, P = uv.shape[:2]
126 | H = H.view(B, 3, 3)
127 | ones = torch.ones((B,P,1), device=uv.device)
128 | homo_uv = torch.cat((uv, ones), dim=-1)
129 |
130 | grid_tmp = torch.einsum("bik,bpk->bpi", H, homo_uv)
131 | grid_tmp = grid_tmp.reshape(B, P, 3)
132 | grid = grid_tmp[..., :2] / (grid_tmp[..., 2:] + 1e-10)
133 | return grid
134 |
135 | class BasicPointCloud(NamedTuple):
136 | points : np.array
137 | colors : np.array
138 | normals : np.array
139 | time : np.array = None
140 |
141 | def getWorld2View(R, t):
142 | Rt = np.zeros((4, 4))
143 | Rt[:3, :3] = R.transpose()
144 | Rt[:3, 3] = t
145 | Rt[3, 3] = 1.0
146 | return np.float32(Rt)
147 |
148 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
149 | Rt = np.zeros((4, 4))
150 | Rt[:3, :3] = R.transpose()
151 | Rt[:3, 3] = t
152 | Rt[3, 3] = 1.0
153 |
154 | C2W = np.linalg.inv(Rt)
155 | cam_center = C2W[:3, 3]
156 | cam_center = (cam_center + translate) * scale
157 | C2W[:3, 3] = cam_center
158 | Rt = np.linalg.inv(C2W)
159 | return np.float32(Rt)
160 |
161 | def getProjectionMatrix(znear, zfar, fovX, fovY):
162 | tanHalfFovY = math.tan((fovY / 2))
163 | tanHalfFovX = math.tan((fovX / 2))
164 |
165 | top = tanHalfFovY * znear
166 | bottom = -top
167 | right = tanHalfFovX * znear
168 | left = -right
169 |
170 | P = torch.zeros(4, 4)
171 |
172 | z_sign = 1.0
173 |
174 | P[0, 0] = 2.0 * znear / (right - left)
175 | P[1, 1] = 2.0 * znear / (top - bottom)
176 | P[0, 2] = (right + left) / (right - left)
177 | P[1, 2] = (top + bottom) / (top - bottom)
178 | P[3, 2] = z_sign
179 | P[2, 2] = z_sign * zfar / (zfar - znear)
180 | P[2, 3] = -(zfar * znear) / (zfar - znear)
181 | return P
182 |
183 | def getProjectionMatrixCenterShift(znear, zfar, cx, cy, fx, fy, w, h):
184 | top = cy / fy * znear
185 | bottom = -(h-cy) / fy * znear
186 |
187 | left = -(w-cx) / fx * znear
188 | right = cx / fx * znear
189 |
190 | P = torch.zeros(4, 4)
191 |
192 | z_sign = 1.0
193 |
194 | P[0, 0] = 2.0 * znear / (right - left)
195 | P[1, 1] = 2.0 * znear / (top - bottom)
196 | P[0, 2] = (right + left) / (right - left)
197 | P[1, 2] = (top + bottom) / (top - bottom)
198 | P[3, 2] = z_sign
199 | P[2, 2] = z_sign * zfar / (zfar - znear)
200 | P[2, 3] = -(zfar * znear) / (zfar - znear)
201 | return P
202 |
203 | def fov2focal(fov, pixels):
204 | return pixels / (2 * math.tan(fov / 2))
205 |
206 | def focal2fov(focal, pixels):
207 | return 2*math.atan(pixels/(2*focal))
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import numpy as np
14 | import PIL
15 | import torch.nn.functional as F
16 | import torch.nn as nn
17 | from typing import Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
18 |
19 | # RGB colors used to visualize each semantic segmentation class.
20 | SEGMENTATION_COLOR_MAP = dict(
21 | TYPE_UNDEFINED=[0, 0, 0],
22 | TYPE_EGO_VEHICLE=[102, 102, 102],
23 | TYPE_CAR=[0, 0, 142],
24 | TYPE_TRUCK=[0, 0, 70],
25 | TYPE_BUS=[0, 60, 100],
26 | TYPE_OTHER_LARGE_VEHICLE=[61, 133, 198],
27 | TYPE_BICYCLE=[119, 11, 32],
28 | TYPE_MOTORCYCLE=[0, 0, 230],
29 | TYPE_TRAILER=[111, 168, 220],
30 | TYPE_PEDESTRIAN=[220, 20, 60],
31 | TYPE_CYCLIST=[255, 0, 0],
32 | TYPE_MOTORCYCLIST=[180, 0, 0],
33 | TYPE_BIRD=[127, 96, 0],
34 | TYPE_GROUND_ANIMAL=[91, 15, 0],
35 | TYPE_CONSTRUCTION_CONE_POLE=[230, 145, 56],
36 | TYPE_POLE=[153, 153, 153],
37 | TYPE_PEDESTRIAN_OBJECT=[234, 153, 153],
38 | TYPE_SIGN=[246, 178, 107],
39 | TYPE_TRAFFIC_LIGHT=[250, 170, 30],
40 | TYPE_BUILDING=[70, 70, 70],
41 | TYPE_ROAD=[128, 64, 128],
42 | TYPE_LANE_MARKER=[234, 209, 220],
43 | TYPE_ROAD_MARKER=[217, 210, 233],
44 | TYPE_SIDEWALK=[244, 35, 232],
45 | TYPE_VEGETATION=[107, 142, 35],
46 | TYPE_SKY=[70, 130, 180],
47 | TYPE_GROUND=[102, 102, 102],
48 | TYPE_DYNAMIC=[102, 102, 102],
49 | TYPE_STATIC=[102, 102, 102],
50 | )
51 |
52 | def _generate_color_map(
53 | color_map_dict: Optional[
54 | Mapping[int, Sequence[int]]] = None
55 | ) -> np.ndarray:
56 | """Generates a mapping from segmentation classes (rows) to colors (cols).
57 |
58 | Args:
59 | color_map_dict: An optional dict mapping from semantic classes to colors. If
60 | None, the default colors in SEGMENTATION_COLOR_MAP will be used.
61 | Returns:
62 | A np array of shape [max_class_id + 1, 3], where each row encodes the color
63 | for the corresponding class id.
64 | """
65 | if color_map_dict is None:
66 | color_map_dict = SEGMENTATION_COLOR_MAP
67 | classes = list(color_map_dict.keys())
68 | colors = list(color_map_dict.values())
69 | color_map = np.zeros([#np.amax(classes) + 1
70 | len(classes)
71 | , 3], dtype=np.uint8)
72 | for idx, color in enumerate(colors):
73 | color_map[idx] = color
74 | #color_map[classes] = colors
75 | return color_map
76 |
77 | DEFAULT_COLOR_MAP = _generate_color_map()
78 |
79 | def get_panoptic_id(semantic_id, instance_id, semantic_interval=1000):
80 | if isinstance(semantic_id, np.ndarray):
81 | semantic_id = torch.from_numpy(semantic_id)
82 | instance_id = torch.from_numpy(instance_id)
83 | elif isinstance(semantic_id, PIL.Image.Image):
84 | semantic_id = torch.from_numpy(np.array(semantic_id))
85 | instance_id = torch.from_numpy(np.array(instance_id))
86 | elif isinstance(semantic_id, torch.Tensor):
87 | pass
88 | else:
89 | raise ValueError("semantic_id type is not supported!")
90 |
91 | return semantic_id * semantic_interval + instance_id
92 |
93 | def get_panoptic_encoding(semantic_id, instance_id, ):
94 | # 将 semantic-id 和 instance-id 编码成 panoptic one-hot编码
95 | panoptic_id = get_panoptic_id(semantic_id, instance_id)
96 | unique_panoptic_classes = panoptic_id.unique()
97 | num_panoptic_classes = unique_panoptic_classes.shape[0]
98 | # construct id map dict: panoptic_id -> num_class_idx
99 | id_to_idx_dict = {}
100 | for i in range(num_panoptic_classes):
101 | id_to_idx_dict[unique_panoptic_classes[i]] = i
102 |
103 | # convert to one-hot encoding
104 | panoptic_encoding = torch.zeros((num_panoptic_classes, ), dtype=torch.float32)
105 |
106 |
107 | def feat_encode(obj_id, id_to_idx, gt_label_embedding: nn.Embedding = None, output_both=False, only_idx=False):
108 | """ 根据 obj_id 和 id_to_idx_dict 编码成 one-hot """
109 |
110 | map_ids = torch.zeros_like(obj_id) #obj_id.clone()
111 | # 将 gt-obj-id 替换成 global-obj-idx ,然后转成 one-hot
112 | for key, value in id_to_idx.items():
113 | map_ids[obj_id == key] = value
114 |
115 | # query embedding
116 | if gt_label_embedding is not None:
117 | gt_label = gt_label_embedding(map_ids.flatten().long())
118 | else:
119 | gt_label = None
120 |
121 | if output_both:
122 | return map_ids, gt_label
123 | else:
124 | if only_idx:
125 | return map_ids.long().flatten()
126 | else:
127 | return gt_label
128 |
129 |
130 |
131 |
132 |
133 | def mse(img1, img2):
134 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
135 |
136 | def psnr(img1, img2):
137 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
138 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
139 |
140 | def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
141 | # features: (N, C)
142 | # m: a hyperparam controlling how many std dev outside for outliers
143 | assert len(features.shape) == 2, "features should be (N, C)"
144 | reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
145 | colors = features @ reduction_mat
146 | if remove_first_component:
147 | colors_min = colors.min(dim=0).values
148 | colors_max = colors.max(dim=0).values
149 | tmp_colors = (colors - colors_min) / (colors_max - colors_min)
150 | fg_mask = tmp_colors[..., 0] < 0.2
151 | reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
152 | colors = features @ reduction_mat
153 | else:
154 | fg_mask = torch.ones_like(colors[:, 0]).bool()
155 | d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
156 | mdev = torch.median(d, dim=0).values
157 | s = d / mdev
158 | rins = colors[fg_mask][s[:, 0] < m, 0]
159 | gins = colors[fg_mask][s[:, 1] < m, 1]
160 | bins = colors[fg_mask][s[:, 2] < m, 2]
161 |
162 | rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
163 | rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
164 | return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
165 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 | import torch.nn as nn
17 | from kornia.filters import laplacian, spatial_gradient
18 | import numpy as np
19 | def psnr(img1, img2):
20 | mse = F.mse_loss(img1, img2)
21 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 | def create_window(window_size, channel):
28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31 | return window
32 |
33 | def ssim(img1, img2, window_size=11, size_average=True):
34 | channel = img1.size(-3)
35 | window = create_window(window_size, channel)
36 |
37 | if img1.is_cuda:
38 | window = window.cuda(img1.get_device())
39 | window = window.type_as(img1)
40 |
41 | return _ssim(img1, img2, window, window_size, channel, size_average)
42 |
43 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54 |
55 | C1 = 0.01 ** 2
56 | C2 = 0.03 ** 2
57 |
58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59 |
60 | if size_average:
61 | return ssim_map.mean()
62 | else:
63 | return ssim_map
64 |
65 | def tv_loss(depth):
66 | c, h, w = depth.shape[0], depth.shape[1], depth.shape[2]
67 | count_h = c * (h - 1) * w
68 | count_w = c * h * (w - 1)
69 | h_tv = torch.square(depth[..., 1:, :] - depth[..., :h-1, :]).sum()
70 | w_tv = torch.square(depth[..., :, 1:] - depth[..., :, :w-1]).sum()
71 | return 2 * (h_tv / count_h + w_tv / count_w)
72 |
73 | def get_img_grad_weight(img, beta=2.0):
74 | _, hd, wd = img.shape
75 | bottom_point = img[..., 2:hd, 1:wd-1]
76 | top_point = img[..., 0:hd-2, 1:wd-1]
77 | right_point = img[..., 1:hd-1, 2:wd]
78 | left_point = img[..., 1:hd-1, 0:wd-2]
79 | grad_img_x = torch.mean(torch.abs(right_point - left_point), 0, keepdim=True)
80 | grad_img_y = torch.mean(torch.abs(top_point - bottom_point), 0, keepdim=True)
81 | grad_img = torch.cat((grad_img_x, grad_img_y), dim=0)
82 | grad_img, _ = torch.max(grad_img, dim=0)
83 | grad_img = (grad_img - grad_img.min()) / (grad_img.max() - grad_img.min())
84 | grad_img = torch.nn.functional.pad(grad_img[None,None], (1,1,1,1), mode='constant', value=1.0).squeeze()
85 | return grad_img
86 |
87 | def lncc(ref, nea):
88 | # ref_gray: [batch_size, total_patch_size]
89 | # nea_grays: [batch_size, total_patch_size]
90 | bs, tps = nea.shape
91 | patch_size = int(np.sqrt(tps))
92 |
93 | ref_nea = ref * nea
94 | ref_nea = ref_nea.view(bs, 1, patch_size, patch_size)
95 | ref = ref.view(bs, 1, patch_size, patch_size)
96 | nea = nea.view(bs, 1, patch_size, patch_size)
97 | ref2 = ref.pow(2)
98 | nea2 = nea.pow(2)
99 |
100 | # sum over kernel
101 | filters = torch.ones(1, 1, patch_size, patch_size, device=ref.device)
102 | padding = patch_size // 2
103 | ref_sum = F.conv2d(ref, filters, stride=1, padding=padding)[:, :, padding, padding]
104 | nea_sum = F.conv2d(nea, filters, stride=1, padding=padding)[:, :, padding, padding]
105 | ref2_sum = F.conv2d(ref2, filters, stride=1, padding=padding)[:, :, padding, padding]
106 | nea2_sum = F.conv2d(nea2, filters, stride=1, padding=padding)[:, :, padding, padding]
107 | ref_nea_sum = F.conv2d(ref_nea, filters, stride=1, padding=padding)[:, :, padding, padding]
108 |
109 | # average over kernel
110 | ref_avg = ref_sum / tps
111 | nea_avg = nea_sum / tps
112 |
113 | cross = ref_nea_sum - nea_avg * ref_sum
114 | ref_var = ref2_sum - ref_avg * ref_sum
115 | nea_var = nea2_sum - nea_avg * nea_sum
116 |
117 | cc = cross * cross / (ref_var * nea_var + 1e-8)
118 | ncc = 1 - cc
119 | ncc = torch.clamp(ncc, 0.0, 2.0)
120 | ncc = torch.mean(ncc, dim=1, keepdim=True)
121 | mask = (ncc < 0.9)
122 | return ncc, mask
123 |
124 |
125 | def dilate(bin_img, ksize=5):
126 | pad = (ksize - 1) // 2
127 | bin_img = F.pad(bin_img, pad=[pad, pad, pad, pad], mode='reflect')
128 | out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
129 | return out
130 |
131 | def erode(bin_img, ksize=5):
132 | out = 1 - dilate(1 - bin_img, ksize)
133 | return out
134 |
135 | def cal_gradient(data):
136 | """
137 | data: [1, C, H, W]
138 | """
139 | kernel_x = [[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]]
140 | kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0).to(data.device)
141 |
142 | kernel_y = [[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]]
143 | kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0).to(data.device)
144 |
145 | weight_x = nn.Parameter(data=kernel_x, requires_grad=False)
146 | weight_y = nn.Parameter(data=kernel_y, requires_grad=False)
147 |
148 | grad_x = F.conv2d(data, weight_x, padding='same')
149 | grad_y = F.conv2d(data, weight_y, padding='same')
150 | gradient = torch.abs(grad_x) + torch.abs(grad_y)
151 |
152 | return gradient
153 |
154 |
155 | def bilateral_smooth_loss(data, image, mask):
156 | """
157 | image: [C, H, W]
158 | data: [C, H, W]
159 | mask: [C, H, W]
160 | """
161 | rgb_grad = cal_gradient(image.mean(0, keepdim=True).unsqueeze(0)).squeeze(0) # [1, H, W]
162 | data_grad = cal_gradient(data.mean(0, keepdim=True).unsqueeze(0)).squeeze(0) # [1, H, W]
163 |
164 | smooth_loss = (data_grad * (-rgb_grad).exp() * mask).mean()
165 |
166 | return smooth_loss
167 |
168 |
169 | def second_order_edge_aware_loss(data, img):
170 | return (spatial_gradient(data[None], order=2)[0, :, [0, 2]].abs() * torch.exp(-10*spatial_gradient(img[None], order=1)[0].abs())).sum(1).mean()
171 |
172 |
173 | def first_order_edge_aware_loss(data, img):
174 | return (spatial_gradient(data[None], order=1)[0].abs() * torch.exp(-spatial_gradient(img[None], order=1)[0].abs())).sum(1).mean()
175 |
176 | def first_order_edge_aware_norm_loss(data, img):
177 | return (spatial_gradient(data[None], order=1)[0].abs() * torch.exp(-spatial_gradient(img[None], order=1)[0].norm(dim=1, keepdim=True))).sum(1).mean()
178 |
179 | def first_order_loss(data):
180 | return spatial_gradient(data[None], order=1)[0].abs().sum(1).mean()
181 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torch
14 | from errno import EEXIST
15 | from os import makedirs, path
16 |
17 | def searchForMaxIteration(folder):
18 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
19 | return max(saved_iters)
20 |
21 | def mkdir_p(folder_path):
22 | # Creates a directory. equivalent to using mkdir -p on the command line
23 | try:
24 | makedirs(folder_path)
25 | except OSError as exc: # Python >2.5
26 | if exc.errno == EEXIST and path.isdir(folder_path):
27 | pass
28 | else:
29 | raise
30 |
31 | class Timing:
32 | """
33 | From https://github.com/sxyu/svox2/blob/ee80e2c4df8f29a407fda5729a494be94ccf9234/svox2/utils.py#L611
34 |
35 | Timing environment
36 | usage:
37 | with Timing("message"):
38 | your commands here
39 | will print CUDA runtime in ms
40 | """
41 |
42 | def __init__(self, name):
43 | self.name = name
44 |
45 | def __enter__(self):
46 | self.start = torch.cuda.Event(enable_timing=True)
47 | self.end = torch.cuda.Event(enable_timing=True)
48 | self.start.record()
49 |
50 | def __exit__(self, type, value, traceback):
51 | self.end.record()
52 | torch.cuda.synchronize()
53 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms")
--------------------------------------------------------------------------------
/visualize_gs.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | import glob
12 | import os
13 | import torch
14 |
15 | from scene import GaussianModel, EnvLight
16 | from utils.general_utils import seed_everything
17 | from tqdm import tqdm
18 | from argparse import ArgumentParser, Namespace
19 | from omegaconf import OmegaConf
20 |
21 | EPS = 1e-5
22 |
23 |
24 |
25 | if __name__ == "__main__":
26 | # Set up command line argument parser
27 | parser = ArgumentParser(description="Training script parameters")
28 | parser.add_argument("--config_path", type=str, required=True)
29 | params, _ = parser.parse_known_args()
30 |
31 | args = OmegaConf.load(params.config_path)
32 | args.resolution_scales = args.resolution_scales[:1]
33 | # print('Configurations:\n {}'.format(pformat(OmegaConf.to_container(args, resolve=True, throw_on_missing=True))))
34 | # convert to DictConfig
35 | conf = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)
36 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
37 | cfg_log_f.write(str(Namespace(**conf)))
38 |
39 | seed_everything(args.seed)
40 |
41 | sep_path = os.path.join(args.model_path, 'point_cloud')
42 | os.makedirs(sep_path, exist_ok=True)
43 |
44 | gaussians = GaussianModel(args)
45 |
46 | if args.env_map_res > 0:
47 | env_map = EnvLight(resolution=args.env_map_res).cuda()
48 | env_map.training_setup(args)
49 | else:
50 | env_map = None
51 |
52 | checkpoints = glob.glob(os.path.join(args.model_path, "chkpnt*.pth"))
53 | assert len(checkpoints) > 0, "No checkpoints found."
54 | checkpoint = sorted(checkpoints, key=lambda x: int(x.split("chkpnt")[-1].split(".")[0]))[-1]
55 | print(f"Loading checkpoint {checkpoint}")
56 | (model_params, first_iter) = torch.load(checkpoint)
57 | gaussians.restore(model_params, args)
58 |
59 | if env_map is not None:
60 | env_checkpoint = os.path.join(os.path.dirname(checkpoint),
61 | os.path.basename(checkpoint).replace("chkpnt", "env_light_chkpnt"))
62 | (light_params, _) = torch.load(env_checkpoint)
63 | env_map.restore(light_params)
64 | print("Number of Gaussians: ", gaussians.get_xyz.shape[0])
65 |
66 |
67 | save_timestamp = 0.0
68 | gaussians.save_ply_at_t(os.path.join(args.model_path, "point_cloud", "iteration_{}".format(first_iter), "point_cloud.ply"), save_timestamp)
69 |
70 |
71 |
--------------------------------------------------------------------------------