├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── main_figure.png ├── configs ├── eval │ ├── eval_w_align.yaml │ └── eval_wo_align.yaml └── train │ ├── train_iphone.yaml │ ├── train_kubric_mrig.yaml │ ├── train_nvidia.yaml │ └── train_tnt.yaml ├── licences └── LICENSE_gs.md ├── scripts ├── img2video.py ├── iphone2format.py ├── kubricmrig2format.py ├── nvidia2format.py ├── run_depthanything.py ├── run_mast3r │ ├── depth_preprocessor │ │ ├── get_pcd.py │ │ ├── pcd_utils.py │ │ └── utils.py │ └── run.py ├── tam_npy2png.py └── tnt2format.py └── src ├── data ├── __init__.py ├── asset_readers.py ├── dataloader.py ├── datamodule.py └── utils.py ├── evaluator ├── eval.py └── utils.py ├── model ├── __init__.py ├── rodygs_dynamic.py └── rodygs_static.py ├── pipelines ├── eval.py ├── train.py └── utils.py ├── trainer ├── __init__.py ├── losses.py ├── optim.py ├── renderer.py ├── rodygs.py ├── rodygs_dynamic.py ├── rodygs_static.py └── utils.py └── utils ├── configs.py ├── eval_utils.py ├── general_utils.py ├── graphic_utils.py ├── loss_utils.py ├── point_utils.py ├── pose_estim_utils.py ├── sh_utils.py └── store_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | *_temporal* 3 | __pycache__ 4 | cache/* 5 | logs/* 6 | .vscode 7 | 8 | data/* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/diff-gaussian-rasterization"] 2 | path = thirdparty/diff-gaussian-rasterization 3 | url = https://github.com/slothfulxtx/diff-gaussian-rasterization.git 4 | branch = pose 5 | [submodule "thirdparty/simple-knn"] 6 | path = thirdparty/simple-knn 7 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 8 | [submodule "thirdparty/mast3r"] 9 | path = thirdparty/mast3r 10 | url = https://github.com/naver/mast3r.git 11 | [submodule "thirdparty/pytorch3d"] 12 | path = thirdparty/pytorch3d 13 | url = https://github.com/facebookresearch/pytorch3d.git 14 | [submodule "thirdparty/depth_anything_v2"] 15 | path = thirdparty/depth_anything_v2 16 | url = https://github.com/DepthAnything/Depth-Anything-V2.git 17 | [submodule "thirdparty/Track-Anything"] 18 | path = thirdparty/Track-Anything 19 | url = https://github.com/gaomingqi/Track-Anything.git -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Dynamic Gaussian Splatting for Casual Videos 2 | ### [Project Page](https://rodygs.github.io/) | [Paper](https://www.arxiv.org/abs/2412.03077) | [Video](https://www.youtube.com/watch?v=pwv4Gwl07Tw) 3 | #### [Yoonwoo Jeong](https://jeongyw12382.github.io/), [Junmyeong Lee](https://www.linkedin.com/in/junmyeong-lee/), [Hoseung Choi](https://www.linkedin.com/in/hoseung-choi/), [Minsu Cho](https://cvlab.postech.ac.kr/~mcho/) 4 | 5 | ![Main Figure](assets/main_figure.png) 6 | 7 | 8 | ## News 9 | - [2024-12-11] We have released the code. (dataset coming soon) 10 | 11 | ## Abstract 12 | Dynamic view synthesis (DVS) has advanced remarkably in recent years, achieving high-fidelity rendering while reducing computational costs. Despite the progress, optimizing dynamic neural fields from casual videos remains challenging, as these videos do not provide direct 3D information, such as camera trajectories or the underlying scene geometry. In this work, we present RoDyGS, an optimization pipeline for dynamic Gaussian Splatting from casual videos. It effectively learns motion and underlying geometry of scenes by separating dynamic and static primitives, and ensures that the learned motion and geometry are physically plausible by incorporating motion and geometric regularization terms. We also introduce a comprehensive benchmark, Kubric-MRig, that provides extensive camera and object motion along with simultaneous multi-view captures, features that are absent in previous benchmarks. Experimental results demonstrate that the proposed method significantly outperforms previous pose-free dynamic neural fields and achieves competitive rendering quality compared to existing pose-free static neural fields. 13 | 14 | ## Short Description Video 15 | 16 | [![Video Title](https://img.youtube.com/vi/pwv4Gwl07Tw/maxresdefault.jpg)](https://www.youtube.com/watch?v=pwv4Gwl07Tw) 17 | 18 | 19 | ## Installation 20 | The codes have been tested on Python 3.11 with CUDA version 12.2. 21 | ```bash 22 | git clone https://github.com/POSTECH-CVLab/RoDyGS.git --recursive 23 | cd RoDyGS 24 | 25 | conda create -n rodygs python=3.11 -c anaconda -y 26 | conda activate rodygs 27 | 28 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia -y 29 | pip3 install plyfile==1.0.3 omegaconf==2.3.0 tqdm==4.66.4 piqa==1.3.2 scipy==1.13.1 tifffile==2024.7.24 pandas==2.2.2 numpy==1.24.1 gradio==3.39.0 av=10.0.0 imageio==2.19.5 imageio-ffmpeg 30 | pip3 install opencv-python matplotlib trimesh roma einops huggingface_hub ninja gdown progressbar mmcv 31 | pip3 install git+https://github.com/facebookresearch/segment-anything.git 32 | 33 | pip3 install thirdparty/simple-knn 34 | pip3 install thirdparty/diff-gaussian-rasterization 35 | pip3 install thirdparty/pytorch3d 36 | ``` 37 | 38 | ## Dataset Format 39 | - First, you need to set up the dataset in the following format to run RoDyGS. 40 | ``` 41 | [scene_name] 42 | |- train # train images 43 | |- rgba_00000.png 44 | |- rgba_00001.png 45 | |- ... 46 | |- test # test images 47 | |- rgba_00000.png 48 | |- rgba_00001.png 49 | |- ... 50 | |- train_transforms.json # train camera information 51 | |- test_transforms.json # test camera information 52 | ``` 53 | 54 |
55 | The format of train_transforms.json 56 | 57 | - The units for "camera_angle_x" and "camera_angle_y" are in degrees. 58 | - The range of "time" is from 0.0 to 1.0. 59 | - The format of "transform_matrix" follows the OpenCV camera-to-world matrix. 60 | - "train/rgba_\*.png" in "file path" should be changed to "test/rgba_\*.png" in test_transforms.json. 61 | ```json 62 | { 63 | "camera_angle_x": 50, 64 | "camera_angle_y": 50, 65 | "frames": [ 66 | { 67 | "time": 0.0, 68 | "file_path": "train/rgba_00000.png", 69 | "width": 1920, 70 | "height": 1080, 71 | "transform_matrix": [ 72 | [ 73 | 1.0, 74 | 0.0, 75 | 0.0, 76 | 0.0 77 | ], 78 | [ 79 | 0.0, 80 | 1.0, 81 | 0.0, 82 | 0.0 83 | ], 84 | [ 85 | 0.0, 86 | 0.0, 87 | 1.0, 88 | 0.0 89 | ], 90 | [ 91 | 0.0, 92 | 0.0, 93 | 0.0, 94 | 1.0 95 | ] 96 | ] 97 | }, 98 | { 99 | "time": 0.01, 100 | "file_path": "train/rgba_00001.png", 101 | "width": 1920, 102 | "height": 1080, 103 | "transform_matrix": ~ 104 | }, 105 | ... 106 | { 107 | "time": 0.99, 108 | "file_path": "train/rgba_00099.png", 109 | "width": 1920, 110 | "height": 1080, 111 | "transform_matrix": ~ 112 | } 113 | ] 114 | } 115 | ``` 116 |
117 | 118 | ## Running 119 | ### Pre-processing 120 | 121 | - Please download the checkpoints from https://github.com/naver/mast3r ("MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric") and https://github.com/DepthAnything/Depth-Anything-V2 ("Depth-Anything-V2-Large"). 122 | 123 | - You need to run Track-Anything and obtain motion masks from the demo UI. For more information, visit https://github.com/gaomingqi/Track-Anything. 124 | ```bash 125 | cd RoDyGS/thirdparty/Track-Anything 126 | python3 app.py --mask_save True 127 | ``` 128 | 129 | - Then, follow these steps: 130 | ```bash 131 | cd RoDyGS 132 | 133 | # Get TAM mask 134 | python3 -m scripts.img2video --data_dir [path_to_dataset_dir] 135 | python3 -m scripts.tam_npy2png --npy_dir [path_to_tam_npy_mask] --output_dir [path_to_dataset_dir] 136 | 137 | # Run Depth Anything 138 | python3 -m scripts.run_depthanything --encoder [encoder_type] --encoder-path [path_to_encoder_ckpt] --img-path [path_to_image_dir] --outdir [path_to_output] --raw-depth 139 | 140 | # Run MASt3R 141 | python3 -m scripts.run_mast3r.run --input_dir [path_to_images] --exp_name [mast3r_expname] --output_dir [path_to_mast3r_output] --ckpt [path_to_model_ckpt] --cache_dir [path_to_cache_output] 142 | python3 -m scripts.run_mast3r.depth_preprocessor.get_pcd --datadir [path_to_dataset_dir] --mask_name [mask_dir_name] --mast3r_expname [mast3r_expname] 143 | ``` 144 |
145 | example commands 146 | 147 | ```bash 148 | python3 -m scripts.img2video --data_dir data/kubric_mrig/scene0 149 | python3 -m scripts.tam_npy2png --npy_dir thirdparty/Track-Anything/result/mask/train --output_dir data/kubric_mrig/scene0 150 | 151 | python3 -m scripts.run_depthanything --encoder vitl --encoder-path checkpoints/depth_anything_v2_vitl.pth --img-path data/kubric_mrig/scene0/train --outdir data/kubric_mrig/scene0/depth_anything --raw-depth 152 | 153 | python3 -m scripts.run_mast3r.run --input_dir data/kubric_mrig/scene0/train --exp_name mast3r --output_dir data/kubric_mrig/scene0/mast3r_opt --ckpt MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth --cache_dir mast3r_cache/ 154 | 155 | python3 -m scripts.run_mast3r.depth_preprocessor.get_pcd --datadir data/kubric_mrig/scene0/ --mask_name tam_mask --mast3r_expname mast3r_000 156 | ``` 157 |
158 | 159 | 160 | - After pre-processing, your dataset will have the following format: 161 | ``` 162 | [scene_name] 163 | |- train # train images 164 | |- test # test images 165 | |- train_transforms.json # train camera informations 166 | |- test_transforms.json # test camera information 167 | |- tam_mask 168 | |- depth_anything 169 | |- mast3r_opt 170 | |- [expname] 171 | |- dynamic # dynamic point clouds (per frame) 172 | |- static # static point clouds (per frame) 173 | |- op_results # merged point clouds (unseperated) 174 | |- global_params.pkl # initial camera poses 175 | ``` 176 | 177 | ### Training 178 | ```bash 179 | python3 -m src.pipelines.train -d [path_to_dataset_dir] -b [path_to_training_config] -g [loggin_group_name] -n [training_log_name] 180 | ``` 181 |
182 | example commands 183 | 184 | ```bash 185 | python3 -m src.pipelines.train -d data/kubric_mrig/scene0/ -b configs/train/train_kubric_mrig.yaml -g kubric_mrig -n scene0 186 | ``` 187 |
188 | 189 | ### Rendering & Evaluation 190 | - We used "configs/eval/eval_w_align.yaml" for iPhone dataset evaluation. 191 | - And we used "configs/eval/eval_w_align.yaml" for Kubric-MRig & NVIDIA Dynamic dataset evaluation. 192 | ```bash 193 | python3 -m src.pipelines.eval -c [path_to_config] -t [test_log_name] -d [path_to_dataset] -m [path_to_training_log] 194 | ``` 195 |
196 | example commands 197 | 198 | ```bash 199 | python3 -m src.pipelines.eval -c configs/eval/eval_wo_align.yaml -t eval -d ../data/kubric_mrig/scene0/ -m logs/kubric_0_0777/ 200 | ``` 201 |
202 | 203 | -------------------------------------------------------------------------------- /assets/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/RoDyGS/7057b57afcb3c0297bab500388edf25b0d8d9754/assets/main_figure.png -------------------------------------------------------------------------------- /configs/eval/eval_w_align.yaml: -------------------------------------------------------------------------------- 1 | evaluator: 2 | target: src.evaluator.eval.RoDyGSEvaluator 3 | params: 4 | camera_lr: 0.00005 5 | num_opts: 1000 6 | 7 | static_data: 8 | target: src.data.datamodule.GSDataModule 9 | params: 10 | train_dset_config: 11 | target: src.data.datamodule.LazyDataReader 12 | params: 13 | pose_reader: 14 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 15 | params: 16 | mast3r_expname: swin_noloop_000 17 | mast3r_img_res: 512 18 | train_dloader_config: 19 | target: src.data.dataloader.SequentialSingleDataLoader 20 | params: {} 21 | test_dset_config: 22 | target: src.data.datamodule.LazyDataReader 23 | params: 24 | pose_reader: 25 | target: src.data.asset_readers.Test_MASt3RFovCameraReader 26 | params: 27 | mast3r_expname: swin_noloop_000 28 | mast3r_img_res: 512 29 | test_dloader_config: 30 | target: src.data.dataloader.SequentialSingleDataLoader 31 | params: {} 32 | test_transform_fname: test_transforms.json 33 | 34 | normalize_cams: false -------------------------------------------------------------------------------- /configs/eval/eval_wo_align.yaml: -------------------------------------------------------------------------------- 1 | evaluator: 2 | target: src.evaluator.eval.RoDyGSEvaluator 3 | params: 4 | camera_lr: 0 5 | num_opts: 0 6 | 7 | static_data: 8 | target: src.data.datamodule.GSDataModule 9 | params: 10 | train_dset_config: 11 | target: src.data.datamodule.LazyDataReader 12 | params: 13 | pose_reader: 14 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 15 | params: 16 | mast3r_expname: swin_noloop_000 17 | mast3r_img_res: 512 18 | train_dloader_config: 19 | target: src.data.dataloader.SequentialSingleDataLoader 20 | params: {} 21 | test_dset_config: 22 | target: src.data.datamodule.LazyDataReader 23 | params: 24 | pose_reader: 25 | target: src.data.asset_readers.Test_MASt3RFovCameraReader 26 | params: 27 | mast3r_expname: swin_noloop_000 28 | mast3r_img_res: 512 29 | test_dloader_config: 30 | target: src.data.dataloader.SequentialSingleDataLoader 31 | params: {} 32 | test_transform_fname: test_transforms.json 33 | 34 | normalize_cams: false -------------------------------------------------------------------------------- /configs/train/train_iphone.yaml: -------------------------------------------------------------------------------- 1 | static_data: 2 | target: src.data.datamodule.GSDataModule 3 | params: 4 | dirpath: ./ 5 | train_dset_config: 6 | target: src.data.datamodule.LazyDataReader 7 | params: 8 | camera_config: 9 | target: src.data.utils.FixedCamera 10 | pose_reader: 11 | target: src.data.asset_readers.MASt3RCameraReader 12 | params: 13 | mast3r_expname: swin_noloop_000 14 | mast3r_img_res: 512 15 | depth_reader: 16 | target: src.data.asset_readers.DepthAnythingReader 17 | params: 18 | split: train 19 | motion_mask_reader: 20 | target: src.data.asset_readers.TAMMaskReader 21 | params: 22 | split: train 23 | train_dloader_config: 24 | target: src.data.dataloader.PermutationSingleDataLoader 25 | params: 26 | num_iterations: 20000 27 | test_dset_config: 28 | target: src.data.datamodule.DataReader 29 | params: 30 | camera_config: 31 | target: src.data.utils.FixedCamera 32 | pose_reader: 33 | target: src.data.asset_readers.GTCameraReader 34 | test_dloader_config: 35 | target: src.data.dataloader.SequentialSingleDataLoader 36 | params: {} 37 | train_pcd_reader_config: 38 | target: src.data.asset_readers.MASt3RPCDReader 39 | params: 40 | mast3r_expname: swin_noloop_000 41 | mode: static 42 | num_limit_points: 120000 43 | train_pose_reader_config: 44 | target: src.data.datamodule.DataReader 45 | params: 46 | camera_config: 47 | target: src.data.utils.FixedCamera 48 | pose_reader: 49 | target: src.data.asset_readers.GTCameraReader 50 | normalize_cams: false 51 | static_model: 52 | target: src.model.rodygs_static.StaticRoDyGS 53 | params: 54 | sh_degree: 3 55 | isotropic: false 56 | static_calibrated_pose_reader: 57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 58 | params: 59 | mast3r_expname: swin_noloop_000 60 | mast3r_img_res: 512 61 | dynamic_data: 62 | target: src.data.datamodule.GSDataModule 63 | params: 64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/ 65 | train_dset_config: 66 | target: src.data.datamodule.LazyDataReader 67 | params: 68 | camera_config: 69 | target: src.data.utils.FixedCamera 70 | pose_reader: 71 | target: src.data.asset_readers.MASt3RCameraReader 72 | params: 73 | mast3r_expname: swin_noloop_000 74 | mast3r_img_res: 512 75 | depth_reader: 76 | target: src.data.asset_readers.DepthAnythingReader 77 | params: 78 | split: train 79 | motion_mask_reader: 80 | target: src.data.asset_readers.TAMMaskReader 81 | params: 82 | split: train 83 | train_dloader_config: 84 | target: src.data.dataloader.PermutationSingleDataLoader 85 | params: 86 | num_iterations: 20000 87 | test_dset_config: 88 | target: src.data.datamodule.DataReader 89 | params: 90 | camera_config: 91 | target: src.data.utils.FixedCamera 92 | pose_reader: 93 | target: src.data.asset_readers.GTCameraReader 94 | test_dloader_config: 95 | target: src.data.dataloader.SequentialSingleDataLoader 96 | params: {} 97 | train_pcd_reader_config: 98 | target: src.data.asset_readers.MASt3RPCDReader 99 | params: 100 | mast3r_expname: swin_noloop_000 101 | mode: dynamic 102 | num_limit_points: 120000 103 | test_transform_fname: test_transforms.json 104 | train_pose_reader_config: 105 | target: src.data.datamodule.DataReader 106 | params: 107 | camera_config: 108 | target: src.data.utils.FixedCamera 109 | pose_reader: 110 | target: src.data.asset_readers.GTCameraReader 111 | normalize_cams: false 112 | dynamic_model: 113 | target: src.model.rodygs_dynamic.DynRoDyGS 114 | params: 115 | sh_degree: 3 116 | deform_netwidth: 128 117 | deform_t_emb_multires: 26 118 | deform_t_log_sampling: false 119 | num_basis: 16 120 | isotropic: false 121 | inverse_motion: true 122 | trainer: 123 | target: src.trainer.rodygs.RoDyGSTrainer 124 | params: 125 | log_freq: 50 126 | sh_up_start_iteration: 15000 127 | sh_up_period: 1000 128 | static: 129 | target: src.trainer.rodygs_static.ThreeDGSTrainer 130 | params: 131 | loss_config: 132 | target: src.trainer.losses.MultiLoss 133 | params: 134 | loss_configs: 135 | - name: d_ssim 136 | weight: 0.2 137 | target: src.trainer.losses.SSIMLoss 138 | params: 139 | mode: all 140 | - name: l1 141 | weight: 0.8 142 | target: src.trainer.losses.L1Loss 143 | params: 144 | mode: all 145 | - name: global_pearson_depth 146 | weight: 0.05 147 | target: src.trainer.losses.GlobalPearsonDepthLoss 148 | start: 0 149 | params: 150 | mode: all 151 | - name: local_pearson_depth 152 | weight: 0.15 153 | target: src.trainer.losses.LocalPearsonDepthLoss 154 | start: 0 155 | params: 156 | box_p: 128 157 | p_corr: 0.5 158 | mode: all 159 | num_iterations: 20000 160 | position_lr_init: 0.00016 161 | position_lr_final: 1.6e-06 162 | position_lr_delay_mult: 0.01 163 | position_lr_max_steps: 20000 164 | feature_lr: 0.0025 165 | opacity_lr: 0.05 166 | scaling_lr: 0.005 167 | rotation_lr: 0.001 168 | percent_dense: 0.01 169 | densification_interval: 100 170 | opacity_reset_interval: 5000000 171 | densify_from_iter: 500 172 | densify_until_iter: 20000 173 | densify_grad_threshold: 0.0002 174 | camera_opt_config: 175 | target: src.trainer.optim.CameraQuatOptimizer 176 | params: 177 | camera_rotation_lr: 1.0e-05 178 | camera_translation_lr: 1.0e-06 179 | camera_lr_warmup: 3000 180 | total_steps: 20000 181 | dynamic: 182 | target: src.trainer.rodygs_dynamic.DynTrainer 183 | params: 184 | loss_config: 185 | target: src.trainer.losses.MultiLoss 186 | params: 187 | loss_configs: 188 | - name: d_ssim 189 | weight: 0.2 190 | target: src.trainer.losses.SSIMLoss 191 | params: 192 | mode: all 193 | - name: l1 194 | weight: 0.8 195 | target: src.trainer.losses.L1Loss 196 | params: 197 | mode: all 198 | - name: motion_l1_reg 199 | weight: 0.01 200 | start: 0 201 | target: src.trainer.losses.MotionL1Loss 202 | - name: motion_sparsity 203 | weight: 0.002 204 | start: 0 205 | target: src.trainer.losses.MotionSparsityLoss 206 | - name: global_pearson_depth 207 | weight: 0.05 208 | target: src.trainer.losses.GlobalPearsonDepthLoss 209 | start: 0 210 | params: 211 | mode: all 212 | - name: local_pearson_depth 213 | weight: 0.15 214 | target: src.trainer.losses.LocalPearsonDepthLoss 215 | start: 0 216 | params: 217 | box_p: 128 218 | p_corr: 0.5 219 | mode: all 220 | - name: rigidity 221 | weight: 0.5 222 | freq: 5 223 | start: 0 224 | target: src.trainer.losses.RigidityLoss 225 | params: 226 | mode: 227 | - distance_preserving 228 | - surface 229 | K: 8 230 | - name: motion_basis_reg 231 | weight: 0.5 232 | start: 0 233 | target: src.trainer.losses.MotionBasisRegularizaiton 234 | params: 235 | transl_degree: 0 236 | rot_degree: 0 237 | freq_div_mode: cum_exponential 238 | num_iterations: 20000 239 | position_lr_init: 0.00016 240 | position_lr_final: 1.6e-06 241 | position_lr_delay_mult: 0.01 242 | position_lr_max_steps: 20000 243 | feature_lr: 0.0025 244 | opacity_lr: 0.05 245 | scaling_lr: 0.001 246 | rotation_lr: 0.001 247 | percent_dense: 0.01 248 | densification_interval: 100 249 | opacity_reset_interval: 5000000 250 | densify_from_iter: 500 251 | densify_until_iter: 15000 252 | densify_grad_threshold: 0.0002 253 | deform_warmup_steps: 0 254 | deform_lr_init: 0.0016 255 | deform_lr_final: 0.00016 256 | deform_lr_delay_mult: 0.01 257 | deform_lr_max_steps: 20000 258 | motion_coeff_lr: 0.00016 259 | camera_opt_config: 260 | target: src.trainer.optim.CameraQuatOptimizer 261 | params: 262 | camera_rotation_lr: 0.0 263 | camera_translation_lr: 0.0 264 | camera_lr_warmup: 0 265 | total_steps: 20000 -------------------------------------------------------------------------------- /configs/train/train_kubric_mrig.yaml: -------------------------------------------------------------------------------- 1 | static_data: 2 | target: src.data.datamodule.GSDataModule 3 | params: 4 | dirpath: ./ 5 | train_dset_config: 6 | target: src.data.datamodule.LazyDataReader 7 | params: 8 | camera_config: 9 | target: src.data.utils.FixedCamera 10 | pose_reader: 11 | target: src.data.asset_readers.MASt3RCameraReader 12 | params: 13 | mast3r_expname: swin_noloop_000 14 | mast3r_img_res: 512 15 | depth_reader: 16 | target: src.data.asset_readers.DepthAnythingReader 17 | params: 18 | split: train 19 | motion_mask_reader: 20 | target: src.data.asset_readers.TAMMaskReader 21 | params: 22 | split: train 23 | train_dloader_config: 24 | target: src.data.dataloader.PermutationSingleDataLoader 25 | params: 26 | num_iterations: 20000 27 | test_dset_config: 28 | target: src.data.datamodule.DataReader 29 | params: 30 | camera_config: 31 | target: src.data.utils.FixedCamera 32 | pose_reader: 33 | target: src.data.asset_readers.GTCameraReader 34 | test_dloader_config: 35 | target: src.data.dataloader.SequentialSingleDataLoader 36 | params: {} 37 | train_pcd_reader_config: 38 | target: src.data.asset_readers.MASt3RPCDReader 39 | params: 40 | mast3r_expname: swin_noloop_000 41 | mode: static 42 | num_limit_points: 120000 43 | train_pose_reader_config: 44 | target: src.data.datamodule.DataReader 45 | params: 46 | camera_config: 47 | target: src.data.utils.FixedCamera 48 | pose_reader: 49 | target: src.data.asset_readers.GTCameraReader 50 | normalize_cams: false 51 | static_model: 52 | target: src.model.rodygs_static.StaticRoDyGS 53 | params: 54 | sh_degree: 3 55 | isotropic: false 56 | static_calibrated_pose_reader: 57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 58 | params: 59 | mast3r_expname: swin_noloop_000 60 | mast3r_img_res: 512 61 | dynamic_data: 62 | target: src.data.datamodule.GSDataModule 63 | params: 64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/ 65 | train_dset_config: 66 | target: src.data.datamodule.LazyDataReader 67 | params: 68 | camera_config: 69 | target: src.data.utils.FixedCamera 70 | pose_reader: 71 | target: src.data.asset_readers.MASt3RCameraReader 72 | params: 73 | mast3r_expname: swin_noloop_000 74 | mast3r_img_res: 512 75 | depth_reader: 76 | target: src.data.asset_readers.DepthAnythingReader 77 | params: 78 | split: train 79 | motion_mask_reader: 80 | target: src.data.asset_readers.TAMMaskReader 81 | params: 82 | split: train 83 | train_dloader_config: 84 | target: src.data.dataloader.PermutationSingleDataLoader 85 | params: 86 | num_iterations: 20000 87 | test_dset_config: 88 | target: src.data.datamodule.DataReader 89 | params: 90 | camera_config: 91 | target: src.data.utils.FixedCamera 92 | pose_reader: 93 | target: src.data.asset_readers.GTCameraReader 94 | test_dloader_config: 95 | target: src.data.dataloader.SequentialSingleDataLoader 96 | params: {} 97 | train_pcd_reader_config: 98 | target: src.data.asset_readers.MASt3RPCDReader 99 | params: 100 | mast3r_expname: swin_noloop_000 101 | mode: dynamic 102 | num_limit_points: 120000 103 | test_transform_fname: test_transforms.json 104 | train_pose_reader_config: 105 | target: src.data.datamodule.DataReader 106 | params: 107 | camera_config: 108 | target: src.data.utils.FixedCamera 109 | pose_reader: 110 | target: src.data.asset_readers.GTCameraReader 111 | normalize_cams: false 112 | dynamic_model: 113 | target: src.model.rodygs_dynamic.DynRoDyGS 114 | params: 115 | sh_degree: 3 116 | deform_netwidth: 128 117 | deform_t_emb_multires: 26 118 | deform_t_log_sampling: false 119 | num_basis: 16 120 | isotropic: false 121 | inverse_motion: true 122 | trainer: 123 | target: src.trainer.rodygs.RoDyGSTrainer 124 | params: 125 | log_freq: 50 126 | sh_up_start_iteration: 15000 127 | sh_up_period: 1000 128 | static: 129 | target: src.trainer.rodygs_static.ThreeDGSTrainer 130 | params: 131 | loss_config: 132 | target: src.trainer.losses.MultiLoss 133 | params: 134 | loss_configs: 135 | - name: d_ssim 136 | weight: 0.2 137 | target: src.trainer.losses.SSIMLoss 138 | params: 139 | mode: all 140 | - name: l1 141 | weight: 0.8 142 | target: src.trainer.losses.L1Loss 143 | params: 144 | mode: all 145 | - name: global_pearson_depth 146 | weight: 0.05 147 | target: src.trainer.losses.GlobalPearsonDepthLoss 148 | start: 0 149 | params: 150 | mode: all 151 | - name: local_pearson_depth 152 | weight: 0.15 153 | target: src.trainer.losses.LocalPearsonDepthLoss 154 | start: 0 155 | params: 156 | box_p: 128 157 | p_corr: 0.5 158 | mode: all 159 | num_iterations: 20000 160 | position_lr_init: 0.00016 161 | position_lr_final: 1.6e-06 162 | position_lr_delay_mult: 0.01 163 | position_lr_max_steps: 20000 164 | feature_lr: 0.0025 165 | opacity_lr: 0.05 166 | scaling_lr: 0.005 167 | rotation_lr: 0.001 168 | percent_dense: 0.01 169 | densification_interval: 100 170 | opacity_reset_interval: 5000000 171 | densify_from_iter: 500 172 | densify_until_iter: 20000 173 | densify_grad_threshold: 0.0002 174 | camera_opt_config: 175 | target: src.trainer.optim.CameraQuatOptimizer 176 | params: 177 | camera_rotation_lr: 1.0e-05 178 | camera_translation_lr: 1.0e-06 179 | camera_lr_warmup: 0 180 | total_steps: 20000 181 | dynamic: 182 | target: src.trainer.rodygs_dynamic.DynTrainer 183 | params: 184 | loss_config: 185 | target: src.trainer.losses.MultiLoss 186 | params: 187 | loss_configs: 188 | - name: d_ssim 189 | weight: 0.2 190 | target: src.trainer.losses.SSIMLoss 191 | params: 192 | mode: all 193 | - name: l1 194 | weight: 0.8 195 | target: src.trainer.losses.L1Loss 196 | params: 197 | mode: all 198 | - name: motion_l1_reg 199 | weight: 0.01 200 | start: 0 201 | target: src.trainer.losses.MotionL1Loss 202 | - name: motion_sparsity 203 | weight: 0.002 204 | start: 0 205 | target: src.trainer.losses.MotionSparsityLoss 206 | - name: global_pearson_depth 207 | weight: 0.05 208 | target: src.trainer.losses.GlobalPearsonDepthLoss 209 | start: 0 210 | params: 211 | mode: all 212 | - name: local_pearson_depth 213 | weight: 0.15 214 | target: src.trainer.losses.LocalPearsonDepthLoss 215 | start: 0 216 | params: 217 | box_p: 128 218 | p_corr: 0.5 219 | mode: all 220 | - name: rigidity 221 | weight: 0.5 222 | freq: 5 223 | start: 0 224 | target: src.trainer.losses.RigidityLoss 225 | params: 226 | mode: 227 | - distance_preserving 228 | - surface 229 | K: 8 230 | - name: motion_basis_reg 231 | weight: 0.1 232 | start: 0 233 | target: src.trainer.losses.MotionBasisRegularizaiton 234 | params: 235 | transl_degree: 0 236 | rot_degree: 0 237 | freq_div_mode: cum_exponential 238 | num_iterations: 20000 239 | position_lr_init: 0.00016 240 | position_lr_final: 1.6e-06 241 | position_lr_delay_mult: 0.01 242 | position_lr_max_steps: 20000 243 | feature_lr: 0.0025 244 | opacity_lr: 0.05 245 | scaling_lr: 0.001 246 | rotation_lr: 0.001 247 | percent_dense: 0.01 248 | densification_interval: 100 249 | opacity_reset_interval: 5000000 250 | densify_from_iter: 500 251 | densify_until_iter: 15000 252 | densify_grad_threshold: 0.0002 253 | deform_warmup_steps: 0 254 | deform_lr_init: 0.0016 255 | deform_lr_final: 0.00016 256 | deform_lr_delay_mult: 0.01 257 | deform_lr_max_steps: 20000 258 | motion_coeff_lr: 0.00016 259 | camera_opt_config: 260 | target: src.trainer.optim.CameraQuatOptimizer 261 | params: 262 | camera_rotation_lr: 0.0 263 | camera_translation_lr: 0.0 264 | camera_lr_warmup: 0 265 | total_steps: 20000 -------------------------------------------------------------------------------- /configs/train/train_nvidia.yaml: -------------------------------------------------------------------------------- 1 | static_data: 2 | target: src.data.datamodule.GSDataModule 3 | params: 4 | dirpath: ./ 5 | train_dset_config: 6 | target: src.data.datamodule.LazyDataReader 7 | params: 8 | camera_config: 9 | target: src.data.utils.FixedCamera 10 | pose_reader: 11 | target: src.data.asset_readers.MASt3RCameraReader 12 | params: 13 | mast3r_expname: swin_noloop_000 14 | mast3r_img_res: 512 15 | depth_reader: 16 | target: src.data.asset_readers.DepthAnythingReader 17 | params: 18 | split: train 19 | motion_mask_reader: 20 | target: src.data.asset_readers.TAMMaskReader 21 | params: 22 | split: train 23 | train_dloader_config: 24 | target: src.data.dataloader.PermutationSingleDataLoader 25 | params: 26 | num_iterations: 20000 27 | test_dset_config: 28 | target: src.data.datamodule.DataReader 29 | params: 30 | camera_config: 31 | target: src.data.utils.FixedCamera 32 | pose_reader: 33 | target: src.data.asset_readers.GTCameraReader 34 | test_dloader_config: 35 | target: src.data.dataloader.SequentialSingleDataLoader 36 | params: {} 37 | train_pcd_reader_config: 38 | target: src.data.asset_readers.MASt3RPCDReader 39 | params: 40 | mast3r_expname: swin_noloop_000 41 | mode: static 42 | num_limit_points: 120000 43 | train_pose_reader_config: 44 | target: src.data.datamodule.DataReader 45 | params: 46 | camera_config: 47 | target: src.data.utils.FixedCamera 48 | pose_reader: 49 | target: src.data.asset_readers.GTCameraReader 50 | normalize_cams: false 51 | static_model: 52 | target: src.model.rodygs_static.StaticRoDyGS 53 | params: 54 | sh_degree: 3 55 | isotropic: false 56 | static_calibrated_pose_reader: 57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 58 | params: 59 | mast3r_expname: swin_noloop_000 60 | mast3r_img_res: 512 61 | dynamic_data: 62 | target: src.data.datamodule.GSDataModule 63 | params: 64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/ 65 | train_dset_config: 66 | target: src.data.datamodule.LazyDataReader 67 | params: 68 | camera_config: 69 | target: src.data.utils.FixedCamera 70 | pose_reader: 71 | target: src.data.asset_readers.MASt3RCameraReader 72 | params: 73 | mast3r_expname: swin_noloop_000 74 | mast3r_img_res: 512 75 | depth_reader: 76 | target: src.data.asset_readers.DepthAnythingReader 77 | params: 78 | split: train 79 | motion_mask_reader: 80 | target: src.data.asset_readers.TAMMaskReader 81 | params: 82 | split: train 83 | train_dloader_config: 84 | target: src.data.dataloader.PermutationSingleDataLoader 85 | params: 86 | num_iterations: 20000 87 | test_dset_config: 88 | target: src.data.datamodule.DataReader 89 | params: 90 | camera_config: 91 | target: src.data.utils.FixedCamera 92 | pose_reader: 93 | target: src.data.asset_readers.GTCameraReader 94 | test_dloader_config: 95 | target: src.data.dataloader.SequentialSingleDataLoader 96 | params: {} 97 | train_pcd_reader_config: 98 | target: src.data.asset_readers.MASt3RPCDReader 99 | params: 100 | mast3r_expname: swin_noloop_000 101 | mode: dynamic 102 | num_limit_points: 120000 103 | test_transform_fname: test_transforms.json 104 | train_pose_reader_config: 105 | target: src.data.datamodule.DataReader 106 | params: 107 | camera_config: 108 | target: src.data.utils.FixedCamera 109 | pose_reader: 110 | target: src.data.asset_readers.GTCameraReader 111 | normalize_cams: false 112 | dynamic_model: 113 | target: src.model.rodygs_dynamic.DynRoDyGS 114 | params: 115 | sh_degree: 3 116 | deform_netwidth: 128 117 | deform_t_emb_multires: 26 118 | deform_t_log_sampling: false 119 | num_basis: 16 120 | isotropic: false 121 | inverse_motion: true 122 | trainer: 123 | target: src.trainer.rodygs.RoDyGSTrainer 124 | params: 125 | log_freq: 50 126 | sh_up_start_iteration: 15000 127 | sh_up_period: 1000 128 | static: 129 | target: src.trainer.rodygs_static.ThreeDGSTrainer 130 | params: 131 | loss_config: 132 | target: src.trainer.losses.MultiLoss 133 | params: 134 | loss_configs: 135 | - name: d_ssim 136 | weight: 0.2 137 | target: src.trainer.losses.SSIMLoss 138 | params: 139 | mode: all 140 | - name: l1 141 | weight: 0.8 142 | target: src.trainer.losses.L1Loss 143 | params: 144 | mode: all 145 | - name: global_pearson_depth 146 | weight: 0.05 147 | target: src.trainer.losses.GlobalPearsonDepthLoss 148 | start: 0 149 | params: 150 | mode: all 151 | - name: local_pearson_depth 152 | weight: 0.15 153 | target: src.trainer.losses.LocalPearsonDepthLoss 154 | start: 0 155 | params: 156 | box_p: 128 157 | p_corr: 0.5 158 | mode: all 159 | num_iterations: 20000 160 | position_lr_init: 0.00016 161 | position_lr_final: 1.6e-06 162 | position_lr_delay_mult: 0.01 163 | position_lr_max_steps: 20000 164 | feature_lr: 0.0025 165 | opacity_lr: 0.05 166 | scaling_lr: 0.005 167 | rotation_lr: 0.001 168 | percent_dense: 0.01 169 | densification_interval: 100 170 | opacity_reset_interval: 5000000 171 | densify_from_iter: 500 172 | densify_until_iter: 20000 173 | densify_grad_threshold: 0.0002 174 | camera_opt_config: 175 | target: src.trainer.optim.CameraQuatOptimizer 176 | params: 177 | camera_rotation_lr: 1.0e-05 178 | camera_translation_lr: 1.0e-06 179 | camera_lr_warmup: 0 180 | total_steps: 20000 181 | dynamic: 182 | target: src.trainer.rodygs_dynamic.DynTrainer 183 | params: 184 | loss_config: 185 | target: src.trainer.losses.MultiLoss 186 | params: 187 | loss_configs: 188 | - name: d_ssim 189 | weight: 0.2 190 | target: src.trainer.losses.SSIMLoss 191 | params: 192 | mode: all 193 | - name: l1 194 | weight: 0.8 195 | target: src.trainer.losses.L1Loss 196 | params: 197 | mode: all 198 | - name: motion_l1_reg 199 | weight: 0.01 200 | start: 0 201 | target: src.trainer.losses.MotionL1Loss 202 | - name: motion_sparsity 203 | weight: 0.002 204 | start: 0 205 | target: src.trainer.losses.MotionSparsityLoss 206 | - name: global_pearson_depth 207 | weight: 0.05 208 | target: src.trainer.losses.GlobalPearsonDepthLoss 209 | start: 0 210 | params: 211 | mode: all 212 | - name: local_pearson_depth 213 | weight: 0.15 214 | target: src.trainer.losses.LocalPearsonDepthLoss 215 | start: 0 216 | params: 217 | box_p: 128 218 | p_corr: 0.5 219 | mode: all 220 | - name: rigidity 221 | weight: 0.5 222 | freq: 5 223 | start: 0 224 | target: src.trainer.losses.RigidityLoss 225 | params: 226 | mode: 227 | - distance_preserving 228 | - surface 229 | K: 8 230 | - name: motion_basis_reg 231 | weight: 0.1 232 | start: 0 233 | target: src.trainer.losses.MotionBasisRegularizaiton 234 | params: 235 | transl_degree: 0 236 | rot_degree: 0 237 | freq_div_mode: cum_exponential 238 | num_iterations: 20000 239 | position_lr_init: 0.00016 240 | position_lr_final: 1.6e-06 241 | position_lr_delay_mult: 0.01 242 | position_lr_max_steps: 20000 243 | feature_lr: 0.0025 244 | opacity_lr: 0.05 245 | scaling_lr: 0.001 246 | rotation_lr: 0.001 247 | percent_dense: 0.01 248 | densification_interval: 100 249 | opacity_reset_interval: 5000000 250 | densify_from_iter: 500 251 | densify_until_iter: 15000 252 | densify_grad_threshold: 0.0002 253 | deform_warmup_steps: 0 254 | deform_lr_init: 0.0016 255 | deform_lr_final: 0.00016 256 | deform_lr_delay_mult: 0.01 257 | deform_lr_max_steps: 20000 258 | motion_coeff_lr: 0.00016 259 | camera_opt_config: 260 | target: src.trainer.optim.CameraQuatOptimizer 261 | params: 262 | camera_rotation_lr: 0.0 263 | camera_translation_lr: 0.0 264 | camera_lr_warmup: 0 265 | total_steps: 20000 -------------------------------------------------------------------------------- /configs/train/train_tnt.yaml: -------------------------------------------------------------------------------- 1 | static_data: 2 | target: src.data.datamodule.GSDataModule 3 | params: 4 | dirpath: ./ 5 | train_dset_config: 6 | target: src.data.datamodule.LazyDataReader 7 | params: 8 | camera_config: 9 | target: src.data.utils.FixedCamera 10 | pose_reader: 11 | target: src.data.asset_readers.MASt3RCameraReader 12 | params: 13 | mast3r_expname: swin_noloop_000 14 | mast3r_img_res: 512 15 | depth_reader: 16 | target: src.data.asset_readers.DepthAnythingReader 17 | params: 18 | split: train 19 | motion_mask_reader: 20 | target: src.data.asset_readers.TAMMaskReader 21 | params: 22 | split: train 23 | train_dloader_config: 24 | target: src.data.dataloader.PermutationSingleDataLoader 25 | params: 26 | num_iterations: 150000 27 | test_dset_config: 28 | target: src.data.datamodule.DataReader 29 | params: 30 | camera_config: 31 | target: src.data.utils.FixedCamera 32 | pose_reader: 33 | target: src.data.asset_readers.GTCameraReader 34 | test_dloader_config: 35 | target: src.data.dataloader.SequentialSingleDataLoader 36 | params: {} 37 | train_pcd_reader_config: 38 | target: src.data.asset_readers.MASt3RPCDReader 39 | params: 40 | mast3r_expname: swin_noloop_000 41 | mode: static 42 | num_limit_points: 120000 43 | train_pose_reader_config: 44 | target: src.data.datamodule.DataReader 45 | params: 46 | camera_config: 47 | target: src.data.utils.FixedCamera 48 | pose_reader: 49 | target: src.data.asset_readers.GTCameraReader 50 | normalize_cams: false 51 | static_model: 52 | target: src.model.rodygs_static.StaticRoDyGS 53 | params: 54 | sh_degree: 3 55 | isotropic: false 56 | static_calibrated_pose_reader: 57 | target: src.data.asset_readers.MASt3R_CKPTCameraReader 58 | params: 59 | mast3r_expname: swin_noloop_000 60 | mast3r_img_res: 512 61 | dynamic_data: 62 | target: src.data.datamodule.GSDataModule 63 | params: 64 | dirpath: /home/junmyeong/data/kubric_multirig/scene_0/ 65 | train_dset_config: 66 | target: src.data.datamodule.LazyDataReader 67 | params: 68 | camera_config: 69 | target: src.data.utils.FixedCamera 70 | pose_reader: 71 | target: src.data.asset_readers.MASt3RCameraReader 72 | params: 73 | mast3r_expname: swin_noloop_000 74 | mast3r_img_res: 512 75 | depth_reader: 76 | target: src.data.asset_readers.DepthAnythingReader 77 | params: 78 | split: train 79 | motion_mask_reader: 80 | target: src.data.asset_readers.TAMMaskReader 81 | params: 82 | split: train 83 | train_dloader_config: 84 | target: src.data.dataloader.PermutationSingleDataLoader 85 | params: 86 | num_iterations: 150000 87 | test_dset_config: 88 | target: src.data.datamodule.DataReader 89 | params: 90 | camera_config: 91 | target: src.data.utils.FixedCamera 92 | pose_reader: 93 | target: src.data.asset_readers.GTCameraReader 94 | test_dloader_config: 95 | target: src.data.dataloader.SequentialSingleDataLoader 96 | params: {} 97 | train_pcd_reader_config: 98 | target: src.data.asset_readers.MASt3RPCDReader 99 | params: 100 | mast3r_expname: swin_noloop_000 101 | mode: dynamic 102 | num_limit_points: 120000 103 | test_transform_fname: test_transforms.json 104 | train_pose_reader_config: 105 | target: src.data.datamodule.DataReader 106 | params: 107 | camera_config: 108 | target: src.data.utils.FixedCamera 109 | pose_reader: 110 | target: src.data.asset_readers.GTCameraReader 111 | normalize_cams: false 112 | dynamic_model: 113 | target: src.model.rodygs_dynamic.DynRoDyGS 114 | params: 115 | sh_degree: 3 116 | deform_netwidth: 128 117 | deform_t_emb_multires: 26 118 | deform_t_log_sampling: false 119 | num_basis: 16 120 | isotropic: false 121 | inverse_motion: true 122 | trainer: 123 | target: src.trainer.rodygs.RoDyGSTrainer 124 | params: 125 | log_freq: 50 126 | sh_up_start_iteration: 15000 127 | sh_up_period: 1000 128 | static: 129 | target: src.trainer.rodygs_static.ThreeDGSTrainer 130 | params: 131 | loss_config: 132 | target: src.trainer.losses.MultiLoss 133 | params: 134 | loss_configs: 135 | - name: d_ssim 136 | weight: 0.2 137 | target: src.trainer.losses.SSIMLoss 138 | params: 139 | mode: all 140 | - name: l1 141 | weight: 0.8 142 | target: src.trainer.losses.L1Loss 143 | params: 144 | mode: all 145 | - name: global_pearson_depth 146 | weight: 0.001 147 | target: src.trainer.losses.GlobalPearsonDepthLoss 148 | start: 0 149 | params: 150 | mode: all 151 | - name: local_pearson_depth 152 | weight: 0.01 153 | target: src.trainer.losses.LocalPearsonDepthLoss 154 | start: 0 155 | params: 156 | box_p: 128 157 | p_corr: 0.5 158 | mode: all 159 | num_iterations: 150000 160 | position_lr_init: 0.00016 161 | position_lr_final: 1.6e-06 162 | position_lr_delay_mult: 0.01 163 | position_lr_max_steps: 150000 164 | feature_lr: 0.0025 165 | opacity_lr: 0.05 166 | scaling_lr: 0.005 167 | rotation_lr: 0.001 168 | percent_dense: 0.01 169 | densification_interval: 100 170 | opacity_reset_interval: 5000000 171 | densify_from_iter: 500 172 | densify_until_iter: 150000 173 | densify_grad_threshold: 0.0002 174 | camera_opt_config: 175 | target: src.trainer.optim.CameraQuatOptimizer 176 | params: 177 | camera_rotation_lr: 5.0e-06 178 | camera_translation_lr: 5.0e-05 179 | camera_lr_warmup: 3000 180 | total_steps: 150000 181 | dynamic: 182 | target: src.trainer.rodygs_dynamic.DynTrainer 183 | params: 184 | loss_config: 185 | target: src.trainer.losses.MultiLoss 186 | params: 187 | loss_configs: 188 | - name: d_ssim 189 | weight: 0.2 190 | target: src.trainer.losses.SSIMLoss 191 | params: 192 | mode: all 193 | - name: l1 194 | weight: 0.8 195 | target: src.trainer.losses.L1Loss 196 | params: 197 | mode: all 198 | - name: motion_l1_reg 199 | weight: 0.01 200 | start: 0 201 | target: src.trainer.losses.MotionL1Loss 202 | - name: motion_sparsity 203 | weight: 0.002 204 | start: 0 205 | target: src.trainer.losses.MotionSparsityLoss 206 | - name: global_pearson_depth 207 | weight: 0.05 208 | target: src.trainer.losses.GlobalPearsonDepthLoss 209 | start: 0 210 | params: 211 | mode: all 212 | - name: local_pearson_depth 213 | weight: 0.15 214 | target: src.trainer.losses.LocalPearsonDepthLoss 215 | start: 0 216 | params: 217 | box_p: 128 218 | p_corr: 0.5 219 | mode: all 220 | - name: rigidity 221 | weight: 0.5 222 | freq: 5 223 | start: 0 224 | target: src.trainer.losses.RigidityLoss 225 | params: 226 | mode: 227 | - distance_preserving 228 | - surface 229 | K: 8 230 | - name: motion_basis_reg 231 | weight: 0.1 232 | start: 0 233 | target: src.trainer.losses.MotionBasisRegularizaiton 234 | params: 235 | transl_degree: 0 236 | rot_degree: 0 237 | freq_div_mode: cum_exponential 238 | num_iterations: 150000 239 | position_lr_init: 0.00016 240 | position_lr_final: 1.6e-06 241 | position_lr_delay_mult: 0.01 242 | position_lr_max_steps: 150000 243 | feature_lr: 0.0025 244 | opacity_lr: 0.05 245 | scaling_lr: 0.001 246 | rotation_lr: 0.001 247 | percent_dense: 0.01 248 | densification_interval: 100 249 | opacity_reset_interval: 5000000 250 | densify_from_iter: 500 251 | densify_until_iter: 15000 252 | densify_grad_threshold: 0.0002 253 | deform_warmup_steps: 0 254 | deform_lr_init: 0.0016 255 | deform_lr_final: 0.00016 256 | deform_lr_delay_mult: 0.01 257 | deform_lr_max_steps: 150000 258 | motion_coeff_lr: 0.00016 259 | camera_opt_config: 260 | target: src.trainer.optim.CameraQuatOptimizer 261 | params: 262 | camera_rotation_lr: 0.0 263 | camera_translation_lr: 0.0 264 | camera_lr_warmup: 0 265 | total_steps: 150000 -------------------------------------------------------------------------------- /licences/LICENSE_gs.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /scripts/img2video.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import imageio 12 | import os 13 | import argparse 14 | 15 | 16 | def convert_png_to_video(data_dir): 17 | 18 | imgs_path = os.path.join(data_dir, "train") 19 | images = [img for img in os.listdir(imgs_path) if img.endswith((".png"))] 20 | images.sort() 21 | output_video = os.path.join(data_dir, "train.mp4") 22 | 23 | if not images: 24 | print("No images!") 25 | else: 26 | with imageio.get_writer( 27 | output_video, fps=30, codec="libx264", macro_block_size=1 28 | ) as writer: 29 | for image in images: 30 | writer.append_data(imageio.imread(os.path.join(imgs_path, image))) 31 | print(f"Generate {output_video} ...") 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--data_dir", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | convert_png_to_video(args.data_dir) 40 | -------------------------------------------------------------------------------- /scripts/iphone2format.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | 12 | import os 13 | import numpy as np 14 | import argparse 15 | from PIL import Image 16 | import json 17 | import math 18 | 19 | 20 | def focal2fov(focal, pixels): 21 | return 2 * math.atan(pixels / (2 * focal)) 22 | 23 | 24 | def convert2format(data_dir, output_dir, resolution): 25 | 26 | split_path = os.path.join(data_dir, "splits") 27 | with open(os.path.join(split_path, "train.json"), "r") as fp: 28 | train_json = json.load(fp) 29 | 30 | img_paths = [] 31 | cam_paths = [] 32 | for frame_name in train_json["frame_names"]: 33 | if resolution == 1: 34 | img_paths.append(os.path.join(data_dir, "rgb", "1x", frame_name + ".png")) 35 | else: 36 | img_paths.append(os.path.join(data_dir, "rgb", "2x", frame_name + ".png")) 37 | cam_paths.append(os.path.join(data_dir, "camera", frame_name + ".json")) 38 | 39 | os.makedirs(output_dir, exist_ok=True) 40 | train_outimgdir = os.path.join(output_dir, "train") 41 | os.makedirs(train_outimgdir, exist_ok=True) 42 | test_outimgdir = os.path.join(output_dir, "test") 43 | os.makedirs(test_outimgdir, exist_ok=True) 44 | 45 | with open(cam_paths[0], "r") as fp: 46 | cam_0 = json.load(fp) 47 | train_transforms = dict() 48 | train_transforms["camera_angle_x"] = ( 49 | focal2fov(cam_0["focal_length"], 720) * 180 / math.pi 50 | ) 51 | train_transforms["camera_angle_y"] = ( 52 | focal2fov(cam_0["focal_length"], 960) * 180 / math.pi 53 | ) 54 | train_transforms["frames"] = [] 55 | test_transforms = dict() 56 | test_transforms["camera_angle_x"] = ( 57 | focal2fov(cam_0["focal_length"], 720) * 180 / math.pi 58 | ) 59 | test_transforms["camera_angle_y"] = ( 60 | focal2fov(cam_0["focal_length"], 960) * 180 / math.pi 61 | ) 62 | test_transforms["frames"] = [] 63 | 64 | # Get train & test images from only the train cameras 65 | frame_idx = 0 66 | train_id = 0 67 | test_id = 0 68 | for img, cam in zip(img_paths, cam_paths): 69 | frame = dict() 70 | image_fname = f"rgba_{frame_idx:05d}.png" 71 | image = Image.open(img) 72 | 73 | frame["time"] = frame_idx / len(img_paths) 74 | frame["file_path"] = image_fname # tmp 75 | frame["width"] = int(720 / resolution) 76 | frame["height"] = int(960 / resolution) 77 | 78 | with open(cam, "r") as fp: 79 | cam = json.load(fp) 80 | w2c_rot = np.array(cam["orientation"]) 81 | c2w_rot = np.linalg.inv(w2c_rot) 82 | 83 | c2w = np.eye(4) 84 | c2w[:3, :3] = c2w_rot 85 | c2w[:3, 3] = np.array(cam["position"]) 86 | frame["transform_matrix"] = c2w.tolist() 87 | 88 | if (frame_idx + 4) % 8 == 0: 89 | image_fname = f"rgba_{train_id:05d}.png" 90 | frame["file_path"] = os.path.join("test", image_fname) 91 | test_transforms["frames"].append(frame) 92 | test_out_image_fpath = os.path.join(test_outimgdir, image_fname) 93 | image.save(test_out_image_fpath) 94 | train_id += 1 95 | else: 96 | image_fname = f"rgba_{test_id:05d}.png" 97 | frame["file_path"] = os.path.join("train", image_fname) 98 | train_transforms["frames"].append(frame) 99 | train_out_image_fpath = os.path.join(train_outimgdir, image_fname) 100 | image.save(train_out_image_fpath) 101 | test_id += 1 102 | 103 | frame_idx += 1 104 | 105 | with open(os.path.join(output_dir, "train_transforms.json"), "w") as fp: 106 | json.dump(train_transforms, fp, indent=4) 107 | with open(os.path.join(output_dir, "test_transforms.json"), "w") as fp: 108 | json.dump(test_transforms, fp, indent=4) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--data_dir", required=True, type=str) 114 | parser.add_argument("--output_dir", required=True, type=str) 115 | parser.add_argument("--resolution", type=int, default=1) 116 | args = parser.parse_args() 117 | assert args.resolution in [1, 2], "assume resolution is 1x or 2x" 118 | 119 | convert2format(args.data_dir, args.output_dir, args.resolution) 120 | -------------------------------------------------------------------------------- /scripts/kubricmrig2format.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import argparse 12 | import json 13 | from pathlib import Path 14 | from PIL import Image 15 | 16 | import numpy as np 17 | 18 | 19 | # used for camera direction conversion(inverse direction) 20 | gl_matrix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 21 | 22 | # used for world coordinate conversion 23 | opencv_matrix = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 24 | 25 | 26 | def quaternion_to_rotation_matrix(quaternion): 27 | q = np.array(quaternion, dtype=np.float64) 28 | norm = np.linalg.norm(q) 29 | if norm == 0: 30 | return np.eye(3) 31 | q /= norm 32 | w, x, y, z = q 33 | rotation_matrix = np.array( 34 | [ 35 | [1 - 2 * y**2 - 2 * z**2, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y], 36 | [2 * x * y + 2 * w * z, 1 - 2 * x**2 - 2 * z**2, 2 * y * z - 2 * w * x], 37 | [2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, 1 - 2 * x**2 - 2 * y**2], 38 | ] 39 | ) 40 | return rotation_matrix 41 | 42 | 43 | def kubric2opencv(extrinsic): 44 | extrinsic = opencv_matrix @ extrinsic @ gl_matrix 45 | 46 | return extrinsic 47 | 48 | 49 | def convert2format(args): 50 | input_dir = Path(args.input_dir) 51 | train_dirpath = input_dir.joinpath("train") 52 | test_dirpath = input_dir.joinpath("test") 53 | 54 | outdirpath = Path(args.output_dir) 55 | outdirpath.mkdir(exist_ok=True, parents=True) 56 | 57 | for split, dirpath in zip( 58 | ["train", "val", "test"], [train_dirpath, test_dirpath, test_dirpath] 59 | ): 60 | metadata = dirpath.joinpath("metadata.json") 61 | outimgdir = outdirpath.joinpath(split) 62 | outimgdir.mkdir(exist_ok=True, parents=True) 63 | 64 | with open(metadata.as_posix(), "r") as fp: 65 | metadata = json.load(fp) 66 | 67 | transforms = dict() 68 | 69 | H, W = metadata["metadata"]["resolution"] 70 | fov = np.rad2deg(metadata["camera"]["field_of_view"]) 71 | camera_angle_x, camera_angle_y = fov, fov 72 | 73 | transforms["camera_angle_x"] = camera_angle_x 74 | transforms["camera_angle_y"] = camera_angle_y 75 | 76 | transforms["frames"] = [] 77 | num_frames = metadata["metadata"]["num_frames"] 78 | 79 | if split == "train": 80 | iterator = range(num_frames) 81 | elif split == "val": 82 | iterator = np.array(range(num_frames))[::10] 83 | else: 84 | # use the rest of the frames for testing 85 | iterator = np.array([idx for idx in range(num_frames) if idx % 10 != 0]) 86 | 87 | for frame_idx in iterator: 88 | frame = dict() 89 | 90 | image_fname = f"rgba_{frame_idx:05d}.png" 91 | image_fpath = dirpath.joinpath(image_fname) 92 | image = Image.open(image_fpath) 93 | out_image_fpath = outimgdir.joinpath(image_fname) 94 | image.save(out_image_fpath) 95 | 96 | frame["time"] = frame_idx / num_frames 97 | frame["file_path"] = Path(split, image_fname).as_posix() 98 | frame["width"] = W 99 | frame["height"] = H 100 | 101 | c2w = np.eye(4) 102 | quaternion = metadata["camera"]["quaternions"][frame_idx] 103 | c2w[:3, :3] = quaternion_to_rotation_matrix(quaternion) 104 | c2w[:3, 3] = np.array(metadata["camera"]["positions"][frame_idx]) 105 | 106 | # change coordinate system to opencv format 107 | # world coordinate : blender -> opencv 108 | # camera (local) type : opengl -> opencv 109 | c2w = kubric2opencv(c2w) 110 | 111 | frame["transform_matrix"] = c2w.tolist() 112 | transforms["frames"].append(frame) 113 | 114 | with open(outdirpath.joinpath(f"{split}_transforms.json"), "w") as fp: 115 | json.dump(transforms, fp, indent=4) 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument( 121 | "--input_dir", required=True, type=str, help="path to generated kubric-mrig" 122 | ) 123 | parser.add_argument( 124 | "--output_dir", required=True, type=str, help="path to store converted assets" 125 | ) 126 | args = parser.parse_args() 127 | 128 | convert2format(args) 129 | -------------------------------------------------------------------------------- /scripts/nvidia2format.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import glob 15 | import argparse 16 | from PIL import Image 17 | import json 18 | import math 19 | 20 | img_downsample = 2 21 | 22 | 23 | def focal2fov(focal, pixels): 24 | return 2 * math.atan(pixels / (2 * focal)) 25 | 26 | 27 | def convert2format(train_dir, test_dir, output_dir): 28 | 29 | train_poses_bounds = np.load( 30 | os.path.join(train_dir, "poses_bounds.npy") 31 | ) # (N_images, 17) 32 | train_image_paths = sorted(glob.glob(os.path.join(train_dir, "images_2/*"))) 33 | test_image_paths = sorted(glob.glob(os.path.join(test_dir, "*.png"))) 34 | 35 | train_poses = train_poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 36 | H, W, focal = train_poses[0, :, -1] # original intrinsics, same for all images 37 | 38 | H, W, focal = ( 39 | H / img_downsample, 40 | W / img_downsample, 41 | focal / img_downsample, 42 | ) # original images are 2x downscaled (same as RodynRF training setup) 43 | 44 | print(f"img_h={H}, img_w={W}") 45 | FoVx = focal2fov(focal, W) * 180 / math.pi 46 | FoVy = focal2fov(focal, H) * 180 / math.pi 47 | 48 | # Original poses has rotation in form "down right back"(LLFF), change to "right down front"(OpenCV) 49 | train_poses = np.concatenate( 50 | [train_poses[..., 1:2], train_poses[..., :1], -train_poses[..., 2:4]], axis=-1 51 | ) 52 | padding = np.array([0, 0, 0, 1]).reshape(1, 1, 4) 53 | train_poses = np.concatenate( 54 | [train_poses, np.tile(padding, (train_poses.shape[0], 1, 1))], axis=-2 55 | ) 56 | 57 | train_outimgdir = os.path.join(output_dir, "train") 58 | os.makedirs(train_outimgdir, exist_ok=True) 59 | test_outimgdir = os.path.join(output_dir, "test") 60 | os.makedirs(test_outimgdir, exist_ok=True) 61 | 62 | train_transforms = dict() 63 | train_transforms["camera_angle_x"] = FoVx 64 | train_transforms["camera_angle_y"] = FoVy 65 | train_transforms["frames"] = [] 66 | test_transforms = dict() 67 | test_transforms["camera_angle_x"] = FoVx 68 | test_transforms["camera_angle_y"] = FoVy 69 | test_transforms["frames"] = [] 70 | 71 | for i, train_dirpath in enumerate(train_image_paths): 72 | 73 | frame = dict() 74 | image_fname = f"rgba_{i:05d}.png" 75 | image = Image.open(train_dirpath) 76 | out_image_fpath = os.path.join(train_outimgdir, image_fname) 77 | image.save(out_image_fpath) 78 | 79 | frame["time"] = i / len(train_image_paths) 80 | frame["file_path"] = os.path.join("train", image_fname) 81 | frame["width"] = int(W) 82 | frame["height"] = int(H) 83 | 84 | c2w = train_poses[i] 85 | frame["transform_matrix"] = c2w.tolist() 86 | train_transforms["frames"].append(frame) 87 | 88 | # c2w of all test cam == c2w of first train cam 89 | if i == 0: 90 | for j, test_dirpath in enumerate(test_image_paths): 91 | frame = dict() 92 | image_fname = f"rgba_{j:05d}.png" 93 | image = Image.open(test_dirpath) 94 | out_image_fpath = os.path.join(test_outimgdir, image_fname) 95 | image.save(out_image_fpath) 96 | 97 | frame["time"] = j / len(test_image_paths) 98 | frame["file_path"] = os.path.join("test", image_fname) 99 | frame["width"] = int(W) 100 | frame["height"] = int(H) 101 | 102 | frame["transform_matrix"] = c2w.tolist() 103 | test_transforms["frames"].append(frame) 104 | with open(os.path.join(output_dir, "test_transforms.json"), "w") as fp: 105 | json.dump(test_transforms, fp, indent=4) 106 | 107 | with open(os.path.join(output_dir, "train_transforms.json"), "w") as fp: 108 | json.dump(train_transforms, fp, indent=4) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--train_dir", required=True, type=str) 114 | parser.add_argument("--test_dir", required=True, type=str) 115 | parser.add_argument("--output_dir", required=True, type=str) 116 | args = parser.parse_args() 117 | 118 | convert2format(args.train_dir, args.test_dir, args.output_dir) 119 | -------------------------------------------------------------------------------- /scripts/run_depthanything.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import argparse 12 | import cv2 13 | import glob 14 | import matplotlib 15 | import numpy as np 16 | import os 17 | import torch 18 | 19 | from thirdparty.depth_anything_v2.depth_anything_v2.dpt import DepthAnythingV2 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description="Depth Anything V2") 24 | 25 | parser.add_argument("--img-path", type=str) 26 | parser.add_argument("--input-size", type=int, default=518) 27 | parser.add_argument("--outdir", type=str, default="./vis_depth") 28 | 29 | parser.add_argument( 30 | "--encoder", type=str, default="vitl", choices=["vits", "vitb", "vitl", "vitg"] 31 | ) 32 | parser.add_argument("--encoder-path", type=str, required=True) 33 | 34 | parser.add_argument( 35 | "--pred-only", 36 | dest="pred_only", 37 | action="store_true", 38 | help="only display the prediction", 39 | ) 40 | parser.add_argument( 41 | "--grayscale", 42 | dest="grayscale", 43 | action="store_true", 44 | help="do not apply colorful palette", 45 | ) 46 | parser.add_argument( 47 | "--raw-depth", 48 | dest="raw_depth", 49 | action="store_true", 50 | help="do not apply colormap", 51 | ) 52 | 53 | args = parser.parse_args() 54 | 55 | DEVICE = ( 56 | "cuda" 57 | if torch.cuda.is_available() 58 | else "mps" if torch.backends.mps.is_available() else "cpu" 59 | ) 60 | 61 | model_configs = { 62 | "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]}, 63 | "vitb": { 64 | "encoder": "vitb", 65 | "features": 128, 66 | "out_channels": [96, 192, 384, 768], 67 | }, 68 | "vitl": { 69 | "encoder": "vitl", 70 | "features": 256, 71 | "out_channels": [256, 512, 1024, 1024], 72 | }, 73 | "vitg": { 74 | "encoder": "vitg", 75 | "features": 384, 76 | "out_channels": [1536, 1536, 1536, 1536], 77 | }, 78 | } 79 | 80 | depth_anything = DepthAnythingV2(**model_configs[args.encoder]) 81 | depth_anything.load_state_dict(torch.load(args.encoder_path, map_location="cpu")) 82 | depth_anything = depth_anything.to(DEVICE).eval() 83 | 84 | if os.path.isfile(args.img_path): 85 | if args.img_path.endswith("txt"): 86 | with open(args.img_path, "r") as f: 87 | filenames = f.read().splitlines() 88 | else: 89 | filenames = [args.img_path] 90 | else: 91 | filenames = glob.glob(os.path.join(args.img_path, "**/*"), recursive=True) 92 | 93 | os.makedirs(args.outdir, exist_ok=True) 94 | 95 | cmap = matplotlib.colormaps.get_cmap("Spectral_r") 96 | 97 | for k, filename in enumerate(filenames): 98 | print(f"Progress {k+1}/{len(filenames)}: {filename}") 99 | 100 | raw_image = cv2.imread(filename) 101 | 102 | depth = depth_anything.infer_image(raw_image, args.input_size) 103 | 104 | if args.raw_depth: 105 | np.save( 106 | os.path.join( 107 | args.outdir, 108 | os.path.splitext(os.path.basename(filename))[0] + ".npy", 109 | ), 110 | depth, 111 | ) 112 | 113 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 114 | depth = depth.astype(np.uint8) 115 | 116 | if args.grayscale: 117 | depth = np.repeat(depth[..., np.newaxis], 3, axis=-1) 118 | else: 119 | depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8) 120 | 121 | if args.pred_only: 122 | cv2.imwrite( 123 | os.path.join( 124 | args.outdir, 125 | os.path.splitext(os.path.basename(filename))[0] + ".png", 126 | ), 127 | depth, 128 | ) 129 | else: 130 | split_region = np.ones((raw_image.shape[0], 50, 3), dtype=np.uint8) * 255 131 | combined_result = cv2.hconcat([raw_image, split_region, depth]) 132 | 133 | cv2.imwrite( 134 | os.path.join( 135 | args.outdir, 136 | os.path.splitext(os.path.basename(filename))[0] + ".png", 137 | ), 138 | combined_result, 139 | ) 140 | -------------------------------------------------------------------------------- /scripts/run_mast3r/depth_preprocessor/get_pcd.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os, cv2 3 | import numpy as np 4 | from PIL import Image 5 | from glob import glob 6 | from argparse import ArgumentParser 7 | 8 | from scripts.run_mast3r.depth_preprocessor.utils import resize_to_mast3r 9 | from scripts.run_mast3r.depth_preprocessor.pcd_utils import unproject_depth 10 | 11 | 12 | def mast3r_unprojection( 13 | data_dir, 14 | maskpaths, 15 | imagepaths, 16 | img_w, 17 | img_h, 18 | skip_dynamic, 19 | static_dst_dir="static", 20 | dynamic_dst_dir="dynamic", 21 | depth_dir="depth", 22 | ): 23 | assert data_dir is not None, "Please provide a depth name" 24 | 25 | # Load the global params 26 | pkl_path = os.path.join(data_dir, "global_params.pkl") 27 | with open(pkl_path, "rb") as pkl_file: 28 | data = pickle.load(pkl_file) 29 | 30 | focal = data["focals"][0] 31 | depth_max = data["max_depths"][0] 32 | depths = np.array(data["depths"]) 33 | 34 | depths *= depth_max 35 | depths = np.clip(depths, 0, depth_max) 36 | 37 | static_pcd_path = os.path.join(data_dir, static_dst_dir) 38 | os.makedirs(static_pcd_path, exist_ok=True) 39 | depth_dst_dir = os.path.join(data_dir, depth_dir) 40 | os.makedirs(depth_dst_dir, exist_ok=True) 41 | 42 | # skip masked unprojection (for tanks and temples) 43 | if skip_dynamic: 44 | for i, imgpath in enumerate(list(imagepaths)): 45 | img = cv2.imread(imgpath) 46 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 47 | img = resize_to_mast3r(img, dst_size=None) 48 | extrinsic = data["cam2worlds"][i] 49 | depth = depths[i].reshape(img.shape[:2]) 50 | static_pcd_path = os.path.join(data_dir, static_dst_dir) 51 | static_pcd_save_path = os.path.join(static_pcd_path, f"{i:04d}_static.ply") 52 | unproject_depth( 53 | focal, 54 | extrinsic, 55 | img, 56 | depth, 57 | export_path=static_pcd_save_path, 58 | mask=None, 59 | ) 60 | depth_save_path = os.path.join(depth_dst_dir, f"{i:05}_depth.npy") 61 | np.save(depth_save_path, depth.reshape(img_h, img_w)) 62 | return 63 | 64 | dynamic_pcd_path = os.path.join(data_dir, dynamic_dst_dir) 65 | os.makedirs(dynamic_pcd_path, exist_ok=True) 66 | 67 | loader = list(zip(imagepaths, maskpaths)) 68 | for i, (imgpath, maskpath) in enumerate(loader): 69 | img = cv2.imread(imgpath) 70 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 71 | img = resize_to_mast3r(img, dst_size=None) 72 | 73 | mask = cv2.imread(maskpath, cv2.IMREAD_GRAYSCALE) 74 | mask = resize_to_mast3r(mask, dst_size=None) 75 | mask = mask > 0 76 | 77 | extrinsic = data["cam2worlds"][i] 78 | depth = depths[i].reshape(img.shape[:2]) 79 | 80 | dynamic_pcd_save_path = os.path.join(dynamic_pcd_path, f"{i:04d}_dynamic.ply") 81 | unproject_depth( 82 | focal, extrinsic, img, depth, export_path=dynamic_pcd_save_path, mask=mask 83 | ) 84 | static_pcd_save_path = os.path.join(static_pcd_path, f"{i:04d}_static.ply") 85 | unproject_depth( 86 | focal, extrinsic, img, depth, export_path=static_pcd_save_path, mask=~mask 87 | ) 88 | 89 | depth_save_path = os.path.join(depth_dst_dir, f"{i:05}_depth.npy") 90 | np.save(depth_save_path, depth.reshape(img_h, img_w)) 91 | 92 | 93 | def check_all_masks_false(maskpaths): 94 | for maskpath in maskpaths: 95 | # Load the image and convert it to a numpy array 96 | mask = np.array(Image.open(maskpath)) 97 | if np.any(mask): 98 | return False 99 | 100 | return True 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = ArgumentParser() 105 | parser.add_argument("--datadir", type=str, required=True) 106 | parser.add_argument("--mast3r_expname", type=str, required=True) 107 | parser.add_argument("--mask_name", type=str) 108 | 109 | args = parser.parse_args() 110 | skip_dynamic = False 111 | 112 | datadir = args.datadir 113 | mast3r_exp_dir = os.path.join(datadir, "mast3r_opt", args.mast3r_expname) 114 | 115 | # load mast3r format image 116 | imagepaths = sorted(glob(f"{datadir}/train/*.png")) 117 | 118 | # load dynamic mask 119 | maskpaths = sorted(glob(f"{datadir}/{args.mask_name}/*.png")) 120 | if len(maskpaths) == 0: 121 | maskpaths = sorted( 122 | glob(f"{datadir}/{args.mask_name}/*.jpg") 123 | ) # mask should not be jpg... 124 | print(f"\nload mask from {maskpaths[0]} ~ {maskpaths[-1]}\n") 125 | 126 | if check_all_masks_false(maskpaths): 127 | skip_dynamic = True 128 | print("\nNo Dynamic regions found. Skip dynamic unprojection\n") 129 | 130 | with open(os.path.join(mast3r_exp_dir, "global_params.pkl"), "rb") as f: 131 | global_params = pickle.load(f) 132 | mast3r_img_h = len(global_params["masks"][0]) 133 | mast3r_img_w = len(global_params["masks"][0][0]) 134 | 135 | mast3r_unprojection( 136 | mast3r_exp_dir, maskpaths, imagepaths, mast3r_img_w, mast3r_img_h, skip_dynamic 137 | ) 138 | -------------------------------------------------------------------------------- /scripts/run_mast3r/depth_preprocessor/pcd_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import trimesh 4 | 5 | from functools import lru_cache 6 | 7 | 8 | @lru_cache 9 | def mask110(device, dtype): 10 | return torch.tensor((1, 1, 0), device=device, dtype=dtype) 11 | 12 | 13 | def proj3d(inv_K, pixels, z): 14 | if pixels.shape[-1] == 2: 15 | pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1) 16 | return z.unsqueeze(-1) * ( 17 | pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype) 18 | ) 19 | 20 | 21 | # from dust3r point transformation 22 | def geotrf(Trf, pts, ncol=None, norm=False): 23 | """Apply a geometric transformation to a list of 3-D points. 24 | 25 | H: 3x3 or 4x4 projection matrix (typically a Homography) 26 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) 27 | 28 | ncol: int. number of columns of the result (2 or 3) 29 | norm: float. if != 0, the resut is projected on the z=norm plane. 30 | 31 | Returns an array of projected 2d points. 32 | """ 33 | assert Trf.ndim >= 2 34 | if isinstance(Trf, np.ndarray): 35 | pts = np.asarray(pts) 36 | elif isinstance(Trf, torch.Tensor): 37 | pts = torch.as_tensor(pts, dtype=Trf.dtype) 38 | 39 | # adapt shape if necessary 40 | output_reshape = pts.shape[:-1] 41 | ncol = ncol or pts.shape[-1] 42 | 43 | # optimized code 44 | if ( 45 | isinstance(Trf, torch.Tensor) 46 | and isinstance(pts, torch.Tensor) 47 | and Trf.ndim == 3 48 | and pts.ndim == 4 49 | ): 50 | d = pts.shape[3] 51 | if Trf.shape[-1] == d: 52 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) 53 | elif Trf.shape[-1] == d + 1: 54 | pts = ( 55 | torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) 56 | + Trf[:, None, None, :d, d] 57 | ) 58 | else: 59 | raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") 60 | else: 61 | if Trf.ndim >= 3: 62 | n = Trf.ndim - 2 63 | assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" 64 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) 65 | 66 | if pts.ndim > Trf.ndim: 67 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) 68 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) 69 | elif pts.ndim == 2: 70 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) 71 | pts = pts[:, None, :] 72 | 73 | if pts.shape[-1] + 1 == Trf.shape[-1]: 74 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 75 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] 76 | elif pts.shape[-1] == Trf.shape[-1]: 77 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 78 | pts = pts @ Trf 79 | else: 80 | pts = Trf @ pts.T 81 | if pts.ndim >= 2: 82 | pts = pts.swapaxes(-1, -2) 83 | 84 | if norm: 85 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG 86 | if norm != 1: 87 | pts *= norm 88 | 89 | res = pts[..., :ncol].reshape(*output_reshape, ncol) 90 | return res 91 | 92 | 93 | def unproject_depth(focal, extrinsic, image, depth, export_path=None, mask=None): 94 | h, w, _ = image.shape 95 | 96 | # extract color 97 | pixels = torch.tensor(np.mgrid[:w, :h].T.reshape(-1, 2)) 98 | K = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 99 | points = proj3d( 100 | torch.tensor(np.linalg.inv(K)), pixels, torch.tensor(depth.reshape(-1)) 101 | ) 102 | 103 | color = image.reshape(-1, 3) 104 | 105 | if mask is not None: 106 | mask = mask.ravel() 107 | color = color[mask] 108 | points = points[mask.ravel()] 109 | 110 | # convert to world coordinates 111 | # points = np.stack([x, y, z], axis=-1) 112 | points = geotrf(torch.tensor(extrinsic), torch.tensor(points)) 113 | points = np.asarray(points).reshape(-1, 3) 114 | 115 | # save to ply with img color 116 | if export_path is not None: 117 | mesh = trimesh.PointCloud(vertices=points, colors=color) 118 | mesh.export(export_path) 119 | 120 | return points 121 | -------------------------------------------------------------------------------- /scripts/run_mast3r/depth_preprocessor/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import PIL 4 | 5 | from functools import lru_cache 6 | 7 | import torch 8 | import numpy as np 9 | import trimesh 10 | 11 | 12 | def _resize_pil_image(img, long_edge_size): 13 | S = max(img.size) 14 | if S > long_edge_size: 15 | interp = PIL.Image.LANCZOS 16 | elif S <= long_edge_size: 17 | interp = PIL.Image.BICUBIC 18 | new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size) 19 | return img.resize(new_size, interp) 20 | 21 | 22 | def resize_to_mast3r(img, dst_size=None, size=512, square_ok=False): 23 | # convert to PIL image 24 | if isinstance(img, str): 25 | img = PIL.Image.open(img) 26 | elif isinstance(img, np.ndarray): 27 | img = PIL.Image.fromarray(img) 28 | elif not isinstance(img, PIL.Image.Image): 29 | raise ValueError(f"Invalid input type: {type(img)}") 30 | 31 | W1, H1 = img.size 32 | 33 | # resize long side to 512 34 | img = _resize_pil_image(img, size) 35 | 36 | W, H = img.size 37 | cx, cy = W // 2, H // 2 38 | 39 | halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8 40 | if not (square_ok) and W == H: 41 | halfh = 3 * halfw / 4 42 | img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) 43 | 44 | # convert to numpy array 45 | if dst_size is not None and dst_size != (W, H): 46 | img = img.resize(dst_size, PIL.Image.BICUBIC) 47 | 48 | img = np.array(img) 49 | return img 50 | 51 | 52 | @lru_cache 53 | def mask110(device, dtype): 54 | return torch.tensor((1, 1, 0), device=device, dtype=dtype) 55 | 56 | 57 | def proj3d(inv_K, pixels, z): 58 | if pixels.shape[-1] == 2: 59 | pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1) 60 | return z.unsqueeze(-1) * ( 61 | pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype) 62 | ) 63 | 64 | 65 | # from dust3r point transformation 66 | def geotrf(Trf, pts, ncol=None, norm=False): 67 | """Apply a geometric transformation to a list of 3-D points. 68 | 69 | H: 3x3 or 4x4 projection matrix (typically a Homography) 70 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) 71 | 72 | ncol: int. number of columns of the result (2 or 3) 73 | norm: float. if != 0, the resut is projected on the z=norm plane. 74 | 75 | Returns an array of projected 2d points. 76 | """ 77 | assert Trf.ndim >= 2 78 | if isinstance(Trf, np.ndarray): 79 | pts = np.asarray(pts) 80 | elif isinstance(Trf, torch.Tensor): 81 | pts = torch.as_tensor(pts, dtype=Trf.dtype) 82 | 83 | # adapt shape if necessary 84 | output_reshape = pts.shape[:-1] 85 | ncol = ncol or pts.shape[-1] 86 | 87 | # optimized code 88 | if ( 89 | isinstance(Trf, torch.Tensor) 90 | and isinstance(pts, torch.Tensor) 91 | and Trf.ndim == 3 92 | and pts.ndim == 4 93 | ): 94 | d = pts.shape[3] 95 | if Trf.shape[-1] == d: 96 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) 97 | elif Trf.shape[-1] == d + 1: 98 | pts = ( 99 | torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) 100 | + Trf[:, None, None, :d, d] 101 | ) 102 | else: 103 | raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") 104 | else: 105 | if Trf.ndim >= 3: 106 | n = Trf.ndim - 2 107 | assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" 108 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) 109 | 110 | if pts.ndim > Trf.ndim: 111 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) 112 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) 113 | elif pts.ndim == 2: 114 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) 115 | pts = pts[:, None, :] 116 | 117 | if pts.shape[-1] + 1 == Trf.shape[-1]: 118 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 119 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] 120 | elif pts.shape[-1] == Trf.shape[-1]: 121 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 122 | pts = pts @ Trf 123 | else: 124 | pts = Trf @ pts.T 125 | if pts.ndim >= 2: 126 | pts = pts.swapaxes(-1, -2) 127 | 128 | if norm: 129 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG 130 | if norm != 1: 131 | pts *= norm 132 | 133 | res = pts[..., :ncol].reshape(*output_reshape, ncol) 134 | return res 135 | 136 | 137 | def unproject_depth(focal, extrinsic, image, depth, export_path=None, mask=None): 138 | h, w, _ = image.shape 139 | 140 | # extract color 141 | pixels = torch.tensor(np.mgrid[:w, :h].T.reshape(-1, 2)) 142 | K = np.array([[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]) 143 | points = proj3d( 144 | torch.tensor(np.linalg.inv(K)), pixels, torch.tensor(depth.reshape(-1)) 145 | ) 146 | 147 | color = image.reshape(-1, 3) 148 | 149 | if mask is not None: 150 | mask = mask.ravel() 151 | color = color[mask] 152 | points = points[mask.ravel()] 153 | 154 | # convert to world coordinates 155 | # points = np.stack([x, y, z], axis=-1) 156 | points = geotrf(torch.tensor(extrinsic), torch.tensor(points)) 157 | points = np.asarray(points).reshape(-1, 3) 158 | 159 | # save to ply with img color 160 | if export_path is not None: 161 | mesh = trimesh.PointCloud(vertices=points, colors=color) 162 | mesh.export(export_path) 163 | 164 | return points 165 | -------------------------------------------------------------------------------- /scripts/run_mast3r/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 3 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | # 5 | # -------------------------------------------------------- 6 | # sparse gradio demo functions 7 | # -------------------------------------------------------- 8 | import math 9 | import os 10 | from glob import glob 11 | import numpy as np 12 | import trimesh 13 | import copy 14 | from scipy.spatial.transform import Rotation 15 | from pathlib import Path 16 | import argparse 17 | import pickle 18 | import math 19 | import sys 20 | import cv2 21 | 22 | sys.path.insert(0, "./thirdparty/mast3r") 23 | 24 | from mast3r.cloud_opt.sparse_ga import sparse_global_alignment 25 | from mast3r.cloud_opt.utils.schedules import cosine_schedule 26 | from dust3r.image_pairs import make_pairs 27 | from dust3r.utils.image import load_images 28 | from dust3r.utils.device import to_numpy 29 | import matplotlib.pyplot as pl 30 | from mast3r.model import AsymmetricMASt3R 31 | import torch 32 | 33 | torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 34 | 35 | 36 | def get_sparse_optim_args(config): 37 | optim_args = { 38 | "lr1": 0.2, 39 | "niter1": 500, 40 | "lr2": 0.02, 41 | "niter2": 500, 42 | "opt_pp": True, 43 | "opt_depth": True, 44 | "schedule": "cosine", 45 | "depth_mode": "add", 46 | "exp_depth": False, 47 | "lora_depth": False, 48 | "shared_intrinsics": True, 49 | "device": "cuda", 50 | "dtype": torch.float32, 51 | "matching_conf_thr": 5.0, 52 | "loss_dust3r_w": 0.01, 53 | } 54 | 55 | # update with config 56 | for key, value in config.items(): 57 | if key in optim_args: 58 | optim_args[key] = value 59 | optim_args["opt_depth"] = "depth" in config["optim_level"] 60 | optim_args["schedule"] = cosine_schedule 61 | 62 | return optim_args 63 | 64 | 65 | def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type, winsize=10): 66 | num_files = len(inputfiles) if inputfiles is not None else 1 67 | max_winsize, min_winsize = 1, 1 68 | if scenegraph_type == "swin": 69 | if win_cyclic: 70 | max_winsize = max(1, math.ceil((num_files - 1) / 2)) 71 | else: 72 | max_winsize = num_files - 1 73 | elif scenegraph_type == "logwin": 74 | if win_cyclic: 75 | half_size = math.ceil((num_files - 1) / 2) 76 | max_winsize = max(1, math.ceil(math.log(half_size, 2))) 77 | else: 78 | max_winsize = max(1, math.ceil(math.log(num_files, 2))) 79 | winsize = min(max_winsize, max(min_winsize, winsize)) 80 | 81 | return winsize, win_cyclic 82 | 83 | 84 | def get_geometries_from_scene(scene, clean_depth, mask_sky, min_conf_thr): 85 | 86 | # post processes - clean depth, mask sky 87 | if mask_sky: 88 | scene = scene.mask_sky() 89 | 90 | # get optimized values from scene 91 | rgbimg, focals = scene.imgs, scene.get_focals().cpu() 92 | 93 | print("obtained focal length : ", focals) 94 | cams2world = scene.get_im_poses().cpu() 95 | 96 | # 3D pointcloud from depthmap, poses and intrinsics 97 | pts3d, depths, confs = to_numpy( 98 | scene.get_dense_pts3d(clean_depth=clean_depth) 99 | ) # ref mast3r/cloud_opt SparseGA() class 100 | msk = to_numpy([c > min_conf_thr for c in confs]) 101 | 102 | # get normalized depthmaps 103 | depths_max = max([d.max() for d in depths]) # type: ignore 104 | depths = [d / depths_max for d in depths] 105 | 106 | return rgbimg, pts3d, msk, focals, cams2world, depths, depths_max 107 | 108 | 109 | def points_to_pct(pts, color, extrinsic, save_path=None): 110 | rot = np.eye(4) 111 | rot[:3, :3] = Rotation.from_euler("y", [np.deg2rad(180)]).as_matrix() 112 | 113 | pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=color.reshape(-1, 3)) 114 | pct.apply_transform(np.linalg.inv(extrinsic)) 115 | 116 | if save_path is not None: 117 | os.makedirs(Path(save_path).parent, exist_ok=True) 118 | pct.export(save_path) 119 | return pct 120 | 121 | 122 | def save_each_geometry( 123 | outdir, 124 | imgs, 125 | pts3d, 126 | mask, 127 | focals, 128 | cams2world, 129 | imgname=None, 130 | depths=None, 131 | depths_max=None, 132 | filter_pct=True, 133 | ): 134 | print( 135 | f"len(pts3d) : {len(pts3d)}, len(mask) : {len(mask)}, len(imgs) : {len(imgs)}, len(cams2world) : {len(cams2world)}" 136 | ) 137 | assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) 138 | extrinsic = cams2world[0].detach().numpy() 139 | pts3d, imgs, focals, cams2world = ( 140 | to_numpy(pts3d), 141 | to_numpy(imgs), 142 | to_numpy(focals), 143 | to_numpy(cams2world), 144 | ) 145 | 146 | if imgname is not None: 147 | outdir = os.path.join(outdir, imgname["data"]) 148 | 149 | # original : dict_keys(['focal', 'cam2worlds', 'pct2worlds', 'pointcloud_paths', 'max_depths', 'depths']) 150 | global_dict = { 151 | "focals": [], 152 | "cam2worlds": [], 153 | "pointcloud_paths": [], 154 | "max_depths": [], 155 | "depths": [], 156 | "masks": [], 157 | } 158 | 159 | args_iter = ( 160 | zip(pts3d, imgs, mask, focals, cams2world) 161 | if depths is None 162 | else zip(pts3d, imgs, mask, focals, cams2world, depths) 163 | ) 164 | for i, arguments in enumerate(args_iter): 165 | if depths is None: 166 | points, img, point_mask, focal, cam2world = arguments 167 | depth, depths_max = None, None 168 | else: 169 | points, img, point_mask, focal, cam2world, depth = arguments 170 | 171 | if filter_pct: 172 | pts = points[point_mask.ravel()].reshape(-1, 3) 173 | col = img[point_mask].reshape(-1, 3) 174 | valid_msk = np.isfinite(pts.sum(axis=1)) 175 | pts, col = pts[valid_msk], col[valid_msk] 176 | else: 177 | pts, col = points.reshape(-1, 3), img.reshape(-1, 3) 178 | 179 | filename = ( 180 | f"pointcloud_{i:04d}.ply" 181 | if imgname is None 182 | else f"{imgname['img_nums'][i]:04d}_pointcloud_{i:04d}.ply" 183 | ) 184 | points_to_pct( 185 | pts, 186 | col, 187 | np.eye(extrinsic.shape[0]), 188 | save_path=os.path.join(outdir, filename), 189 | ) 190 | cam_params = { 191 | "focal": focal, 192 | "cam2world": cam2world, 193 | "c2w_original": cam2world, 194 | "depth": depth, 195 | "depth_max": depths_max, 196 | "base_extrinsic": extrinsic, 197 | "imgname": imgname, 198 | } 199 | maskdir = os.path.join(outdir, "masks") 200 | os.makedirs(maskdir, exist_ok=True) 201 | 202 | # save mask to png 203 | re_mask = point_mask.reshape(depth.shape) 204 | mask = np.zeros_like(re_mask, dtype=np.uint8) 205 | mask[re_mask == 0] = 255 206 | cv2.imwrite(os.path.join(maskdir, f"{i:04d}.png"), mask) 207 | 208 | with open(os.path.join(outdir, filename.replace(".ply", ".pkl")), "wb") as f: 209 | pickle.dump(cam_params, f) 210 | 211 | global_dict["focals"].append(focal) 212 | global_dict["cam2worlds"].append(cam2world) 213 | global_dict["pointcloud_paths"].append(os.path.join(outdir, filename)) 214 | global_dict["max_depths"].append(depths_max) 215 | global_dict["depths"].append(depth) 216 | global_dict["masks"].append(point_mask) 217 | 218 | return global_dict 219 | 220 | 221 | def get_reconstructed_scene( 222 | outdir, 223 | cache_dir, 224 | model, 225 | device, 226 | image_size, 227 | filelist, 228 | optim_level, 229 | lr1, 230 | niter1, 231 | lr2, 232 | niter2, 233 | min_conf_thr, 234 | matching_conf_thr, 235 | mask_sky, 236 | clean_depth, 237 | scenegraph_type, 238 | winsize, 239 | win_cyclic, 240 | refid, 241 | TSDF_thresh, 242 | shared_intrinsics, 243 | filter_pct, 244 | optim_args, 245 | **kw, 246 | ): 247 | """ 248 | from a list of images, run mast3r inference, sparse global aligner. 249 | then run get_3D_model_from_scene 250 | """ 251 | imgs = load_images(filelist, size=image_size, verbose=True) 252 | if len(imgs) == 1: 253 | imgs = [imgs[0], copy.deepcopy(imgs[0])] 254 | imgs[1]["idx"] = 1 255 | filelist = [filelist[0], filelist[0] + "_2"] 256 | 257 | scene_graph_params = [scenegraph_type] 258 | if scenegraph_type in ["swin", "logwin"]: 259 | scene_graph_params.append(str(winsize)) 260 | elif scenegraph_type == "oneref": 261 | scene_graph_params.append(str(refid)) 262 | if scenegraph_type in ["swin", "logwin"] and not win_cyclic: 263 | scene_graph_params.append("noncyclic") 264 | scene_graph = "-".join(scene_graph_params) 265 | pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True) 266 | if optim_level == "coarse": 267 | niter2 = 0 268 | 269 | # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation) 270 | os.makedirs(cache_dir, exist_ok=True) 271 | scene = sparse_global_alignment(filelist, pairs, cache_dir, model, **optim_args) 272 | 273 | rgbimg, pts3d, msk, focals, cams2world, depths, depths_max = ( 274 | get_geometries_from_scene(scene, clean_depth, mask_sky, min_conf_thr) 275 | ) 276 | global_dict = save_each_geometry( 277 | os.path.join(outdir, "op_results"), 278 | rgbimg, 279 | pts3d, 280 | msk, 281 | focals, 282 | cams2world, 283 | imgname=None, 284 | depths=depths, 285 | depths_max=depths_max, 286 | filter_pct=filter_pct, 287 | ) 288 | 289 | with open(os.path.join(outdir, "global_params.pkl"), "wb") as f: 290 | pickle.dump(global_dict, f) 291 | 292 | 293 | def main(): 294 | 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument("--input_dir", type=str, default="data", help="input directory") 297 | parser.add_argument( 298 | "--output_dir", type=str, default="output", help="output directory" 299 | ) 300 | parser.add_argument("--exp_name", type=str, default="exp", help="experiment name") 301 | parser.add_argument( 302 | "--ckpt", 303 | type=str, 304 | default="checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", 305 | help="mast3r ckpt", 306 | ) 307 | parser.add_argument( 308 | "--cache_dir", type=str, default="optim_cache", help="cache directory" 309 | ) 310 | 311 | args = parser.parse_args() 312 | 313 | device = "cuda" 314 | model = AsymmetricMASt3R.from_pretrained(args.ckpt).to(device) 315 | 316 | filelist = sorted(glob(os.path.join(args.input_dir, "*.png"))) 317 | 318 | cache_dir = os.path.join( 319 | args.cache_dir, 320 | f"{os.path.basename(os.path.dirname(args.input_dir))}_{np.random.randint(1e6):05d}", 321 | ) 322 | outdir = os.path.join(args.output_dir, args.exp_name + "_000") 323 | 324 | data = {"cache_dir": "optim_cache"} 325 | optimization = { 326 | "device": "cuda", 327 | "image_size": 512, 328 | "shared_intrinsics": True, 329 | "win_cyclic": False, 330 | "lr1": 0.07, 331 | "niter1": 500, 332 | "lr2": 0.014, 333 | "niter2": 200, 334 | "optim_level": "refine+depth", 335 | "scenegraph_type": "swin", 336 | "winsize": 10, 337 | "refid": 0, 338 | "TSDF_thresh": 0, 339 | "schedule": "cosine", 340 | "min_conf_thr": 1.5, 341 | "matching_conf_thr": 5, 342 | "mask_sky": False, 343 | "clean_depth": True, 344 | "filter_pct": True, 345 | } 346 | arg_dict = dict() 347 | 348 | arg_dict.update(data) 349 | arg_dict.update(optimization) 350 | arg_dict.update({"filelist": filelist, "model": model, "schedule": cosine_schedule}) 351 | arg_dict.update( 352 | {"outdir": outdir, "input_dir": args.input_dir, "cache_dir": cache_dir} 353 | ) 354 | winsize, win_cyclic = set_scenegraph_options(filelist, False, 0, "swin", 10) 355 | arg_dict.update({"winsize": winsize, "win_cyclic": win_cyclic}) 356 | arg_dict.update({"optim_args": get_sparse_optim_args(optimization)}) 357 | 358 | get_reconstructed_scene(**arg_dict) 359 | 360 | 361 | if __name__ == "__main__": 362 | main() 363 | -------------------------------------------------------------------------------- /scripts/tam_npy2png.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | 12 | import numpy as np 13 | from PIL import Image 14 | import os 15 | import argparse 16 | 17 | 18 | def convert_npy_to_png(input_dir, output_dir): 19 | os.makedirs(os.path.join(output_dir, "tam_mask"), exist_ok=True) 20 | 21 | for file_name in os.listdir(input_dir): 22 | if file_name.endswith(".npy"): 23 | input_file_path = os.path.join(input_dir, file_name) 24 | output_file_name = f"{int(os.path.splitext(file_name)[0]):06d}" + ".png" 25 | output_tam_path = os.path.join(output_dir, "tam_mask", output_file_name) 26 | 27 | try: 28 | motion_mask = np.load(input_file_path) 29 | binary_image = (motion_mask * 255).astype(np.uint8) 30 | 31 | Image.fromarray(binary_image).save(output_tam_path) 32 | print(f"Converted: {input_file_path} -> {output_tam_path}") 33 | except Exception as e: 34 | print(f"Failed to convert {input_file_path}: {e}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--npy_dir", required=True, type=str) 40 | parser.add_argument("--output_dir", required=True, type=str) 41 | args = parser.parse_args() 42 | 43 | convert_npy_to_png(args.npy_dir, args.output_dir) 44 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /src/data/asset_readers.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import os 12 | from pathlib import Path 13 | from PIL import Image 14 | import json 15 | import pickle 16 | 17 | import numpy as np 18 | import torch 19 | from torchvision.transforms import ToTensor 20 | 21 | from src.utils.graphic_utils import focal2fov, quaternion_to_matrix 22 | from src.utils.point_utils import uniform_sample, merge_pcds 23 | from .utils import fetchPly 24 | 25 | 26 | class GTCameraReader: 27 | # To read full poses (for evaluation) 28 | 29 | def __init__(self, dirpath, fname, **kwargs): 30 | 31 | poses = [] 32 | 33 | with open(os.path.join(dirpath, fname)) as json_file: 34 | contents = json.load(json_file) 35 | fovx = contents["camera_angle_x"] 36 | for frame in contents["frames"]: 37 | c2w = np.array(frame["transform_matrix"], dtype=np.float32) 38 | poses.append(torch.from_numpy(c2w)) 39 | self._poses = np.array(torch.stack(poses)) 40 | self._fovx = np.deg2rad(fovx) 41 | 42 | def get_poses(self, idx=None): 43 | if idx is None: 44 | return self._poses 45 | else: 46 | return self._poses[idx] 47 | 48 | def get_fovx(self, idx): 49 | return self._fovx 50 | 51 | 52 | class DepthAnythingReader: 53 | prefix: str = "depth_anything" 54 | 55 | def __init__(self, **kwargs): 56 | pass 57 | 58 | def __call__(self, dirpath, basename): 59 | asset_path = Path(dirpath).joinpath(self.prefix) 60 | base_name_wo_ext = os.path.splitext(basename)[0] 61 | base_name_tiff = base_name_wo_ext + ".npy" 62 | file_path = asset_path.joinpath(base_name_tiff) 63 | depth = -torch.from_numpy(np.load(file_path.as_posix())).unsqueeze(0) 64 | return (depth - depth.min()) / (depth.max() - depth.min()) 65 | 66 | 67 | class TAMMaskReader: 68 | prefix: str = "tam_mask" 69 | to_tensor = ToTensor() 70 | 71 | def __init__(self, split, resolution=1): 72 | assert split in ["train", "val", "test"] 73 | self.prefix = f"tam_mask" 74 | self.resolution = resolution 75 | 76 | def __call__(self, dirpath, basename): 77 | asset_path = Path(dirpath).joinpath(self.prefix) 78 | base_name_wo_ext = os.path.splitext(basename)[0] 79 | rgb_idx = base_name_wo_ext.split("_")[-1] 80 | rgb_idx = rgb_idx.zfill(6) 81 | base_name_png = f"{rgb_idx}.jpg" 82 | file_path = asset_path.joinpath(base_name_png) 83 | if not file_path.exists(): 84 | base_name_png = f"{rgb_idx}.png" 85 | file_path = asset_path.joinpath(base_name_png) 86 | if self.resolution != 1: 87 | img = Image.open(file_path) 88 | w, h = img.size 89 | new_size = (w // self.resolution, h // self.resolution) 90 | return self.to_tensor(img.resize(new_size, Image.NEAREST)) > 0 91 | else: 92 | return self.to_tensor(Image.open(file_path)) > 0 93 | 94 | 95 | class Test_MASt3RFovCameraReader: 96 | # To read full poses and trained fov (for evaluation) 97 | 98 | def __init__(self, dirpath, fname, mast3r_expname, mast3r_img_res, **kwargs): 99 | 100 | self.dirname = "mast3r_opt" 101 | poses = [] 102 | 103 | # read gt test poses from test_transforms.json 104 | with open(os.path.join(dirpath, fname)) as json_file: 105 | contents = json.load(json_file) 106 | for frame in contents["frames"]: 107 | c2w = frame["transform_matrix"] 108 | poses.append(c2w) 109 | self._poses = np.array(poses, dtype=np.float32) 110 | 111 | # read fov from dust3r init 112 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl") 113 | with open(pkl_path.as_posix(), "rb") as pkl_file: 114 | data = pickle.load(pkl_file) 115 | 116 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res) 117 | 118 | def get_poses(self, idx=None): 119 | if idx is None: 120 | return self._poses 121 | else: 122 | return self._poses[idx] 123 | 124 | def get_fovx(self, idx): 125 | return self._fovx 126 | 127 | 128 | class MASt3RCameraReader: 129 | 130 | dirname = "mast3r_opt" 131 | 132 | def __init__(self, dirpath, mast3r_expname, mast3r_img_res, **kwargs): 133 | 134 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl") 135 | with open(pkl_path.as_posix(), "rb") as pkl_file: 136 | data = pickle.load(pkl_file) 137 | 138 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res) 139 | self._poses = data["cam2worlds"] 140 | 141 | def get_poses(self, idx): 142 | return self._poses[idx] 143 | 144 | def get_fovx(self, idx): 145 | return self._fovx 146 | 147 | 148 | class MASt3R_CKPTCameraReader: 149 | 150 | dirname = "mast3r_opt" 151 | 152 | def __init__(self, dirpath, ckpt_path, mast3r_expname, mast3r_img_res, **kwargs): 153 | 154 | # read poses from ckpt 155 | ckpt = torch.load(ckpt_path) 156 | c2w_rot = quaternion_to_matrix(ckpt[0]["camera"]["R_c2ws_quat"]) 157 | c2w_trans = ckpt[0]["camera"]["T_c2ws"].unsqueeze(-1) 158 | c2w = torch.cat((c2w_rot, c2w_trans), dim=-1) 159 | bottom_row = ( 160 | torch.tensor([[0, 0, 0, 1]], dtype=c2w_rot.dtype) 161 | .expand(c2w.shape[0], -1, -1) 162 | .cuda() 163 | ) 164 | self._poses = torch.cat((c2w, bottom_row), dim=1).cpu().numpy() 165 | 166 | # read fovx from dust3r init 167 | pkl_path = Path(dirpath, self.dirname, mast3r_expname, "global_params.pkl") 168 | with open(pkl_path.as_posix(), "rb") as pkl_file: 169 | data = pickle.load(pkl_file) 170 | self._fovx = focal2fov(data["focals"][0], mast3r_img_res) 171 | 172 | def get_poses(self, idx): 173 | return self._poses[idx] 174 | 175 | def get_fovx(self, idx): 176 | return self._fovx 177 | 178 | 179 | class MASt3RPCDReader: 180 | 181 | path = "./op_results" 182 | dirname = "mast3r_opt" 183 | dynamic_path = "./dynamic" 184 | static_path = "./static" 185 | 186 | skip_dynamic = False 187 | 188 | def __init__( 189 | self, 190 | dirpath, 191 | mast3r_expname, 192 | mode=None, 193 | downsample_ratio=0.1, 194 | num_limit_points=None, 195 | **kwargs, 196 | ): 197 | 198 | if not Path(dirpath, self.dirname, mast3r_expname, self.dynamic_path).exists(): 199 | static_pcd_paths = Path( 200 | dirpath, self.dirname, mast3r_expname, self.static_path 201 | ) 202 | static_pcd_file = [ 203 | pth.as_posix() for pth in sorted(static_pcd_paths.glob("*.ply")) 204 | ][0] 205 | static_pcd = fetchPly(static_pcd_file) 206 | self.pcd = static_pcd 207 | self.skip_dynamic = True 208 | return 209 | 210 | if mode == "dynamic": 211 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.dynamic_path) 212 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))] 213 | elif mode == "static": 214 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.static_path) 215 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))] 216 | else: 217 | pcd_paths = Path(dirpath, self.dirname, mast3r_expname, self.path) 218 | pcd_files = [pth.as_posix() for pth in sorted(pcd_paths.glob("*.ply"))] 219 | pcds = [fetchPly(pcd_file) for pcd_file in pcd_files] 220 | 221 | json_file = Path(dirpath, "train_transforms.json") 222 | with open(json_file) as json_file: 223 | contents = json.load(json_file) 224 | times = [frame["time"] for frame in contents["frames"]] 225 | 226 | for idx, pcd in enumerate(pcds): 227 | time = times[idx] 228 | num_points = len(pcd.points) 229 | pcd.time = np.ones(num_points) * time 230 | 231 | merged_pcds = merge_pcds(pcds) 232 | 233 | if num_limit_points is not None: 234 | print(f"override downsample_ratio with num_vertices: {num_limit_points}") 235 | downsample_ratio = min(num_limit_points / len(merged_pcds.points), 1.0) 236 | 237 | self.pcd = uniform_sample(merged_pcds, downsample_ratio) 238 | 239 | def __call__(self): 240 | return self.pcd, self.skip_dynamic 241 | -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import numpy as np 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class WarmupDataLoader: 16 | # Pick a random viewpoint over selectable viewpoints. 17 | # Selectable viewpoints are incrementally increasing as the `register_frame` is called. 18 | is_incremental: bool = True 19 | warmup_rate = 20 20 | 21 | def __init__(self, dataset: Dataset, init_num_frames: int, return_idx=False): 22 | self.dataset = dataset 23 | self.curr_num_frames = init_num_frames 24 | self.total_len = len(dataset) 25 | self.return_idx = return_idx 26 | 27 | num_list = int(np.ceil(self.num_iterations / self.total_len)) 28 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)] 29 | self.idx_list = [i for idx, i in cat_list if idx % self.warmup_rate == 0] 30 | 31 | def __iter__(self): 32 | for random_idx in self.idx_list: 33 | if self.return_idx: 34 | yield random_idx, self.dataset[random_idx] 35 | else: 36 | yield self.dataset[random_idx] 37 | 38 | def register_all_frame(self): 39 | num_list = int(np.ceil(self.num_iterations / self.total_len)) 40 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)] 41 | remainder = self.num_iterations % self.total_len 42 | if remainder != 0: 43 | cat_list[-1] = cat_list[-1][:remainder] 44 | self.idx_list = np.concatenate(cat_list) 45 | 46 | 47 | class PermutationSingleDataLoader: 48 | is_incremental: bool = False 49 | 50 | def __init__(self, dataset: Dataset, num_iterations: int, return_idx: bool = False): 51 | self.dataset = dataset 52 | self.total_len = len(dataset) 53 | self.return_idx = return_idx 54 | self.num_iterations = num_iterations 55 | self.idx_list = self.get_permuted_idx_list() 56 | 57 | def get_permuted_idx_list(self): 58 | num_list = int(np.ceil(self.num_iterations / self.total_len)) 59 | cat_list = [np.random.permutation(self.total_len) for _ in range(num_list)] 60 | remainder = self.num_iterations % self.total_len 61 | if remainder != 0: 62 | cat_list[-1] = cat_list[-1][:remainder] 63 | idx_list = np.concatenate(cat_list) 64 | return idx_list 65 | 66 | def __iter__(self): 67 | for random_idx in self.idx_list: 68 | if self.return_idx: 69 | yield random_idx, self.dataset[random_idx] 70 | else: 71 | yield self.dataset[random_idx] 72 | 73 | 74 | class SequentialSingleDataLoader: 75 | # Call all viewpoints in a sequential manner. 76 | is_incremental: bool = False 77 | 78 | def __init__(self, dataset: Dataset, return_idx: bool = False, **kwargs): 79 | self.dataset = dataset 80 | self.total_len = len(dataset) 81 | self.return_idx = return_idx 82 | 83 | def __len__(self): 84 | return self.total_len 85 | 86 | def __iter__(self): 87 | for idx in range(self.total_len): 88 | if self.return_idx: 89 | yield idx, self.dataset[idx] 90 | else: 91 | yield self.dataset[idx] 92 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import math 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | from plyfile import PlyData 17 | 18 | from src.utils.graphic_utils import ( 19 | getProjectionMatrix, 20 | getWorld2View2, 21 | matrix_to_quaternion, 22 | quaternion_to_matrix, 23 | ) 24 | from src.utils.point_utils import BasicPointCloud 25 | 26 | 27 | class CameraInterface(nn.Module): 28 | 29 | znear = 0.01 30 | zfar = 100.0 31 | trans = 0.0 32 | scale = 1.0 33 | 34 | def __init__( 35 | self, R, T, image, image_name, time, depth, normal, motion_mask, cam_idx 36 | ): 37 | super(CameraInterface, self).__init__() 38 | 39 | self.R = R 40 | self.T = T 41 | self.image_name = image_name 42 | self.image_width = image.shape[2] 43 | self.image_height = image.shape[1] 44 | self.cam_idx = cam_idx 45 | self.register_buffer("time", torch.tensor(time)) 46 | self.register_buffer("original_image", image) 47 | 48 | if depth is not None: 49 | self.register_buffer("depth", depth) 50 | else: 51 | self.depth = None 52 | 53 | if normal is not None: 54 | self.register_buffer("normal", normal) 55 | else: 56 | self.normal = None 57 | 58 | if motion_mask is not None: 59 | self.register_buffer("motion_mask", motion_mask) 60 | else: 61 | self.motion_mask = None 62 | 63 | 64 | class FixedCamera(CameraInterface): 65 | def __init__( 66 | self, 67 | R, 68 | T, 69 | FoVx, 70 | FoVy, 71 | image, 72 | image_name, 73 | time, 74 | depth, 75 | normal, 76 | motion_mask, 77 | cam_idx, 78 | ): 79 | 80 | super(FixedCamera, self).__init__( 81 | R=R, 82 | T=T, 83 | image=image, 84 | image_name=image_name, 85 | time=time, 86 | depth=depth, 87 | normal=normal, 88 | motion_mask=motion_mask, 89 | cam_idx=cam_idx, 90 | ) 91 | self.FoVx = FoVx 92 | self.FoVy = FoVy 93 | 94 | world_view_transform = torch.from_numpy( 95 | getWorld2View2(R, T, self.trans, self.scale) 96 | ) 97 | projection_matrix = getProjectionMatrix( 98 | self.znear, self.zfar, self.FoVx, self.FoVy 99 | ) 100 | 101 | self.register_buffer("world_view_transform", world_view_transform) 102 | self.register_buffer("projection_matrix", projection_matrix) 103 | 104 | 105 | class FixedCameraTorch: 106 | 107 | znear = 0.01 108 | zfar = 100.0 109 | trans = 0.0 110 | scale = 1.0 111 | 112 | def __init__( 113 | self, 114 | R_quat_c2w, 115 | T_c2w, 116 | FoVx, 117 | FoVy, 118 | image, 119 | image_name, 120 | time, 121 | depth, 122 | max_depth, 123 | normal, 124 | motion_mask, 125 | cam_idx, 126 | ): 127 | 128 | self.image_name = image_name 129 | self.image_width = image.shape[2] 130 | self.image_height = image.shape[1] 131 | self.time = torch.tensor(time) 132 | self.original_image = image 133 | 134 | self.depth = depth 135 | self.max_depth = max_depth 136 | self.normal = normal 137 | self.motion_mask = motion_mask 138 | 139 | self.FoVx = FoVx 140 | self.FoVy = FoVy 141 | self.R_quat_c2w = R_quat_c2w 142 | self.T_c2w = T_c2w 143 | projection_matrix = getProjectionMatrix( 144 | self.znear, self.zfar, self.FoVx, self.FoVy 145 | ) 146 | self.projection_matrix = projection_matrix 147 | self.cam_idx = cam_idx 148 | 149 | def cuda(self): 150 | self.original_image = self.original_image.cuda() 151 | if self.depth is not None: 152 | self.depth = self.depth.cuda() 153 | if self.normal is not None: 154 | self.normal = self.normal.cuda() 155 | if self.motion_mask is not None: 156 | self.motion_mask = self.motion_mask.cuda() 157 | self.time = self.time.cuda() 158 | self.projection_matrix = self.projection_matrix.cuda() 159 | return self 160 | 161 | @property 162 | def world_view_transform(self): 163 | R_c2w = quaternion_to_matrix(self.R_quat_c2w) 164 | T_c2w = (self.T_c2w + self.trans) * self.scale 165 | R_w2c = R_c2w.transpose(0, 1) 166 | T_w2c = -torch.einsum("ij, j -> i", R_w2c, T_c2w) 167 | ret = torch.eye(4).type_as(R_c2w) 168 | ret[:3, :3] = R_w2c 169 | ret[:3, 3] = T_w2c 170 | return ret 171 | 172 | 173 | class LearnableCamera(CameraInterface): 174 | 175 | def __init__( 176 | self, 177 | R, 178 | T, 179 | FoVx, 180 | FoVy, 181 | image, 182 | image_name, 183 | time, 184 | depth, 185 | normal, 186 | motion_mask, 187 | cam_idx, 188 | ): 189 | super(LearnableCamera, self).__init__( 190 | R=R, 191 | T=T, 192 | image=image, 193 | image_name=image_name, 194 | time=time, 195 | depth=depth, 196 | normal=normal, 197 | motion_mask=motion_mask, 198 | cam_idx=cam_idx, 199 | ) 200 | 201 | self.image_name = image_name 202 | self.image_width = image.shape[2] 203 | self.image_height = image.shape[1] 204 | 205 | R_tr = np.transpose(R, (1, 0)) 206 | R_c2w = torch.from_numpy(R_tr) 207 | T_c2w = torch.from_numpy(-np.einsum("ij, j", R_tr, T)) 208 | 209 | R_c2w_quat = matrix_to_quaternion(R_c2w) 210 | 211 | self.register_parameter( 212 | "R_c2w_quat", nn.Parameter(R_c2w_quat, requires_grad=True) 213 | ) 214 | self.register_parameter("T_c2w", nn.Parameter(T_c2w, requires_grad=True)) 215 | self.register_parameter( 216 | "FoVx", 217 | nn.Parameter(torch.tensor(FoVx).clone().detach(), requires_grad=False), 218 | ) 219 | self.register_parameter( 220 | "FoVy", 221 | nn.Parameter(torch.tensor(FoVy).clone().detach(), requires_grad=False), 222 | ) 223 | 224 | @property 225 | def world_view_transform(self): 226 | R_c2w = quaternion_to_matrix(self.R_c2w_quat) 227 | T_c2w = (self.T_c2w + self.trans) * self.scale 228 | ret = torch.eye(4).type_as(R_c2w) 229 | rot_transpose = R_c2w.transpose(0, 1) 230 | ret[:3, :3] = rot_transpose 231 | ret[:3, 3] = -torch.einsum("ij, j -> i", rot_transpose, T_c2w) 232 | return ret 233 | 234 | @property 235 | def projection_matrix(self): 236 | tanHalfFovY = math.tan((self.FoVy / 2)) 237 | tanHalfFovX = math.tan((self.FoVx / 2)) 238 | 239 | top = tanHalfFovY * self.znear 240 | bottom = -top 241 | right = tanHalfFovX * self.znear 242 | left = -right 243 | 244 | P = torch.zeros(4, 4).type_as(self.R_c2w_quat) 245 | 246 | z_sign = 1.0 247 | 248 | P[0, 0] = 2.0 * self.znear / (right - left) 249 | P[1, 1] = 2.0 * self.znear / (top - bottom) 250 | P[0, 2] = (right + left) / (right - left) 251 | P[1, 2] = (top + bottom) / (top - bottom) 252 | P[3, 2] = z_sign 253 | P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear) 254 | P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear) 255 | 256 | return P 257 | 258 | 259 | def fetchPly(path): 260 | plydata = PlyData.read(path) 261 | vertices = plydata["vertex"] 262 | positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T 263 | colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 264 | if "nx" in vertices: 265 | normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T 266 | else: 267 | normals = np.zeros_like(positions) 268 | if "time" in vertices: 269 | timestamp = vertices["time"][:, None] 270 | else: 271 | timestamp = None 272 | return BasicPointCloud( 273 | points=positions, colors=colors, normals=normals, time=timestamp 274 | ) 275 | 276 | 277 | def PILtoTorch(pil_image): 278 | image_tensor = torch.from_numpy(np.array(pil_image)) / 255.0 279 | if len(image_tensor.shape) == 3: 280 | return image_tensor.permute(2, 0, 1) 281 | else: 282 | return image_tensor.unsqueeze(dim=-1).permute(2, 0, 1) 283 | -------------------------------------------------------------------------------- /src/evaluator/utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def search_nearest_two(query_pose, db_poses): 16 | """ 17 | query_pose: torch.Tensor (4 X 4) 18 | db_poses: torch.Tensor (N X 4 X 4) 19 | """ 20 | query_t = query_pose[None, :3, 3] 21 | db_t = db_poses[:, :3, 3] 22 | 23 | distances = torch.norm(query_t - db_t, dim=1) 24 | nearest_indices = torch.topk(distances, k=2, largest=False).indices 25 | 26 | return nearest_indices 27 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /src/pipelines/eval.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import argparse 12 | 13 | from pathlib import Path 14 | from omegaconf import OmegaConf 15 | 16 | from src.utils.configs import str2bool, instantiate_from_config 17 | from src.utils.general_utils import seed_all 18 | 19 | 20 | def parse_args(): 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "-m", "--model_path", type=str, required=True, help="path to log" 25 | ) 26 | parser.add_argument( 27 | "-c", 28 | "--eval_config", 29 | required=True, 30 | ) 31 | parser.add_argument( 32 | "-d", 33 | "--dirpath", 34 | type=str, 35 | required=True, 36 | help="path to data directory", 37 | ) 38 | parser.add_argument( 39 | "-t", 40 | "--task_name", 41 | type=str, 42 | default="eval", 43 | help="the name of evaluation task.", 44 | ) 45 | parser.add_argument( 46 | "--verbose", type=str2bool, default=False, help="verbose printing" 47 | ) 48 | parser.add_argument( 49 | "--debug", type=str2bool, default=False, help="debug mode running" 50 | ) 51 | 52 | args, unknown = parser.parse_known_args() 53 | 54 | return args, unknown 55 | 56 | 57 | if __name__ == "__main__": 58 | 59 | args, unknown = parse_args() 60 | 61 | model_path = Path(args.model_path) 62 | 63 | train_config_path = model_path.joinpath("train/config.yaml") 64 | train_config = OmegaConf.load(train_config_path) 65 | eval_config = OmegaConf.load(args.eval_config) 66 | config = OmegaConf.merge(train_config, eval_config) 67 | 68 | out_path = model_path.joinpath(args.task_name) 69 | out_path.mkdir(exist_ok=args.debug) 70 | 71 | static_ckpt_path = model_path.joinpath("train", "static_last.ckpt") 72 | dynamic_ckpt_path = model_path.joinpath("train", "dynamic_last.ckpt") 73 | 74 | seed_all(config.metadata.seed) 75 | 76 | normalize_cams = eval_config.get("normalize_cams", False) 77 | 78 | static_datamodule = instantiate_from_config( 79 | config.static_data, normalize_cams=normalize_cams, ckpt_path=static_ckpt_path 80 | ) 81 | dynamic_datamodule = instantiate_from_config( 82 | config.dynamic_data, normalize_cams=normalize_cams, ckpt_path=dynamic_ckpt_path 83 | ) 84 | static_model = instantiate_from_config(config.static_model) 85 | dynamic_model = instantiate_from_config(config.dynamic_model) 86 | evaluator = instantiate_from_config( 87 | eval_config.evaluator, 88 | dirpath=args.dirpath, 89 | static_datamodule=static_datamodule, 90 | dynamic_datamodule=dynamic_datamodule, 91 | static_model=static_model, 92 | dynamic_model=dynamic_model, 93 | out_path=out_path, 94 | static_ckpt_path=static_ckpt_path, 95 | dynamic_ckpt_path=dynamic_ckpt_path, 96 | ) 97 | 98 | evaluator.eval() 99 | -------------------------------------------------------------------------------- /src/pipelines/train.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import argparse 12 | import os 13 | import logging 14 | import sys 15 | import shutil 16 | 17 | from pathlib import Path 18 | from omegaconf import OmegaConf, DictConfig 19 | 20 | from src.utils.configs import str2bool, is_instantiable, instantiate_from_config 21 | from src.utils.general_utils import seed_all 22 | from src.pipelines.utils import StreamToLogger 23 | 24 | 25 | def check_argument_sanity(args: argparse.Namespace) -> None: 26 | # check whether arguments are given properly. 27 | assert args.name != "", "provide an appropriate expname" 28 | 29 | config_paths = args.base 30 | assert len(config_paths) != 0, "no config given for training" 31 | for config_path in config_paths: 32 | assert os.path.exists(config_path), f"no config exists in path {config_path}" 33 | 34 | if args.verbose: 35 | # When running verbose mode, set the os.environ to globally share 36 | # the running mode in a conveninent way 37 | os.environ["VERBOSE_RUN"] = "True" 38 | 39 | if args.debug: 40 | os.environ["DEBUG_RUN"] = "True" 41 | 42 | assert args.group != "", "specify the group name" 43 | 44 | 45 | def check_config(config: OmegaConf) -> None: 46 | # check whether config has a proper structure to run training 47 | assert is_instantiable(config, "static_data") 48 | assert is_instantiable(config, "static_model") 49 | assert is_instantiable(config, "dynamic_data") 50 | assert is_instantiable(config, "dynamic_model") 51 | assert is_instantiable(config, "trainer") 52 | 53 | 54 | def set_traindir(logdir: str, name: str, group: str, seed: int, debug: bool) -> Path: 55 | # explogdir (logdir/group/name) 56 | is_overridable = debug 57 | 58 | # in the case of group, previous experiments might already have made the directory 59 | path_logdir = Path(logdir, group) 60 | path_logdir_posix = path_logdir.as_posix() 61 | # group logdir is always overridable 62 | os.makedirs(path_logdir_posix, exist_ok=True) 63 | 64 | # but each experiment logdir should not exist 65 | expname = f"{name}_{str(seed).zfill(4)}" 66 | path_expdir = path_logdir.joinpath(expname) 67 | path_expdir_posix = path_expdir.as_posix() 68 | os.makedirs(path_expdir_posix, exist_ok=is_overridable) 69 | 70 | traindir = path_expdir.joinpath("train") 71 | traindir_posix = traindir.as_posix() 72 | os.makedirs(traindir_posix, exist_ok=is_overridable) 73 | 74 | return traindir 75 | 76 | 77 | def set_logger(logdir: Path) -> logging.Logger: 78 | # set logger in "logdir/train.log" 79 | 80 | logger = logging.getLogger("__main__") 81 | formatter = logging.Formatter( 82 | "%(asctime)s;[%(levelname)s];%(message)s", "%Y-%m-%d %H:%M:%S" 83 | ) 84 | logger.setLevel(logging.INFO) 85 | 86 | stream_handler = logging.StreamHandler() 87 | stream_handler.setLevel(logging.INFO) 88 | stream_handler.setFormatter(formatter) 89 | 90 | file_handler = logging.FileHandler(logdir.joinpath("train.log").as_posix()) 91 | file_handler.setLevel(logging.INFO) 92 | file_handler.setFormatter(formatter) 93 | 94 | logger.addHandler(stream_handler) 95 | logger.addHandler(file_handler) 96 | 97 | # Change the print function to 98 | sys.stdout = StreamToLogger(logger, logging.INFO) 99 | sys.stderr = StreamToLogger(logger, logging.INFO) 100 | 101 | return logger 102 | 103 | 104 | def store_args_and_config(logdir: Path, args: argparse.Namespace, config: DictConfig): 105 | # store args as runnable metadata 106 | config["metadata"] = vars(args) 107 | store_path = logdir.joinpath("config.yaml") 108 | 109 | # yaml_format = OmegaConf.to_yaml(config) 110 | OmegaConf.save(config, store_path) 111 | 112 | 113 | def store_code(logdir: Path): 114 | code_path = "./src" 115 | dst_path = logdir.joinpath("code").as_posix() 116 | if not os.path.exists(dst_path): 117 | shutil.copytree( 118 | src=code_path, 119 | dst=dst_path, 120 | ignore=shutil.ignore_patterns("*__pycache__*"), 121 | ) 122 | 123 | 124 | def parse_args(): 125 | 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument( 128 | "-n", 129 | "--name", 130 | type=str, 131 | const=True, 132 | default="", 133 | nargs="?", 134 | help="postfix for logdir", 135 | ) 136 | parser.add_argument( 137 | "-b", 138 | "--base", 139 | nargs="*", 140 | metavar="configs/default.yaml", 141 | default=list(), 142 | help="path to config files", 143 | ) 144 | parser.add_argument( 145 | "-d", 146 | "--dirpath", 147 | type=str, 148 | required=True, 149 | help="path to data directory", 150 | ) 151 | parser.add_argument( 152 | "-s", 153 | "--seed", 154 | type=int, 155 | default=777, 156 | help="seed for seed_everything", 157 | ) 158 | parser.add_argument( 159 | "-l", 160 | "--logdir", 161 | type=str, 162 | default="./logs", 163 | help="path to store logs", 164 | ) 165 | parser.add_argument( 166 | "-g", 167 | "--group", 168 | type=str, 169 | default="", 170 | help="group name", 171 | ) 172 | parser.add_argument( 173 | "--verbose", 174 | type=str2bool, 175 | nargs="?", 176 | const=True, 177 | default=True, 178 | help="verbose printing", 179 | ) 180 | parser.add_argument( 181 | "--debug", 182 | type=str2bool, 183 | nargs="?", 184 | const=True, 185 | default=False, 186 | help="run with debug mode", 187 | ) 188 | 189 | args, unknown = parser.parse_known_args() 190 | 191 | return args, unknown 192 | 193 | 194 | def override_config(static_data_config, dynamic_data_config, trainer_config): 195 | # override static_config with dynamic_config 196 | iteration = static_data_config["params"]["train_dloader_config"]["params"][ 197 | "num_iterations" 198 | ] 199 | if ( 200 | dynamic_data_config["params"]["train_dloader_config"]["params"][ 201 | "num_iterations" 202 | ] 203 | != iteration 204 | ): 205 | print("Override num_iterations...") 206 | 207 | dynamic_data_config["params"]["train_dloader_config"]["params"][ 208 | "num_iterations" 209 | ] = iteration 210 | 211 | trainer_config["params"]["static"]["params"]["num_iterations"] = iteration 212 | trainer_config["params"]["static"]["params"]["camera_opt_config"]["params"][ 213 | "total_steps" 214 | ] = iteration 215 | trainer_config["params"]["static"]["params"][ 216 | "position_lr_max_steps" 217 | ] = iteration 218 | 219 | trainer_config["params"]["dynamic"]["params"]["num_iterations"] = iteration 220 | trainer_config["params"]["dynamic"]["params"]["camera_opt_config"]["params"][ 221 | "total_steps" 222 | ] = iteration 223 | trainer_config["params"]["dynamic"]["params"][ 224 | "position_lr_max_steps" 225 | ] = iteration 226 | trainer_config["params"]["dynamic"]["params"]["deform_lr_max_steps"] = iteration 227 | 228 | return dynamic_data_config, trainer_config 229 | 230 | 231 | if __name__ == "__main__": 232 | 233 | args, unknown = parse_args() 234 | 235 | # argument sanity check 236 | check_argument_sanity(args) 237 | 238 | configs = [OmegaConf.load(cfg) for cfg in args.base] 239 | cli = OmegaConf.from_dotlist(unknown) 240 | config = OmegaConf.merge(*configs, cli) 241 | 242 | # config sanity check 243 | check_config(config) 244 | config.static_data.params.dirpath = args.dirpath 245 | config.dynamic_data.params.dirpath = args.dirpath 246 | 247 | seed_all(args.seed) 248 | 249 | logdir: Path = set_traindir( 250 | logdir=args.logdir, 251 | name=args.name, 252 | group=args.group, 253 | seed=args.seed, 254 | debug=args.debug, 255 | ) 256 | logger = set_logger(logdir=logdir) 257 | store_args_and_config(logdir, args, config) 258 | store_code(logdir) 259 | 260 | config.dynamic_data, config.trainer = override_config( 261 | config.static_data, config.dynamic_data, config.trainer 262 | ) 263 | 264 | static_datamodule = instantiate_from_config(config.static_data) 265 | dynamic_datamodule = instantiate_from_config(config.dynamic_data) 266 | 267 | print("Init Static GS model") 268 | static_model = instantiate_from_config(config.static_model) 269 | 270 | print("Init Dynamic GS model") 271 | dynamic_model = instantiate_from_config(config.dynamic_model) 272 | trainer = instantiate_from_config( 273 | config.trainer, 274 | static_datamodule=static_datamodule, 275 | dynamic_datamodule=dynamic_datamodule, 276 | static_model=static_model, 277 | dynamic_model=dynamic_model, 278 | logdir=logdir, 279 | ) 280 | 281 | trainer.train() 282 | -------------------------------------------------------------------------------- /src/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import logging 12 | 13 | 14 | class StreamToLogger(object): 15 | # Reference: https://stackoverflow.com/questions/11124093/redirect-python-print-output-to-logger 16 | def __init__(self, logger, log_level=logging.INFO): 17 | self.logger = logger 18 | self.log_level = log_level 19 | self.linebuf = "" 20 | 21 | def write(self, buf): 22 | temp_linebuf = self.linebuf + buf 23 | self.linebuf = "" 24 | for line in temp_linebuf.splitlines(True): 25 | if line[-1] == "\n": 26 | self.logger.log(self.log_level, line.rstrip()) 27 | else: 28 | self.linebuf += line 29 | 30 | def flush(self): 31 | if self.linebuf != "": 32 | self.logger.log(self.log_level, self.linebuf.rstrip()) 33 | self.linebuf = "" 34 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /src/trainer/optim.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import os 12 | import math 13 | 14 | import torch 15 | from functools import partial 16 | 17 | VERBOSE = os.environ.get("VERBOSE_RUN", False) 18 | 19 | 20 | def linear_warmup_cosine_annealing_func(step, max_lr, warmup_steps, total_steps): 21 | if step < warmup_steps: 22 | # Linear warmup 23 | return max_lr * (step / warmup_steps) 24 | else: 25 | # Cosine annealing 26 | progress = (step - warmup_steps) / (total_steps - warmup_steps) 27 | cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) 28 | return max_lr * cosine_decay 29 | 30 | 31 | class CameraQuatOptimizer: 32 | 33 | eps = 1e-15 34 | 35 | def __init__( 36 | self, 37 | dataset, 38 | camera_rotation_lr, 39 | camera_translation_lr, 40 | camera_lr_warmup, 41 | total_steps, 42 | spatial_lr_scale, 43 | ): 44 | 45 | self.spatial_lr_scale = spatial_lr_scale 46 | 47 | cam_R_params, cam_T_params = [], [] 48 | for name, param in dataset.named_parameters(): 49 | if "R_" in name: 50 | cam_R_params.append(param) 51 | elif "T_" in name: 52 | cam_T_params.append(param) 53 | else: 54 | raise NameError(f"Unknown parameter {name}") 55 | 56 | l = ( 57 | {"params": cam_R_params, "lr": 0.0, "name": "camera_R"}, 58 | {"params": cam_T_params, "lr": 0.0, "name": "camera_T"}, 59 | ) 60 | 61 | self._optimizer = torch.optim.Adam(l, lr=0.0, eps=self.eps) 62 | self.R_lr_fn = partial( 63 | linear_warmup_cosine_annealing_func, 64 | max_lr=camera_rotation_lr, 65 | warmup_steps=camera_lr_warmup, 66 | total_steps=total_steps, 67 | ) 68 | self.T_lr_fn = partial( 69 | linear_warmup_cosine_annealing_func, 70 | max_lr=camera_translation_lr, 71 | warmup_steps=camera_lr_warmup, 72 | total_steps=total_steps, 73 | ) 74 | 75 | def update_learning_rate(self, iter): 76 | for param_group in self._optimizer.param_groups: 77 | if param_group["name"] == "camera_R": 78 | param_group["lr"] = self.R_lr_fn(iter) 79 | elif param_group["name"] == "camera_T": 80 | param_group["lr"] = self.T_lr_fn(iter) 81 | else: 82 | raise NameError(f"Unknown param_group {param_group['name']}") 83 | 84 | def state_dict(self): 85 | return self._optimizer.state_dict() 86 | 87 | def zero_grad(self, set_to_none): 88 | return self._optimizer.zero_grad(set_to_none=set_to_none) 89 | 90 | def step(self): 91 | self._optimizer.step() 92 | -------------------------------------------------------------------------------- /src/trainer/renderer.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import torch 12 | import math 13 | 14 | from diff_gauss_pose import GaussianRasterizationSettings, GaussianRasterizer 15 | 16 | 17 | def render( 18 | xyz, 19 | active_sh_degree, 20 | opacity, 21 | scaling, 22 | rotation, 23 | features, 24 | viewpoint_camera, 25 | bg_color: torch.Tensor, 26 | scaling_modifier=1, 27 | override_color=None, 28 | enable_sh_grad=False, 29 | enable_cov_grad=False, 30 | ): 31 | """ 32 | Render the scene. 33 | 34 | Background tensor (bg_color) must be on GPU! 35 | """ 36 | 37 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 38 | screenspace_points = ( 39 | torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device="cuda") + 0 40 | ) 41 | try: 42 | screenspace_points.retain_grad() 43 | except: 44 | pass 45 | 46 | # Set up rasterization configuration 47 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 48 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 49 | 50 | raster_settings = GaussianRasterizationSettings( 51 | image_height=int(viewpoint_camera.image_height), 52 | image_width=int(viewpoint_camera.image_width), 53 | tanfovx=tanfovx, 54 | tanfovy=tanfovy, 55 | bg=bg_color, 56 | scale_modifier=scaling_modifier, 57 | projmatrix=viewpoint_camera.projection_matrix.transpose(0, 1), # glm storage 58 | sh_degree=active_sh_degree, 59 | prefiltered=False, 60 | debug=False, 61 | enable_cov_grad=enable_sh_grad, 62 | enable_sh_grad=enable_cov_grad, 63 | ) 64 | 65 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 66 | 67 | means3D = xyz 68 | means2D = screenspace_points 69 | opacity = opacity 70 | 71 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 72 | # scaling / rotation by the rasterizer. 73 | cov3D_precomp = None 74 | scales = scaling 75 | rotations = rotation 76 | 77 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 78 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 79 | shs = None 80 | colors_precomp = None 81 | if override_color is None: 82 | shs = features 83 | else: 84 | colors_precomp = override_color 85 | 86 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 87 | rendered_image, rendered_depth, rendered_normal, rendered_alpha, radii, extra = ( 88 | rasterizer( 89 | means3D=means3D, 90 | means2D=means2D, 91 | shs=shs, 92 | colors_precomp=colors_precomp, 93 | opacities=opacity, 94 | scales=scales, 95 | rotations=rotations, 96 | cov3Ds_precomp=cov3D_precomp, 97 | viewmatrix=viewpoint_camera.world_view_transform.transpose( 98 | 0, 1 99 | ), # glm storage 100 | ) 101 | ) 102 | 103 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 104 | # They will be excluded from value updates used in the splitting criteria. 105 | return { 106 | "rendered_image": rendered_image, 107 | "rendered_depth": rendered_depth, 108 | "rendered_normal": rendered_normal, 109 | "rendered_alpha": rendered_alpha, 110 | "viewspace_points": screenspace_points, 111 | "visibility_filter": radii > 0, 112 | "radii": radii, 113 | "extra": extra, 114 | } 115 | -------------------------------------------------------------------------------- /src/trainer/rodygs_dynamic.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | from pathlib import Path 12 | from typing import Optional 13 | 14 | from omegaconf import DictConfig 15 | import torch 16 | from tqdm import tqdm 17 | from src.model.rodygs_dynamic import DynRoDyGS 18 | from src.trainer.rodygs_static import ThreeDGSTrainer 19 | from src.trainer.utils import prune_optimizer 20 | from src.utils.general_utils import get_expon_lr_func 21 | 22 | 23 | class DynTrainer(ThreeDGSTrainer): 24 | 25 | def __init__( 26 | self, 27 | datamodule, 28 | logdir: Path, 29 | model: DynRoDyGS, 30 | loss_config: DictConfig, 31 | # GS params 32 | num_iterations: int, 33 | position_lr_init: float, 34 | position_lr_final: float, 35 | position_lr_delay_mult: float, 36 | position_lr_max_steps: int, 37 | feature_lr: float, 38 | opacity_lr: float, 39 | scaling_lr: float, 40 | rotation_lr: float, 41 | percent_dense: float, 42 | opacity_reset_interval: int, 43 | densify_grad_threshold: float, 44 | densify_from_iter: int, 45 | densify_until_iter: int, 46 | densification_interval: int, 47 | # DyNMF params 48 | deform_lr_init: float, 49 | deform_lr_final: float, 50 | deform_lr_delay_mult: float, 51 | deform_lr_max_steps: int, 52 | motion_coeff_lr: float, 53 | deform_warmup_steps: int, 54 | # logging option 55 | log_freq: int = 50, 56 | # camera optim 57 | camera_opt_config: Optional[DictConfig] = None, 58 | ): 59 | super(DynTrainer, self).__init__( 60 | datamodule=datamodule, 61 | logdir=logdir, 62 | model=model, 63 | loss_config=loss_config, 64 | num_iterations=num_iterations, 65 | position_lr_init=position_lr_init, 66 | position_lr_final=position_lr_final, 67 | position_lr_delay_mult=position_lr_delay_mult, 68 | position_lr_max_steps=position_lr_max_steps, 69 | feature_lr=feature_lr, 70 | opacity_lr=opacity_lr, 71 | scaling_lr=scaling_lr, 72 | rotation_lr=rotation_lr, 73 | percent_dense=percent_dense, 74 | opacity_reset_interval=opacity_reset_interval, 75 | densify_grad_threshold=densify_grad_threshold, 76 | densify_from_iter=densify_from_iter, 77 | densify_until_iter=densify_until_iter, 78 | densification_interval=densification_interval, 79 | log_freq=log_freq, 80 | camera_opt_config=camera_opt_config, 81 | deform_warmup_steps=deform_warmup_steps, 82 | ) 83 | 84 | self.append_motion_optim( 85 | deform_lr_init=deform_lr_init, 86 | deform_lr_final=deform_lr_final, 87 | deform_lr_delay_mult=deform_lr_delay_mult, 88 | deform_lr_max_steps=deform_lr_max_steps, 89 | motion_coeff_lr=motion_coeff_lr, 90 | ) 91 | 92 | def append_motion_optim( 93 | self, 94 | deform_lr_init: float, 95 | deform_lr_final: float, 96 | deform_lr_delay_mult: float, 97 | deform_lr_max_steps: int, 98 | motion_coeff_lr: float, 99 | ): 100 | # Additional parameters and learning rates 101 | additional_params = [ 102 | { 103 | "params": list(self.model._deform_network.parameters()), 104 | "lr": deform_lr_init, 105 | "name": "deform_network", 106 | }, 107 | { 108 | "params": [self.model._motion_coeff], 109 | "lr": motion_coeff_lr, 110 | "name": "motion_coeff", 111 | }, 112 | ] 113 | 114 | # Append additional parameters and learning rates to the optimizer 115 | for additional_param in additional_params: 116 | self.optimizer.add_param_group(additional_param) 117 | 118 | self.deform_scheduler_fn = get_expon_lr_func( 119 | lr_init=deform_lr_init, 120 | lr_final=deform_lr_final, 121 | lr_delay_mult=deform_lr_delay_mult, 122 | max_steps=deform_lr_max_steps, 123 | ) 124 | 125 | def train(self): 126 | pbar = tqdm(total=self.num_iterations) 127 | dloader = self.datamodule.get_train_dloader() 128 | 129 | times = dloader.dataset.get_times() 130 | times = torch.tensor(times).cuda().view(-1, 1) 131 | self.model._time_embeddings = self.model._deform_network.batch_embedding(times) 132 | 133 | print("pre-encoding time embeddings") 134 | print("shape : ", self.model._time_embeddings.shape) 135 | 136 | train_dloader_iter = iter(dloader) 137 | 138 | for iteration in range(1, self.num_iterations + 1): 139 | self.train_iteration(train_dloader_iter, iteration) 140 | 141 | if iteration % self.log_freq == 0: 142 | pbar.update(self.log_freq) 143 | 144 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 145 | torch.save( 146 | (self.state_dict(iteration), iteration), 147 | self.logdir.as_posix() + "/last.ckpt", 148 | ) 149 | 150 | def prune_points(self, mask): 151 | # the same as the original 3DGS pruner, but also prunes the motion coefficients 152 | valid_points_mask = ~mask 153 | optimizable_tensors = prune_optimizer(self.optimizer, valid_points_mask) 154 | 155 | self.model._xyz = optimizable_tensors["xyz"] 156 | self.model._features_dc = optimizable_tensors["f_dc"] 157 | self.model._features_rest = optimizable_tensors["f_rest"] 158 | self.model._opacity = optimizable_tensors["opacity"] 159 | self.model._scaling = optimizable_tensors["scaling"] 160 | self.model._rotation = optimizable_tensors["rotation"] 161 | self.model._motion_coeff = optimizable_tensors["motion_coeff"] 162 | 163 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 164 | self.denom = self.denom[valid_points_mask] 165 | self.max_radii2D = self.max_radii2D[valid_points_mask] 166 | self.model.gaussian_to_time = self.model.gaussian_to_time[valid_points_mask] 167 | self.model.gaussian_to_time_ind = self.model.gaussian_to_time_ind[ 168 | valid_points_mask 169 | ] 170 | 171 | def set_attributes_from_opt_tensors(self, optimizable_tensors): 172 | super(DynTrainer, self).set_attributes_from_opt_tensors(optimizable_tensors) 173 | self.model._motion_coeff = optimizable_tensors["motion_coeff"] 174 | 175 | def densify_and_clone_update_attributes(self, grads, grad_threshold, scene_extent): 176 | ( 177 | updated_attributes, 178 | selected_pts_mask, 179 | ) = super().densify_and_clone_update_attributes( 180 | grads, grad_threshold, scene_extent 181 | ) 182 | updated_attributes["motion_coeff"] = self.model._motion_coeff[selected_pts_mask] 183 | return updated_attributes, selected_pts_mask 184 | 185 | def densify_and_split_update_attributes( 186 | self, grads, grad_threshold, scene_extent, N 187 | ): 188 | ( 189 | updated_attributes, 190 | selected_pts_mask, 191 | ) = super().densify_and_split_update_attributes( 192 | grads, grad_threshold, scene_extent, N 193 | ) 194 | updated_attributes["motion_coeff"] = self.model._motion_coeff[ 195 | selected_pts_mask 196 | ].repeat(N, 1, 1) 197 | return updated_attributes, selected_pts_mask 198 | 199 | def update_learning_rate(self, iteration): 200 | """Learning rate scheduling per step""" 201 | 202 | xyz_found = False 203 | deform_found = False 204 | 205 | for param_group in self.optimizer.param_groups: 206 | if xyz_found and deform_found: 207 | break 208 | if param_group["name"] == "xyz": 209 | lr = self.xyz_scheduler_fn(iteration) 210 | param_group["lr"] = lr 211 | xyz_found = True 212 | elif param_group["name"] == "deform": 213 | lr = self.deform_scheduler_fn(iteration) 214 | param_group["lr"] = lr 215 | deform_found = True 216 | 217 | def state_dict(self, iteration): 218 | state_dict = super(DynTrainer, self).state_dict(iteration) 219 | state_dict["model"]["_motion_coeff"] = self.model._motion_coeff 220 | state_dict["model"]["_deform_network"] = self.model._deform_network.state_dict() 221 | state_dict["model"]["_timestep"] = self.model.gaussian_to_time 222 | return state_dict 223 | -------------------------------------------------------------------------------- /src/trainer/utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def replace_tensor_to_optimizer(optimizer, tensor, name): 16 | optimizable_tensors = {} 17 | for group in optimizer.param_groups: 18 | if len(group["params"]) > 1: 19 | # for these cases, we don't want to modify the optimizer 20 | continue 21 | if group["name"] == name: 22 | stored_state = optimizer.state.get(group["params"][0], None) 23 | stored_state["exp_avg"] = torch.zeros_like(tensor) 24 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 25 | 26 | del optimizer.state[group["params"][0]] 27 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 28 | optimizer.state[group["params"][0]] = stored_state 29 | 30 | optimizable_tensors[group["name"]] = group["params"][0] 31 | return optimizable_tensors 32 | 33 | 34 | def cat_tensors_to_optimizer(optimizer, tensors_dict): 35 | optimizable_tensors = {} 36 | for group in optimizer.param_groups: 37 | if len(group["params"]) > 1: 38 | # for these cases, we don't want to modify the optimizer 39 | continue 40 | extension_tensor = tensors_dict[group["name"]] 41 | stored_state = optimizer.state.get(group["params"][0], None) 42 | if stored_state is not None: 43 | 44 | stored_state["exp_avg"] = torch.cat( 45 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0 46 | ) 47 | stored_state["exp_avg_sq"] = torch.cat( 48 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0 49 | ) 50 | 51 | del optimizer.state[group["params"][0]] 52 | group["params"][0] = nn.Parameter( 53 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_( 54 | True 55 | ) 56 | ) 57 | optimizer.state[group["params"][0]] = stored_state 58 | 59 | optimizable_tensors[group["name"]] = group["params"][0] 60 | else: 61 | group["params"][0] = nn.Parameter( 62 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_( 63 | True 64 | ) 65 | ) 66 | optimizable_tensors[group["name"]] = group["params"][0] 67 | 68 | return optimizable_tensors 69 | 70 | 71 | def prune_optimizer(optimizer, mask): 72 | optimizable_tensors = {} 73 | for group in optimizer.param_groups: 74 | if len(group["params"]) > 1: 75 | # for these cases, we don't want to modify the optimizer 76 | continue 77 | stored_state = optimizer.state.get(group["params"][0], None) 78 | if stored_state is not None: 79 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 80 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 81 | 82 | del optimizer.state[group["params"][0]] 83 | group["params"][0] = nn.Parameter( 84 | (group["params"][0][mask].requires_grad_(True)) 85 | ) 86 | optimizer.state[group["params"][0]] = stored_state 87 | 88 | optimizable_tensors[group["name"]] = group["params"][0] 89 | else: 90 | group["params"][0] = nn.Parameter( 91 | group["params"][0][mask].requires_grad_(True) 92 | ) 93 | optimizable_tensors[group["name"]] = group["params"][0] 94 | 95 | return optimizable_tensors 96 | -------------------------------------------------------------------------------- /src/utils/configs.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import argparse 12 | import importlib 13 | 14 | from omegaconf import OmegaConf, DictConfig 15 | 16 | 17 | def str2bool(v): 18 | # a util function that accepts various kinds of flag options. 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ("yes", "true", "t", "y", "1"): 22 | return True 23 | elif v.lower() in ("no", "false", "f", "n", "0"): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError("Boolean value expected.") 27 | 28 | 29 | def is_instantiable(config: DictConfig, key: str): 30 | # check whether the configuration is an instantiable format. 31 | subconfig = config.get(key) 32 | is_config = OmegaConf.is_config(subconfig) 33 | if not is_config: 34 | return False 35 | 36 | # target and params should be in keys. 37 | key_list = subconfig.keys() 38 | has_params = "params" in key_list 39 | has_target = "target" in key_list 40 | 41 | # no other keys are needed. 42 | no_more_keys = len(key_list) == 2 43 | 44 | return has_params and has_target and no_more_keys 45 | 46 | 47 | def instantiate_from_config(config: DictConfig, **kwargs): 48 | # instantiate an object from configuration. 49 | # if some arguments are passed in kwargs, this function overloads the configruation. 50 | assert "target" in config.keys(), "target not exists" 51 | 52 | params = dict() 53 | params.update(config.get("params", dict())) 54 | params.update(kwargs) 55 | return get_obj_from_str(config["target"])(**params) 56 | 57 | 58 | def get_obj_from_str(string, reload=False, invalidate_cache=True): 59 | # from a str, return the corresponding object. 60 | module, obj_name = string.rsplit(".", 1) 61 | if invalidate_cache: 62 | importlib.invalidate_caches() 63 | if reload: 64 | module_imp = importlib.import_module(module) 65 | importlib.reload(module_imp) 66 | return getattr(importlib.import_module(module, package=None), obj_name) 67 | -------------------------------------------------------------------------------- /src/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | from itertools import chain 12 | from typing import Sequence 13 | from collections import OrderedDict 14 | 15 | import numpy as np 16 | import scipy 17 | from piqa import PSNR, SSIM, LPIPS, MS_SSIM 18 | import torch 19 | import torch.nn as nn 20 | from torchvision import models 21 | 22 | from src.utils.general_utils import batchify, reduce 23 | from src.utils.pose_estim_utils import align_ate_c2b_use_a2b, compute_ATE, compute_rpe 24 | 25 | 26 | class VizScoreEvaluator: 27 | 28 | def __init__(self, device): 29 | self.psnr_module = PSNR().to(device) 30 | self.ssim_module = SSIM().to(device) 31 | self.msssim_module = MS_SSIM().to(device) 32 | 33 | @torch.inference_mode() 34 | def get_score(self, gt_image, pred_image): 35 | 36 | bf_gt_image = batchify(gt_image).clip(0, 1).contiguous() 37 | bf_pred_image = batchify(pred_image).clip(0, 1).contiguous() 38 | 39 | psnr_score = reduce(self.psnr_module(bf_gt_image, bf_pred_image)) 40 | ssim_score = reduce(self.ssim_module(bf_gt_image, bf_pred_image)) 41 | lpipsa_score = reduce(lpips(bf_gt_image, bf_pred_image, net_type="alex")) 42 | lpipsv_score = reduce(lpips(bf_gt_image, bf_pred_image, net_type="vgg")) 43 | msssim_score = reduce(self.msssim_module(bf_gt_image, bf_pred_image)) 44 | dssim_score = (1 - msssim_score) / 2 45 | 46 | return { 47 | "psnr": psnr_score, 48 | "ssim": ssim_score, 49 | "lpipsa": lpipsa_score, 50 | "lpipsv": lpipsv_score, 51 | "msssim": msssim_score, 52 | "dssim": dssim_score, 53 | } 54 | 55 | 56 | class PoseEvaluator: 57 | def __init__(self): 58 | pass 59 | 60 | def normalize_pose(self, pose1, pose2): 61 | mtx1 = np.array(pose1.cpu().numpy(), dtype=np.double, copy=True) 62 | mtx2 = np.array(pose2.cpu().numpy(), dtype=np.double, copy=True) 63 | 64 | if mtx1.ndim != 2 or mtx2.ndim != 2: 65 | raise ValueError("Input matrices must be two-dimensional") 66 | if mtx1.shape != mtx2.shape: 67 | raise ValueError("Input matrices must be of same shape") 68 | if mtx1.size == 0: 69 | raise ValueError("Input matrices must be >0 rows and >0 cols") 70 | 71 | # translate all the data to the origin 72 | mtx1 -= np.mean(mtx1, 0) 73 | mtx2 -= np.mean(mtx2, 0) 74 | 75 | norm1 = np.linalg.norm(mtx1) 76 | norm2 = np.linalg.norm(mtx2) 77 | 78 | if norm1 == 0 or norm2 == 0: 79 | raise ValueError("Input matrices must contain >1 unique points") 80 | 81 | # change scaling of data (in rows) such that trace(mtx*mtx') = 1 82 | mtx1 /= norm1 83 | mtx2 /= norm2 84 | 85 | # transform mtx2 to minimize disparity 86 | R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2) 87 | mtx2 = mtx2 * s 88 | 89 | return mtx1, mtx2, R 90 | 91 | def algin_pose(self, gt, estim): 92 | gt_ret, estim_ret = gt.clone(), estim.clone() 93 | gt_transl, estim_transl, _ = self.normalize_pose( 94 | gt[:, :3, -1], estim_ret[:, :3, -1] 95 | ) 96 | gt_ret[:, :3, -1], estim_ret[:, :3, -1] = torch.from_numpy(gt_transl).type_as( 97 | gt_ret 98 | ), torch.from_numpy(estim_transl).type_as(estim_ret) 99 | c2ws_est_aligned = align_ate_c2b_use_a2b(estim_ret, gt_ret) 100 | return gt_ret, c2ws_est_aligned 101 | 102 | @torch.inference_mode() 103 | def get_score(self, gt, estim): 104 | gt, c2ws_est_aligned = self.algin_pose(gt, estim) 105 | ate = compute_ATE(gt.cpu().numpy(), c2ws_est_aligned.cpu().numpy()) 106 | rpe_trans, rpe_rot = compute_rpe( 107 | gt.cpu().numpy(), c2ws_est_aligned.cpu().numpy() 108 | ) 109 | rpe_trans *= 100 110 | rpe_rot *= 180 / np.pi 111 | 112 | return { 113 | "ATE": ate, 114 | "RPE_trans": rpe_trans, 115 | "RPE_rot": rpe_rot, 116 | "aligned": c2ws_est_aligned, 117 | } 118 | 119 | 120 | class LPIPS(nn.Module): 121 | r"""Creates a criterion that measures 122 | Learned Perceptual Image Patch Similarity (LPIPS). 123 | 124 | Arguments: 125 | net_type (str): the network type to compare the features: 126 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 127 | version (str): the version of LPIPS. Default: 0.1. 128 | """ 129 | 130 | def __init__(self, net_type: str = "alex", version: str = "0.1"): 131 | 132 | assert version in ["0.1"], "v0.1 is only supported now" 133 | 134 | super(LPIPS, self).__init__() 135 | 136 | # pretrained network 137 | self.net = get_network(net_type) 138 | 139 | # linear layers 140 | self.lin = LinLayers(self.net.n_channels_list) 141 | self.lin.load_state_dict(get_state_dict(net_type, version)) 142 | 143 | def forward(self, x: torch.Tensor, y: torch.Tensor): 144 | feat_x, feat_y = self.net(x), self.net(y) 145 | 146 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 147 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 148 | 149 | return torch.sum(torch.cat(res, 0), 0, True) 150 | 151 | 152 | def get_network(net_type: str): 153 | if net_type == "alex": 154 | return AlexNet() 155 | elif net_type == "squeeze": 156 | return SqueezeNet() 157 | elif net_type == "vgg": 158 | return VGG16() 159 | else: 160 | raise NotImplementedError("choose net_type from [alex, squeeze, vgg].") 161 | 162 | 163 | class LinLayers(nn.ModuleList): 164 | def __init__(self, n_channels_list: Sequence[int]): 165 | super(LinLayers, self).__init__( 166 | [ 167 | nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False)) 168 | for nc in n_channels_list 169 | ] 170 | ) 171 | 172 | for param in self.parameters(): 173 | param.requires_grad = False 174 | 175 | 176 | class BaseNet(nn.Module): 177 | def __init__(self): 178 | super(BaseNet, self).__init__() 179 | 180 | # register buffer 181 | self.register_buffer( 182 | "mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 183 | ) 184 | self.register_buffer( 185 | "std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 186 | ) 187 | 188 | def set_requires_grad(self, state: bool): 189 | for param in chain(self.parameters(), self.buffers()): 190 | param.requires_grad = state 191 | 192 | def z_score(self, x: torch.Tensor): 193 | return (x - self.mean) / self.std 194 | 195 | def forward(self, x: torch.Tensor): 196 | x = self.z_score(x) 197 | 198 | output = [] 199 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 200 | x = layer(x) 201 | if i in self.target_layers: 202 | output.append(normalize_activation(x)) 203 | if len(output) == len(self.target_layers): 204 | break 205 | return output 206 | 207 | 208 | class SqueezeNet(BaseNet): 209 | def __init__(self): 210 | super(SqueezeNet, self).__init__() 211 | 212 | self.layers = models.squeezenet1_1(True).features 213 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 214 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 215 | 216 | self.set_requires_grad(False) 217 | 218 | 219 | class AlexNet(BaseNet): 220 | def __init__(self): 221 | super(AlexNet, self).__init__() 222 | 223 | self.layers = models.alexnet(True).features 224 | self.target_layers = [2, 5, 8, 10, 12] 225 | self.n_channels_list = [64, 192, 384, 256, 256] 226 | 227 | self.set_requires_grad(False) 228 | 229 | 230 | class VGG16(BaseNet): 231 | def __init__(self): 232 | super(VGG16, self).__init__() 233 | 234 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 235 | self.target_layers = [4, 9, 16, 23, 30] 236 | self.n_channels_list = [64, 128, 256, 512, 512] 237 | 238 | self.set_requires_grad(False) 239 | 240 | 241 | def normalize_activation(x, eps=1e-10): 242 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 243 | return x / (norm_factor + eps) 244 | 245 | 246 | def get_state_dict(net_type: str = "alex", version: str = "0.1"): 247 | # build url 248 | url = ( 249 | "https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/" 250 | + f"master/lpips/weights/v{version}/{net_type}.pth" 251 | ) 252 | 253 | # download 254 | old_state_dict = torch.hub.load_state_dict_from_url( 255 | url, 256 | progress=True, 257 | map_location=None if torch.cuda.is_available() else torch.device("cpu"), 258 | ) 259 | 260 | # rename keys 261 | new_state_dict = OrderedDict() 262 | for key, val in old_state_dict.items(): 263 | new_key = key 264 | new_key = new_key.replace("lin", "") 265 | new_key = new_key.replace("model.", "") 266 | new_state_dict[new_key] = val 267 | 268 | return new_state_dict 269 | 270 | 271 | def lpips( 272 | x: torch.Tensor, y: torch.Tensor, net_type: str = "alex", version: str = "0.1" 273 | ): 274 | r"""Function that measures 275 | Learned Perceptual Image Patch Similarity (LPIPS). 276 | 277 | Arguments: 278 | x, y (torch.Tensor): the input tensors to compare. 279 | net_type (str): the network type to compare the features: 280 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 281 | version (str): the version of LPIPS. Default: 0.1. 282 | """ 283 | device = x.device 284 | criterion = LPIPS(net_type, version).to(device) 285 | return criterion(x, y) 286 | -------------------------------------------------------------------------------- /src/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | 11 | import os 12 | import random 13 | 14 | import torch 15 | import numpy as np 16 | 17 | 18 | def seed_all(seed: int): 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | random.seed(seed) 25 | os.environ["PYTHONHASHSEED"] = str(seed) 26 | 27 | 28 | def batchify(tensor): 29 | return tensor.unsqueeze(0) 30 | 31 | 32 | def reduce(tensor): 33 | return tensor.squeeze(0) 34 | 35 | 36 | def inverse_sigmoid(x): 37 | return torch.log(x / (1 - x)) 38 | 39 | 40 | def get_expon_lr_func( 41 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 42 | ): 43 | """ 44 | Copied from Plenoxels 45 | 46 | Continuous learning rate decay function. Adapted from JaxNeRF 47 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 48 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 49 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 50 | function of lr_delay_mult, such that the initial learning rate is 51 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 52 | to the normal learning rate when steps>lr_delay_steps. 53 | :param conf: config subtree 'lr' or similar 54 | :param max_steps: int, the number of steps during optimization. 55 | :return HoF which takes step as input 56 | """ 57 | 58 | def helper(step): 59 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 60 | # Disable this parameter 61 | return 0.0 62 | if lr_delay_steps > 0: 63 | # A kind of reverse cosine decay. 64 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 65 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 66 | ) 67 | else: 68 | delay_rate = 1.0 69 | t = np.clip(step / max_steps, 0, 1) 70 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 71 | return delay_rate * log_lerp 72 | 73 | return helper 74 | 75 | 76 | def strip_lowerdiag(L): 77 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 78 | 79 | uncertainty[:, 0] = L[:, 0, 0] 80 | uncertainty[:, 1] = L[:, 0, 1] 81 | uncertainty[:, 2] = L[:, 0, 2] 82 | uncertainty[:, 3] = L[:, 1, 1] 83 | uncertainty[:, 4] = L[:, 1, 2] 84 | uncertainty[:, 5] = L[:, 2, 2] 85 | return uncertainty 86 | 87 | 88 | def strip_symmetric(sym): 89 | return strip_lowerdiag(sym) 90 | 91 | 92 | def build_rotation(r): 93 | norm = torch.sqrt( 94 | r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] 95 | ) 96 | 97 | q = r / norm[:, None] 98 | 99 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 100 | 101 | r = q[:, 0] 102 | x = q[:, 1] 103 | y = q[:, 2] 104 | z = q[:, 3] 105 | 106 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 107 | R[:, 0, 1] = 2 * (x * y - r * z) 108 | R[:, 0, 2] = 2 * (x * z + r * y) 109 | R[:, 1, 0] = 2 * (x * y + r * z) 110 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 111 | R[:, 1, 2] = 2 * (y * z - r * x) 112 | R[:, 2, 0] = 2 * (x * z - r * y) 113 | R[:, 2, 1] = 2 * (y * z + r * x) 114 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 115 | return R 116 | 117 | 118 | def build_scaling_rotation(s, r): 119 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 120 | R = build_rotation(r) 121 | 122 | L[:, 0, 0] = s[:, 0] 123 | L[:, 1, 1] = s[:, 1] 124 | L[:, 2, 2] = s[:, 2] 125 | 126 | L = R @ L 127 | return L 128 | -------------------------------------------------------------------------------- /src/utils/graphic_utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | import math 12 | import warnings 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | 19 | def geom_transform_points(points, transf_matrix): 20 | P, _ = points.shape 21 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 22 | points_hom = torch.cat([points, ones], dim=1) 23 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 24 | 25 | denom = points_out[..., 3:] + 0.0000001 26 | return (points_out[..., :3] / denom).squeeze(dim=0) 27 | 28 | 29 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 30 | Rt = np.zeros((4, 4)) 31 | Rt[:3, :3] = R 32 | Rt[:3, 3] = t 33 | Rt[3, 3] = 1.0 34 | 35 | C2W = np.linalg.inv(Rt) 36 | cam_center = C2W[:3, 3] 37 | cam_center = (cam_center + translate) * scale 38 | C2W[:3, 3] = cam_center 39 | Rt = np.linalg.inv(C2W) 40 | return np.float32(Rt) 41 | 42 | 43 | def getProjectionMatrix(znear, zfar, fovX, fovY): 44 | tanHalfFovY = math.tan((fovY / 2)) 45 | tanHalfFovX = math.tan((fovX / 2)) 46 | 47 | top = tanHalfFovY * znear 48 | bottom = -top 49 | right = tanHalfFovX * znear 50 | left = -right 51 | 52 | P = torch.zeros(4, 4) 53 | 54 | z_sign = 1.0 55 | 56 | P[0, 0] = 2.0 * znear / (right - left) 57 | P[1, 1] = 2.0 * znear / (top - bottom) 58 | P[0, 2] = (right + left) / (right - left) 59 | P[1, 2] = (top + bottom) / (top - bottom) 60 | P[3, 2] = z_sign 61 | P[2, 2] = z_sign * zfar / (zfar - znear) 62 | P[2, 3] = -(zfar * znear) / (zfar - znear) 63 | return P 64 | 65 | 66 | def fov2focal(fov, pixels): 67 | if fov > math.pi * 2 or fov < -math.pi * 2: 68 | warnings.warn("fov seems to be degree value, plz double check!") 69 | return pixels / (2 * math.tan(fov / 2)) 70 | 71 | 72 | def focal2fov(focal, pixels): 73 | return 2 * math.atan(pixels / (2 * focal)) 74 | 75 | 76 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 77 | """ 78 | Convert rotations given as quaternions to rotation matrices. 79 | Args: 80 | quaternions: quaternions with real part first, 81 | as tensor of shape (..., 4). 82 | Returns: 83 | Rotation matrices as tensor of shape (..., 3, 3). 84 | """ 85 | r, i, j, k = torch.unbind(quaternions, -1) 86 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 87 | 88 | o = torch.stack( 89 | ( 90 | 1 - two_s * (j * j + k * k), 91 | two_s * (i * j - k * r), 92 | two_s * (i * k + j * r), 93 | two_s * (i * j + k * r), 94 | 1 - two_s * (i * i + k * k), 95 | two_s * (j * k - i * r), 96 | two_s * (i * k - j * r), 97 | two_s * (j * k + i * r), 98 | 1 - two_s * (i * i + j * j), 99 | ), 100 | -1, 101 | ) 102 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 103 | 104 | 105 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 106 | """ 107 | Returns torch.sqrt(torch.max(0, x)) 108 | but with a zero subgradient where x is 0. 109 | """ 110 | ret = torch.zeros_like(x) 111 | positive_mask = x > 0 112 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 113 | return ret 114 | 115 | 116 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 117 | """ 118 | Convert rotations given as rotation matrices to quaternions. 119 | Args: 120 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 121 | Returns: 122 | quaternions with real part first, as tensor of shape (..., 4). 123 | """ 124 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 125 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 126 | 127 | batch_dim = matrix.shape[:-2] 128 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 129 | matrix.reshape(batch_dim + (9,)), dim=-1 130 | ) 131 | 132 | q_abs = _sqrt_positive_part( 133 | torch.stack( 134 | [ 135 | 1.0 + m00 + m11 + m22, 136 | 1.0 + m00 - m11 - m22, 137 | 1.0 - m00 + m11 - m22, 138 | 1.0 - m00 - m11 + m22, 139 | ], 140 | dim=-1, 141 | ) 142 | ) 143 | 144 | # we produce the desired quaternion multiplied by each of r, i, j, k 145 | quat_by_rijk = torch.stack( 146 | [ 147 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 148 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 149 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 150 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 151 | ], 152 | dim=-2, 153 | ) 154 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 155 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 156 | 157 | return quat_candidates[ 158 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 159 | ].reshape(batch_dim + (4,)) 160 | -------------------------------------------------------------------------------- /src/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | from math import exp 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | 19 | def l1_loss(network_output, gt): 20 | return torch.abs((network_output - gt)).mean() 21 | 22 | 23 | def l2_loss(network_output, gt): 24 | return ((network_output - gt) ** 2).mean() 25 | 26 | 27 | def smooth_loss(output): 28 | grad_output_x = output[:, :, :-1] - output[:, :, 1:] 29 | grad_output_y = output[:, :-1, :] - output[:, 1:, :] 30 | 31 | return torch.abs(grad_output_x).mean() + torch.abs(grad_output_y).mean() 32 | 33 | 34 | def gaussian(window_size, sigma): 35 | gauss = torch.Tensor( 36 | [ 37 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 38 | for x in range(window_size) 39 | ] 40 | ) 41 | return gauss / gauss.sum() 42 | 43 | 44 | def logl1(pred, gt): 45 | return torch.log(1 + torch.abs(pred - gt)) 46 | 47 | 48 | def create_window(window_size, channel): 49 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 50 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 51 | window = Variable( 52 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 53 | ) 54 | return window 55 | 56 | 57 | def ssim(img1, img2, window_size=11, size_average=True): 58 | channel = img1.size(-3) 59 | window = create_window(window_size, channel) 60 | 61 | if img1.is_cuda: 62 | window = window.cuda(img1.get_device()) 63 | window = window.type_as(img1) 64 | 65 | return _ssim(img1, img2, window, window_size, channel, size_average) 66 | 67 | 68 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 69 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 70 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 71 | 72 | mu1_sq = mu1.pow(2) 73 | mu2_sq = mu2.pow(2) 74 | mu1_mu2 = mu1 * mu2 75 | 76 | sigma1_sq = ( 77 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 78 | ) 79 | sigma2_sq = ( 80 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 81 | ) 82 | sigma12 = ( 83 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 84 | - mu1_mu2 85 | ) 86 | 87 | C1 = 0.01**2 88 | C2 = 0.03**2 89 | 90 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 91 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 92 | ) 93 | 94 | if size_average: 95 | return ssim_map.mean() 96 | else: 97 | return ssim_map.mean(1).mean(1).mean(1) 98 | 99 | 100 | def pearson_depth_loss(input_depth, target_depth, eps, mask=None): 101 | 102 | if mask is not None: 103 | pred_depth = input_depth * mask 104 | gt_depth = target_depth * mask 105 | else: 106 | pred_depth = input_depth 107 | gt_depth = target_depth 108 | 109 | centered_pred_depth = pred_depth - pred_depth.mean() 110 | centered_gt_depth = gt_depth - gt_depth.mean() 111 | 112 | normalized_pred_depth = centered_pred_depth / (centered_pred_depth.std() + eps) 113 | normalized_gt_depth = centered_gt_depth / (centered_gt_depth.std() + eps) 114 | 115 | covariance = (normalized_pred_depth * normalized_gt_depth).mean() 116 | 117 | return 1 - covariance 118 | 119 | 120 | def compute_fundamental_matrix(K, w2c1, w2c2): 121 | 122 | # Compute relative rotation and translation 123 | R_rel = w2c2[:, :3, :3] @ w2c1[:, :3, :3].transpose(1, 2) 124 | t_rel = w2c2[:, :3, 3] - torch.einsum("nij, nj -> ni", R_rel, w2c1[:, :3, 3]) 125 | 126 | # Skew-symmetric matrix [t]_x for cross product 127 | def skew_symmetric(t): 128 | B = t.shape[0] 129 | zero = torch.zeros(B, device=t.device) 130 | tx = torch.stack( 131 | [zero, -t[:, 2], t[:, 1], t[:, 2], zero, -t[:, 0], -t[:, 1], t[:, 0], zero], 132 | dim=1, 133 | ) 134 | return tx.view(B, 3, 3) 135 | 136 | t_rel_skew = skew_symmetric(t_rel) # (B, 3, 3) 137 | 138 | # Compute Essential matrix E = [t]_x * R_rel 139 | E = t_rel_skew @ R_rel 140 | 141 | # Compute Fundamental matrix: F = K2^-T * E * K1^-1 142 | inverse_K = invert_intrinsics(K) 143 | K1_inv = inverse_K 144 | K2_inv_T = inverse_K.transpose(1, 2) 145 | F = K2_inv_T.bmm(E).bmm(K1_inv) 146 | 147 | return F 148 | 149 | 150 | def invert_intrinsics(K): 151 | """ 152 | Efficiently compute the inverse of a batch of intrinsic matrices. 153 | 154 | Args: 155 | K: (B, 3, 3) tensor representing a batch of intrinsic matrices. 156 | 157 | Returns: 158 | K_inv: (B, 3, 3) tensor representing the batch of inverse intrinsic matrices. 159 | """ 160 | B = K.shape[0] 161 | 162 | # Extract intrinsic parameters from the matrix 163 | fx = K[:, 0, 0] # focal length in x 164 | fy = K[:, 1, 1] # focal length in y 165 | cx = K[:, 0, 2] # principal point x 166 | cy = K[:, 1, 2] # principal point y 167 | 168 | # Construct the inverse intrinsic matrices manually 169 | K_inv = torch.zeros_like(K) 170 | K_inv[:, 0, 0] = 1.0 / fx 171 | K_inv[:, 1, 1] = 1.0 / fy 172 | K_inv[:, 0, 2] = -cx / fx 173 | K_inv[:, 1, 2] = -cy / fy 174 | K_inv[:, 2, 2] = 1.0 175 | 176 | return K_inv 177 | 178 | 179 | def construct_intrinsics(focal, image_width, image_height, batch_size): 180 | K = torch.zeros(batch_size, 3, 3) 181 | K[:, 0, 0] = focal 182 | K[:, 1, 1] = focal 183 | K[:, 2, 2] = 1.0 184 | K[:, 0, 2] = image_width / 2 185 | K[:, 1, 2] = image_height / 2 186 | return K 187 | 188 | 189 | def compute_sampson_error(x1, x2, F): 190 | """ 191 | :param x1 (*, N, 2) 192 | :param x2 (*, N, 2) 193 | :param F (*, 3, 3) 194 | """ 195 | h1 = torch.cat([x1, torch.ones_like(x1[..., :1])], dim=-1) 196 | h2 = torch.cat([x2, torch.ones_like(x2[..., :1])], dim=-1) 197 | d1 = torch.matmul(h1, F.transpose(-1, -2)) # (B, N, 3) 198 | d2 = torch.matmul(h2, F) # (B, N, 3) 199 | z = (h2 * d1).sum(dim=-1) # (B, N) 200 | err = (z**2) / ( 201 | d1[..., 0] ** 2 + d1[..., 1] ** 2 + d2[..., 0] ** 2 + d2[..., 1] ** 2 202 | ) 203 | return err 204 | 205 | 206 | def get_outnorm(x: torch.Tensor, out_norm: str = "") -> torch.Tensor: 207 | """Common function to get a loss normalization value. Can 208 | normalize by either the batch size ('b'), the number of 209 | channels ('c'), the image size ('i') or combinations 210 | ('bi', 'bci', etc) 211 | """ 212 | # b, c, h, w = x.size() 213 | img_shape = x.shape 214 | 215 | if not out_norm: 216 | return 1 217 | 218 | norm = 1 219 | if "b" in out_norm: 220 | # normalize by batch size 221 | # norm /= b 222 | norm /= img_shape[0] 223 | if "c" in out_norm: 224 | # normalize by the number of channels 225 | # norm /= c 226 | norm /= img_shape[1] 227 | if "i" in out_norm: 228 | # normalize by image/map size 229 | # norm /= h*w 230 | norm /= img_shape[-1] * img_shape[-2] 231 | 232 | return norm 233 | 234 | 235 | class CharbonnierLoss(nn.Module): 236 | """Charbonnier Loss (L1)""" 237 | 238 | def __init__(self, eps=1e-6, out_norm: str = "bci"): 239 | super(CharbonnierLoss, self).__init__() 240 | self.eps = eps 241 | self.out_norm = out_norm 242 | 243 | def forward(self, x, y, weight=None): 244 | norm = get_outnorm(x, self.out_norm) 245 | if weight is None: 246 | loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) 247 | else: 248 | loss = torch.sum(weight * torch.sqrt((x - y).pow(2) + self.eps**2)) 249 | return loss * norm 250 | -------------------------------------------------------------------------------- /src/utils/point_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from typing import Optional 13 | from dataclasses import dataclass 14 | import numpy as np 15 | 16 | 17 | @dataclass 18 | class BasicPointCloud: 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | time: Optional[np.array] 23 | 24 | 25 | def uniform_sample(pcd: BasicPointCloud, ratio: float): 26 | assert ratio <= 1.0 27 | num_points = len(pcd.points) 28 | num_sample = int(num_points * ratio) 29 | point_idx = np.random.choice(num_points, num_sample, replace=False) 30 | 31 | return BasicPointCloud( 32 | points=pcd.points[point_idx], 33 | colors=pcd.colors[point_idx], 34 | normals=pcd.normals[point_idx], 35 | time=pcd.time[point_idx] if not pcd.time is None else None, 36 | ) 37 | 38 | 39 | def merge_pcds(pcds): 40 | 41 | points, colors, normals, time = [], [], [], [] 42 | 43 | for pcd in pcds: 44 | points.append(pcd.points) 45 | colors.append(pcd.colors) 46 | normals.append(pcd.normals) 47 | time.append(pcd.time) 48 | 49 | return BasicPointCloud( 50 | points=np.concatenate(points), 51 | colors=np.concatenate(colors), 52 | normals=np.concatenate(normals), 53 | time=np.concatenate(time) if not time[0] is None else None, 54 | ) 55 | -------------------------------------------------------------------------------- /src/utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | C0 = 0.28209479177387814 25 | C1 = 0.4886025119029199 26 | C2 = [ 27 | 1.0925484305920792, 28 | -1.0925484305920792, 29 | 0.31539156525252005, 30 | -1.0925484305920792, 31 | 0.5462742152960396, 32 | ] 33 | C3 = [ 34 | -0.5900435899266435, 35 | 2.890611442640554, 36 | -0.4570457994644658, 37 | 0.3731763325901154, 38 | -0.4570457994644658, 39 | 1.445305721320277, 40 | -0.5900435899266435, 41 | ] 42 | C4 = [ 43 | 2.5033429417967046, 44 | -1.7701307697799304, 45 | 0.9461746957575601, 46 | -0.6690465435572892, 47 | 0.10578554691520431, 48 | -0.6690465435572892, 49 | 0.47308734787878004, 50 | -1.7701307697799304, 51 | 0.6258357354491761, 52 | ] 53 | 54 | 55 | def eval_sh(deg, sh, dirs): 56 | """ 57 | Evaluate spherical harmonics at unit directions 58 | using hardcoded SH polynomials. 59 | Works with torch/np/jnp. 60 | ... Can be 0 or more batch dimensions. 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | Returns: 66 | [..., C] 67 | """ 68 | assert deg <= 4 and deg >= 0 69 | coeff = (deg + 1) ** 2 70 | assert sh.shape[-1] >= coeff 71 | 72 | result = C0 * sh[..., 0] 73 | if deg > 0: 74 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 75 | result = ( 76 | result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 77 | ) 78 | 79 | if deg > 1: 80 | xx, yy, zz = x * x, y * y, z * z 81 | xy, yz, xz = x * y, y * z, x * z 82 | result = ( 83 | result 84 | + C2[0] * xy * sh[..., 4] 85 | + C2[1] * yz * sh[..., 5] 86 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 87 | + C2[3] * xz * sh[..., 7] 88 | + C2[4] * (xx - yy) * sh[..., 8] 89 | ) 90 | 91 | if deg > 2: 92 | result = ( 93 | result 94 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 95 | + C3[1] * xy * z * sh[..., 10] 96 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 97 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 98 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 99 | + C3[5] * z * (xx - yy) * sh[..., 14] 100 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 101 | ) 102 | 103 | if deg > 3: 104 | result = ( 105 | result 106 | + C4[0] * xy * (xx - yy) * sh[..., 16] 107 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 108 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 109 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 110 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 111 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 112 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 113 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 114 | + C4[8] 115 | * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 116 | * sh[..., 24] 117 | ) 118 | return result 119 | 120 | 121 | def RGB2SH(rgb): 122 | return (rgb - 0.5) / C0 123 | 124 | 125 | def SH2RGB(sh): 126 | return sh * C0 + 0.5 127 | -------------------------------------------------------------------------------- /src/utils/store_utils.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | 3 | # Copyright (c) 2024 Yoonwoo Jeong, Junmyeong Lee, Hoseung Choi, and Minsu Cho (POSTECH) 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | from pathlib import Path 12 | from typing import Union 13 | from abc import ABC, abstractmethod 14 | 15 | import numpy as np 16 | import torch 17 | import cv2 18 | 19 | 20 | def load_depth_from_pgm(pgm_file_path, min_depth_threshold=None): 21 | """ 22 | Load the depth map from PGM file 23 | :param pgm_file_path: pgm file path 24 | :return: depth map with 3D ND-array 25 | """ 26 | raw_img = None 27 | with open(pgm_file_path, "rb") as f: 28 | line = str(f.readline(), encoding="ascii") 29 | if line != "P5\n": 30 | print("Error loading pgm, format error\n") 31 | 32 | line = str(f.readline(), encoding="ascii") 33 | max_depth = float(line.split(" ")[-1].strip()) 34 | 35 | line = str(f.readline(), encoding="ascii") 36 | dims = line.split(" ") 37 | cols = int(dims[0].strip()) 38 | rows = int(dims[1].strip()) 39 | 40 | line = str(f.readline(), encoding="ascii") 41 | max_factor = float(line.strip()) 42 | 43 | raw_img = ( 44 | np.frombuffer( 45 | f.read(cols * rows * np.dtype(np.uint16).itemsize), dtype=np.uint16 46 | ) 47 | .reshape((rows, cols)) 48 | .astype(np.float32) 49 | ) 50 | raw_img *= max_depth / max_factor 51 | 52 | if min_depth_threshold is not None: 53 | raw_img[raw_img < min_depth_threshold] = min_depth_threshold 54 | 55 | return np.expand_dims(raw_img, axis=0) 56 | 57 | 58 | def save_depth_to_pgm(depth, pgm_file_path): 59 | """ 60 | Save the depth map to PGM file 61 | :param depth: depth map with 3D ND-array, 1XHXW 62 | :param pgm_file_path: output file path 63 | """ 64 | depth_flatten = depth[0] 65 | max_depth = np.max(depth_flatten) 66 | depth_copy = np.copy(depth_flatten) 67 | depth_copy = 65535.0 * (depth_copy / max_depth) 68 | depth_copy = depth_copy.astype(np.uint16) 69 | 70 | with open(pgm_file_path, "wb") as f: 71 | f.write(bytes("P5\n", encoding="ascii")) 72 | f.write(bytes("# %f\n" % max_depth, encoding="ascii")) 73 | f.write( 74 | bytes( 75 | "%d %d\n" % (depth_flatten.shape[1], depth_flatten.shape[0]), 76 | encoding="ascii", 77 | ) 78 | ) 79 | f.write(bytes("65535\n", encoding="ascii")) 80 | f.write(depth_copy.tobytes()) 81 | 82 | 83 | class Storer(ABC): 84 | """ 85 | Abstract base class for storing tensor data with validation. Subclasses 86 | must implement the `sanity_check` method to validate tensors and the `store` 87 | method to define how tensors are stored. 88 | """ 89 | 90 | sanity_check = None 91 | 92 | def __init__(self, path: Path): 93 | self.path = path 94 | self.path.mkdir(exist_ok=True) 95 | 96 | def to_cv2(self, tensor: torch.Tensor): 97 | np_arr = np.transpose(tensor.clamp(0, 1).cpu().detach().numpy(), (1, 2, 0)) 98 | np_arr = np_arr[..., ::-1] 99 | # cv2_format = (np_arr * 65536).astype(np.uint16) 100 | cv2_format = (np_arr * 65535).astype(np.uint16) 101 | return cv2_format 102 | 103 | @abstractmethod 104 | def sanity_check(self, tensor: torch.Tensor) -> bool: 105 | """Abstract method for sanity checking the tensor. 106 | Subclasses should implement this method to define thir specific checks 107 | """ 108 | raise NotImplementedError 109 | 110 | @abstractmethod 111 | def store(self, image_name, tensor: torch.Tensor) -> None: 112 | """Abstract method for storing the tensor. 113 | Subclasses should implement this method to define thir specific checks 114 | """ 115 | raise NotImplementedError 116 | 117 | def __call__(self, image_name, tensor: torch.Tensor): 118 | if not self.sanity_check(tensor): 119 | raise ValueError( 120 | f"Sanity check failed for tensor with shape {tensor.shape}" 121 | ) 122 | self.store(image_name, tensor) 123 | 124 | 125 | class RGBStorer(Storer): 126 | 127 | def sanity_check(self, tensor: torch.Tensor) -> bool: 128 | return tensor.ndim == 3 and tensor.shape[0] == 3 129 | 130 | def store(self, image_name: str, tensor: torch.Tensor) -> None: 131 | cv2_image = self.to_cv2(tensor.clamp(0, 1)) 132 | cv2.imwrite(self.path.joinpath(image_name).as_posix(), cv2_image) 133 | 134 | 135 | class AssetStorer: 136 | 137 | def __init__( 138 | self, 139 | out_path: Path, 140 | ): 141 | self.out_path = out_path 142 | out_path.mkdir(exist_ok=True) 143 | 144 | self.viz_storer = RGBStorer(out_path.joinpath("viz")) 145 | 146 | def __call__( 147 | self, 148 | image_name: str, 149 | viz_tensor: Union[torch.Tensor], # 3 X H X W 150 | ): 151 | self.viz_storer(image_name, viz_tensor) 152 | --------------------------------------------------------------------------------