├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments ├── __init__.py └── nvidia_rodynrf │ ├── Balloon1.py │ ├── Balloon2.py │ ├── Jumping.py │ ├── Playground.py │ ├── Skating.py │ ├── Truck.py │ ├── Umbrella.py │ └── default.py ├── assets ├── architecture.png └── icon.png ├── dycheck_geometry ├── __init__.py ├── barf_se3.py ├── camera.py ├── se3.py ├── trajs.py └── utils.py ├── eval.sh ├── eval_nvidia.py ├── gaussian_renderer └── __init__.py ├── gen_depth.py ├── gen_depth.sh ├── gen_tracks.py ├── gen_tracks.sh ├── install.sh ├── requirements.txt ├── requirements_unidepth.txt ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset.py ├── dataset_readers.py ├── deformation.py └── gaussian_model.py ├── train.py ├── train.sh └── utils ├── TIMES.TTF ├── TIMESBD.TTF ├── TIMESBI.TTF ├── TIMESI.TTF ├── camera_utils.py ├── depth_loss_utils.py ├── dycheck_utils ├── __init__.py ├── annotation.py ├── common.py ├── flax_multioptim.py ├── image.py ├── io.py ├── path_ops.py ├── safe_ops.py ├── struct.py ├── types.py └── visuals │ ├── __init__.py │ ├── corrs.py │ ├── depth.py │ ├── flow.py │ ├── kps │ ├── __init__.py │ └── skeleton.py │ ├── plotly.py │ └── rendering.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loader_utils.py ├── loss_utils.py ├── main_utils.py ├── model_utils.py ├── params_utils.py ├── point_utils.py ├── pose_utils.py ├── render_utils.py ├── scene_utils.py ├── sh_utils.py ├── system_utils.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | .idea 4 | output 5 | output_1 6 | build 7 | diff_rasterization/diff_rast.egg-info 8 | diff_rasterization/dist 9 | tensorboard_3d 10 | screenshots 11 | data/ 12 | data 13 | extensions 14 | lab4d 15 | preprocess -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/co-tracker"] 5 | path = submodules/co-tracker 6 | url = https://github.com/facebookresearch/co-tracker.git 7 | [submodule "submodules/UniDepth"] 8 | path = submodules/UniDepth 9 | url = https://github.com/lpiccinelli-eth/UniDepth.git 10 | [submodule "submodules/mega-sam"] 11 | path = submodules/mega-sam 12 | url = https://github.com/mega-sam/mega-sam.git 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Minh-Quan Viet Bui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

