├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── main_figure.png
├── configs
├── eval
│ ├── eval_w_align.yaml
│ └── eval_wo_align.yaml
└── train
│ ├── train_iphone.yaml
│ ├── train_kubric_mrig.yaml
│ ├── train_nvidia.yaml
│ └── train_tnt.yaml
├── licences
└── LICENSE_gs.md
├── scripts
├── img2video.py
├── iphone2format.py
├── kubricmrig2format.py
├── nvidia2format.py
├── run_depthanything.py
├── run_mast3r
│ ├── depth_preprocessor
│ │ ├── get_pcd.py
│ │ ├── pcd_utils.py
│ │ └── utils.py
│ └── run.py
├── tam_npy2png.py
└── tnt2format.py
└── src
├── data
├── __init__.py
├── asset_readers.py
├── dataloader.py
├── datamodule.py
└── utils.py
├── evaluator
├── eval.py
└── utils.py
├── model
├── __init__.py
├── rodygs_dynamic.py
└── rodygs_static.py
├── pipelines
├── eval.py
├── train.py
└── utils.py
├── trainer
├── __init__.py
├── losses.py
├── optim.py
├── renderer.py
├── rodygs.py
├── rodygs_dynamic.py
├── rodygs_static.py
└── utils.py
└── utils
├── configs.py
├── eval_utils.py
├── general_utils.py
├── graphic_utils.py
├── loss_utils.py
├── point_utils.py
├── pose_estim_utils.py
├── sh_utils.py
└── store_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output
2 | *_temporal*
3 | __pycache__
4 | cache/*
5 | logs/*
6 | .vscode
7 |
8 | data/*
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "thirdparty/diff-gaussian-rasterization"]
2 | path = thirdparty/diff-gaussian-rasterization
3 | url = https://github.com/slothfulxtx/diff-gaussian-rasterization.git
4 | branch = pose
5 | [submodule "thirdparty/simple-knn"]
6 | path = thirdparty/simple-knn
7 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
8 | [submodule "thirdparty/mast3r"]
9 | path = thirdparty/mast3r
10 | url = https://github.com/naver/mast3r.git
11 | [submodule "thirdparty/pytorch3d"]
12 | path = thirdparty/pytorch3d
13 | url = https://github.com/facebookresearch/pytorch3d.git
14 | [submodule "thirdparty/depth_anything_v2"]
15 | path = thirdparty/depth_anything_v2
16 | url = https://github.com/DepthAnything/Depth-Anything-V2.git
17 | [submodule "thirdparty/Track-Anything"]
18 | path = thirdparty/Track-Anything
19 | url = https://github.com/gaomingqi/Track-Anything.git
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Robust Dynamic Gaussian Splatting for Casual Videos
2 | ### [Project Page](https://rodygs.github.io/) | [Paper](https://www.arxiv.org/abs/2412.03077) | [Video](https://www.youtube.com/watch?v=pwv4Gwl07Tw)
3 | #### [Yoonwoo Jeong](https://jeongyw12382.github.io/), [Junmyeong Lee](https://www.linkedin.com/in/junmyeong-lee/), [Hoseung Choi](https://www.linkedin.com/in/hoseung-choi/), [Minsu Cho](https://cvlab.postech.ac.kr/~mcho/)
4 |
5 | 
6 |
7 |
8 | ## News
9 | - [2024-12-11] We have released the code. (dataset coming soon)
10 |
11 | ## Abstract
12 | Dynamic view synthesis (DVS) has advanced remarkably in recent years, achieving high-fidelity rendering while reducing computational costs. Despite the progress, optimizing dynamic neural fields from casual videos remains challenging, as these videos do not provide direct 3D information, such as camera trajectories or the underlying scene geometry. In this work, we present RoDyGS, an optimization pipeline for dynamic Gaussian Splatting from casual videos. It effectively learns motion and underlying geometry of scenes by separating dynamic and static primitives, and ensures that the learned motion and geometry are physically plausible by incorporating motion and geometric regularization terms. We also introduce a comprehensive benchmark, Kubric-MRig, that provides extensive camera and object motion along with simultaneous multi-view captures, features that are absent in previous benchmarks. Experimental results demonstrate that the proposed method significantly outperforms previous pose-free dynamic neural fields and achieves competitive rendering quality compared to existing pose-free static neural fields.
13 |
14 | ## Short Description Video
15 |
16 | [](https://www.youtube.com/watch?v=pwv4Gwl07Tw)
17 |
18 |
19 | ## Installation
20 | The codes have been tested on Python 3.11 with CUDA version 12.2.
21 | ```bash
22 | git clone https://github.com/POSTECH-CVLab/RoDyGS.git --recursive
23 | cd RoDyGS
24 |
25 | conda create -n rodygs python=3.11 -c anaconda -y
26 | conda activate rodygs
27 |
28 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia -y
29 | pip3 install plyfile==1.0.3 omegaconf==2.3.0 tqdm==4.66.4 piqa==1.3.2 scipy==1.13.1 tifffile==2024.7.24 pandas==2.2.2 numpy==1.24.1 gradio==3.39.0 av=10.0.0 imageio==2.19.5 imageio-ffmpeg
30 | pip3 install opencv-python matplotlib trimesh roma einops huggingface_hub ninja gdown progressbar mmcv
31 | pip3 install git+https://github.com/facebookresearch/segment-anything.git
32 |
33 | pip3 install thirdparty/simple-knn
34 | pip3 install thirdparty/diff-gaussian-rasterization
35 | pip3 install thirdparty/pytorch3d
36 | ```
37 |
38 | ## Dataset Format
39 | - First, you need to set up the dataset in the following format to run RoDyGS.
40 | ```
41 | [scene_name]
42 | |- train # train images
43 | |- rgba_00000.png
44 | |- rgba_00001.png
45 | |- ...
46 | |- test # test images
47 | |- rgba_00000.png
48 | |- rgba_00001.png
49 | |- ...
50 | |- train_transforms.json # train camera information
51 | |- test_transforms.json # test camera information
52 | ```
53 |
54 |
55 | The format of train_transforms.json
56 |
57 | - The units for "camera_angle_x" and "camera_angle_y" are in degrees.
58 | - The range of "time" is from 0.0 to 1.0.
59 | - The format of "transform_matrix" follows the OpenCV camera-to-world matrix.
60 | - "train/rgba_\*.png" in "file path" should be changed to "test/rgba_\*.png" in test_transforms.json.
61 | ```json
62 | {
63 | "camera_angle_x": 50,
64 | "camera_angle_y": 50,
65 | "frames": [
66 | {
67 | "time": 0.0,
68 | "file_path": "train/rgba_00000.png",
69 | "width": 1920,
70 | "height": 1080,
71 | "transform_matrix": [
72 | [
73 | 1.0,
74 | 0.0,
75 | 0.0,
76 | 0.0
77 | ],
78 | [
79 | 0.0,
80 | 1.0,
81 | 0.0,
82 | 0.0
83 | ],
84 | [
85 | 0.0,
86 | 0.0,
87 | 1.0,
88 | 0.0
89 | ],
90 | [
91 | 0.0,
92 | 0.0,
93 | 0.0,
94 | 1.0
95 | ]
96 | ]
97 | },
98 | {
99 | "time": 0.01,
100 | "file_path": "train/rgba_00001.png",
101 | "width": 1920,
102 | "height": 1080,
103 | "transform_matrix": ~
104 | },
105 | ...
106 | {
107 | "time": 0.99,
108 | "file_path": "train/rgba_00099.png",
109 | "width": 1920,
110 | "height": 1080,
111 | "transform_matrix": ~
112 | }
113 | ]
114 | }
115 | ```
116 |
117 |
118 | ## Running
119 | ### Pre-processing
120 |
121 | - Please download the checkpoints from https://github.com/naver/mast3r ("MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric") and https://github.com/DepthAnything/Depth-Anything-V2 ("Depth-Anything-V2-Large").
122 |
123 | - You need to run Track-Anything and obtain motion masks from the demo UI. For more information, visit https://github.com/gaomingqi/Track-Anything.
124 | ```bash
125 | cd RoDyGS/thirdparty/Track-Anything
126 | python3 app.py --mask_save True
127 | ```
128 |
129 | - Then, follow these steps:
130 | ```bash
131 | cd RoDyGS
132 |
133 | # Get TAM mask
134 | python3 -m scripts.img2video --data_dir [path_to_dataset_dir]
135 | python3 -m scripts.tam_npy2png --npy_dir [path_to_tam_npy_mask] --output_dir [path_to_dataset_dir]
136 |
137 | # Run Depth Anything
138 | python3 -m scripts.run_depthanything --encoder [encoder_type] --encoder-path [path_to_encoder_ckpt] --img-path [path_to_image_dir] --outdir [path_to_output] --raw-depth
139 |
140 | # Run MASt3R
141 | python3 -m scripts.run_mast3r.run --input_dir [path_to_images] --exp_name [mast3r_expname] --output_dir [path_to_mast3r_output] --ckpt [path_to_model_ckpt] --cache_dir [path_to_cache_output]
142 | python3 -m scripts.run_mast3r.depth_preprocessor.get_pcd --datadir [path_to_dataset_dir] --mask_name [mask_dir_name] --mast3r_expname [mast3r_expname]
143 | ```
144 |
145 | example commands
146 |
147 | ```bash
148 | python3 -m scripts.img2video --data_dir data/kubric_mrig/scene0
149 | python3 -m scripts.tam_npy2png --npy_dir thirdparty/Track-Anything/result/mask/train --output_dir data/kubric_mrig/scene0
150 |
151 | python3 -m scripts.run_depthanything --encoder vitl --encoder-path checkpoints/depth_anything_v2_vitl.pth --img-path data/kubric_mrig/scene0/train --outdir data/kubric_mrig/scene0/depth_anything --raw-depth
152 |
153 | python3 -m scripts.run_mast3r.run --input_dir data/kubric_mrig/scene0/train --exp_name mast3r --output_dir data/kubric_mrig/scene0/mast3r_opt --ckpt MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth --cache_dir mast3r_cache/
154 |
155 | python3 -m scripts.run_mast3r.depth_preprocessor.get_pcd --datadir data/kubric_mrig/scene0/ --mask_name tam_mask --mast3r_expname mast3r_000
156 | ```
157 |
158 |
159 |
160 | - After pre-processing, your dataset will have the following format:
161 | ```
162 | [scene_name]
163 | |- train # train images
164 | |- test # test images
165 | |- train_transforms.json # train camera informations
166 | |- test_transforms.json # test camera information
167 | |- tam_mask
168 | |- depth_anything
169 | |- mast3r_opt
170 | |- [expname]
171 | |- dynamic # dynamic point clouds (per frame)
172 | |- static # static point clouds (per frame)
173 | |- op_results # merged point clouds (unseperated)
174 | |- global_params.pkl # initial camera poses
175 | ```
176 |
177 | ### Training
178 | ```bash
179 | python3 -m src.pipelines.train -d [path_to_dataset_dir] -b [path_to_training_config] -g [loggin_group_name] -n [training_log_name]
180 | ```
181 |
182 | example commands
183 |
184 | ```bash
185 | python3 -m src.pipelines.train -d data/kubric_mrig/scene0/ -b configs/train/train_kubric_mrig.yaml -g kubric_mrig -n scene0
186 | ```
187 |
188 |
189 | ### Rendering & Evaluation
190 | - We used "configs/eval/eval_w_align.yaml" for iPhone dataset evaluation.
191 | - And we used "configs/eval/eval_w_align.yaml" for Kubric-MRig & NVIDIA Dynamic dataset evaluation.
192 | ```bash
193 | python3 -m src.pipelines.eval -c [path_to_config] -t [test_log_name] -d [path_to_dataset] -m [path_to_training_log]
194 | ```
195 |
196 | example commands
197 |
198 | ```bash
199 | python3 -m src.pipelines.eval -c configs/eval/eval_wo_align.yaml -t eval -d ../data/kubric_mrig/scene0/ -m logs/kubric_0_0777/
200 | ```
201 |
202 |
203 |
--------------------------------------------------------------------------------
/assets/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/POSTECH-CVLab/RoDyGS/7057b57afcb3c0297bab500388edf25b0d8d9754/assets/main_figure.png
--------------------------------------------------------------------------------
/configs/eval/eval_w_align.yaml:
--------------------------------------------------------------------------------
1 | evaluator:
2 | target: src.evaluator.eval.RoDyGSEvaluator
3 | params:
4 | camera_lr: 0.00005
5 | num_opts: 1000
6 |
7 | static_data:
8 | target: src.data.datamodule.GSDataModule
9 | params:
10 | train_dset_config:
11 | target: src.data.datamodule.LazyDataReader
12 | params:
13 | pose_reader:
14 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
15 | params:
16 | mast3r_expname: swin_noloop_000
17 | mast3r_img_res: 512
18 | train_dloader_config:
19 | target: src.data.dataloader.SequentialSingleDataLoader
20 | params: {}
21 | test_dset_config:
22 | target: src.data.datamodule.LazyDataReader
23 | params:
24 | pose_reader:
25 | target: src.data.asset_readers.Test_MASt3RFovCameraReader
26 | params:
27 | mast3r_expname: swin_noloop_000
28 | mast3r_img_res: 512
29 | test_dloader_config:
30 | target: src.data.dataloader.SequentialSingleDataLoader
31 | params: {}
32 | test_transform_fname: test_transforms.json
33 |
34 | normalize_cams: false
--------------------------------------------------------------------------------
/configs/eval/eval_wo_align.yaml:
--------------------------------------------------------------------------------
1 | evaluator:
2 | target: src.evaluator.eval.RoDyGSEvaluator
3 | params:
4 | camera_lr: 0
5 | num_opts: 0
6 |
7 | static_data:
8 | target: src.data.datamodule.GSDataModule
9 | params:
10 | train_dset_config:
11 | target: src.data.datamodule.LazyDataReader
12 | params:
13 | pose_reader:
14 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
15 | params:
16 | mast3r_expname: swin_noloop_000
17 | mast3r_img_res: 512
18 | train_dloader_config:
19 | target: src.data.dataloader.SequentialSingleDataLoader
20 | params: {}
21 | test_dset_config:
22 | target: src.data.datamodule.LazyDataReader
23 | params:
24 | pose_reader:
25 | target: src.data.asset_readers.Test_MASt3RFovCameraReader
26 | params:
27 | mast3r_expname: swin_noloop_000
28 | mast3r_img_res: 512
29 | test_dloader_config:
30 | target: src.data.dataloader.SequentialSingleDataLoader
31 | params: {}
32 | test_transform_fname: test_transforms.json
33 |
34 | normalize_cams: false
--------------------------------------------------------------------------------
/configs/train/train_iphone.yaml:
--------------------------------------------------------------------------------
1 | static_data:
2 | target: src.data.datamodule.GSDataModule
3 | params:
4 | dirpath: ./
5 | train_dset_config:
6 | target: src.data.datamodule.LazyDataReader
7 | params:
8 | camera_config:
9 | target: src.data.utils.FixedCamera
10 | pose_reader:
11 | target: src.data.asset_readers.MASt3RCameraReader
12 | params:
13 | mast3r_expname: swin_noloop_000
14 | mast3r_img_res: 512
15 | depth_reader:
16 | target: src.data.asset_readers.DepthAnythingReader
17 | params:
18 | split: train
19 | motion_mask_reader:
20 | target: src.data.asset_readers.TAMMaskReader
21 | params:
22 | split: train
23 | train_dloader_config:
24 | target: src.data.dataloader.PermutationSingleDataLoader
25 | params:
26 | num_iterations: 20000
27 | test_dset_config:
28 | target: src.data.datamodule.DataReader
29 | params:
30 | camera_config:
31 | target: src.data.utils.FixedCamera
32 | pose_reader:
33 | target: src.data.asset_readers.GTCameraReader
34 | test_dloader_config:
35 | target: src.data.dataloader.SequentialSingleDataLoader
36 | params: {}
37 | train_pcd_reader_config:
38 | target: src.data.asset_readers.MASt3RPCDReader
39 | params:
40 | mast3r_expname: swin_noloop_000
41 | mode: static
42 | num_limit_points: 120000
43 | train_pose_reader_config:
44 | target: src.data.datamodule.DataReader
45 | params:
46 | camera_config:
47 | target: src.data.utils.FixedCamera
48 | pose_reader:
49 | target: src.data.asset_readers.GTCameraReader
50 | normalize_cams: false
51 | static_model:
52 | target: src.model.rodygs_static.StaticRoDyGS
53 | params:
54 | sh_degree: 3
55 | isotropic: false
56 | static_calibrated_pose_reader:
57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
58 | params:
59 | mast3r_expname: swin_noloop_000
60 | mast3r_img_res: 512
61 | dynamic_data:
62 | target: src.data.datamodule.GSDataModule
63 | params:
64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/
65 | train_dset_config:
66 | target: src.data.datamodule.LazyDataReader
67 | params:
68 | camera_config:
69 | target: src.data.utils.FixedCamera
70 | pose_reader:
71 | target: src.data.asset_readers.MASt3RCameraReader
72 | params:
73 | mast3r_expname: swin_noloop_000
74 | mast3r_img_res: 512
75 | depth_reader:
76 | target: src.data.asset_readers.DepthAnythingReader
77 | params:
78 | split: train
79 | motion_mask_reader:
80 | target: src.data.asset_readers.TAMMaskReader
81 | params:
82 | split: train
83 | train_dloader_config:
84 | target: src.data.dataloader.PermutationSingleDataLoader
85 | params:
86 | num_iterations: 20000
87 | test_dset_config:
88 | target: src.data.datamodule.DataReader
89 | params:
90 | camera_config:
91 | target: src.data.utils.FixedCamera
92 | pose_reader:
93 | target: src.data.asset_readers.GTCameraReader
94 | test_dloader_config:
95 | target: src.data.dataloader.SequentialSingleDataLoader
96 | params: {}
97 | train_pcd_reader_config:
98 | target: src.data.asset_readers.MASt3RPCDReader
99 | params:
100 | mast3r_expname: swin_noloop_000
101 | mode: dynamic
102 | num_limit_points: 120000
103 | test_transform_fname: test_transforms.json
104 | train_pose_reader_config:
105 | target: src.data.datamodule.DataReader
106 | params:
107 | camera_config:
108 | target: src.data.utils.FixedCamera
109 | pose_reader:
110 | target: src.data.asset_readers.GTCameraReader
111 | normalize_cams: false
112 | dynamic_model:
113 | target: src.model.rodygs_dynamic.DynRoDyGS
114 | params:
115 | sh_degree: 3
116 | deform_netwidth: 128
117 | deform_t_emb_multires: 26
118 | deform_t_log_sampling: false
119 | num_basis: 16
120 | isotropic: false
121 | inverse_motion: true
122 | trainer:
123 | target: src.trainer.rodygs.RoDyGSTrainer
124 | params:
125 | log_freq: 50
126 | sh_up_start_iteration: 15000
127 | sh_up_period: 1000
128 | static:
129 | target: src.trainer.rodygs_static.ThreeDGSTrainer
130 | params:
131 | loss_config:
132 | target: src.trainer.losses.MultiLoss
133 | params:
134 | loss_configs:
135 | - name: d_ssim
136 | weight: 0.2
137 | target: src.trainer.losses.SSIMLoss
138 | params:
139 | mode: all
140 | - name: l1
141 | weight: 0.8
142 | target: src.trainer.losses.L1Loss
143 | params:
144 | mode: all
145 | - name: global_pearson_depth
146 | weight: 0.05
147 | target: src.trainer.losses.GlobalPearsonDepthLoss
148 | start: 0
149 | params:
150 | mode: all
151 | - name: local_pearson_depth
152 | weight: 0.15
153 | target: src.trainer.losses.LocalPearsonDepthLoss
154 | start: 0
155 | params:
156 | box_p: 128
157 | p_corr: 0.5
158 | mode: all
159 | num_iterations: 20000
160 | position_lr_init: 0.00016
161 | position_lr_final: 1.6e-06
162 | position_lr_delay_mult: 0.01
163 | position_lr_max_steps: 20000
164 | feature_lr: 0.0025
165 | opacity_lr: 0.05
166 | scaling_lr: 0.005
167 | rotation_lr: 0.001
168 | percent_dense: 0.01
169 | densification_interval: 100
170 | opacity_reset_interval: 5000000
171 | densify_from_iter: 500
172 | densify_until_iter: 20000
173 | densify_grad_threshold: 0.0002
174 | camera_opt_config:
175 | target: src.trainer.optim.CameraQuatOptimizer
176 | params:
177 | camera_rotation_lr: 1.0e-05
178 | camera_translation_lr: 1.0e-06
179 | camera_lr_warmup: 3000
180 | total_steps: 20000
181 | dynamic:
182 | target: src.trainer.rodygs_dynamic.DynTrainer
183 | params:
184 | loss_config:
185 | target: src.trainer.losses.MultiLoss
186 | params:
187 | loss_configs:
188 | - name: d_ssim
189 | weight: 0.2
190 | target: src.trainer.losses.SSIMLoss
191 | params:
192 | mode: all
193 | - name: l1
194 | weight: 0.8
195 | target: src.trainer.losses.L1Loss
196 | params:
197 | mode: all
198 | - name: motion_l1_reg
199 | weight: 0.01
200 | start: 0
201 | target: src.trainer.losses.MotionL1Loss
202 | - name: motion_sparsity
203 | weight: 0.002
204 | start: 0
205 | target: src.trainer.losses.MotionSparsityLoss
206 | - name: global_pearson_depth
207 | weight: 0.05
208 | target: src.trainer.losses.GlobalPearsonDepthLoss
209 | start: 0
210 | params:
211 | mode: all
212 | - name: local_pearson_depth
213 | weight: 0.15
214 | target: src.trainer.losses.LocalPearsonDepthLoss
215 | start: 0
216 | params:
217 | box_p: 128
218 | p_corr: 0.5
219 | mode: all
220 | - name: rigidity
221 | weight: 0.5
222 | freq: 5
223 | start: 0
224 | target: src.trainer.losses.RigidityLoss
225 | params:
226 | mode:
227 | - distance_preserving
228 | - surface
229 | K: 8
230 | - name: motion_basis_reg
231 | weight: 0.5
232 | start: 0
233 | target: src.trainer.losses.MotionBasisRegularizaiton
234 | params:
235 | transl_degree: 0
236 | rot_degree: 0
237 | freq_div_mode: cum_exponential
238 | num_iterations: 20000
239 | position_lr_init: 0.00016
240 | position_lr_final: 1.6e-06
241 | position_lr_delay_mult: 0.01
242 | position_lr_max_steps: 20000
243 | feature_lr: 0.0025
244 | opacity_lr: 0.05
245 | scaling_lr: 0.001
246 | rotation_lr: 0.001
247 | percent_dense: 0.01
248 | densification_interval: 100
249 | opacity_reset_interval: 5000000
250 | densify_from_iter: 500
251 | densify_until_iter: 15000
252 | densify_grad_threshold: 0.0002
253 | deform_warmup_steps: 0
254 | deform_lr_init: 0.0016
255 | deform_lr_final: 0.00016
256 | deform_lr_delay_mult: 0.01
257 | deform_lr_max_steps: 20000
258 | motion_coeff_lr: 0.00016
259 | camera_opt_config:
260 | target: src.trainer.optim.CameraQuatOptimizer
261 | params:
262 | camera_rotation_lr: 0.0
263 | camera_translation_lr: 0.0
264 | camera_lr_warmup: 0
265 | total_steps: 20000
--------------------------------------------------------------------------------
/configs/train/train_kubric_mrig.yaml:
--------------------------------------------------------------------------------
1 | static_data:
2 | target: src.data.datamodule.GSDataModule
3 | params:
4 | dirpath: ./
5 | train_dset_config:
6 | target: src.data.datamodule.LazyDataReader
7 | params:
8 | camera_config:
9 | target: src.data.utils.FixedCamera
10 | pose_reader:
11 | target: src.data.asset_readers.MASt3RCameraReader
12 | params:
13 | mast3r_expname: swin_noloop_000
14 | mast3r_img_res: 512
15 | depth_reader:
16 | target: src.data.asset_readers.DepthAnythingReader
17 | params:
18 | split: train
19 | motion_mask_reader:
20 | target: src.data.asset_readers.TAMMaskReader
21 | params:
22 | split: train
23 | train_dloader_config:
24 | target: src.data.dataloader.PermutationSingleDataLoader
25 | params:
26 | num_iterations: 20000
27 | test_dset_config:
28 | target: src.data.datamodule.DataReader
29 | params:
30 | camera_config:
31 | target: src.data.utils.FixedCamera
32 | pose_reader:
33 | target: src.data.asset_readers.GTCameraReader
34 | test_dloader_config:
35 | target: src.data.dataloader.SequentialSingleDataLoader
36 | params: {}
37 | train_pcd_reader_config:
38 | target: src.data.asset_readers.MASt3RPCDReader
39 | params:
40 | mast3r_expname: swin_noloop_000
41 | mode: static
42 | num_limit_points: 120000
43 | train_pose_reader_config:
44 | target: src.data.datamodule.DataReader
45 | params:
46 | camera_config:
47 | target: src.data.utils.FixedCamera
48 | pose_reader:
49 | target: src.data.asset_readers.GTCameraReader
50 | normalize_cams: false
51 | static_model:
52 | target: src.model.rodygs_static.StaticRoDyGS
53 | params:
54 | sh_degree: 3
55 | isotropic: false
56 | static_calibrated_pose_reader:
57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
58 | params:
59 | mast3r_expname: swin_noloop_000
60 | mast3r_img_res: 512
61 | dynamic_data:
62 | target: src.data.datamodule.GSDataModule
63 | params:
64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/
65 | train_dset_config:
66 | target: src.data.datamodule.LazyDataReader
67 | params:
68 | camera_config:
69 | target: src.data.utils.FixedCamera
70 | pose_reader:
71 | target: src.data.asset_readers.MASt3RCameraReader
72 | params:
73 | mast3r_expname: swin_noloop_000
74 | mast3r_img_res: 512
75 | depth_reader:
76 | target: src.data.asset_readers.DepthAnythingReader
77 | params:
78 | split: train
79 | motion_mask_reader:
80 | target: src.data.asset_readers.TAMMaskReader
81 | params:
82 | split: train
83 | train_dloader_config:
84 | target: src.data.dataloader.PermutationSingleDataLoader
85 | params:
86 | num_iterations: 20000
87 | test_dset_config:
88 | target: src.data.datamodule.DataReader
89 | params:
90 | camera_config:
91 | target: src.data.utils.FixedCamera
92 | pose_reader:
93 | target: src.data.asset_readers.GTCameraReader
94 | test_dloader_config:
95 | target: src.data.dataloader.SequentialSingleDataLoader
96 | params: {}
97 | train_pcd_reader_config:
98 | target: src.data.asset_readers.MASt3RPCDReader
99 | params:
100 | mast3r_expname: swin_noloop_000
101 | mode: dynamic
102 | num_limit_points: 120000
103 | test_transform_fname: test_transforms.json
104 | train_pose_reader_config:
105 | target: src.data.datamodule.DataReader
106 | params:
107 | camera_config:
108 | target: src.data.utils.FixedCamera
109 | pose_reader:
110 | target: src.data.asset_readers.GTCameraReader
111 | normalize_cams: false
112 | dynamic_model:
113 | target: src.model.rodygs_dynamic.DynRoDyGS
114 | params:
115 | sh_degree: 3
116 | deform_netwidth: 128
117 | deform_t_emb_multires: 26
118 | deform_t_log_sampling: false
119 | num_basis: 16
120 | isotropic: false
121 | inverse_motion: true
122 | trainer:
123 | target: src.trainer.rodygs.RoDyGSTrainer
124 | params:
125 | log_freq: 50
126 | sh_up_start_iteration: 15000
127 | sh_up_period: 1000
128 | static:
129 | target: src.trainer.rodygs_static.ThreeDGSTrainer
130 | params:
131 | loss_config:
132 | target: src.trainer.losses.MultiLoss
133 | params:
134 | loss_configs:
135 | - name: d_ssim
136 | weight: 0.2
137 | target: src.trainer.losses.SSIMLoss
138 | params:
139 | mode: all
140 | - name: l1
141 | weight: 0.8
142 | target: src.trainer.losses.L1Loss
143 | params:
144 | mode: all
145 | - name: global_pearson_depth
146 | weight: 0.05
147 | target: src.trainer.losses.GlobalPearsonDepthLoss
148 | start: 0
149 | params:
150 | mode: all
151 | - name: local_pearson_depth
152 | weight: 0.15
153 | target: src.trainer.losses.LocalPearsonDepthLoss
154 | start: 0
155 | params:
156 | box_p: 128
157 | p_corr: 0.5
158 | mode: all
159 | num_iterations: 20000
160 | position_lr_init: 0.00016
161 | position_lr_final: 1.6e-06
162 | position_lr_delay_mult: 0.01
163 | position_lr_max_steps: 20000
164 | feature_lr: 0.0025
165 | opacity_lr: 0.05
166 | scaling_lr: 0.005
167 | rotation_lr: 0.001
168 | percent_dense: 0.01
169 | densification_interval: 100
170 | opacity_reset_interval: 5000000
171 | densify_from_iter: 500
172 | densify_until_iter: 20000
173 | densify_grad_threshold: 0.0002
174 | camera_opt_config:
175 | target: src.trainer.optim.CameraQuatOptimizer
176 | params:
177 | camera_rotation_lr: 1.0e-05
178 | camera_translation_lr: 1.0e-06
179 | camera_lr_warmup: 0
180 | total_steps: 20000
181 | dynamic:
182 | target: src.trainer.rodygs_dynamic.DynTrainer
183 | params:
184 | loss_config:
185 | target: src.trainer.losses.MultiLoss
186 | params:
187 | loss_configs:
188 | - name: d_ssim
189 | weight: 0.2
190 | target: src.trainer.losses.SSIMLoss
191 | params:
192 | mode: all
193 | - name: l1
194 | weight: 0.8
195 | target: src.trainer.losses.L1Loss
196 | params:
197 | mode: all
198 | - name: motion_l1_reg
199 | weight: 0.01
200 | start: 0
201 | target: src.trainer.losses.MotionL1Loss
202 | - name: motion_sparsity
203 | weight: 0.002
204 | start: 0
205 | target: src.trainer.losses.MotionSparsityLoss
206 | - name: global_pearson_depth
207 | weight: 0.05
208 | target: src.trainer.losses.GlobalPearsonDepthLoss
209 | start: 0
210 | params:
211 | mode: all
212 | - name: local_pearson_depth
213 | weight: 0.15
214 | target: src.trainer.losses.LocalPearsonDepthLoss
215 | start: 0
216 | params:
217 | box_p: 128
218 | p_corr: 0.5
219 | mode: all
220 | - name: rigidity
221 | weight: 0.5
222 | freq: 5
223 | start: 0
224 | target: src.trainer.losses.RigidityLoss
225 | params:
226 | mode:
227 | - distance_preserving
228 | - surface
229 | K: 8
230 | - name: motion_basis_reg
231 | weight: 0.1
232 | start: 0
233 | target: src.trainer.losses.MotionBasisRegularizaiton
234 | params:
235 | transl_degree: 0
236 | rot_degree: 0
237 | freq_div_mode: cum_exponential
238 | num_iterations: 20000
239 | position_lr_init: 0.00016
240 | position_lr_final: 1.6e-06
241 | position_lr_delay_mult: 0.01
242 | position_lr_max_steps: 20000
243 | feature_lr: 0.0025
244 | opacity_lr: 0.05
245 | scaling_lr: 0.001
246 | rotation_lr: 0.001
247 | percent_dense: 0.01
248 | densification_interval: 100
249 | opacity_reset_interval: 5000000
250 | densify_from_iter: 500
251 | densify_until_iter: 15000
252 | densify_grad_threshold: 0.0002
253 | deform_warmup_steps: 0
254 | deform_lr_init: 0.0016
255 | deform_lr_final: 0.00016
256 | deform_lr_delay_mult: 0.01
257 | deform_lr_max_steps: 20000
258 | motion_coeff_lr: 0.00016
259 | camera_opt_config:
260 | target: src.trainer.optim.CameraQuatOptimizer
261 | params:
262 | camera_rotation_lr: 0.0
263 | camera_translation_lr: 0.0
264 | camera_lr_warmup: 0
265 | total_steps: 20000
--------------------------------------------------------------------------------
/configs/train/train_nvidia.yaml:
--------------------------------------------------------------------------------
1 | static_data:
2 | target: src.data.datamodule.GSDataModule
3 | params:
4 | dirpath: ./
5 | train_dset_config:
6 | target: src.data.datamodule.LazyDataReader
7 | params:
8 | camera_config:
9 | target: src.data.utils.FixedCamera
10 | pose_reader:
11 | target: src.data.asset_readers.MASt3RCameraReader
12 | params:
13 | mast3r_expname: swin_noloop_000
14 | mast3r_img_res: 512
15 | depth_reader:
16 | target: src.data.asset_readers.DepthAnythingReader
17 | params:
18 | split: train
19 | motion_mask_reader:
20 | target: src.data.asset_readers.TAMMaskReader
21 | params:
22 | split: train
23 | train_dloader_config:
24 | target: src.data.dataloader.PermutationSingleDataLoader
25 | params:
26 | num_iterations: 20000
27 | test_dset_config:
28 | target: src.data.datamodule.DataReader
29 | params:
30 | camera_config:
31 | target: src.data.utils.FixedCamera
32 | pose_reader:
33 | target: src.data.asset_readers.GTCameraReader
34 | test_dloader_config:
35 | target: src.data.dataloader.SequentialSingleDataLoader
36 | params: {}
37 | train_pcd_reader_config:
38 | target: src.data.asset_readers.MASt3RPCDReader
39 | params:
40 | mast3r_expname: swin_noloop_000
41 | mode: static
42 | num_limit_points: 120000
43 | train_pose_reader_config:
44 | target: src.data.datamodule.DataReader
45 | params:
46 | camera_config:
47 | target: src.data.utils.FixedCamera
48 | pose_reader:
49 | target: src.data.asset_readers.GTCameraReader
50 | normalize_cams: false
51 | static_model:
52 | target: src.model.rodygs_static.StaticRoDyGS
53 | params:
54 | sh_degree: 3
55 | isotropic: false
56 | static_calibrated_pose_reader:
57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
58 | params:
59 | mast3r_expname: swin_noloop_000
60 | mast3r_img_res: 512
61 | dynamic_data:
62 | target: src.data.datamodule.GSDataModule
63 | params:
64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/
65 | train_dset_config:
66 | target: src.data.datamodule.LazyDataReader
67 | params:
68 | camera_config:
69 | target: src.data.utils.FixedCamera
70 | pose_reader:
71 | target: src.data.asset_readers.MASt3RCameraReader
72 | params:
73 | mast3r_expname: swin_noloop_000
74 | mast3r_img_res: 512
75 | depth_reader:
76 | target: src.data.asset_readers.DepthAnythingReader
77 | params:
78 | split: train
79 | motion_mask_reader:
80 | target: src.data.asset_readers.TAMMaskReader
81 | params:
82 | split: train
83 | train_dloader_config:
84 | target: src.data.dataloader.PermutationSingleDataLoader
85 | params:
86 | num_iterations: 20000
87 | test_dset_config:
88 | target: src.data.datamodule.DataReader
89 | params:
90 | camera_config:
91 | target: src.data.utils.FixedCamera
92 | pose_reader:
93 | target: src.data.asset_readers.GTCameraReader
94 | test_dloader_config:
95 | target: src.data.dataloader.SequentialSingleDataLoader
96 | params: {}
97 | train_pcd_reader_config:
98 | target: src.data.asset_readers.MASt3RPCDReader
99 | params:
100 | mast3r_expname: swin_noloop_000
101 | mode: dynamic
102 | num_limit_points: 120000
103 | test_transform_fname: test_transforms.json
104 | train_pose_reader_config:
105 | target: src.data.datamodule.DataReader
106 | params:
107 | camera_config:
108 | target: src.data.utils.FixedCamera
109 | pose_reader:
110 | target: src.data.asset_readers.GTCameraReader
111 | normalize_cams: false
112 | dynamic_model:
113 | target: src.model.rodygs_dynamic.DynRoDyGS
114 | params:
115 | sh_degree: 3
116 | deform_netwidth: 128
117 | deform_t_emb_multires: 26
118 | deform_t_log_sampling: false
119 | num_basis: 16
120 | isotropic: false
121 | inverse_motion: true
122 | trainer:
123 | target: src.trainer.rodygs.RoDyGSTrainer
124 | params:
125 | log_freq: 50
126 | sh_up_start_iteration: 15000
127 | sh_up_period: 1000
128 | static:
129 | target: src.trainer.rodygs_static.ThreeDGSTrainer
130 | params:
131 | loss_config:
132 | target: src.trainer.losses.MultiLoss
133 | params:
134 | loss_configs:
135 | - name: d_ssim
136 | weight: 0.2
137 | target: src.trainer.losses.SSIMLoss
138 | params:
139 | mode: all
140 | - name: l1
141 | weight: 0.8
142 | target: src.trainer.losses.L1Loss
143 | params:
144 | mode: all
145 | - name: global_pearson_depth
146 | weight: 0.05
147 | target: src.trainer.losses.GlobalPearsonDepthLoss
148 | start: 0
149 | params:
150 | mode: all
151 | - name: local_pearson_depth
152 | weight: 0.15
153 | target: src.trainer.losses.LocalPearsonDepthLoss
154 | start: 0
155 | params:
156 | box_p: 128
157 | p_corr: 0.5
158 | mode: all
159 | num_iterations: 20000
160 | position_lr_init: 0.00016
161 | position_lr_final: 1.6e-06
162 | position_lr_delay_mult: 0.01
163 | position_lr_max_steps: 20000
164 | feature_lr: 0.0025
165 | opacity_lr: 0.05
166 | scaling_lr: 0.005
167 | rotation_lr: 0.001
168 | percent_dense: 0.01
169 | densification_interval: 100
170 | opacity_reset_interval: 5000000
171 | densify_from_iter: 500
172 | densify_until_iter: 20000
173 | densify_grad_threshold: 0.0002
174 | camera_opt_config:
175 | target: src.trainer.optim.CameraQuatOptimizer
176 | params:
177 | camera_rotation_lr: 1.0e-05
178 | camera_translation_lr: 1.0e-06
179 | camera_lr_warmup: 0
180 | total_steps: 20000
181 | dynamic:
182 | target: src.trainer.rodygs_dynamic.DynTrainer
183 | params:
184 | loss_config:
185 | target: src.trainer.losses.MultiLoss
186 | params:
187 | loss_configs:
188 | - name: d_ssim
189 | weight: 0.2
190 | target: src.trainer.losses.SSIMLoss
191 | params:
192 | mode: all
193 | - name: l1
194 | weight: 0.8
195 | target: src.trainer.losses.L1Loss
196 | params:
197 | mode: all
198 | - name: motion_l1_reg
199 | weight: 0.01
200 | start: 0
201 | target: src.trainer.losses.MotionL1Loss
202 | - name: motion_sparsity
203 | weight: 0.002
204 | start: 0
205 | target: src.trainer.losses.MotionSparsityLoss
206 | - name: global_pearson_depth
207 | weight: 0.05
208 | target: src.trainer.losses.GlobalPearsonDepthLoss
209 | start: 0
210 | params:
211 | mode: all
212 | - name: local_pearson_depth
213 | weight: 0.15
214 | target: src.trainer.losses.LocalPearsonDepthLoss
215 | start: 0
216 | params:
217 | box_p: 128
218 | p_corr: 0.5
219 | mode: all
220 | - name: rigidity
221 | weight: 0.5
222 | freq: 5
223 | start: 0
224 | target: src.trainer.losses.RigidityLoss
225 | params:
226 | mode:
227 | - distance_preserving
228 | - surface
229 | K: 8
230 | - name: motion_basis_reg
231 | weight: 0.1
232 | start: 0
233 | target: src.trainer.losses.MotionBasisRegularizaiton
234 | params:
235 | transl_degree: 0
236 | rot_degree: 0
237 | freq_div_mode: cum_exponential
238 | num_iterations: 20000
239 | position_lr_init: 0.00016
240 | position_lr_final: 1.6e-06
241 | position_lr_delay_mult: 0.01
242 | position_lr_max_steps: 20000
243 | feature_lr: 0.0025
244 | opacity_lr: 0.05
245 | scaling_lr: 0.001
246 | rotation_lr: 0.001
247 | percent_dense: 0.01
248 | densification_interval: 100
249 | opacity_reset_interval: 5000000
250 | densify_from_iter: 500
251 | densify_until_iter: 15000
252 | densify_grad_threshold: 0.0002
253 | deform_warmup_steps: 0
254 | deform_lr_init: 0.0016
255 | deform_lr_final: 0.00016
256 | deform_lr_delay_mult: 0.01
257 | deform_lr_max_steps: 20000
258 | motion_coeff_lr: 0.00016
259 | camera_opt_config:
260 | target: src.trainer.optim.CameraQuatOptimizer
261 | params:
262 | camera_rotation_lr: 0.0
263 | camera_translation_lr: 0.0
264 | camera_lr_warmup: 0
265 | total_steps: 20000
--------------------------------------------------------------------------------
/configs/train/train_tnt.yaml:
--------------------------------------------------------------------------------
1 | static_data:
2 | target: src.data.datamodule.GSDataModule
3 | params:
4 | dirpath: ./
5 | train_dset_config:
6 | target: src.data.datamodule.LazyDataReader
7 | params:
8 | camera_config:
9 | target: src.data.utils.FixedCamera
10 | pose_reader:
11 | target: src.data.asset_readers.MASt3RCameraReader
12 | params:
13 | mast3r_expname: swin_noloop_000
14 | mast3r_img_res: 512
15 | depth_reader:
16 | target: src.data.asset_readers.DepthAnythingReader
17 | params:
18 | split: train
19 | motion_mask_reader:
20 | target: src.data.asset_readers.TAMMaskReader
21 | params:
22 | split: train
23 | train_dloader_config:
24 | target: src.data.dataloader.PermutationSingleDataLoader
25 | params:
26 | num_iterations: 150000
27 | test_dset_config:
28 | target: src.data.datamodule.DataReader
29 | params:
30 | camera_config:
31 | target: src.data.utils.FixedCamera
32 | pose_reader:
33 | target: src.data.asset_readers.GTCameraReader
34 | test_dloader_config:
35 | target: src.data.dataloader.SequentialSingleDataLoader
36 | params: {}
37 | train_pcd_reader_config:
38 | target: src.data.asset_readers.MASt3RPCDReader
39 | params:
40 | mast3r_expname: swin_noloop_000
41 | mode: static
42 | num_limit_points: 120000
43 | train_pose_reader_config:
44 | target: src.data.datamodule.DataReader
45 | params:
46 | camera_config:
47 | target: src.data.utils.FixedCamera
48 | pose_reader:
49 | target: src.data.asset_readers.GTCameraReader
50 | normalize_cams: false
51 | static_model:
52 | target: src.model.rodygs_static.StaticRoDyGS
53 | params:
54 | sh_degree: 3
55 | isotropic: false
56 | static_calibrated_pose_reader:
57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader
58 | params:
59 | mast3r_expname: swin_noloop_000
60 | mast3r_img_res: 512
61 | dynamic_data:
62 | target: src.data.datamodule.GSDataModule
63 | params:
64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/
65 | train_dset_config:
66 | target: src.data.datamodule.LazyDataReader
67 | params:
68 | camera_config:
69 | target: src.data.utils.FixedCamera
70 | pose_reader:
71 | target: src.data.asset_readers.MASt3RCameraReader
72 | params:
73 | mast3r_expname: swin_noloop_000
74 | mast3r_img_res: 512
75 | depth_reader:
76 | target: src.data.asset_readers.DepthAnythingReader
77 | params:
78 | split: train
79 | motion_mask_reader:
80 | target: src.data.asset_readers.TAMMaskReader
81 | params:
82 | split: train
83 | train_dloader_config:
84 | target: src.data.dataloader.PermutationSingleDataLoader
85 | params:
86 | num_iterations: 150000
87 | test_dset_config:
88 | target: src.data.datamodule.DataReader
89 | params:
90 | camera_config:
91 | target: src.data.utils.FixedCamera
92 | pose_reader:
93 | target: src.data.asset_readers.GTCameraReader
94 | test_dloader_config:
95 | target: src.data.dataloader.SequentialSingleDataLoader
96 | params: {}
97 | train_pcd_reader_config:
98 | target: src.data.asset_readers.MASt3RPCDReader
99 | params:
100 | mast3r_expname: swin_noloop_000
101 | mode: dynamic
102 | num_limit_points: 120000
103 | test_transform_fname: test_transforms.json
104 | train_pose_reader_config:
105 | target: src.data.datamodule.DataReader
106 | params:
107 | camera_config:
108 | target: src.data.utils.FixedCamera
109 | pose_reader:
110 | target: src.data.asset_readers.GTCameraReader
111 | normalize_cams: false
112 | dynamic_model:
113 | target: src.model.rodygs_dynamic.DynRoDyGS
114 | params:
115 | sh_degree: 3
116 | deform_netwidth: 128
117 | deform_t_emb_multires: 26
118 | deform_t_log_sampling: false
119 | num_basis: 16
120 | isotropic: false
121 | inverse_motion: true
122 | trainer:
123 | target: src.trainer.rodygs.RoDyGSTrainer
124 | params:
125 | log_freq: 50
126 | sh_up_start_iteration: 15000
127 | sh_up_period: 1000
128 | static:
129 | target: src.trainer.rodygs_static.ThreeDGSTrainer
130 | params:
131 | loss_config:
132 | target: src.trainer.losses.MultiLoss
133 | params:
134 | loss_configs:
135 | - name: d_ssim
136 | weight: 0.2
137 | target: src.trainer.losses.SSIMLoss
138 | params:
139 | mode: all
140 | - name: l1
141 | weight: 0.8
142 | target: src.trainer.losses.L1Loss
143 | params:
144 | mode: all
145 | - name: global_pearson_depth
146 | weight: 0.001
147 | target: src.trainer.losses.GlobalPearsonDepthLoss
148 | start: 0
149 | params:
150 | mode: all
151 | - name: local_pearson_depth
152 | weight: 0.01
153 | target: src.trainer.losses.LocalPearsonDepthLoss
154 | start: 0
155 | params:
156 | box_p: 128
157 | p_corr: 0.5
158 | mode: all
159 | num_iterations: 150000
160 | position_lr_init: 0.00016
161 | position_lr_final: 1.6e-06
162 | position_lr_delay_mult: 0.01
163 | position_lr_max_steps: 150000
164 | feature_lr: 0.0025
165 | opacity_lr: 0.05
166 | scaling_lr: 0.005
167 | rotation_lr: 0.001
168 | percent_dense: 0.01
169 | densification_interval: 100
170 | opacity_reset_interval: 5000000
171 | densify_from_iter: 500
172 | densify_until_iter: 150000
173 | densify_grad_threshold: 0.0002
174 | camera_opt_config:
175 | target: src.trainer.optim.CameraQuatOptimizer
176 | params:
177 | camera_rotation_lr: 5.0e-06
178 | camera_translation_lr: 5.0e-05
179 | camera_lr_warmup: 3000
180 | total_steps: 150000
181 | dynamic:
182 | target: src.trainer.rodygs_dynamic.DynTrainer
183 | params:
184 | loss_config:
185 | target: src.trainer.losses.MultiLoss
186 | params:
187 | loss_configs:
188 | - name: d_ssim
189 | weight: 0.2
190 | target: src.trainer.losses.SSIMLoss
191 | params:
192 | mode: all
193 | - name: l1
194 | weight: 0.8
195 | target: src.trainer.losses.L1Loss
196 | params:
197 | mode: all
198 | - name: motion_l1_reg
199 | weight: 0.01
200 | start: 0
201 | target: src.trainer.losses.MotionL1Loss
202 | - name: motion_sparsity
203 | weight: 0.002
204 | start: 0
205 | target: src.trainer.losses.MotionSparsityLoss
206 | - name: global_pearson_depth
207 | weight: 0.05
208 | target: src.trainer.losses.GlobalPearsonDepthLoss
209 | start: 0
210 | params:
211 | mode: all
212 | - name: local_pearson_depth
213 | weight: 0.15
214 | target: src.trainer.losses.LocalPearsonDepthLoss
215 | start: 0
216 | params:
217 | box_p: 128
218 | p_corr: 0.5
219 | mode: all
220 | - name: rigidity
221 | weight: 0.5
222 | freq: 5
223 | start: 0
224 | target: src.trainer.losses.RigidityLoss
225 | params:
226 | mode:
227 | - distance_preserving
228 | - surface
229 | K: 8
230 | - name: motion_basis_reg
231 | weight: 0.1
232 | start: 0
233 | target: src.trainer.losses.MotionBasisRegularizaiton
234 | params:
235 | transl_degree: 0
236 | rot_degree: 0
237 | freq_div_mode: cum_exponential
238 | num_iterations: 150000
239 | position_lr_init: 0.00016
240 | position_lr_final: 1.6e-06
241 | position_lr_delay_mult: 0.01
242 | position_lr_max_steps: 150000
243 | feature_lr: 0.0025
244 | opacity_lr: 0.05
245 | scaling_lr: 0.001
246 | rotation_lr: 0.001
247 | percent_dense: 0.01
248 | densification_interval: 100
249 | opacity_reset_interval: 5000000
250 | densify_from_iter: 500
251 | densify_until_iter: 15000
252 | densify_grad_threshold: 0.0002
253 | deform_warmup_steps: 0
254 | deform_lr_init: 0.0016
255 | deform_lr_final: 0.00016
256 | deform_lr_delay_mult: 0.01
257 | deform_lr_max_steps: 150000
258 | motion_coeff_lr: 0.00016
259 | camera_opt_config:
260 | target: src.trainer.optim.CameraQuatOptimizer
261 | params:
262 | camera_rotation_lr: 0.0
263 | camera_translation_lr: 0.0
264 | camera_lr_warmup: 0
265 | total_steps: 150000
--------------------------------------------------------------------------------
/licences/LICENSE_gs.md:
--------------------------------------------------------------------------------
1 | Gaussian-Splatting License
2 | ===========================
3 |
4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5 | The *Software* is in the process of being registered with the Agence pour la Protection des
6 | Programmes (APP).
7 |
8 | The *Software* is still being developed by the *Licensor*.
9 |
10 | *Licensor*'s goal is to allow the research community to use, test and evaluate
11 | the *Software*.
12 |
13 | ## 1. Definitions
14 |
15 | *Licensee* means any person or entity that uses the *Software* and distributes
16 | its *Work*.
17 |
18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII
19 |
20 | *Software* means the original work of authorship made available under this
21 | License ie gaussian-splatting.
22 |
23 | *Work* means the *Software* and any additions to or derivative works of the
24 | *Software* that are made available under this License.
25 |
26 |
27 | ## 2. Purpose
28 | This license is intended to define the rights granted to the *Licensee* by
29 | Licensors under the *Software*.
30 |
31 | ## 3. Rights granted
32 |
33 | For the above reasons Licensors have decided to distribute the *Software*.
34 | Licensors grant non-exclusive rights to use the *Software* for research purposes
35 | to research users (both academic and industrial), free of charge, without right
36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37 | and/or evaluation purposes only.
38 |
39 | Subject to the terms and conditions of this License, you are granted a
40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41 | publicly display, publicly perform and distribute its *Work* and any resulting
42 | derivative works in any form.
43 |
44 | ## 4. Limitations
45 |
46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47 | so under this License, (b) you include a complete copy of this License with
48 | your distribution, and (c) you retain without modification any copyright,
49 | patent, trademark, or attribution notices that are present in the *Work*.
50 |
51 | **4.2 Derivative Works.** You may specify that additional or different terms apply
52 | to the use, reproduction, and distribution of your derivative works of the *Work*
53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
54 | Section 2 applies to your derivative works, and (b) you identify the specific
55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56 | this License (including the redistribution requirements in Section 3.1) will
57 | continue to apply to the *Work* itself.
58 |
59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60 | users explicitly acknowledge having received from Licensors all information
61 | allowing to appreciate the adequacy between of the *Software* and their needs and
62 | to undertake all necessary precautions for its execution and use.
63 |
64 | **4.4** The *Software* is provided both as a compiled library file and as source
65 | code. In case of using the *Software* for a publication or other results obtained
66 | through the use of the *Software*, users are strongly encouraged to cite the
67 | corresponding publications as explained in the documentation of the *Software*.
68 |
69 | ## 5. Disclaimer
70 |
71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
84 |
85 | ## 6. Files subject to permissive licenses
86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license.
87 |
88 | Title: pytorch-ssim\
89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\
90 | Copyright Evan Su, 2017\
91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)
--------------------------------------------------------------------------------
/scripts/img2video.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import imageio
12 | import os
13 | import argparse
14 |
15 |
16 | def convert_png_to_video(data_dir):
17 |
18 | imgs_path = os.path.join(data_dir, "train")
19 | images = [img for img in os.listdir(imgs_path) if img.endswith((".png"))]
20 | images.sort()
21 | output_video = os.path.join(data_dir, "train.mp4")
22 |
23 | if not images:
24 | print("No images!")
25 | else:
26 | with imageio.get_writer(
27 | output_video, fps=30, codec="libx264", macro_block_size=1
28 | ) as writer:
29 | for image in images:
30 | writer.append_data(imageio.imread(os.path.join(imgs_path, image)))
31 | print(f"Generate {output_video} ...")
32 |
33 |
34 | if __name__ == "__main__":
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument("--data_dir", required=True, type=str)
37 | args = parser.parse_args()
38 |
39 | convert_png_to_video(args.data_dir)
40 |
--------------------------------------------------------------------------------
/scripts/iphone2format.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 |
12 | import os
13 | import numpy as np
14 | import argparse
15 | from PIL import Image
16 | import json
17 | import math
18 |
19 |
20 | def focal2fov(focal, pixels):
21 | return 2 * math.atan(pixels / (2 * focal))
22 |
23 |
24 | def convert2format(data_dir, output_dir, resolution):
25 |
26 | split_path = os.path.join(data_dir, "splits")
27 | with open(os.path.join(split_path, "train.json"), "r") as fp:
28 | train_json = json.load(fp)
29 |
30 | img_paths = []
31 | cam_paths = []
32 | for frame_name in train_json["frame_names"]:
33 | if resolution == 1:
34 | img_paths.append(os.path.join(data_dir, "rgb", "1x", frame_name + ".png"))
35 | else:
36 | img_paths.append(os.path.join(data_dir, "rgb", "2x", frame_name + ".png"))
37 | cam_paths.append(os.path.join(data_dir, "camera", frame_name + ".json"))
38 |
39 | os.makedirs(output_dir, exist_ok=True)
40 | train_outimgdir = os.path.join(output_dir, "train")
41 | os.makedirs(train_outimgdir, exist_ok=True)
42 | test_outimgdir = os.path.join(output_dir, "test")
43 | os.makedirs(test_outimgdir, exist_ok=True)
44 |
45 | with open(cam_paths[0], "r") as fp:
46 | cam_0 = json.load(fp)
47 | train_transforms = dict()
48 | train_transforms["camera_angle_x"] = (
49 | focal2fov(cam_0["focal_length"], 720) * 180 / math.pi
50 | )
51 | train_transforms["camera_angle_y"] = (
52 | focal2fov(cam_0["focal_length"], 960) * 180 / math.pi
53 | )
54 | train_transforms["frames"] = []
55 | test_transforms = dict()
56 | test_transforms["camera_angle_x"] = (
57 | focal2fov(cam_0["focal_length"], 720) * 180 / math.pi
58 | )
59 | test_transforms["camera_angle_y"] = (
60 | focal2fov(cam_0["focal_length"], 960) * 180 / math.pi
61 | )
62 | test_transforms["frames"] = []
63 |
64 | # Get train & test images from only the train cameras
65 | frame_idx = 0
66 | train_id = 0
67 | test_id = 0
68 | for img, cam in zip(img_paths, cam_paths):
69 | frame = dict()
70 | image_fname = f"rgba_{frame_idx:05d}.png"
71 | image = Image.open(img)
72 |
73 | frame["time"] = frame_idx / len(img_paths)
74 | frame["file_path"] = image_fname # tmp
75 | frame["width"] = int(720 / resolution)
76 | frame["height"] = int(960 / resolution)
77 |
78 | with open(cam, "r") as fp:
79 | cam = json.load(fp)
80 | w2c_rot = np.array(cam["orientation"])
81 | c2w_rot = np.linalg.inv(w2c_rot)
82 |
83 | c2w = np.eye(4)
84 | c2w[:3, :3] = c2w_rot
85 | c2w[:3, 3] = np.array(cam["position"])
86 | frame["transform_matrix"] = c2w.tolist()
87 |
88 | if (frame_idx + 4) % 8 == 0:
89 | image_fname = f"rgba_{train_id:05d}.png"
90 | frame["file_path"] = os.path.join("test", image_fname)
91 | test_transforms["frames"].append(frame)
92 | test_out_image_fpath = os.path.join(test_outimgdir, image_fname)
93 | image.save(test_out_image_fpath)
94 | train_id += 1
95 | else:
96 | image_fname = f"rgba_{test_id:05d}.png"
97 | frame["file_path"] = os.path.join("train", image_fname)
98 | train_transforms["frames"].append(frame)
99 | train_out_image_fpath = os.path.join(train_outimgdir, image_fname)
100 | image.save(train_out_image_fpath)
101 | test_id += 1
102 |
103 | frame_idx += 1
104 |
105 | with open(os.path.join(output_dir, "train_transforms.json"), "w") as fp:
106 | json.dump(train_transforms, fp, indent=4)
107 | with open(os.path.join(output_dir, "test_transforms.json"), "w") as fp:
108 | json.dump(test_transforms, fp, indent=4)
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--data_dir", required=True, type=str)
114 | parser.add_argument("--output_dir", required=True, type=str)
115 | parser.add_argument("--resolution", type=int, default=1)
116 | args = parser.parse_args()
117 | assert args.resolution in [1, 2], "assume resolution is 1x or 2x"
118 |
119 | convert2format(args.data_dir, args.output_dir, args.resolution)
120 |
--------------------------------------------------------------------------------
/scripts/kubricmrig2format.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import argparse
12 | import json
13 | from pathlib import Path
14 | from PIL import Image
15 |
16 | import numpy as np
17 |
18 |
19 | # used for camera direction conversion(inverse direction)
20 | gl_matrix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
21 |
22 | # used for world coordinate conversion
23 | opencv_matrix = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
24 |
25 |
26 | def quaternion_to_rotation_matrix(quaternion):
27 | q = np.array(quaternion, dtype=np.float64)
28 | norm = np.linalg.norm(q)
29 | if norm == 0:
30 | return np.eye(3)
31 | q /= norm
32 | w, x, y, z = q
33 | rotation_matrix = np.array(
34 | [
35 | [1 - 2 * y**2 - 2 * z**2, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y],
36 | [2 * x * y + 2 * w * z, 1 - 2 * x**2 - 2 * z**2, 2 * y * z - 2 * w * x],
37 | [2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, 1 - 2 * x**2 - 2 * y**2],
38 | ]
39 | )
40 | return rotation_matrix
41 |
42 |
43 | def kubric2opencv(extrinsic):
44 | extrinsic = opencv_matrix @ extrinsic @ gl_matrix
45 |
46 | return extrinsic
47 |
48 |
49 | def convert2format(args):
50 | input_dir = Path(args.input_dir)
51 | train_dirpath = input_dir.joinpath("train")
52 | test_dirpath = input_dir.joinpath("test")
53 |
54 | outdirpath = Path(args.output_dir)
55 | outdirpath.mkdir(exist_ok=True, parents=True)
56 |
57 | for split, dirpath in zip(
58 | ["train", "val", "test"], [train_dirpath, test_dirpath, test_dirpath]
59 | ):
60 | metadata = dirpath.joinpath("metadata.json")
61 | outimgdir = outdirpath.joinpath(split)
62 | outimgdir.mkdir(exist_ok=True, parents=True)
63 |
64 | with open(metadata.as_posix(), "r") as fp:
65 | metadata = json.load(fp)
66 |
67 | transforms = dict()
68 |
69 | H, W = metadata["metadata"]["resolution"]
70 | fov = np.rad2deg(metadata["camera"]["field_of_view"])
71 | camera_angle_x, camera_angle_y = fov, fov
72 |
73 | transforms["camera_angle_x"] = camera_angle_x
74 | transforms["camera_angle_y"] = camera_angle_y
75 |
76 | transforms["frames"] = []
77 | num_frames = metadata["metadata"]["num_frames"]
78 |
79 | if split == "train":
80 | iterator = range(num_frames)
81 | elif split == "val":
82 | iterator = np.array(range(num_frames))[::10]
83 | else:
84 | # use the rest of the frames for testing
85 | iterator = np.array([idx for idx in range(num_frames) if idx % 10 != 0])
86 |
87 | for frame_idx in iterator:
88 | frame = dict()
89 |
90 | image_fname = f"rgba_{frame_idx:05d}.png"
91 | image_fpath = dirpath.joinpath(image_fname)
92 | image = Image.open(image_fpath)
93 | out_image_fpath = outimgdir.joinpath(image_fname)
94 | image.save(out_image_fpath)
95 |
96 | frame["time"] = frame_idx / num_frames
97 | frame["file_path"] = Path(split, image_fname).as_posix()
98 | frame["width"] = W
99 | frame["height"] = H
100 |
101 | c2w = np.eye(4)
102 | quaternion = metadata["camera"]["quaternions"][frame_idx]
103 | c2w[:3, :3] = quaternion_to_rotation_matrix(quaternion)
104 | c2w[:3, 3] = np.array(metadata["camera"]["positions"][frame_idx])
105 |
106 | # change coordinate system to opencv format
107 | # world coordinate : blender -> opencv
108 | # camera (local) type : opengl -> opencv
109 | c2w = kubric2opencv(c2w)
110 |
111 | frame["transform_matrix"] = c2w.tolist()
112 | transforms["frames"].append(frame)
113 |
114 | with open(outdirpath.joinpath(f"{split}_transforms.json"), "w") as fp:
115 | json.dump(transforms, fp, indent=4)
116 |
117 |
118 | if __name__ == "__main__":
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument(
121 | "--input_dir", required=True, type=str, help="path to generated kubric-mrig"
122 | )
123 | parser.add_argument(
124 | "--output_dir", required=True, type=str, help="path to store converted assets"
125 | )
126 | args = parser.parse_args()
127 |
128 | convert2format(args)
129 |
--------------------------------------------------------------------------------
/scripts/nvidia2format.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import os
12 | import numpy as np
13 | import torch
14 | import glob
15 | import argparse
16 | from PIL import Image
17 | import json
18 | import math
19 |
20 | img_downsample = 2
21 |
22 |
23 | def focal2fov(focal, pixels):
24 | return 2 * math.atan(pixels / (2 * focal))
25 |
26 |
27 | def convert2format(train_dir, test_dir, output_dir):
28 |
29 | train_poses_bounds = np.load(
30 | os.path.join(train_dir, "poses_bounds.npy")
31 | ) # (N_images, 17)
32 | train_image_paths = sorted(glob.glob(os.path.join(train_dir, "images_2/*")))
33 | test_image_paths = sorted(glob.glob(os.path.join(test_dir, "*.png")))
34 |
35 | train_poses = train_poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
36 | H, W, focal = train_poses[0, :, -1] # original intrinsics, same for all images
37 |
38 | H, W, focal = (
39 | H / img_downsample,
40 | W / img_downsample,
41 | focal / img_downsample,
42 | ) # original images are 2x downscaled (same as RodynRF training setup)
43 |
44 | print(f"img_h={H}, img_w={W}")
45 | FoVx = focal2fov(focal, W) * 180 / math.pi
46 | FoVy = focal2fov(focal, H) * 180 / math.pi
47 |
48 | # Original poses has rotation in form "down right back"(LLFF), change to "right down front"(OpenCV)
49 | train_poses = np.concatenate(
50 | [train_poses[..., 1:2], train_poses[..., :1], -train_poses[..., 2:4]], axis=-1
51 | )
52 | padding = np.array([0, 0, 0, 1]).reshape(1, 1, 4)
53 | train_poses = np.concatenate(
54 | [train_poses, np.tile(padding, (train_poses.shape[0], 1, 1))], axis=-2
55 | )
56 |
57 | train_outimgdir = os.path.join(output_dir, "train")
58 | os.makedirs(train_outimgdir, exist_ok=True)
59 | test_outimgdir = os.path.join(output_dir, "test")
60 | os.makedirs(test_outimgdir, exist_ok=True)
61 |
62 | train_transforms = dict()
63 | train_transforms["camera_angle_x"] = FoVx
64 | train_transforms["camera_angle_y"] = FoVy
65 | train_transforms["frames"] = []
66 | test_transforms = dict()
67 | test_transforms["camera_angle_x"] = FoVx
68 | test_transforms["camera_angle_y"] = FoVy
69 | test_transforms["frames"] = []
70 |
71 | for i, train_dirpath in enumerate(train_image_paths):
72 |
73 | frame = dict()
74 | image_fname = f"rgba_{i:05d}.png"
75 | image = Image.open(train_dirpath)
76 | out_image_fpath = os.path.join(train_outimgdir, image_fname)
77 | image.save(out_image_fpath)
78 |
79 | frame["time"] = i / len(train_image_paths)
80 | frame["file_path"] = os.path.join("train", image_fname)
81 | frame["width"] = int(W)
82 | frame["height"] = int(H)
83 |
84 | c2w = train_poses[i]
85 | frame["transform_matrix"] = c2w.tolist()
86 | train_transforms["frames"].append(frame)
87 |
88 | # c2w of all test cam == c2w of first train cam
89 | if i == 0:
90 | for j, test_dirpath in enumerate(test_image_paths):
91 | frame = dict()
92 | image_fname = f"rgba_{j:05d}.png"
93 | image = Image.open(test_dirpath)
94 | out_image_fpath = os.path.join(test_outimgdir, image_fname)
95 | image.save(out_image_fpath)
96 |
97 | frame["time"] = j / len(test_image_paths)
98 | frame["file_path"] = os.path.join("test", image_fname)
99 | frame["width"] = int(W)
100 | frame["height"] = int(H)
101 |
102 | frame["transform_matrix"] = c2w.tolist()
103 | test_transforms["frames"].append(frame)
104 | with open(os.path.join(output_dir, "test_transforms.json"), "w") as fp:
105 | json.dump(test_transforms, fp, indent=4)
106 |
107 | with open(os.path.join(output_dir, "train_transforms.json"), "w") as fp:
108 | json.dump(train_transforms, fp, indent=4)
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--train_dir", required=True, type=str)
114 | parser.add_argument("--test_dir", required=True, type=str)
115 | parser.add_argument("--output_dir", required=True, type=str)
116 | args = parser.parse_args()
117 |
118 | convert2format(args.train_dir, args.test_dir, args.output_dir)
119 |
--------------------------------------------------------------------------------
/scripts/run_depthanything.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import argparse
12 | import cv2
13 | import glob
14 | import matplotlib
15 | import numpy as np
16 | import os
17 | import torch
18 |
19 | from thirdparty.depth_anything_v2.depth_anything_v2.dpt import DepthAnythingV2
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser(description="Depth Anything V2")
24 |
25 | parser.add_argument("--img-path", type=str)
26 | parser.add_argument("--input-size", type=int, default=518)
27 | parser.add_argument("--outdir", type=str, default="./vis_depth")
28 |
29 | parser.add_argument(
30 | "--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"]
31 | )
32 | parser.add_argument("--encoder-path", type=str, required=True)
33 |
34 | parser.add_argument(
35 | "--pred-only",
36 | dest="pred_only",
37 | action="store_true",
38 | help="only display the prediction",
39 | )
40 | parser.add_argument(
41 | "--grayscale",
42 | dest="grayscale",
43 | action="store_true",
44 | help="do not apply colorful palette",
45 | )
46 | parser.add_argument(
47 | "--raw-depth",
48 | dest="raw_depth",
49 | action="store_true",
50 | help="do not apply colormap",
51 | )
52 |
53 | args = parser.parse_args()
54 |
55 | DEVICE = (
56 | "cuda"
57 | if torch.cuda.is_available()
58 | else "mps" if torch.backends.mps.is_available() else "cpu"
59 | )
60 |
61 | model_configs = {
62 | "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
63 | "vitb": {
64 | "encoder": "vitb",
65 | "features": 128,
66 | "out_channels": [96, 192, 384, 768],
67 | },
68 | "vitl": {
69 | "encoder": "vitl",
70 | "features": 256,
71 | "out_channels": [256, 512, 1024, 1024],
72 | },
73 | "vitg": {
74 | "encoder": "vitg",
75 | "features": 384,
76 | "out_channels": [1536, 1536, 1536, 1536],
77 | },
78 | }
79 |
80 | depth_anything = DepthAnythingV2(**model_configs[args.encoder])
81 | depth_anything.load_state_dict(torch.load(args.encoder_path, map_location="cpu"))
82 | depth_anything = depth_anything.to(DEVICE).eval()
83 |
84 | if os.path.isfile(args.img_path):
85 | if args.img_path.endswith("txt"):
86 | with open(args.img_path, "r") as f:
87 | filenames = f.read().splitlines()
88 | else:
89 | filenames = [args.img_path]
90 | else:
91 | filenames = glob.glob(os.path.join(args.img_path, "**/*"), recursive=True)
92 |
93 | os.makedirs(args.outdir, exist_ok=True)
94 |
95 | cmap = matplotlib.colormaps.get_cmap("Spectral_r")
96 |
97 | for k, filename in enumerate(filenames):
98 | print(f"Progress {k+1}/{len(filenames)}: {filename}")
99 |
100 | raw_image = cv2.imread(filename)
101 |
102 | depth = depth_anything.infer_image(raw_image, args.input_size)
103 |
104 | if args.raw_depth:
105 | np.save(
106 | os.path.join(
107 | args.outdir,
108 | os.path.splitext(os.path.basename(filename))[0] + ".npy",
109 | ),
110 | depth,
111 | )
112 |
113 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
114 | depth = depth.astype(np.uint8)
115 |
116 | if args.grayscale:
117 | depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
118 | else:
119 | depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
120 |
121 | if args.pred_only:
122 | cv2.imwrite(
123 | os.path.join(
124 | args.outdir,
125 | os.path.splitext(os.path.basename(filename))[0] + ".png",
126 | ),
127 | depth,
128 | )
129 | else:
130 | split_region = np.ones((raw_image.shape[0], 50, 3), dtype=np.uint8) * 255
131 | combined_result = cv2.hconcat([raw_image, split_region, depth])
132 |
133 | cv2.imwrite(
134 | os.path.join(
135 | args.outdir,
136 | os.path.splitext(os.path.basename(filename))[0] + ".png",
137 | ),
138 | combined_result,
139 | )
140 |
--------------------------------------------------------------------------------
/scripts/run_mast3r/depth_preprocessor/get_pcd.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os, cv2
3 | import numpy as np
4 | from PIL import Image
5 | from glob import glob
6 | from argparse import ArgumentParser
7 |
8 | from scripts.run_mast3r.depth_preprocessor.utils import resize_to_mast3r
9 | from scripts.run_mast3r.depth_preprocessor.pcd_utils import unproject_depth
10 |
11 |
12 | def mast3r_unprojection(
13 | data_dir,
14 | maskpaths,
15 | imagepaths,
16 | img_w,
17 | img_h,
18 | skip_dynamic,
19 | static_dst_dir="static",
20 | dynamic_dst_dir="dynamic",
21 | depth_dir="depth",
22 | ):
23 | assert data_dir is not None, "Please provide a depth name"
24 |
25 | # Load the global params
26 | pkl_path = os.path.join(data_dir, "global_params.pkl")
27 | with open(pkl_path, "rb") as pkl_file:
28 | data = pickle.load(pkl_file)
29 |
30 | focal = data["focals"][0]
31 | depth_max = data["max_depths"][0]
32 | depths = np.array(data["depths"])
33 |
34 | depths *= depth_max
35 | depths = np.clip(depths, 0, depth_max)
36 |
37 | static_pcd_path = os.path.join(data_dir, static_dst_dir)
38 | os.makedirs(static_pcd_path, exist_ok=True)
39 | depth_dst_dir = os.path.join(data_dir, depth_dir)
40 | os.makedirs(depth_dst_dir, exist_ok=True)
41 |
42 | # skip masked unprojection (for tanks and temples)
43 | if skip_dynamic:
44 | for i, imgpath in enumerate(list(imagepaths)):
45 | img = cv2.imread(imgpath)
46 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
47 | img = resize_to_mast3r(img, dst_size=None)
48 | extrinsic = data["cam2worlds"][i]
49 | depth = depths[i].reshape(img.shape[:2])
50 | static_pcd_path = os.path.join(data_dir, static_dst_dir)
51 | static_pcd_save_path = os.path.join(static_pcd_path, f"{i:04d}_static.ply")
52 | unproject_depth(
53 | focal,
54 | extrinsic,
55 | img,
56 | depth,
57 | export_path=static_pcd_save_path,
58 | mask=None,
59 | )
60 | depth_save_path = os.path.join(depth_dst_dir, f"{i:05}_depth.npy")
61 | np.save(depth_save_path, depth.reshape(img_h, img_w))
62 | return
63 |
64 | dynamic_pcd_path = os.path.join(data_dir, dynamic_dst_dir)
65 | os.makedirs(dynamic_pcd_path, exist_ok=True)
66 |
67 | loader = list(zip(imagepaths, maskpaths))
68 | for i, (imgpath, maskpath) in enumerate(loader):
69 | img = cv2.imread(imgpath)
70 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
71 | img = resize_to_mast3r(img, dst_size=None)
72 |
73 | mask = cv2.imread(maskpath, cv2.IMREAD_GRAYSCALE)
74 | mask = resize_to_mast3r(mask, dst_size=None)
75 | mask = mask > 0
76 |
77 | extrinsic = data["cam2worlds"][i]
78 | depth = depths[i].reshape(img.shape[:2])
79 |
80 | dynamic_pcd_save_path = os.path.join(dynamic_pcd_path, f"{i:04d}_dynamic.ply")
81 | unproject_depth(
82 | focal, extrinsic, img, depth, export_path=dynamic_pcd_save_path, mask=mask
83 | )
84 | static_pcd_save_path = os.path.join(static_pcd_path, f"{i:04d}_static.ply")
85 | unproject_depth(
86 | focal, extrinsic, img, depth, export_path=static_pcd_save_path, mask=~mask
87 | )
88 |
89 | depth_save_path = os.path.join(depth_dst_dir, f"{i:05}_depth.npy")
90 | np.save(depth_save_path, depth.reshape(img_h, img_w))
91 |
92 |
93 | def check_all_masks_false(maskpaths):
94 | for maskpath in maskpaths:
95 | # Load the image and convert it to a numpy array
96 | mask = np.array(Image.open(maskpath))
97 | if np.any(mask):
98 | return False
99 |
100 | return True
101 |
102 |
103 | if __name__ == "__main__":
104 | parser = ArgumentParser()
105 | parser.add_argument("--datadir", type=str, required=True)
106 | parser.add_argument("--mast3r_expname", type=str, required=True)
107 | parser.add_argument("--mask_name", type=str)
108 |
109 | args = parser.parse_args()
110 | skip_dynamic = False
111 |
112 | datadir = args.datadir
113 | mast3r_exp_dir = os.path.join(datadir, "mast3r_opt", args.mast3r_expname)
114 |
115 | # load mast3r format image
116 | imagepaths = sorted(glob(f"{datadir}/train/*.png"))
117 |
118 | # load dynamic mask
119 | maskpaths = sorted(glob(f"{datadir}/{args.mask_name}/*.png"))
120 | if len(maskpaths) == 0:
121 | maskpaths = sorted(
122 | glob(f"{datadir}/{args.mask_name}/*.jpg")
123 | ) # mask should not be jpg...
124 | print(f"\nload mask from {maskpaths[0]} ~ {maskpaths[-1]}\n")
125 |
126 | if check_all_masks_false(maskpaths):
127 | skip_dynamic = True
128 | print("\nNo Dynamic regions found. Skip dynamic unprojection\n")
129 |
130 | with open(os.path.join(mast3r_exp_dir, "global_params.pkl"), "rb") as f:
131 | global_params = pickle.load(f)
132 | mast3r_img_h = len(global_params["masks"][0])
133 | mast3r_img_w = len(global_params["masks"][0][0])
134 |
135 | mast3r_unprojection(
136 | mast3r_exp_dir, maskpaths, imagepaths, mast3r_img_w, mast3r_img_h, skip_dynamic
137 | )
138 |
--------------------------------------------------------------------------------
/scripts/run_mast3r/depth_preprocessor/pcd_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import trimesh
4 |
5 | from functools import lru_cache
6 |
7 |
8 | @lru_cache
9 | def mask110(device, dtype):
10 | return torch.tensor((1, 1, 0), device=device, dtype=dtype)
11 |
12 |
13 | def proj3d(inv_K, pixels, z):
14 | if pixels.shape[-1] == 2:
15 | pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
16 | return z.unsqueeze(-1) * (
17 | pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype)
18 | )
19 |
20 |
21 | # from dust3r point transformation
22 | def geotrf(Trf, pts, ncol=None, norm=False):
23 | """Apply a geometric transformation to a list of 3-D points.
24 |
25 | H: 3x3 or 4x4 projection matrix (typically a Homography)
26 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
27 |
28 | ncol: int. number of columns of the result (2 or 3)
29 | norm: float. if != 0, the resut is projected on the z=norm plane.
30 |
31 | Returns an array of projected 2d points.
32 | """
33 | assert Trf.ndim >= 2
34 | if isinstance(Trf, np.ndarray):
35 | pts = np.asarray(pts)
36 | elif isinstance(Trf, torch.Tensor):
37 | pts = torch.as_tensor(pts, dtype=Trf.dtype)
38 |
39 | # adapt shape if necessary
40 | output_reshape = pts.shape[:-1]
41 | ncol = ncol or pts.shape[-1]
42 |
43 | # optimized code
44 | if (
45 | isinstance(Trf, torch.Tensor)
46 | and isinstance(pts, torch.Tensor)
47 | and Trf.ndim == 3
48 | and pts.ndim == 4
49 | ):
50 | d = pts.shape[3]
51 | if Trf.shape[-1] == d:
52 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
53 | elif Trf.shape[-1] == d + 1:
54 | pts = (
55 | torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
56 | + Trf[:, None, None, :d, d]
57 | )
58 | else:
59 | raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
60 | else:
61 | if Trf.ndim >= 3:
62 | n = Trf.ndim - 2
63 | assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
64 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
65 |
66 | if pts.ndim > Trf.ndim:
67 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
68 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
69 | elif pts.ndim == 2:
70 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
71 | pts = pts[:, None, :]
72 |
73 | if pts.shape[-1] + 1 == Trf.shape[-1]:
74 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
75 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
76 | elif pts.shape[-1] == Trf.shape[-1]:
77 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
78 | pts = pts @ Trf
79 | else:
80 | pts = Trf @ pts.T
81 | if pts.ndim >= 2:
82 | pts = pts.swapaxes(-1, -2)
83 |
84 | if norm:
85 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
86 | if norm != 1:
87 | pts *= norm
88 |
89 | res = pts[..., :ncol].reshape(*output_reshape, ncol)
90 | return res
91 |
92 |
93 | def unproject_depth(focal, extrinsic, image, depth, export_path=None, mask=None):
94 | h, w, _ = image.shape
95 |
96 | # extract color
97 | pixels = torch.tensor(np.mgrid[:w, :h].T.reshape(-1, 2))
98 | K = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
99 | points = proj3d(
100 | torch.tensor(np.linalg.inv(K)), pixels, torch.tensor(depth.reshape(-1))
101 | )
102 |
103 | color = image.reshape(-1, 3)
104 |
105 | if mask is not None:
106 | mask = mask.ravel()
107 | color = color[mask]
108 | points = points[mask.ravel()]
109 |
110 | # convert to world coordinates
111 | # points = np.stack([x, y, z], axis=-1)
112 | points = geotrf(torch.tensor(extrinsic), torch.tensor(points))
113 | points = np.asarray(points).reshape(-1, 3)
114 |
115 | # save to ply with img color
116 | if export_path is not None:
117 | mesh = trimesh.PointCloud(vertices=points, colors=color)
118 | mesh.export(export_path)
119 |
120 | return points
121 |
--------------------------------------------------------------------------------
/scripts/run_mast3r/depth_preprocessor/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import PIL
4 |
5 | from functools import lru_cache
6 |
7 | import torch
8 | import numpy as np
9 | import trimesh
10 |
11 |
12 | def _resize_pil_image(img, long_edge_size):
13 | S = max(img.size)
14 | if S > long_edge_size:
15 | interp = PIL.Image.LANCZOS
16 | elif S <= long_edge_size:
17 | interp = PIL.Image.BICUBIC
18 | new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
19 | return img.resize(new_size, interp)
20 |
21 |
22 | def resize_to_mast3r(img, dst_size=None, size=512, square_ok=False):
23 | # convert to PIL image
24 | if isinstance(img, str):
25 | img = PIL.Image.open(img)
26 | elif isinstance(img, np.ndarray):
27 | img = PIL.Image.fromarray(img)
28 | elif not isinstance(img, PIL.Image.Image):
29 | raise ValueError(f"Invalid input type: {type(img)}")
30 |
31 | W1, H1 = img.size
32 |
33 | # resize long side to 512
34 | img = _resize_pil_image(img, size)
35 |
36 | W, H = img.size
37 | cx, cy = W // 2, H // 2
38 |
39 | halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
40 | if not (square_ok) and W == H:
41 | halfh = 3 * halfw / 4
42 | img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
43 |
44 | # convert to numpy array
45 | if dst_size is not None and dst_size != (W, H):
46 | img = img.resize(dst_size, PIL.Image.BICUBIC)
47 |
48 | img = np.array(img)
49 | return img
50 |
51 |
52 | @lru_cache
53 | def mask110(device, dtype):
54 | return torch.tensor((1, 1, 0), device=device, dtype=dtype)
55 |
56 |
57 | def proj3d(inv_K, pixels, z):
58 | if pixels.shape[-1] == 2:
59 | pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
60 | return z.unsqueeze(-1) * (
61 | pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype)
62 | )
63 |
64 |
65 | # from dust3r point transformation
66 | def geotrf(Trf, pts, ncol=None, norm=False):
67 | """Apply a geometric transformation to a list of 3-D points.
68 |
69 | H: 3x3 or 4x4 projection matrix (typically a Homography)
70 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
71 |
72 | ncol: int. number of columns of the result (2 or 3)
73 | norm: float. if != 0, the resut is projected on the z=norm plane.
74 |
75 | Returns an array of projected 2d points.
76 | """
77 | assert Trf.ndim >= 2
78 | if isinstance(Trf, np.ndarray):
79 | pts = np.asarray(pts)
80 | elif isinstance(Trf, torch.Tensor):
81 | pts = torch.as_tensor(pts, dtype=Trf.dtype)
82 |
83 | # adapt shape if necessary
84 | output_reshape = pts.shape[:-1]
85 | ncol = ncol or pts.shape[-1]
86 |
87 | # optimized code
88 | if (
89 | isinstance(Trf, torch.Tensor)
90 | and isinstance(pts, torch.Tensor)
91 | and Trf.ndim == 3
92 | and pts.ndim == 4
93 | ):
94 | d = pts.shape[3]
95 | if Trf.shape[-1] == d:
96 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
97 | elif Trf.shape[-1] == d + 1:
98 | pts = (
99 | torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
100 | + Trf[:, None, None, :d, d]
101 | )
102 | else:
103 | raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
104 | else:
105 | if Trf.ndim >= 3:
106 | n = Trf.ndim - 2
107 | assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
108 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
109 |
110 | if pts.ndim > Trf.ndim:
111 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
112 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
113 | elif pts.ndim == 2:
114 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
115 | pts = pts[:, None, :]
116 |
117 | if pts.shape[-1] + 1 == Trf.shape[-1]:
118 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
119 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
120 | elif pts.shape[-1] == Trf.shape[-1]:
121 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
122 | pts = pts @ Trf
123 | else:
124 | pts = Trf @ pts.T
125 | if pts.ndim >= 2:
126 | pts = pts.swapaxes(-1, -2)
127 |
128 | if norm:
129 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
130 | if norm != 1:
131 | pts *= norm
132 |
133 | res = pts[..., :ncol].reshape(*output_reshape, ncol)
134 | return res
135 |
136 |
137 | def unproject_depth(focal, extrinsic, image, depth, export_path=None, mask=None):
138 | h, w, _ = image.shape
139 |
140 | # extract color
141 | pixels = torch.tensor(np.mgrid[:w, :h].T.reshape(-1, 2))
142 | K = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]])
143 | points = proj3d(
144 | torch.tensor(np.linalg.inv(K)), pixels, torch.tensor(depth.reshape(-1))
145 | )
146 |
147 | color = image.reshape(-1, 3)
148 |
149 | if mask is not None:
150 | mask = mask.ravel()
151 | color = color[mask]
152 | points = points[mask.ravel()]
153 |
154 | # convert to world coordinates
155 | # points = np.stack([x, y, z], axis=-1)
156 | points = geotrf(torch.tensor(extrinsic), torch.tensor(points))
157 | points = np.asarray(points).reshape(-1, 3)
158 |
159 | # save to ply with img color
160 | if export_path is not None:
161 | mesh = trimesh.PointCloud(vertices=points, colors=color)
162 | mesh.export(export_path)
163 |
164 | return points
165 |
--------------------------------------------------------------------------------
/scripts/run_mast3r/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | #
5 | # --------------------------------------------------------
6 | # sparse gradio demo functions
7 | # --------------------------------------------------------
8 | import math
9 | import os
10 | from glob import glob
11 | import numpy as np
12 | import trimesh
13 | import copy
14 | from scipy.spatial.transform import Rotation
15 | from pathlib import Path
16 | import argparse
17 | import pickle
18 | import math
19 | import sys
20 | import cv2
21 |
22 | sys.path.insert(0, "./thirdparty/mast3r")
23 |
24 | from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
25 | from mast3r.cloud_opt.utils.schedules import cosine_schedule
26 | from dust3r.image_pairs import make_pairs
27 | from dust3r.utils.image import load_images
28 | from dust3r.utils.device import to_numpy
29 | import matplotlib.pyplot as pl
30 | from mast3r.model import AsymmetricMASt3R
31 | import torch
32 |
33 | torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
34 |
35 |
36 | def get_sparse_optim_args(config):
37 | optim_args = {
38 | "lr1": 0.2,
39 | "niter1": 500,
40 | "lr2": 0.02,
41 | "niter2": 500,
42 | "opt_pp": True,
43 | "opt_depth": True,
44 | "schedule": "cosine",
45 | "depth_mode": "add",
46 | "exp_depth": False,
47 | "lora_depth": False,
48 | "shared_intrinsics": True,
49 | "device": "cuda",
50 | "dtype": torch.float32,
51 | "matching_conf_thr": 5.0,
52 | "loss_dust3r_w": 0.01,
53 | }
54 |
55 | # update with config
56 | for key, value in config.items():
57 | if key in optim_args:
58 | optim_args[key] = value
59 | optim_args["opt_depth"] = "depth" in config["optim_level"]
60 | optim_args["schedule"] = cosine_schedule
61 |
62 | return optim_args
63 |
64 |
65 | def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type, winsize=10):
66 | num_files = len(inputfiles) if inputfiles is not None else 1
67 | max_winsize, min_winsize = 1, 1
68 | if scenegraph_type == "swin":
69 | if win_cyclic:
70 | max_winsize = max(1, math.ceil((num_files - 1) / 2))
71 | else:
72 | max_winsize = num_files - 1
73 | elif scenegraph_type == "logwin":
74 | if win_cyclic:
75 | half_size = math.ceil((num_files - 1) / 2)
76 | max_winsize = max(1, math.ceil(math.log(half_size, 2)))
77 | else:
78 | max_winsize = max(1, math.ceil(math.log(num_files, 2)))
79 | winsize = min(max_winsize, max(min_winsize, winsize))
80 |
81 | return winsize, win_cyclic
82 |
83 |
84 | def get_geometries_from_scene(scene, clean_depth, mask_sky, min_conf_thr):
85 |
86 | # post processes - clean depth, mask sky
87 | if mask_sky:
88 | scene = scene.mask_sky()
89 |
90 | # get optimized values from scene
91 | rgbimg, focals = scene.imgs, scene.get_focals().cpu()
92 |
93 | print("obtained focal length : ", focals)
94 | cams2world = scene.get_im_poses().cpu()
95 |
96 | # 3D pointcloud from depthmap, poses and intrinsics
97 | pts3d, depths, confs = to_numpy(
98 | scene.get_dense_pts3d(clean_depth=clean_depth)
99 | ) # ref mast3r/cloud_opt SparseGA() class
100 | msk = to_numpy([c > min_conf_thr for c in confs])
101 |
102 | # get normalized depthmaps
103 | depths_max = max([d.max() for d in depths]) # type: ignore
104 | depths = [d / depths_max for d in depths]
105 |
106 | return rgbimg, pts3d, msk, focals, cams2world, depths, depths_max
107 |
108 |
109 | def points_to_pct(pts, color, extrinsic, save_path=None):
110 | rot = np.eye(4)
111 | rot[:3, :3] = Rotation.from_euler("y", [np.deg2rad(180)]).as_matrix()
112 |
113 | pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=color.reshape(-1, 3))
114 | pct.apply_transform(np.linalg.inv(extrinsic))
115 |
116 | if save_path is not None:
117 | os.makedirs(Path(save_path).parent, exist_ok=True)
118 | pct.export(save_path)
119 | return pct
120 |
121 |
122 | def save_each_geometry(
123 | outdir,
124 | imgs,
125 | pts3d,
126 | mask,
127 | focals,
128 | cams2world,
129 | imgname=None,
130 | depths=None,
131 | depths_max=None,
132 | filter_pct=True,
133 | ):
134 | print(
135 | f"len(pts3d) : {len(pts3d)}, len(mask) : {len(mask)}, len(imgs) : {len(imgs)}, len(cams2world) : {len(cams2world)}"
136 | )
137 | assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
138 | extrinsic = cams2world[0].detach().numpy()
139 | pts3d, imgs, focals, cams2world = (
140 | to_numpy(pts3d),
141 | to_numpy(imgs),
142 | to_numpy(focals),
143 | to_numpy(cams2world),
144 | )
145 |
146 | if imgname is not None:
147 | outdir = os.path.join(outdir, imgname["data"])
148 |
149 | # original : dict_keys(['focal', 'cam2worlds', 'pct2worlds', 'pointcloud_paths', 'max_depths', 'depths'])
150 | global_dict = {
151 | "focals": [],
152 | "cam2worlds": [],
153 | "pointcloud_paths": [],
154 | "max_depths": [],
155 | "depths": [],
156 | "masks": [],
157 | }
158 |
159 | args_iter = (
160 | zip(pts3d, imgs, mask, focals, cams2world)
161 | if depths is None
162 | else zip(pts3d, imgs, mask, focals, cams2world, depths)
163 | )
164 | for i, arguments in enumerate(args_iter):
165 | if depths is None:
166 | points, img, point_mask, focal, cam2world = arguments
167 | depth, depths_max = None, None
168 | else:
169 | points, img, point_mask, focal, cam2world, depth = arguments
170 |
171 | if filter_pct:
172 | pts = points[point_mask.ravel()].reshape(-1, 3)
173 | col = img[point_mask].reshape(-1, 3)
174 | valid_msk = np.isfinite(pts.sum(axis=1))
175 | pts, col = pts[valid_msk], col[valid_msk]
176 | else:
177 | pts, col = points.reshape(-1, 3), img.reshape(-1, 3)
178 |
179 | filename = (
180 | f"pointcloud_{i:04d}.ply"
181 | if imgname is None
182 | else f"{imgname['img_nums'][i]:04d}_pointcloud_{i:04d}.ply"
183 | )
184 | points_to_pct(
185 | pts,
186 | col,
187 | np.eye(extrinsic.shape[0]),
188 | save_path=os.path.join(outdir, filename),
189 | )
190 | cam_params = {
191 | "focal": focal,
192 | "cam2world": cam2world,
193 | "c2w_original": cam2world,
194 | "depth": depth,
195 | "depth_max": depths_max,
196 | "base_extrinsic": extrinsic,
197 | "imgname": imgname,
198 | }
199 | maskdir = os.path.join(outdir, "masks")
200 | os.makedirs(maskdir, exist_ok=True)
201 |
202 | # save mask to png
203 | re_mask = point_mask.reshape(depth.shape)
204 | mask = np.zeros_like(re_mask, dtype=np.uint8)
205 | mask[re_mask == 0] = 255
206 | cv2.imwrite(os.path.join(maskdir, f"{i:04d}.png"), mask)
207 |
208 | with open(os.path.join(outdir, filename.replace(".ply", ".pkl")), "wb") as f:
209 | pickle.dump(cam_params, f)
210 |
211 | global_dict["focals"].append(focal)
212 | global_dict["cam2worlds"].append(cam2world)
213 | global_dict["pointcloud_paths"].append(os.path.join(outdir, filename))
214 | global_dict["max_depths"].append(depths_max)
215 | global_dict["depths"].append(depth)
216 | global_dict["masks"].append(point_mask)
217 |
218 | return global_dict
219 |
220 |
221 | def get_reconstructed_scene(
222 | outdir,
223 | cache_dir,
224 | model,
225 | device,
226 | image_size,
227 | filelist,
228 | optim_level,
229 | lr1,
230 | niter1,
231 | lr2,
232 | niter2,
233 | min_conf_thr,
234 | matching_conf_thr,
235 | mask_sky,
236 | clean_depth,
237 | scenegraph_type,
238 | winsize,
239 | win_cyclic,
240 | refid,
241 | TSDF_thresh,
242 | shared_intrinsics,
243 | filter_pct,
244 | optim_args,
245 | **kw,
246 | ):
247 | """
248 | from a list of images, run mast3r inference, sparse global aligner.
249 | then run get_3D_model_from_scene
250 | """
251 | imgs = load_images(filelist, size=image_size, verbose=True)
252 | if len(imgs) == 1:
253 | imgs = [imgs[0], copy.deepcopy(imgs[0])]
254 | imgs[1]["idx"] = 1
255 | filelist = [filelist[0], filelist[0] + "_2"]
256 |
257 | scene_graph_params = [scenegraph_type]
258 | if scenegraph_type in ["swin", "logwin"]:
259 | scene_graph_params.append(str(winsize))
260 | elif scenegraph_type == "oneref":
261 | scene_graph_params.append(str(refid))
262 | if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
263 | scene_graph_params.append("noncyclic")
264 | scene_graph = "-".join(scene_graph_params)
265 | pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
266 | if optim_level == "coarse":
267 | niter2 = 0
268 |
269 | # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
270 | os.makedirs(cache_dir, exist_ok=True)
271 | scene = sparse_global_alignment(filelist, pairs, cache_dir, model, **optim_args)
272 |
273 | rgbimg, pts3d, msk, focals, cams2world, depths, depths_max = (
274 | get_geometries_from_scene(scene, clean_depth, mask_sky, min_conf_thr)
275 | )
276 | global_dict = save_each_geometry(
277 | os.path.join(outdir, "op_results"),
278 | rgbimg,
279 | pts3d,
280 | msk,
281 | focals,
282 | cams2world,
283 | imgname=None,
284 | depths=depths,
285 | depths_max=depths_max,
286 | filter_pct=filter_pct,
287 | )
288 |
289 | with open(os.path.join(outdir, "global_params.pkl"), "wb") as f:
290 | pickle.dump(global_dict, f)
291 |
292 |
293 | def main():
294 |
295 | parser = argparse.ArgumentParser()
296 | parser.add_argument("--input_dir", type=str, default="data", help="input directory")
297 | parser.add_argument(
298 | "--output_dir", type=str, default="output", help="output directory"
299 | )
300 | parser.add_argument("--exp_name", type=str, default="exp", help="experiment name")
301 | parser.add_argument(
302 | "--ckpt",
303 | type=str,
304 | default="checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
305 | help="mast3r ckpt",
306 | )
307 | parser.add_argument(
308 | "--cache_dir", type=str, default="optim_cache", help="cache directory"
309 | )
310 |
311 | args = parser.parse_args()
312 |
313 | device = "cuda"
314 | model = AsymmetricMASt3R.from_pretrained(args.ckpt).to(device)
315 |
316 | filelist = sorted(glob(os.path.join(args.input_dir, "*.png")))
317 |
318 | cache_dir = os.path.join(
319 | args.cache_dir,
320 | f"{os.path.basename(os.path.dirname(args.input_dir))}_{np.random.randint(1e6):05d}",
321 | )
322 | outdir = os.path.join(args.output_dir, args.exp_name + "_000")
323 |
324 | data = {"cache_dir": "optim_cache"}
325 | optimization = {
326 | "device": "cuda",
327 | "image_size": 512,
328 | "shared_intrinsics": True,
329 | "win_cyclic": False,
330 | "lr1": 0.07,
331 | "niter1": 500,
332 | "lr2": 0.014,
333 | "niter2": 200,
334 | "optim_level": "refine+depth",
335 | "scenegraph_type": "swin",
336 | "winsize": 10,
337 | "refid": 0,
338 | "TSDF_thresh": 0,
339 | "schedule": "cosine",
340 | "min_conf_thr": 1.5,
341 | "matching_conf_thr": 5,
342 | "mask_sky": False,
343 | "clean_depth": True,
344 | "filter_pct": True,
345 | }
346 | arg_dict = dict()
347 |
348 | arg_dict.update(data)
349 | arg_dict.update(optimization)
350 | arg_dict.update({"filelist": filelist, "model": model, "schedule": cosine_schedule})
351 | arg_dict.update(
352 | {"outdir": outdir, "input_dir": args.input_dir, "cache_dir": cache_dir}
353 | )
354 | winsize, win_cyclic = set_scenegraph_options(filelist, False, 0, "swin", 10)
355 | arg_dict.update({"winsize": winsize, "win_cyclic": win_cyclic})
356 | arg_dict.update({"optim_args": get_sparse_optim_args(optimization)})
357 |
358 | get_reconstructed_scene(**arg_dict)
359 |
360 |
361 | if __name__ == "__main__":
362 | main()
363 |
--------------------------------------------------------------------------------
/scripts/tam_npy2png.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 |
12 | import numpy as np
13 | from PIL import Image
14 | import os
15 | import argparse
16 |
17 |
18 | def convert_npy_to_png(input_dir, output_dir):
19 | os.makedirs(os.path.join(output_dir, "tam_mask"), exist_ok=True)
20 |
21 | for file_name in os.listdir(input_dir):
22 | if file_name.endswith(".npy"):
23 | input_file_path = os.path.join(input_dir, file_name)
24 | output_file_name = f"{int(os.path.splitext(file_name)[0]):06d}" + ".png"
25 | output_tam_path = os.path.join(output_dir, "tam_mask", output_file_name)
26 |
27 | try:
28 | motion_mask = np.load(input_file_path)
29 | binary_image = (motion_mask * 255).astype(np.uint8)
30 |
31 | Image.fromarray(binary_image).save(output_tam_path)
32 | print(f"Converted: {input_file_path} -> {output_tam_path}")
33 | except Exception as e:
34 | print(f"Failed to convert {input_file_path}: {e}")
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--npy_dir", required=True, type=str)
40 | parser.add_argument("--output_dir", required=True, type=str)
41 | args = parser.parse_args()
42 |
43 | convert_npy_to_png(args.npy_dir, args.output_dir)
44 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/src/data/asset_readers.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import os
12 | from pathlib import Path
13 | from PIL import Image
14 | import json
15 | import pickle
16 |
17 | import numpy as np
18 | import torch
19 | from torchvision.transforms import ToTensor
20 |
21 | from src.utils.graphic_utils import focal2fov, quaternion_to_matrix
22 | from src.utils.point_utils import uniform_sample, merge_pcds
23 | from .utils import fetchPly
24 |
25 |
26 | class GTCameraReader:
27 | # To read full poses (for evaluation)
28 |
29 | def __init__(self, dirpath, fname, **kwargs):
30 |
31 | poses = []
32 |
33 | with open(os.path.join(dirpath, fname)) as json_file:
34 | contents = json.load(json_file)
35 | fovx = contents["camera_angle_x"]
36 | for frame in contents["frames"]:
37 | c2w = np.array(frame["transform_matrix"], dtype=np.float32)
38 | poses.append(torch.from_numpy(c2w))
39 | self._poses = np.array(torch.stack(poses))
40 | self._fovx = np.deg2rad(fovx)
41 |
42 | def get_poses(self, idx=None):
43 | if idx is None:
44 | return self._poses
45 | else:
46 | return self._poses[idx]
47 |
48 | def get_fovx(self, idx):
49 | return self._fovx
50 |
51 |
52 | class DepthAnythingReader:
53 | prefix: str = "depth_anything"
54 |
55 | def __init__(self, **kwargs):
56 | pass
57 |
58 | def __call__(self, dirpath, basename):
59 | asset_path = Path(dirpath).joinpath(self.prefix)
60 | base_name_wo_ext = os.path.splitext(basename)[0]
61 | base_name_tiff = base_name_wo_ext + ".npy"
62 | file_path = asset_path.joinpath(base_name_tiff)
63 | depth = -torch.from_numpy(np.load(file_path.as_posix())).unsqueeze(0)
64 | return (depth - depth.min()) / (depth.max() - depth.min())
65 |
66 |
67 | class TAMMaskReader:
68 | prefix: str = "tam_mask"
69 | to_tensor = ToTensor()
70 |
71 | def __init__(self, split, resolution=1):
72 | assert split in ["train", "val", "test"]
73 | self.prefix = f"tam_mask"
74 | self.resolution = resolution
75 |
76 | def __call__(self, dirpath, basename):
77 | asset_path = Path(dirpath).joinpath(self.prefix)
78 | base_name_wo_ext = os.path.splitext(basename)[0]
79 | rgb_idx = base_name_wo_ext.split("_")[-1]
80 | rgb_idx = rgb_idx.zfill(6)
81 | base_name_png = f"{rgb_idx}.jpg"
82 | file_path = asset_path.joinpath(base_name_png)
83 | if not file_path.exists():
84 | base_name_png = f"{rgb_idx}.png"
85 | file_path = asset_path.joinpath(base_name_png)
86 | if self.resolution != 1:
87 | img = Image.open(file_path)
88 | w, h = img.size
89 | new_size = (w // self.resolution, h // self.resolution)
90 | return self.to_tensor(img.resize(new_size, Image.NEAREST)) > 0
91 | else:
92 | return self.to_tensor(Image.open(file_path)) > 0
93 |
94 |
95 | class Test_MASt3RFovCameraReader:
96 | # To read full poses and trained fov (for evaluation)
97 |
98 | def __init__(self, dirpath, fname, mast3r_expname, mast3r_img_res, **kwargs):
99 |
100 | self.dirname = "mast3r_opt"
101 | poses = []
102 |
103 | # read gt test poses from test_transforms.json
104 | with open(os.path.join(dirpath, fname)) as json_file:
105 | contents = json.load(json_file)
106 | for frame in contents["frames"]:
107 | c2w = frame["transform_matrix"]
108 | poses.append(c2w)
109 | self._poses = np.array(poses, dtype=np.float32)
110 |
111 | # read fov from dust3r init
112 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl")
113 | with open(pkl_path.as_posix(), "rb") as pkl_file:
114 | data = pickle.load(pkl_file)
115 |
116 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res)
117 |
118 | def get_poses(self, idx=None):
119 | if idx is None:
120 | return self._poses
121 | else:
122 | return self._poses[idx]
123 |
124 | def get_fovx(self, idx):
125 | return self._fovx
126 |
127 |
128 | class MASt3RCameraReader:
129 |
130 | dirname = "mast3r_opt"
131 |
132 | def __init__(self, dirpath, mast3r_expname, mast3r_img_res, **kwargs):
133 |
134 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl")
135 | with open(pkl_path.as_posix(), "rb") as pkl_file:
136 | data = pickle.load(pkl_file)
137 |
138 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res)
139 | self._poses = data["cam2worlds"]
140 |
141 | def get_poses(self, idx):
142 | return self._poses[idx]
143 |
144 | def get_fovx(self, idx):
145 | return self._fovx
146 |
147 |
148 | class MASt3R_CKPTCameraReader:
149 |
150 | dirname = "mast3r_opt"
151 |
152 | def __init__(self, dirpath, ckpt_path, mast3r_expname, mast3r_img_res, **kwargs):
153 |
154 | # read poses from ckpt
155 | ckpt = torch.load(ckpt_path)
156 | c2w_rot = quaternion_to_matrix(ckpt[0]["camera"]["R_c2ws_quat"])
157 | c2w_trans = ckpt[0]["camera"]["T_c2ws"].unsqueeze(-1)
158 | c2w = torch.cat((c2w_rot, c2w_trans), dim=-1)
159 | bottom_row = (
160 | torch.tensor([[0, 0, 0, 1]], dtype=c2w_rot.dtype)
161 | .expand(c2w.shape[0], -1, -1)
162 | .cuda()
163 | )
164 | self._poses = torch.cat((c2w, bottom_row), dim=1).cpu().numpy()
165 |
166 | # read fovx from dust3r init
167 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl")
168 | with open(pkl_path.as_posix(), "rb") as pkl_file:
169 | data = pickle.load(pkl_file)
170 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res)
171 |
172 | def get_poses(self, idx):
173 | return self._poses[idx]
174 |
175 | def get_fovx(self, idx):
176 | return self._fovx
177 |
178 |
179 | class MASt3RPCDReader:
180 |
181 | path = "./op_results"
182 | dirname = "mast3r_opt"
183 | dynamic_path = "./dynamic"
184 | static_path = "./static"
185 |
186 | skip_dynamic = False
187 |
188 | def __init__(
189 | self,
190 | dirpath,
191 | mast3r_expname,
192 | mode=None,
193 | downsample_ratio=0.1,
194 | num_limit_points=None,
195 | **kwargs,
196 | ):
197 |
198 | if not Path(dirpath, self.dirname, mast3r_expname, self.dynamic_path).exists():
199 | static_pcd_paths = Path(
200 | dirpath, self.dirname, mast3r_expname, self.static_path
201 | )
202 | static_pcd_file = [
203 | pth.as_posix() for pth in sorted(static_pcd_paths.glob("*.ply"))
204 | ][0]
205 | static_pcd = fetchPly(static_pcd_file)
206 | self.pcd = static_pcd
207 | self.skip_dynamic = True
208 | return
209 |
210 | if mode == "dynamic":
211 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.dynamic_path)
212 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))]
213 | elif mode == "static":
214 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.static_path)
215 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))]
216 | else:
217 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.path)
218 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))]
219 | pcds = [fetchPly(pcd_file) for pcd_file in pcd_files]
220 |
221 | json_file = Path(dirpath, "train_transforms.json")
222 | with open(json_file) as json_file:
223 | contents = json.load(json_file)
224 | times = [frame["time"] for frame in contents["frames"]]
225 |
226 | for idx, pcd in enumerate(pcds):
227 | time = times[idx]
228 | num_points = len(pcd.points)
229 | pcd.time = np.ones(num_points) * time
230 |
231 | merged_pcds = merge_pcds(pcds)
232 |
233 | if num_limit_points is not None:
234 | print(f"override downsample_ratio with num_vertices: {num_limit_points}")
235 | downsample_ratio = min(num_limit_points / len(merged_pcds.points), 1.0)
236 |
237 | self.pcd = uniform_sample(merged_pcds, downsample_ratio)
238 |
239 | def __call__(self):
240 | return self.pcd, self.skip_dynamic
241 |
--------------------------------------------------------------------------------
/src/data/dataloader.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import numpy as np
12 | from torch.utils.data import Dataset
13 |
14 |
15 | class WarmupDataLoader:
16 | # Pick a random viewpoint over selectable viewpoints.
17 | # Selectable viewpoints are incrementally increasing as the `register_frame` is called.
18 | is_incremental: bool = True
19 | warmup_rate = 20
20 |
21 | def __init__(self, dataset: Dataset, init_num_frames: int, return_idx=False):
22 | self.dataset = dataset
23 | self.curr_num_frames = init_num_frames
24 | self.total_len = len(dataset)
25 | self.return_idx = return_idx
26 |
27 | num_list = int(np.ceil(self.num_iterations / self.total_len))
28 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)]
29 | self.idx_list = [i for idx, i in cat_list if idx % self.warmup_rate == 0]
30 |
31 | def __iter__(self):
32 | for random_idx in self.idx_list:
33 | if self.return_idx:
34 | yield random_idx, self.dataset[random_idx]
35 | else:
36 | yield self.dataset[random_idx]
37 |
38 | def register_all_frame(self):
39 | num_list = int(np.ceil(self.num_iterations / self.total_len))
40 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)]
41 | remainder = self.num_iterations % self.total_len
42 | if remainder != 0:
43 | cat_list[-1] = cat_list[-1][:remainder]
44 | self.idx_list = np.concatenate(cat_list)
45 |
46 |
47 | class PermutationSingleDataLoader:
48 | is_incremental: bool = False
49 |
50 | def __init__(self, dataset: Dataset, num_iterations: int, return_idx: bool = False):
51 | self.dataset = dataset
52 | self.total_len = len(dataset)
53 | self.return_idx = return_idx
54 | self.num_iterations = num_iterations
55 | self.idx_list = self.get_permuted_idx_list()
56 |
57 | def get_permuted_idx_list(self):
58 | num_list = int(np.ceil(self.num_iterations / self.total_len))
59 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)]
60 | remainder = self.num_iterations % self.total_len
61 | if remainder != 0:
62 | cat_list[-1] = cat_list[-1][:remainder]
63 | idx_list = np.concatenate(cat_list)
64 | return idx_list
65 |
66 | def __iter__(self):
67 | for random_idx in self.idx_list:
68 | if self.return_idx:
69 | yield random_idx, self.dataset[random_idx]
70 | else:
71 | yield self.dataset[random_idx]
72 |
73 |
74 | class SequentialSingleDataLoader:
75 | # Call all viewpoints in a sequential manner.
76 | is_incremental: bool = False
77 |
78 | def __init__(self, dataset: Dataset, return_idx: bool = False, **kwargs):
79 | self.dataset = dataset
80 | self.total_len = len(dataset)
81 | self.return_idx = return_idx
82 |
83 | def __len__(self):
84 | return self.total_len
85 |
86 | def __iter__(self):
87 | for idx in range(self.total_len):
88 | if self.return_idx:
89 | yield idx, self.dataset[idx]
90 | else:
91 | yield self.dataset[idx]
92 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import math
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | from plyfile import PlyData
17 |
18 | from src.utils.graphic_utils import (
19 | getProjectionMatrix,
20 | getWorld2View2,
21 | matrix_to_quaternion,
22 | quaternion_to_matrix,
23 | )
24 | from src.utils.point_utils import BasicPointCloud
25 |
26 |
27 | class CameraInterface(nn.Module):
28 |
29 | znear = 0.01
30 | zfar = 100.0
31 | trans = 0.0
32 | scale = 1.0
33 |
34 | def __init__(
35 | self, R, T, image, image_name, time, depth, normal, motion_mask, cam_idx
36 | ):
37 | super(CameraInterface, self).__init__()
38 |
39 | self.R = R
40 | self.T = T
41 | self.image_name = image_name
42 | self.image_width = image.shape[2]
43 | self.image_height = image.shape[1]
44 | self.cam_idx = cam_idx
45 | self.register_buffer("time", torch.tensor(time))
46 | self.register_buffer("original_image", image)
47 |
48 | if depth is not None:
49 | self.register_buffer("depth", depth)
50 | else:
51 | self.depth = None
52 |
53 | if normal is not None:
54 | self.register_buffer("normal", normal)
55 | else:
56 | self.normal = None
57 |
58 | if motion_mask is not None:
59 | self.register_buffer("motion_mask", motion_mask)
60 | else:
61 | self.motion_mask = None
62 |
63 |
64 | class FixedCamera(CameraInterface):
65 | def __init__(
66 | self,
67 | R,
68 | T,
69 | FoVx,
70 | FoVy,
71 | image,
72 | image_name,
73 | time,
74 | depth,
75 | normal,
76 | motion_mask,
77 | cam_idx,
78 | ):
79 |
80 | super(FixedCamera, self).__init__(
81 | R=R,
82 | T=T,
83 | image=image,
84 | image_name=image_name,
85 | time=time,
86 | depth=depth,
87 | normal=normal,
88 | motion_mask=motion_mask,
89 | cam_idx=cam_idx,
90 | )
91 | self.FoVx = FoVx
92 | self.FoVy = FoVy
93 |
94 | world_view_transform = torch.from_numpy(
95 | getWorld2View2(R, T, self.trans, self.scale)
96 | )
97 | projection_matrix = getProjectionMatrix(
98 | self.znear, self.zfar, self.FoVx, self.FoVy
99 | )
100 |
101 | self.register_buffer("world_view_transform", world_view_transform)
102 | self.register_buffer("projection_matrix", projection_matrix)
103 |
104 |
105 | class FixedCameraTorch:
106 |
107 | znear = 0.01
108 | zfar = 100.0
109 | trans = 0.0
110 | scale = 1.0
111 |
112 | def __init__(
113 | self,
114 | R_quat_c2w,
115 | T_c2w,
116 | FoVx,
117 | FoVy,
118 | image,
119 | image_name,
120 | time,
121 | depth,
122 | max_depth,
123 | normal,
124 | motion_mask,
125 | cam_idx,
126 | ):
127 |
128 | self.image_name = image_name
129 | self.image_width = image.shape[2]
130 | self.image_height = image.shape[1]
131 | self.time = torch.tensor(time)
132 | self.original_image = image
133 |
134 | self.depth = depth
135 | self.max_depth = max_depth
136 | self.normal = normal
137 | self.motion_mask = motion_mask
138 |
139 | self.FoVx = FoVx
140 | self.FoVy = FoVy
141 | self.R_quat_c2w = R_quat_c2w
142 | self.T_c2w = T_c2w
143 | projection_matrix = getProjectionMatrix(
144 | self.znear, self.zfar, self.FoVx, self.FoVy
145 | )
146 | self.projection_matrix = projection_matrix
147 | self.cam_idx = cam_idx
148 |
149 | def cuda(self):
150 | self.original_image = self.original_image.cuda()
151 | if self.depth is not None:
152 | self.depth = self.depth.cuda()
153 | if self.normal is not None:
154 | self.normal = self.normal.cuda()
155 | if self.motion_mask is not None:
156 | self.motion_mask = self.motion_mask.cuda()
157 | self.time = self.time.cuda()
158 | self.projection_matrix = self.projection_matrix.cuda()
159 | return self
160 |
161 | @property
162 | def world_view_transform(self):
163 | R_c2w = quaternion_to_matrix(self.R_quat_c2w)
164 | T_c2w = (self.T_c2w + self.trans) * self.scale
165 | R_w2c = R_c2w.transpose(0, 1)
166 | T_w2c = -torch.einsum("ij, j -> i", R_w2c, T_c2w)
167 | ret = torch.eye(4).type_as(R_c2w)
168 | ret[:3, :3] = R_w2c
169 | ret[:3, 3] = T_w2c
170 | return ret
171 |
172 |
173 | class LearnableCamera(CameraInterface):
174 |
175 | def __init__(
176 | self,
177 | R,
178 | T,
179 | FoVx,
180 | FoVy,
181 | image,
182 | image_name,
183 | time,
184 | depth,
185 | normal,
186 | motion_mask,
187 | cam_idx,
188 | ):
189 | super(LearnableCamera, self).__init__(
190 | R=R,
191 | T=T,
192 | image=image,
193 | image_name=image_name,
194 | time=time,
195 | depth=depth,
196 | normal=normal,
197 | motion_mask=motion_mask,
198 | cam_idx=cam_idx,
199 | )
200 |
201 | self.image_name = image_name
202 | self.image_width = image.shape[2]
203 | self.image_height = image.shape[1]
204 |
205 | R_tr = np.transpose(R, (1, 0))
206 | R_c2w = torch.from_numpy(R_tr)
207 | T_c2w = torch.from_numpy(-np.einsum("ij, j", R_tr, T))
208 |
209 | R_c2w_quat = matrix_to_quaternion(R_c2w)
210 |
211 | self.register_parameter(
212 | "R_c2w_quat", nn.Parameter(R_c2w_quat, requires_grad=True)
213 | )
214 | self.register_parameter("T_c2w", nn.Parameter(T_c2w, requires_grad=True))
215 | self.register_parameter(
216 | "FoVx",
217 | nn.Parameter(torch.tensor(FoVx).clone().detach(), requires_grad=False),
218 | )
219 | self.register_parameter(
220 | "FoVy",
221 | nn.Parameter(torch.tensor(FoVy).clone().detach(), requires_grad=False),
222 | )
223 |
224 | @property
225 | def world_view_transform(self):
226 | R_c2w = quaternion_to_matrix(self.R_c2w_quat)
227 | T_c2w = (self.T_c2w + self.trans) * self.scale
228 | ret = torch.eye(4).type_as(R_c2w)
229 | rot_transpose = R_c2w.transpose(0, 1)
230 | ret[:3, :3] = rot_transpose
231 | ret[:3, 3] = -torch.einsum("ij, j -> i", rot_transpose, T_c2w)
232 | return ret
233 |
234 | @property
235 | def projection_matrix(self):
236 | tanHalfFovY = math.tan((self.FoVy / 2))
237 | tanHalfFovX = math.tan((self.FoVx / 2))
238 |
239 | top = tanHalfFovY * self.znear
240 | bottom = -top
241 | right = tanHalfFovX * self.znear
242 | left = -right
243 |
244 | P = torch.zeros(4, 4).type_as(self.R_c2w_quat)
245 |
246 | z_sign = 1.0
247 |
248 | P[0, 0] = 2.0 * self.znear / (right - left)
249 | P[1, 1] = 2.0 * self.znear / (top - bottom)
250 | P[0, 2] = (right + left) / (right - left)
251 | P[1, 2] = (top + bottom) / (top - bottom)
252 | P[3, 2] = z_sign
253 | P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear)
254 | P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear)
255 |
256 | return P
257 |
258 |
259 | def fetchPly(path):
260 | plydata = PlyData.read(path)
261 | vertices = plydata["vertex"]
262 | positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
263 | colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
264 | if "nx" in vertices:
265 | normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T
266 | else:
267 | normals = np.zeros_like(positions)
268 | if "time" in vertices:
269 | timestamp = vertices["time"][:, None]
270 | else:
271 | timestamp = None
272 | return BasicPointCloud(
273 | points=positions, colors=colors, normals=normals, time=timestamp
274 | )
275 |
276 |
277 | def PILtoTorch(pil_image):
278 | image_tensor = torch.from_numpy(np.array(pil_image)) / 255.0
279 | if len(image_tensor.shape) == 3:
280 | return image_tensor.permute(2, 0, 1)
281 | else:
282 | return image_tensor.unsqueeze(dim=-1).permute(2, 0, 1)
283 |
--------------------------------------------------------------------------------
/src/evaluator/utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import torch
12 | import torch.nn.functional as F
13 |
14 |
15 | def search_nearest_two(query_pose, db_poses):
16 | """
17 | query_pose: torch.Tensor (4 X 4)
18 | db_poses: torch.Tensor (N X 4 X 4)
19 | """
20 | query_t = query_pose[None, :3, 3]
21 | db_t = db_poses[:, :3, 3]
22 |
23 | distances = torch.norm(query_t - db_t, dim=1)
24 | nearest_indices = torch.topk(distances, k=2, largest=False).indices
25 |
26 | return nearest_indices
27 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/src/pipelines/eval.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import argparse
12 |
13 | from pathlib import Path
14 | from omegaconf import OmegaConf
15 |
16 | from src.utils.configs import str2bool, instantiate_from_config
17 | from src.utils.general_utils import seed_all
18 |
19 |
20 | def parse_args():
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument(
24 | "-m", "--model_path", type=str, required=True, help="path to log"
25 | )
26 | parser.add_argument(
27 | "-c",
28 | "--eval_config",
29 | required=True,
30 | )
31 | parser.add_argument(
32 | "-d",
33 | "--dirpath",
34 | type=str,
35 | required=True,
36 | help="path to data directory",
37 | )
38 | parser.add_argument(
39 | "-t",
40 | "--task_name",
41 | type=str,
42 | default="eval",
43 | help="the name of evaluation task.",
44 | )
45 | parser.add_argument(
46 | "--verbose", type=str2bool, default=False, help="verbose printing"
47 | )
48 | parser.add_argument(
49 | "--debug", type=str2bool, default=False, help="debug mode running"
50 | )
51 |
52 | args, unknown = parser.parse_known_args()
53 |
54 | return args, unknown
55 |
56 |
57 | if __name__ == "__main__":
58 |
59 | args, unknown = parse_args()
60 |
61 | model_path = Path(args.model_path)
62 |
63 | train_config_path = model_path.joinpath("train/config.yaml")
64 | train_config = OmegaConf.load(train_config_path)
65 | eval_config = OmegaConf.load(args.eval_config)
66 | config = OmegaConf.merge(train_config, eval_config)
67 |
68 | out_path = model_path.joinpath(args.task_name)
69 | out_path.mkdir(exist_ok=args.debug)
70 |
71 | static_ckpt_path = model_path.joinpath("train", "static_last.ckpt")
72 | dynamic_ckpt_path = model_path.joinpath("train", "dynamic_last.ckpt")
73 |
74 | seed_all(config.metadata.seed)
75 |
76 | normalize_cams = eval_config.get("normalize_cams", False)
77 |
78 | static_datamodule = instantiate_from_config(
79 | config.static_data, normalize_cams=normalize_cams, ckpt_path=static_ckpt_path
80 | )
81 | dynamic_datamodule = instantiate_from_config(
82 | config.dynamic_data, normalize_cams=normalize_cams, ckpt_path=dynamic_ckpt_path
83 | )
84 | static_model = instantiate_from_config(config.static_model)
85 | dynamic_model = instantiate_from_config(config.dynamic_model)
86 | evaluator = instantiate_from_config(
87 | eval_config.evaluator,
88 | dirpath=args.dirpath,
89 | static_datamodule=static_datamodule,
90 | dynamic_datamodule=dynamic_datamodule,
91 | static_model=static_model,
92 | dynamic_model=dynamic_model,
93 | out_path=out_path,
94 | static_ckpt_path=static_ckpt_path,
95 | dynamic_ckpt_path=dynamic_ckpt_path,
96 | )
97 |
98 | evaluator.eval()
99 |
--------------------------------------------------------------------------------
/src/pipelines/train.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import argparse
12 | import os
13 | import logging
14 | import sys
15 | import shutil
16 |
17 | from pathlib import Path
18 | from omegaconf import OmegaConf, DictConfig
19 |
20 | from src.utils.configs import str2bool, is_instantiable, instantiate_from_config
21 | from src.utils.general_utils import seed_all
22 | from src.pipelines.utils import StreamToLogger
23 |
24 |
25 | def check_argument_sanity(args: argparse.Namespace) -> None:
26 | # check whether arguments are given properly.
27 | assert args.name != "", "provide an appropriate expname"
28 |
29 | config_paths = args.base
30 | assert len(config_paths) != 0, "no config given for training"
31 | for config_path in config_paths:
32 | assert os.path.exists(config_path), f"no config exists in path {config_path}"
33 |
34 | if args.verbose:
35 | # When running verbose mode, set the os.environ to globally share
36 | # the running mode in a conveninent way
37 | os.environ["VERBOSE_RUN"] = "True"
38 |
39 | if args.debug:
40 | os.environ["DEBUG_RUN"] = "True"
41 |
42 | assert args.group != "", "specify the group name"
43 |
44 |
45 | def check_config(config: OmegaConf) -> None:
46 | # check whether config has a proper structure to run training
47 | assert is_instantiable(config, "static_data")
48 | assert is_instantiable(config, "static_model")
49 | assert is_instantiable(config, "dynamic_data")
50 | assert is_instantiable(config, "dynamic_model")
51 | assert is_instantiable(config, "trainer")
52 |
53 |
54 | def set_traindir(logdir: str, name: str, group: str, seed: int, debug: bool) -> Path:
55 | # explogdir (logdir/group/name)
56 | is_overridable = debug
57 |
58 | # in the case of group, previous experiments might already have made the directory
59 | path_logdir = Path(logdir, group)
60 | path_logdir_posix = path_logdir.as_posix()
61 | # group logdir is always overridable
62 | os.makedirs(path_logdir_posix, exist_ok=True)
63 |
64 | # but each experiment logdir should not exist
65 | expname = f"{name}_{str(seed).zfill(4)}"
66 | path_expdir = path_logdir.joinpath(expname)
67 | path_expdir_posix = path_expdir.as_posix()
68 | os.makedirs(path_expdir_posix, exist_ok=is_overridable)
69 |
70 | traindir = path_expdir.joinpath("train")
71 | traindir_posix = traindir.as_posix()
72 | os.makedirs(traindir_posix, exist_ok=is_overridable)
73 |
74 | return traindir
75 |
76 |
77 | def set_logger(logdir: Path) -> logging.Logger:
78 | # set logger in "logdir/train.log"
79 |
80 | logger = logging.getLogger("__main__")
81 | formatter = logging.Formatter(
82 | "%(asctime)s;[%(levelname)s];%(message)s", "%Y-%m-%d %H:%M:%S"
83 | )
84 | logger.setLevel(logging.INFO)
85 |
86 | stream_handler = logging.StreamHandler()
87 | stream_handler.setLevel(logging.INFO)
88 | stream_handler.setFormatter(formatter)
89 |
90 | file_handler = logging.FileHandler(logdir.joinpath("train.log").as_posix())
91 | file_handler.setLevel(logging.INFO)
92 | file_handler.setFormatter(formatter)
93 |
94 | logger.addHandler(stream_handler)
95 | logger.addHandler(file_handler)
96 |
97 | # Change the print function to
98 | sys.stdout = StreamToLogger(logger, logging.INFO)
99 | sys.stderr = StreamToLogger(logger, logging.INFO)
100 |
101 | return logger
102 |
103 |
104 | def store_args_and_config(logdir: Path, args: argparse.Namespace, config: DictConfig):
105 | # store args as runnable metadata
106 | config["metadata"] = vars(args)
107 | store_path = logdir.joinpath("config.yaml")
108 |
109 | # yaml_format = OmegaConf.to_yaml(config)
110 | OmegaConf.save(config, store_path)
111 |
112 |
113 | def store_code(logdir: Path):
114 | code_path = "./src"
115 | dst_path = logdir.joinpath("code").as_posix()
116 | if not os.path.exists(dst_path):
117 | shutil.copytree(
118 | src=code_path,
119 | dst=dst_path,
120 | ignore=shutil.ignore_patterns("*__pycache__*"),
121 | )
122 |
123 |
124 | def parse_args():
125 |
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument(
128 | "-n",
129 | "--name",
130 | type=str,
131 | const=True,
132 | default="",
133 | nargs="?",
134 | help="postfix for logdir",
135 | )
136 | parser.add_argument(
137 | "-b",
138 | "--base",
139 | nargs="*",
140 | metavar="configs/default.yaml",
141 | default=list(),
142 | help="path to config files",
143 | )
144 | parser.add_argument(
145 | "-d",
146 | "--dirpath",
147 | type=str,
148 | required=True,
149 | help="path to data directory",
150 | )
151 | parser.add_argument(
152 | "-s",
153 | "--seed",
154 | type=int,
155 | default=777,
156 | help="seed for seed_everything",
157 | )
158 | parser.add_argument(
159 | "-l",
160 | "--logdir",
161 | type=str,
162 | default="./logs",
163 | help="path to store logs",
164 | )
165 | parser.add_argument(
166 | "-g",
167 | "--group",
168 | type=str,
169 | default="",
170 | help="group name",
171 | )
172 | parser.add_argument(
173 | "--verbose",
174 | type=str2bool,
175 | nargs="?",
176 | const=True,
177 | default=True,
178 | help="verbose printing",
179 | )
180 | parser.add_argument(
181 | "--debug",
182 | type=str2bool,
183 | nargs="?",
184 | const=True,
185 | default=False,
186 | help="run with debug mode",
187 | )
188 |
189 | args, unknown = parser.parse_known_args()
190 |
191 | return args, unknown
192 |
193 |
194 | def override_config(static_data_config, dynamic_data_config, trainer_config):
195 | # override static_config with dynamic_config
196 | iteration = static_data_config["params"]["train_dloader_config"]["params"][
197 | "num_iterations"
198 | ]
199 | if (
200 | dynamic_data_config["params"]["train_dloader_config"]["params"][
201 | "num_iterations"
202 | ]
203 | != iteration
204 | ):
205 | print("Override num_iterations...")
206 |
207 | dynamic_data_config["params"]["train_dloader_config"]["params"][
208 | "num_iterations"
209 | ] = iteration
210 |
211 | trainer_config["params"]["static"]["params"]["num_iterations"] = iteration
212 | trainer_config["params"]["static"]["params"]["camera_opt_config"]["params"][
213 | "total_steps"
214 | ] = iteration
215 | trainer_config["params"]["static"]["params"][
216 | "position_lr_max_steps"
217 | ] = iteration
218 |
219 | trainer_config["params"]["dynamic"]["params"]["num_iterations"] = iteration
220 | trainer_config["params"]["dynamic"]["params"]["camera_opt_config"]["params"][
221 | "total_steps"
222 | ] = iteration
223 | trainer_config["params"]["dynamic"]["params"][
224 | "position_lr_max_steps"
225 | ] = iteration
226 | trainer_config["params"]["dynamic"]["params"]["deform_lr_max_steps"] = iteration
227 |
228 | return dynamic_data_config, trainer_config
229 |
230 |
231 | if __name__ == "__main__":
232 |
233 | args, unknown = parse_args()
234 |
235 | # argument sanity check
236 | check_argument_sanity(args)
237 |
238 | configs = [OmegaConf.load(cfg) for cfg in args.base]
239 | cli = OmegaConf.from_dotlist(unknown)
240 | config = OmegaConf.merge(*configs, cli)
241 |
242 | # config sanity check
243 | check_config(config)
244 | config.static_data.params.dirpath = args.dirpath
245 | config.dynamic_data.params.dirpath = args.dirpath
246 |
247 | seed_all(args.seed)
248 |
249 | logdir: Path = set_traindir(
250 | logdir=args.logdir,
251 | name=args.name,
252 | group=args.group,
253 | seed=args.seed,
254 | debug=args.debug,
255 | )
256 | logger = set_logger(logdir=logdir)
257 | store_args_and_config(logdir, args, config)
258 | store_code(logdir)
259 |
260 | config.dynamic_data, config.trainer = override_config(
261 | config.static_data, config.dynamic_data, config.trainer
262 | )
263 |
264 | static_datamodule = instantiate_from_config(config.static_data)
265 | dynamic_datamodule = instantiate_from_config(config.dynamic_data)
266 |
267 | print("Init Static GS model")
268 | static_model = instantiate_from_config(config.static_model)
269 |
270 | print("Init Dynamic GS model")
271 | dynamic_model = instantiate_from_config(config.dynamic_model)
272 | trainer = instantiate_from_config(
273 | config.trainer,
274 | static_datamodule=static_datamodule,
275 | dynamic_datamodule=dynamic_datamodule,
276 | static_model=static_model,
277 | dynamic_model=dynamic_model,
278 | logdir=logdir,
279 | )
280 |
281 | trainer.train()
282 |
--------------------------------------------------------------------------------
/src/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import logging
12 |
13 |
14 | class StreamToLogger(object):
15 | # Reference: https://stackoverflow.com/questions/11124093/redirect-python-print-output-to-logger
16 | def __init__(self, logger, log_level=logging.INFO):
17 | self.logger = logger
18 | self.log_level = log_level
19 | self.linebuf = ""
20 |
21 | def write(self, buf):
22 | temp_linebuf = self.linebuf + buf
23 | self.linebuf = ""
24 | for line in temp_linebuf.splitlines(True):
25 | if line[-1] == "\n":
26 | self.logger.log(self.log_level, line.rstrip())
27 | else:
28 | self.linebuf += line
29 |
30 | def flush(self):
31 | if self.linebuf != "":
32 | self.logger.log(self.log_level, self.linebuf.rstrip())
33 | self.linebuf = ""
34 |
--------------------------------------------------------------------------------
/src/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/src/trainer/optim.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import os
12 | import math
13 |
14 | import torch
15 | from functools import partial
16 |
17 | VERBOSE = os.environ.get("VERBOSE_RUN", False)
18 |
19 |
20 | def linear_warmup_cosine_annealing_func(step, max_lr, warmup_steps, total_steps):
21 | if step < warmup_steps:
22 | # Linear warmup
23 | return max_lr * (step / warmup_steps)
24 | else:
25 | # Cosine annealing
26 | progress = (step - warmup_steps) / (total_steps - warmup_steps)
27 | cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
28 | return max_lr * cosine_decay
29 |
30 |
31 | class CameraQuatOptimizer:
32 |
33 | eps = 1e-15
34 |
35 | def __init__(
36 | self,
37 | dataset,
38 | camera_rotation_lr,
39 | camera_translation_lr,
40 | camera_lr_warmup,
41 | total_steps,
42 | spatial_lr_scale,
43 | ):
44 |
45 | self.spatial_lr_scale = spatial_lr_scale
46 |
47 | cam_R_params, cam_T_params = [], []
48 | for name, param in dataset.named_parameters():
49 | if "R_" in name:
50 | cam_R_params.append(param)
51 | elif "T_" in name:
52 | cam_T_params.append(param)
53 | else:
54 | raise NameError(f"Unknown parameter {name}")
55 |
56 | l = (
57 | {"params": cam_R_params, "lr": 0.0, "name": "camera_R"},
58 | {"params": cam_T_params, "lr": 0.0, "name": "camera_T"},
59 | )
60 |
61 | self._optimizer = torch.optim.Adam(l, lr=0.0, eps=self.eps)
62 | self.R_lr_fn = partial(
63 | linear_warmup_cosine_annealing_func,
64 | max_lr=camera_rotation_lr,
65 | warmup_steps=camera_lr_warmup,
66 | total_steps=total_steps,
67 | )
68 | self.T_lr_fn = partial(
69 | linear_warmup_cosine_annealing_func,
70 | max_lr=camera_translation_lr,
71 | warmup_steps=camera_lr_warmup,
72 | total_steps=total_steps,
73 | )
74 |
75 | def update_learning_rate(self, iter):
76 | for param_group in self._optimizer.param_groups:
77 | if param_group["name"] == "camera_R":
78 | param_group["lr"] = self.R_lr_fn(iter)
79 | elif param_group["name"] == "camera_T":
80 | param_group["lr"] = self.T_lr_fn(iter)
81 | else:
82 | raise NameError(f"Unknown param_group {param_group['name']}")
83 |
84 | def state_dict(self):
85 | return self._optimizer.state_dict()
86 |
87 | def zero_grad(self, set_to_none):
88 | return self._optimizer.zero_grad(set_to_none=set_to_none)
89 |
90 | def step(self):
91 | self._optimizer.step()
92 |
--------------------------------------------------------------------------------
/src/trainer/renderer.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import torch
12 | import math
13 |
14 | from diff_gauss_pose import GaussianRasterizationSettings, GaussianRasterizer
15 |
16 |
17 | def render(
18 | xyz,
19 | active_sh_degree,
20 | opacity,
21 | scaling,
22 | rotation,
23 | features,
24 | viewpoint_camera,
25 | bg_color: torch.Tensor,
26 | scaling_modifier=1,
27 | override_color=None,
28 | enable_sh_grad=False,
29 | enable_cov_grad=False,
30 | ):
31 | """
32 | Render the scene.
33 |
34 | Background tensor (bg_color) must be on GPU!
35 | """
36 |
37 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
38 | screenspace_points = (
39 | torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device="cuda") + 0
40 | )
41 | try:
42 | screenspace_points.retain_grad()
43 | except:
44 | pass
45 |
46 | # Set up rasterization configuration
47 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
48 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
49 |
50 | raster_settings = GaussianRasterizationSettings(
51 | image_height=int(viewpoint_camera.image_height),
52 | image_width=int(viewpoint_camera.image_width),
53 | tanfovx=tanfovx,
54 | tanfovy=tanfovy,
55 | bg=bg_color,
56 | scale_modifier=scaling_modifier,
57 | projmatrix=viewpoint_camera.projection_matrix.transpose(0, 1), # glm storage
58 | sh_degree=active_sh_degree,
59 | prefiltered=False,
60 | debug=False,
61 | enable_cov_grad=enable_sh_grad,
62 | enable_sh_grad=enable_cov_grad,
63 | )
64 |
65 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
66 |
67 | means3D = xyz
68 | means2D = screenspace_points
69 | opacity = opacity
70 |
71 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
72 | # scaling / rotation by the rasterizer.
73 | cov3D_precomp = None
74 | scales = scaling
75 | rotations = rotation
76 |
77 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
78 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
79 | shs = None
80 | colors_precomp = None
81 | if override_color is None:
82 | shs = features
83 | else:
84 | colors_precomp = override_color
85 |
86 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
87 | rendered_image, rendered_depth, rendered_normal, rendered_alpha, radii, extra = (
88 | rasterizer(
89 | means3D=means3D,
90 | means2D=means2D,
91 | shs=shs,
92 | colors_precomp=colors_precomp,
93 | opacities=opacity,
94 | scales=scales,
95 | rotations=rotations,
96 | cov3Ds_precomp=cov3D_precomp,
97 | viewmatrix=viewpoint_camera.world_view_transform.transpose(
98 | 0, 1
99 | ), # glm storage
100 | )
101 | )
102 |
103 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
104 | # They will be excluded from value updates used in the splitting criteria.
105 | return {
106 | "rendered_image": rendered_image,
107 | "rendered_depth": rendered_depth,
108 | "rendered_normal": rendered_normal,
109 | "rendered_alpha": rendered_alpha,
110 | "viewspace_points": screenspace_points,
111 | "visibility_filter": radii > 0,
112 | "radii": radii,
113 | "extra": extra,
114 | }
115 |
--------------------------------------------------------------------------------
/src/trainer/rodygs_dynamic.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | from pathlib import Path
12 | from typing import Optional
13 |
14 | from omegaconf import DictConfig
15 | import torch
16 | from tqdm import tqdm
17 | from src.model.rodygs_dynamic import DynRoDyGS
18 | from src.trainer.rodygs_static import ThreeDGSTrainer
19 | from src.trainer.utils import prune_optimizer
20 | from src.utils.general_utils import get_expon_lr_func
21 |
22 |
23 | class DynTrainer(ThreeDGSTrainer):
24 |
25 | def __init__(
26 | self,
27 | datamodule,
28 | logdir: Path,
29 | model: DynRoDyGS,
30 | loss_config: DictConfig,
31 | # GS params
32 | num_iterations: int,
33 | position_lr_init: float,
34 | position_lr_final: float,
35 | position_lr_delay_mult: float,
36 | position_lr_max_steps: int,
37 | feature_lr: float,
38 | opacity_lr: float,
39 | scaling_lr: float,
40 | rotation_lr: float,
41 | percent_dense: float,
42 | opacity_reset_interval: int,
43 | densify_grad_threshold: float,
44 | densify_from_iter: int,
45 | densify_until_iter: int,
46 | densification_interval: int,
47 | # DyNMF params
48 | deform_lr_init: float,
49 | deform_lr_final: float,
50 | deform_lr_delay_mult: float,
51 | deform_lr_max_steps: int,
52 | motion_coeff_lr: float,
53 | deform_warmup_steps: int,
54 | # logging option
55 | log_freq: int = 50,
56 | # camera optim
57 | camera_opt_config: Optional[DictConfig] = None,
58 | ):
59 | super(DynTrainer, self).__init__(
60 | datamodule=datamodule,
61 | logdir=logdir,
62 | model=model,
63 | loss_config=loss_config,
64 | num_iterations=num_iterations,
65 | position_lr_init=position_lr_init,
66 | position_lr_final=position_lr_final,
67 | position_lr_delay_mult=position_lr_delay_mult,
68 | position_lr_max_steps=position_lr_max_steps,
69 | feature_lr=feature_lr,
70 | opacity_lr=opacity_lr,
71 | scaling_lr=scaling_lr,
72 | rotation_lr=rotation_lr,
73 | percent_dense=percent_dense,
74 | opacity_reset_interval=opacity_reset_interval,
75 | densify_grad_threshold=densify_grad_threshold,
76 | densify_from_iter=densify_from_iter,
77 | densify_until_iter=densify_until_iter,
78 | densification_interval=densification_interval,
79 | log_freq=log_freq,
80 | camera_opt_config=camera_opt_config,
81 | deform_warmup_steps=deform_warmup_steps,
82 | )
83 |
84 | self.append_motion_optim(
85 | deform_lr_init=deform_lr_init,
86 | deform_lr_final=deform_lr_final,
87 | deform_lr_delay_mult=deform_lr_delay_mult,
88 | deform_lr_max_steps=deform_lr_max_steps,
89 | motion_coeff_lr=motion_coeff_lr,
90 | )
91 |
92 | def append_motion_optim(
93 | self,
94 | deform_lr_init: float,
95 | deform_lr_final: float,
96 | deform_lr_delay_mult: float,
97 | deform_lr_max_steps: int,
98 | motion_coeff_lr: float,
99 | ):
100 | # Additional parameters and learning rates
101 | additional_params = [
102 | {
103 | "params": list(self.model._deform_network.parameters()),
104 | "lr": deform_lr_init,
105 | "name": "deform_network",
106 | },
107 | {
108 | "params": [self.model._motion_coeff],
109 | "lr": motion_coeff_lr,
110 | "name": "motion_coeff",
111 | },
112 | ]
113 |
114 | # Append additional parameters and learning rates to the optimizer
115 | for additional_param in additional_params:
116 | self.optimizer.add_param_group(additional_param)
117 |
118 | self.deform_scheduler_fn = get_expon_lr_func(
119 | lr_init=deform_lr_init,
120 | lr_final=deform_lr_final,
121 | lr_delay_mult=deform_lr_delay_mult,
122 | max_steps=deform_lr_max_steps,
123 | )
124 |
125 | def train(self):
126 | pbar = tqdm(total=self.num_iterations)
127 | dloader = self.datamodule.get_train_dloader()
128 |
129 | times = dloader.dataset.get_times()
130 | times = torch.tensor(times).cuda().view(-1, 1)
131 | self.model._time_embeddings = self.model._deform_network.batch_embedding(times)
132 |
133 | print("pre-encoding time embeddings")
134 | print("shape : ", self.model._time_embeddings.shape)
135 |
136 | train_dloader_iter = iter(dloader)
137 |
138 | for iteration in range(1, self.num_iterations + 1):
139 | self.train_iteration(train_dloader_iter, iteration)
140 |
141 | if iteration % self.log_freq == 0:
142 | pbar.update(self.log_freq)
143 |
144 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
145 | torch.save(
146 | (self.state_dict(iteration), iteration),
147 | self.logdir.as_posix() + "/last.ckpt",
148 | )
149 |
150 | def prune_points(self, mask):
151 | # the same as the original 3DGS pruner, but also prunes the motion coefficients
152 | valid_points_mask = ~mask
153 | optimizable_tensors = prune_optimizer(self.optimizer, valid_points_mask)
154 |
155 | self.model._xyz = optimizable_tensors["xyz"]
156 | self.model._features_dc = optimizable_tensors["f_dc"]
157 | self.model._features_rest = optimizable_tensors["f_rest"]
158 | self.model._opacity = optimizable_tensors["opacity"]
159 | self.model._scaling = optimizable_tensors["scaling"]
160 | self.model._rotation = optimizable_tensors["rotation"]
161 | self.model._motion_coeff = optimizable_tensors["motion_coeff"]
162 |
163 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
164 | self.denom = self.denom[valid_points_mask]
165 | self.max_radii2D = self.max_radii2D[valid_points_mask]
166 | self.model.gaussian_to_time = self.model.gaussian_to_time[valid_points_mask]
167 | self.model.gaussian_to_time_ind = self.model.gaussian_to_time_ind[
168 | valid_points_mask
169 | ]
170 |
171 | def set_attributes_from_opt_tensors(self, optimizable_tensors):
172 | super(DynTrainer, self).set_attributes_from_opt_tensors(optimizable_tensors)
173 | self.model._motion_coeff = optimizable_tensors["motion_coeff"]
174 |
175 | def densify_and_clone_update_attributes(self, grads, grad_threshold, scene_extent):
176 | (
177 | updated_attributes,
178 | selected_pts_mask,
179 | ) = super().densify_and_clone_update_attributes(
180 | grads, grad_threshold, scene_extent
181 | )
182 | updated_attributes["motion_coeff"] = self.model._motion_coeff[selected_pts_mask]
183 | return updated_attributes, selected_pts_mask
184 |
185 | def densify_and_split_update_attributes(
186 | self, grads, grad_threshold, scene_extent, N
187 | ):
188 | (
189 | updated_attributes,
190 | selected_pts_mask,
191 | ) = super().densify_and_split_update_attributes(
192 | grads, grad_threshold, scene_extent, N
193 | )
194 | updated_attributes["motion_coeff"] = self.model._motion_coeff[
195 | selected_pts_mask
196 | ].repeat(N, 1, 1)
197 | return updated_attributes, selected_pts_mask
198 |
199 | def update_learning_rate(self, iteration):
200 | """Learning rate scheduling per step"""
201 |
202 | xyz_found = False
203 | deform_found = False
204 |
205 | for param_group in self.optimizer.param_groups:
206 | if xyz_found and deform_found:
207 | break
208 | if param_group["name"] == "xyz":
209 | lr = self.xyz_scheduler_fn(iteration)
210 | param_group["lr"] = lr
211 | xyz_found = True
212 | elif param_group["name"] == "deform":
213 | lr = self.deform_scheduler_fn(iteration)
214 | param_group["lr"] = lr
215 | deform_found = True
216 |
217 | def state_dict(self, iteration):
218 | state_dict = super(DynTrainer, self).state_dict(iteration)
219 | state_dict["model"]["_motion_coeff"] = self.model._motion_coeff
220 | state_dict["model"]["_deform_network"] = self.model._deform_network.state_dict()
221 | state_dict["model"]["_timestep"] = self.model.gaussian_to_time
222 | return state_dict
223 |
--------------------------------------------------------------------------------
/src/trainer/utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | def replace_tensor_to_optimizer(optimizer, tensor, name):
16 | optimizable_tensors = {}
17 | for group in optimizer.param_groups:
18 | if len(group["params"]) > 1:
19 | # for these cases, we don't want to modify the optimizer
20 | continue
21 | if group["name"] == name:
22 | stored_state = optimizer.state.get(group["params"][0], None)
23 | stored_state["exp_avg"] = torch.zeros_like(tensor)
24 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
25 |
26 | del optimizer.state[group["params"][0]]
27 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
28 | optimizer.state[group["params"][0]] = stored_state
29 |
30 | optimizable_tensors[group["name"]] = group["params"][0]
31 | return optimizable_tensors
32 |
33 |
34 | def cat_tensors_to_optimizer(optimizer, tensors_dict):
35 | optimizable_tensors = {}
36 | for group in optimizer.param_groups:
37 | if len(group["params"]) > 1:
38 | # for these cases, we don't want to modify the optimizer
39 | continue
40 | extension_tensor = tensors_dict[group["name"]]
41 | stored_state = optimizer.state.get(group["params"][0], None)
42 | if stored_state is not None:
43 |
44 | stored_state["exp_avg"] = torch.cat(
45 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0
46 | )
47 | stored_state["exp_avg_sq"] = torch.cat(
48 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0
49 | )
50 |
51 | del optimizer.state[group["params"][0]]
52 | group["params"][0] = nn.Parameter(
53 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(
54 | True
55 | )
56 | )
57 | optimizer.state[group["params"][0]] = stored_state
58 |
59 | optimizable_tensors[group["name"]] = group["params"][0]
60 | else:
61 | group["params"][0] = nn.Parameter(
62 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(
63 | True
64 | )
65 | )
66 | optimizable_tensors[group["name"]] = group["params"][0]
67 |
68 | return optimizable_tensors
69 |
70 |
71 | def prune_optimizer(optimizer, mask):
72 | optimizable_tensors = {}
73 | for group in optimizer.param_groups:
74 | if len(group["params"]) > 1:
75 | # for these cases, we don't want to modify the optimizer
76 | continue
77 | stored_state = optimizer.state.get(group["params"][0], None)
78 | if stored_state is not None:
79 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
80 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
81 |
82 | del optimizer.state[group["params"][0]]
83 | group["params"][0] = nn.Parameter(
84 | (group["params"][0][mask].requires_grad_(True))
85 | )
86 | optimizer.state[group["params"][0]] = stored_state
87 |
88 | optimizable_tensors[group["name"]] = group["params"][0]
89 | else:
90 | group["params"][0] = nn.Parameter(
91 | group["params"][0][mask].requires_grad_(True)
92 | )
93 | optimizable_tensors[group["name"]] = group["params"][0]
94 |
95 | return optimizable_tensors
96 |
--------------------------------------------------------------------------------
/src/utils/configs.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import argparse
12 | import importlib
13 |
14 | from omegaconf import OmegaConf, DictConfig
15 |
16 |
17 | def str2bool(v):
18 | # a util function that accepts various kinds of flag options.
19 | if isinstance(v, bool):
20 | return v
21 | if v.lower() in ("yes", "true", "t", "y", "1"):
22 | return True
23 | elif v.lower() in ("no", "false", "f", "n", "0"):
24 | return False
25 | else:
26 | raise argparse.ArgumentTypeError("Boolean value expected.")
27 |
28 |
29 | def is_instantiable(config: DictConfig, key: str):
30 | # check whether the configuration is an instantiable format.
31 | subconfig = config.get(key)
32 | is_config = OmegaConf.is_config(subconfig)
33 | if not is_config:
34 | return False
35 |
36 | # target and params should be in keys.
37 | key_list = subconfig.keys()
38 | has_params = "params" in key_list
39 | has_target = "target" in key_list
40 |
41 | # no other keys are needed.
42 | no_more_keys = len(key_list) == 2
43 |
44 | return has_params and has_target and no_more_keys
45 |
46 |
47 | def instantiate_from_config(config: DictConfig, **kwargs):
48 | # instantiate an object from configuration.
49 | # if some arguments are passed in kwargs, this function overloads the configruation.
50 | assert "target" in config.keys(), "target not exists"
51 |
52 | params = dict()
53 | params.update(config.get("params", dict()))
54 | params.update(kwargs)
55 | return get_obj_from_str(config["target"])(**params)
56 |
57 |
58 | def get_obj_from_str(string, reload=False, invalidate_cache=True):
59 | # from a str, return the corresponding object.
60 | module, obj_name = string.rsplit(".", 1)
61 | if invalidate_cache:
62 | importlib.invalidate_caches()
63 | if reload:
64 | module_imp = importlib.import_module(module)
65 | importlib.reload(module_imp)
66 | return getattr(importlib.import_module(module, package=None), obj_name)
67 |
--------------------------------------------------------------------------------
/src/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | from itertools import chain
12 | from typing import Sequence
13 | from collections import OrderedDict
14 |
15 | import numpy as np
16 | import scipy
17 | from piqa import PSNR, SSIM, LPIPS, MS_SSIM
18 | import torch
19 | import torch.nn as nn
20 | from torchvision import models
21 |
22 | from src.utils.general_utils import batchify, reduce
23 | from src.utils.pose_estim_utils import align_ate_c2b_use_a2b, compute_ATE, compute_rpe
24 |
25 |
26 | class VizScoreEvaluator:
27 |
28 | def __init__(self, device):
29 | self.psnr_module = PSNR().to(device)
30 | self.ssim_module = SSIM().to(device)
31 | self.msssim_module = MS_SSIM().to(device)
32 |
33 | @torch.inference_mode()
34 | def get_score(self, gt_image, pred_image):
35 |
36 | bf_gt_image = batchify(gt_image).clip(0, 1).contiguous()
37 | bf_pred_image = batchify(pred_image).clip(0, 1).contiguous()
38 |
39 | psnr_score = reduce(self.psnr_module(bf_gt_image, bf_pred_image))
40 | ssim_score = reduce(self.ssim_module(bf_gt_image, bf_pred_image))
41 | lpipsa_score = reduce(lpips(bf_gt_image, bf_pred_image, net_type="alex"))
42 | lpipsv_score = reduce(lpips(bf_gt_image, bf_pred_image, net_type="vgg"))
43 | msssim_score = reduce(self.msssim_module(bf_gt_image, bf_pred_image))
44 | dssim_score = (1 - msssim_score) / 2
45 |
46 | return {
47 | "psnr": psnr_score,
48 | "ssim": ssim_score,
49 | "lpipsa": lpipsa_score,
50 | "lpipsv": lpipsv_score,
51 | "msssim": msssim_score,
52 | "dssim": dssim_score,
53 | }
54 |
55 |
56 | class PoseEvaluator:
57 | def __init__(self):
58 | pass
59 |
60 | def normalize_pose(self, pose1, pose2):
61 | mtx1 = np.array(pose1.cpu().numpy(), dtype=np.double, copy=True)
62 | mtx2 = np.array(pose2.cpu().numpy(), dtype=np.double, copy=True)
63 |
64 | if mtx1.ndim != 2 or mtx2.ndim != 2:
65 | raise ValueError("Input matrices must be two-dimensional")
66 | if mtx1.shape != mtx2.shape:
67 | raise ValueError("Input matrices must be of same shape")
68 | if mtx1.size == 0:
69 | raise ValueError("Input matrices must be >0 rows and >0 cols")
70 |
71 | # translate all the data to the origin
72 | mtx1 -= np.mean(mtx1, 0)
73 | mtx2 -= np.mean(mtx2, 0)
74 |
75 | norm1 = np.linalg.norm(mtx1)
76 | norm2 = np.linalg.norm(mtx2)
77 |
78 | if norm1 == 0 or norm2 == 0:
79 | raise ValueError("Input matrices must contain >1 unique points")
80 |
81 | # change scaling of data (in rows) such that trace(mtx*mtx') = 1
82 | mtx1 /= norm1
83 | mtx2 /= norm2
84 |
85 | # transform mtx2 to minimize disparity
86 | R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2)
87 | mtx2 = mtx2 * s
88 |
89 | return mtx1, mtx2, R
90 |
91 | def algin_pose(self, gt, estim):
92 | gt_ret, estim_ret = gt.clone(), estim.clone()
93 | gt_transl, estim_transl, _ = self.normalize_pose(
94 | gt[:, :3, -1], estim_ret[:, :3, -1]
95 | )
96 | gt_ret[:, :3, -1], estim_ret[:, :3, -1] = torch.from_numpy(gt_transl).type_as(
97 | gt_ret
98 | ), torch.from_numpy(estim_transl).type_as(estim_ret)
99 | c2ws_est_aligned = align_ate_c2b_use_a2b(estim_ret, gt_ret)
100 | return gt_ret, c2ws_est_aligned
101 |
102 | @torch.inference_mode()
103 | def get_score(self, gt, estim):
104 | gt, c2ws_est_aligned = self.algin_pose(gt, estim)
105 | ate = compute_ATE(gt.cpu().numpy(), c2ws_est_aligned.cpu().numpy())
106 | rpe_trans, rpe_rot = compute_rpe(
107 | gt.cpu().numpy(), c2ws_est_aligned.cpu().numpy()
108 | )
109 | rpe_trans *= 100
110 | rpe_rot *= 180 / np.pi
111 |
112 | return {
113 | "ATE": ate,
114 | "RPE_trans": rpe_trans,
115 | "RPE_rot": rpe_rot,
116 | "aligned": c2ws_est_aligned,
117 | }
118 |
119 |
120 | class LPIPS(nn.Module):
121 | r"""Creates a criterion that measures
122 | Learned Perceptual Image Patch Similarity (LPIPS).
123 |
124 | Arguments:
125 | net_type (str): the network type to compare the features:
126 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
127 | version (str): the version of LPIPS. Default: 0.1.
128 | """
129 |
130 | def __init__(self, net_type: str = "alex", version: str = "0.1"):
131 |
132 | assert version in ["0.1"], "v0.1 is only supported now"
133 |
134 | super(LPIPS, self).__init__()
135 |
136 | # pretrained network
137 | self.net = get_network(net_type)
138 |
139 | # linear layers
140 | self.lin = LinLayers(self.net.n_channels_list)
141 | self.lin.load_state_dict(get_state_dict(net_type, version))
142 |
143 | def forward(self, x: torch.Tensor, y: torch.Tensor):
144 | feat_x, feat_y = self.net(x), self.net(y)
145 |
146 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
147 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
148 |
149 | return torch.sum(torch.cat(res, 0), 0, True)
150 |
151 |
152 | def get_network(net_type: str):
153 | if net_type == "alex":
154 | return AlexNet()
155 | elif net_type == "squeeze":
156 | return SqueezeNet()
157 | elif net_type == "vgg":
158 | return VGG16()
159 | else:
160 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].")
161 |
162 |
163 | class LinLayers(nn.ModuleList):
164 | def __init__(self, n_channels_list: Sequence[int]):
165 | super(LinLayers, self).__init__(
166 | [
167 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False))
168 | for nc in n_channels_list
169 | ]
170 | )
171 |
172 | for param in self.parameters():
173 | param.requires_grad = False
174 |
175 |
176 | class BaseNet(nn.Module):
177 | def __init__(self):
178 | super(BaseNet, self).__init__()
179 |
180 | # register buffer
181 | self.register_buffer(
182 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
183 | )
184 | self.register_buffer(
185 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
186 | )
187 |
188 | def set_requires_grad(self, state: bool):
189 | for param in chain(self.parameters(), self.buffers()):
190 | param.requires_grad = state
191 |
192 | def z_score(self, x: torch.Tensor):
193 | return (x - self.mean) / self.std
194 |
195 | def forward(self, x: torch.Tensor):
196 | x = self.z_score(x)
197 |
198 | output = []
199 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
200 | x = layer(x)
201 | if i in self.target_layers:
202 | output.append(normalize_activation(x))
203 | if len(output) == len(self.target_layers):
204 | break
205 | return output
206 |
207 |
208 | class SqueezeNet(BaseNet):
209 | def __init__(self):
210 | super(SqueezeNet, self).__init__()
211 |
212 | self.layers = models.squeezenet1_1(True).features
213 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
214 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
215 |
216 | self.set_requires_grad(False)
217 |
218 |
219 | class AlexNet(BaseNet):
220 | def __init__(self):
221 | super(AlexNet, self).__init__()
222 |
223 | self.layers = models.alexnet(True).features
224 | self.target_layers = [2, 5, 8, 10, 12]
225 | self.n_channels_list = [64, 192, 384, 256, 256]
226 |
227 | self.set_requires_grad(False)
228 |
229 |
230 | class VGG16(BaseNet):
231 | def __init__(self):
232 | super(VGG16, self).__init__()
233 |
234 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
235 | self.target_layers = [4, 9, 16, 23, 30]
236 | self.n_channels_list = [64, 128, 256, 512, 512]
237 |
238 | self.set_requires_grad(False)
239 |
240 |
241 | def normalize_activation(x, eps=1e-10):
242 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
243 | return x / (norm_factor + eps)
244 |
245 |
246 | def get_state_dict(net_type: str = "alex", version: str = "0.1"):
247 | # build url
248 | url = (
249 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/"
250 | + f"master/lpips/weights/v{version}/{net_type}.pth"
251 | )
252 |
253 | # download
254 | old_state_dict = torch.hub.load_state_dict_from_url(
255 | url,
256 | progress=True,
257 | map_location=None if torch.cuda.is_available() else torch.device("cpu"),
258 | )
259 |
260 | # rename keys
261 | new_state_dict = OrderedDict()
262 | for key, val in old_state_dict.items():
263 | new_key = key
264 | new_key = new_key.replace("lin", "")
265 | new_key = new_key.replace("model.", "")
266 | new_state_dict[new_key] = val
267 |
268 | return new_state_dict
269 |
270 |
271 | def lpips(
272 | x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1"
273 | ):
274 | r"""Function that measures
275 | Learned Perceptual Image Patch Similarity (LPIPS).
276 |
277 | Arguments:
278 | x, y (torch.Tensor): the input tensors to compare.
279 | net_type (str): the network type to compare the features:
280 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
281 | version (str): the version of LPIPS. Default: 0.1.
282 | """
283 | device = x.device
284 | criterion = LPIPS(net_type, version).to(device)
285 | return criterion(x, y)
286 |
--------------------------------------------------------------------------------
/src/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 |
11 | import os
12 | import random
13 |
14 | import torch
15 | import numpy as np
16 |
17 |
18 | def seed_all(seed: int):
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 | torch.backends.cudnn.deterministic = True
23 | torch.backends.cudnn.benchmark = False
24 | random.seed(seed)
25 | os.environ["PYTHONHASHSEED"] = str(seed)
26 |
27 |
28 | def batchify(tensor):
29 | return tensor.unsqueeze(0)
30 |
31 |
32 | def reduce(tensor):
33 | return tensor.squeeze(0)
34 |
35 |
36 | def inverse_sigmoid(x):
37 | return torch.log(x / (1 - x))
38 |
39 |
40 | def get_expon_lr_func(
41 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
42 | ):
43 | """
44 | Copied from Plenoxels
45 |
46 | Continuous learning rate decay function. Adapted from JaxNeRF
47 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
48 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
49 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
50 | function of lr_delay_mult, such that the initial learning rate is
51 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
52 | to the normal learning rate when steps>lr_delay_steps.
53 | :param conf: config subtree 'lr' or similar
54 | :param max_steps: int, the number of steps during optimization.
55 | :return HoF which takes step as input
56 | """
57 |
58 | def helper(step):
59 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
60 | # Disable this parameter
61 | return 0.0
62 | if lr_delay_steps > 0:
63 | # A kind of reverse cosine decay.
64 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
65 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
66 | )
67 | else:
68 | delay_rate = 1.0
69 | t = np.clip(step / max_steps, 0, 1)
70 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
71 | return delay_rate * log_lerp
72 |
73 | return helper
74 |
75 |
76 | def strip_lowerdiag(L):
77 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
78 |
79 | uncertainty[:, 0] = L[:, 0, 0]
80 | uncertainty[:, 1] = L[:, 0, 1]
81 | uncertainty[:, 2] = L[:, 0, 2]
82 | uncertainty[:, 3] = L[:, 1, 1]
83 | uncertainty[:, 4] = L[:, 1, 2]
84 | uncertainty[:, 5] = L[:, 2, 2]
85 | return uncertainty
86 |
87 |
88 | def strip_symmetric(sym):
89 | return strip_lowerdiag(sym)
90 |
91 |
92 | def build_rotation(r):
93 | norm = torch.sqrt(
94 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
95 | )
96 |
97 | q = r / norm[:, None]
98 |
99 | R = torch.zeros((q.size(0), 3, 3), device="cuda")
100 |
101 | r = q[:, 0]
102 | x = q[:, 1]
103 | y = q[:, 2]
104 | z = q[:, 3]
105 |
106 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
107 | R[:, 0, 1] = 2 * (x * y - r * z)
108 | R[:, 0, 2] = 2 * (x * z + r * y)
109 | R[:, 1, 0] = 2 * (x * y + r * z)
110 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
111 | R[:, 1, 2] = 2 * (y * z - r * x)
112 | R[:, 2, 0] = 2 * (x * z - r * y)
113 | R[:, 2, 1] = 2 * (y * z + r * x)
114 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
115 | return R
116 |
117 |
118 | def build_scaling_rotation(s, r):
119 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
120 | R = build_rotation(r)
121 |
122 | L[:, 0, 0] = s[:, 0]
123 | L[:, 1, 1] = s[:, 1]
124 | L[:, 2, 2] = s[:, 2]
125 |
126 | L = R @ L
127 | return L
128 |
--------------------------------------------------------------------------------
/src/utils/graphic_utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | import math
12 | import warnings
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn.functional as F
17 |
18 |
19 | def geom_transform_points(points, transf_matrix):
20 | P, _ = points.shape
21 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
22 | points_hom = torch.cat([points, ones], dim=1)
23 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
24 |
25 | denom = points_out[..., 3:] + 0.0000001
26 | return (points_out[..., :3] / denom).squeeze(dim=0)
27 |
28 |
29 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
30 | Rt = np.zeros((4, 4))
31 | Rt[:3, :3] = R
32 | Rt[:3, 3] = t
33 | Rt[3, 3] = 1.0
34 |
35 | C2W = np.linalg.inv(Rt)
36 | cam_center = C2W[:3, 3]
37 | cam_center = (cam_center + translate) * scale
38 | C2W[:3, 3] = cam_center
39 | Rt = np.linalg.inv(C2W)
40 | return np.float32(Rt)
41 |
42 |
43 | def getProjectionMatrix(znear, zfar, fovX, fovY):
44 | tanHalfFovY = math.tan((fovY / 2))
45 | tanHalfFovX = math.tan((fovX / 2))
46 |
47 | top = tanHalfFovY * znear
48 | bottom = -top
49 | right = tanHalfFovX * znear
50 | left = -right
51 |
52 | P = torch.zeros(4, 4)
53 |
54 | z_sign = 1.0
55 |
56 | P[0, 0] = 2.0 * znear / (right - left)
57 | P[1, 1] = 2.0 * znear / (top - bottom)
58 | P[0, 2] = (right + left) / (right - left)
59 | P[1, 2] = (top + bottom) / (top - bottom)
60 | P[3, 2] = z_sign
61 | P[2, 2] = z_sign * zfar / (zfar - znear)
62 | P[2, 3] = -(zfar * znear) / (zfar - znear)
63 | return P
64 |
65 |
66 | def fov2focal(fov, pixels):
67 | if fov > math.pi * 2 or fov < -math.pi * 2:
68 | warnings.warn("fov seems to be degree value, plz double check!")
69 | return pixels / (2 * math.tan(fov / 2))
70 |
71 |
72 | def focal2fov(focal, pixels):
73 | return 2 * math.atan(pixels / (2 * focal))
74 |
75 |
76 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
77 | """
78 | Convert rotations given as quaternions to rotation matrices.
79 | Args:
80 | quaternions: quaternions with real part first,
81 | as tensor of shape (..., 4).
82 | Returns:
83 | Rotation matrices as tensor of shape (..., 3, 3).
84 | """
85 | r, i, j, k = torch.unbind(quaternions, -1)
86 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
87 |
88 | o = torch.stack(
89 | (
90 | 1 - two_s * (j * j + k * k),
91 | two_s * (i * j - k * r),
92 | two_s * (i * k + j * r),
93 | two_s * (i * j + k * r),
94 | 1 - two_s * (i * i + k * k),
95 | two_s * (j * k - i * r),
96 | two_s * (i * k - j * r),
97 | two_s * (j * k + i * r),
98 | 1 - two_s * (i * i + j * j),
99 | ),
100 | -1,
101 | )
102 | return o.reshape(quaternions.shape[:-1] + (3, 3))
103 |
104 |
105 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
106 | """
107 | Returns torch.sqrt(torch.max(0, x))
108 | but with a zero subgradient where x is 0.
109 | """
110 | ret = torch.zeros_like(x)
111 | positive_mask = x > 0
112 | ret[positive_mask] = torch.sqrt(x[positive_mask])
113 | return ret
114 |
115 |
116 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
117 | """
118 | Convert rotations given as rotation matrices to quaternions.
119 | Args:
120 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
121 | Returns:
122 | quaternions with real part first, as tensor of shape (..., 4).
123 | """
124 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
125 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
126 |
127 | batch_dim = matrix.shape[:-2]
128 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
129 | matrix.reshape(batch_dim + (9,)), dim=-1
130 | )
131 |
132 | q_abs = _sqrt_positive_part(
133 | torch.stack(
134 | [
135 | 1.0 + m00 + m11 + m22,
136 | 1.0 + m00 - m11 - m22,
137 | 1.0 - m00 + m11 - m22,
138 | 1.0 - m00 - m11 + m22,
139 | ],
140 | dim=-1,
141 | )
142 | )
143 |
144 | # we produce the desired quaternion multiplied by each of r, i, j, k
145 | quat_by_rijk = torch.stack(
146 | [
147 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
148 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
149 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
150 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
151 | ],
152 | dim=-2,
153 | )
154 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
155 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
156 |
157 | return quat_candidates[
158 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
159 | ].reshape(batch_dim + (4,))
160 |
--------------------------------------------------------------------------------
/src/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 | from math import exp
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | import torch.nn as nn
16 | from torch.autograd import Variable
17 |
18 |
19 | def l1_loss(network_output, gt):
20 | return torch.abs((network_output - gt)).mean()
21 |
22 |
23 | def l2_loss(network_output, gt):
24 | return ((network_output - gt) ** 2).mean()
25 |
26 |
27 | def smooth_loss(output):
28 | grad_output_x = output[:, :, :-1] - output[:, :, 1:]
29 | grad_output_y = output[:, :-1, :] - output[:, 1:, :]
30 |
31 | return torch.abs(grad_output_x).mean() + torch.abs(grad_output_y).mean()
32 |
33 |
34 | def gaussian(window_size, sigma):
35 | gauss = torch.Tensor(
36 | [
37 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
38 | for x in range(window_size)
39 | ]
40 | )
41 | return gauss / gauss.sum()
42 |
43 |
44 | def logl1(pred, gt):
45 | return torch.log(1 + torch.abs(pred - gt))
46 |
47 |
48 | def create_window(window_size, channel):
49 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
50 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
51 | window = Variable(
52 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()
53 | )
54 | return window
55 |
56 |
57 | def ssim(img1, img2, window_size=11, size_average=True):
58 | channel = img1.size(-3)
59 | window = create_window(window_size, channel)
60 |
61 | if img1.is_cuda:
62 | window = window.cuda(img1.get_device())
63 | window = window.type_as(img1)
64 |
65 | return _ssim(img1, img2, window, window_size, channel, size_average)
66 |
67 |
68 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
69 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
70 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
71 |
72 | mu1_sq = mu1.pow(2)
73 | mu2_sq = mu2.pow(2)
74 | mu1_mu2 = mu1 * mu2
75 |
76 | sigma1_sq = (
77 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
78 | )
79 | sigma2_sq = (
80 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
81 | )
82 | sigma12 = (
83 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
84 | - mu1_mu2
85 | )
86 |
87 | C1 = 0.01**2
88 | C2 = 0.03**2
89 |
90 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
91 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
92 | )
93 |
94 | if size_average:
95 | return ssim_map.mean()
96 | else:
97 | return ssim_map.mean(1).mean(1).mean(1)
98 |
99 |
100 | def pearson_depth_loss(input_depth, target_depth, eps, mask=None):
101 |
102 | if mask is not None:
103 | pred_depth = input_depth * mask
104 | gt_depth = target_depth * mask
105 | else:
106 | pred_depth = input_depth
107 | gt_depth = target_depth
108 |
109 | centered_pred_depth = pred_depth - pred_depth.mean()
110 | centered_gt_depth = gt_depth - gt_depth.mean()
111 |
112 | normalized_pred_depth = centered_pred_depth / (centered_pred_depth.std() + eps)
113 | normalized_gt_depth = centered_gt_depth / (centered_gt_depth.std() + eps)
114 |
115 | covariance = (normalized_pred_depth * normalized_gt_depth).mean()
116 |
117 | return 1 - covariance
118 |
119 |
120 | def compute_fundamental_matrix(K, w2c1, w2c2):
121 |
122 | # Compute relative rotation and translation
123 | R_rel = w2c2[:, :3, :3] @ w2c1[:, :3, :3].transpose(1, 2)
124 | t_rel = w2c2[:, :3, 3] - torch.einsum("nij, nj -> ni", R_rel, w2c1[:, :3, 3])
125 |
126 | # Skew-symmetric matrix [t]_x for cross product
127 | def skew_symmetric(t):
128 | B = t.shape[0]
129 | zero = torch.zeros(B, device=t.device)
130 | tx = torch.stack(
131 | [zero, -t[:, 2], t[:, 1], t[:, 2], zero, -t[:, 0], -t[:, 1], t[:, 0], zero],
132 | dim=1,
133 | )
134 | return tx.view(B, 3, 3)
135 |
136 | t_rel_skew = skew_symmetric(t_rel) # (B, 3, 3)
137 |
138 | # Compute Essential matrix E = [t]_x * R_rel
139 | E = t_rel_skew @ R_rel
140 |
141 | # Compute Fundamental matrix: F = K2^-T * E * K1^-1
142 | inverse_K = invert_intrinsics(K)
143 | K1_inv = inverse_K
144 | K2_inv_T = inverse_K.transpose(1, 2)
145 | F = K2_inv_T.bmm(E).bmm(K1_inv)
146 |
147 | return F
148 |
149 |
150 | def invert_intrinsics(K):
151 | """
152 | Efficiently compute the inverse of a batch of intrinsic matrices.
153 |
154 | Args:
155 | K: (B, 3, 3) tensor representing a batch of intrinsic matrices.
156 |
157 | Returns:
158 | K_inv: (B, 3, 3) tensor representing the batch of inverse intrinsic matrices.
159 | """
160 | B = K.shape[0]
161 |
162 | # Extract intrinsic parameters from the matrix
163 | fx = K[:, 0, 0] # focal length in x
164 | fy = K[:, 1, 1] # focal length in y
165 | cx = K[:, 0, 2] # principal point x
166 | cy = K[:, 1, 2] # principal point y
167 |
168 | # Construct the inverse intrinsic matrices manually
169 | K_inv = torch.zeros_like(K)
170 | K_inv[:, 0, 0] = 1.0 / fx
171 | K_inv[:, 1, 1] = 1.0 / fy
172 | K_inv[:, 0, 2] = -cx / fx
173 | K_inv[:, 1, 2] = -cy / fy
174 | K_inv[:, 2, 2] = 1.0
175 |
176 | return K_inv
177 |
178 |
179 | def construct_intrinsics(focal, image_width, image_height, batch_size):
180 | K = torch.zeros(batch_size, 3, 3)
181 | K[:, 0, 0] = focal
182 | K[:, 1, 1] = focal
183 | K[:, 2, 2] = 1.0
184 | K[:, 0, 2] = image_width / 2
185 | K[:, 1, 2] = image_height / 2
186 | return K
187 |
188 |
189 | def compute_sampson_error(x1, x2, F):
190 | """
191 | :param x1 (*, N, 2)
192 | :param x2 (*, N, 2)
193 | :param F (*, 3, 3)
194 | """
195 | h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1)
196 | h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1)
197 | d1 = torch.matmul(h1, F.transpose(-1, -2)) # (B, N, 3)
198 | d2 = torch.matmul(h2, F) # (B, N, 3)
199 | z = (h2 * d1).sum(dim=-1) # (B, N)
200 | err = (z**2) / (
201 | d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2
202 | )
203 | return err
204 |
205 |
206 | def get_outnorm(x: torch.Tensor, out_norm: str = "") -> torch.Tensor:
207 | """Common function to get a loss normalization value. Can
208 | normalize by either the batch size ('b'), the number of
209 | channels ('c'), the image size ('i') or combinations
210 | ('bi', 'bci', etc)
211 | """
212 | # b, c, h, w = x.size()
213 | img_shape = x.shape
214 |
215 | if not out_norm:
216 | return 1
217 |
218 | norm = 1
219 | if "b" in out_norm:
220 | # normalize by batch size
221 | # norm /= b
222 | norm /= img_shape[0]
223 | if "c" in out_norm:
224 | # normalize by the number of channels
225 | # norm /= c
226 | norm /= img_shape[1]
227 | if "i" in out_norm:
228 | # normalize by image/map size
229 | # norm /= h*w
230 | norm /= img_shape[-1] * img_shape[-2]
231 |
232 | return norm
233 |
234 |
235 | class CharbonnierLoss(nn.Module):
236 | """Charbonnier Loss (L1)"""
237 |
238 | def __init__(self, eps=1e-6, out_norm: str = "bci"):
239 | super(CharbonnierLoss, self).__init__()
240 | self.eps = eps
241 | self.out_norm = out_norm
242 |
243 | def forward(self, x, y, weight=None):
244 | norm = get_outnorm(x, self.out_norm)
245 | if weight is None:
246 | loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
247 | else:
248 | loss = torch.sum(weight * torch.sqrt((x - y).pow(2) + self.eps**2))
249 | return loss * norm
250 |
--------------------------------------------------------------------------------
/src/utils/point_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from typing import Optional
13 | from dataclasses import dataclass
14 | import numpy as np
15 |
16 |
17 | @dataclass
18 | class BasicPointCloud:
19 | points: np.array
20 | colors: np.array
21 | normals: np.array
22 | time: Optional[np.array]
23 |
24 |
25 | def uniform_sample(pcd: BasicPointCloud, ratio: float):
26 | assert ratio <= 1.0
27 | num_points = len(pcd.points)
28 | num_sample = int(num_points * ratio)
29 | point_idx = np.random.choice(num_points, num_sample, replace=False)
30 |
31 | return BasicPointCloud(
32 | points=pcd.points[point_idx],
33 | colors=pcd.colors[point_idx],
34 | normals=pcd.normals[point_idx],
35 | time=pcd.time[point_idx] if not pcd.time is None else None,
36 | )
37 |
38 |
39 | def merge_pcds(pcds):
40 |
41 | points, colors, normals, time = [], [], [], []
42 |
43 | for pcd in pcds:
44 | points.append(pcd.points)
45 | colors.append(pcd.colors)
46 | normals.append(pcd.normals)
47 | time.append(pcd.time)
48 |
49 | return BasicPointCloud(
50 | points=np.concatenate(points),
51 | colors=np.concatenate(colors),
52 | normals=np.concatenate(normals),
53 | time=np.concatenate(time) if not time[0] is None else None,
54 | )
55 |
--------------------------------------------------------------------------------
/src/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | C0 = 0.28209479177387814
25 | C1 = 0.4886025119029199
26 | C2 = [
27 | 1.0925484305920792,
28 | -1.0925484305920792,
29 | 0.31539156525252005,
30 | -1.0925484305920792,
31 | 0.5462742152960396,
32 | ]
33 | C3 = [
34 | -0.5900435899266435,
35 | 2.890611442640554,
36 | -0.4570457994644658,
37 | 0.3731763325901154,
38 | -0.4570457994644658,
39 | 1.445305721320277,
40 | -0.5900435899266435,
41 | ]
42 | C4 = [
43 | 2.5033429417967046,
44 | -1.7701307697799304,
45 | 0.9461746957575601,
46 | -0.6690465435572892,
47 | 0.10578554691520431,
48 | -0.6690465435572892,
49 | 0.47308734787878004,
50 | -1.7701307697799304,
51 | 0.6258357354491761,
52 | ]
53 |
54 |
55 | def eval_sh(deg, sh, dirs):
56 | """
57 | Evaluate spherical harmonics at unit directions
58 | using hardcoded SH polynomials.
59 | Works with torch/np/jnp.
60 | ... Can be 0 or more batch dimensions.
61 | Args:
62 | deg: int SH deg. Currently, 0-3 supported
63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
64 | dirs: jnp.ndarray unit directions [..., 3]
65 | Returns:
66 | [..., C]
67 | """
68 | assert deg <= 4 and deg >= 0
69 | coeff = (deg + 1) ** 2
70 | assert sh.shape[-1] >= coeff
71 |
72 | result = C0 * sh[..., 0]
73 | if deg > 0:
74 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
75 | result = (
76 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]
77 | )
78 |
79 | if deg > 1:
80 | xx, yy, zz = x * x, y * y, z * z
81 | xy, yz, xz = x * y, y * z, x * z
82 | result = (
83 | result
84 | + C2[0] * xy * sh[..., 4]
85 | + C2[1] * yz * sh[..., 5]
86 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
87 | + C2[3] * xz * sh[..., 7]
88 | + C2[4] * (xx - yy) * sh[..., 8]
89 | )
90 |
91 | if deg > 2:
92 | result = (
93 | result
94 | + C3[0] * y * (3 * xx - yy) * sh[..., 9]
95 | + C3[1] * xy * z * sh[..., 10]
96 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
97 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
98 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
99 | + C3[5] * z * (xx - yy) * sh[..., 14]
100 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15]
101 | )
102 |
103 | if deg > 3:
104 | result = (
105 | result
106 | + C4[0] * xy * (xx - yy) * sh[..., 16]
107 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17]
108 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18]
109 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19]
110 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
111 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21]
112 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
113 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
114 | + C4[8]
115 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
116 | * sh[..., 24]
117 | )
118 | return result
119 |
120 |
121 | def RGB2SH(rgb):
122 | return (rgb - 0.5) / C0
123 |
124 |
125 | def SH2RGB(sh):
126 | return sh * C0 + 0.5
127 |
--------------------------------------------------------------------------------
/src/utils/store_utils.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 |
3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH)
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | from pathlib import Path
12 | from typing import Union
13 | from abc import ABC, abstractmethod
14 |
15 | import numpy as np
16 | import torch
17 | import cv2
18 |
19 |
20 | def load_depth_from_pgm(pgm_file_path, min_depth_threshold=None):
21 | """
22 | Load the depth map from PGM file
23 | :param pgm_file_path: pgm file path
24 | :return: depth map with 3D ND-array
25 | """
26 | raw_img = None
27 | with open(pgm_file_path, "rb") as f:
28 | line = str(f.readline(), encoding="ascii")
29 | if line != "P5\n":
30 | print("Error loading pgm, format error\n")
31 |
32 | line = str(f.readline(), encoding="ascii")
33 | max_depth = float(line.split(" ")[-1].strip())
34 |
35 | line = str(f.readline(), encoding="ascii")
36 | dims = line.split(" ")
37 | cols = int(dims[0].strip())
38 | rows = int(dims[1].strip())
39 |
40 | line = str(f.readline(), encoding="ascii")
41 | max_factor = float(line.strip())
42 |
43 | raw_img = (
44 | np.frombuffer(
45 | f.read(cols * rows * np.dtype(np.uint16).itemsize), dtype=np.uint16
46 | )
47 | .reshape((rows, cols))
48 | .astype(np.float32)
49 | )
50 | raw_img *= max_depth / max_factor
51 |
52 | if min_depth_threshold is not None:
53 | raw_img[raw_img < min_depth_threshold] = min_depth_threshold
54 |
55 | return np.expand_dims(raw_img, axis=0)
56 |
57 |
58 | def save_depth_to_pgm(depth, pgm_file_path):
59 | """
60 | Save the depth map to PGM file
61 | :param depth: depth map with 3D ND-array, 1XHXW
62 | :param pgm_file_path: output file path
63 | """
64 | depth_flatten = depth[0]
65 | max_depth = np.max(depth_flatten)
66 | depth_copy = np.copy(depth_flatten)
67 | depth_copy = 65535.0 * (depth_copy / max_depth)
68 | depth_copy = depth_copy.astype(np.uint16)
69 |
70 | with open(pgm_file_path, "wb") as f:
71 | f.write(bytes("P5\n", encoding="ascii"))
72 | f.write(bytes("# %f\n" % max_depth, encoding="ascii"))
73 | f.write(
74 | bytes(
75 | "%d %d\n" % (depth_flatten.shape[1], depth_flatten.shape[0]),
76 | encoding="ascii",
77 | )
78 | )
79 | f.write(bytes("65535\n", encoding="ascii"))
80 | f.write(depth_copy.tobytes())
81 |
82 |
83 | class Storer(ABC):
84 | """
85 | Abstract base class for storing tensor data with validation. Subclasses
86 | must implement the `sanity_check` method to validate tensors and the `store`
87 | method to define how tensors are stored.
88 | """
89 |
90 | sanity_check = None
91 |
92 | def __init__(self, path: Path):
93 | self.path = path
94 | self.path.mkdir(exist_ok=True)
95 |
96 | def to_cv2(self, tensor: torch.Tensor):
97 | np_arr = np.transpose(tensor.clamp(0, 1).cpu().detach().numpy(), (1, 2, 0))
98 | np_arr = np_arr[..., ::-1]
99 | # cv2_format = (np_arr * 65536).astype(np.uint16)
100 | cv2_format = (np_arr * 65535).astype(np.uint16)
101 | return cv2_format
102 |
103 | @abstractmethod
104 | def sanity_check(self, tensor: torch.Tensor) -> bool:
105 | """Abstract method for sanity checking the tensor.
106 | Subclasses should implement this method to define thir specific checks
107 | """
108 | raise NotImplementedError
109 |
110 | @abstractmethod
111 | def store(self, image_name, tensor: torch.Tensor) -> None:
112 | """Abstract method for storing the tensor.
113 | Subclasses should implement this method to define thir specific checks
114 | """
115 | raise NotImplementedError
116 |
117 | def __call__(self, image_name, tensor: torch.Tensor):
118 | if not self.sanity_check(tensor):
119 | raise ValueError(
120 | f"Sanity check failed for tensor with shape {tensor.shape}"
121 | )
122 | self.store(image_name, tensor)
123 |
124 |
125 | class RGBStorer(Storer):
126 |
127 | def sanity_check(self, tensor: torch.Tensor) -> bool:
128 | return tensor.ndim == 3 and tensor.shape[0] == 3
129 |
130 | def store(self, image_name: str, tensor: torch.Tensor) -> None:
131 | cv2_image = self.to_cv2(tensor.clamp(0, 1))
132 | cv2.imwrite(self.path.joinpath(image_name).as_posix(), cv2_image)
133 |
134 |
135 | class AssetStorer:
136 |
137 | def __init__(
138 | self,
139 | out_path: Path,
140 | ):
141 | self.out_path = out_path
142 | out_path.mkdir(exist_ok=True)
143 |
144 | self.viz_storer = RGBStorer(out_path.joinpath("viz"))
145 |
146 | def __call__(
147 | self,
148 | image_name: str,
149 | viz_tensor: Union[torch.Tensor], # 3 X H X W
150 | ):
151 | self.viz_storer(image_name, viz_tensor)
152 |
--------------------------------------------------------------------------------