[CVPR'25] SplineGS: Robust Motion-Adaptive Spline for Real-Time Dynamic 3D Gaussians from Monocular Video

2 |
3 | 4 | **[Jongmin Park](https://sites.google.com/view/jongmin-park)1\*, [Minh-Quan Viet Bui](https://quan5609.github.io/)1\*, [Juan Luis Gonzalez Bello](https://sites.google.com/view/juan-luis-gb/home)1, [Jaeho Moon](https://sites.google.com/view/jaehomoon)1, [Jihyong Oh](https://cmlab.cau.ac.kr/)2†, [Munchurl Kim](https://www.viclab.kaist.ac.kr/)1†** 5 |
6 | 1KAIST, South Korea, 2Chung-Ang University, South Korea 7 |
8 | \*Co-first authors (equal contribution), †Co-corresponding authors 9 |

10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | GitHub Repo stars 20 |

21 | 22 |

23 | 24 |

25 | 26 | ## 📣 News 27 | ### Updates 28 | - **May 26, 2025**: Code released. 29 | - **February 26, 2025**: SplineGS accepted to CVPR 2025 🎉. 30 | - **December 13, 2024**: Paper uploaded to arXiv. Check out the manuscript [here](https://arxiv.org/abs/2412.09982).(https://arxiv.org/abs/2412.09982). 31 | ### To-Dos 32 | - Add DAVIS dataset configurations. 33 | - Add custom dataset support. 34 | - Add iPhone dataset configurations. 35 | ## ⚙️ Environmental Setups 36 | Clone the repo and install dependencies: 37 | ```sh 38 | git clone https://github.com/KAIST-VICLab/SplineGS.git --recursive 39 | cd SplineGS 40 | 41 | # install splinegs environment 42 | conda create -n splinegs python=3.7 43 | conda activate splinegs 44 | export CUDA_HOME=$CONDA_PREFIX 45 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib 46 | 47 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 48 | conda install nvidia/label/cuda-11.7.0::cuda 49 | conda install nvidia/label/cuda-11.7.0::cuda-nvcc 50 | conda install nvidia/label/cuda-11.7.0::cuda-runtime 51 | conda install nvidia/label/cuda-11.7.0::cuda-cudart 52 | 53 | 54 | pip install -e submodules/simple-knn 55 | pip install -e submodules/co-tracker 56 | pip install -r requirements.txt 57 | 58 | # install depth environment 59 | conda deactivate 60 | conda create -n unidepth_splinegs python=3.10 61 | conda activate unidepth_splinegs 62 | 63 | pip install -r requirements_unidepth.txt 64 | conda install -c conda-forge ld_impl_linux-64 65 | export CUDA_HOME=$CONDA_PREFIX 66 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib 67 | conda install nvidia/label/cuda-12.1.0::cuda 68 | conda install nvidia/label/cuda-12.1.0::cuda-nvcc 69 | conda install nvidia/label/cuda-12.1.0::cuda-runtime 70 | conda install nvidia/label/cuda-12.1.0::cuda-cudart 71 | conda install nvidia/label/cuda-12.1.0::libcusparse 72 | conda install nvidia/label/cuda-12.1.0::libcublas 73 | cd submodules/UniDepth/unidepth/ops/knn;bash compile.sh;cd ../../../../../ 74 | cd submodules/UniDepth/unidepth/ops/extract_patches;bash compile.sh;cd ../../../../../ 75 | 76 | pip install -e submodules/UniDepth 77 | mkdir -p submodules/mega-sam/Depth-Anything/checkpoints 78 | ``` 79 | ## 📁 Data Preparations 80 | ### Nvidia Dataset 81 | 1. We follow the evaluation setup from [RoDynRF](https://robust-dynrf.github.io/). Download the training images [here](https://github.com/KAIST-VICLab/SplineGS/releases/tag/dataset) and arrange them as follows: 82 | ```bash 83 | SplineGS/data/nvidia_rodynrf 84 | ├── Balloon1 85 | │ ├── images_2 86 | │ ├── instance_masks 87 | │ ├── motion_masks 88 | │ └── gt 89 | ├── ... 90 | └── Umbrella 91 | ``` 92 | 2. Download [Depth-Anything checkpoint](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitl14.pth) and place it at `submodules/mega-sam/Depth-Anything/checkpoints`. Generate depth estimation and tracking results for all scenes as: 93 | ```sh 94 | conda activate unidepth_splinegs 95 | bash gen_depth.sh 96 | 97 | conda deactivate 98 | conda activate splinegs 99 | bash gen_tracks.sh 100 | ``` 101 | 3. To obtain motion masks, please refer to [Shape of Motion](https://github.com/vye16/shape-of-motion/). For Nvidia dataset, we provide the precomputed in `motion_masks` folder 102 | ### YOUR OWN Dataset 103 | T.B.D 104 | ## 🚀 Get Started 105 | ### Nvidia Dataset 106 | #### Training 107 | ```sh 108 | # check if environment is activated properly 109 | conda activate splinegs 110 | 111 | python train.py -s data/nvidia_rodynrf/${SCENE}/ --expname "${EXP_NAME}" --configs arguments/nvidia_rodynrf/${SCENE}.py 112 | ``` 113 | #### Metrics Evaluation 114 | ```sh 115 | python eval_nvidia.py -s data/nvidia_rodynrf/${SCENE}/ --expname "${EXP_NAME}" --configs arguments/nvidia_rodynrf/${SCENE}.py --checkpoint output/${EXP_NAME}/point_cloud/fine_best 116 | ``` 117 | ### YOUR OWN Dataset 118 | #### Training 119 | T.B.D 120 | #### Evaluation 121 | T.B.D 122 | 123 | ## Acknowledgments 124 | - This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korean Government [Ministry of Science and ICT (Information and Communications Technology)] (Project Number: RS-2022-00144444, Project Title: Deep Learning Based Visual Representational Learning and Rendering of Static and Dynamic Scenes, 100%). 125 | 126 | ## ⭐ Citing SplineGS 127 | 128 | If you find our repository useful, please consider giving it a star ⭐ and citing our research papers in your work: 129 | ```bibtex 130 | @InProceedings{Park_2025_CVPR, 131 | author = {Park, Jongmin and Bui, Minh-Quan Viet and Bello, Juan Luis Gonzalez and Moon, Jaeho and Oh, Jihyong and Kim, Munchurl}, 132 | title = {SplineGS: Robust Motion-Adaptive Spline for Real-Time Dynamic 3D Gaussians from Monocular Video}, 133 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, 134 | month = {June}, 135 | year = {2025}, 136 | pages = {26866-26875} 137 | } 138 | ``` 139 | 140 | 141 | ## 📈 Star History 142 | 143 | [![Star History Chart](https://api.star-history.com/svg?repos=KAIST-VICLab/SplineGS&type=Date)](https://www.star-history.com/#KAIST-VICLab/SplineGS&Date) 144 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from argparse import ArgumentParser, Namespace 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | else: 35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | 50 | class ModelParams(ParamGroup): 51 | def __init__(self, parser, sentinel=False): 52 | self.sh_degree = 3 53 | self.deform_spatial_scale = 1e-2 54 | self.rgbfuntion = "sandwich" 55 | self.control_num = 12 56 | self.prune_error_threshold = 1.0 57 | 58 | self._source_path = "" 59 | self.dataset_type = "nvidia" 60 | self.depth_type = "depth" # depth or disp 61 | self._model_path = "" 62 | self._images = "images" 63 | self._resolution = -1 64 | self._white_background = False 65 | self.data_device = "cuda" 66 | self.eval = True 67 | self.render_process = True 68 | self.debug_process = True 69 | self.add_points = False 70 | self.extension = ".png" 71 | self.llffhold = 8 72 | super().__init__(parser, "Loading Parameters", sentinel) 73 | 74 | def extract(self, args): 75 | g = super().extract(args) 76 | g.source_path = os.path.abspath(g.source_path) 77 | return g 78 | 79 | 80 | class PipelineParams(ParamGroup): 81 | def __init__(self, parser): 82 | self.convert_SHs_python = False 83 | self.compute_cov3D_python = False 84 | self.debug = False 85 | super().__init__(parser, "Pipeline Parameters") 86 | 87 | 88 | class ModelHiddenParams(ParamGroup): 89 | def __init__(self, parser): 90 | self.timebase_pe = 10 91 | self.timenet_width = 256 92 | self.timenet_output = 6 93 | self.pixel_base_pe = 5 94 | super().__init__(parser, "ModelHiddenParams") 95 | 96 | 97 | class OptimizationParams(ParamGroup): 98 | def __init__(self, parser): 99 | self.dataloader = False 100 | self.zerostamp_init = False 101 | self.custom_sampler = None 102 | self.iterations = 30_000 103 | self.coarse_iterations = 1000 104 | self.static_iterations = 1000 105 | 106 | self.position_lr_init = 0.00016 107 | self.position_lr_final = 0.0000016 108 | self.position_lr_delay_mult = 0.01 109 | self.position_lr_max_steps = 20_000 110 | 111 | self.deformation_lr_init = 0.00016 112 | self.deformation_lr_final = 0.000016 113 | self.deformation_lr_delay_mult = 0.01 114 | 115 | self.grid_lr_init = 0.0016 116 | self.grid_lr_final = 0.00016 117 | 118 | self.pose_lr_init = 0.0005 119 | self.pose_lr_final = 0.00005 120 | self.pose_lr_delay_mult = 0.01 121 | 122 | self.feature_lr = 0.0025 123 | self.featuret_lr = 0.001 124 | self.opacity_lr = 0.05 125 | self.scaling_lr = 0.005 126 | self.rotation_lr = 0.001 127 | self.percent_dense = 0.01 128 | self.lambda_dssim = 0.2 129 | self.p_lambda_dssim = 0.0 130 | self.lambda_lpips = 0 131 | self.weight_constraint_init = 1 132 | self.weight_constraint_after = 0.2 133 | self.weight_decay_iteration = 5_000 134 | self.opacity_reset_interval = 3_000 135 | self.densification_interval = 1_00 136 | self.densify_from_iter = 500 137 | self.densify_until_iter = 15_000 138 | self.densify_grad_threshold_coarse = 0.0002 139 | self.densify_grad_threshold_fine_init = 0.0002 140 | self.densify_grad_threshold_after = 0.0002 141 | self.pruning_from_iter = 500 142 | self.pruning_interval = 100 143 | self.opacity_threshold_coarse = 0.005 144 | self.opacity_threshold_fine_init = 0.005 145 | self.opacity_threshold_fine_after = 0.005 146 | self.fine_batch_size = 1 147 | self.coarse_batch_size = 1 148 | self.add_point = False 149 | self.use_instance_mask = False 150 | 151 | self.prevpath = "1" 152 | self.opthr = 0.005 153 | self.desicnt = 6 154 | self.densify = 1 155 | self.densify_grad_threshold = 0.0008 156 | self.densify_grad_threshold_dynamic = 0.00008 157 | self.preprocesspoints = 0 158 | self.addsphpointsscale = 0.8 159 | self.raystart = 0.7 160 | 161 | self.soft_depth_start = 1000 162 | self.hard_depth_start = 0 163 | self.error_tolerance = 0.001 164 | 165 | self.trbfc_lr = 0.0001 # 166 | self.trbfs_lr = 0.03 167 | self.trbfslinit = 0.0 # 168 | self.omega_lr = 0.0001 169 | self.zeta_lr = 0.0001 170 | self.movelr = 3.5 171 | self.rgb_lr = 0.0001 172 | 173 | self.stat_npts = 20000 174 | self.dyn_npts = 20000 175 | 176 | self.w_depth = 1.0 177 | self.w_mask = 2.0 178 | self.w_track = 1.0 179 | self.w_normal = 0 180 | super().__init__(parser, "Optimization Parameters") 181 | 182 | 183 | def get_combined_args(parser: ArgumentParser): 184 | cmdlne_string = sys.argv[1:] 185 | cfgfile_string = "Namespace()" 186 | args_cmdline = parser.parse_args(cmdlne_string) 187 | 188 | try: 189 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 190 | print("Looking for config file in", cfgfilepath) 191 | with open(cfgfilepath) as cfg_file: 192 | print("Config file found: {}".format(cfgfilepath)) 193 | cfgfile_string = cfg_file.read() 194 | except TypeError: 195 | print("Config file not found at") 196 | args_cfgfile = eval(cfgfile_string) 197 | 198 | merged_dict = vars(args_cfgfile).copy() 199 | for k, v in vars(args_cmdline).items(): 200 | if v is not None: 201 | merged_dict[k] = v 202 | return Namespace(**merged_dict) 203 | -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Balloon1.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | ModelParams = dict( 4 | depth_type="disp", 5 | ) 6 | 7 | OptimizationParams = dict( 8 | densify_grad_threshold_dynamic = 0.0002, 9 | densify_grad_threshold = 0.0008, 10 | use_instance_mask=True, 11 | ) -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Balloon2.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | ModelParams = dict( 4 | depth_type="disp", 5 | ) 6 | 7 | OptimizationParams = dict( 8 | densify_grad_threshold_dynamic = 0.0002, 9 | densify_grad_threshold = 0.0008, 10 | ) -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Jumping.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | OptimizationParams = dict( 4 | densify_grad_threshold = 0.0002, 5 | densify_grad_threshold_dynamic = 0.0002 6 | ) -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Playground.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | OptimizationParams = dict( 4 | densify_grad_threshold = 0.0002, 5 | densify_grad_threshold_dynamic = 0.0002, 6 | opacity_reset_interval=30_000, 7 | use_instance_mask=True, 8 | ) 9 | -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Skating.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | OptimizationParams = dict( 4 | densify_grad_threshold = 0.0002, 5 | densify_grad_threshold_dynamic = 0.00008, 6 | ) -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Truck.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | OptimizationParams = dict( 4 | densify_grad_threshold = 0.0008, 5 | densify_grad_threshold_dynamic = 0.0002 6 | ) -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/Umbrella.py: -------------------------------------------------------------------------------- 1 | _base_ = "./default.py" 2 | 3 | OptimizationParams = dict( 4 | w_normal=1.0, 5 | opacity_reset_interval=30_000, 6 | densify_grad_threshold = 0.0002, 7 | densify_grad_threshold_dynamic = 0.00002, 8 | ) 9 | -------------------------------------------------------------------------------- /arguments/nvidia_rodynrf/default.py: -------------------------------------------------------------------------------- 1 | OptimizationParams = dict( 2 | iterations=25_000, 3 | coarse_batch_size=2, 4 | fine_batch_size=2, 5 | coarse_iterations=5_000, 6 | static_iterations=5_000, 7 | densify_from_iter=500, 8 | densify_until_iter=12000, 9 | opacity_reset_interval=3_000, 10 | densify=1, 11 | stat_npts=20000, 12 | dyn_npts=10000, 13 | desicnt=12, 14 | ) -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/assets/architecture.png -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/assets/icon.png -------------------------------------------------------------------------------- /dycheck_geometry/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | from .camera import Camera, project, get_rays_direction 20 | from .se3 import ( 21 | exp_se3, 22 | exp_so3, 23 | from_homogenous, 24 | rt_to_se3, 25 | skew, 26 | to_homogenous, 27 | ) 28 | from .trajs import get_arc_traj, get_lemniscate_traj 29 | from .utils import matmul, matv -------------------------------------------------------------------------------- /dycheck_geometry/barf_se3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : se3.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | 21 | import jax 22 | from jax import numpy as np 23 | 24 | 25 | # def func_A(x): 26 | # return np.sin(x)/x 27 | 28 | # def func_B(x): 29 | # return (1-np.cos(x))/x**2 30 | 31 | # def func_C(x): 32 | # return (x-np.sin(x))/x**3 33 | 34 | # # a recursive definition shares some work 35 | # def taylor(f, order): 36 | # def improve_approx(g, k): 37 | # return lambda x, v: jvp_first(g, (x, v), v)[1] + f(x) / factorial(k) 38 | # approx = lambda x, v: f(x) / factorial(order) 39 | # for n in range(order): 40 | # approx = improve_approx(approx, order - n - 1) 41 | # return approx 42 | 43 | # def jvp_first(f, primals, tangent): 44 | # x, xs = primals[0], primals[1:] 45 | # return jvp(lambda x: f(x, *xs), (x,), (tangent,)) 46 | 47 | 48 | def procrustes_analysis(X0, X1): # [N,3] X0 is target X1 is src 49 | # translation 50 | t0 = X0.mean(axis=0, keepdims=True) 51 | t1 = X1.mean(axis=0, keepdims=True) 52 | X0c = X0 - t0 53 | X1c = X1 - t1 54 | # scale 55 | s0 = np.sqrt((X0c**2).sum(axis=-1).mean()) 56 | s1 = np.sqrt((X1c**2).sum(axis=-1).mean()) 57 | X0cs = X0c / s0 58 | X1cs = X1c / s1 59 | # rotation (use double for SVD, float loses precision) 60 | 61 | U, S, V = np.linalg.svd((X0cs.T @ X1cs), full_matrices=False) 62 | R = U @ V.T 63 | 64 | if np.linalg.det(R) < 0: 65 | R = R.at[2].set(-R[2]) 66 | sim3 = edict(t0=t0[0], t1=t1[0], s0=s0, s1=s1, R=R) 67 | return sim3 68 | 69 | 70 | @jax.jit 71 | def skew(w: np.ndarray) -> np.ndarray: 72 | """Build a skew matrix ("cross product matrix") for vector w. 73 | Modern Robotics Eqn 3.30. 74 | 75 | Args: 76 | w: (..., 3,) A 3-vector 77 | 78 | Returns: 79 | W: (..., 3, 3) A skew matrix such that W @ v == w x v 80 | """ 81 | zeros = np.zeros_like(w[..., 0]) 82 | return np.stack( 83 | [ 84 | np.stack([zeros, -w[..., 2], w[..., 1]], axis=-1), 85 | np.stack([w[..., 2], zeros, -w[..., 0]], axis=-1), 86 | np.stack([-w[..., 1], w[..., 0], zeros], axis=-1), 87 | ], 88 | axis=-2, 89 | ) 90 | 91 | 92 | @jax.jit 93 | def taylor_A(x, nth=10): 94 | # Taylor expansion of sin(x)/x 95 | ans = np.zeros_like(x) 96 | denom = 1.0 97 | for i in range(nth + 1): 98 | if i > 0: 99 | denom *= (2 * i) * (2 * i + 1) 100 | ans = ans + (-1) ** i * x ** (2 * i) / denom 101 | return ans 102 | 103 | 104 | @jax.jit 105 | def taylor_B(x, nth=10): 106 | # Taylor expansion of (1-cos(x))/x**2 107 | ans = np.zeros_like(x) 108 | denom = 1.0 109 | for i in range(nth + 1): 110 | denom *= (2 * i + 1) * (2 * i + 2) 111 | ans = ans + (-1) ** i * x ** (2 * i) / denom 112 | return ans 113 | 114 | 115 | @jax.jit 116 | def taylor_C(x, nth=10): 117 | # Taylor expansion of (x-sin(x))/x**3 118 | ans = np.zeros_like(x) 119 | denom = 1.0 120 | for i in range(nth + 1): 121 | denom *= (2 * i + 2) * (2 * i + 3) 122 | ans = ans + (-1) ** i * x ** (2 * i) / denom 123 | return ans 124 | 125 | 126 | def se3_to_SE3(w: np.ndarray, u: np.ndarray) -> np.ndarray: 127 | wx = skew(w) 128 | theta = np.linalg.norm(w, axis=-1)[..., None, None] 129 | I = np.identity(3, dtype=w.dtype) 130 | 131 | A = taylor_A(theta) 132 | B = taylor_B(theta) 133 | C = taylor_C(theta) 134 | 135 | # check nerfies: R is e^r, V is G, u is v, wx is [r] 136 | 137 | R = I + A * wx + B * wx @ wx 138 | V = I + B * wx + C * wx @ wx 139 | 140 | t = V @ u[..., None] 141 | return R, t 142 | -------------------------------------------------------------------------------- /dycheck_geometry/se3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : se3.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import numpy as np 21 | 22 | from . import utils 23 | 24 | 25 | def skew(w: np.ndarray) -> np.ndarray: 26 | """Build a skew matrix ("cross product matrix") for vector w. 27 | Modern Robotics Eqn 3.30. 28 | 29 | Args: 30 | w: (..., 3,) A 3-vector 31 | 32 | Returns: 33 | W: (..., 3, 3) A skew matrix such that W @ v == w x v 34 | """ 35 | zeros = np.zeros_like(w[..., 0]) 36 | return np.stack( 37 | [ 38 | np.stack([zeros, -w[..., 2], w[..., 1]], axis=-1), 39 | np.stack([w[..., 2], zeros, -w[..., 0]], axis=-1), 40 | np.stack([-w[..., 1], w[..., 0], zeros], axis=-1), 41 | ], 42 | axis=-2, 43 | ) 44 | 45 | 46 | def rt_to_se3(R: np.ndarray, t: np.ndarray) -> np.ndarray: 47 | """Rotation and translation to homogeneous transform. 48 | 49 | Args: 50 | R: (..., 3, 3) An orthonormal rotation matrix. 51 | t: (..., 3,) A 3-vector representing an offset. 52 | 53 | Returns: 54 | X: (..., 4, 4) The homogeneous transformation matrix described by 55 | rotating by R and translating by t. 56 | """ 57 | batch_shape = R.shape[:-2] 58 | return np.concatenate( 59 | [ 60 | np.concatenate([R, t[..., None]], axis=-1), 61 | np.broadcast_to(np.array([[0, 0, 0, 1]], np.float32), batch_shape + (1, 4)), 62 | ], 63 | axis=-2, 64 | ) 65 | 66 | 67 | def exp_so3(w: np.ndarray, theta: np.ndarray) -> np.ndarray: 68 | """Exponential map from Lie algebra so3 to Lie group SO3. 69 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 70 | 71 | Args: 72 | w: (..., 3,) An axis of rotation. This is assumed to be a unit-vector. 73 | theta (...,): An angle of rotation. 74 | 75 | Returns: 76 | R: (..., 3, 3) An orthonormal rotation matrix representing a rotation 77 | of magnitude theta about axis w. 78 | """ 79 | batch_shape = w.shape[:-1] 80 | W = skew(w) 81 | return ( 82 | np.broadcast_to(np.eye(3), batch_shape + (3, 3)) 83 | + np.sin(theta)[..., None, None] * W 84 | + (1 - np.cos(theta)[..., None, None]) * utils.matmul(W, W) 85 | ) 86 | 87 | 88 | def exp_se3(S: np.ndarray, theta: np.ndarray) -> np.ndarray: 89 | """Exponential map from Lie algebra so3 to Lie group SO3. 90 | 91 | Modern Robotics Eqn 3.88. 92 | 93 | Args: 94 | S: (..., 6,) A screw axis of motion. 95 | theta (...,): Magnitude of motion. 96 | 97 | Returns: 98 | a_X_b: (..., 4, 4) The homogeneous transformation matrix attained by 99 | integrating motion of magnitude theta about S for one second. 100 | """ 101 | batch_shape = S.shape[:-1] 102 | w, v = np.split(S, 2, axis=-1) 103 | W = skew(w) 104 | R = exp_so3(w, theta) 105 | t = utils.matv( 106 | ( 107 | theta[..., None, None] * np.broadcast_to(np.eye(3), batch_shape + (3, 3)) 108 | + (1 - np.cos(theta)[..., None, None]) * W 109 | + (theta[..., None, None] - np.sin(theta)[..., None, None]) * utils.matmul(W, W) 110 | ), 111 | v, 112 | ) 113 | return rt_to_se3(R, t) 114 | 115 | 116 | def to_homogenous(v): 117 | return np.concatenate([v, np.ones_like(v[..., :1])], axis=-1) 118 | 119 | 120 | def from_homogenous(v): 121 | return v[..., :3] / v[..., -1:] 122 | -------------------------------------------------------------------------------- /dycheck_geometry/trajs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : trajs.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import List 21 | 22 | import numpy as np 23 | from scipy.spatial.transform import Rotation 24 | 25 | from .camera import Camera 26 | from .utils import matv 27 | 28 | 29 | def get_arc_traj( 30 | ref_camera: Camera, 31 | lookat: np.ndarray, 32 | up: np.ndarray, 33 | *, 34 | num_frames: int, 35 | degree: float, 36 | **_, 37 | ) -> List[Camera]: 38 | positions = [ 39 | matv(Rotation.from_rotvec(d / 180 * np.pi * up).as_matrix() @ (ref_camera.position - lookat)) + lookat 40 | for d in np.linspace(-degree / 2, degree / 2, num_frames) 41 | ] 42 | cameras = [ref_camera.lookat(p, lookat, up) for p in positions] 43 | return cameras 44 | 45 | 46 | def get_lemniscate_traj( 47 | ref_camera: Camera, 48 | lookat: np.ndarray, 49 | up: np.ndarray, 50 | *, 51 | num_frames: int, 52 | degree: float, 53 | **_, 54 | ) -> List[Camera]: 55 | a = np.linalg.norm(ref_camera.position - lookat) * np.tan(degree / 360 * np.pi) 56 | # Lemniscate curve in camera space. Starting at the origin. 57 | positions = np.array( 58 | [ 59 | np.array( 60 | [ 61 | a * np.cos(t) / (1 + np.sin(t) ** 2), 62 | a * np.cos(t) * np.sin(t) / (1 + np.sin(t) ** 2), 63 | 0, 64 | ] 65 | ) 66 | for t in (np.linspace(0, 2 * np.pi, num_frames) + np.pi / 2) 67 | ] 68 | ) 69 | # Transform to world space. 70 | positions = matv(ref_camera.orientation.T, positions) + ref_camera.position 71 | cameras = [ref_camera.lookat(p, lookat, up) for p in positions] 72 | return cameras 73 | -------------------------------------------------------------------------------- /dycheck_geometry/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : utils.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import numpy as np 21 | from utils.dycheck_utils import types 22 | 23 | 24 | def matmul(a: types.Array, b: types.Array) -> types.Array: 25 | if isinstance(a, np.ndarray): 26 | assert isinstance(b, np.ndarray) 27 | else: 28 | assert isinstance(a, np.ndarray) 29 | assert isinstance(b, np.ndarray) 30 | 31 | if isinstance(a, np.ndarray): 32 | return a @ b 33 | else: 34 | # NOTE: The original implementation uses highest precision for TPU 35 | # computation. Since we are using GPUs only, comment it out. 36 | # return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST) 37 | return np.matmul(a, b) 38 | 39 | 40 | def matv(a: types.Array, b: types.Array) -> types.Array: 41 | return matmul(a, b[..., None])[..., 0] 42 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | python eval_nvidia.py -s data/nvidia_rodynrf/Balloon1/ --expname "Balloon1" --configs arguments/nvidia_rodynrf/Balloon1.py --checkpoint output/Balloon1/point_cloud/fine_best 2 | python eval_nvidia.py -s data/nvidia_rodynrf/Balloon2/ --expname "Balloon2" --configs arguments/nvidia_rodynrf/Balloon2.py --checkpoint output/Balloon2/point_cloud/fine_best 3 | python eval_nvidia.py -s data/nvidia_rodynrf/Jumping/ --expname "Jumping" --configs arguments/nvidia_rodynrf/Jumping.py --checkpoint output/Jumping/point_cloud/fine_best 4 | python eval_nvidia.py -s data/nvidia_rodynrf/Playground/ --expname "Playground" --configs arguments/nvidia_rodynrf/Playground.py --checkpoint output/Playground/point_cloud/fine_best 5 | python eval_nvidia.py -s data/nvidia_rodynrf/Skating/ --expname "Skating" --configs arguments/nvidia_rodynrf/Skating.py --checkpoint output/Skating/point_cloud/fine_best 6 | python eval_nvidia.py -s data/nvidia_rodynrf/Umbrella/ --expname "Umbrella" --configs arguments/nvidia_rodynrf/Umbrella.py --checkpoint output/Umbrella/point_cloud/fine_best 7 | -------------------------------------------------------------------------------- /eval_nvidia.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | sys.path.append(os.path.join(sys.path[0], "..")) 5 | import cv2 6 | import lpips 7 | import numpy as np 8 | from argparse import ArgumentParser 9 | from arguments import ModelHiddenParams, ModelParams, OptimizationParams, PipelineParams 10 | from gaussian_renderer import render_infer 11 | from PIL import Image 12 | from scene import GaussianModel, Scene, dataset_readers 13 | from utils.graphics_utils import pts2pixel 14 | from utils.main_utils import get_pixels 15 | from utils.image_utils import psnr 16 | from gsplat.rendering import fully_fused_projection 17 | from scene import GaussianModel, Scene, dataset_readers, deformation 18 | import random 19 | 20 | 21 | def normalize_image(img): 22 | return (2.0 * img - 1.0)[None, ...] 23 | 24 | 25 | def training_report(scene: Scene, train_cams, test_cams, renderFunc, background, stage, dataset_type, path): 26 | test_psnr = 0.0 27 | torch.cuda.empty_cache() 28 | 29 | validation_configs = ({"name": "test", "cameras": test_cams}, {"name": "train", "cameras": train_cams}) 30 | lpips_loss = lpips.LPIPS(net="alex").cuda() 31 | 32 | start_event = torch.cuda.Event(enable_timing=True) 33 | end_event = torch.cuda.Event(enable_timing=True) 34 | 35 | for config in validation_configs: 36 | if config["cameras"] and len(config["cameras"]) > 0: 37 | l1_test = 0.0 38 | psnr_test = 0.0 39 | lpips_test = 0.0 40 | run_time = 0.0 41 | elapsed_time_ms_list = [] 42 | for idx, viewpoint in enumerate(config["cameras"]): 43 | 44 | if idx == 0: # warmup iter 45 | for _ in range(5): 46 | render_pkg = renderFunc( 47 | viewpoint, scene.stat_gaussians, scene.dyn_gaussians, background 48 | ) 49 | 50 | torch.cuda.synchronize() 51 | start_event.record() 52 | render_pkg = renderFunc( 53 | viewpoint, scene.stat_gaussians, scene.dyn_gaussians, background 54 | ) 55 | end_event.record() 56 | torch.cuda.synchronize() 57 | elapsed_time_ms = start_event.elapsed_time(end_event) 58 | elapsed_time_ms_list.append(elapsed_time_ms) 59 | run_time += elapsed_time_ms 60 | 61 | image = render_pkg["render"] 62 | image = torch.clamp(image, 0.0, 1.0) 63 | 64 | img = Image.fromarray( 65 | (np.clip(image.permute(1, 2, 0).detach().cpu().numpy(), 0, 1) * 255).astype("uint8") 66 | ) 67 | os.makedirs(path + "/{}".format(config["name"]), exist_ok=True) 68 | img.save(path + "/{}/img_{}.png".format(config["name"], idx)) 69 | 70 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 71 | 72 | psnr_test += psnr(image, gt_image, mask=None).mean().double() 73 | lpips_test += lpips_loss.forward(normalize_image(image), normalize_image(gt_image)).item() 74 | 75 | psnr_test /= len(config["cameras"]) 76 | l1_test /= len(config["cameras"]) 77 | lpips_test /= len(config["cameras"]) 78 | run_time /= len(config["cameras"]) 79 | 80 | print( 81 | "\n[ITER {}] Evaluating {}: PSNR {}, LPIPS {}, FPS {}".format( 82 | -1, config["name"], psnr_test, lpips_test, 1 / (run_time / 1000) 83 | ) 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = ArgumentParser(description="Training script parameters") 89 | lp = ModelParams(parser) 90 | op = OptimizationParams(parser) 91 | pp = PipelineParams(parser) 92 | hp = ModelHiddenParams(parser) 93 | 94 | parser.add_argument( 95 | "--checkpoint", type=str, required=True, help="Path to the checkpoint file", 96 | ) 97 | parser.add_argument("--expname", type=str, default="") 98 | parser.add_argument("--configs", type=str, default="") 99 | 100 | args = parser.parse_args(sys.argv[1:]) 101 | if args.configs: 102 | import mmengine as mmcv 103 | from utils.params_utils import merge_hparams 104 | 105 | config = mmcv.Config.fromfile(args.configs) 106 | args = merge_hparams(args, config) 107 | 108 | dataset = lp.extract(args) 109 | hyper = hp.extract(args) 110 | stat_gaussians = GaussianModel(dataset) 111 | dyn_gaussians = GaussianModel(dataset) 112 | opt = op.extract(args) 113 | 114 | scene = Scene( 115 | dataset, dyn_gaussians, stat_gaussians, load_coarse=None 116 | ) # for other datasets rather than iPhone dataset 117 | 118 | dyn_gaussians.create_pose_network(hyper, scene.getTrainCameras()) # pose network with instance scaling 119 | 120 | bg_color = [1] * 9 + [0] if dataset.white_background else [0] * 9 + [0] 121 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 122 | pipe = pp.extract(args) 123 | 124 | test_cams = scene.getTestCameras() 125 | train_cams = scene.getTrainCameras() 126 | my_test_cams = [i for i in test_cams] 127 | viewpoint_stack = [i for i in train_cams] 128 | 129 | # if os.path.exists(os.path.join(args.checkpoint, "compact_point_cloud.npz")): 130 | # if False: # TODO: remove this after training 131 | # dyn_gaussians.load_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud.ply")) 132 | # stat_gaussians.load_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud_static.ply")) 133 | # else: 134 | dyn_gaussians.load_ply(os.path.join(args.checkpoint, "point_cloud.ply")) 135 | stat_gaussians.load_ply(os.path.join(args.checkpoint, "point_cloud_static.ply")) 136 | 137 | dyn_gaussians.flatten_control_point() # TODO: support this saving in training 138 | stat_gaussians.save_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud_static.ply")) 139 | dyn_gaussians.save_ply_compact_dy(os.path.join(args.checkpoint, "compact_point_cloud.ply")) 140 | 141 | 142 | dyn_gaussians.load_model(args.checkpoint) 143 | dyn_gaussians._posenet.eval() 144 | 145 | 146 | pixels = get_pixels( 147 | scene.train_camera.dataset[0].metadata.image_size_x, 148 | scene.train_camera.dataset[0].metadata.image_size_y, 149 | use_center=True, 150 | ) 151 | if pixels.shape[-1] != 2: 152 | raise ValueError("The last dimension of pixels must be 2.") 153 | batch_shape = pixels.shape[:-1] 154 | pixels = np.reshape(pixels, (-1, 2)) 155 | y = ( 156 | pixels[..., 1] - scene.train_camera.dataset[0].metadata.principal_point_y 157 | ) / dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy() 158 | x = ( 159 | pixels[..., 0] - scene.train_camera.dataset[0].metadata.principal_point_x 160 | ) / dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy() 161 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1) 162 | local_viewdirs = viewdirs / np.linalg.norm(viewdirs, axis=-1, keepdims=True) 163 | 164 | with torch.no_grad(): 165 | for cam in viewpoint_stack: 166 | time_in = torch.tensor(cam.time).float().cuda() 167 | pred_R, pred_T = dyn_gaussians._posenet(time_in.view(1, 1)) 168 | R_ = torch.transpose(pred_R, 2, 1).detach().cpu().numpy() 169 | t_ = pred_T.detach().cpu().numpy() 170 | cam.update_cam( 171 | R_[0], 172 | t_[0], 173 | local_viewdirs, 174 | batch_shape, 175 | dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy(), 176 | ) 177 | 178 | for view_id in range(len(my_test_cams)): 179 | my_test_cams[view_id].update_cam( 180 | viewpoint_stack[0].R, 181 | viewpoint_stack[0].T, 182 | local_viewdirs, 183 | batch_shape, 184 | dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy(), 185 | ) 186 | 187 | training_report( 188 | scene, 189 | viewpoint_stack, 190 | my_test_cams, 191 | render_infer, 192 | background, 193 | "fine", 194 | scene.dataset_type, 195 | os.path.join("output", args.expname), 196 | ) -------------------------------------------------------------------------------- /gen_depth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | from submodules.UniDepth.unidepth.models import UniDepthV1, UniDepthV2, UniDepthV2old 7 | from submodules.UniDepth.unidepth.utils import colorize, image_grid 8 | 9 | import glob 10 | import os 11 | 12 | def gen_depth(model, args): 13 | focal = 500 # default 500 for all datasets 14 | 15 | images_list = sorted(glob.glob(os.path.join(args.image_dir, '*.png'))) 16 | os.makedirs(args.out_dir, exist_ok=True) 17 | for image in images_list: 18 | rgb = np.array(Image.open(image).convert('RGB')) 19 | rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1) 20 | 21 | H, W = rgb_torch.shape[1], rgb_torch.shape[2] 22 | intrinsics_torch = torch.from_numpy(np.array([[focal, 0, H/2], 23 | [0, focal, W/2], 24 | [0, 0 ,1]])).float() 25 | # predict 26 | predictions = model.infer(rgb_torch, intrinsics_torch) 27 | depth_pred = predictions["depth"].squeeze().cpu().numpy()[..., None] 28 | 29 | fname = os.path.basename(image) 30 | np.save(os.path.join(args.out_dir, fname.replace('png', 'npy')), depth_pred) 31 | 32 | # # colorize 33 | depth_pred_col = colorize(depth_pred) 34 | Image.fromarray(depth_pred_col).save(os.path.join(args.out_dir, fname)) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--image_dir", type=str, required=True, help="image dir") 40 | parser.add_argument("--out_dir", type=str, required=True, help="out dir") 41 | parser.add_argument("--depth_type", type=str, help="depth type disp or depth", default="depth") 42 | parser.add_argument("--depth_model", type=str, help="unidepth model", default="v2old") 43 | args = parser.parse_args() 44 | 45 | print("Torch version:", torch.__version__) 46 | 47 | if args.depth_type == "disp": 48 | os.makedirs(args.out_dir, exist_ok=True) 49 | cmd = f"python submodules/mega-sam/Depth-Anything/run_videos.py --encoder vitl \ 50 | --load-from submodules/mega-sam/Depth-Anything/checkpoints/depth_anything_vitl14.pth \ 51 | --img-path {args.image_dir} \ 52 | --outdir {args.out_dir}" 53 | os.system(cmd) 54 | elif args.depth_type == "depth": 55 | type_ = "l" # available types: s, b, l 56 | name = f"unidepth-{args.depth_model}-vit{type_}14" 57 | if args.depth_model == "v2": 58 | model = UniDepthV2.from_pretrained(f"lpiccinelli/{name}") 59 | # set resolution level (only V2) 60 | # model.resolution_level = 9 61 | 62 | # set interpolation mode (only V2) 63 | model.interpolation_mode = "bilinear" 64 | elif args.depth_model == "v2old": 65 | model = UniDepthV2old.from_pretrained(f"lpiccinelli/{name}") 66 | 67 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 68 | model = model.to(device).eval() 69 | 70 | gen_depth(model, args) 71 | else: 72 | raise ValueError("depth_type must be either 'disp' or 'depth'") 73 | 74 | -------------------------------------------------------------------------------- /gen_depth.sh: -------------------------------------------------------------------------------- 1 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --out_dir data/nvidia_rodynrf/Balloon1/uni_depth 2 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --out_dir data/nvidia_rodynrf/Balloon1/depth_anything --depth_type disp 3 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --out_dir data/nvidia_rodynrf/Balloon2/uni_depth 4 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --out_dir data/nvidia_rodynrf/Balloon2/depth_anything --depth_type disp 5 | python gen_depth.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --out_dir data/nvidia_rodynrf/Jumping/uni_depth 6 | python gen_depth.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --out_dir data/nvidia_rodynrf/Jumping/depth_anything --depth_type disp 7 | python gen_depth.py --image_dir data/nvidia_rodynrf/Playground/images_2 --out_dir data/nvidia_rodynrf/Playground/uni_depth --depth_model v2 8 | python gen_depth.py --image_dir data/nvidia_rodynrf/Playground/images_2 --out_dir data/nvidia_rodynrf/Playground/depth_anything --depth_type disp 9 | python gen_depth.py --image_dir data/nvidia_rodynrf/Skating/images_2 --out_dir data/nvidia_rodynrf/Skating/uni_depth 10 | python gen_depth.py --image_dir data/nvidia_rodynrf/Skating/images_2 --out_dir data/nvidia_rodynrf/Skating/depth_anything --depth_type disp 11 | python gen_depth.py --image_dir data/nvidia_rodynrf/Truck/images_2 --out_dir data/nvidia_rodynrf/Truck/uni_depth 12 | python gen_depth.py --image_dir data/nvidia_rodynrf/Truck/images_2 --out_dir data/nvidia_rodynrf/Truck/depth_anything --depth_type disp 13 | python gen_depth.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --out_dir data/nvidia_rodynrf/Umbrella/uni_depth 14 | python gen_depth.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --out_dir data/nvidia_rodynrf/Umbrella/depth_anything --depth_type disp -------------------------------------------------------------------------------- /gen_tracks.py: -------------------------------------------------------------------------------- 1 | 2 | """ Code borrowed from 3 | https://github.com/vye16/shape-of-motion/blob/main/preproc/compute_tracks_torch.py 4 | """ 5 | import argparse 6 | import glob 7 | import os 8 | 9 | from PIL import Image 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from cotracker.utils.visualizer import Visualizer 15 | 16 | DEFAULT_DEVICE = ( 17 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 18 | "cuda" 19 | if torch.cuda.is_available() 20 | else "cpu" 21 | ) 22 | 23 | def read_video(folder_path): 24 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*"))) 25 | video = np.concatenate([np.array(Image.open(frame_path)).transpose(2, 0, 1)[None, None] for frame_path in frame_paths], axis=1) 26 | video = torch.from_numpy(video).float() 27 | return video 28 | 29 | def read_mask(folder_path): 30 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*"))) 31 | video = np.concatenate([np.array(Image.open(frame_path))[None, None] for frame_path in frame_paths], axis=1) 32 | video = torch.from_numpy(video).float() 33 | return video 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--image_dir", type=str, required=True, help="image dir") 38 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir") 39 | parser.add_argument("--out_dir", type=str, required=True, help="out dir") 40 | parser.add_argument("--is_static", action="store_true") 41 | parser.add_argument("--grid_size", type=int, default=100, help="Regular grid size") 42 | parser.add_argument( 43 | "--grid_query_frame", 44 | type=int, 45 | default=0, 46 | help="Compute dense and grid tracks starting from this frame", 47 | ) 48 | parser.add_argument( 49 | "--backward_tracking", 50 | action="store_true", 51 | help="Compute tracks in both directions, not only forward", 52 | ) 53 | args = parser.parse_args() 54 | 55 | folder_path = args.image_dir 56 | mask_dir = args.mask_dir 57 | frame_names = [ 58 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*"))) 59 | ] 60 | out_dir = args.out_dir 61 | os.makedirs(out_dir, exist_ok=True) 62 | os.makedirs(os.path.join(out_dir, "vis"), exist_ok=True) 63 | 64 | done = True 65 | for t in range(len(frame_names)): 66 | for j in range(len(frame_names)): 67 | name_t = os.path.splitext(frame_names[t])[0] 68 | name_j = os.path.splitext(frame_names[j])[0] 69 | out_path = f"{out_dir}/{name_t}_{name_j}.npy" 70 | if not os.path.exists(out_path): 71 | done = False 72 | break 73 | print(f"{done}") 74 | if done: 75 | print("Already done") 76 | return 77 | 78 | ## Load model 79 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(DEFAULT_DEVICE) 80 | video = read_video(folder_path).to(DEFAULT_DEVICE) 81 | 82 | masks = read_mask(mask_dir).to(DEFAULT_DEVICE) 83 | 84 | masks[masks>0] = 1. 85 | if args.is_static: 86 | masks = 1.0 - masks 87 | 88 | _, num_frames,_, height, width = video.shape 89 | vis = Visualizer(save_dir=os.path.join(out_dir, "vis"), pad_value=120, linewidth=3) 90 | 91 | for t in tqdm(range(num_frames), desc="query frames"): 92 | name_t = os.path.splitext(frame_names[t])[0] 93 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy") 94 | if len(file_matches) == num_frames: 95 | print(f"Already computed tracks with query {t} {name_t}") 96 | continue 97 | 98 | current_mask = masks[:,t].unsqueeze(1) 99 | start_pred = None 100 | 101 | for j in range(num_frames): 102 | if j > t: 103 | current_video = video[:,t:j+1] 104 | elif j < t: 105 | current_video = torch.flip(video[:,j:t+1], dims=(1,)) # reverse 106 | else: 107 | continue 108 | # current_video = video[:,t:t+1] 109 | 110 | 111 | pred_tracks, pred_visibility = model( 112 | current_video, 113 | grid_size=args.grid_size, 114 | grid_query_frame=0, 115 | backward_tracking=False, 116 | segm_mask=current_mask 117 | ) 118 | 119 | 120 | pred = torch.cat([pred_tracks, pred_visibility.unsqueeze(-1)], dim=-1) 121 | current_pred = pred[0,-1] 122 | start_pred = pred[0,0] 123 | 124 | # save 125 | name_j = os.path.splitext(frame_names[j])[0] 126 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", current_pred.cpu().numpy()) 127 | 128 | # visualize 129 | # vis.visualize(current_video, pred_tracks, pred_visibility, filename=f"{name_t}_{name_j}") 130 | 131 | np.save(f"{out_dir}/{name_t}_{name_t}.npy", start_pred.cpu().numpy()) 132 | 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /gen_tracks.sh: -------------------------------------------------------------------------------- 1 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --mask_dir data/nvidia_rodynrf/Balloon1/motion_masks --out_dir data/nvidia_rodynrf/Balloon1/bootscotracker_dynamic --grid_size 256 2 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --mask_dir data/nvidia_rodynrf/Balloon1/motion_masks --out_dir data/nvidia_rodynrf/Balloon1/bootscotracker_static --is_static --grid_size 50 3 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --mask_dir data/nvidia_rodynrf/Balloon2/motion_masks --out_dir data/nvidia_rodynrf/Balloon2/bootscotracker_dynamic --grid_size 256 4 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --mask_dir data/nvidia_rodynrf/Balloon2/motion_masks --out_dir data/nvidia_rodynrf/Balloon2/bootscotracker_static --is_static --grid_size 50 5 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --mask_dir data/nvidia_rodynrf/Jumping/motion_masks --out_dir data/nvidia_rodynrf/Jumping/bootscotracker_dynamic --grid_size 256 6 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --mask_dir data/nvidia_rodynrf/Jumping/motion_masks --out_dir data/nvidia_rodynrf/Jumping/bootscotracker_static --is_static --grid_size 50 7 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Playground/images_2 --mask_dir data/nvidia_rodynrf/Playground/motion_masks --out_dir data/nvidia_rodynrf/Playground/bootscotracker_dynamic --grid_size 256 8 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Playground/images_2 --mask_dir data/nvidia_rodynrf/Playground/motion_masks --out_dir data/nvidia_rodynrf/Playground/bootscotracker_static --is_static --grid_size 50 9 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Truck/images_2 --mask_dir data/nvidia_rodynrf/Truck/motion_masks --out_dir data/nvidia_rodynrf/Truck/bootscotracker_dynamic --grid_size 256 10 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Truck/images_2 --mask_dir data/nvidia_rodynrf/Truck/motion_masks --out_dir data/nvidia_rodynrf/Truck/bootscotracker_static --is_static --grid_size 50 11 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Skating/images_2 --mask_dir data/nvidia_rodynrf/Skating/motion_masks --out_dir data/nvidia_rodynrf/Skating/bootscotracker_dynamic --grid_size 256 12 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Skating/images_2 --mask_dir data/nvidia_rodynrf/Skating/motion_masks --out_dir data/nvidia_rodynrf/Skating/bootscotracker_static --is_static --grid_size 50 13 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --mask_dir data/nvidia_rodynrf/Umbrella/motion_masks --out_dir data/nvidia_rodynrf/Umbrella/bootscotracker_dynamic --grid_size 256 14 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --mask_dir data/nvidia_rodynrf/Umbrella/motion_masks --out_dir data/nvidia_rodynrf/Umbrella/bootscotracker_static --is_static --grid_size 50 -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # # install splinegs environment 2 | # ENV_NAME=splinegs 3 | 4 | # conda remove -n $ENV_NAME --all -y 5 | # conda create -n $ENV_NAME python=3.7 6 | # conda activate $ENV_NAME 7 | # export CUDA_HOME=$CONDA_PREFIX 8 | 9 | # conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 10 | # conda install nvidia/label/cuda-11.7.0::cuda 11 | # conda install nvidia/label/cuda-11.7.0::cuda-nvcc 12 | # conda install nvidia/label/cuda-11.7.0::cuda-runtime 13 | # conda install nvidia/label/cuda-11.7.0::cuda-cudart 14 | # conda install -c conda-forge ld_impl_linux-64 15 | 16 | # pip install -e submodules/simple-knn 17 | # pip install -e submodules/co-tracker 18 | # pip install -r requirements.txt 19 | 20 | # install unidepth environment 21 | UNIDEPTH_ENV_NAME=unidepth_splinegs 22 | conda remove -n $UNIDEPTH_ENV_NAME --all -y 23 | conda create -n $UNIDEPTH_ENV_NAME python=3.10 24 | conda activate $UNIDEPTH_ENV_NAME 25 | 26 | pip install -r requirements_unidepth.txt 27 | conda install -c conda-forge ld_impl_linux-64 28 | export CUDA_HOME=$CONDA_PREFIX 29 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib 30 | conda install nvidia/label/cuda-12.1.0::cuda 31 | conda install nvidia/label/cuda-12.1.0::cuda-nvcc 32 | conda install nvidia/label/cuda-12.1.0::cuda-runtime 33 | conda install nvidia/label/cuda-12.1.0::cuda-cudart 34 | conda install nvidia/label/cuda-12.1.0::libcusparse 35 | conda install nvidia/label/cuda-12.1.0::libcublas 36 | cd submodules/UniDepth/unidepth/ops/knn;bash compile.sh;cd ../../../../../ 37 | pip install -e submodules/UniDepth 38 | 39 | # install depthanything 40 | mkdir -p submodules/mega-sam/Depth-Anything/checkpoints -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/nerfstudio-project/gsplat.git@v1.4.0 2 | plyfile 3 | ffmpeg 4 | numpy==1.21.6 5 | scipy==1.7.3 6 | matplotlib==3.5.3 7 | mmengine==0.10.6 8 | imageio==2.19.3 9 | imageio-ffmpeg==0.4.7 10 | kornia==0.6.9 11 | tqdm -------------------------------------------------------------------------------- /requirements_unidepth.txt: -------------------------------------------------------------------------------- 1 | einops>=0.7.0 2 | gradio 3 | h5py>=3.10.0 4 | huggingface-hub>=0.22.0 5 | imageio 6 | matplotlib 7 | numpy==2.2.5 8 | opencv-python 9 | pandas 10 | pillow>=10.2.0 11 | protobuf>=4.25.3 12 | scipy 13 | tables 14 | tabulate 15 | termcolor 16 | timm 17 | tqdm 18 | trimesh 19 | triton>=2.4.0 20 | torch==2.4.0 21 | torchvision==0.19.0 22 | torchaudio==2.4.0 23 | wandb 24 | xformers==0.0.27.post2 -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | 14 | from arguments import ModelParams 15 | from scene.dataset import FourDGSdataset 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from utils.system_utils import searchForMaxIteration 19 | 20 | 21 | class Scene: 22 | dyn_gaussians: GaussianModel 23 | 24 | def __init__( 25 | self, 26 | args: ModelParams, 27 | gaussians: GaussianModel, 28 | static_gaussians: GaussianModel, 29 | load_iteration=None, 30 | shuffle=True, 31 | resolution_scales=[1.0], 32 | load_coarse=False, 33 | ): 34 | """b 35 | :param path: Path to colmap scene main folder. 36 | """ 37 | self.model_path = args.model_path 38 | self.loaded_iter = None 39 | self.dyn_gaussians = gaussians 40 | self.stat_gaussians = static_gaussians 41 | 42 | if load_iteration: 43 | if load_iteration == -1: 44 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 45 | else: 46 | self.loaded_iter = load_iteration 47 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 48 | 49 | self.train_cameras = {} 50 | self.test_cameras = {} 51 | self.video_cameras = {} 52 | 53 | assert args.dataset_type in sceneLoadTypeCallbacks.keys(), "Could not recognize scene type!" 54 | 55 | dataset_type = args.dataset_type 56 | scene_info = sceneLoadTypeCallbacks[dataset_type](args) 57 | 58 | self.maxtime = scene_info.maxtime 59 | self.dataset_type = dataset_type 60 | self.cameras_extent = scene_info.nerf_normalization["radius"] 61 | print(f"Original scene extent {self.cameras_extent}") 62 | print("Loading Training Cameras") 63 | self.train_camera = FourDGSdataset(scene_info.train_cameras, args, dataset_type) 64 | print("Loading Test Cameras") 65 | self.test_camera = FourDGSdataset(scene_info.test_cameras, args, dataset_type) 66 | print("Loading Video Cameras") 67 | self.video_camera = FourDGSdataset(scene_info.video_cameras, args, dataset_type) 68 | 69 | # self.video_camera = cameraList_from_camInfos(scene_info.video_cameras,-1,args) 70 | xyz_max = scene_info.point_cloud.points.max(axis=0) 71 | xyz_min = scene_info.point_cloud.points.min(axis=0) 72 | 73 | if self.loaded_iter: 74 | self.dyn_gaussians.load_ply( 75 | os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply") 76 | ) 77 | self.dyn_gaussians.load_model( 78 | os.path.join( 79 | self.model_path, 80 | "point_cloud", 81 | "iteration_" + str(self.loaded_iter), 82 | ) 83 | ) 84 | self.stat_gaussians.load_ply( 85 | os.path.join( 86 | self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud_static.ply" 87 | ) 88 | ) 89 | else: 90 | self.dyn_gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime) 91 | self.stat_gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime) 92 | 93 | def save(self, iteration, stage): 94 | if stage == "coarse": 95 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration)) 96 | 97 | else: 98 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 99 | self.dyn_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 100 | self.dyn_gaussians.save_deformation(point_cloud_path) 101 | 102 | self.stat_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud_static.ply")) 103 | 104 | def save_best_psnr(self, iteration, stage): 105 | if stage == "coarse": 106 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_best") 107 | 108 | else: 109 | point_cloud_path = os.path.join(self.model_path, "point_cloud/fine_best") 110 | self.dyn_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 111 | self.dyn_gaussians.save_deformation(point_cloud_path) 112 | 113 | self.stat_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud_static.ply")) 114 | 115 | def getTrainCameras(self, scale=1.0): 116 | return self.train_camera 117 | 118 | def getTestCameras(self, scale=1.0): 119 | return self.test_camera 120 | 121 | def getVideoCameras(self, scale=1.0): 122 | return self.video_camera 123 | -------------------------------------------------------------------------------- /scene/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scene.cameras import Camera 3 | from torch.utils.data import Dataset 4 | from utils.general_utils import PILtoTorch 5 | from utils.graphics_utils import focal2fov 6 | 7 | 8 | class FourDGSdataset(Dataset): 9 | def __init__(self, dataset, args, dataset_type): 10 | self.dataset = dataset 11 | self.args = args 12 | self.dataset_type = dataset_type 13 | 14 | def __getitem__(self, index): 15 | if self.dataset_type != "PanopticSports": 16 | try: 17 | image, w2c, time = self.dataset[index] 18 | R, T = w2c 19 | FovX = focal2fov(self.dataset.focal[0], image.shape[2]) 20 | FovY = focal2fov(self.dataset.focal[0], image.shape[1]) 21 | mask = None 22 | except: 23 | caminfo = self.dataset[index] 24 | image = PILtoTorch(caminfo.image, (caminfo.image.width, caminfo.image.height)) 25 | R = caminfo.R 26 | T = caminfo.T 27 | FovX = caminfo.FovX 28 | FovY = caminfo.FovY 29 | time = caminfo.time 30 | mask = caminfo.mask 31 | 32 | return Camera( 33 | colmap_id=index, 34 | R=R, 35 | T=T, 36 | FoVx=FovX, 37 | FoVy=FovY, 38 | image=image, 39 | gt_alpha_mask=None, 40 | image_name=f"{caminfo.image_name}", 41 | uid=index, 42 | data_device=torch.device("cuda"), 43 | time=time, 44 | mask=mask, 45 | metadata=caminfo.metadata, 46 | normal=caminfo.normal, 47 | depth=caminfo.depth, 48 | max_time=caminfo.max_time, 49 | sem_mask=caminfo.sem_mask, 50 | fwd_flow=caminfo.fwd_flow, 51 | bwd_flow=caminfo.bwd_flow, 52 | fwd_flow_mask=caminfo.fwd_flow_mask, 53 | bwd_flow_mask=caminfo.bwd_flow_mask, 54 | instance_mask=caminfo.instance_mask, 55 | target_tracks=caminfo.target_tracks, 56 | target_visibility=caminfo.target_visibility, 57 | target_tracks_static=caminfo.target_tracks_static, 58 | target_visibility_static=caminfo.target_visibility_static, 59 | ) 60 | else: 61 | return self.dataset[index] 62 | 63 | def __len__(self): 64 | return len(self.dataset) 65 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py -s data/nvidia_rodynrf/Balloon1/ --expname "Balloon1" --configs arguments/nvidia_rodynrf/Balloon1.py 2 | python train.py -s data/nvidia_rodynrf/Balloon2/ --expname "Balloon2" --configs arguments/nvidia_rodynrf/Balloon2.py 3 | python train.py -s data/nvidia_rodynrf/Playground/ --expname "Playground" --configs arguments/nvidia_rodynrf/Playground.py 4 | python train.py -s data/nvidia_rodynrf/Jumping/ --expname "Jumping" --configs arguments/nvidia_rodynrf/Jumping.py 5 | python train.py -s data/nvidia_rodynrf/Truck/ --expname "Truck" --configs arguments/nvidia_rodynrf/Truck.py 6 | python train.py -s data/nvidia_rodynrf/Skating/ --expname "Skating" --configs arguments/nvidia_rodynrf/Skating.py 7 | python train.py -s data/nvidia_rodynrf/Umbrella/ --expname "Umbrella" --configs arguments/nvidia_rodynrf/Umbrella.py -------------------------------------------------------------------------------- /utils/TIMES.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMES.TTF -------------------------------------------------------------------------------- /utils/TIMESBD.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESBD.TTF -------------------------------------------------------------------------------- /utils/TIMESBI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESBI.TTF -------------------------------------------------------------------------------- /utils/TIMESI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESI.TTF -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | from scene.cameras import Camera 14 | from utils.graphics_utils import fov2focal 15 | 16 | 17 | WARNED = False 18 | 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | # resized_image_rgb = PILtoTorch(cam_info.image, resolution) 22 | 23 | # gt_image = resized_image_rgb[:3, ...] 24 | # loaded_mask = None 25 | 26 | # if resized_image_rgb.shape[1] == 4: 27 | # loaded_mask = resized_image_rgb[3:4, ...] 28 | 29 | return Camera( 30 | colmap_id=cam_info.uid, 31 | R=cam_info.R, 32 | T=cam_info.T, 33 | FoVx=cam_info.FovX, 34 | FoVy=cam_info.FovY, 35 | image=cam_info.image, 36 | gt_alpha_mask=None, 37 | image_name=cam_info.image_name, 38 | uid=id, 39 | data_device=args.data_device, 40 | time=cam_info.time, 41 | metadata=cam_info.metadata, 42 | ) 43 | 44 | 45 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 46 | camera_list = [] 47 | 48 | for id, c in enumerate(cam_infos): 49 | camera_list.append(loadCam(args, id, c, resolution_scale)) 50 | 51 | return camera_list 52 | 53 | 54 | def camera_to_JSON(id, camera: Camera): 55 | Rt = np.zeros((4, 4)) 56 | Rt[:3, :3] = camera.R.transpose() 57 | Rt[:3, 3] = camera.T 58 | Rt[3, 3] = 1.0 59 | 60 | W2C = np.linalg.inv(Rt) 61 | pos = W2C[:3, 3] 62 | rot = W2C[:3, :3] 63 | serializable_array_2d = [x.tolist() for x in rot] 64 | camera_entry = { 65 | "id": id, 66 | "img_name": camera.image_name, 67 | "width": camera.width, 68 | "height": camera.height, 69 | "position": pos.tolist(), 70 | "rotation": serializable_array_2d, 71 | "fy": fov2focal(camera.FovY, camera.height), 72 | "fx": fov2focal(camera.FovX, camera.width), 73 | } 74 | return camera_entry 75 | -------------------------------------------------------------------------------- /utils/depth_loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from math import exp 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | 18 | 19 | def normalize(input, mean=None, std=None): 20 | input_mean = torch.mean(input, dim=1, keepdim=True) if mean is None else mean 21 | input_std = torch.std(input, dim=1, keepdim=True) if std is None else std 22 | return (input - input_mean) / (input_std + 1e-2 * torch.std(input.reshape(-1))) 23 | 24 | 25 | def shuffle(input): 26 | # shuffle dim=1 27 | idx = torch.randperm(input[0].shape[1]) 28 | for i in range(input.shape[0]): 29 | input[i] = input[i][:, idx].view(input[i].shape) 30 | 31 | 32 | def loss_depth_smoothness(depth, img): 33 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:] 34 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :] 35 | weight_x = torch.exp(-torch.abs(img_grad_x).mean(1).unsqueeze(1)) 36 | weight_y = torch.exp(-torch.abs(img_grad_y).mean(1).unsqueeze(1)) 37 | 38 | loss = ( 39 | ((depth[:, :, :, :-1] - depth[:, :, :, 1:]).abs() * weight_x).sum() 40 | + ((depth[:, :, :-1, :] - depth[:, :, 1:, :]).abs() * weight_y).sum() 41 | ) / (weight_x.sum() + weight_y.sum()) 42 | return loss 43 | 44 | 45 | def loss_depth_grad(depth, img): 46 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:] 47 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :] 48 | weight_x = img_grad_x / (torch.abs(img_grad_x) + 1e-6) 49 | weight_y = img_grad_y / (torch.abs(img_grad_y) + 1e-6) 50 | 51 | depth_grad_x = depth[:, :, :, :-1] - depth[:, :, :, 1:] 52 | depth_grad_y = depth[:, :, :-1, :] - depth[:, :, 1:, :] 53 | grad_x = depth_grad_x / (torch.abs(depth_grad_x) + 1e-6) 54 | grad_y = depth_grad_y / (torch.abs(depth_grad_y) + 1e-6) 55 | 56 | loss = l1_loss(grad_x, weight_x) + l1_loss(grad_y, weight_y) 57 | return loss 58 | 59 | 60 | def l1_loss(network_output, gt): 61 | return torch.abs((network_output - gt)).mean() 62 | 63 | 64 | def l2_loss(network_output, gt): 65 | return ((network_output - gt) ** 2).mean() 66 | 67 | 68 | def margin_l2_loss(network_output, gt, margin, return_mask=False): 69 | mask = (network_output - gt).abs() > margin 70 | if not return_mask: 71 | return ((network_output - gt)[mask] ** 2).mean() 72 | else: 73 | return ((network_output - gt)[mask] ** 2).mean(), mask 74 | 75 | 76 | def margin_l1_loss(network_output, gt, margin, return_mask=False): 77 | mask = (network_output - gt).abs() > margin 78 | if not return_mask: 79 | return ((network_output - gt)[mask].abs()).mean() 80 | else: 81 | return ((network_output - gt)[mask].abs()).mean(), mask 82 | 83 | 84 | def kl_loss(input, target): 85 | input = F.log_softmax(input, dim=-1) 86 | target = F.softmax(target, dim=-1) 87 | return F.kl_div(input, target, reduction="batchmean") 88 | 89 | 90 | def patchify(input, patch_size): 91 | patches = ( 92 | F.unfold(input, kernel_size=patch_size, stride=patch_size) 93 | .permute(0, 2, 1) 94 | .reshape(-1, 1 * patch_size * patch_size) 95 | ) 96 | return patches 97 | 98 | 99 | def patch_norm_mse_loss(input, target, patch_size, margin, return_mask=False): 100 | input_patches = normalize(patchify(input, patch_size)) 101 | target_patches = normalize(patchify(target, patch_size)) 102 | return margin_l2_loss(input_patches, target_patches, margin, return_mask) 103 | 104 | 105 | def patch_norm_mse_loss_global(input, target, patch_size, margin, return_mask=False): 106 | input_patches = normalize(patchify(input, patch_size), std=input.std().detach()) 107 | target_patches = normalize(patchify(target, patch_size), std=target.std().detach()) 108 | return margin_l2_loss(input_patches, target_patches, margin, return_mask) 109 | 110 | 111 | def patch_norm_l1_loss_global(input, target, patch_size, margin, return_mask=False): 112 | input_patches = normalize(patchify(input, patch_size), std=input.std().detach()) 113 | target_patches = normalize(patchify(target, patch_size), std=target.std().detach()) 114 | return margin_l1_loss(input_patches, target_patches, margin, return_mask) 115 | 116 | 117 | def patch_norm_l1_loss(input, target, patch_size, margin, return_mask=False): 118 | input_patches = normalize(patchify(input, patch_size)) 119 | target_patches = normalize(patchify(target, patch_size)) 120 | return margin_l1_loss(input_patches, target_patches, margin, return_mask) 121 | 122 | 123 | def gaussian(window_size, sigma): 124 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) 125 | return gauss / gauss.sum() 126 | 127 | 128 | def create_window(window_size, channel): 129 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 130 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 131 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 132 | return window 133 | 134 | 135 | def margin_ssim(img1, img2, window_size=11, size_average=True): 136 | result = ssim(img1, img2, window_size, False) 137 | print(result.shape) 138 | 139 | 140 | def ssim(img1, img2, window_size=11, size_average=True): 141 | channel = img1.size(-3) 142 | window = create_window(window_size, channel) 143 | 144 | if img1.is_cuda: 145 | window = window.cuda(img1.get_device()) 146 | window = window.type_as(img1) 147 | 148 | return _ssim(img1, img2, window, window_size, channel, size_average) 149 | 150 | 151 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 152 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 153 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 154 | 155 | mu1_sq = mu1.pow(2) 156 | mu2_sq = mu2.pow(2) 157 | mu1_mu2 = mu1 * mu2 158 | 159 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 160 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 161 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 162 | 163 | C1 = 0.01**2 164 | C2 = 0.03**2 165 | 166 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 167 | 168 | if size_average: 169 | return ssim_map.mean() 170 | else: 171 | return ssim_map.mean(1).mean(1).mean(1) 172 | -------------------------------------------------------------------------------- /utils/dycheck_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | -------------------------------------------------------------------------------- /utils/dycheck_utils/annotation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : annotation.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import functools 21 | from typing import List 22 | 23 | import cv2 24 | import jax 25 | import numpy as np 26 | from ipyevents import Event 27 | from IPython.display import clear_output, display 28 | from ipywidgets import HTML, Button, HBox, Image, Output 29 | 30 | from . import common, visuals 31 | 32 | 33 | def annotate_record3d_bad_frames( 34 | frames: np.ndarray, 35 | *, 36 | frame_ext: str = ".png", 37 | ) -> List[np.ndarray]: 38 | """Interactively annotate bad frames in validation set to skip later. 39 | 40 | "Bad images" exist because the hand may occlude the scene during capturing. 41 | 42 | Args: 43 | frames (np.ndarray]): images of shape (N, H, W, 3) in uint8 to annotate. 44 | skel (Skeleton): a skeleton definition object. 45 | frame_ext (str): the extension of images. This is used for decoding and 46 | then display. Default: '.png'. 47 | 48 | Return: 49 | np.ndarray: bad frame indices of shape (N_bad,). 50 | """ 51 | 52 | def _frame_to_widget(frame): 53 | value = cv2.imencode(frame_ext, frame[..., ::-1])[1].tobytes() 54 | widget = Image(value=value, format=frame_ext[1:]) 55 | widget.layout.max_width = "100%" 56 | widget.layout.height = "auto" 57 | return widget 58 | 59 | out = Output() 60 | 61 | frame_widgets = list( 62 | map( 63 | _frame_to_widget, 64 | common.tqdm(frames, desc="* Decoding frames"), 65 | ) 66 | ) 67 | bad_frames = [] 68 | 69 | frame_idx = -1 70 | frame_msg = HTML() 71 | 72 | def show_next_frame(): 73 | nonlocal frame_idx, frame_msg 74 | frame_idx += 1 75 | 76 | frame_msg.value = f"{frame_idx} frames annotated, " f"{len(frames) - frame_idx} frames left. " 77 | 78 | if frame_idx == len(frames): 79 | for btn in buttons: 80 | if btn is not redo_btn: 81 | btn.disabled = True 82 | print("Annotation done.") 83 | return 84 | 85 | frame_widget = frame_widgets[frame_idx] 86 | with out: 87 | clear_output(wait=True) 88 | display(frame_widget) 89 | 90 | def mark_frame(_, good): 91 | nonlocal frame_idx, bad_frames 92 | if not good: 93 | bad_frames.append(frame_idx) 94 | 95 | show_next_frame() 96 | 97 | def redo_frame(btn): 98 | nonlocal frame_idx, bad_frames 99 | if len(bad_frames) > 0 and bad_frames[-1] == frame_idx - 1: 100 | _ = bad_frames.pop() 101 | frame_idx = max(-1, frame_idx - 2) 102 | 103 | for btn in buttons: 104 | if btn is not redo_btn: 105 | btn.disabled = False 106 | 107 | show_next_frame() 108 | 109 | buttons = [] 110 | 111 | valid_btn = Button(description="👍") 112 | valid_btn.on_click(functools.partial(mark_frame, good=True)) 113 | buttons.append(valid_btn) 114 | 115 | invalid_btn = Button(description="👎") 116 | invalid_btn.on_click(functools.partial(mark_frame, good=False)) 117 | buttons.append(invalid_btn) 118 | 119 | redo_btn = Button(description="♻️") 120 | redo_btn.on_click(redo_frame) 121 | buttons.append(redo_btn) 122 | 123 | display(HBox([frame_msg])) 124 | display(HBox(buttons)) 125 | display(HBox([out])) 126 | 127 | show_next_frame() 128 | 129 | return bad_frames 130 | 131 | 132 | def annotate_keypoints( 133 | frames: np.ndarray, 134 | skeleton: visuals.Skeleton, 135 | *, 136 | frame_ext: str = ".png", 137 | **kwargs, 138 | ) -> List[np.ndarray]: 139 | """Interactively annotate keypoints on input frames. 140 | 141 | Note that each frame will only be submitted when finished. 142 | 143 | Args: 144 | frames (np.ndarray]): images of shape (N, H, W, 3) in uint8 to annotate. 145 | skel (Skeleton): a skeleton definition object. 146 | frame_ext (str): the extension of images. This is used for decoding and 147 | then display. Default: '.png'. 148 | 149 | Return: 150 | keypoints (List[np.ndarray]): a list of N annotated keypoints of shape 151 | (J, 3) where the last column is visibility in [0, 1]. 152 | """ 153 | 154 | def _frame_to_widget(frame): 155 | value = cv2.imencode(frame_ext, np.array(frame[..., ::-1]))[1].tobytes() 156 | widget = Image(value=value, format=frame_ext[1:]) 157 | widget.layout.max_width = "100%" 158 | widget.layout.height = "auto" 159 | return widget 160 | 161 | frame_widgets = jax.tree_map(_frame_to_widget, list(frames)) 162 | keypoints = [] 163 | 164 | frame_idx, kps = -1, [] 165 | 166 | event, frame_msg, kp_msg, kp_inst = Event(), HTML(), HTML(), HTML() 167 | 168 | def show_next_frame(): 169 | nonlocal frame_idx, frame_msg, event 170 | frame_idx += 1 171 | 172 | frame_msg.value = f"{len(keypoints)} frames annotated, " f"{len(frames) - frame_idx} frames left." 173 | 174 | if frame_idx == len(frames): 175 | _show_current_kp_name() 176 | for btn in buttons: 177 | if btn is not redo_btn: 178 | btn.disabled = True 179 | event.watched_events = [] 180 | print("Annotation done.") 181 | return 182 | else: 183 | _show_current_kp_name() 184 | for btn in buttons: 185 | btn.disabled = False 186 | 187 | _mark_next_kp() 188 | 189 | frame_widget = frame_widgets[frame_idx] 190 | with out: 191 | clear_output(wait=True) 192 | display(frame_widget) 193 | event.source = frame_widget 194 | event.watched_events = ["click"] 195 | event.on_dom_event(mark_kp) 196 | 197 | def _show_kp_visual(): 198 | padded_kps = np.array( 199 | kps + [[0, 0, 0] for _ in range(skeleton.num_kps - len(kps))], 200 | dtype=np.float32, 201 | ) 202 | canvas = frames[frame_idx] 203 | kp_visual = visuals.visualize_kps(padded_kps, canvas, skeleton=skeleton, **kwargs) 204 | visual_widget = _frame_to_widget(kp_visual) 205 | with kp_visual_out: 206 | clear_output(wait=True) 207 | display(visual_widget) 208 | 209 | def _mark_next_kp(): 210 | _show_kp_visual() 211 | if len(kps) == skeleton.num_kps: 212 | submit_frame() 213 | 214 | _show_current_kp_name() 215 | 216 | nonlocal kp_msg 217 | kp_msg.value = f"{len(kps)} keypoints annotated, " f"{len(skeleton.kp_names) - len(kps)} keypoints left." 218 | 219 | def _show_current_kp_name(): 220 | nonlocal frame_idx 221 | if frame_idx == len(frames): 222 | msg = "FINISHED!" 223 | else: 224 | kp_name = skeleton.kp_names[len(kps)] 225 | msg = f"Marking [{kp_name}]..." 226 | kp_inst.value = msg 227 | 228 | def mark_kp(event): 229 | kp = np.array([event["dataX"], event["dataY"], 1], dtype=np.float32) 230 | nonlocal kps 231 | kps.append(kp) 232 | 233 | _mark_next_kp() 234 | 235 | def redo_kp(_): 236 | nonlocal kps, keypoints, frame_idx 237 | if len(kps) == 0 and len(keypoints) > 0: 238 | kps = keypoints.pop().tolist() 239 | kps = kps[:-1] 240 | frame_idx -= 2 241 | show_next_frame() 242 | else: 243 | kps.pop() 244 | _mark_next_kp() 245 | 246 | def mark_invisible_kp(_): 247 | kp = np.array([0, 0, 0], dtype=np.float32) 248 | nonlocal kps 249 | kps.append(kp) 250 | 251 | _mark_next_kp() 252 | 253 | def submit_frame(): 254 | nonlocal kps 255 | keypoints.append(np.array(kps)) 256 | kps = [] 257 | 258 | show_next_frame() 259 | 260 | buttons = [] 261 | 262 | invisible_btn = Button(description="🫥") 263 | invisible_btn.on_click(mark_invisible_kp) 264 | buttons.append(invisible_btn) 265 | 266 | redo_btn = Button(description="♻️") 267 | redo_btn.on_click(redo_kp) 268 | buttons.append(redo_btn) 269 | 270 | display(HBox([kp_inst])) 271 | display(HBox([frame_msg, kp_msg])) 272 | display(HBox(buttons)) 273 | out, kp_visual_out = Output(), Output() 274 | display(HBox([out, kp_visual_out])) 275 | show_next_frame() 276 | 277 | return keypoints 278 | -------------------------------------------------------------------------------- /utils/dycheck_utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : common.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import functools 21 | import inspect 22 | from concurrent import futures 23 | from copy import copy 24 | from typing import Any, Callable, Dict, Iterable, Optional, Sequence 25 | 26 | # import jax 27 | import numpy as np 28 | 29 | 30 | def tolerant_partial(fn: Callable, *args, **kwargs) -> Callable: 31 | """A thin wrapper around functools.partial which only binds the keyword 32 | arguments that matches the function signature. 33 | """ 34 | signatures = inspect.signature(fn) 35 | return functools.partial( 36 | fn, 37 | *args, 38 | **{k: v for k, v in kwargs.items() if k in signatures.parameters}, 39 | ) 40 | 41 | 42 | def traverse_filter( 43 | data_dict: Dict[str, Any], 44 | exclude_fields: Sequence[str] = (), 45 | return_fields: Sequence[str] = (), 46 | protect_fields: Sequence[str] = (), 47 | inplace: bool = False, 48 | ) -> Dict[str, Any]: 49 | """Keep matched field values within the dictionary, either inplace or not. 50 | 51 | Args: 52 | data_dict (Dict[str, Any]): A dictionary to be filtered. 53 | exclude_fields (Sequence[str]): A list of fields to be excluded. 54 | return_fields (Sequence[str]): A list of fields to be returned. 55 | protect_fields (Sequence[str]): A list of fields to be protected. 56 | inplace (bool): Whether to modify the input dictionary inplace. 57 | 58 | Returns: 59 | Dict[str, Any]: The filtered dictionary. 60 | """ 61 | assert isinstance(data_dict, dict) 62 | 63 | str_to_tupid = lambda s: tuple(s.split("/")) 64 | exclude_fields = [str_to_tupid(f) for f in set(exclude_fields)] 65 | return_fields = [str_to_tupid(f) for f in set(return_fields)] 66 | protect_fields = [str_to_tupid(f) for f in set(protect_fields)] 67 | 68 | filter_fn = lambda f: f in protect_fields or ( 69 | f in return_fields if len(return_fields) > 0 else f not in exclude_fields 70 | ) 71 | 72 | if not inplace: 73 | data_dict = copy(data_dict) 74 | 75 | def delete_filtered(d, prefix): 76 | if isinstance(d, dict): 77 | for k in list(d.keys()): 78 | path = prefix + (k,) 79 | if (not isinstance(d[k], dict) or len(d[k]) == 0) and not filter_fn(path): 80 | del d[k] 81 | else: 82 | delete_filtered(d[k], path) 83 | 84 | delete_filtered(data_dict, ()) 85 | return data_dict 86 | 87 | 88 | @functools.lru_cache(maxsize=None) 89 | def in_notebook() -> bool: 90 | """Check if the code is running in a notebook.""" 91 | try: 92 | from IPython import get_ipython 93 | 94 | ipython = get_ipython() 95 | if not ipython or "IPKernelApp" not in ipython.config: 96 | return False 97 | except ImportError: 98 | return False 99 | return True 100 | 101 | 102 | def tqdm(iterable: Iterable, *args, **kwargs) -> Iterable: 103 | if not in_notebook(): 104 | from tqdm import tqdm as _tqdm 105 | else: 106 | from tqdm.notebook import tqdm as _tqdm 107 | return _tqdm(iterable, *args, **kwargs) 108 | 109 | 110 | def parallel_map( 111 | func: Callable, 112 | *iterables: Sequence[Iterable], 113 | max_threads: Optional[int] = None, 114 | show_pbar: bool = False, 115 | desc: Optional[str] = None, 116 | pbar_kwargs: Dict[str, Any] = {}, 117 | debug: bool = False, 118 | **kwargs, 119 | ) -> Sequence[Any]: 120 | """Parallel version of map().""" 121 | if not debug: 122 | with futures.ThreadPoolExecutor(max_threads) as executor: 123 | if show_pbar: 124 | results = list( 125 | tqdm( 126 | executor.map(func, *iterables, **kwargs), 127 | desc=desc, 128 | total=len(iterables[0]), 129 | **pbar_kwargs, 130 | ) 131 | ) 132 | else: 133 | results = list(executor.map(func, *iterables, **kwargs)) 134 | return results 135 | else: 136 | return list(map(func, *iterables, **kwargs)) 137 | 138 | 139 | def tree_collate(trees: Sequence[Any], collate_fn=lambda *x: np.asarray(x)): 140 | """Collates a list of pytrees with the same structure.""" 141 | return jax.tree_map(collate_fn, *trees) 142 | 143 | 144 | def strided_subset(sequence: Sequence[Any], count: int) -> Sequence[Any]: 145 | if count > len(sequence): 146 | raise ValueError("count must be less than or equal to len(sequence)") 147 | inds = np.linspace(0, len(sequence), count, dtype=int, endpoint=False) 148 | if isinstance(sequence, np.ndarray): 149 | sequence = sequence[inds] 150 | else: 151 | sequence = [sequence[i] for i in inds] 152 | return sequence 153 | 154 | 155 | def random_subset(sequence: Sequence[Any], count: int, seed: int = 0) -> Sequence[Any]: 156 | if count > len(sequence): 157 | raise ValueError("count must be less than or equal to len(sequence)") 158 | rng = np.random.default_rng(seed) 159 | inds = rng.choice(len(sequence), count, replace=False) 160 | if isinstance(sequence, np.ndarray): 161 | sequence = sequence[inds] 162 | else: 163 | sequence = [sequence[i] for i in inds] 164 | return sequence 165 | -------------------------------------------------------------------------------- /utils/dycheck_utils/image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : image.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import math 21 | from typing import Any, Tuple 22 | 23 | import cv2 24 | import numpy as np 25 | 26 | from . import types 27 | 28 | 29 | # from absl import logging 30 | 31 | 32 | UINT8_MAX = 255 33 | UINT16_MAX = 65535 34 | 35 | 36 | def downscale(img: types.Array, scale: int) -> np.ndarray: 37 | if isinstance(img, np.ndarray): 38 | img = np.array(img) 39 | 40 | if scale == 1: 41 | return img 42 | 43 | height, width = img.shape[:2] 44 | if height % scale > 0 or width % scale > 0: 45 | raise ValueError(f"Image shape ({height},{width}) must be divisible by the" f" scale ({scale}).") 46 | out_height, out_width = height // scale, width // scale 47 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA) 48 | return resized 49 | 50 | 51 | def upscale(img: types.Array, scale: int) -> np.ndarray: 52 | if isinstance(img, np.ndarray): 53 | img = np.array(img) 54 | 55 | if scale == 1: 56 | return img 57 | 58 | height, width = img.shape[:2] 59 | out_height, out_width = height * scale, width * scale 60 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA) 61 | return resized 62 | 63 | 64 | def rescale(img: types.Array, scale_factor: float, interpolation: Any = cv2.INTER_AREA) -> np.ndarray: 65 | scale_factor = float(scale_factor) 66 | 67 | if scale_factor <= 0.0: 68 | raise ValueError("scale_factor must be a non-negative number.") 69 | if scale_factor == 1.0: 70 | return img 71 | 72 | height, width = img.shape[:2] 73 | if scale_factor.is_integer(): 74 | return upscale(img, int(scale_factor)) 75 | 76 | inv_scale = 1.0 / scale_factor 77 | if inv_scale.is_integer() and (scale_factor * height).is_integer() and (scale_factor * width).is_integer(): 78 | return downscale(img, int(inv_scale)) 79 | 80 | logging.warning( 81 | "Resizing image by non-integer factor %f, this may lead to artifacts.", 82 | scale_factor, 83 | ) 84 | 85 | height, width = img.shape[:2] 86 | out_height = math.ceil(height * scale_factor) 87 | out_height -= out_height % 2 88 | out_width = math.ceil(width * scale_factor) 89 | out_width -= out_width % 2 90 | 91 | return resize(img, (out_height, out_width), interpolation) 92 | 93 | 94 | def resize( 95 | img: types.Array, 96 | shape: Tuple[int, int], 97 | interpolation: Any = cv2.INTER_AREA, 98 | ) -> np.ndarray: 99 | if isinstance(img, np.ndarray): 100 | img = np.array(img) 101 | 102 | out_height, out_width = shape 103 | return cv2.resize( 104 | img, 105 | (out_width, out_height), 106 | interpolation=interpolation, 107 | ) 108 | 109 | 110 | def varlap(img: types.Array) -> np.ndarray: 111 | """Measure the focus/motion-blur of an image by the Laplacian variance.""" 112 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 113 | return cv2.Laplacian(gray, cv2.CV_64F).var() 114 | 115 | 116 | def to_float32(img: types.Array) -> np.ndarray: 117 | img = np.array(img) 118 | if img.dtype == np.float32: 119 | return img 120 | 121 | dtype = img.dtype 122 | img = img.astype(np.float32) 123 | if dtype == np.uint8: 124 | return img / UINT8_MAX 125 | elif dtype == np.uint16: 126 | return img / UINT16_MAX 127 | elif dtype == np.float64: 128 | return img 129 | elif dtype == np.float16: 130 | return img 131 | 132 | raise ValueError(f"Unexpected dtype: {dtype}.") 133 | 134 | 135 | def to_quantized_float32(img: types.Array) -> np.ndarray: 136 | return to_float32(to_uint8(img)) 137 | 138 | 139 | def to_uint8(img: types.Array) -> np.ndarray: 140 | img = np.array(img) 141 | if img.dtype == np.uint8: 142 | return img 143 | if not issubclass(img.dtype.type, np.floating): 144 | raise ValueError(f"Input image should be a floating type but is of type " f"{img.dtype!r}.") 145 | return (img * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8) 146 | 147 | 148 | def to_uint16(img: types.Array) -> np.ndarray: 149 | img = np.array(img) 150 | if img.dtype == np.uint16: 151 | return img 152 | if not issubclass(img.dtype.type, np.floating): 153 | raise ValueError(f"Input image should be a floating type but is of type " f"{img.dtype!r}.") 154 | return (img * UINT16_MAX).clip(0.0, UINT16_MAX).astype(np.uint16) 155 | 156 | 157 | # Special forms of images. 158 | def rescale_flow( 159 | flow: types.Array, 160 | scale_factor: float, 161 | interpolation: Any = cv2.INTER_LINEAR, 162 | ) -> np.ndarray: 163 | height, width = flow.shape[:2] 164 | 165 | out_flow = rescale(flow, scale_factor, interpolation) 166 | 167 | out_height, out_width = out_flow.shape[:2] 168 | out_flow[..., 0] *= float(out_width) / float(width) 169 | out_flow[..., 1] *= float(out_height) / float(height) 170 | return out_flow 171 | -------------------------------------------------------------------------------- /utils/dycheck_utils/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : io.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import json 21 | import os.path as osp 22 | import pickle as pkl 23 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union # Literal, 24 | 25 | import cv2 26 | import ffmpeg 27 | import numpy as np 28 | 29 | from . import common, image, path_ops, types 30 | 31 | 32 | _LOAD_REGISTRY, _DUMP_REGISTRY = {}, {} 33 | 34 | IMG_EXTS = ( 35 | ".jpg", 36 | ".jpeg", 37 | ".png", 38 | ".ppm", 39 | ".bmp", 40 | ".pgm", 41 | ".tif", 42 | ".tiff", 43 | ".webp", 44 | ) 45 | VID_EXTS = (".mov", ".avi", ".mpg", ".mpeg", ".mp4", ".mkv", ".wmv", ".gif") 46 | 47 | 48 | def _register( 49 | *, 50 | ext: Optional[Union[str, Sequence[str]]] = None, 51 | ) -> Callable: 52 | if isinstance(ext, str): 53 | ext = [ext] 54 | 55 | def _inner_register(func: Callable) -> Callable: 56 | name = func.__name__ 57 | assert name.startswith("load_") or name.startswith("dump_") 58 | 59 | nonlocal ext 60 | if ext is None: 61 | ext = ["." + name[5:]] 62 | for e in ext: 63 | if name.startswith("load_"): 64 | _LOAD_REGISTRY[e] = func 65 | else: 66 | _DUMP_REGISTRY[e] = func 67 | 68 | return func 69 | 70 | return _inner_register 71 | 72 | 73 | def _dispatch(registry: Dict[str, Callable], name) -> Callable: 74 | def _dispatched(filename: types.PathType, *args, **kwargs): 75 | ext = path_ops.get_ext(filename, match_first=False) 76 | func = registry[ext] 77 | if name == "dump" and osp.dirname(filename) != "" and not osp.exists(osp.dirname(filename)): 78 | path_ops.mkdir(osp.dirname(filename)) 79 | return func(filename, *args, **kwargs) 80 | 81 | _dispatched.__name__ = name 82 | return _dispatched 83 | 84 | 85 | load = _dispatch(_LOAD_REGISTRY, "load") 86 | dump = _dispatch(_DUMP_REGISTRY, "dump") 87 | 88 | 89 | @_register() 90 | def load_txt(filename: types.PathType, *, strip: bool = True, **kwargs) -> List[str]: 91 | with open(filename) as f: 92 | lines = f.readlines(**kwargs) 93 | if strip: 94 | lines = [line.strip() for line in lines] 95 | return lines 96 | 97 | 98 | @_register() 99 | def dump_txt(filename: types.PathType, obj: List[Any], **_) -> None: 100 | # Prefer visual appearance over compactness. 101 | obj = "\n".join([str(item) for item in obj]) 102 | with open(filename, "w") as f: 103 | f.write(obj) 104 | 105 | 106 | @_register() 107 | def load_json(filename: types.PathType, **kwargs) -> Dict: 108 | with open(filename) as f: 109 | return json.load(f, **kwargs) 110 | 111 | 112 | @_register() 113 | def dump_json( 114 | filename: types.PathType, 115 | obj: Dict, 116 | *, 117 | sort_keys: bool = True, 118 | indent: Optional[int] = 4, 119 | separators: Tuple[str, str] = (",", ": "), 120 | **kwargs, 121 | ) -> None: 122 | # Process potential numpy arrays. 123 | if isinstance(obj, dict): 124 | obj = {k: v.tolist() if hasattr(v, "tolist") else v for k, v in obj.items()} 125 | elif isinstance(obj, (list, tuple)): 126 | pass 127 | elif isinstance(obj, np.ndarray): 128 | obj = obj.tolist() 129 | else: 130 | raise ValueError(f"{type(obj)} is not a supported type.") 131 | # Prefer visual appearance over compactness. 132 | with open(filename, "w") as f: 133 | json.dump( 134 | obj, 135 | f, 136 | sort_keys=sort_keys, 137 | indent=indent, 138 | separators=separators, 139 | **kwargs, 140 | ) 141 | 142 | 143 | @_register() 144 | def load_pkl(filename: types.PathType, **kwargs) -> Dict: 145 | with open(filename, "rb") as f: 146 | try: 147 | return pkl.load(f, **kwargs) 148 | except UnicodeDecodeError as e: 149 | if "encoding" in kwargs: 150 | raise e 151 | return load_pkl(filename, encoding="latin1", **kwargs) 152 | 153 | 154 | @_register() 155 | def dump_pkl(filename: types.PathType, obj: Dict, **kwargs) -> None: 156 | with open(filename, "wb") as f: 157 | pkl.dump(obj, f, **kwargs) 158 | 159 | 160 | @_register() 161 | def load_npy( 162 | filename: types.PathType, *, allow_pickle: bool = True, **kwargs 163 | ) -> Union[np.ndarray, Dict[str, np.ndarray]]: 164 | return np.load(filename, allow_pickle=allow_pickle, **kwargs) 165 | 166 | 167 | @_register() 168 | def dump_npy(filename: types.PathType, obj: np.ndarray, **kwargs) -> None: 169 | np.save(filename, obj, **kwargs) 170 | 171 | 172 | @_register() 173 | def load_npz(filename: types.PathType, **kwargs) -> np.lib.npyio.NpzFile: 174 | return np.load(filename, **kwargs) 175 | 176 | 177 | @_register() 178 | def dump_npz(filename: types.PathType, **kwargs) -> None: 179 | # Disable positional argument for np.savez. 180 | np.savez(filename, **kwargs) 181 | 182 | 183 | @_register(ext=IMG_EXTS) 184 | def load_img(filename: types.PathType, *, use_rgb: bool = True, **kwargs) -> np.ndarray: 185 | img = cv2.imread(filename, **kwargs) 186 | if use_rgb and img.shape[-1] >= 3: 187 | # Take care of RGBA case when flipping. 188 | img = np.concatenate([img[..., 2::-1], img[..., 3:]], axis=-1) 189 | return img 190 | 191 | 192 | @_register(ext=IMG_EXTS) 193 | def dump_img( 194 | filename: types.PathType, 195 | obj: np.ndarray, 196 | *, 197 | use_rgb: bool = True, 198 | **kwargs, 199 | ) -> None: 200 | if use_rgb and obj.shape[-1] >= 3: 201 | obj = np.concatenate([obj[..., 2::-1], obj[..., 3:]], axis=-1) 202 | cv2.imwrite(filename, image.to_uint8(obj), **kwargs) 203 | 204 | 205 | def load_vid_metadata(filename: types.PathType) -> np.ndarray: 206 | assert osp.exists(filename), f"{filename} does not exist!" 207 | try: 208 | probe = ffmpeg.probe(filename) 209 | except ffmpeg.Error as e: 210 | print("stdout:", e.stdout.decode("utf8")) 211 | print("stderr:", e.stderr.decode("utf8")) 212 | raise e 213 | metadata = next(stream for stream in probe["streams"] if stream["codec_type"] == "video") 214 | metadata["fps"] = float(eval(metadata["r_frame_rate"])) 215 | return metadata 216 | 217 | 218 | @_register(ext=VID_EXTS) 219 | def load_vid( 220 | filename: types.PathType, 221 | *, 222 | quiet: bool = True, 223 | trim_kwargs: Dict[str, Any] = {}, 224 | **_, 225 | ) -> Dict: 226 | vid_metadata = load_vid_metadata(filename) 227 | W = int(vid_metadata["width"]) 228 | H = int(vid_metadata["height"]) 229 | 230 | stream = ffmpeg.input(filename) 231 | if len(trim_kwargs) > 0: 232 | stream = ffmpeg.trim(stream, **trim_kwargs).setpts("PTS-STARTPTS") 233 | stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24") 234 | out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=quiet) 235 | out = np.frombuffer(out, np.uint8).reshape([-1, H, W, 3]) 236 | return out.copy() 237 | 238 | 239 | @_register(ext=VID_EXTS) 240 | def dump_vid( 241 | filename: types.PathType, 242 | obj: Union[List[np.ndarray], np.ndarray], 243 | *, 244 | fps: float, 245 | quiet: bool = True, 246 | show_pbar: bool = True, 247 | desc: Optional[str] = "* Dumping video", 248 | **kwargs, 249 | ) -> None: 250 | if not isinstance(obj, np.ndarray): 251 | obj = np.asarray(obj) 252 | obj = image.to_uint8(obj) 253 | 254 | H, W = obj.shape[1:3] 255 | stream = ffmpeg.input( 256 | "pipe:", 257 | format="rawvideo", 258 | pix_fmt="rgb24", 259 | s="{}x{}".format(W, H), 260 | r=fps, 261 | ) 262 | process = ( 263 | stream.output(filename, pix_fmt="yuv420p", vcodec="libx264") 264 | .overwrite_output() 265 | .run_async(pipe_stdin=True, quiet=quiet) 266 | ) 267 | obj_bytes = common.parallel_map( 268 | lambda f: f.tobytes(), 269 | list(obj), 270 | show_pbar=show_pbar, 271 | desc=desc, 272 | **kwargs, 273 | ) 274 | for b in obj_bytes: 275 | process.stdin.write(b) 276 | process.stdin.close() 277 | process.wait() 278 | -------------------------------------------------------------------------------- /utils/dycheck_utils/path_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : path_ops.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import glob 21 | import os 22 | import os.path as osp 23 | import re 24 | import shutil 25 | from typing import List 26 | 27 | from . import types 28 | 29 | 30 | def get_ext(filename: types.PathType, match_first: bool = False) -> types.PathType: 31 | if match_first: 32 | filename = osp.split(filename)[1] 33 | return filename[filename.find(".") :] 34 | else: 35 | return osp.splitext(filename)[1] 36 | 37 | 38 | def basename(filename: types.PathType, with_ext: bool = True, **kwargs) -> types.PathType: 39 | name = osp.basename(filename, **kwargs) 40 | if not with_ext: 41 | name = name.replace(get_ext(name), "") 42 | return name 43 | 44 | 45 | def natural_sorted(lst: List[types.PathType]) -> List[types.PathType]: 46 | convert = lambda text: int(text) if text.isdigit() else text.lower() 47 | alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)] 48 | return sorted(lst, key=alphanum_key) 49 | 50 | 51 | def mtime_sorted(lst: List[types.PathType]) -> List[types.PathType]: 52 | # Ascending order: last modified file will be the last one. 53 | return sorted(lst, key=lambda p: os.stat(p).st_mtime) 54 | 55 | 56 | def ls( 57 | pattern: str, 58 | *, 59 | type: str = "a", 60 | latestk: int = -1, 61 | exclude: bool = False, 62 | ) -> List[types.PathType]: 63 | filter_fn = { 64 | "f": lambda p: osp.isfile(p) and not osp.islink(p), 65 | "d": osp.isdir, 66 | "l": osp.islink, 67 | "a": lambda p: osp.isfile(p) or osp.isdir(p) or osp.islink(p), 68 | }[type] 69 | 70 | def _natural_sorted_latestk(fs): 71 | if latestk > 0: 72 | if not exclude: 73 | fs = sorted(fs, key=osp.getmtime)[::-1][:latestk] 74 | else: 75 | fs = sorted(fs, key=osp.getmtime)[::-1][latestk:] 76 | return natural_sorted(fs) 77 | 78 | if "**" in pattern: 79 | dsts = glob.glob(pattern, recursive=True) 80 | elif "*" in pattern: 81 | dsts = glob.glob(pattern) 82 | else: 83 | dsts = [osp.join(pattern, p) for p in os.listdir(pattern) if filter_fn(osp.join(pattern, p))] 84 | return _natural_sorted_latestk(dsts) 85 | 86 | dsts = [dst for dst in dsts if filter_fn(dst)] 87 | return _natural_sorted_latestk(dsts) 88 | 89 | 90 | def mv(src: types.PathType, dst: types.PathType) -> None: 91 | shutil.move(src, dst) 92 | 93 | 94 | def ln( 95 | src: types.PathType, 96 | dst: types.PathType, 97 | use_relpath: bool = True, 98 | exist_ok: bool = True, 99 | ) -> None: 100 | if osp.exists(dst): 101 | if exist_ok: 102 | rm(dst) 103 | else: 104 | raise FileExistsError(f'Force link from "{src}" to existed "{dst}".') 105 | if use_relpath: 106 | src = osp.relpath(src, start=osp.dirname(dst)) 107 | if not osp.exists(osp.dirname(dst)): 108 | mkdir(osp.dirname(dst)) 109 | os.symlink(src, dst) 110 | 111 | 112 | def cp(src: types.PathType, dst: types.PathType, **kwargs) -> None: 113 | try: 114 | shutil.copyfile(src, dst) 115 | except OSError: 116 | shutil.copytree(src, dst, **kwargs) 117 | 118 | 119 | def mkdir(dst: types.PathType, exist_ok: bool = True, **kwargs) -> None: 120 | os.makedirs(dst, exist_ok=exist_ok, **kwargs) 121 | 122 | 123 | def rm(dst: types.PathType) -> None: 124 | if osp.exists(dst): 125 | if osp.isdir(dst): 126 | shutil.rmtree(dst, ignore_errors=True) 127 | if osp.isfile(dst) or osp.islink(dst): 128 | os.remove(dst) 129 | -------------------------------------------------------------------------------- /utils/dycheck_utils/safe_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : safe_ops.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import functools 21 | from typing import Tuple 22 | 23 | import jax 24 | import jax.numpy as np 25 | 26 | 27 | @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 3)) 28 | def safe_norm( 29 | x: np.ndarray, 30 | axis: int = -1, 31 | keepdims: bool = False, 32 | _: float = 1e-9, 33 | ) -> np.ndarray: 34 | """Calculates a np.linalg.norm(d) that's safe for gradients at d=0. 35 | 36 | These gymnastics are to avoid a poorly defined gradient for 37 | np.linal.norm(0). see https://github.com/google/jax/issues/3058 for details 38 | 39 | Args: 40 | x (np.ndarray): A np.array. 41 | axis (int): The axis along which to compute the norm. 42 | keepdims (bool): if True don't squeeze the axis. 43 | tol (float): the absolute threshold within which to zero out the 44 | gradient. 45 | 46 | Returns: 47 | Equivalent to np.linalg.norm(d) 48 | """ 49 | return np.linalg.norm(x, axis=axis, keepdims=keepdims) 50 | 51 | 52 | @safe_norm.defjvp 53 | def _safe_norm_jvp( 54 | axis: int, keepdims: bool, tol: float, primals: Tuple, tangents: Tuple 55 | ) -> Tuple[np.ndarray, np.ndarray]: 56 | (x,) = primals 57 | (x_dot,) = tangents 58 | safe_tol = max(tol, 1e-30) 59 | y = safe_norm(x, tol=safe_tol, axis=axis, keepdims=True) 60 | y_safe = np.maximum(y, tol) # Prevent divide by zero. 61 | y_dot = np.where(y > safe_tol, x_dot * x / y_safe, np.zeros_like(x)) 62 | y_dot = np.sum(y_dot, axis=axis, keepdims=True) 63 | # Squeeze the axis if `keepdims` is True. 64 | if not keepdims: 65 | y = np.squeeze(y, axis=axis) 66 | y_dot = np.squeeze(y_dot, axis=axis) 67 | return y, y_dot 68 | 69 | 70 | def log1p_safe(x: np.ndarray) -> np.ndarray: 71 | return np.log1p(np.minimum(x, 3e37)) 72 | 73 | 74 | def exp_safe(x: np.ndarray) -> np.ndarray: 75 | return np.exp(np.minimum(x, 87.5)) 76 | 77 | 78 | def expm1_safe(x: np.ndarray) -> np.ndarray: 79 | return np.expm1(np.minimum(x, 87.5)) 80 | 81 | 82 | def safe_sqrt(x: np.ndarray, eps: float = 1e-7) -> np.ndarray: 83 | safe_x = np.where(x == 0, np.ones_like(x) * eps, x) 84 | return np.sqrt(safe_x) 85 | -------------------------------------------------------------------------------- /utils/dycheck_utils/struct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : structures.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import NamedTuple, Optional 21 | 22 | import numpy as np 23 | 24 | 25 | class Metadata(NamedTuple): 26 | time: Optional[np.ndarray] = None 27 | camera: Optional[np.ndarray] = None 28 | time_to: Optional[np.ndarray] = None 29 | 30 | 31 | class Rays(NamedTuple): 32 | origins: np.ndarray 33 | directions: np.ndarray 34 | pixels: np.ndarray 35 | local_directions: Optional[np.ndarray] = None 36 | radii: Optional[np.ndarray] = None 37 | metadata: Optional[Metadata] = None 38 | 39 | near: Optional[np.ndarray] = None 40 | far: Optional[np.ndarray] = None 41 | 42 | 43 | class Samples(NamedTuple): 44 | xs: np.ndarray 45 | directions: np.ndarray 46 | cov_diags: Optional[np.ndarray] = None 47 | metadata: Optional[Metadata] = None 48 | 49 | tvals: Optional[np.ndarray] = None 50 | -------------------------------------------------------------------------------- /utils/dycheck_utils/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : types.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import Any, Callable, Tuple, Union 21 | 22 | # import jax.numpy as np 23 | import numpy as np 24 | 25 | 26 | PRNGKey = np.ndarray 27 | Shape = Tuple[int] 28 | Dtype = Any 29 | Array = Union[np.ndarray, np.ndarray] 30 | 31 | Activation = Callable[[Array], Array] 32 | Initializer = Callable[[PRNGKey, Shape, Dtype], Array] 33 | 34 | PathType = str 35 | ScheduleType = Callable[[int], float] 36 | EngineType = Any 37 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/corrs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : corrs.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import Optional 21 | 22 | import cv2 23 | import numpy as np 24 | 25 | from .. import image 26 | 27 | 28 | def visualize_corrs( 29 | corrs: np.ndarray, 30 | img: np.ndarray, 31 | img_to: np.ndarray, 32 | *, 33 | rgbs: Optional[np.ndarray] = None, 34 | min_rad: float = 5, 35 | subsample: int = 50, 36 | num_min_keeps: int = 10, 37 | circle_radius: int = 1, 38 | circle_thickness: int = -1, 39 | line_thickness: int = 1, 40 | alpha: float = 0.7, 41 | ): 42 | """Visualize a set of correspondences. 43 | 44 | By default this function visualizes a sparse set subset of correspondences 45 | with lines. 46 | 47 | Args: 48 | corrs (np.ndarray): A set of correspondences of shape (N, 2, 2), where 49 | the second dimension represents (from, to) and the last dimension 50 | represents (x, y) coordinates. 51 | img (np.ndarray): An image for start points of shape (Hi, Wi, 3) in 52 | either float32 or uint8. 53 | img_to (np.ndarray): An image for end points of shape (Ht, Wt, 3) in 54 | either float32 or uint8. 55 | rgbs (Optional[np.ndarray]): A set of rgbs for each correspondence 56 | of shape (N, 3) or (3,). If None then use pixel coordinates. 57 | Default: None. 58 | min_rad (float): The minimum threshold for the correspondence. 59 | subsample (int): The number of points to subsample. Default: 50. 60 | num_min_keeps (int): The number of correspondences to keep. Default: 61 | 10. 62 | circle_radius (int): The radius of the circle. Default: 1. 63 | circle_thickness (int): The thickness of the circle. Default: 1. 64 | line_thickness (int): The thickness of the line. Default: 1. 65 | alpha (float): The alpha value between [0, 1] for foreground blending. 66 | The bigger the more prominent of the visualization. Default: 0.7. 67 | 68 | Returns: 69 | np.ndarray: A visualization image of shape (H, W, 3) in uint8. 70 | """ 71 | corrs = np.array(corrs) 72 | img = image.to_uint8(img) 73 | img_to = image.to_uint8(img_to) 74 | rng = np.random.default_rng(0) 75 | 76 | (Hi, Wi), (Ht, Wt) = img.shape[:2], img_to.shape[:2] 77 | combined = np.concatenate([img, img_to], axis=1) 78 | canvas = combined.copy() 79 | 80 | norm = np.linalg.norm(corrs[:, 1] - corrs[:, 0], axis=-1) 81 | mask = ( 82 | (norm >= min_rad) 83 | & (corrs[..., 0, 0] < Wi) 84 | & (corrs[..., 0, 0] >= 0) 85 | & (corrs[..., 0, 1] < Hi) 86 | & (corrs[..., 0, 1] >= 0) 87 | & (corrs[..., 1, 0] < Wt) 88 | & (corrs[..., 1, 0] >= 0) 89 | & (corrs[..., 1, 1] < Ht) 90 | & (corrs[..., 1, 1] >= 0) 91 | ) 92 | filtered_inds = np.nonzero(mask)[0] 93 | num_min_keeps = min( 94 | max(num_min_keeps, filtered_inds.shape[0] // subsample), 95 | filtered_inds.shape[0], 96 | ) 97 | filtered_inds = rng.choice(filtered_inds, num_min_keeps, replace=False) if filtered_inds.shape[0] > 0 else [] 98 | 99 | if len(filtered_inds) > 0: 100 | if rgbs is None: 101 | # Use normalized pixel coordinate of img for colorization. 102 | corr = corrs[:, 0] 103 | phi = 2 * np.pi * (corr[:, 0] / (Wi - 1) - 0.5) 104 | theta = np.pi * (corr[:, 1] / (Hi - 1) - 0.5) 105 | x = np.cos(theta) * np.cos(phi) 106 | y = np.cos(theta) * np.sin(phi) 107 | z = np.sin(theta) 108 | rgbs = image.to_uint8((np.stack([x, y, z], axis=-1) + 1) / 2) 109 | for idx in filtered_inds: 110 | start = tuple(corrs[idx, 0].astype(np.int32)) 111 | end = tuple((corrs[idx, 1] + [Wi, 0]).astype(np.int32)) 112 | rgb = tuple(int(c) for c in (rgbs[idx] if rgbs.ndim == 2 else rgbs)) 113 | cv2.circle( 114 | combined, 115 | start, 116 | radius=circle_radius, 117 | color=rgb, 118 | thickness=circle_thickness, 119 | lineType=cv2.LINE_AA, 120 | ) 121 | cv2.circle( 122 | combined, 123 | end, 124 | radius=circle_radius, 125 | color=rgb, 126 | thickness=circle_thickness, 127 | lineType=cv2.LINE_AA, 128 | ) 129 | if line_thickness > 0: 130 | cv2.line( 131 | canvas, 132 | start, 133 | end, 134 | color=rgb, 135 | thickness=line_thickness, 136 | lineType=cv2.LINE_AA, 137 | ) 138 | 139 | combined = cv2.addWeighted(combined, alpha, canvas, 1 - alpha, 0) 140 | return combined 141 | 142 | 143 | def visualize_chained_corrs( 144 | corrs: np.ndarray, 145 | imgs: np.ndarray, 146 | *, 147 | rgbs: Optional[np.ndarray] = None, 148 | circle_radius: int = 1, 149 | circle_thickness: int = -1, 150 | line_thickness: int = 1, 151 | alpha: float = 0.7, 152 | ): 153 | """Visualize a set of correspondences. 154 | 155 | By default this function visualizes a sparse set subset of correspondences 156 | with lines. 157 | 158 | Args: 159 | corrs (np.ndarray): A set of correspondences of shape (N, C, 2), where 160 | the second dimension represents chained frames and the last 161 | dimension represents (x, y) coordinates. 162 | imgs (np.ndarray): An image for start points of shape (C, H, W, 3) in 163 | either float32 or uint8. 164 | rgbs (Optional[np.ndarray]): A set of rgbs for each correspondence 165 | of shape (N, 3) or (3,). If None then use pixel coordinates. 166 | Default: None. 167 | circle_radius (int): The radius of the circle. Default: 1. 168 | circle_thickness (int): The thickness of the circle. Default: 1. 169 | line_thickness (int): The thickness of the line. Default: 1. 170 | alpha (float): The alpha value between [0, 1] for foreground blending. 171 | The bigger the more prominent of the visualization. Default: 0.7. 172 | 173 | Returns: 174 | np.ndarray: A visualization image of shape (H, W, 3) in uint8. 175 | """ 176 | corrs = np.array(corrs) 177 | imgs = image.to_uint8(imgs) 178 | 179 | C, H, W = imgs.shape[:3] 180 | combined = np.concatenate(list(imgs), axis=1) 181 | canvas = combined.copy() 182 | 183 | if rgbs is None: 184 | # Use normalized pixel coordinate of img for colorization. 185 | corr = corrs[:, 0] 186 | phi = 2 * np.pi * (corr[:, 0] / (W - 1) - 0.5) 187 | theta = np.pi * (corr[:, 1] / (H - 1) - 0.5) 188 | x = np.cos(theta) * np.cos(phi) 189 | y = np.cos(theta) * np.sin(phi) 190 | z = np.sin(theta) 191 | rgbs = image.to_uint8((np.stack([x, y, z], axis=-1) + 1) / 2) 192 | 193 | for i in range(C): 194 | mask = (corrs[..., i, 0] < W) & (corrs[..., i, 0] >= 0) & (corrs[..., i, 1] < H) & (corrs[..., i, 1] >= 0) 195 | filtered_inds = np.nonzero(mask)[0] 196 | 197 | for idx in filtered_inds: 198 | start = tuple((corrs[idx, i] + [W * i, 0]).astype(np.int32)) 199 | rgb = tuple(int(c) for c in (rgbs[idx] if rgbs.ndim == 2 else rgbs)) 200 | cv2.circle( 201 | combined, 202 | start, 203 | radius=circle_radius, 204 | color=rgb, 205 | thickness=circle_thickness, 206 | lineType=cv2.LINE_AA, 207 | ) 208 | if ( 209 | line_thickness > 0 210 | and i < C - 1 211 | and ( 212 | (corrs[idx, i + 1, 0] < W) 213 | & (corrs[idx, i + 1, 0] >= 0) 214 | & (corrs[idx, i + 1, 1] < H) 215 | & (corrs[idx, i + 1, 1] >= 0) 216 | ) 217 | ): 218 | end = tuple((corrs[idx, i + 1] + [W * (i + 1), 0]).astype(np.int32)) 219 | cv2.line( 220 | canvas, 221 | start, 222 | end, 223 | color=rgb, 224 | thickness=line_thickness, 225 | lineType=cv2.LINE_AA, 226 | ) 227 | 228 | combined = cv2.addWeighted(combined, alpha, canvas, 1 - alpha, 0) 229 | return combined 230 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/depth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : depth.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import Callable, Optional, Union 21 | 22 | import numpy as np 23 | from dycheck.utils import image 24 | from matplotlib import cm 25 | 26 | 27 | def visualize_depth( 28 | depth: np.ndarray, 29 | acc: Optional[np.ndarray] = None, 30 | near: Optional[float] = None, 31 | far: Optional[float] = None, 32 | ignore_frac: float = 0, 33 | curve_fn: Callable = lambda x: -np.log(x + np.finfo(np.float32).eps), 34 | cmap: Union[str, Callable] = "turbo", 35 | invalid_depth: float = 0, 36 | ) -> np.ndarray: 37 | """Visualize a depth map. 38 | 39 | Args: 40 | depth (np.ndarray): A depth map of shape (H, W, 1). 41 | acc (np.ndarray): An accumulation map of shape (H, W, 1) in [0, 1]. 42 | near (Optional[float]): The depth of the near plane. If None then just 43 | use the min. Default: None. 44 | far (Optional[float]): The depth of the far plane. If None then just 45 | use the max. Default: None. 46 | ignore_frac (float): The fraction of the depth map to ignore when 47 | automatically generating `near` and `far`. Depends on `acc` as well 48 | as `depth'. Default: 0. 49 | curve_fn (Callable): A curve function that gets applied to `depth`, 50 | `near`, and `far` before the rest of visualization. Good choices: 51 | x, 1/(x+eps), log(x+eps). Note that the default choice will flip 52 | the sign of depths, so that the default cmap (turbo) renders "near" 53 | as red and "far" as blue. Default: a negative log scale mapping. 54 | cmap (Union[str, Callable]): A cmap for colorization. Default: "turbo". 55 | invalid_depth (float): The value to use for invalid depths. Can be 56 | np.nan. Default: 0. 57 | 58 | Returns: 59 | np.ndarray: A depth visualzation image of shape (H, W, 3) in uint8. 60 | """ 61 | depth = np.array(depth) 62 | if acc is None: 63 | acc = np.ones_like(depth) 64 | else: 65 | acc = np.array(acc) 66 | if invalid_depth is not None: 67 | if invalid_depth is np.nan: 68 | acc = np.where(np.isnan(depth), np.zeros_like(acc), acc) 69 | else: 70 | acc = np.where(depth == invalid_depth, np.zeros_like(acc), acc) 71 | 72 | if near is None or far is None: 73 | # Sort `depth` and `acc` according to `depth`, then identify the depth 74 | # values that span the middle of `acc`, ignoring `ignore_frac` fraction 75 | # of `acc`. 76 | sortidx = np.argsort(depth.reshape((-1,))) 77 | depth_sorted = depth.reshape((-1,))[sortidx] 78 | acc_sorted = acc.reshape((-1,))[sortidx] # type: ignore 79 | cum_acc_sorted = np.cumsum(acc_sorted) 80 | mask = (cum_acc_sorted >= cum_acc_sorted[-1] * ignore_frac) & ( 81 | cum_acc_sorted <= cum_acc_sorted[-1] * (1 - ignore_frac) 82 | ) 83 | if invalid_depth is not None: 84 | mask &= (depth_sorted != invalid_depth) if invalid_depth is not np.nan else ~np.isnan(depth_sorted) 85 | depth_keep = depth_sorted[mask] 86 | eps = np.finfo(np.float32).eps 87 | # If `near` or `far` are None, use the highest and lowest non-NaN 88 | # values in `depth_keep` as automatic near/far planes. 89 | near = near or depth_keep[0] - eps 90 | far = far or depth_keep[-1] + eps 91 | 92 | assert near < far 93 | 94 | # Curve all values. 95 | depth, near, far = [curve_fn(x) for x in [depth, near, far]] 96 | 97 | # Scale to [0, 1]. 98 | value = np.nan_to_num(np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1))[..., 0] 99 | 100 | if isinstance(cmap, str): 101 | cmap = cm.get_cmap(cmap) 102 | color = cmap(value)[..., :3] 103 | 104 | # Set non-accumulated pixels to white. 105 | color = color * acc + (1 - acc) # type: ignore 106 | 107 | return image.to_uint8(color) 108 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/flow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : flow.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from typing import Optional 21 | 22 | import cv2 23 | import jax 24 | import numpy as np 25 | 26 | from .. import image 27 | from .corrs import visualize_corrs 28 | 29 | 30 | def _make_colorwheel() -> np.ndarray: 31 | """Generates a classic color wheel for optical flow visualization. 32 | 33 | A Database and Evaluation Methodology for Optical Flow. 34 | Baker et al., ICCV 2007. 35 | http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 36 | 37 | Code follows the original C++ source code of Daniel Scharstein. 38 | Code follows the the Matlab source code of Deqing Sun. 39 | 40 | Returns: 41 | colorwheel (np.ndarray): Color wheel of shape (55, 3) in uint8. 42 | """ 43 | 44 | RY = 15 45 | YG = 6 46 | GC = 4 47 | CB = 11 48 | BM = 13 49 | MR = 6 50 | 51 | ncols = RY + YG + GC + CB + BM + MR 52 | colorwheel = np.zeros((ncols, 3)) 53 | col = 0 54 | 55 | # RY 56 | colorwheel[0:RY, 0] = 255 57 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 58 | col = col + RY 59 | # YG 60 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 61 | colorwheel[col : col + YG, 1] = 255 62 | col = col + YG 63 | # GC 64 | colorwheel[col : col + GC, 1] = 255 65 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 66 | col = col + GC 67 | # CB 68 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 69 | colorwheel[col : col + CB, 2] = 255 70 | col = col + CB 71 | # BM 72 | colorwheel[col : col + BM, 2] = 255 73 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 74 | col = col + BM 75 | # MR 76 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 77 | colorwheel[col : col + MR, 0] = 255 78 | return colorwheel 79 | 80 | 81 | def _flow_to_colors(flow: np.ndarray) -> np.ndarray: 82 | """Applies the flow color wheel to (possibly clipped) flow visualization 83 | image. 84 | 85 | According to the C++ source code of Daniel Scharstein. 86 | According to the Matlab source code of Deqing Sun. 87 | 88 | Args: 89 | flow (np.ndarray): Flow image of shape (H, W, 2). 90 | 91 | Returns: 92 | flow_visual (np.ndarray): Flow visualization image of shape (H, W, 3). 93 | """ 94 | u, v = jax.tree_map(lambda x: x[..., 0], np.split(flow, 2, axis=-1)) 95 | 96 | flow_visual = np.zeros(flow.shape[:2] + (3,), np.uint8) 97 | colorwheel = _make_colorwheel() 98 | ncols = colorwheel.shape[0] 99 | 100 | rad = np.sqrt(np.square(u) + np.square(v)) 101 | a = np.arctan2(-v, -u) / np.pi 102 | fk = (a + 1) / 2 * (ncols - 1) 103 | k0 = np.floor(fk).astype(np.int32) 104 | k1 = k0 + 1 105 | k1[k1 == ncols] = 0 106 | f = fk - k0 107 | for i in range(colorwheel.shape[1]): 108 | tmp = colorwheel[:, i] 109 | col0 = tmp[k0] / 255 110 | col1 = tmp[k1] / 255 111 | col = (1 - f) * col0 + f * col1 112 | idx = rad <= 1 113 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 114 | col[~idx] = col[~idx] * 0.75 # Out of range. 115 | flow_visual[..., i] = np.floor(255 * col) 116 | return flow_visual 117 | 118 | 119 | def visualize_flow( 120 | flow: np.ndarray, 121 | *, 122 | clip_flow: Optional[float] = None, 123 | rad_max: Optional[float] = None, 124 | ) -> np.ndarray: 125 | """Visualizei a flow image. 126 | 127 | Args: 128 | flow (np.ndarray): A flow image of shape (H, W, 2). 129 | clip_flow (Optional[float]): Clip flow to [0, clip_flow]. 130 | rad_max (Optional[float]): Maximum radius of the flow visualization. 131 | 132 | Returns: 133 | np.ndarray: Flow visualization image of shape (H, W, 3). 134 | """ 135 | flow = np.array(flow) 136 | 137 | if clip_flow is not None: 138 | flow = np.clip(flow, 0, clip_flow) 139 | rad = np.linalg.norm(flow, axis=-1, keepdims=True).clip(min=1e-6) 140 | if rad_max is None: 141 | rad_max = np.full_like(rad, rad.max()) 142 | else: 143 | # Clip big flow to rad_max while homogenously scaling for the rest. 144 | rad_max = rad.clip(min=rad_max) 145 | flow = flow / rad_max 146 | 147 | return _flow_to_colors(flow) 148 | 149 | 150 | def visualize_flow_arrows( 151 | flow: np.ndarray, 152 | img: np.ndarray, 153 | *, 154 | rgbs: Optional[np.ndarray] = None, 155 | clip_flow: Optional[float] = None, 156 | min_thresh: float = 5, 157 | subsample: int = 50, 158 | num_min_keeps: int = 10, 159 | line_thickness: int = 1, 160 | tip_length: float = 0.2, 161 | alpha: float = 0.5, 162 | ) -> np.ndarray: 163 | """Visualize a flow image with arrows. 164 | 165 | Args: 166 | flow (np.ndarray): A flow image of shape (H, W, 2). 167 | img (np.ndarray): An image for start points of shape (H, W, 3) in 168 | float32 or uint8. 169 | rgbs (Optional[np.ndarray]): A color map for the arrows at each pixel 170 | location of shape (H, W, 3). Default: None. 171 | clip_flow (Optional[float]): Clip flow to [0, clip_flow]. 172 | min_thresh (float): Minimum threshold for flow magnitude. 173 | subsample (int): Subsample the flow to speed up visualization. 174 | num_min_keeps (int): The number of correspondences to keep. Default: 175 | 10. 176 | line_thickness (int): Line thickness. Default: 1. 177 | tip_length (float): Length of the arrow tip. Default: 0.2. 178 | alpha (float): The alpha value between [0, 1] for foreground blending. 179 | The bigger the more prominent of the visualization. Default: 0.5. 180 | 181 | Returns: 182 | canvas (np.ndarray): Flow visualization image of shape (H, W, 3). 183 | """ 184 | img = image.to_uint8(img) 185 | canvas = img.copy() 186 | rng = np.random.default_rng(0) 187 | 188 | if rgbs is None: 189 | rgbs = visualize_flow(flow, clip_flow=clip_flow) 190 | H, W = flow.shape[:2] 191 | 192 | flow_start = np.stack(np.meshgrid(range(W), range(H)), 2) 193 | flow_end = (flow[flow_start[..., 1], flow_start[..., 0]] + flow_start).astype(np.int32) 194 | 195 | norm = np.linalg.norm(flow, axis=-1) 196 | valid_mask = ( 197 | (norm >= min_thresh) 198 | & (flow_end[..., 0] < flow.shape[1]) 199 | & (flow_end[..., 0] >= 0) 200 | & (flow_end[..., 1] < flow.shape[0]) 201 | & (flow_end[..., 1] >= 0) 202 | ) 203 | filtered_inds = np.stack(np.nonzero(valid_mask), axis=-1) 204 | num_min_keeps = min( 205 | max(num_min_keeps, filtered_inds.shape[0] // subsample), 206 | filtered_inds.shape[0], 207 | ) 208 | filtered_inds = rng.choice(filtered_inds, num_min_keeps, replace=False) if filtered_inds.shape[0] > 0 else [] 209 | 210 | for inds in filtered_inds: 211 | y, x = inds 212 | start = tuple(flow_start[y, x]) 213 | end = tuple(flow_end[y, x]) 214 | rgb = tuple(int(x) for x in rgbs[y, x]) 215 | cv2.arrowedLine( 216 | canvas, 217 | start, 218 | end, 219 | color=rgb, 220 | thickness=line_thickness, 221 | tipLength=tip_length, 222 | line_type=cv2.LINE_AA, 223 | ) 224 | 225 | canvas = cv2.addWeighted(img, alpha, canvas, 1 - alpha, 0) 226 | return canvas 227 | 228 | 229 | def visualize_flow_corrs( 230 | flow: np.ndarray, 231 | img: np.ndarray, 232 | img_to: np.ndarray, 233 | *, 234 | mask: Optional[np.ndarray] = None, 235 | rgbs: Optional[np.ndarray] = None, 236 | **kwargs, 237 | ) -> np.ndarray: 238 | """Visualize a flow image as a set of correspondences. 239 | 240 | Args: 241 | flow (np.ndarray): A flow image of shape (H, W, 2). 242 | img (np.ndarray): An image for start points of shape (H, W, 3) in 243 | float32 or uint8. 244 | img_to (np.ndarray): An image for end points of shape (H, W, 3) in 245 | float32 or uint8. 246 | mask (Optional[np.ndarray]): A hard mask for start points of shape 247 | (H, W, 1). Default: None. 248 | rgbs (Optional[np.ndarray]): A color map for the arrows at each pixel 249 | location of shape (H, W, 3). Default: None. 250 | 251 | Returns: 252 | canvas (np.ndarray): Flow visualization image of shape (H, W, 3). 253 | """ 254 | flow_start = np.stack(np.meshgrid(range(flow.shape[1]), range(flow.shape[0])), 2) 255 | flow_end = (flow[flow_start[..., 1], flow_start[..., 0]] + flow_start).astype(np.int32) 256 | flow_corrs = np.stack([flow_start, flow_end]) 257 | 258 | if mask is not None: 259 | # Only show correspondences inside of the mask. 260 | flow_corrs = flow_corrs[np.stack([mask[..., 0]] * 2)] 261 | if rgbs is not None: 262 | rgbs = rgbs[mask[..., 0]] 263 | if flow_corrs.shape[0] == 0: 264 | flow_corrs = np.ones((0, 4)) 265 | if rgbs is not None: 266 | rgbs = np.ones((0, 3)) 267 | flow_corrs = flow_corrs.reshape(2, -1, 2).swapaxes(0, 1) 268 | 269 | return visualize_corrs(flow_corrs, img, img_to, rgbs=rgbs, **kwargs) 270 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/kps/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : __init__.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | from copy import deepcopy 21 | from typing import Callable, Optional, Union 22 | 23 | import cv2 24 | import numpy as np 25 | from dycheck.utils import image 26 | 27 | from .skeleton import SKELETON_MAP, Skeleton 28 | 29 | 30 | def visualize_kps( 31 | kps: np.ndarray, 32 | img: np.ndarray, 33 | *, 34 | skeleton: Union[str, Skeleton, Callable[..., Skeleton]] = "unconnected", 35 | rgbs: Optional[np.ndarray] = None, 36 | kp_radius: int = 4, 37 | bone_thickness: int = 3, 38 | **kwargs, 39 | ) -> np.ndarray: 40 | """Visualize 2D keypoints with their skeleton. 41 | 42 | Args: 43 | kps (np.ndarray): an array of shape (J, 3) for keypoints. Expect the 44 | last column to be the visibility in [0, 1]. 45 | img (np.ndarray): a RGB image of shape (H, W, 3) in float32 or uint8. 46 | skeleton_cls (Union[str, Callable[..., Skeleton]]): a class name or a 47 | callable that returns a Skeleton instance. 48 | rgbs (Optional[np.ndarray]): A set of rgbs for each keypoint of shape 49 | (J, 3) or (3,). If None then use skeleton palette. Default: None. 50 | kp_radius (int): the radius of kps for visualization. Default: 4. 51 | bone_thickness (int): the thickness of bones connecting kps for 52 | visualization. Default: 3. 53 | 54 | Returns: 55 | combined (np.ndarray): Keypoint visualzation image of shape (H, W, 3) 56 | in uint8. 57 | """ 58 | if isinstance(skeleton, str): 59 | skeleton = SKELETON_MAP[skeleton] 60 | if isinstance(skeleton, Callable): 61 | skeleton = skeleton(num_kps=len(kps), **kwargs) 62 | if rgbs is not None: 63 | if rgbs.ndim == 1: 64 | rgbs = rgbs[None, :].repeat(skeleton.num_kps, axis=0) 65 | skeleton = deepcopy(skeleton) 66 | skeleton._palette = rgbs.tolist() 67 | 68 | assert skeleton.num_kps == len(kps) 69 | 70 | kps = np.array(kps) 71 | img = image.to_uint8(img) 72 | 73 | H, W = img.shape[:2] 74 | canvas = img.copy() 75 | 76 | mask = (kps[:, -1] != 0) & (kps[:, 0] >= 0) & (kps[:, 0] < W) & (kps[:, 1] >= 0) & (kps[:, 1] < H) 77 | 78 | # Visualize bones. 79 | palette = skeleton.non_root_palette 80 | bones = skeleton.non_root_bones 81 | for rgb, (j, p) in zip(palette, bones): 82 | # Skip invisible keypoints. 83 | if (~mask[[j, p]]).any(): 84 | continue 85 | 86 | kp_p, kp_j = kps[p, :2], kps[j, :2] 87 | kp_mid = (kp_p + kp_j) / 2 88 | bone_length = np.linalg.norm(kp_j - kp_p) 89 | bone_angle = (np.arctan2(kp_j[1] - kp_p[1], kp_j[0] - kp_p[0])) * 180 / np.pi 90 | polygon = cv2.ellipse2Poly( 91 | (int(kp_mid[0]), int(kp_mid[1])), 92 | (int(bone_length / 2), bone_thickness), 93 | int(bone_angle), 94 | arcStart=0, 95 | arcEnd=360, 96 | delta=5, 97 | ) 98 | cv2.fillConvexPoly(canvas, polygon, rgb, lineType=cv2.LINE_AA) 99 | canvas = cv2.addWeighted(img, 0.5, canvas, 0.5, 0) 100 | 101 | # Visualize keypoints. 102 | combined = canvas.copy() 103 | palette = skeleton.palette 104 | for rgb, kp, valid in zip(palette, kps, mask): 105 | # Skip invisible keypoints. 106 | if not valid: 107 | continue 108 | 109 | cv2.circle( 110 | combined, 111 | (int(kp[0]), int(kp[1])), 112 | radius=kp_radius, 113 | color=rgb, 114 | thickness=-1, 115 | lineType=cv2.LINE_AA, 116 | ) 117 | combined = cv2.addWeighted(canvas, 0.3, combined, 0.7, 0) 118 | 119 | return combined 120 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/kps/skeleton.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : skeleton.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import re 21 | from typing import Callable, Optional, Sequence, Tuple, Union 22 | 23 | import gin 24 | import numpy as np 25 | from dycheck.utils import image 26 | from matplotlib import cm 27 | 28 | 29 | KP_PALETTE_MAP = {} 30 | 31 | 32 | @gin.configurable() 33 | class Skeleton(object): 34 | name = "skeleton" 35 | _anonymous_kp_name = "ANONYMOUS KP" 36 | 37 | def __init__( 38 | self, 39 | parents: Sequence[Optional[int]], 40 | kp_names: Optional[Sequence[str]] = None, 41 | palette: Optional[Sequence[Tuple[int, int, int]]] = None, 42 | ): 43 | if kp_names is not None: 44 | assert len(parents) == len(kp_names) 45 | if palette is not None: 46 | assert len(kp_names) == len(palette) 47 | 48 | self._parents = parents 49 | self._kp_names = kp_names if kp_names is not None else [self._anonymous_kp_name] * self.num_kps 50 | self._palette = palette 51 | 52 | def asdict(self): 53 | return { 54 | "name": self.name, 55 | "parents": self.parents, 56 | "kp_names": self.kp_names, 57 | "palette": self.palette, 58 | } 59 | 60 | @property 61 | def is_unconnected(self): 62 | return all([p is None for p in self._parents]) 63 | 64 | @property 65 | def parents(self): 66 | return self._parents 67 | 68 | @property 69 | def kp_names(self): 70 | return self._kp_names 71 | 72 | @property 73 | def palette(self): 74 | if self._palette is not None: 75 | return self._palette 76 | 77 | if self.kp_names[0] != self._anonymous_kp_name and all( 78 | [kp_name in KP_PALETTE_MAP for kp_name in self.kp_names] 79 | ): 80 | return [KP_PALETTE_MAP[kp_name] for kp_name in self.kp_names] 81 | 82 | palette = np.zeros((self.num_kps, 3), dtype=np.uint8) 83 | left_mask = np.array( 84 | [len(re.findall(r"^(\w+ |)L\w+$", kp_name)) > 0 for kp_name in self._kp_names], 85 | dtype=np.bool, 86 | ) 87 | palette[left_mask] = (255, 0, 0) 88 | return [tuple(color.tolist()) for color in palette] 89 | 90 | @property 91 | def num_kps(self): 92 | return len(self._parents) 93 | 94 | @property 95 | def root_idx(self): 96 | if self.is_unconnected: 97 | return 0 98 | return self._parents.index(-1) 99 | 100 | @property 101 | def bones(self): 102 | if self.is_unconnected: 103 | return [] 104 | return np.stack([list(range(self.num_kps)), self.parents]).T.tolist() 105 | 106 | @property 107 | def non_root_bones(self): 108 | if self.is_unconnected: 109 | return [] 110 | return np.delete(self.bones.copy(), self.root_idx, axis=0) 111 | 112 | @property 113 | def non_root_palette(self): 114 | if self.is_unconnected: 115 | return [] 116 | return np.delete(self.palette.copy(), self.root_idx, axis=0).tolist() 117 | 118 | 119 | @gin.configurable() 120 | class UnconnectedSkeleton(Skeleton): 121 | """A keypoint skeleton that does not define parents. This could be useful 122 | when organizing randomly annotated keypoints. 123 | """ 124 | 125 | name: str = "unconnected" 126 | 127 | def __init__(self, num_kps: int, cmap: Union[str, Callable] = "gist_rainbow"): 128 | if isinstance(cmap, str): 129 | cmap = cm.get_cmap(cmap, num_kps) 130 | pallete = image.to_uint8(np.array([cmap(i)[:3] for i in range(num_kps)], np.float32)).tolist() 131 | super().__init__( 132 | parents=[None for _ in range(num_kps)], 133 | kp_names=[f"KP_{i}" for i in range(num_kps)], 134 | palette=pallete, 135 | ) 136 | 137 | 138 | @gin.configurable() 139 | class HumanSkeleton(Skeleton): 140 | """A human skeleton following the COCO dataset. 141 | 142 | Microsoft COCO: Common Objects in Context. 143 | Lin et al., ECCV 2014. 144 | https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48 145 | 146 | For pictorial definition, see also: shorturl.at/ilnpZ. 147 | """ 148 | 149 | name: str = "human" 150 | 151 | def __init__(self, **_): 152 | super().__init__( 153 | parents=[ 154 | 1, 155 | -1, 156 | 1, 157 | 2, 158 | 3, 159 | 1, 160 | 5, 161 | 6, 162 | 1, 163 | 8, 164 | 9, 165 | 1, 166 | 11, 167 | 12, 168 | 0, 169 | 0, 170 | 14, 171 | 15, 172 | ], 173 | kp_names=[ 174 | "Nose", 175 | "Neck", 176 | "RShoulder", 177 | "RElbow", 178 | "RWrist", 179 | "LShoulder", 180 | "LElbow", 181 | "LWrist", 182 | "RHip", 183 | "RKnee", 184 | "RAnkle", 185 | "LHip", 186 | "LKnee", 187 | "LAnkle", 188 | "REye", 189 | "LEye", 190 | "REar", 191 | "LEar", 192 | ], 193 | palette=[ 194 | (255, 0, 0), 195 | (255, 85, 0), 196 | (255, 170, 0), 197 | (255, 255, 0), 198 | (170, 255, 0), 199 | (85, 255, 0), 200 | (0, 255, 0), 201 | (0, 255, 85), 202 | (0, 255, 170), 203 | (0, 255, 255), 204 | (0, 170, 255), 205 | (0, 85, 255), 206 | (0, 0, 255), 207 | (85, 0, 255), 208 | (170, 0, 255), 209 | (255, 0, 255), 210 | (255, 0, 170), 211 | (255, 0, 85), 212 | ], 213 | ) 214 | 215 | 216 | @gin.configurable() 217 | class QuadrupedSkeleton(Skeleton): 218 | """A quadruped skeleton following StanfordExtra dataset. 219 | 220 | Novel dataset for Fine-Grained Image Categorization. 221 | Khosla et al., CVPR 2011, FGVC workshop. 222 | http://vision.stanford.edu/aditya86/ImageNetDogs/main.html 223 | 224 | Who Left the Dogs Out? 3D Animal Reconstruction with Expectation 225 | Maximization in the Loop. 226 | Biggs et al., ECCV 2020. 227 | https://arxiv.org/abs/2007.11110 228 | """ 229 | 230 | name: str = "quadruped" 231 | 232 | def __init__(self, **_): 233 | super().__init__( 234 | parents=[ 235 | 1, 236 | 2, 237 | 22, 238 | 4, 239 | 5, 240 | 12, 241 | 7, 242 | 8, 243 | 22, 244 | 10, 245 | 11, 246 | 12, 247 | -1, 248 | 12, 249 | 20, 250 | 21, 251 | 17, 252 | 23, 253 | 14, 254 | 15, 255 | 16, 256 | 16, 257 | 12, 258 | 22, 259 | ], 260 | kp_names=[ 261 | "LFrontPaw", 262 | "LFrontWrist", 263 | "LFrontElbow", 264 | "LRearPaw", 265 | "LRearWrist", 266 | "LRearElbow", 267 | "RFrontPaw", 268 | "RFrontWrist", 269 | "RFrontElbow", 270 | "RRearPaw", 271 | "RRearWrist", 272 | "RRearElbow", 273 | "TailStart", 274 | "TailEnd", 275 | "LEar", 276 | "REar", 277 | "Nose", 278 | "Chin", 279 | "LEarTip", 280 | "REarTip", 281 | "LEye", 282 | "REye", 283 | "Withers", 284 | "Throat", 285 | ], 286 | palette=[ 287 | (0, 255, 0), 288 | (63, 255, 0), 289 | (127, 255, 0), 290 | (0, 0, 255), 291 | (0, 63, 255), 292 | (0, 127, 255), 293 | (255, 255, 0), 294 | (255, 191, 0), 295 | (255, 127, 0), 296 | (0, 255, 255), 297 | (0, 255, 191), 298 | (0, 255, 127), 299 | (0, 0, 0), 300 | (0, 0, 0), 301 | (255, 0, 170), 302 | (255, 0, 170), 303 | (255, 0, 170), 304 | (255, 0, 170), 305 | (255, 0, 170), 306 | (255, 0, 170), 307 | (255, 0, 170), 308 | (255, 0, 170), 309 | (255, 0, 170), 310 | (255, 0, 170), 311 | ], 312 | ) 313 | 314 | 315 | SKELETON_MAP = { 316 | cls.name: cls 317 | for cls in [ 318 | Skeleton, 319 | UnconnectedSkeleton, 320 | HumanSkeleton, 321 | QuadrupedSkeleton, 322 | ] 323 | } 324 | -------------------------------------------------------------------------------- /utils/dycheck_utils/visuals/rendering.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File : rendering.py 4 | # Author : Hang Gao 5 | # Email : hangg.sv7@gmail.com 6 | # 7 | # Copyright 2022 Adobe. All rights reserved. 8 | # 9 | # This file is licensed to you under the Apache License, Version 2.0 (the 10 | # "License"); you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 17 | # License for the specific language governing permissions and limitations under 18 | # the License. 19 | 20 | import matplotlib 21 | 22 | 23 | matplotlib.use("agg") 24 | 25 | from typing import NamedTuple, Optional, Tuple 26 | 27 | import matplotlib.pyplot as plt 28 | import numpy as np 29 | from dycheck import geometry 30 | from matplotlib.backends.backend_agg import FigureCanvasAgg 31 | from matplotlib.collections import PolyCollection 32 | 33 | from .. import image 34 | 35 | 36 | class Renderings(NamedTuple): 37 | # (H, W, 3) in uint8. 38 | rgb: Optional[np.ndarray] = None 39 | # (H, W, 1) in float32. 40 | depth: Optional[np.ndarray] = None 41 | # (H, W, 1) in [0, 1]. 42 | acc: Optional[np.ndarray] = None 43 | 44 | 45 | def visualize_pcd_renderings(points: np.ndarray, point_rgbs: np.ndarray, camera: geometry.Camera, **_) -> Renderings: 46 | """Visualize a point cloud as a set renderings. 47 | 48 | Args: 49 | points (np.ndarray): (N, 3) array of points. 50 | point_rgbs (np.ndarray): (N, 3) array of point colors in either uint8 51 | or float32. 52 | camera (geometry.Camera): a camera object containing view information. 53 | 54 | Returns: 55 | Renderings: the image output object. 56 | """ 57 | 58 | point_rgbs = image.to_uint8(point_rgbs) # type: ignore 59 | 60 | # Setup the camera. 61 | W, H = camera.image_size 62 | 63 | # project the 3D points to 2D on image plane 64 | pixels, depths = camera.project(points, return_depth=True, use_projective_depth=True) 65 | pixels = pixels.astype(np.int32) 66 | mask = (pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] >= 0) & (pixels[:, 1] < H) & (depths[:, 0] > 0) 67 | 68 | pixels = pixels[mask] 69 | rgbs = point_rgbs[mask] 70 | depths = depths[mask] 71 | 72 | sorted_inds = np.argsort(depths[..., 0])[::-1] 73 | pixels = pixels[sorted_inds] 74 | rgbs = rgbs[sorted_inds] 75 | depths = depths[sorted_inds] 76 | 77 | rgb = np.full((H, W, 3), 255, dtype=np.uint8) 78 | rgb[pixels[:, 1], pixels[:, 0]] = rgbs 79 | 80 | depth = np.zeros((H, W, 1), dtype=np.float32) 81 | depth[pixels[:, 1], pixels[:, 0]] = depths 82 | 83 | acc = np.zeros((H, W, 1), dtype=np.float32) 84 | acc[pixels[:, 1], pixels[:, 0]] = 1 85 | 86 | return Renderings(rgb=rgb, depth=depth, acc=acc) 87 | 88 | 89 | def _is_front(T: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 90 | # Triangle is front facing if its projection on XY plane is clock-wise. 91 | Z = ( 92 | (T[:, 1, 0] - T[:, 0, 0]) * (T[:, 1, 1] + T[:, 0, 1]) 93 | + (T[:, 2, 0] - T[:, 1, 0]) * (T[:, 2, 1] + T[:, 1, 1]) 94 | + (T[:, 0, 0] - T[:, 2, 0]) * (T[:, 0, 1] + T[:, 2, 1]) 95 | ) 96 | return Z >= 0 97 | 98 | 99 | def grid_faces(h: int, w: int, mask: Optional[np.ndarray] = None) -> np.ndarray: 100 | """Creates mesh face indices from a given pixel grid size. 101 | 102 | Args: 103 | h (int): image height. 104 | w (int): image width. 105 | mask (Optional[np.ndarray], optional): mask of valid pixels. Defaults 106 | to None. 107 | 108 | Returns: 109 | faces (np.ndarray): array of face indices. Note that the face indices 110 | include invalid pixels (they are not excluded). 111 | """ 112 | if mask is None: 113 | mask = np.ones((h, w), bool) 114 | 115 | x, y = np.meshgrid(range(w - 1), range(h - 1)) 116 | tl = y * w + x 117 | tr = y * w + x + 1 118 | bl = (y + 1) * w + x 119 | br = (y + 1) * w + x + 1 120 | faces_1 = np.stack([tl, bl, tr], axis=-1)[mask[:-1, :-1] & mask[1:, :-1] & mask[:-1, 1:]] 121 | faces_2 = np.stack([br, tr, bl], axis=-1)[mask[1:, 1:] & mask[:-1, 1:] & mask[1:, :-1]] 122 | faces = np.concatenate([faces_1, faces_2], axis=0) 123 | return faces 124 | 125 | 126 | def visualize_mesh_renderings( 127 | points: np.ndarray, faces: np.ndarray, point_rgbs: np.ndarray, camera: geometry.Camera, **_ 128 | ): 129 | """Visualize a mesh as a set renderings. 130 | 131 | Note that front facing triangles are defined in clock-wise orientation. 132 | 133 | Args: 134 | points (np.ndarray): (N, 3) array of points. 135 | faces (np.ndarray): (F, 3) array of faces. 136 | point_rgbs (np.ndarray): (N, 3) array of point colors in either uint8 137 | or float32. 138 | camera (geometry.Camera): a camera object containing view information. 139 | 140 | Returns: 141 | Renderings: the image output object. 142 | """ 143 | 144 | # High quality output. 145 | DPI = 10.0 146 | 147 | face_rgbs = image.to_float32(point_rgbs[faces]).mean(axis=-2) 148 | 149 | # Setup the camera. 150 | W, H = camera.image_size 151 | 152 | # Project the 3D points to 2D on image plane. 153 | pixels, depth = camera.project(points, return_depth=True, use_projective_depth=True) 154 | 155 | T = pixels[faces] 156 | Z = -depth[faces][..., 0].mean(axis=1) 157 | front = _is_front(T) 158 | T, Z = T[front], Z[front] 159 | face_rgbs = face_rgbs[front] 160 | 161 | # Sort triangles according to z buffer. 162 | triangles = T[:, :, :2] 163 | sorted_inds = np.argsort(Z) 164 | triangles = triangles[sorted_inds] 165 | face_rgbs = face_rgbs[sorted_inds] 166 | 167 | # Painter's algorithm using matplotlib. 168 | fig = plt.figure(figsize=(W / DPI, H / DPI), dpi=DPI) 169 | canvas = FigureCanvasAgg(fig) 170 | ax = fig.add_axes([0, 0, 1, 1], xlim=[0, W], ylim=[H, 0], aspect=1) 171 | ax.axis("off") 172 | 173 | collection = PolyCollection([], closed=True) 174 | collection.set_verts(triangles) 175 | collection.set_linewidths(0.0) 176 | collection.set_facecolors(face_rgbs) 177 | ax.add_collection(collection) 178 | 179 | canvas.draw() 180 | s, _ = canvas.print_to_buffer() 181 | img = np.frombuffer(s, np.uint8).reshape((H, W, 4)) 182 | plt.close(fig) 183 | 184 | rgb = img[..., :3] 185 | acc = (img[..., -1:] > 0).astype(np.float32) 186 | 187 | # Depth is not fully supported by maptlotlib yet and we render it as if 188 | # it's an image by a hack. 189 | fig = plt.figure(figsize=(W / DPI, H / DPI), dpi=DPI) 190 | canvas = FigureCanvasAgg(fig) 191 | ax = fig.add_axes([0, 0, 1, 1], xlim=[0, W], ylim=[H, 0], aspect=1) 192 | ax.axis("off") 193 | 194 | collection = PolyCollection([], closed=True) 195 | collection.set_verts(triangles) 196 | collection.set_linewidths(0.0) 197 | Z = -Z[sorted_inds] 198 | Zmin = Z.min() 199 | Zmax = Z.max() 200 | Z = (Z - Zmin) / (Zmax - Zmin) 201 | collection.set_facecolors(Z[..., None].repeat(3, axis=-1)) 202 | ax.add_collection(collection) 203 | 204 | canvas.draw() 205 | s, _ = canvas.print_to_buffer() 206 | depth = image.to_float32(np.frombuffer(s, np.uint8).reshape((H, W, 4))[..., :1]) * (Zmax - Zmin) + Zmin 207 | depth[acc[..., 0] == 0] = 0 208 | plt.close(fig) 209 | 210 | return Renderings(rgb=rgb, depth=depth, acc=acc) 211 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import random 13 | import sys 14 | from datetime import datetime 15 | 16 | import numpy as np 17 | import torch 18 | 19 | 20 | def inverse_sigmoid(x): 21 | return torch.log(x / (1 - x)) 22 | 23 | 24 | def PILtoTorch(pil_image, resolution): 25 | if resolution is not None: 26 | resized_image_PIL = pil_image.resize(resolution) 27 | else: 28 | resized_image_PIL = pil_image 29 | if np.array(resized_image_PIL).max() != 1: 30 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 31 | else: 32 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) 33 | if len(resized_image.shape) == 3: 34 | return resized_image.permute(2, 0, 1) 35 | else: 36 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 37 | 38 | 39 | def get_expon_lr_func(lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000): 40 | """ 41 | Copied from Plenoxels 42 | 43 | Continuous learning rate decay function. Adapted from JaxNeRF 44 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 45 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 46 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 47 | function of lr_delay_mult, such that the initial learning rate is 48 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 49 | to the normal learning rate when steps>lr_delay_steps. 50 | :param conf: config subtree 'lr' or similar 51 | :param max_steps: int, the number of steps during optimization. 52 | :return HoF which takes step as input 53 | """ 54 | 55 | def helper(step): 56 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 57 | # Disable this parameter 58 | return 0.0 59 | if lr_delay_steps > 0: 60 | # A kind of reverse cosine decay. 61 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 62 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 63 | ) 64 | else: 65 | delay_rate = 1.0 66 | t = np.clip(step / max_steps, 0, 1) 67 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 68 | 69 | return delay_rate * log_lerp 70 | 71 | return helper 72 | 73 | 74 | def strip_lowerdiag(L): 75 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 76 | 77 | uncertainty[:, 0] = L[:, 0, 0] 78 | uncertainty[:, 1] = L[:, 0, 1] 79 | uncertainty[:, 2] = L[:, 0, 2] 80 | uncertainty[:, 3] = L[:, 1, 1] 81 | uncertainty[:, 4] = L[:, 1, 2] 82 | uncertainty[:, 5] = L[:, 2, 2] 83 | return uncertainty 84 | 85 | 86 | def strip_symmetric(sym): 87 | return strip_lowerdiag(sym) 88 | 89 | 90 | def build_rotation(r): 91 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 92 | 93 | q = r / norm[:, None] 94 | 95 | R = torch.zeros((q.size(0), 3, 3), device="cuda") 96 | 97 | r = q[:, 0] 98 | x = q[:, 1] 99 | y = q[:, 2] 100 | z = q[:, 3] 101 | 102 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 103 | R[:, 0, 1] = 2 * (x * y - r * z) 104 | R[:, 0, 2] = 2 * (x * z + r * y) 105 | R[:, 1, 0] = 2 * (x * y + r * z) 106 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 107 | R[:, 1, 2] = 2 * (y * z - r * x) 108 | R[:, 2, 0] = 2 * (x * z - r * y) 109 | R[:, 2, 1] = 2 * (y * z + r * x) 110 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 111 | return R 112 | 113 | 114 | def build_scaling_rotation(s, r): 115 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 116 | R = build_rotation(r) 117 | 118 | L[:, 0, 0] = s[:, 0] 119 | L[:, 1, 1] = s[:, 1] 120 | L[:, 2, 2] = s[:, 2] 121 | 122 | L = R @ L 123 | return L 124 | 125 | 126 | def safe_state(silent, seed=0): 127 | old_f = sys.stdout 128 | 129 | class F: 130 | def __init__(self, silent): 131 | self.silent = silent 132 | 133 | def write(self, x): 134 | if not self.silent: 135 | if x.endswith("\n"): 136 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 137 | else: 138 | old_f.write(x) 139 | 140 | def flush(self): 141 | old_f.flush() 142 | 143 | sys.stdout = F(silent) 144 | 145 | torch.manual_seed(seed) 146 | torch.cuda.manual_seed_all(seed) 147 | np.random.seed(seed) 148 | random.seed(seed) 149 | torch.backends.cudnn.deterministic = True 150 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import math 13 | from typing import NamedTuple 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | # from gsplat._torch_impl import clip_near_plane, scale_rot_to_cov3d, project_cov3d_ewa, compute_cov2d_bounds, project_pix 20 | 21 | 22 | class BasicPointCloud(NamedTuple): 23 | points: np.array 24 | colors: np.array 25 | normals: np.array 26 | times: np.array 27 | 28 | 29 | def geom_transform_points(points, transf_matrix): 30 | P, _ = points.shape 31 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 32 | points_hom = torch.cat([points, ones], dim=1) 33 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 34 | 35 | denom = points_out[..., 3:] + 0.0000001 36 | return (points_out[..., :3] / denom).squeeze(dim=0) 37 | 38 | 39 | def getWorld2View(R, t): 40 | Rt = np.zeros((4, 4)) 41 | Rt[:3, :3] = R.transpose() 42 | Rt[:3, 3] = t 43 | Rt[3, 3] = 1.0 44 | return np.float32(Rt) 45 | 46 | 47 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0): 48 | Rt = np.zeros((4, 4)) 49 | Rt[:3, :3] = R.transpose() 50 | Rt[:3, 3] = t 51 | Rt[3, 3] = 1.0 52 | 53 | C2W = np.linalg.inv(Rt) 54 | cam_center = C2W[:3, 3] 55 | cam_center = (cam_center + translate) * scale 56 | C2W[:3, 3] = cam_center 57 | Rt = np.linalg.inv(C2W) 58 | return np.float32(Rt) 59 | 60 | def getWorld2View2_torch(R, t, translate=torch.tensor([.0, .0, .0]), scale=1.0): 61 | Rt = torch.cat([R.transpose(0, 1), t.unsqueeze(1)], dim=1) 62 | Rt_fill = torch.tensor([0, 0, 0, 1], dtype=Rt.dtype, device=Rt.device).unsqueeze(0) 63 | Rt = torch.cat([Rt, Rt_fill], dim=0) 64 | return Rt 65 | 66 | def getProjectionMatrix(znear, zfar, fovX, fovY): 67 | tanHalfFovY = math.tan((fovY / 2)) 68 | tanHalfFovX = math.tan((fovX / 2)) 69 | 70 | top = tanHalfFovY * znear 71 | bottom = -top 72 | right = tanHalfFovX * znear 73 | left = -right 74 | 75 | P = torch.zeros(4, 4) 76 | 77 | z_sign = 1.0 78 | 79 | P[0, 0] = 2.0 * znear / (right - left) 80 | P[1, 1] = 2.0 * znear / (top - bottom) 81 | P[0, 2] = (right + left) / (right - left) 82 | P[1, 2] = (top + bottom) / (top - bottom) 83 | P[3, 2] = z_sign 84 | P[2, 2] = z_sign * zfar / (zfar - znear) 85 | P[2, 3] = -(zfar * znear) / (zfar - znear) 86 | return P 87 | 88 | 89 | def fov2focal(fov, pixels): 90 | return pixels / (2 * math.tan(fov / 2)) 91 | 92 | 93 | def focal2fov(focal, pixels): 94 | return 2 * math.atan(pixels / (2 * focal)) 95 | 96 | 97 | def apply_rotation(q1, q2): 98 | """ 99 | Applies a rotation to a quaternion. 100 | 101 | Parameters: 102 | q1 (Tensor): The original quaternion. 103 | q2 (Tensor): The rotation quaternion to be applied. 104 | 105 | Returns: 106 | Tensor: The resulting quaternion after applying the rotation. 107 | """ 108 | # Extract components for readability 109 | w1, x1, y1, z1 = q1 110 | w2, x2, y2, z2 = q2 111 | 112 | # Compute the product of the two quaternions 113 | w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 114 | x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 115 | y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 116 | z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 117 | 118 | # Combine the components into a new quaternion tensor 119 | q3 = torch.tensor([w3, x3, y3, z3]) 120 | 121 | # Normalize the resulting quaternion 122 | q3_normalized = q3 / torch.norm(q3) 123 | 124 | return q3_normalized 125 | 126 | 127 | def batch_quaternion_multiply(q1, q2): 128 | """ 129 | Multiply batches of quaternions. 130 | 131 | Args: 132 | - q1 (torch.Tensor): A tensor of shape [N, 4] representing the first batch of quaternions. 133 | - q2 (torch.Tensor): A tensor of shape [N, 4] representing the second batch of quaternions. 134 | 135 | Returns: 136 | - torch.Tensor: The resulting batch of quaternions after applying the rotation. 137 | """ 138 | # Calculate the product of each quaternion in the batch 139 | w = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3] 140 | x = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2] 141 | y = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1] 142 | z = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0] 143 | 144 | # Combine into new quaternions 145 | q3 = torch.stack((w, x, y, z), dim=1) 146 | 147 | # Normalize the quaternions 148 | norm_q3 = q3 / torch.norm(q3, dim=1, keepdim=True) 149 | 150 | return norm_q3 151 | 152 | 153 | def cam2pixel(pts, K): 154 | pixels = torch.matmul(K, pts.unsqueeze(-1)).squeeze(-1) 155 | pixels = pixels[:, :2] / (pixels[:, 2:] + 0.0000001) 156 | return pixels 157 | 158 | 159 | def pts2pixel(pts, cam_info, K): 160 | cam_pts = geom_transform_points( 161 | pts, cam_info.world_view_transform.to(pts.device) 162 | ) #! this is column-wise transformation 163 | # # check 164 | # view_R = torch.tensor(cam_info.R.T, device=pts.device) 165 | # view_T = torch.tensor(cam_info.T, device=pts.device) 166 | # cam_pts = torch.matmul(view_R[None], pts[:, :, None]).squeeze(-1) + view_T[None] 167 | 168 | pixels = cam2pixel(cam_pts, K) 169 | return pixels 170 | 171 | 172 | def project_gaussians_forward( 173 | means3d, 174 | scales, 175 | glob_scale, 176 | quats, 177 | viewmat, 178 | intrins, 179 | img_size, 180 | clip_thresh=0.01, 181 | ): 182 | fx, fy, cx, cy = intrins 183 | tan_fovx = 0.5 * img_size[0] / fx 184 | tan_fovy = 0.5 * img_size[1] / fy 185 | p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh) 186 | cov3d = scale_rot_to_cov3d(scales, glob_scale, quats) 187 | cov2d, compensation = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy) 188 | xys = project_pix((fx, fy), p_view, (cx, cy)) 189 | 190 | return xys, cov2d 191 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | @torch.no_grad() 20 | def psnr(img1, img2, mask=None): 21 | if mask is not None: 22 | img1 = img1.flatten(1) 23 | img2 = img2.flatten(1) 24 | 25 | mask = mask.flatten(1).repeat(3, 1) 26 | mask = torch.where(mask != 0, True, False) 27 | img1 = img1[mask] 28 | img2 = img2[mask] 29 | 30 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 31 | 32 | else: 33 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 34 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float())) 35 | if mask is not None: 36 | if torch.isinf(psnr).any(): 37 | print(mse.mean(), psnr.mean()) 38 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float())) 39 | psnr = psnr[~torch.isinf(psnr)] 40 | 41 | return psnr 42 | -------------------------------------------------------------------------------- /utils/loader_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | def get_stamp_list(dataset, timestamp): 8 | frame_length = int(len(dataset) / len(dataset.dataset.poses)) 9 | # print(frame_length) 10 | if timestamp > frame_length: 11 | raise IndexError("input timestamp bigger than total timestamp.") 12 | print("select index:", [i * frame_length + timestamp for i in range(len(dataset.dataset.poses))]) 13 | return [dataset[i * frame_length + timestamp] for i in range(len(dataset.dataset.poses))] 14 | 15 | 16 | class FineSampler(Sampler): 17 | def __init__(self, dataset): 18 | self.len_dataset = len(dataset) 19 | self.len_pose = len(dataset.dataset.poses) 20 | self.frame_length = int(self.len_dataset / self.len_pose) 21 | 22 | sample_list = [] 23 | for i in range(self.frame_length): 24 | for j in range(4): 25 | idx = torch.randperm(self.len_pose) * self.frame_length + i 26 | # print(idx) 27 | # breakpoint() 28 | now_list = [] 29 | cnt = 0 30 | for item in idx.tolist(): 31 | now_list.append(item) 32 | cnt += 1 33 | if cnt % 2 == 0 and len(sample_list) > 2: 34 | select_element = [x for x in random.sample(sample_list, 2)] 35 | now_list += select_element 36 | 37 | sample_list += now_list 38 | 39 | self.sample_list = sample_list 40 | # print(self.sample_list) 41 | # breakpoint() 42 | print("one epoch containing:", len(self.sample_list)) 43 | 44 | def __iter__(self): 45 | return iter(self.sample_list) 46 | 47 | def __len__(self): 48 | return len(self.sample_list) 49 | -------------------------------------------------------------------------------- /utils/main_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from matplotlib import cm 7 | from PIL import Image 8 | 9 | 10 | __all__ = ( 11 | "save_debug_imgs", 12 | "get_normals", 13 | "sw_cams", 14 | "sw_depth_normalization", 15 | "error_to_prob", 16 | "get_rays", 17 | "get_gs_mask", 18 | "get_pixels", 19 | ) 20 | 21 | def to8b(x): 22 | return (255 * np.clip(x.cpu().numpy(), 0, 1)).astype(np.uint8) 23 | 24 | def get_gs_mask(s_image_tensor, gt_image_tensor, s_depth_tensor, depth_tensor, CVD): 25 | B, C, H, W = s_image_tensor.shape 26 | 27 | # Color based 28 | gs_error = torch.mean(torch.abs(s_image_tensor - gt_image_tensor), 1, True) 29 | gs_mask_c = error_to_prob(gs_error.detach()) 30 | 31 | # Depth based 32 | gs_mask_d = error_to_prob(torch.mean(torch.abs(s_depth_tensor - depth_tensor), 1, True).detach()) 33 | norm_disp = 1 / (CVD + 1e-7) 34 | norm_disp = (norm_disp + F.max_pool2d(-norm_disp, kernel_size=(H, W))) / ( 35 | F.max_pool2d(norm_disp, kernel_size=(H, W)) + F.max_pool2d(-norm_disp, kernel_size=(H, W)) 36 | ) 37 | gs_mask_d = 1 - norm_disp * (1 - gs_mask_d) 38 | 39 | return gs_mask_c.detach(), gs_mask_d.detach() 40 | 41 | 42 | def get_pixels(image_size_x, image_size_y, use_center=None): 43 | """Return the pixel at center or corner.""" 44 | xx, yy = np.meshgrid( 45 | np.arange(image_size_x, dtype=np.float32), 46 | np.arange(image_size_y, dtype=np.float32), 47 | ) 48 | offset = 0.5 if use_center else 0 49 | return np.stack([xx, yy], axis=-1) + offset 50 | 51 | 52 | def error_to_prob(error, mask=None, mean_prob=0.5): 53 | if mask is None: 54 | mean_err = torch.mean(error, dim=(3, 2, 1)) + 1e-7 55 | else: 56 | mean_err = torch.sum(mask * error, dim=(3, 2)) / (torch.sum(mask, dim=(3, 2)) + 1e-7) + 1e-7 57 | prob = mean_prob * (error / mean_err.view(error.shape[0], 1, 1, 1)) 58 | prob[prob > 1] = 1 59 | prob = 1 - prob 60 | return prob 61 | 62 | 63 | def get_rays(H, W, K, c2w): 64 | i, j = torch.meshgrid( 65 | torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H) 66 | ) # pytorch's meshgrid has indexing='ij' 67 | i = i.t().type_as(K) 68 | j = j.t().type_as(K) 69 | dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1) 70 | # Rotate ray directions from camera frame to the world frame 71 | rays_d = torch.sum( 72 | dirs[..., np.newaxis, :] * c2w[:3, :3], -1 73 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 74 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 75 | rays_o = c2w[:3, -1].expand(rays_d.shape) 76 | return rays_o, rays_d 77 | 78 | 79 | def flow2rgb(flow_map, max_value): 80 | _, h, w = flow_map.shape 81 | flow_map[:, (flow_map[0] == 0) & (flow_map[1] == 0)] = float("nan") 82 | rgb_map = np.ones((3, h, w)).astype(np.float32) 83 | if max_value is not None: 84 | normalized_flow_map = flow_map / max_value 85 | else: 86 | normalized_flow_map = flow_map / (np.abs(flow_map).max()) 87 | rgb_map[0, :, :] += normalized_flow_map[0, :, :] 88 | rgb_map[1, :, :] -= 0.5 * (normalized_flow_map[0, :, :] + normalized_flow_map[1, :, :]) 89 | rgb_map[2, :, :] += normalized_flow_map[1, :, :] 90 | return rgb_map.clip(0, 1) 91 | 92 | 93 | def save_debug_imgs(debug_dict, idx, epoch=0, deb_path=None, ext="jpg"): 94 | new_outputs = {} 95 | for key in debug_dict.keys(): 96 | out_tensor = debug_dict[key].detach() 97 | if out_tensor.shape[1] == 3: 98 | if "normal" in key: 99 | p_im = (out_tensor[idx].squeeze().cpu().numpy() + 1) / 2 100 | p_im[p_im > 1] = 1 101 | p_im[p_im < 0] = 0 102 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8) 103 | else: 104 | p_im = out_tensor[idx].squeeze().cpu().numpy() 105 | p_im[p_im > 1] = 1 106 | p_im[p_im < 0] = 0 107 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8) 108 | elif out_tensor.shape[1] == 2: 109 | p_im = flow2rgb(out_tensor[idx].squeeze().cpu().numpy(), None) 110 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8) 111 | elif out_tensor.shape[1] == 1: 112 | if "disp" in key: 113 | nmap = out_tensor[idx].squeeze().cpu().numpy() 114 | nmap = np.clip(nmap / (np.percentile(nmap, 99) + 1e-6), 0, 1) 115 | p_im = (255 * cm.plasma(nmap)).astype(np.uint8) 116 | elif "error" in key: 117 | nmap = out_tensor[idx].squeeze().cpu().numpy() 118 | nmap = np.clip(nmap, 0, 1) 119 | p_im = (255 * cm.jet(nmap)).astype(np.uint8) 120 | else: 121 | B, C, H, W = out_tensor.shape 122 | the_max = torch.max_pool2d(out_tensor, kernel_size=(H, W)) 123 | nmap = out_tensor / the_max 124 | p_im = nmap[idx].squeeze().cpu().numpy() 125 | p_im = np.rint(255 * p_im).astype(np.uint8) 126 | 127 | # Save or return normalized image 128 | if deb_path is not None: 129 | if len(p_im.shape) == 3: 130 | p_im = p_im[:, :, 0:3] 131 | im = Image.fromarray(p_im) 132 | im.save(os.path.join(deb_path, "e{}_{}.{}".format(epoch, key, ext))) 133 | else: 134 | new_outputs[key] = p_im 135 | 136 | if deb_path is None: 137 | return new_outputs 138 | 139 | 140 | def get_normals(z, camera_metadata): 141 | pixels = camera_metadata.get_pixels() 142 | y = (pixels[..., 1] - camera_metadata.principal_point_y) / camera_metadata.scale_factor_y 143 | x = ( 144 | pixels[..., 0] - camera_metadata.principal_point_x - y * camera_metadata.skew 145 | ) / camera_metadata.scale_factor_x 146 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1) 147 | viewdirs = torch.from_numpy(viewdirs).to(z.device) 148 | 149 | coords = viewdirs[None] * z[..., None] 150 | coords = coords.permute(0, 3, 1, 2) 151 | 152 | dxdu = coords[..., 0, :, 1:] - coords[..., 0, :, :-1] 153 | dydu = coords[..., 1, :, 1:] - coords[..., 1, :, :-1] 154 | dzdu = coords[..., 2, :, 1:] - coords[..., 2, :, :-1] 155 | dxdv = coords[..., 0, 1:, :] - coords[..., 0, :-1, :] 156 | dydv = coords[..., 1, 1:, :] - coords[..., 1, :-1, :] 157 | dzdv = coords[..., 2, 1:, :] - coords[..., 2, :-1, :] 158 | 159 | dxdu = torch.nn.functional.pad(dxdu, (0, 1), mode="replicate") 160 | dydu = torch.nn.functional.pad(dydu, (0, 1), mode="replicate") 161 | dzdu = torch.nn.functional.pad(dzdu, (0, 1), mode="replicate") 162 | 163 | dxdv = torch.cat([dxdv, dxdv[..., -1:, :]], dim=-2) 164 | dydv = torch.cat([dydv, dydv[..., -1:, :]], dim=-2) 165 | dzdv = torch.cat([dzdv, dzdv[..., -1:, :]], dim=-2) 166 | 167 | n_x = dydv * dzdu - dydu * dzdv 168 | n_y = dzdv * dxdu - dzdu * dxdv 169 | n_z = dxdv * dydu - dxdu * dydv 170 | 171 | pred_normal = torch.stack([n_x, n_y, n_z], dim=-3) 172 | pred_normal = torch.nn.functional.normalize(pred_normal, dim=-3) 173 | return pred_normal 174 | 175 | # coords = coords.squeeze(0) 176 | # hd, wd, _ = coords.shape 177 | # bottom_point = coords[..., 2:hd, 1 : wd - 1, :] 178 | # top_point = coords[..., 0 : hd - 2, 1 : wd - 1, :] 179 | # right_point = coords[..., 1 : hd - 1, 2:wd, :] 180 | # left_point = coords[..., 1 : hd - 1, 0 : wd - 2, :] 181 | # left_to_right = right_point - left_point 182 | # bottom_to_top = top_point - bottom_point 183 | # xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1) 184 | # xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1) 185 | # pred_normal = torch.nn.functional.pad(xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant") 186 | # return pred_normal[None] 187 | 188 | 189 | def sw_cams(viewpoint_stack, cam_id, sw_size=2): 190 | viewpoint_cams_window = [viewpoint_stack[cam_id]] 191 | for sw in range(1, sw_size + 1): 192 | if cam_id - sw >= 0: 193 | viewpoint_cams_window.append(viewpoint_stack[cam_id - sw]) 194 | if cam_id + sw < len(viewpoint_stack): 195 | viewpoint_cams_window.append(viewpoint_stack[cam_id + sw]) 196 | return viewpoint_cams_window 197 | 198 | 199 | def sw_depth_normalization(viewpoint_cams_window_list, depth_tensor, batch_size): 200 | for n_batch in range(batch_size): 201 | depth_window = [] 202 | for viewpoint_cams_window in viewpoint_cams_window_list[n_batch]: 203 | depth_window.append(viewpoint_cams_window.depth[None].cuda()) 204 | depth_window = torch.cat(depth_window, 0) 205 | depth_window_min = torch.min(depth_window).cuda() 206 | depth_window_max = torch.max(depth_window).cuda() 207 | depth_tensor[n_batch] = (depth_tensor[n_batch] - depth_window_min) / (depth_window_max - depth_window_min) 208 | return depth_tensor 209 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 OPPO 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import numpy as np 24 | import torch 25 | 26 | # from mmcv.ops import knn 27 | import torch.nn as nn 28 | 29 | 30 | class Sandwich(nn.Module): 31 | def __init__(self, dim, outdim=3, bias=False): 32 | super(Sandwich, self).__init__() 33 | 34 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias) # 35 | 36 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias) 37 | self.relu = nn.ReLU() 38 | 39 | self.sigmoid = torch.nn.Sigmoid() 40 | 41 | def forward(self, input, rays, time=None): 42 | albedo, spec, timefeature = input.chunk(3, dim=1) 43 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5 44 | specular = self.mlp1(specular) 45 | specular = self.relu(specular) 46 | specular = self.mlp2(specular) 47 | 48 | result = albedo + specular 49 | result = self.sigmoid(result) 50 | return result 51 | 52 | 53 | class Sandwichnoact(nn.Module): 54 | def __init__(self, dim, outdim=3, bias=False): 55 | super(Sandwichnoact, self).__init__() 56 | 57 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias) 58 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias) 59 | self.relu = nn.ReLU() 60 | 61 | def forward(self, input, rays, time=None): 62 | albedo, spec, timefeature = input.chunk(3, dim=1) 63 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5 64 | specular = self.mlp1(specular) 65 | specular = self.relu(specular) 66 | specular = self.mlp2(specular) 67 | 68 | result = albedo + specular 69 | result = torch.clamp(result, min=0.0, max=1.0) 70 | return result 71 | 72 | 73 | class Sandwichnoactss(nn.Module): 74 | def __init__(self, dim, outdim=3, bias=False): 75 | super(Sandwichnoactss, self).__init__() 76 | 77 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias) 78 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias) 79 | 80 | self.relu = nn.ReLU() 81 | 82 | def forward(self, input, rays, time=None): 83 | albedo, spec, timefeature = input.chunk(3, dim=1) 84 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5 85 | specular = self.mlp1(specular) 86 | specular = self.relu(specular) 87 | specular = self.mlp2(specular) 88 | 89 | result = albedo + specular 90 | return result 91 | 92 | 93 | ####### following are also good rgb model but not used in the paper, slower than sandwich, inspired by color shift in hyperreel 94 | # remove sigmoid for immersive dataset 95 | class RGBDecoderVRayShift(nn.Module): 96 | def __init__(self, dim, outdim=3, bias=False): 97 | super(RGBDecoderVRayShift, self).__init__() 98 | 99 | self.mlp1 = nn.Conv2d(dim, outdim, kernel_size=1, bias=bias) 100 | self.mlp2 = nn.Conv2d(15, outdim, kernel_size=1, bias=bias) 101 | self.mlp3 = nn.Conv2d(6, outdim, kernel_size=1, bias=bias) 102 | self.sigmoid = torch.nn.Sigmoid() 103 | 104 | self.dwconv1 = nn.Conv2d(9, 9, kernel_size=1, bias=bias) 105 | 106 | def forward(self, input, rays, t=None): 107 | x = self.dwconv1(input) + input 108 | albeado = self.mlp1(x) 109 | specualr = torch.cat([x, rays], dim=1) 110 | specualr = self.mlp2(specualr) 111 | 112 | finalfeature = torch.cat([albeado, specualr], dim=1) 113 | result = self.mlp3(finalfeature) 114 | result = self.sigmoid(result) 115 | return result 116 | 117 | def getcolormodel(rgbfuntion): 118 | if rgbfuntion == "sandwich": 119 | rgbdecoder = Sandwich(9, 3) 120 | 121 | elif rgbfuntion == "sandwichnoact": 122 | rgbdecoder = Sandwichnoact(9, 3) 123 | elif rgbfuntion == "sandwichnoactss": 124 | rgbdecoder = Sandwichnoactss(9, 3) 125 | else: 126 | return None 127 | return rgbdecoder 128 | 129 | 130 | def pix2ndc(v, S): 131 | return (v * 2.0 + 1.0) / S - 1.0 132 | 133 | 134 | def ndc2pix(v, S): 135 | return ((v + 1.0) * S - 1.0) * 0.5 136 | -------------------------------------------------------------------------------- /utils/params_utils.py: -------------------------------------------------------------------------------- 1 | def merge_hparams(args, config): 2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] 3 | for param in params: 4 | if param in config.keys(): 5 | for key, value in config[param].items(): 6 | if hasattr(args, key): 7 | setattr(args, key, value) 8 | 9 | return args 10 | -------------------------------------------------------------------------------- /utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import torch 4 | from torch.utils.data import TensorDataset, random_split 5 | from torch_cluster import grid_cluster 6 | from sklearn.neighbors import NearestNeighbors 7 | 8 | def voxel_down_sample_custom(points, voxel_size): 9 | # 将点云归一化到体素网格 10 | voxel_grid = torch.floor(points / voxel_size) 11 | 12 | # 找到唯一的体素,并获取它们在原始体素网格中的索引 13 | unique_voxels, inverse_indices = torch.unique(voxel_grid, dim=0, return_inverse=True) 14 | 15 | # 创建一个新的点云,其中每个点是其对应体素中所有点的平均值 16 | new_points = torch.zeros_like(unique_voxels) 17 | new_points_count = torch.zeros(unique_voxels.size(0), dtype=torch.long) 18 | # for i in tqdm(range(points.size(0))): 19 | new_points[inverse_indices] = points 20 | # new_points_count[inverse_indices[i]] += 1 21 | # new_points /= new_points_count.unsqueeze(-1) 22 | 23 | return new_points, inverse_indices 24 | 25 | 26 | def downsample_point_cloud(points, ratio): 27 | # 创建一个TensorDataset 28 | dataset = TensorDataset(points) 29 | 30 | # 计算下采样后的点的数量 31 | num_points = len(dataset) 32 | num_downsampled_points = int(num_points * ratio) 33 | 34 | # 使用random_split进行下采样 35 | downsampled_dataset, _ = random_split(dataset, [num_downsampled_points, num_points - num_downsampled_points]) 36 | 37 | # 获取下采样后的点的index和点云矩阵 38 | indices = torch.tensor([i for i, _ in enumerate(downsampled_dataset)]) 39 | downsampled_points = torch.stack([x for x, in downsampled_dataset]) 40 | 41 | return indices, downsampled_points 42 | 43 | 44 | def downsample_point_cloud_open3d(points, voxel_size): 45 | # 创建一个点云对象 46 | 47 | downsampled_pcd, inverse_indices = voxel_down_sample_custom(points, voxel_size) 48 | downsampled_points = downsampled_pcd 49 | # 获取下采样后的点云矩阵 50 | 51 | return torch.tensor(downsampled_points) 52 | 53 | 54 | def downsample_point_cloud_cluster(points, voxel_size): 55 | # 创建一个点云对象 56 | cluster = grid_cluster(points, size=torch.tensor([1, 1, 1])) 57 | 58 | # 获取下采样后的点云矩阵 59 | # downsampled_points = np.asarray(downsampled_pcd.points) 60 | 61 | return cluster, points 62 | 63 | 64 | def upsample_point_cloud(points, density_threshold, displacement_scale, iter_pass): 65 | # 计算每个点的密度 66 | # breakpoint() 67 | try: 68 | nbrs = NearestNeighbors(n_neighbors=2 + iter_pass, algorithm="ball_tree").fit(points) 69 | distances, indices = nbrs.kneighbors(points) 70 | except: 71 | print("no point added") 72 | return points, torch.tensor([]), torch.tensor([]), torch.zeros((points.shape[0]), dtype=torch.bool) 73 | 74 | # 找出密度低的点 75 | low_density_points = points[distances[:, 1] > density_threshold] 76 | low_density_index = distances[:, 1] > density_threshold 77 | low_density_index = torch.from_numpy(low_density_index) 78 | # 复制这些点并添加随机位移 79 | num_points = low_density_points.shape[0] 80 | displacements = torch.randn(num_points, 3) * displacement_scale 81 | new_points = low_density_points + displacements 82 | # 返回新的点云矩阵 83 | return points, low_density_points, new_points, low_density_index 84 | 85 | 86 | def visualize_point_cloud(points, low_density_points, new_points): 87 | # 创建一个点云对象 88 | pcd = o3d.geometry.PointCloud() 89 | 90 | # 给被选中的点云添加一个小的偏移量 91 | low_density_points += 0.01 92 | 93 | # 将所有的点合并到一起 94 | all_points = np.concatenate([points, low_density_points, new_points], axis=0) 95 | pcd.points = o3d.utility.Vector3dVector(all_points) 96 | 97 | # 创建颜色数组 98 | colors = np.zeros((all_points.shape[0], 3)) 99 | colors[: points.shape[0]] = [0, 0, 0] # 黑色表示初始化的点云 100 | colors[points.shape[0] : points.shape[0] + low_density_points.shape[0]] = [1, 0, 0] # 红色表示被选中的点云 101 | colors[points.shape[0] + low_density_points.shape[0] :] = [0, 1, 0] # 绿色表示增长的点云 102 | pcd.colors = o3d.utility.Vector3dVector(colors) 103 | 104 | # 显示点云 105 | o3d.visualization.draw_geometries([pcd]) 106 | 107 | 108 | def combine_pointcloud(points, low_density_points, new_points): 109 | pcd = o3d.geometry.PointCloud() 110 | 111 | # 给被选中的点云添加一个小的偏移量 112 | low_density_points += 0.01 113 | new_points -= 0.01 114 | # 将所有的点合并到一起 115 | all_points = np.concatenate([points, low_density_points, new_points], axis=0) 116 | pcd.points = o3d.utility.Vector3dVector(all_points) 117 | 118 | # 创建颜色数组 119 | colors = np.zeros((all_points.shape[0], 3)) 120 | colors[: points.shape[0]] = [0, 0, 0] # 黑色表示初始化的点云 121 | colors[points.shape[0] : points.shape[0] + low_density_points.shape[0]] = [1, 0, 0] # 红色表示被选中的点云 122 | colors[points.shape[0] + low_density_points.shape[0] :] = [0, 1, 0] # 绿色表示增长的点云 123 | pcd.colors = o3d.utility.Vector3dVector(colors) 124 | return pcd 125 | 126 | 127 | def addpoint( 128 | point_cloud, 129 | density_threshold, 130 | displacement_scale, 131 | iter_pass, 132 | ): 133 | # density_threshold: 密度的阈值,越大能筛选出越稀疏的点。 134 | # displacement_scale: 在以displacement_scale的圆心内随机生成点 135 | 136 | points, low_density_points, new_points, low_density_index = upsample_point_cloud( 137 | point_cloud, density_threshold, displacement_scale, iter_pass 138 | ) 139 | # breakpoint() 140 | # breakpoint() 141 | print("low_density_points", low_density_points.shape[0]) 142 | 143 | return point_cloud, low_density_points, new_points, low_density_index 144 | 145 | 146 | def find_point_indices(origin_point, goal_point): 147 | indices = torch.nonzero((origin_point[:, None] == goal_point).all(-1), as_tuple=True)[0] 148 | return indices 149 | 150 | 151 | def find_indices_in_A(A, B): 152 | """ 153 | 找出子集矩阵 B 中每个点在点云矩阵 A 中的索引 u。 154 | 155 | 参数: 156 | A (torch.Tensor): 点云矩阵 A,大小为 [N, 3]。 157 | B (torch.Tensor): 子集矩阵 B,大小为 [M, 3]。 158 | 159 | 返回: 160 | torch.Tensor: 包含 B 中每个点在 A 中的索引 u 的张量,形状为 (M,)。 161 | """ 162 | is_equal = torch.eq(B.view(1, -1, 3), A.view(-1, 1, 3)) 163 | u_indices = torch.nonzero(is_equal, as_tuple=False)[:, 0] 164 | return torch.unique(u_indices) 165 | 166 | 167 | if __name__ == "__main__": 168 | # 169 | from time import time 170 | 171 | pass_ = 0 172 | # filename=f"pointcloud/pass_{pass_}.ply" 173 | filename = "point_cloud.ply" 174 | pcd = o3d.io.read_point_cloud(filename) 175 | point_cloud = torch.tensor(pcd.points) 176 | voxel_size = 8 177 | density_threshold = 20 178 | displacement_scale = 5 179 | for i in range(pass_ + 1, 50): 180 | print("pass ", i) 181 | time0 = time() 182 | 183 | point_downsample = point_cloud 184 | flag = False 185 | while point_downsample.shape[0] > 1000: 186 | if flag: 187 | voxel_size += 8 188 | print("point size:", point_downsample.shape[0]) 189 | point_downsample = downsample_point_cloud_open3d(point_cloud, voxel_size=voxel_size) 190 | flag = True 191 | 192 | print("point size:", point_downsample.shape[0]) 193 | # downsampled_point_index = find_point_indices(point_cloud, point_downsample) 194 | downsampled_point_index = find_indices_in_A(point_cloud, point_downsample) 195 | print("selected_num", point_cloud[downsampled_point_index].shape[0]) 196 | _, low_density_points, new_points, low_density_index = addpoint( 197 | point_cloud[downsampled_point_index], 198 | density_threshold=density_threshold, 199 | displacement_scale=displacement_scale, 200 | iter_pass=0, 201 | ) 202 | if new_points.shape[0] < 100: 203 | density_threshold /= 2 204 | displacement_scale /= 2 205 | print("reduce diplacement_scale to: ", displacement_scale) 206 | 207 | global_mask = torch.zeros((point_cloud.shape[0]), dtype=torch.bool) 208 | 209 | global_mask[downsampled_point_index] = low_density_index 210 | time1 = time() 211 | 212 | print("time cost:", time1 - time0, "new_points:", new_points.shape[0]) 213 | if low_density_points.shape[0] == 0: 214 | print("no more points.") 215 | continue 216 | # breakpoint() 217 | point = combine_pointcloud(point_cloud, low_density_points, new_points) 218 | point_cloud = torch.tensor(point.points) 219 | o3d.io.write_point_cloud(f"pointcloud/pass_{i}.ply", point) 220 | # visualize_qpoint_cloud( point_cloud, low_density_points, new_points) 221 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | def rotation_matrix_to_quaternion(rotation_matrix): 8 | """将旋转矩阵转换为四元数""" 9 | return R.from_matrix(rotation_matrix).as_quat() 10 | 11 | 12 | def quaternion_to_rotation_matrix(quat): 13 | """将四元数转换为旋转矩阵""" 14 | return R.from_quat(quat).as_matrix() 15 | 16 | 17 | def quaternion_slerp(q1, q2, t): 18 | """在两个四元数之间进行球面线性插值(SLERP)""" 19 | # 计算两个四元数之间的点积 20 | dot = np.dot(q1, q2) 21 | 22 | # 如果点积为负,取反一个四元数以保证最短路径插值 23 | if dot < 0.0: 24 | q1 = -q1 25 | dot = -dot 26 | 27 | # 防止数值误差导致的问题 28 | dot = np.clip(dot, -1.0, 1.0) 29 | 30 | # 计算插值参数 31 | theta = np.arccos(dot) * t 32 | q3 = q2 - q1 * dot 33 | q3 = q3 / np.linalg.norm(q3) 34 | 35 | # 计算插值结果 36 | return np.cos(theta) * q1 + np.sin(theta) * q3 37 | 38 | 39 | def bezier_interpolation(p1, p2, t): 40 | """在两点之间使用贝塞尔曲线进行插值""" 41 | return (1 - t) * p1 + t * p2 42 | 43 | 44 | def linear_interpolation(v1, v2, t): 45 | """线性插值""" 46 | return (1 - t) * v1 + t * v2 47 | 48 | 49 | def smooth_camera_poses(cameras, num_interpolations=5): 50 | """对一系列相机位姿进行平滑处理,通过在每对位姿之间插入额外的位姿""" 51 | smoothed_cameras = [] 52 | smoothed_times = [] 53 | total_poses = len(cameras) - 1 + (len(cameras) - 1) * num_interpolations 54 | time_increment = 10 / total_poses 55 | 56 | for i in range(len(cameras) - 1): 57 | cam1 = cameras[i] 58 | cam2 = cameras[i + 1] 59 | 60 | # 将旋转矩阵转换为四元数 61 | quat1 = rotation_matrix_to_quaternion(cam1.orientation) 62 | quat2 = rotation_matrix_to_quaternion(cam2.orientation) 63 | 64 | for j in range(num_interpolations + 1): 65 | t = j / (num_interpolations + 1) 66 | 67 | # 插值方向 68 | interp_orientation_quat = quaternion_slerp(quat1, quat2, t) 69 | interp_orientation_matrix = quaternion_to_rotation_matrix(interp_orientation_quat) 70 | 71 | # 插值位置 72 | interp_position = linear_interpolation(cam1.position, cam2.position, t) 73 | 74 | # 计算插值时间戳 75 | interp_time = i * 10 / (len(cameras) - 1) + time_increment * j 76 | 77 | # 添加新的相机位姿和时间戳 78 | newcam = deepcopy(cam1) 79 | newcam.orientation = interp_orientation_matrix 80 | newcam.position = interp_position 81 | smoothed_cameras.append(newcam) 82 | smoothed_times.append(interp_time) 83 | 84 | # 添加最后一个原始位姿和时间戳 85 | smoothed_cameras.append(cameras[-1]) 86 | smoothed_times.append(1.0) 87 | print(smoothed_times) 88 | return smoothed_cameras, smoothed_times 89 | 90 | 91 | # # 示例:使用两个相机位姿 92 | # cam1 = Camera(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([0, 0, 0])) 93 | # cam2 = Camera(np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]), np.array([1, 1, 1])) 94 | 95 | # # 应用平滑处理 96 | # smoothed_cameras = smooth_camera_poses([cam1, cam2], num_interpolations=5) 97 | 98 | # # 打印结果 99 | # for cam in smoothed_cameras: 100 | # print("Orientation:\n", cam.orientation) 101 | # print("Position:", cam.position) 102 | -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | def get_state_at_time(pc, viewpoint_camera): 6 | means3D = pc.get_xyz 7 | time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0], 1) 8 | opacity = pc._opacity 9 | shs = pc.get_features 10 | 11 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 12 | # scaling / rotation by the rasterizer. 13 | scales = pc._scaling 14 | rotations = pc._rotation 15 | 16 | # time0 = get_time() 17 | # means3D_deform, scales_deform, rotations_deform, opacity_deform = pc._deformation(means3D[deformation_point], scales[deformation_point], 18 | # rotations[deformation_point], opacity[deformation_point], 19 | # time[deformation_point]) 20 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = pc._deformation( 21 | means3D, scales, rotations, opacity, shs, time 22 | ) 23 | scales_final = pc.scaling_activation(scales_final) 24 | rotations_final = pc.rotation_activation(rotations_final) 25 | opacity = pc.opacity_activation(opacity_final) 26 | return means3D_final, scales_final, rotations_final, opacity, shs 27 | -------------------------------------------------------------------------------- /utils/scene_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from PIL import Image, ImageDraw, ImageFont 6 | 7 | 8 | plt.rcParams["font.sans-serif"] = ["Times New Roman"] 9 | 10 | 11 | import numpy as np 12 | from scene import deformation 13 | 14 | 15 | @torch.no_grad() 16 | def render_training_image( 17 | scene, 18 | stat_gaussians, 19 | dyn_gaussians, 20 | viewpoints, 21 | render_func, 22 | pipe, 23 | background, 24 | stage, 25 | iteration, 26 | time_now, 27 | dataset_type, 28 | is_train=False, 29 | over_t=None, 30 | ): 31 | # Get base parameters for warping 32 | if stage == "warm": 33 | viewpoint = viewpoints[0] 34 | gt_normal = viewpoint.normal[None].cuda() 35 | gt_normal.reshape(-1, 3) 36 | pixels = viewpoint.metadata.get_pixels(normalize=True) 37 | pixels = torch.from_numpy(pixels).cuda() 38 | pixels.reshape(-1, 2) 39 | gt_depth = viewpoint.depth[None].cuda() 40 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[2::]) 41 | depth_in = gt_depth.reshape(-1, 1) 42 | time_in_0 = torch.tensor(0).float().cuda() 43 | time_in_0 = time_in_0.view(1, 1) 44 | # pred_R0, pred_T0, depth_scale_T0, depth_shift_T0 = dyn_gaussians._posenet(time_in_0, depth=depth_in, normals=normal_in, pixel=pixels_in) 45 | pred_R0, pred_T0, CVD0 = dyn_gaussians._posenet(time_in_0, depth=depth_in) 46 | 47 | def render(static_gaussians, dynamic_gaussians, viewpoint, path, scaling, cam_type, over_t=None): 48 | if stage == "warm": 49 | # Get warped images (all warped to time step 0) 50 | if dataset_type == "PanopticSports": 51 | image_tensor = viewpoint["image"][None].cuda() 52 | else: 53 | image_tensor = viewpoint.original_image[None].cuda() 54 | B, C, H, W = image_tensor.shape 55 | 56 | viewpoint.normal[None].cuda().reshape(-1, 3) 57 | pixels = viewpoint.metadata.get_pixels(normalize=True) 58 | torch.from_numpy(pixels).cuda().reshape(-1, 2) 59 | gt_depth = viewpoint.depth[None].cuda() 60 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[2::]) 61 | depth_in = gt_depth.reshape(-1, 1) 62 | time_in = torch.tensor(viewpoint.time).float().cuda() 63 | time_in = time_in.view(1, 1) 64 | # CVD = depth_in.view(1, 1, H, W) 65 | 66 | # CVD = CVD_T0.view(1, 1, H, W) 67 | # pred_R, pred_T, _, _ = dynamic_gaussians._posenet(time_in, depth=depth_in, normals=normal_in, pixel=pixels_in) 68 | pred_R, pred_T, CVD = dynamic_gaussians._posenet(time_in, depth=depth_in) 69 | 70 | # depth_in = depth_in.view(1, 1, H, W) 71 | # depth_shift = depth_shift_T0.view(1, 1, H, W) 72 | # depth_scale = depth_scale_T0.view(1, 1, 1, 1) 73 | # CVD = depth_scale * depth_in + depth_shift 74 | 75 | K_tensor = torch.zeros(1, 3, 3).type_as(image_tensor) 76 | K_tensor[:, 0, 0] = float(viewpoint.metadata.scale_factor_x) 77 | K_tensor[:, 1, 1] = float(viewpoint.metadata.scale_factor_y) 78 | K_tensor[:, 0, 2] = float(viewpoint.metadata.principal_point_x) 79 | K_tensor[:, 1, 2] = float(viewpoint.metadata.principal_point_y) 80 | K_tensor[:, 2, 2] = float(1) 81 | 82 | w2c_target = torch.cat((pred_R0, pred_T0[:, :, None]), -1) 83 | w2c_prev = torch.cat((pred_R, pred_T[:, :, None]), -1) 84 | warped_img = deformation.inverse_warp_rt1_rt2( 85 | image_tensor, CVD, w2c_target, w2c_prev, K_tensor, torch.inverse(K_tensor) 86 | ) 87 | 88 | p_im = warped_img.detach().squeeze().cpu().numpy() 89 | im = Image.fromarray(np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8)) 90 | im.save(path.replace(".jpg", "_warped.jpg")) 91 | return 92 | 93 | # scaling_copy = gaussians._scaling 94 | render_pkg = render_func( 95 | viewpoint, static_gaussians, dynamic_gaussians, background, get_static=True, get_dynamic=True 96 | ) 97 | 98 | label1 = f"stage:{stage},iter:{iteration}" 99 | times = time_now / 60 100 | if times < 1: 101 | end = "min" 102 | else: 103 | end = "mins" 104 | label2 = "time:%.2f" % times + end 105 | 106 | image = render_pkg["render"] 107 | 108 | d_image = render_pkg["d_render"] 109 | s_image = render_pkg["s_render"] 110 | 111 | d_alpha = render_pkg["d_alpha"] 112 | 113 | depth = render_pkg["depth"] 114 | st_depth = render_pkg["s_depth"] 115 | 116 | z = depth + 1e-6 117 | camera_metadata = viewpoint.metadata 118 | pixels = camera_metadata.get_pixels() 119 | y = ( 120 | pixels[..., 1] - camera_metadata.principal_point_y 121 | ) / dynamic_gaussians._posenet.focal_bias.exp().detach().cpu().numpy() 122 | x = ( 123 | pixels[..., 0] - camera_metadata.principal_point_x - y * camera_metadata.skew 124 | ) / dynamic_gaussians._posenet.focal_bias.exp().detach().cpu().numpy() 125 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1) 126 | viewdirs = torch.from_numpy(viewdirs).to(z.device) 127 | 128 | coords = viewdirs[None] * z[..., None] 129 | coords = coords.permute(0, 3, 1, 2) 130 | 131 | dxdu = coords[..., 0, :, 1:] - coords[..., 0, :, :-1] 132 | dydu = coords[..., 1, :, 1:] - coords[..., 1, :, :-1] 133 | dzdu = coords[..., 2, :, 1:] - coords[..., 2, :, :-1] 134 | dxdv = coords[..., 0, 1:, :] - coords[..., 0, :-1, :] 135 | dydv = coords[..., 1, 1:, :] - coords[..., 1, :-1, :] 136 | dzdv = coords[..., 2, 1:, :] - coords[..., 2, :-1, :] 137 | 138 | dxdu = torch.nn.functional.pad(dxdu, (0, 1), mode="replicate") 139 | dydu = torch.nn.functional.pad(dydu, (0, 1), mode="replicate") 140 | dzdu = torch.nn.functional.pad(dzdu, (0, 1), mode="replicate") 141 | 142 | dxdv = torch.cat([dxdv, dxdv[..., -1:, :]], dim=-2) 143 | dydv = torch.cat([dydv, dydv[..., -1:, :]], dim=-2) 144 | dzdv = torch.cat([dzdv, dzdv[..., -1:, :]], dim=-2) 145 | 146 | n_x = dydv * dzdu - dydu * dzdv 147 | n_y = dzdv * dxdu - dzdu * dxdv 148 | n_z = dxdv * dydu - dxdu * dydv 149 | 150 | pred_normal = torch.stack([n_x, n_y, n_z], dim=-3) 151 | pred_normal = torch.nn.functional.normalize(pred_normal, dim=-3) 152 | 153 | if dataset_type == "PanopticSports": 154 | gt_np = viewpoint["image"].permute(1, 2, 0).cpu().numpy() 155 | else: 156 | gt_np = viewpoint.original_image.permute(1, 2, 0).cpu().numpy() 157 | image_np = image.permute(1, 2, 0).cpu().numpy() # (H, W, 3) 158 | 159 | d_image_np = d_image.permute(1, 2, 0).cpu().numpy() 160 | s_image_np = s_image.permute(1, 2, 0).cpu().numpy() 161 | 162 | d_alpha_np = d_alpha.permute(1, 2, 0).cpu().numpy() 163 | d_alpha_np = np.repeat(d_alpha_np, 3, axis=2) 164 | 165 | depth_np = depth.permute(1, 2, 0).cpu().numpy() 166 | depth_np /= depth_np.max() 167 | depth_np = np.repeat(depth_np, 3, axis=2) 168 | 169 | st_depth_np = st_depth.permute(1, 2, 0).cpu().numpy() 170 | st_depth_np /= depth_np.max() 171 | st_depth_np = np.repeat(st_depth_np, 3, axis=2) 172 | 173 | pred_normal_np = (pred_normal[0].permute(1, 2, 0).cpu().numpy() + 1) / 2 174 | 175 | error = (image_np - gt_np) ** 2 176 | error_np = (error - np.min(error)) / (max(np.max(error) - np.min(error), 1e-8)) 177 | 178 | if is_train: 179 | gt_normal = (viewpoint.normal.cuda() + 1) / 2 180 | gt_normal_np = gt_normal.permute(1, 2, 0).cpu().numpy() 181 | 182 | gt_depth = viewpoint.depth.cuda() 183 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[1::]) 184 | gt_depth_np = gt_depth.permute(1, 2, 0).cpu().numpy() 185 | gt_depth_np /= gt_depth_np.max() 186 | gt_depth_np = np.repeat(gt_depth_np, 3, axis=2) 187 | 188 | decomp_image_np = np.concatenate((gt_normal_np, pred_normal_np, gt_depth_np, depth_np), axis=1) 189 | 190 | mask_np = viewpoint.mask.permute(1, 2, 0).cpu().numpy() 191 | mask_np = np.repeat(mask_np, 3, axis=2) 192 | image_np = np.concatenate((gt_np, image_np, mask_np, d_alpha_np, d_image_np, s_image_np), axis=1) 193 | else: 194 | decomp_image_np = np.concatenate((pred_normal_np, depth_np), axis=1) 195 | image_np = np.concatenate((gt_np, image_np, error_np, d_alpha_np, d_image_np, s_image_np), axis=1) 196 | 197 | image_with_labels = Image.fromarray((np.clip(image_np, 0, 1) * 255).astype("uint8")) # 转换为8位图像 198 | decomp_image_with_labels = Image.fromarray((np.clip(decomp_image_np, 0, 1) * 255).astype("uint8")) 199 | # 创建PIL图像对象的副本以绘制标签 200 | draw1 = ImageDraw.Draw(image_with_labels) 201 | 202 | # 选择字体和字体大小 203 | font = ImageFont.truetype("./utils/TIMES.TTF", size=40) # 请将路径替换为您选择的字体文件路径 204 | 205 | # 选择文本颜色 206 | text_color = (255, 0, 0) # 白色 207 | 208 | # 选择标签的位置(左上角坐标) 209 | label1_position = (10, 10) 210 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标 211 | 212 | # 在图像上添加标签 213 | draw1.text(label1_position, label1, fill=text_color, font=font) 214 | draw1.text(label2_position, label2, fill=text_color, font=font) 215 | 216 | image_with_labels.save(path) 217 | decomp_image_with_labels.save(path.replace(".jpg", "_decomp.jpg")) 218 | 219 | render_base_path = os.path.join(scene.model_path, f"{stage}_render") 220 | point_cloud_path = os.path.join(render_base_path, "pointclouds") 221 | if is_train: 222 | image_path = os.path.join(render_base_path, "train/images") 223 | else: 224 | image_path = os.path.join(render_base_path, "val/images") 225 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): 226 | os.makedirs(render_base_path) 227 | if not os.path.exists(point_cloud_path): 228 | os.makedirs(point_cloud_path) 229 | if not os.path.exists(image_path): 230 | os.makedirs(image_path) 231 | # image:3,800,800 232 | 233 | # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg") 234 | for idx in range(len(viewpoints)): 235 | image_save_path = os.path.join(image_path, f"{viewpoints[idx].image_name}.jpg") 236 | render( 237 | stat_gaussians, 238 | dyn_gaussians, 239 | viewpoints[idx], 240 | image_save_path, 241 | scaling=1, 242 | cam_type=dataset_type, 243 | over_t=over_t, 244 | ) 245 | # render(gaussians,point_save_path,scaling = 0.1) 246 | # 保存带有标签的图像 247 | 248 | pc_mask = dyn_gaussians.get_opacity 249 | pc_mask = pc_mask > 0.1 250 | dyn_gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1, 0).numpy() 251 | # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path) 252 | # 如果需要,您可以将PIL图像转换回PyTorch张量 253 | # return image 254 | # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0 255 | 256 | 257 | def visualize_and_save_point_cloud(point_cloud, R, T, filename): 258 | # 创建3D散点图 259 | fig = plt.figure() 260 | ax = fig.add_subplot(111, projection="3d") 261 | R = R.T 262 | # 应用旋转和平移变换 263 | T = -R.dot(T) 264 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) 265 | # pcd = o3d.geometry.PointCloud() 266 | # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式 267 | # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:] 268 | # 可视化点云 269 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c="g", marker="o") 270 | ax.axis("off") 271 | # ax.set_xlabel('X Label') 272 | # ax.set_ylabel('Y Label') 273 | # ax.set_zlabel('Z Label') 274 | 275 | # 保存渲染结果为图片 276 | plt.savefig(filename) 277 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | C0 = 0.28209479177387814 24 | C1 = 0.4886025119029199 25 | C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396] 26 | C3 = [ 27 | -0.5900435899266435, 28 | 2.890611442640554, 29 | -0.4570457994644658, 30 | 0.3731763325901154, 31 | -0.4570457994644658, 32 | 1.445305721320277, 33 | -0.5900435899266435, 34 | ] 35 | C4 = [ 36 | 2.5033429417967046, 37 | -1.7701307697799304, 38 | 0.9461746957575601, 39 | -0.6690465435572892, 40 | 0.10578554691520431, 41 | -0.6690465435572892, 42 | 0.47308734787878004, 43 | -1.7701307697799304, 44 | 0.6258357354491761, 45 | ] 46 | 47 | 48 | def eval_sh(deg, sh, dirs): 49 | """ 50 | Evaluate spherical harmonics at unit directions 51 | using hardcoded SH polynomials. 52 | Works with torch/np/jnp. 53 | ... Can be 0 or more batch dimensions. 54 | Args: 55 | deg: int SH deg. Currently, 0-3 supported 56 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 57 | dirs: jnp.ndarray unit directions [..., 3] 58 | Returns: 59 | [..., C] 60 | """ 61 | assert deg <= 4 and deg >= 0 62 | coeff = (deg + 1) ** 2 63 | assert sh.shape[-1] >= coeff 64 | 65 | result = C0 * sh[..., 0] 66 | if deg > 0: 67 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 68 | result = result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] 69 | 70 | if deg > 1: 71 | xx, yy, zz = x * x, y * y, z * z 72 | xy, yz, xz = x * y, y * z, x * z 73 | result = ( 74 | result 75 | + C2[0] * xy * sh[..., 4] 76 | + C2[1] * yz * sh[..., 5] 77 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] 78 | + C2[3] * xz * sh[..., 7] 79 | + C2[4] * (xx - yy) * sh[..., 8] 80 | ) 81 | 82 | if deg > 2: 83 | result = ( 84 | result 85 | + C3[0] * y * (3 * xx - yy) * sh[..., 9] 86 | + C3[1] * xy * z * sh[..., 10] 87 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] 88 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] 89 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] 90 | + C3[5] * z * (xx - yy) * sh[..., 14] 91 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15] 92 | ) 93 | 94 | if deg > 3: 95 | result = ( 96 | result 97 | + C4[0] * xy * (xx - yy) * sh[..., 16] 98 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17] 99 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18] 100 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19] 101 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] 102 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21] 103 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] 104 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] 105 | + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24] 106 | ) 107 | return result 108 | 109 | 110 | def RGB2SH(rgb): 111 | return (rgb - 0.5) / C0 112 | 113 | 114 | def SH2RGB(sh): 115 | return sh * C0 + 0.5 116 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from errno import EEXIST 14 | from os import makedirs, path 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.start_time = None 7 | self.elapsed = 0 8 | self.paused = False 9 | 10 | def start(self): 11 | if self.start_time is None: 12 | self.start_time = time.time() 13 | elif self.paused: 14 | self.start_time = time.time() - self.elapsed 15 | self.paused = False 16 | 17 | def pause(self): 18 | if not self.paused: 19 | self.elapsed = time.time() - self.start_time 20 | self.paused = True 21 | 22 | def get_elapsed_time(self): 23 | if self.paused: 24 | return self.elapsed 25 | else: 26 | return time.time() - self.start_time 27 | --------------------------------------------------------------------------------