├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── arguments
├── __init__.py
└── nvidia_rodynrf
│ ├── Balloon1.py
│ ├── Balloon2.py
│ ├── Jumping.py
│ ├── Playground.py
│ ├── Skating.py
│ ├── Truck.py
│ ├── Umbrella.py
│ └── default.py
├── assets
├── architecture.png
└── icon.png
├── dycheck_geometry
├── __init__.py
├── barf_se3.py
├── camera.py
├── se3.py
├── trajs.py
└── utils.py
├── eval.sh
├── eval_nvidia.py
├── gaussian_renderer
└── __init__.py
├── gen_depth.py
├── gen_depth.sh
├── gen_tracks.py
├── gen_tracks.sh
├── install.sh
├── requirements.txt
├── requirements_unidepth.txt
├── scene
├── __init__.py
├── cameras.py
├── colmap_loader.py
├── dataset.py
├── dataset_readers.py
├── deformation.py
└── gaussian_model.py
├── train.py
├── train.sh
└── utils
├── TIMES.TTF
├── TIMESBD.TTF
├── TIMESBI.TTF
├── TIMESI.TTF
├── camera_utils.py
├── depth_loss_utils.py
├── dycheck_utils
├── __init__.py
├── annotation.py
├── common.py
├── flax_multioptim.py
├── image.py
├── io.py
├── path_ops.py
├── safe_ops.py
├── struct.py
├── types.py
└── visuals
│ ├── __init__.py
│ ├── corrs.py
│ ├── depth.py
│ ├── flow.py
│ ├── kps
│ ├── __init__.py
│ └── skeleton.py
│ ├── plotly.py
│ └── rendering.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loader_utils.py
├── loss_utils.py
├── main_utils.py
├── model_utils.py
├── params_utils.py
├── point_utils.py
├── pose_utils.py
├── render_utils.py
├── scene_utils.py
├── sh_utils.py
├── system_utils.py
└── timer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .vscode
3 | .idea
4 | output
5 | output_1
6 | build
7 | diff_rasterization/diff_rast.egg-info
8 | diff_rasterization/dist
9 | tensorboard_3d
10 | screenshots
11 | data/
12 | data
13 | extensions
14 | lab4d
15 | preprocess
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/simple-knn"]
2 | path = submodules/simple-knn
3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git
4 | [submodule "submodules/co-tracker"]
5 | path = submodules/co-tracker
6 | url = https://github.com/facebookresearch/co-tracker.git
7 | [submodule "submodules/UniDepth"]
8 | path = submodules/UniDepth
9 | url = https://github.com/lpiccinelli-eth/UniDepth.git
10 | [submodule "submodules/mega-sam"]
11 | path = submodules/mega-sam
12 | url = https://github.com/mega-sam/mega-sam.git
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Minh-Quan Viet Bui
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
[CVPR'25] SplineGS: Robust Motion-Adaptive Spline for Real-Time Dynamic 3D Gaussians from Monocular Video
2 |
3 |
4 | **[Jongmin Park](https://sites.google.com/view/jongmin-park)1\*, [Minh-Quan Viet Bui](https://quan5609.github.io/)1\*, [Juan Luis Gonzalez Bello](https://sites.google.com/view/juan-luis-gb/home)1, [Jaeho Moon](https://sites.google.com/view/jaehomoon)1, [Jihyong Oh](https://cmlab.cau.ac.kr/)2†, [Munchurl Kim](https://www.viclab.kaist.ac.kr/)1†**
5 |
6 | 1KAIST, South Korea, 2Chung-Ang University, South Korea
7 |
8 | \*Co-first authors (equal contribution), †Co-corresponding authors
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | ## 📣 News
27 | ### Updates
28 | - **May 26, 2025**: Code released.
29 | - **February 26, 2025**: SplineGS accepted to CVPR 2025 🎉.
30 | - **December 13, 2024**: Paper uploaded to arXiv. Check out the manuscript [here](https://arxiv.org/abs/2412.09982).(https://arxiv.org/abs/2412.09982).
31 | ### To-Dos
32 | - Add DAVIS dataset configurations.
33 | - Add custom dataset support.
34 | - Add iPhone dataset configurations.
35 | ## ⚙️ Environmental Setups
36 | Clone the repo and install dependencies:
37 | ```sh
38 | git clone https://github.com/KAIST-VICLab/SplineGS.git --recursive
39 | cd SplineGS
40 |
41 | # install splinegs environment
42 | conda create -n splinegs python=3.7
43 | conda activate splinegs
44 | export CUDA_HOME=$CONDA_PREFIX
45 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib
46 |
47 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
48 | conda install nvidia/label/cuda-11.7.0::cuda
49 | conda install nvidia/label/cuda-11.7.0::cuda-nvcc
50 | conda install nvidia/label/cuda-11.7.0::cuda-runtime
51 | conda install nvidia/label/cuda-11.7.0::cuda-cudart
52 |
53 |
54 | pip install -e submodules/simple-knn
55 | pip install -e submodules/co-tracker
56 | pip install -r requirements.txt
57 |
58 | # install depth environment
59 | conda deactivate
60 | conda create -n unidepth_splinegs python=3.10
61 | conda activate unidepth_splinegs
62 |
63 | pip install -r requirements_unidepth.txt
64 | conda install -c conda-forge ld_impl_linux-64
65 | export CUDA_HOME=$CONDA_PREFIX
66 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib
67 | conda install nvidia/label/cuda-12.1.0::cuda
68 | conda install nvidia/label/cuda-12.1.0::cuda-nvcc
69 | conda install nvidia/label/cuda-12.1.0::cuda-runtime
70 | conda install nvidia/label/cuda-12.1.0::cuda-cudart
71 | conda install nvidia/label/cuda-12.1.0::libcusparse
72 | conda install nvidia/label/cuda-12.1.0::libcublas
73 | cd submodules/UniDepth/unidepth/ops/knn;bash compile.sh;cd ../../../../../
74 | cd submodules/UniDepth/unidepth/ops/extract_patches;bash compile.sh;cd ../../../../../
75 |
76 | pip install -e submodules/UniDepth
77 | mkdir -p submodules/mega-sam/Depth-Anything/checkpoints
78 | ```
79 | ## 📁 Data Preparations
80 | ### Nvidia Dataset
81 | 1. We follow the evaluation setup from [RoDynRF](https://robust-dynrf.github.io/). Download the training images [here](https://github.com/KAIST-VICLab/SplineGS/releases/tag/dataset) and arrange them as follows:
82 | ```bash
83 | SplineGS/data/nvidia_rodynrf
84 | ├── Balloon1
85 | │ ├── images_2
86 | │ ├── instance_masks
87 | │ ├── motion_masks
88 | │ └── gt
89 | ├── ...
90 | └── Umbrella
91 | ```
92 | 2. Download [Depth-Anything checkpoint](https://huggingface.co/spaces/LiheYoung/Depth-Anything/blob/main/checkpoints/depth_anything_vitl14.pth) and place it at `submodules/mega-sam/Depth-Anything/checkpoints`. Generate depth estimation and tracking results for all scenes as:
93 | ```sh
94 | conda activate unidepth_splinegs
95 | bash gen_depth.sh
96 |
97 | conda deactivate
98 | conda activate splinegs
99 | bash gen_tracks.sh
100 | ```
101 | 3. To obtain motion masks, please refer to [Shape of Motion](https://github.com/vye16/shape-of-motion/). For Nvidia dataset, we provide the precomputed in `motion_masks` folder
102 | ### YOUR OWN Dataset
103 | T.B.D
104 | ## 🚀 Get Started
105 | ### Nvidia Dataset
106 | #### Training
107 | ```sh
108 | # check if environment is activated properly
109 | conda activate splinegs
110 |
111 | python train.py -s data/nvidia_rodynrf/${SCENE}/ --expname "${EXP_NAME}" --configs arguments/nvidia_rodynrf/${SCENE}.py
112 | ```
113 | #### Metrics Evaluation
114 | ```sh
115 | python eval_nvidia.py -s data/nvidia_rodynrf/${SCENE}/ --expname "${EXP_NAME}" --configs arguments/nvidia_rodynrf/${SCENE}.py --checkpoint output/${EXP_NAME}/point_cloud/fine_best
116 | ```
117 | ### YOUR OWN Dataset
118 | #### Training
119 | T.B.D
120 | #### Evaluation
121 | T.B.D
122 |
123 | ## Acknowledgments
124 | - This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korean Government [Ministry of Science and ICT (Information and Communications Technology)] (Project Number: RS-2022-00144444, Project Title: Deep Learning Based Visual Representational Learning and Rendering of Static and Dynamic Scenes, 100%).
125 |
126 | ## ⭐ Citing SplineGS
127 |
128 | If you find our repository useful, please consider giving it a star ⭐ and citing our research papers in your work:
129 | ```bibtex
130 | @InProceedings{Park_2025_CVPR,
131 | author = {Park, Jongmin and Bui, Minh-Quan Viet and Bello, Juan Luis Gonzalez and Moon, Jaeho and Oh, Jihyong and Kim, Munchurl},
132 | title = {SplineGS: Robust Motion-Adaptive Spline for Real-Time Dynamic 3D Gaussians from Monocular Video},
133 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
134 | month = {June},
135 | year = {2025},
136 | pages = {26866-26875}
137 | }
138 | ```
139 |
140 |
141 | ## 📈 Star History
142 |
143 | [](https://www.star-history.com/#KAIST-VICLab/SplineGS&Date)
144 |
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from argparse import ArgumentParser, Namespace
15 |
16 |
17 | class GroupParams:
18 | pass
19 |
20 |
21 | class ParamGroup:
22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
23 | group = parser.add_argument_group(name)
24 | for key, value in vars(self).items():
25 | shorthand = False
26 | if key.startswith("_"):
27 | shorthand = True
28 | key = key[1:]
29 | t = type(value)
30 | value = value if not fill_none else None
31 | if shorthand:
32 | if t == bool:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
34 | else:
35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
36 | else:
37 | if t == bool:
38 | group.add_argument("--" + key, default=value, action="store_true")
39 | else:
40 | group.add_argument("--" + key, default=value, type=t)
41 |
42 | def extract(self, args):
43 | group = GroupParams()
44 | for arg in vars(args).items():
45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
46 | setattr(group, arg[0], arg[1])
47 | return group
48 |
49 |
50 | class ModelParams(ParamGroup):
51 | def __init__(self, parser, sentinel=False):
52 | self.sh_degree = 3
53 | self.deform_spatial_scale = 1e-2
54 | self.rgbfuntion = "sandwich"
55 | self.control_num = 12
56 | self.prune_error_threshold = 1.0
57 |
58 | self._source_path = ""
59 | self.dataset_type = "nvidia"
60 | self.depth_type = "depth" # depth or disp
61 | self._model_path = ""
62 | self._images = "images"
63 | self._resolution = -1
64 | self._white_background = False
65 | self.data_device = "cuda"
66 | self.eval = True
67 | self.render_process = True
68 | self.debug_process = True
69 | self.add_points = False
70 | self.extension = ".png"
71 | self.llffhold = 8
72 | super().__init__(parser, "Loading Parameters", sentinel)
73 |
74 | def extract(self, args):
75 | g = super().extract(args)
76 | g.source_path = os.path.abspath(g.source_path)
77 | return g
78 |
79 |
80 | class PipelineParams(ParamGroup):
81 | def __init__(self, parser):
82 | self.convert_SHs_python = False
83 | self.compute_cov3D_python = False
84 | self.debug = False
85 | super().__init__(parser, "Pipeline Parameters")
86 |
87 |
88 | class ModelHiddenParams(ParamGroup):
89 | def __init__(self, parser):
90 | self.timebase_pe = 10
91 | self.timenet_width = 256
92 | self.timenet_output = 6
93 | self.pixel_base_pe = 5
94 | super().__init__(parser, "ModelHiddenParams")
95 |
96 |
97 | class OptimizationParams(ParamGroup):
98 | def __init__(self, parser):
99 | self.dataloader = False
100 | self.zerostamp_init = False
101 | self.custom_sampler = None
102 | self.iterations = 30_000
103 | self.coarse_iterations = 1000
104 | self.static_iterations = 1000
105 |
106 | self.position_lr_init = 0.00016
107 | self.position_lr_final = 0.0000016
108 | self.position_lr_delay_mult = 0.01
109 | self.position_lr_max_steps = 20_000
110 |
111 | self.deformation_lr_init = 0.00016
112 | self.deformation_lr_final = 0.000016
113 | self.deformation_lr_delay_mult = 0.01
114 |
115 | self.grid_lr_init = 0.0016
116 | self.grid_lr_final = 0.00016
117 |
118 | self.pose_lr_init = 0.0005
119 | self.pose_lr_final = 0.00005
120 | self.pose_lr_delay_mult = 0.01
121 |
122 | self.feature_lr = 0.0025
123 | self.featuret_lr = 0.001
124 | self.opacity_lr = 0.05
125 | self.scaling_lr = 0.005
126 | self.rotation_lr = 0.001
127 | self.percent_dense = 0.01
128 | self.lambda_dssim = 0.2
129 | self.p_lambda_dssim = 0.0
130 | self.lambda_lpips = 0
131 | self.weight_constraint_init = 1
132 | self.weight_constraint_after = 0.2
133 | self.weight_decay_iteration = 5_000
134 | self.opacity_reset_interval = 3_000
135 | self.densification_interval = 1_00
136 | self.densify_from_iter = 500
137 | self.densify_until_iter = 15_000
138 | self.densify_grad_threshold_coarse = 0.0002
139 | self.densify_grad_threshold_fine_init = 0.0002
140 | self.densify_grad_threshold_after = 0.0002
141 | self.pruning_from_iter = 500
142 | self.pruning_interval = 100
143 | self.opacity_threshold_coarse = 0.005
144 | self.opacity_threshold_fine_init = 0.005
145 | self.opacity_threshold_fine_after = 0.005
146 | self.fine_batch_size = 1
147 | self.coarse_batch_size = 1
148 | self.add_point = False
149 | self.use_instance_mask = False
150 |
151 | self.prevpath = "1"
152 | self.opthr = 0.005
153 | self.desicnt = 6
154 | self.densify = 1
155 | self.densify_grad_threshold = 0.0008
156 | self.densify_grad_threshold_dynamic = 0.00008
157 | self.preprocesspoints = 0
158 | self.addsphpointsscale = 0.8
159 | self.raystart = 0.7
160 |
161 | self.soft_depth_start = 1000
162 | self.hard_depth_start = 0
163 | self.error_tolerance = 0.001
164 |
165 | self.trbfc_lr = 0.0001 #
166 | self.trbfs_lr = 0.03
167 | self.trbfslinit = 0.0 #
168 | self.omega_lr = 0.0001
169 | self.zeta_lr = 0.0001
170 | self.movelr = 3.5
171 | self.rgb_lr = 0.0001
172 |
173 | self.stat_npts = 20000
174 | self.dyn_npts = 20000
175 |
176 | self.w_depth = 1.0
177 | self.w_mask = 2.0
178 | self.w_track = 1.0
179 | self.w_normal = 0
180 | super().__init__(parser, "Optimization Parameters")
181 |
182 |
183 | def get_combined_args(parser: ArgumentParser):
184 | cmdlne_string = sys.argv[1:]
185 | cfgfile_string = "Namespace()"
186 | args_cmdline = parser.parse_args(cmdlne_string)
187 |
188 | try:
189 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
190 | print("Looking for config file in", cfgfilepath)
191 | with open(cfgfilepath) as cfg_file:
192 | print("Config file found: {}".format(cfgfilepath))
193 | cfgfile_string = cfg_file.read()
194 | except TypeError:
195 | print("Config file not found at")
196 | args_cfgfile = eval(cfgfile_string)
197 |
198 | merged_dict = vars(args_cfgfile).copy()
199 | for k, v in vars(args_cmdline).items():
200 | if v is not None:
201 | merged_dict[k] = v
202 | return Namespace(**merged_dict)
203 |
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Balloon1.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | ModelParams = dict(
4 | depth_type="disp",
5 | )
6 |
7 | OptimizationParams = dict(
8 | densify_grad_threshold_dynamic = 0.0002,
9 | densify_grad_threshold = 0.0008,
10 | use_instance_mask=True,
11 | )
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Balloon2.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | ModelParams = dict(
4 | depth_type="disp",
5 | )
6 |
7 | OptimizationParams = dict(
8 | densify_grad_threshold_dynamic = 0.0002,
9 | densify_grad_threshold = 0.0008,
10 | )
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Jumping.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | OptimizationParams = dict(
4 | densify_grad_threshold = 0.0002,
5 | densify_grad_threshold_dynamic = 0.0002
6 | )
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Playground.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | OptimizationParams = dict(
4 | densify_grad_threshold = 0.0002,
5 | densify_grad_threshold_dynamic = 0.0002,
6 | opacity_reset_interval=30_000,
7 | use_instance_mask=True,
8 | )
9 |
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Skating.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | OptimizationParams = dict(
4 | densify_grad_threshold = 0.0002,
5 | densify_grad_threshold_dynamic = 0.00008,
6 | )
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Truck.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | OptimizationParams = dict(
4 | densify_grad_threshold = 0.0008,
5 | densify_grad_threshold_dynamic = 0.0002
6 | )
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/Umbrella.py:
--------------------------------------------------------------------------------
1 | _base_ = "./default.py"
2 |
3 | OptimizationParams = dict(
4 | w_normal=1.0,
5 | opacity_reset_interval=30_000,
6 | densify_grad_threshold = 0.0002,
7 | densify_grad_threshold_dynamic = 0.00002,
8 | )
9 |
--------------------------------------------------------------------------------
/arguments/nvidia_rodynrf/default.py:
--------------------------------------------------------------------------------
1 | OptimizationParams = dict(
2 | iterations=25_000,
3 | coarse_batch_size=2,
4 | fine_batch_size=2,
5 | coarse_iterations=5_000,
6 | static_iterations=5_000,
7 | densify_from_iter=500,
8 | densify_until_iter=12000,
9 | opacity_reset_interval=3_000,
10 | densify=1,
11 | stat_npts=20000,
12 | dyn_npts=10000,
13 | desicnt=12,
14 | )
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/assets/architecture.png
--------------------------------------------------------------------------------
/assets/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/assets/icon.png
--------------------------------------------------------------------------------
/dycheck_geometry/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 | from .camera import Camera, project, get_rays_direction
20 | from .se3 import (
21 | exp_se3,
22 | exp_so3,
23 | from_homogenous,
24 | rt_to_se3,
25 | skew,
26 | to_homogenous,
27 | )
28 | from .trajs import get_arc_traj, get_lemniscate_traj
29 | from .utils import matmul, matv
--------------------------------------------------------------------------------
/dycheck_geometry/barf_se3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : se3.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 |
21 | import jax
22 | from jax import numpy as np
23 |
24 |
25 | # def func_A(x):
26 | # return np.sin(x)/x
27 |
28 | # def func_B(x):
29 | # return (1-np.cos(x))/x**2
30 |
31 | # def func_C(x):
32 | # return (x-np.sin(x))/x**3
33 |
34 | # # a recursive definition shares some work
35 | # def taylor(f, order):
36 | # def improve_approx(g, k):
37 | # return lambda x, v: jvp_first(g, (x, v), v)[1] + f(x) / factorial(k)
38 | # approx = lambda x, v: f(x) / factorial(order)
39 | # for n in range(order):
40 | # approx = improve_approx(approx, order - n - 1)
41 | # return approx
42 |
43 | # def jvp_first(f, primals, tangent):
44 | # x, xs = primals[0], primals[1:]
45 | # return jvp(lambda x: f(x, *xs), (x,), (tangent,))
46 |
47 |
48 | def procrustes_analysis(X0, X1): # [N,3] X0 is target X1 is src
49 | # translation
50 | t0 = X0.mean(axis=0, keepdims=True)
51 | t1 = X1.mean(axis=0, keepdims=True)
52 | X0c = X0 - t0
53 | X1c = X1 - t1
54 | # scale
55 | s0 = np.sqrt((X0c**2).sum(axis=-1).mean())
56 | s1 = np.sqrt((X1c**2).sum(axis=-1).mean())
57 | X0cs = X0c / s0
58 | X1cs = X1c / s1
59 | # rotation (use double for SVD, float loses precision)
60 |
61 | U, S, V = np.linalg.svd((X0cs.T @ X1cs), full_matrices=False)
62 | R = U @ V.T
63 |
64 | if np.linalg.det(R) < 0:
65 | R = R.at[2].set(-R[2])
66 | sim3 = edict(t0=t0[0], t1=t1[0], s0=s0, s1=s1, R=R)
67 | return sim3
68 |
69 |
70 | @jax.jit
71 | def skew(w: np.ndarray) -> np.ndarray:
72 | """Build a skew matrix ("cross product matrix") for vector w.
73 | Modern Robotics Eqn 3.30.
74 |
75 | Args:
76 | w: (..., 3,) A 3-vector
77 |
78 | Returns:
79 | W: (..., 3, 3) A skew matrix such that W @ v == w x v
80 | """
81 | zeros = np.zeros_like(w[..., 0])
82 | return np.stack(
83 | [
84 | np.stack([zeros, -w[..., 2], w[..., 1]], axis=-1),
85 | np.stack([w[..., 2], zeros, -w[..., 0]], axis=-1),
86 | np.stack([-w[..., 1], w[..., 0], zeros], axis=-1),
87 | ],
88 | axis=-2,
89 | )
90 |
91 |
92 | @jax.jit
93 | def taylor_A(x, nth=10):
94 | # Taylor expansion of sin(x)/x
95 | ans = np.zeros_like(x)
96 | denom = 1.0
97 | for i in range(nth + 1):
98 | if i > 0:
99 | denom *= (2 * i) * (2 * i + 1)
100 | ans = ans + (-1) ** i * x ** (2 * i) / denom
101 | return ans
102 |
103 |
104 | @jax.jit
105 | def taylor_B(x, nth=10):
106 | # Taylor expansion of (1-cos(x))/x**2
107 | ans = np.zeros_like(x)
108 | denom = 1.0
109 | for i in range(nth + 1):
110 | denom *= (2 * i + 1) * (2 * i + 2)
111 | ans = ans + (-1) ** i * x ** (2 * i) / denom
112 | return ans
113 |
114 |
115 | @jax.jit
116 | def taylor_C(x, nth=10):
117 | # Taylor expansion of (x-sin(x))/x**3
118 | ans = np.zeros_like(x)
119 | denom = 1.0
120 | for i in range(nth + 1):
121 | denom *= (2 * i + 2) * (2 * i + 3)
122 | ans = ans + (-1) ** i * x ** (2 * i) / denom
123 | return ans
124 |
125 |
126 | def se3_to_SE3(w: np.ndarray, u: np.ndarray) -> np.ndarray:
127 | wx = skew(w)
128 | theta = np.linalg.norm(w, axis=-1)[..., None, None]
129 | I = np.identity(3, dtype=w.dtype)
130 |
131 | A = taylor_A(theta)
132 | B = taylor_B(theta)
133 | C = taylor_C(theta)
134 |
135 | # check nerfies: R is e^r, V is G, u is v, wx is [r]
136 |
137 | R = I + A * wx + B * wx @ wx
138 | V = I + B * wx + C * wx @ wx
139 |
140 | t = V @ u[..., None]
141 | return R, t
142 |
--------------------------------------------------------------------------------
/dycheck_geometry/se3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : se3.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import numpy as np
21 |
22 | from . import utils
23 |
24 |
25 | def skew(w: np.ndarray) -> np.ndarray:
26 | """Build a skew matrix ("cross product matrix") for vector w.
27 | Modern Robotics Eqn 3.30.
28 |
29 | Args:
30 | w: (..., 3,) A 3-vector
31 |
32 | Returns:
33 | W: (..., 3, 3) A skew matrix such that W @ v == w x v
34 | """
35 | zeros = np.zeros_like(w[..., 0])
36 | return np.stack(
37 | [
38 | np.stack([zeros, -w[..., 2], w[..., 1]], axis=-1),
39 | np.stack([w[..., 2], zeros, -w[..., 0]], axis=-1),
40 | np.stack([-w[..., 1], w[..., 0], zeros], axis=-1),
41 | ],
42 | axis=-2,
43 | )
44 |
45 |
46 | def rt_to_se3(R: np.ndarray, t: np.ndarray) -> np.ndarray:
47 | """Rotation and translation to homogeneous transform.
48 |
49 | Args:
50 | R: (..., 3, 3) An orthonormal rotation matrix.
51 | t: (..., 3,) A 3-vector representing an offset.
52 |
53 | Returns:
54 | X: (..., 4, 4) The homogeneous transformation matrix described by
55 | rotating by R and translating by t.
56 | """
57 | batch_shape = R.shape[:-2]
58 | return np.concatenate(
59 | [
60 | np.concatenate([R, t[..., None]], axis=-1),
61 | np.broadcast_to(np.array([[0, 0, 0, 1]], np.float32), batch_shape + (1, 4)),
62 | ],
63 | axis=-2,
64 | )
65 |
66 |
67 | def exp_so3(w: np.ndarray, theta: np.ndarray) -> np.ndarray:
68 | """Exponential map from Lie algebra so3 to Lie group SO3.
69 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula.
70 |
71 | Args:
72 | w: (..., 3,) An axis of rotation. This is assumed to be a unit-vector.
73 | theta (...,): An angle of rotation.
74 |
75 | Returns:
76 | R: (..., 3, 3) An orthonormal rotation matrix representing a rotation
77 | of magnitude theta about axis w.
78 | """
79 | batch_shape = w.shape[:-1]
80 | W = skew(w)
81 | return (
82 | np.broadcast_to(np.eye(3), batch_shape + (3, 3))
83 | + np.sin(theta)[..., None, None] * W
84 | + (1 - np.cos(theta)[..., None, None]) * utils.matmul(W, W)
85 | )
86 |
87 |
88 | def exp_se3(S: np.ndarray, theta: np.ndarray) -> np.ndarray:
89 | """Exponential map from Lie algebra so3 to Lie group SO3.
90 |
91 | Modern Robotics Eqn 3.88.
92 |
93 | Args:
94 | S: (..., 6,) A screw axis of motion.
95 | theta (...,): Magnitude of motion.
96 |
97 | Returns:
98 | a_X_b: (..., 4, 4) The homogeneous transformation matrix attained by
99 | integrating motion of magnitude theta about S for one second.
100 | """
101 | batch_shape = S.shape[:-1]
102 | w, v = np.split(S, 2, axis=-1)
103 | W = skew(w)
104 | R = exp_so3(w, theta)
105 | t = utils.matv(
106 | (
107 | theta[..., None, None] * np.broadcast_to(np.eye(3), batch_shape + (3, 3))
108 | + (1 - np.cos(theta)[..., None, None]) * W
109 | + (theta[..., None, None] - np.sin(theta)[..., None, None]) * utils.matmul(W, W)
110 | ),
111 | v,
112 | )
113 | return rt_to_se3(R, t)
114 |
115 |
116 | def to_homogenous(v):
117 | return np.concatenate([v, np.ones_like(v[..., :1])], axis=-1)
118 |
119 |
120 | def from_homogenous(v):
121 | return v[..., :3] / v[..., -1:]
122 |
--------------------------------------------------------------------------------
/dycheck_geometry/trajs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : trajs.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import List
21 |
22 | import numpy as np
23 | from scipy.spatial.transform import Rotation
24 |
25 | from .camera import Camera
26 | from .utils import matv
27 |
28 |
29 | def get_arc_traj(
30 | ref_camera: Camera,
31 | lookat: np.ndarray,
32 | up: np.ndarray,
33 | *,
34 | num_frames: int,
35 | degree: float,
36 | **_,
37 | ) -> List[Camera]:
38 | positions = [
39 | matv(Rotation.from_rotvec(d / 180 * np.pi * up).as_matrix() @ (ref_camera.position - lookat)) + lookat
40 | for d in np.linspace(-degree / 2, degree / 2, num_frames)
41 | ]
42 | cameras = [ref_camera.lookat(p, lookat, up) for p in positions]
43 | return cameras
44 |
45 |
46 | def get_lemniscate_traj(
47 | ref_camera: Camera,
48 | lookat: np.ndarray,
49 | up: np.ndarray,
50 | *,
51 | num_frames: int,
52 | degree: float,
53 | **_,
54 | ) -> List[Camera]:
55 | a = np.linalg.norm(ref_camera.position - lookat) * np.tan(degree / 360 * np.pi)
56 | # Lemniscate curve in camera space. Starting at the origin.
57 | positions = np.array(
58 | [
59 | np.array(
60 | [
61 | a * np.cos(t) / (1 + np.sin(t) ** 2),
62 | a * np.cos(t) * np.sin(t) / (1 + np.sin(t) ** 2),
63 | 0,
64 | ]
65 | )
66 | for t in (np.linspace(0, 2 * np.pi, num_frames) + np.pi / 2)
67 | ]
68 | )
69 | # Transform to world space.
70 | positions = matv(ref_camera.orientation.T, positions) + ref_camera.position
71 | cameras = [ref_camera.lookat(p, lookat, up) for p in positions]
72 | return cameras
73 |
--------------------------------------------------------------------------------
/dycheck_geometry/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : utils.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import numpy as np
21 | from utils.dycheck_utils import types
22 |
23 |
24 | def matmul(a: types.Array, b: types.Array) -> types.Array:
25 | if isinstance(a, np.ndarray):
26 | assert isinstance(b, np.ndarray)
27 | else:
28 | assert isinstance(a, np.ndarray)
29 | assert isinstance(b, np.ndarray)
30 |
31 | if isinstance(a, np.ndarray):
32 | return a @ b
33 | else:
34 | # NOTE: The original implementation uses highest precision for TPU
35 | # computation. Since we are using GPUs only, comment it out.
36 | # return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST)
37 | return np.matmul(a, b)
38 |
39 |
40 | def matv(a: types.Array, b: types.Array) -> types.Array:
41 | return matmul(a, b[..., None])[..., 0]
42 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | python eval_nvidia.py -s data/nvidia_rodynrf/Balloon1/ --expname "Balloon1" --configs arguments/nvidia_rodynrf/Balloon1.py --checkpoint output/Balloon1/point_cloud/fine_best
2 | python eval_nvidia.py -s data/nvidia_rodynrf/Balloon2/ --expname "Balloon2" --configs arguments/nvidia_rodynrf/Balloon2.py --checkpoint output/Balloon2/point_cloud/fine_best
3 | python eval_nvidia.py -s data/nvidia_rodynrf/Jumping/ --expname "Jumping" --configs arguments/nvidia_rodynrf/Jumping.py --checkpoint output/Jumping/point_cloud/fine_best
4 | python eval_nvidia.py -s data/nvidia_rodynrf/Playground/ --expname "Playground" --configs arguments/nvidia_rodynrf/Playground.py --checkpoint output/Playground/point_cloud/fine_best
5 | python eval_nvidia.py -s data/nvidia_rodynrf/Skating/ --expname "Skating" --configs arguments/nvidia_rodynrf/Skating.py --checkpoint output/Skating/point_cloud/fine_best
6 | python eval_nvidia.py -s data/nvidia_rodynrf/Umbrella/ --expname "Umbrella" --configs arguments/nvidia_rodynrf/Umbrella.py --checkpoint output/Umbrella/point_cloud/fine_best
7 |
--------------------------------------------------------------------------------
/eval_nvidia.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | sys.path.append(os.path.join(sys.path[0], ".."))
5 | import cv2
6 | import lpips
7 | import numpy as np
8 | from argparse import ArgumentParser
9 | from arguments import ModelHiddenParams, ModelParams, OptimizationParams, PipelineParams
10 | from gaussian_renderer import render_infer
11 | from PIL import Image
12 | from scene import GaussianModel, Scene, dataset_readers
13 | from utils.graphics_utils import pts2pixel
14 | from utils.main_utils import get_pixels
15 | from utils.image_utils import psnr
16 | from gsplat.rendering import fully_fused_projection
17 | from scene import GaussianModel, Scene, dataset_readers, deformation
18 | import random
19 |
20 |
21 | def normalize_image(img):
22 | return (2.0 * img - 1.0)[None, ...]
23 |
24 |
25 | def training_report(scene: Scene, train_cams, test_cams, renderFunc, background, stage, dataset_type, path):
26 | test_psnr = 0.0
27 | torch.cuda.empty_cache()
28 |
29 | validation_configs = ({"name": "test", "cameras": test_cams}, {"name": "train", "cameras": train_cams})
30 | lpips_loss = lpips.LPIPS(net="alex").cuda()
31 |
32 | start_event = torch.cuda.Event(enable_timing=True)
33 | end_event = torch.cuda.Event(enable_timing=True)
34 |
35 | for config in validation_configs:
36 | if config["cameras"] and len(config["cameras"]) > 0:
37 | l1_test = 0.0
38 | psnr_test = 0.0
39 | lpips_test = 0.0
40 | run_time = 0.0
41 | elapsed_time_ms_list = []
42 | for idx, viewpoint in enumerate(config["cameras"]):
43 |
44 | if idx == 0: # warmup iter
45 | for _ in range(5):
46 | render_pkg = renderFunc(
47 | viewpoint, scene.stat_gaussians, scene.dyn_gaussians, background
48 | )
49 |
50 | torch.cuda.synchronize()
51 | start_event.record()
52 | render_pkg = renderFunc(
53 | viewpoint, scene.stat_gaussians, scene.dyn_gaussians, background
54 | )
55 | end_event.record()
56 | torch.cuda.synchronize()
57 | elapsed_time_ms = start_event.elapsed_time(end_event)
58 | elapsed_time_ms_list.append(elapsed_time_ms)
59 | run_time += elapsed_time_ms
60 |
61 | image = render_pkg["render"]
62 | image = torch.clamp(image, 0.0, 1.0)
63 |
64 | img = Image.fromarray(
65 | (np.clip(image.permute(1, 2, 0).detach().cpu().numpy(), 0, 1) * 255).astype("uint8")
66 | )
67 | os.makedirs(path + "/{}".format(config["name"]), exist_ok=True)
68 | img.save(path + "/{}/img_{}.png".format(config["name"], idx))
69 |
70 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
71 |
72 | psnr_test += psnr(image, gt_image, mask=None).mean().double()
73 | lpips_test += lpips_loss.forward(normalize_image(image), normalize_image(gt_image)).item()
74 |
75 | psnr_test /= len(config["cameras"])
76 | l1_test /= len(config["cameras"])
77 | lpips_test /= len(config["cameras"])
78 | run_time /= len(config["cameras"])
79 |
80 | print(
81 | "\n[ITER {}] Evaluating {}: PSNR {}, LPIPS {}, FPS {}".format(
82 | -1, config["name"], psnr_test, lpips_test, 1 / (run_time / 1000)
83 | )
84 | )
85 |
86 |
87 | if __name__ == "__main__":
88 | parser = ArgumentParser(description="Training script parameters")
89 | lp = ModelParams(parser)
90 | op = OptimizationParams(parser)
91 | pp = PipelineParams(parser)
92 | hp = ModelHiddenParams(parser)
93 |
94 | parser.add_argument(
95 | "--checkpoint", type=str, required=True, help="Path to the checkpoint file",
96 | )
97 | parser.add_argument("--expname", type=str, default="")
98 | parser.add_argument("--configs", type=str, default="")
99 |
100 | args = parser.parse_args(sys.argv[1:])
101 | if args.configs:
102 | import mmengine as mmcv
103 | from utils.params_utils import merge_hparams
104 |
105 | config = mmcv.Config.fromfile(args.configs)
106 | args = merge_hparams(args, config)
107 |
108 | dataset = lp.extract(args)
109 | hyper = hp.extract(args)
110 | stat_gaussians = GaussianModel(dataset)
111 | dyn_gaussians = GaussianModel(dataset)
112 | opt = op.extract(args)
113 |
114 | scene = Scene(
115 | dataset, dyn_gaussians, stat_gaussians, load_coarse=None
116 | ) # for other datasets rather than iPhone dataset
117 |
118 | dyn_gaussians.create_pose_network(hyper, scene.getTrainCameras()) # pose network with instance scaling
119 |
120 | bg_color = [1] * 9 + [0] if dataset.white_background else [0] * 9 + [0]
121 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
122 | pipe = pp.extract(args)
123 |
124 | test_cams = scene.getTestCameras()
125 | train_cams = scene.getTrainCameras()
126 | my_test_cams = [i for i in test_cams]
127 | viewpoint_stack = [i for i in train_cams]
128 |
129 | # if os.path.exists(os.path.join(args.checkpoint, "compact_point_cloud.npz")):
130 | # if False: # TODO: remove this after training
131 | # dyn_gaussians.load_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud.ply"))
132 | # stat_gaussians.load_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud_static.ply"))
133 | # else:
134 | dyn_gaussians.load_ply(os.path.join(args.checkpoint, "point_cloud.ply"))
135 | stat_gaussians.load_ply(os.path.join(args.checkpoint, "point_cloud_static.ply"))
136 |
137 | dyn_gaussians.flatten_control_point() # TODO: support this saving in training
138 | stat_gaussians.save_ply_compact(os.path.join(args.checkpoint, "compact_point_cloud_static.ply"))
139 | dyn_gaussians.save_ply_compact_dy(os.path.join(args.checkpoint, "compact_point_cloud.ply"))
140 |
141 |
142 | dyn_gaussians.load_model(args.checkpoint)
143 | dyn_gaussians._posenet.eval()
144 |
145 |
146 | pixels = get_pixels(
147 | scene.train_camera.dataset[0].metadata.image_size_x,
148 | scene.train_camera.dataset[0].metadata.image_size_y,
149 | use_center=True,
150 | )
151 | if pixels.shape[-1] != 2:
152 | raise ValueError("The last dimension of pixels must be 2.")
153 | batch_shape = pixels.shape[:-1]
154 | pixels = np.reshape(pixels, (-1, 2))
155 | y = (
156 | pixels[..., 1] - scene.train_camera.dataset[0].metadata.principal_point_y
157 | ) / dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy()
158 | x = (
159 | pixels[..., 0] - scene.train_camera.dataset[0].metadata.principal_point_x
160 | ) / dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy()
161 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1)
162 | local_viewdirs = viewdirs / np.linalg.norm(viewdirs, axis=-1, keepdims=True)
163 |
164 | with torch.no_grad():
165 | for cam in viewpoint_stack:
166 | time_in = torch.tensor(cam.time).float().cuda()
167 | pred_R, pred_T = dyn_gaussians._posenet(time_in.view(1, 1))
168 | R_ = torch.transpose(pred_R, 2, 1).detach().cpu().numpy()
169 | t_ = pred_T.detach().cpu().numpy()
170 | cam.update_cam(
171 | R_[0],
172 | t_[0],
173 | local_viewdirs,
174 | batch_shape,
175 | dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy(),
176 | )
177 |
178 | for view_id in range(len(my_test_cams)):
179 | my_test_cams[view_id].update_cam(
180 | viewpoint_stack[0].R,
181 | viewpoint_stack[0].T,
182 | local_viewdirs,
183 | batch_shape,
184 | dyn_gaussians._posenet.focal_bias.exp().detach().cpu().numpy(),
185 | )
186 |
187 | training_report(
188 | scene,
189 | viewpoint_stack,
190 | my_test_cams,
191 | render_infer,
192 | background,
193 | "fine",
194 | scene.dataset_type,
195 | os.path.join("output", args.expname),
196 | )
--------------------------------------------------------------------------------
/gen_depth.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 |
6 | from submodules.UniDepth.unidepth.models import UniDepthV1, UniDepthV2, UniDepthV2old
7 | from submodules.UniDepth.unidepth.utils import colorize, image_grid
8 |
9 | import glob
10 | import os
11 |
12 | def gen_depth(model, args):
13 | focal = 500 # default 500 for all datasets
14 |
15 | images_list = sorted(glob.glob(os.path.join(args.image_dir, '*.png')))
16 | os.makedirs(args.out_dir, exist_ok=True)
17 | for image in images_list:
18 | rgb = np.array(Image.open(image).convert('RGB'))
19 | rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
20 |
21 | H, W = rgb_torch.shape[1], rgb_torch.shape[2]
22 | intrinsics_torch = torch.from_numpy(np.array([[focal, 0, H/2],
23 | [0, focal, W/2],
24 | [0, 0 ,1]])).float()
25 | # predict
26 | predictions = model.infer(rgb_torch, intrinsics_torch)
27 | depth_pred = predictions["depth"].squeeze().cpu().numpy()[..., None]
28 |
29 | fname = os.path.basename(image)
30 | np.save(os.path.join(args.out_dir, fname.replace('png', 'npy')), depth_pred)
31 |
32 | # # colorize
33 | depth_pred_col = colorize(depth_pred)
34 | Image.fromarray(depth_pred_col).save(os.path.join(args.out_dir, fname))
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument("--image_dir", type=str, required=True, help="image dir")
40 | parser.add_argument("--out_dir", type=str, required=True, help="out dir")
41 | parser.add_argument("--depth_type", type=str, help="depth type disp or depth", default="depth")
42 | parser.add_argument("--depth_model", type=str, help="unidepth model", default="v2old")
43 | args = parser.parse_args()
44 |
45 | print("Torch version:", torch.__version__)
46 |
47 | if args.depth_type == "disp":
48 | os.makedirs(args.out_dir, exist_ok=True)
49 | cmd = f"python submodules/mega-sam/Depth-Anything/run_videos.py --encoder vitl \
50 | --load-from submodules/mega-sam/Depth-Anything/checkpoints/depth_anything_vitl14.pth \
51 | --img-path {args.image_dir} \
52 | --outdir {args.out_dir}"
53 | os.system(cmd)
54 | elif args.depth_type == "depth":
55 | type_ = "l" # available types: s, b, l
56 | name = f"unidepth-{args.depth_model}-vit{type_}14"
57 | if args.depth_model == "v2":
58 | model = UniDepthV2.from_pretrained(f"lpiccinelli/{name}")
59 | # set resolution level (only V2)
60 | # model.resolution_level = 9
61 |
62 | # set interpolation mode (only V2)
63 | model.interpolation_mode = "bilinear"
64 | elif args.depth_model == "v2old":
65 | model = UniDepthV2old.from_pretrained(f"lpiccinelli/{name}")
66 |
67 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68 | model = model.to(device).eval()
69 |
70 | gen_depth(model, args)
71 | else:
72 | raise ValueError("depth_type must be either 'disp' or 'depth'")
73 |
74 |
--------------------------------------------------------------------------------
/gen_depth.sh:
--------------------------------------------------------------------------------
1 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --out_dir data/nvidia_rodynrf/Balloon1/uni_depth
2 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --out_dir data/nvidia_rodynrf/Balloon1/depth_anything --depth_type disp
3 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --out_dir data/nvidia_rodynrf/Balloon2/uni_depth
4 | python gen_depth.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --out_dir data/nvidia_rodynrf/Balloon2/depth_anything --depth_type disp
5 | python gen_depth.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --out_dir data/nvidia_rodynrf/Jumping/uni_depth
6 | python gen_depth.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --out_dir data/nvidia_rodynrf/Jumping/depth_anything --depth_type disp
7 | python gen_depth.py --image_dir data/nvidia_rodynrf/Playground/images_2 --out_dir data/nvidia_rodynrf/Playground/uni_depth --depth_model v2
8 | python gen_depth.py --image_dir data/nvidia_rodynrf/Playground/images_2 --out_dir data/nvidia_rodynrf/Playground/depth_anything --depth_type disp
9 | python gen_depth.py --image_dir data/nvidia_rodynrf/Skating/images_2 --out_dir data/nvidia_rodynrf/Skating/uni_depth
10 | python gen_depth.py --image_dir data/nvidia_rodynrf/Skating/images_2 --out_dir data/nvidia_rodynrf/Skating/depth_anything --depth_type disp
11 | python gen_depth.py --image_dir data/nvidia_rodynrf/Truck/images_2 --out_dir data/nvidia_rodynrf/Truck/uni_depth
12 | python gen_depth.py --image_dir data/nvidia_rodynrf/Truck/images_2 --out_dir data/nvidia_rodynrf/Truck/depth_anything --depth_type disp
13 | python gen_depth.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --out_dir data/nvidia_rodynrf/Umbrella/uni_depth
14 | python gen_depth.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --out_dir data/nvidia_rodynrf/Umbrella/depth_anything --depth_type disp
--------------------------------------------------------------------------------
/gen_tracks.py:
--------------------------------------------------------------------------------
1 |
2 | """ Code borrowed from
3 | https://github.com/vye16/shape-of-motion/blob/main/preproc/compute_tracks_torch.py
4 | """
5 | import argparse
6 | import glob
7 | import os
8 |
9 | from PIL import Image
10 | import numpy as np
11 | import torch
12 | from tqdm import tqdm
13 |
14 | from cotracker.utils.visualizer import Visualizer
15 |
16 | DEFAULT_DEVICE = (
17 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18 | "cuda"
19 | if torch.cuda.is_available()
20 | else "cpu"
21 | )
22 |
23 | def read_video(folder_path):
24 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*")))
25 | video = np.concatenate([np.array(Image.open(frame_path)).transpose(2, 0, 1)[None, None] for frame_path in frame_paths], axis=1)
26 | video = torch.from_numpy(video).float()
27 | return video
28 |
29 | def read_mask(folder_path):
30 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*")))
31 | video = np.concatenate([np.array(Image.open(frame_path))[None, None] for frame_path in frame_paths], axis=1)
32 | video = torch.from_numpy(video).float()
33 | return video
34 |
35 | def main():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--image_dir", type=str, required=True, help="image dir")
38 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir")
39 | parser.add_argument("--out_dir", type=str, required=True, help="out dir")
40 | parser.add_argument("--is_static", action="store_true")
41 | parser.add_argument("--grid_size", type=int, default=100, help="Regular grid size")
42 | parser.add_argument(
43 | "--grid_query_frame",
44 | type=int,
45 | default=0,
46 | help="Compute dense and grid tracks starting from this frame",
47 | )
48 | parser.add_argument(
49 | "--backward_tracking",
50 | action="store_true",
51 | help="Compute tracks in both directions, not only forward",
52 | )
53 | args = parser.parse_args()
54 |
55 | folder_path = args.image_dir
56 | mask_dir = args.mask_dir
57 | frame_names = [
58 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*")))
59 | ]
60 | out_dir = args.out_dir
61 | os.makedirs(out_dir, exist_ok=True)
62 | os.makedirs(os.path.join(out_dir, "vis"), exist_ok=True)
63 |
64 | done = True
65 | for t in range(len(frame_names)):
66 | for j in range(len(frame_names)):
67 | name_t = os.path.splitext(frame_names[t])[0]
68 | name_j = os.path.splitext(frame_names[j])[0]
69 | out_path = f"{out_dir}/{name_t}_{name_j}.npy"
70 | if not os.path.exists(out_path):
71 | done = False
72 | break
73 | print(f"{done}")
74 | if done:
75 | print("Already done")
76 | return
77 |
78 | ## Load model
79 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(DEFAULT_DEVICE)
80 | video = read_video(folder_path).to(DEFAULT_DEVICE)
81 |
82 | masks = read_mask(mask_dir).to(DEFAULT_DEVICE)
83 |
84 | masks[masks>0] = 1.
85 | if args.is_static:
86 | masks = 1.0 - masks
87 |
88 | _, num_frames,_, height, width = video.shape
89 | vis = Visualizer(save_dir=os.path.join(out_dir, "vis"), pad_value=120, linewidth=3)
90 |
91 | for t in tqdm(range(num_frames), desc="query frames"):
92 | name_t = os.path.splitext(frame_names[t])[0]
93 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy")
94 | if len(file_matches) == num_frames:
95 | print(f"Already computed tracks with query {t} {name_t}")
96 | continue
97 |
98 | current_mask = masks[:,t].unsqueeze(1)
99 | start_pred = None
100 |
101 | for j in range(num_frames):
102 | if j > t:
103 | current_video = video[:,t:j+1]
104 | elif j < t:
105 | current_video = torch.flip(video[:,j:t+1], dims=(1,)) # reverse
106 | else:
107 | continue
108 | # current_video = video[:,t:t+1]
109 |
110 |
111 | pred_tracks, pred_visibility = model(
112 | current_video,
113 | grid_size=args.grid_size,
114 | grid_query_frame=0,
115 | backward_tracking=False,
116 | segm_mask=current_mask
117 | )
118 |
119 |
120 | pred = torch.cat([pred_tracks, pred_visibility.unsqueeze(-1)], dim=-1)
121 | current_pred = pred[0,-1]
122 | start_pred = pred[0,0]
123 |
124 | # save
125 | name_j = os.path.splitext(frame_names[j])[0]
126 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", current_pred.cpu().numpy())
127 |
128 | # visualize
129 | # vis.visualize(current_video, pred_tracks, pred_visibility, filename=f"{name_t}_{name_j}")
130 |
131 | np.save(f"{out_dir}/{name_t}_{name_t}.npy", start_pred.cpu().numpy())
132 |
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/gen_tracks.sh:
--------------------------------------------------------------------------------
1 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --mask_dir data/nvidia_rodynrf/Balloon1/motion_masks --out_dir data/nvidia_rodynrf/Balloon1/bootscotracker_dynamic --grid_size 256
2 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon1/images_2 --mask_dir data/nvidia_rodynrf/Balloon1/motion_masks --out_dir data/nvidia_rodynrf/Balloon1/bootscotracker_static --is_static --grid_size 50
3 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --mask_dir data/nvidia_rodynrf/Balloon2/motion_masks --out_dir data/nvidia_rodynrf/Balloon2/bootscotracker_dynamic --grid_size 256
4 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Balloon2/images_2 --mask_dir data/nvidia_rodynrf/Balloon2/motion_masks --out_dir data/nvidia_rodynrf/Balloon2/bootscotracker_static --is_static --grid_size 50
5 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --mask_dir data/nvidia_rodynrf/Jumping/motion_masks --out_dir data/nvidia_rodynrf/Jumping/bootscotracker_dynamic --grid_size 256
6 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Jumping/images_2 --mask_dir data/nvidia_rodynrf/Jumping/motion_masks --out_dir data/nvidia_rodynrf/Jumping/bootscotracker_static --is_static --grid_size 50
7 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Playground/images_2 --mask_dir data/nvidia_rodynrf/Playground/motion_masks --out_dir data/nvidia_rodynrf/Playground/bootscotracker_dynamic --grid_size 256
8 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Playground/images_2 --mask_dir data/nvidia_rodynrf/Playground/motion_masks --out_dir data/nvidia_rodynrf/Playground/bootscotracker_static --is_static --grid_size 50
9 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Truck/images_2 --mask_dir data/nvidia_rodynrf/Truck/motion_masks --out_dir data/nvidia_rodynrf/Truck/bootscotracker_dynamic --grid_size 256
10 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Truck/images_2 --mask_dir data/nvidia_rodynrf/Truck/motion_masks --out_dir data/nvidia_rodynrf/Truck/bootscotracker_static --is_static --grid_size 50
11 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Skating/images_2 --mask_dir data/nvidia_rodynrf/Skating/motion_masks --out_dir data/nvidia_rodynrf/Skating/bootscotracker_dynamic --grid_size 256
12 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Skating/images_2 --mask_dir data/nvidia_rodynrf/Skating/motion_masks --out_dir data/nvidia_rodynrf/Skating/bootscotracker_static --is_static --grid_size 50
13 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --mask_dir data/nvidia_rodynrf/Umbrella/motion_masks --out_dir data/nvidia_rodynrf/Umbrella/bootscotracker_dynamic --grid_size 256
14 | python gen_tracks.py --image_dir data/nvidia_rodynrf/Umbrella/images_2 --mask_dir data/nvidia_rodynrf/Umbrella/motion_masks --out_dir data/nvidia_rodynrf/Umbrella/bootscotracker_static --is_static --grid_size 50
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | # # install splinegs environment
2 | # ENV_NAME=splinegs
3 |
4 | # conda remove -n $ENV_NAME --all -y
5 | # conda create -n $ENV_NAME python=3.7
6 | # conda activate $ENV_NAME
7 | # export CUDA_HOME=$CONDA_PREFIX
8 |
9 | # conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
10 | # conda install nvidia/label/cuda-11.7.0::cuda
11 | # conda install nvidia/label/cuda-11.7.0::cuda-nvcc
12 | # conda install nvidia/label/cuda-11.7.0::cuda-runtime
13 | # conda install nvidia/label/cuda-11.7.0::cuda-cudart
14 | # conda install -c conda-forge ld_impl_linux-64
15 |
16 | # pip install -e submodules/simple-knn
17 | # pip install -e submodules/co-tracker
18 | # pip install -r requirements.txt
19 |
20 | # install unidepth environment
21 | UNIDEPTH_ENV_NAME=unidepth_splinegs
22 | conda remove -n $UNIDEPTH_ENV_NAME --all -y
23 | conda create -n $UNIDEPTH_ENV_NAME python=3.10
24 | conda activate $UNIDEPTH_ENV_NAME
25 |
26 | pip install -r requirements_unidepth.txt
27 | conda install -c conda-forge ld_impl_linux-64
28 | export CUDA_HOME=$CONDA_PREFIX
29 | export LD_LIBRARY_PATH=$CONDA_PREFIX/lib
30 | conda install nvidia/label/cuda-12.1.0::cuda
31 | conda install nvidia/label/cuda-12.1.0::cuda-nvcc
32 | conda install nvidia/label/cuda-12.1.0::cuda-runtime
33 | conda install nvidia/label/cuda-12.1.0::cuda-cudart
34 | conda install nvidia/label/cuda-12.1.0::libcusparse
35 | conda install nvidia/label/cuda-12.1.0::libcublas
36 | cd submodules/UniDepth/unidepth/ops/knn;bash compile.sh;cd ../../../../../
37 | pip install -e submodules/UniDepth
38 |
39 | # install depthanything
40 | mkdir -p submodules/mega-sam/Depth-Anything/checkpoints
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/nerfstudio-project/gsplat.git@v1.4.0
2 | plyfile
3 | ffmpeg
4 | numpy==1.21.6
5 | scipy==1.7.3
6 | matplotlib==3.5.3
7 | mmengine==0.10.6
8 | imageio==2.19.3
9 | imageio-ffmpeg==0.4.7
10 | kornia==0.6.9
11 | tqdm
--------------------------------------------------------------------------------
/requirements_unidepth.txt:
--------------------------------------------------------------------------------
1 | einops>=0.7.0
2 | gradio
3 | h5py>=3.10.0
4 | huggingface-hub>=0.22.0
5 | imageio
6 | matplotlib
7 | numpy==2.2.5
8 | opencv-python
9 | pandas
10 | pillow>=10.2.0
11 | protobuf>=4.25.3
12 | scipy
13 | tables
14 | tabulate
15 | termcolor
16 | timm
17 | tqdm
18 | trimesh
19 | triton>=2.4.0
20 | torch==2.4.0
21 | torchvision==0.19.0
22 | torchaudio==2.4.0
23 | wandb
24 | xformers==0.0.27.post2
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 |
14 | from arguments import ModelParams
15 | from scene.dataset import FourDGSdataset
16 | from scene.dataset_readers import sceneLoadTypeCallbacks
17 | from scene.gaussian_model import GaussianModel
18 | from utils.system_utils import searchForMaxIteration
19 |
20 |
21 | class Scene:
22 | dyn_gaussians: GaussianModel
23 |
24 | def __init__(
25 | self,
26 | args: ModelParams,
27 | gaussians: GaussianModel,
28 | static_gaussians: GaussianModel,
29 | load_iteration=None,
30 | shuffle=True,
31 | resolution_scales=[1.0],
32 | load_coarse=False,
33 | ):
34 | """b
35 | :param path: Path to colmap scene main folder.
36 | """
37 | self.model_path = args.model_path
38 | self.loaded_iter = None
39 | self.dyn_gaussians = gaussians
40 | self.stat_gaussians = static_gaussians
41 |
42 | if load_iteration:
43 | if load_iteration == -1:
44 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
45 | else:
46 | self.loaded_iter = load_iteration
47 | print("Loading trained model at iteration {}".format(self.loaded_iter))
48 |
49 | self.train_cameras = {}
50 | self.test_cameras = {}
51 | self.video_cameras = {}
52 |
53 | assert args.dataset_type in sceneLoadTypeCallbacks.keys(), "Could not recognize scene type!"
54 |
55 | dataset_type = args.dataset_type
56 | scene_info = sceneLoadTypeCallbacks[dataset_type](args)
57 |
58 | self.maxtime = scene_info.maxtime
59 | self.dataset_type = dataset_type
60 | self.cameras_extent = scene_info.nerf_normalization["radius"]
61 | print(f"Original scene extent {self.cameras_extent}")
62 | print("Loading Training Cameras")
63 | self.train_camera = FourDGSdataset(scene_info.train_cameras, args, dataset_type)
64 | print("Loading Test Cameras")
65 | self.test_camera = FourDGSdataset(scene_info.test_cameras, args, dataset_type)
66 | print("Loading Video Cameras")
67 | self.video_camera = FourDGSdataset(scene_info.video_cameras, args, dataset_type)
68 |
69 | # self.video_camera = cameraList_from_camInfos(scene_info.video_cameras,-1,args)
70 | xyz_max = scene_info.point_cloud.points.max(axis=0)
71 | xyz_min = scene_info.point_cloud.points.min(axis=0)
72 |
73 | if self.loaded_iter:
74 | self.dyn_gaussians.load_ply(
75 | os.path.join(self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud.ply")
76 | )
77 | self.dyn_gaussians.load_model(
78 | os.path.join(
79 | self.model_path,
80 | "point_cloud",
81 | "iteration_" + str(self.loaded_iter),
82 | )
83 | )
84 | self.stat_gaussians.load_ply(
85 | os.path.join(
86 | self.model_path, "point_cloud", "iteration_" + str(self.loaded_iter), "point_cloud_static.ply"
87 | )
88 | )
89 | else:
90 | self.dyn_gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime)
91 | self.stat_gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime)
92 |
93 | def save(self, iteration, stage):
94 | if stage == "coarse":
95 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration))
96 |
97 | else:
98 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
99 | self.dyn_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
100 | self.dyn_gaussians.save_deformation(point_cloud_path)
101 |
102 | self.stat_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud_static.ply"))
103 |
104 | def save_best_psnr(self, iteration, stage):
105 | if stage == "coarse":
106 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_best")
107 |
108 | else:
109 | point_cloud_path = os.path.join(self.model_path, "point_cloud/fine_best")
110 | self.dyn_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
111 | self.dyn_gaussians.save_deformation(point_cloud_path)
112 |
113 | self.stat_gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud_static.ply"))
114 |
115 | def getTrainCameras(self, scale=1.0):
116 | return self.train_camera
117 |
118 | def getTestCameras(self, scale=1.0):
119 | return self.test_camera
120 |
121 | def getVideoCameras(self, scale=1.0):
122 | return self.video_camera
123 |
--------------------------------------------------------------------------------
/scene/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from scene.cameras import Camera
3 | from torch.utils.data import Dataset
4 | from utils.general_utils import PILtoTorch
5 | from utils.graphics_utils import focal2fov
6 |
7 |
8 | class FourDGSdataset(Dataset):
9 | def __init__(self, dataset, args, dataset_type):
10 | self.dataset = dataset
11 | self.args = args
12 | self.dataset_type = dataset_type
13 |
14 | def __getitem__(self, index):
15 | if self.dataset_type != "PanopticSports":
16 | try:
17 | image, w2c, time = self.dataset[index]
18 | R, T = w2c
19 | FovX = focal2fov(self.dataset.focal[0], image.shape[2])
20 | FovY = focal2fov(self.dataset.focal[0], image.shape[1])
21 | mask = None
22 | except:
23 | caminfo = self.dataset[index]
24 | image = PILtoTorch(caminfo.image, (caminfo.image.width, caminfo.image.height))
25 | R = caminfo.R
26 | T = caminfo.T
27 | FovX = caminfo.FovX
28 | FovY = caminfo.FovY
29 | time = caminfo.time
30 | mask = caminfo.mask
31 |
32 | return Camera(
33 | colmap_id=index,
34 | R=R,
35 | T=T,
36 | FoVx=FovX,
37 | FoVy=FovY,
38 | image=image,
39 | gt_alpha_mask=None,
40 | image_name=f"{caminfo.image_name}",
41 | uid=index,
42 | data_device=torch.device("cuda"),
43 | time=time,
44 | mask=mask,
45 | metadata=caminfo.metadata,
46 | normal=caminfo.normal,
47 | depth=caminfo.depth,
48 | max_time=caminfo.max_time,
49 | sem_mask=caminfo.sem_mask,
50 | fwd_flow=caminfo.fwd_flow,
51 | bwd_flow=caminfo.bwd_flow,
52 | fwd_flow_mask=caminfo.fwd_flow_mask,
53 | bwd_flow_mask=caminfo.bwd_flow_mask,
54 | instance_mask=caminfo.instance_mask,
55 | target_tracks=caminfo.target_tracks,
56 | target_visibility=caminfo.target_visibility,
57 | target_tracks_static=caminfo.target_tracks_static,
58 | target_visibility_static=caminfo.target_visibility_static,
59 | )
60 | else:
61 | return self.dataset[index]
62 |
63 | def __len__(self):
64 | return len(self.dataset)
65 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train.py -s data/nvidia_rodynrf/Balloon1/ --expname "Balloon1" --configs arguments/nvidia_rodynrf/Balloon1.py
2 | python train.py -s data/nvidia_rodynrf/Balloon2/ --expname "Balloon2" --configs arguments/nvidia_rodynrf/Balloon2.py
3 | python train.py -s data/nvidia_rodynrf/Playground/ --expname "Playground" --configs arguments/nvidia_rodynrf/Playground.py
4 | python train.py -s data/nvidia_rodynrf/Jumping/ --expname "Jumping" --configs arguments/nvidia_rodynrf/Jumping.py
5 | python train.py -s data/nvidia_rodynrf/Truck/ --expname "Truck" --configs arguments/nvidia_rodynrf/Truck.py
6 | python train.py -s data/nvidia_rodynrf/Skating/ --expname "Skating" --configs arguments/nvidia_rodynrf/Skating.py
7 | python train.py -s data/nvidia_rodynrf/Umbrella/ --expname "Umbrella" --configs arguments/nvidia_rodynrf/Umbrella.py
--------------------------------------------------------------------------------
/utils/TIMES.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMES.TTF
--------------------------------------------------------------------------------
/utils/TIMESBD.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESBD.TTF
--------------------------------------------------------------------------------
/utils/TIMESBI.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESBI.TTF
--------------------------------------------------------------------------------
/utils/TIMESI.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-VICLab/SplineGS/5030b35285b91afb73eccb9b7783797f31d97d39/utils/TIMESI.TTF
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | from scene.cameras import Camera
14 | from utils.graphics_utils import fov2focal
15 |
16 |
17 | WARNED = False
18 |
19 |
20 | def loadCam(args, id, cam_info, resolution_scale):
21 | # resized_image_rgb = PILtoTorch(cam_info.image, resolution)
22 |
23 | # gt_image = resized_image_rgb[:3, ...]
24 | # loaded_mask = None
25 |
26 | # if resized_image_rgb.shape[1] == 4:
27 | # loaded_mask = resized_image_rgb[3:4, ...]
28 |
29 | return Camera(
30 | colmap_id=cam_info.uid,
31 | R=cam_info.R,
32 | T=cam_info.T,
33 | FoVx=cam_info.FovX,
34 | FoVy=cam_info.FovY,
35 | image=cam_info.image,
36 | gt_alpha_mask=None,
37 | image_name=cam_info.image_name,
38 | uid=id,
39 | data_device=args.data_device,
40 | time=cam_info.time,
41 | metadata=cam_info.metadata,
42 | )
43 |
44 |
45 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
46 | camera_list = []
47 |
48 | for id, c in enumerate(cam_infos):
49 | camera_list.append(loadCam(args, id, c, resolution_scale))
50 |
51 | return camera_list
52 |
53 |
54 | def camera_to_JSON(id, camera: Camera):
55 | Rt = np.zeros((4, 4))
56 | Rt[:3, :3] = camera.R.transpose()
57 | Rt[:3, 3] = camera.T
58 | Rt[3, 3] = 1.0
59 |
60 | W2C = np.linalg.inv(Rt)
61 | pos = W2C[:3, 3]
62 | rot = W2C[:3, :3]
63 | serializable_array_2d = [x.tolist() for x in rot]
64 | camera_entry = {
65 | "id": id,
66 | "img_name": camera.image_name,
67 | "width": camera.width,
68 | "height": camera.height,
69 | "position": pos.tolist(),
70 | "rotation": serializable_array_2d,
71 | "fy": fov2focal(camera.FovY, camera.height),
72 | "fx": fov2focal(camera.FovX, camera.width),
73 | }
74 | return camera_entry
75 |
--------------------------------------------------------------------------------
/utils/depth_loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from math import exp
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.autograd import Variable
17 |
18 |
19 | def normalize(input, mean=None, std=None):
20 | input_mean = torch.mean(input, dim=1, keepdim=True) if mean is None else mean
21 | input_std = torch.std(input, dim=1, keepdim=True) if std is None else std
22 | return (input - input_mean) / (input_std + 1e-2 * torch.std(input.reshape(-1)))
23 |
24 |
25 | def shuffle(input):
26 | # shuffle dim=1
27 | idx = torch.randperm(input[0].shape[1])
28 | for i in range(input.shape[0]):
29 | input[i] = input[i][:, idx].view(input[i].shape)
30 |
31 |
32 | def loss_depth_smoothness(depth, img):
33 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:]
34 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :]
35 | weight_x = torch.exp(-torch.abs(img_grad_x).mean(1).unsqueeze(1))
36 | weight_y = torch.exp(-torch.abs(img_grad_y).mean(1).unsqueeze(1))
37 |
38 | loss = (
39 | ((depth[:, :, :, :-1] - depth[:, :, :, 1:]).abs() * weight_x).sum()
40 | + ((depth[:, :, :-1, :] - depth[:, :, 1:, :]).abs() * weight_y).sum()
41 | ) / (weight_x.sum() + weight_y.sum())
42 | return loss
43 |
44 |
45 | def loss_depth_grad(depth, img):
46 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:]
47 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :]
48 | weight_x = img_grad_x / (torch.abs(img_grad_x) + 1e-6)
49 | weight_y = img_grad_y / (torch.abs(img_grad_y) + 1e-6)
50 |
51 | depth_grad_x = depth[:, :, :, :-1] - depth[:, :, :, 1:]
52 | depth_grad_y = depth[:, :, :-1, :] - depth[:, :, 1:, :]
53 | grad_x = depth_grad_x / (torch.abs(depth_grad_x) + 1e-6)
54 | grad_y = depth_grad_y / (torch.abs(depth_grad_y) + 1e-6)
55 |
56 | loss = l1_loss(grad_x, weight_x) + l1_loss(grad_y, weight_y)
57 | return loss
58 |
59 |
60 | def l1_loss(network_output, gt):
61 | return torch.abs((network_output - gt)).mean()
62 |
63 |
64 | def l2_loss(network_output, gt):
65 | return ((network_output - gt) ** 2).mean()
66 |
67 |
68 | def margin_l2_loss(network_output, gt, margin, return_mask=False):
69 | mask = (network_output - gt).abs() > margin
70 | if not return_mask:
71 | return ((network_output - gt)[mask] ** 2).mean()
72 | else:
73 | return ((network_output - gt)[mask] ** 2).mean(), mask
74 |
75 |
76 | def margin_l1_loss(network_output, gt, margin, return_mask=False):
77 | mask = (network_output - gt).abs() > margin
78 | if not return_mask:
79 | return ((network_output - gt)[mask].abs()).mean()
80 | else:
81 | return ((network_output - gt)[mask].abs()).mean(), mask
82 |
83 |
84 | def kl_loss(input, target):
85 | input = F.log_softmax(input, dim=-1)
86 | target = F.softmax(target, dim=-1)
87 | return F.kl_div(input, target, reduction="batchmean")
88 |
89 |
90 | def patchify(input, patch_size):
91 | patches = (
92 | F.unfold(input, kernel_size=patch_size, stride=patch_size)
93 | .permute(0, 2, 1)
94 | .reshape(-1, 1 * patch_size * patch_size)
95 | )
96 | return patches
97 |
98 |
99 | def patch_norm_mse_loss(input, target, patch_size, margin, return_mask=False):
100 | input_patches = normalize(patchify(input, patch_size))
101 | target_patches = normalize(patchify(target, patch_size))
102 | return margin_l2_loss(input_patches, target_patches, margin, return_mask)
103 |
104 |
105 | def patch_norm_mse_loss_global(input, target, patch_size, margin, return_mask=False):
106 | input_patches = normalize(patchify(input, patch_size), std=input.std().detach())
107 | target_patches = normalize(patchify(target, patch_size), std=target.std().detach())
108 | return margin_l2_loss(input_patches, target_patches, margin, return_mask)
109 |
110 |
111 | def patch_norm_l1_loss_global(input, target, patch_size, margin, return_mask=False):
112 | input_patches = normalize(patchify(input, patch_size), std=input.std().detach())
113 | target_patches = normalize(patchify(target, patch_size), std=target.std().detach())
114 | return margin_l1_loss(input_patches, target_patches, margin, return_mask)
115 |
116 |
117 | def patch_norm_l1_loss(input, target, patch_size, margin, return_mask=False):
118 | input_patches = normalize(patchify(input, patch_size))
119 | target_patches = normalize(patchify(target, patch_size))
120 | return margin_l1_loss(input_patches, target_patches, margin, return_mask)
121 |
122 |
123 | def gaussian(window_size, sigma):
124 | gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
125 | return gauss / gauss.sum()
126 |
127 |
128 | def create_window(window_size, channel):
129 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
130 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
131 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
132 | return window
133 |
134 |
135 | def margin_ssim(img1, img2, window_size=11, size_average=True):
136 | result = ssim(img1, img2, window_size, False)
137 | print(result.shape)
138 |
139 |
140 | def ssim(img1, img2, window_size=11, size_average=True):
141 | channel = img1.size(-3)
142 | window = create_window(window_size, channel)
143 |
144 | if img1.is_cuda:
145 | window = window.cuda(img1.get_device())
146 | window = window.type_as(img1)
147 |
148 | return _ssim(img1, img2, window, window_size, channel, size_average)
149 |
150 |
151 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
152 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
153 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
154 |
155 | mu1_sq = mu1.pow(2)
156 | mu2_sq = mu2.pow(2)
157 | mu1_mu2 = mu1 * mu2
158 |
159 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
160 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
161 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
162 |
163 | C1 = 0.01**2
164 | C2 = 0.03**2
165 |
166 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
167 |
168 | if size_average:
169 | return ssim_map.mean()
170 | else:
171 | return ssim_map.mean(1).mean(1).mean(1)
172 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/annotation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : annotation.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import functools
21 | from typing import List
22 |
23 | import cv2
24 | import jax
25 | import numpy as np
26 | from ipyevents import Event
27 | from IPython.display import clear_output, display
28 | from ipywidgets import HTML, Button, HBox, Image, Output
29 |
30 | from . import common, visuals
31 |
32 |
33 | def annotate_record3d_bad_frames(
34 | frames: np.ndarray,
35 | *,
36 | frame_ext: str = ".png",
37 | ) -> List[np.ndarray]:
38 | """Interactively annotate bad frames in validation set to skip later.
39 |
40 | "Bad images" exist because the hand may occlude the scene during capturing.
41 |
42 | Args:
43 | frames (np.ndarray]): images of shape (N, H, W, 3) in uint8 to annotate.
44 | skel (Skeleton): a skeleton definition object.
45 | frame_ext (str): the extension of images. This is used for decoding and
46 | then display. Default: '.png'.
47 |
48 | Return:
49 | np.ndarray: bad frame indices of shape (N_bad,).
50 | """
51 |
52 | def _frame_to_widget(frame):
53 | value = cv2.imencode(frame_ext, frame[..., ::-1])[1].tobytes()
54 | widget = Image(value=value, format=frame_ext[1:])
55 | widget.layout.max_width = "100%"
56 | widget.layout.height = "auto"
57 | return widget
58 |
59 | out = Output()
60 |
61 | frame_widgets = list(
62 | map(
63 | _frame_to_widget,
64 | common.tqdm(frames, desc="* Decoding frames"),
65 | )
66 | )
67 | bad_frames = []
68 |
69 | frame_idx = -1
70 | frame_msg = HTML()
71 |
72 | def show_next_frame():
73 | nonlocal frame_idx, frame_msg
74 | frame_idx += 1
75 |
76 | frame_msg.value = f"{frame_idx} frames annotated, " f"{len(frames) - frame_idx} frames left. "
77 |
78 | if frame_idx == len(frames):
79 | for btn in buttons:
80 | if btn is not redo_btn:
81 | btn.disabled = True
82 | print("Annotation done.")
83 | return
84 |
85 | frame_widget = frame_widgets[frame_idx]
86 | with out:
87 | clear_output(wait=True)
88 | display(frame_widget)
89 |
90 | def mark_frame(_, good):
91 | nonlocal frame_idx, bad_frames
92 | if not good:
93 | bad_frames.append(frame_idx)
94 |
95 | show_next_frame()
96 |
97 | def redo_frame(btn):
98 | nonlocal frame_idx, bad_frames
99 | if len(bad_frames) > 0 and bad_frames[-1] == frame_idx - 1:
100 | _ = bad_frames.pop()
101 | frame_idx = max(-1, frame_idx - 2)
102 |
103 | for btn in buttons:
104 | if btn is not redo_btn:
105 | btn.disabled = False
106 |
107 | show_next_frame()
108 |
109 | buttons = []
110 |
111 | valid_btn = Button(description="👍")
112 | valid_btn.on_click(functools.partial(mark_frame, good=True))
113 | buttons.append(valid_btn)
114 |
115 | invalid_btn = Button(description="👎")
116 | invalid_btn.on_click(functools.partial(mark_frame, good=False))
117 | buttons.append(invalid_btn)
118 |
119 | redo_btn = Button(description="♻️")
120 | redo_btn.on_click(redo_frame)
121 | buttons.append(redo_btn)
122 |
123 | display(HBox([frame_msg]))
124 | display(HBox(buttons))
125 | display(HBox([out]))
126 |
127 | show_next_frame()
128 |
129 | return bad_frames
130 |
131 |
132 | def annotate_keypoints(
133 | frames: np.ndarray,
134 | skeleton: visuals.Skeleton,
135 | *,
136 | frame_ext: str = ".png",
137 | **kwargs,
138 | ) -> List[np.ndarray]:
139 | """Interactively annotate keypoints on input frames.
140 |
141 | Note that each frame will only be submitted when finished.
142 |
143 | Args:
144 | frames (np.ndarray]): images of shape (N, H, W, 3) in uint8 to annotate.
145 | skel (Skeleton): a skeleton definition object.
146 | frame_ext (str): the extension of images. This is used for decoding and
147 | then display. Default: '.png'.
148 |
149 | Return:
150 | keypoints (List[np.ndarray]): a list of N annotated keypoints of shape
151 | (J, 3) where the last column is visibility in [0, 1].
152 | """
153 |
154 | def _frame_to_widget(frame):
155 | value = cv2.imencode(frame_ext, np.array(frame[..., ::-1]))[1].tobytes()
156 | widget = Image(value=value, format=frame_ext[1:])
157 | widget.layout.max_width = "100%"
158 | widget.layout.height = "auto"
159 | return widget
160 |
161 | frame_widgets = jax.tree_map(_frame_to_widget, list(frames))
162 | keypoints = []
163 |
164 | frame_idx, kps = -1, []
165 |
166 | event, frame_msg, kp_msg, kp_inst = Event(), HTML(), HTML(), HTML()
167 |
168 | def show_next_frame():
169 | nonlocal frame_idx, frame_msg, event
170 | frame_idx += 1
171 |
172 | frame_msg.value = f"{len(keypoints)} frames annotated, " f"{len(frames) - frame_idx} frames left."
173 |
174 | if frame_idx == len(frames):
175 | _show_current_kp_name()
176 | for btn in buttons:
177 | if btn is not redo_btn:
178 | btn.disabled = True
179 | event.watched_events = []
180 | print("Annotation done.")
181 | return
182 | else:
183 | _show_current_kp_name()
184 | for btn in buttons:
185 | btn.disabled = False
186 |
187 | _mark_next_kp()
188 |
189 | frame_widget = frame_widgets[frame_idx]
190 | with out:
191 | clear_output(wait=True)
192 | display(frame_widget)
193 | event.source = frame_widget
194 | event.watched_events = ["click"]
195 | event.on_dom_event(mark_kp)
196 |
197 | def _show_kp_visual():
198 | padded_kps = np.array(
199 | kps + [[0, 0, 0] for _ in range(skeleton.num_kps - len(kps))],
200 | dtype=np.float32,
201 | )
202 | canvas = frames[frame_idx]
203 | kp_visual = visuals.visualize_kps(padded_kps, canvas, skeleton=skeleton, **kwargs)
204 | visual_widget = _frame_to_widget(kp_visual)
205 | with kp_visual_out:
206 | clear_output(wait=True)
207 | display(visual_widget)
208 |
209 | def _mark_next_kp():
210 | _show_kp_visual()
211 | if len(kps) == skeleton.num_kps:
212 | submit_frame()
213 |
214 | _show_current_kp_name()
215 |
216 | nonlocal kp_msg
217 | kp_msg.value = f"{len(kps)} keypoints annotated, " f"{len(skeleton.kp_names) - len(kps)} keypoints left."
218 |
219 | def _show_current_kp_name():
220 | nonlocal frame_idx
221 | if frame_idx == len(frames):
222 | msg = "FINISHED!"
223 | else:
224 | kp_name = skeleton.kp_names[len(kps)]
225 | msg = f"Marking [{kp_name}]..."
226 | kp_inst.value = msg
227 |
228 | def mark_kp(event):
229 | kp = np.array([event["dataX"], event["dataY"], 1], dtype=np.float32)
230 | nonlocal kps
231 | kps.append(kp)
232 |
233 | _mark_next_kp()
234 |
235 | def redo_kp(_):
236 | nonlocal kps, keypoints, frame_idx
237 | if len(kps) == 0 and len(keypoints) > 0:
238 | kps = keypoints.pop().tolist()
239 | kps = kps[:-1]
240 | frame_idx -= 2
241 | show_next_frame()
242 | else:
243 | kps.pop()
244 | _mark_next_kp()
245 |
246 | def mark_invisible_kp(_):
247 | kp = np.array([0, 0, 0], dtype=np.float32)
248 | nonlocal kps
249 | kps.append(kp)
250 |
251 | _mark_next_kp()
252 |
253 | def submit_frame():
254 | nonlocal kps
255 | keypoints.append(np.array(kps))
256 | kps = []
257 |
258 | show_next_frame()
259 |
260 | buttons = []
261 |
262 | invisible_btn = Button(description="🫥")
263 | invisible_btn.on_click(mark_invisible_kp)
264 | buttons.append(invisible_btn)
265 |
266 | redo_btn = Button(description="♻️")
267 | redo_btn.on_click(redo_kp)
268 | buttons.append(redo_btn)
269 |
270 | display(HBox([kp_inst]))
271 | display(HBox([frame_msg, kp_msg]))
272 | display(HBox(buttons))
273 | out, kp_visual_out = Output(), Output()
274 | display(HBox([out, kp_visual_out]))
275 | show_next_frame()
276 |
277 | return keypoints
278 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/common.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : common.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import functools
21 | import inspect
22 | from concurrent import futures
23 | from copy import copy
24 | from typing import Any, Callable, Dict, Iterable, Optional, Sequence
25 |
26 | # import jax
27 | import numpy as np
28 |
29 |
30 | def tolerant_partial(fn: Callable, *args, **kwargs) -> Callable:
31 | """A thin wrapper around functools.partial which only binds the keyword
32 | arguments that matches the function signature.
33 | """
34 | signatures = inspect.signature(fn)
35 | return functools.partial(
36 | fn,
37 | *args,
38 | **{k: v for k, v in kwargs.items() if k in signatures.parameters},
39 | )
40 |
41 |
42 | def traverse_filter(
43 | data_dict: Dict[str, Any],
44 | exclude_fields: Sequence[str] = (),
45 | return_fields: Sequence[str] = (),
46 | protect_fields: Sequence[str] = (),
47 | inplace: bool = False,
48 | ) -> Dict[str, Any]:
49 | """Keep matched field values within the dictionary, either inplace or not.
50 |
51 | Args:
52 | data_dict (Dict[str, Any]): A dictionary to be filtered.
53 | exclude_fields (Sequence[str]): A list of fields to be excluded.
54 | return_fields (Sequence[str]): A list of fields to be returned.
55 | protect_fields (Sequence[str]): A list of fields to be protected.
56 | inplace (bool): Whether to modify the input dictionary inplace.
57 |
58 | Returns:
59 | Dict[str, Any]: The filtered dictionary.
60 | """
61 | assert isinstance(data_dict, dict)
62 |
63 | str_to_tupid = lambda s: tuple(s.split("/"))
64 | exclude_fields = [str_to_tupid(f) for f in set(exclude_fields)]
65 | return_fields = [str_to_tupid(f) for f in set(return_fields)]
66 | protect_fields = [str_to_tupid(f) for f in set(protect_fields)]
67 |
68 | filter_fn = lambda f: f in protect_fields or (
69 | f in return_fields if len(return_fields) > 0 else f not in exclude_fields
70 | )
71 |
72 | if not inplace:
73 | data_dict = copy(data_dict)
74 |
75 | def delete_filtered(d, prefix):
76 | if isinstance(d, dict):
77 | for k in list(d.keys()):
78 | path = prefix + (k,)
79 | if (not isinstance(d[k], dict) or len(d[k]) == 0) and not filter_fn(path):
80 | del d[k]
81 | else:
82 | delete_filtered(d[k], path)
83 |
84 | delete_filtered(data_dict, ())
85 | return data_dict
86 |
87 |
88 | @functools.lru_cache(maxsize=None)
89 | def in_notebook() -> bool:
90 | """Check if the code is running in a notebook."""
91 | try:
92 | from IPython import get_ipython
93 |
94 | ipython = get_ipython()
95 | if not ipython or "IPKernelApp" not in ipython.config:
96 | return False
97 | except ImportError:
98 | return False
99 | return True
100 |
101 |
102 | def tqdm(iterable: Iterable, *args, **kwargs) -> Iterable:
103 | if not in_notebook():
104 | from tqdm import tqdm as _tqdm
105 | else:
106 | from tqdm.notebook import tqdm as _tqdm
107 | return _tqdm(iterable, *args, **kwargs)
108 |
109 |
110 | def parallel_map(
111 | func: Callable,
112 | *iterables: Sequence[Iterable],
113 | max_threads: Optional[int] = None,
114 | show_pbar: bool = False,
115 | desc: Optional[str] = None,
116 | pbar_kwargs: Dict[str, Any] = {},
117 | debug: bool = False,
118 | **kwargs,
119 | ) -> Sequence[Any]:
120 | """Parallel version of map()."""
121 | if not debug:
122 | with futures.ThreadPoolExecutor(max_threads) as executor:
123 | if show_pbar:
124 | results = list(
125 | tqdm(
126 | executor.map(func, *iterables, **kwargs),
127 | desc=desc,
128 | total=len(iterables[0]),
129 | **pbar_kwargs,
130 | )
131 | )
132 | else:
133 | results = list(executor.map(func, *iterables, **kwargs))
134 | return results
135 | else:
136 | return list(map(func, *iterables, **kwargs))
137 |
138 |
139 | def tree_collate(trees: Sequence[Any], collate_fn=lambda *x: np.asarray(x)):
140 | """Collates a list of pytrees with the same structure."""
141 | return jax.tree_map(collate_fn, *trees)
142 |
143 |
144 | def strided_subset(sequence: Sequence[Any], count: int) -> Sequence[Any]:
145 | if count > len(sequence):
146 | raise ValueError("count must be less than or equal to len(sequence)")
147 | inds = np.linspace(0, len(sequence), count, dtype=int, endpoint=False)
148 | if isinstance(sequence, np.ndarray):
149 | sequence = sequence[inds]
150 | else:
151 | sequence = [sequence[i] for i in inds]
152 | return sequence
153 |
154 |
155 | def random_subset(sequence: Sequence[Any], count: int, seed: int = 0) -> Sequence[Any]:
156 | if count > len(sequence):
157 | raise ValueError("count must be less than or equal to len(sequence)")
158 | rng = np.random.default_rng(seed)
159 | inds = rng.choice(len(sequence), count, replace=False)
160 | if isinstance(sequence, np.ndarray):
161 | sequence = sequence[inds]
162 | else:
163 | sequence = [sequence[i] for i in inds]
164 | return sequence
165 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : image.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import math
21 | from typing import Any, Tuple
22 |
23 | import cv2
24 | import numpy as np
25 |
26 | from . import types
27 |
28 |
29 | # from absl import logging
30 |
31 |
32 | UINT8_MAX = 255
33 | UINT16_MAX = 65535
34 |
35 |
36 | def downscale(img: types.Array, scale: int) -> np.ndarray:
37 | if isinstance(img, np.ndarray):
38 | img = np.array(img)
39 |
40 | if scale == 1:
41 | return img
42 |
43 | height, width = img.shape[:2]
44 | if height % scale > 0 or width % scale > 0:
45 | raise ValueError(f"Image shape ({height},{width}) must be divisible by the" f" scale ({scale}).")
46 | out_height, out_width = height // scale, width // scale
47 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA)
48 | return resized
49 |
50 |
51 | def upscale(img: types.Array, scale: int) -> np.ndarray:
52 | if isinstance(img, np.ndarray):
53 | img = np.array(img)
54 |
55 | if scale == 1:
56 | return img
57 |
58 | height, width = img.shape[:2]
59 | out_height, out_width = height * scale, width * scale
60 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA)
61 | return resized
62 |
63 |
64 | def rescale(img: types.Array, scale_factor: float, interpolation: Any = cv2.INTER_AREA) -> np.ndarray:
65 | scale_factor = float(scale_factor)
66 |
67 | if scale_factor <= 0.0:
68 | raise ValueError("scale_factor must be a non-negative number.")
69 | if scale_factor == 1.0:
70 | return img
71 |
72 | height, width = img.shape[:2]
73 | if scale_factor.is_integer():
74 | return upscale(img, int(scale_factor))
75 |
76 | inv_scale = 1.0 / scale_factor
77 | if inv_scale.is_integer() and (scale_factor * height).is_integer() and (scale_factor * width).is_integer():
78 | return downscale(img, int(inv_scale))
79 |
80 | logging.warning(
81 | "Resizing image by non-integer factor %f, this may lead to artifacts.",
82 | scale_factor,
83 | )
84 |
85 | height, width = img.shape[:2]
86 | out_height = math.ceil(height * scale_factor)
87 | out_height -= out_height % 2
88 | out_width = math.ceil(width * scale_factor)
89 | out_width -= out_width % 2
90 |
91 | return resize(img, (out_height, out_width), interpolation)
92 |
93 |
94 | def resize(
95 | img: types.Array,
96 | shape: Tuple[int, int],
97 | interpolation: Any = cv2.INTER_AREA,
98 | ) -> np.ndarray:
99 | if isinstance(img, np.ndarray):
100 | img = np.array(img)
101 |
102 | out_height, out_width = shape
103 | return cv2.resize(
104 | img,
105 | (out_width, out_height),
106 | interpolation=interpolation,
107 | )
108 |
109 |
110 | def varlap(img: types.Array) -> np.ndarray:
111 | """Measure the focus/motion-blur of an image by the Laplacian variance."""
112 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
113 | return cv2.Laplacian(gray, cv2.CV_64F).var()
114 |
115 |
116 | def to_float32(img: types.Array) -> np.ndarray:
117 | img = np.array(img)
118 | if img.dtype == np.float32:
119 | return img
120 |
121 | dtype = img.dtype
122 | img = img.astype(np.float32)
123 | if dtype == np.uint8:
124 | return img / UINT8_MAX
125 | elif dtype == np.uint16:
126 | return img / UINT16_MAX
127 | elif dtype == np.float64:
128 | return img
129 | elif dtype == np.float16:
130 | return img
131 |
132 | raise ValueError(f"Unexpected dtype: {dtype}.")
133 |
134 |
135 | def to_quantized_float32(img: types.Array) -> np.ndarray:
136 | return to_float32(to_uint8(img))
137 |
138 |
139 | def to_uint8(img: types.Array) -> np.ndarray:
140 | img = np.array(img)
141 | if img.dtype == np.uint8:
142 | return img
143 | if not issubclass(img.dtype.type, np.floating):
144 | raise ValueError(f"Input image should be a floating type but is of type " f"{img.dtype!r}.")
145 | return (img * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8)
146 |
147 |
148 | def to_uint16(img: types.Array) -> np.ndarray:
149 | img = np.array(img)
150 | if img.dtype == np.uint16:
151 | return img
152 | if not issubclass(img.dtype.type, np.floating):
153 | raise ValueError(f"Input image should be a floating type but is of type " f"{img.dtype!r}.")
154 | return (img * UINT16_MAX).clip(0.0, UINT16_MAX).astype(np.uint16)
155 |
156 |
157 | # Special forms of images.
158 | def rescale_flow(
159 | flow: types.Array,
160 | scale_factor: float,
161 | interpolation: Any = cv2.INTER_LINEAR,
162 | ) -> np.ndarray:
163 | height, width = flow.shape[:2]
164 |
165 | out_flow = rescale(flow, scale_factor, interpolation)
166 |
167 | out_height, out_width = out_flow.shape[:2]
168 | out_flow[..., 0] *= float(out_width) / float(width)
169 | out_flow[..., 1] *= float(out_height) / float(height)
170 | return out_flow
171 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : io.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import json
21 | import os.path as osp
22 | import pickle as pkl
23 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union # Literal,
24 |
25 | import cv2
26 | import ffmpeg
27 | import numpy as np
28 |
29 | from . import common, image, path_ops, types
30 |
31 |
32 | _LOAD_REGISTRY, _DUMP_REGISTRY = {}, {}
33 |
34 | IMG_EXTS = (
35 | ".jpg",
36 | ".jpeg",
37 | ".png",
38 | ".ppm",
39 | ".bmp",
40 | ".pgm",
41 | ".tif",
42 | ".tiff",
43 | ".webp",
44 | )
45 | VID_EXTS = (".mov", ".avi", ".mpg", ".mpeg", ".mp4", ".mkv", ".wmv", ".gif")
46 |
47 |
48 | def _register(
49 | *,
50 | ext: Optional[Union[str, Sequence[str]]] = None,
51 | ) -> Callable:
52 | if isinstance(ext, str):
53 | ext = [ext]
54 |
55 | def _inner_register(func: Callable) -> Callable:
56 | name = func.__name__
57 | assert name.startswith("load_") or name.startswith("dump_")
58 |
59 | nonlocal ext
60 | if ext is None:
61 | ext = ["." + name[5:]]
62 | for e in ext:
63 | if name.startswith("load_"):
64 | _LOAD_REGISTRY[e] = func
65 | else:
66 | _DUMP_REGISTRY[e] = func
67 |
68 | return func
69 |
70 | return _inner_register
71 |
72 |
73 | def _dispatch(registry: Dict[str, Callable], name) -> Callable:
74 | def _dispatched(filename: types.PathType, *args, **kwargs):
75 | ext = path_ops.get_ext(filename, match_first=False)
76 | func = registry[ext]
77 | if name == "dump" and osp.dirname(filename) != "" and not osp.exists(osp.dirname(filename)):
78 | path_ops.mkdir(osp.dirname(filename))
79 | return func(filename, *args, **kwargs)
80 |
81 | _dispatched.__name__ = name
82 | return _dispatched
83 |
84 |
85 | load = _dispatch(_LOAD_REGISTRY, "load")
86 | dump = _dispatch(_DUMP_REGISTRY, "dump")
87 |
88 |
89 | @_register()
90 | def load_txt(filename: types.PathType, *, strip: bool = True, **kwargs) -> List[str]:
91 | with open(filename) as f:
92 | lines = f.readlines(**kwargs)
93 | if strip:
94 | lines = [line.strip() for line in lines]
95 | return lines
96 |
97 |
98 | @_register()
99 | def dump_txt(filename: types.PathType, obj: List[Any], **_) -> None:
100 | # Prefer visual appearance over compactness.
101 | obj = "\n".join([str(item) for item in obj])
102 | with open(filename, "w") as f:
103 | f.write(obj)
104 |
105 |
106 | @_register()
107 | def load_json(filename: types.PathType, **kwargs) -> Dict:
108 | with open(filename) as f:
109 | return json.load(f, **kwargs)
110 |
111 |
112 | @_register()
113 | def dump_json(
114 | filename: types.PathType,
115 | obj: Dict,
116 | *,
117 | sort_keys: bool = True,
118 | indent: Optional[int] = 4,
119 | separators: Tuple[str, str] = (",", ": "),
120 | **kwargs,
121 | ) -> None:
122 | # Process potential numpy arrays.
123 | if isinstance(obj, dict):
124 | obj = {k: v.tolist() if hasattr(v, "tolist") else v for k, v in obj.items()}
125 | elif isinstance(obj, (list, tuple)):
126 | pass
127 | elif isinstance(obj, np.ndarray):
128 | obj = obj.tolist()
129 | else:
130 | raise ValueError(f"{type(obj)} is not a supported type.")
131 | # Prefer visual appearance over compactness.
132 | with open(filename, "w") as f:
133 | json.dump(
134 | obj,
135 | f,
136 | sort_keys=sort_keys,
137 | indent=indent,
138 | separators=separators,
139 | **kwargs,
140 | )
141 |
142 |
143 | @_register()
144 | def load_pkl(filename: types.PathType, **kwargs) -> Dict:
145 | with open(filename, "rb") as f:
146 | try:
147 | return pkl.load(f, **kwargs)
148 | except UnicodeDecodeError as e:
149 | if "encoding" in kwargs:
150 | raise e
151 | return load_pkl(filename, encoding="latin1", **kwargs)
152 |
153 |
154 | @_register()
155 | def dump_pkl(filename: types.PathType, obj: Dict, **kwargs) -> None:
156 | with open(filename, "wb") as f:
157 | pkl.dump(obj, f, **kwargs)
158 |
159 |
160 | @_register()
161 | def load_npy(
162 | filename: types.PathType, *, allow_pickle: bool = True, **kwargs
163 | ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
164 | return np.load(filename, allow_pickle=allow_pickle, **kwargs)
165 |
166 |
167 | @_register()
168 | def dump_npy(filename: types.PathType, obj: np.ndarray, **kwargs) -> None:
169 | np.save(filename, obj, **kwargs)
170 |
171 |
172 | @_register()
173 | def load_npz(filename: types.PathType, **kwargs) -> np.lib.npyio.NpzFile:
174 | return np.load(filename, **kwargs)
175 |
176 |
177 | @_register()
178 | def dump_npz(filename: types.PathType, **kwargs) -> None:
179 | # Disable positional argument for np.savez.
180 | np.savez(filename, **kwargs)
181 |
182 |
183 | @_register(ext=IMG_EXTS)
184 | def load_img(filename: types.PathType, *, use_rgb: bool = True, **kwargs) -> np.ndarray:
185 | img = cv2.imread(filename, **kwargs)
186 | if use_rgb and img.shape[-1] >= 3:
187 | # Take care of RGBA case when flipping.
188 | img = np.concatenate([img[..., 2::-1], img[..., 3:]], axis=-1)
189 | return img
190 |
191 |
192 | @_register(ext=IMG_EXTS)
193 | def dump_img(
194 | filename: types.PathType,
195 | obj: np.ndarray,
196 | *,
197 | use_rgb: bool = True,
198 | **kwargs,
199 | ) -> None:
200 | if use_rgb and obj.shape[-1] >= 3:
201 | obj = np.concatenate([obj[..., 2::-1], obj[..., 3:]], axis=-1)
202 | cv2.imwrite(filename, image.to_uint8(obj), **kwargs)
203 |
204 |
205 | def load_vid_metadata(filename: types.PathType) -> np.ndarray:
206 | assert osp.exists(filename), f"{filename} does not exist!"
207 | try:
208 | probe = ffmpeg.probe(filename)
209 | except ffmpeg.Error as e:
210 | print("stdout:", e.stdout.decode("utf8"))
211 | print("stderr:", e.stderr.decode("utf8"))
212 | raise e
213 | metadata = next(stream for stream in probe["streams"] if stream["codec_type"] == "video")
214 | metadata["fps"] = float(eval(metadata["r_frame_rate"]))
215 | return metadata
216 |
217 |
218 | @_register(ext=VID_EXTS)
219 | def load_vid(
220 | filename: types.PathType,
221 | *,
222 | quiet: bool = True,
223 | trim_kwargs: Dict[str, Any] = {},
224 | **_,
225 | ) -> Dict:
226 | vid_metadata = load_vid_metadata(filename)
227 | W = int(vid_metadata["width"])
228 | H = int(vid_metadata["height"])
229 |
230 | stream = ffmpeg.input(filename)
231 | if len(trim_kwargs) > 0:
232 | stream = ffmpeg.trim(stream, **trim_kwargs).setpts("PTS-STARTPTS")
233 | stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24")
234 | out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=quiet)
235 | out = np.frombuffer(out, np.uint8).reshape([-1, H, W, 3])
236 | return out.copy()
237 |
238 |
239 | @_register(ext=VID_EXTS)
240 | def dump_vid(
241 | filename: types.PathType,
242 | obj: Union[List[np.ndarray], np.ndarray],
243 | *,
244 | fps: float,
245 | quiet: bool = True,
246 | show_pbar: bool = True,
247 | desc: Optional[str] = "* Dumping video",
248 | **kwargs,
249 | ) -> None:
250 | if not isinstance(obj, np.ndarray):
251 | obj = np.asarray(obj)
252 | obj = image.to_uint8(obj)
253 |
254 | H, W = obj.shape[1:3]
255 | stream = ffmpeg.input(
256 | "pipe:",
257 | format="rawvideo",
258 | pix_fmt="rgb24",
259 | s="{}x{}".format(W, H),
260 | r=fps,
261 | )
262 | process = (
263 | stream.output(filename, pix_fmt="yuv420p", vcodec="libx264")
264 | .overwrite_output()
265 | .run_async(pipe_stdin=True, quiet=quiet)
266 | )
267 | obj_bytes = common.parallel_map(
268 | lambda f: f.tobytes(),
269 | list(obj),
270 | show_pbar=show_pbar,
271 | desc=desc,
272 | **kwargs,
273 | )
274 | for b in obj_bytes:
275 | process.stdin.write(b)
276 | process.stdin.close()
277 | process.wait()
278 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/path_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : path_ops.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import glob
21 | import os
22 | import os.path as osp
23 | import re
24 | import shutil
25 | from typing import List
26 |
27 | from . import types
28 |
29 |
30 | def get_ext(filename: types.PathType, match_first: bool = False) -> types.PathType:
31 | if match_first:
32 | filename = osp.split(filename)[1]
33 | return filename[filename.find(".") :]
34 | else:
35 | return osp.splitext(filename)[1]
36 |
37 |
38 | def basename(filename: types.PathType, with_ext: bool = True, **kwargs) -> types.PathType:
39 | name = osp.basename(filename, **kwargs)
40 | if not with_ext:
41 | name = name.replace(get_ext(name), "")
42 | return name
43 |
44 |
45 | def natural_sorted(lst: List[types.PathType]) -> List[types.PathType]:
46 | convert = lambda text: int(text) if text.isdigit() else text.lower()
47 | alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]
48 | return sorted(lst, key=alphanum_key)
49 |
50 |
51 | def mtime_sorted(lst: List[types.PathType]) -> List[types.PathType]:
52 | # Ascending order: last modified file will be the last one.
53 | return sorted(lst, key=lambda p: os.stat(p).st_mtime)
54 |
55 |
56 | def ls(
57 | pattern: str,
58 | *,
59 | type: str = "a",
60 | latestk: int = -1,
61 | exclude: bool = False,
62 | ) -> List[types.PathType]:
63 | filter_fn = {
64 | "f": lambda p: osp.isfile(p) and not osp.islink(p),
65 | "d": osp.isdir,
66 | "l": osp.islink,
67 | "a": lambda p: osp.isfile(p) or osp.isdir(p) or osp.islink(p),
68 | }[type]
69 |
70 | def _natural_sorted_latestk(fs):
71 | if latestk > 0:
72 | if not exclude:
73 | fs = sorted(fs, key=osp.getmtime)[::-1][:latestk]
74 | else:
75 | fs = sorted(fs, key=osp.getmtime)[::-1][latestk:]
76 | return natural_sorted(fs)
77 |
78 | if "**" in pattern:
79 | dsts = glob.glob(pattern, recursive=True)
80 | elif "*" in pattern:
81 | dsts = glob.glob(pattern)
82 | else:
83 | dsts = [osp.join(pattern, p) for p in os.listdir(pattern) if filter_fn(osp.join(pattern, p))]
84 | return _natural_sorted_latestk(dsts)
85 |
86 | dsts = [dst for dst in dsts if filter_fn(dst)]
87 | return _natural_sorted_latestk(dsts)
88 |
89 |
90 | def mv(src: types.PathType, dst: types.PathType) -> None:
91 | shutil.move(src, dst)
92 |
93 |
94 | def ln(
95 | src: types.PathType,
96 | dst: types.PathType,
97 | use_relpath: bool = True,
98 | exist_ok: bool = True,
99 | ) -> None:
100 | if osp.exists(dst):
101 | if exist_ok:
102 | rm(dst)
103 | else:
104 | raise FileExistsError(f'Force link from "{src}" to existed "{dst}".')
105 | if use_relpath:
106 | src = osp.relpath(src, start=osp.dirname(dst))
107 | if not osp.exists(osp.dirname(dst)):
108 | mkdir(osp.dirname(dst))
109 | os.symlink(src, dst)
110 |
111 |
112 | def cp(src: types.PathType, dst: types.PathType, **kwargs) -> None:
113 | try:
114 | shutil.copyfile(src, dst)
115 | except OSError:
116 | shutil.copytree(src, dst, **kwargs)
117 |
118 |
119 | def mkdir(dst: types.PathType, exist_ok: bool = True, **kwargs) -> None:
120 | os.makedirs(dst, exist_ok=exist_ok, **kwargs)
121 |
122 |
123 | def rm(dst: types.PathType) -> None:
124 | if osp.exists(dst):
125 | if osp.isdir(dst):
126 | shutil.rmtree(dst, ignore_errors=True)
127 | if osp.isfile(dst) or osp.islink(dst):
128 | os.remove(dst)
129 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/safe_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : safe_ops.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import functools
21 | from typing import Tuple
22 |
23 | import jax
24 | import jax.numpy as np
25 |
26 |
27 | @functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 3))
28 | def safe_norm(
29 | x: np.ndarray,
30 | axis: int = -1,
31 | keepdims: bool = False,
32 | _: float = 1e-9,
33 | ) -> np.ndarray:
34 | """Calculates a np.linalg.norm(d) that's safe for gradients at d=0.
35 |
36 | These gymnastics are to avoid a poorly defined gradient for
37 | np.linal.norm(0). see https://github.com/google/jax/issues/3058 for details
38 |
39 | Args:
40 | x (np.ndarray): A np.array.
41 | axis (int): The axis along which to compute the norm.
42 | keepdims (bool): if True don't squeeze the axis.
43 | tol (float): the absolute threshold within which to zero out the
44 | gradient.
45 |
46 | Returns:
47 | Equivalent to np.linalg.norm(d)
48 | """
49 | return np.linalg.norm(x, axis=axis, keepdims=keepdims)
50 |
51 |
52 | @safe_norm.defjvp
53 | def _safe_norm_jvp(
54 | axis: int, keepdims: bool, tol: float, primals: Tuple, tangents: Tuple
55 | ) -> Tuple[np.ndarray, np.ndarray]:
56 | (x,) = primals
57 | (x_dot,) = tangents
58 | safe_tol = max(tol, 1e-30)
59 | y = safe_norm(x, tol=safe_tol, axis=axis, keepdims=True)
60 | y_safe = np.maximum(y, tol) # Prevent divide by zero.
61 | y_dot = np.where(y > safe_tol, x_dot * x / y_safe, np.zeros_like(x))
62 | y_dot = np.sum(y_dot, axis=axis, keepdims=True)
63 | # Squeeze the axis if `keepdims` is True.
64 | if not keepdims:
65 | y = np.squeeze(y, axis=axis)
66 | y_dot = np.squeeze(y_dot, axis=axis)
67 | return y, y_dot
68 |
69 |
70 | def log1p_safe(x: np.ndarray) -> np.ndarray:
71 | return np.log1p(np.minimum(x, 3e37))
72 |
73 |
74 | def exp_safe(x: np.ndarray) -> np.ndarray:
75 | return np.exp(np.minimum(x, 87.5))
76 |
77 |
78 | def expm1_safe(x: np.ndarray) -> np.ndarray:
79 | return np.expm1(np.minimum(x, 87.5))
80 |
81 |
82 | def safe_sqrt(x: np.ndarray, eps: float = 1e-7) -> np.ndarray:
83 | safe_x = np.where(x == 0, np.ones_like(x) * eps, x)
84 | return np.sqrt(safe_x)
85 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/struct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : structures.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import NamedTuple, Optional
21 |
22 | import numpy as np
23 |
24 |
25 | class Metadata(NamedTuple):
26 | time: Optional[np.ndarray] = None
27 | camera: Optional[np.ndarray] = None
28 | time_to: Optional[np.ndarray] = None
29 |
30 |
31 | class Rays(NamedTuple):
32 | origins: np.ndarray
33 | directions: np.ndarray
34 | pixels: np.ndarray
35 | local_directions: Optional[np.ndarray] = None
36 | radii: Optional[np.ndarray] = None
37 | metadata: Optional[Metadata] = None
38 |
39 | near: Optional[np.ndarray] = None
40 | far: Optional[np.ndarray] = None
41 |
42 |
43 | class Samples(NamedTuple):
44 | xs: np.ndarray
45 | directions: np.ndarray
46 | cov_diags: Optional[np.ndarray] = None
47 | metadata: Optional[Metadata] = None
48 |
49 | tvals: Optional[np.ndarray] = None
50 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : types.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import Any, Callable, Tuple, Union
21 |
22 | # import jax.numpy as np
23 | import numpy as np
24 |
25 |
26 | PRNGKey = np.ndarray
27 | Shape = Tuple[int]
28 | Dtype = Any
29 | Array = Union[np.ndarray, np.ndarray]
30 |
31 | Activation = Callable[[Array], Array]
32 | Initializer = Callable[[PRNGKey, Shape, Dtype], Array]
33 |
34 | PathType = str
35 | ScheduleType = Callable[[int], float]
36 | EngineType = Any
37 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/corrs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : corrs.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import Optional
21 |
22 | import cv2
23 | import numpy as np
24 |
25 | from .. import image
26 |
27 |
28 | def visualize_corrs(
29 | corrs: np.ndarray,
30 | img: np.ndarray,
31 | img_to: np.ndarray,
32 | *,
33 | rgbs: Optional[np.ndarray] = None,
34 | min_rad: float = 5,
35 | subsample: int = 50,
36 | num_min_keeps: int = 10,
37 | circle_radius: int = 1,
38 | circle_thickness: int = -1,
39 | line_thickness: int = 1,
40 | alpha: float = 0.7,
41 | ):
42 | """Visualize a set of correspondences.
43 |
44 | By default this function visualizes a sparse set subset of correspondences
45 | with lines.
46 |
47 | Args:
48 | corrs (np.ndarray): A set of correspondences of shape (N, 2, 2), where
49 | the second dimension represents (from, to) and the last dimension
50 | represents (x, y) coordinates.
51 | img (np.ndarray): An image for start points of shape (Hi, Wi, 3) in
52 | either float32 or uint8.
53 | img_to (np.ndarray): An image for end points of shape (Ht, Wt, 3) in
54 | either float32 or uint8.
55 | rgbs (Optional[np.ndarray]): A set of rgbs for each correspondence
56 | of shape (N, 3) or (3,). If None then use pixel coordinates.
57 | Default: None.
58 | min_rad (float): The minimum threshold for the correspondence.
59 | subsample (int): The number of points to subsample. Default: 50.
60 | num_min_keeps (int): The number of correspondences to keep. Default:
61 | 10.
62 | circle_radius (int): The radius of the circle. Default: 1.
63 | circle_thickness (int): The thickness of the circle. Default: 1.
64 | line_thickness (int): The thickness of the line. Default: 1.
65 | alpha (float): The alpha value between [0, 1] for foreground blending.
66 | The bigger the more prominent of the visualization. Default: 0.7.
67 |
68 | Returns:
69 | np.ndarray: A visualization image of shape (H, W, 3) in uint8.
70 | """
71 | corrs = np.array(corrs)
72 | img = image.to_uint8(img)
73 | img_to = image.to_uint8(img_to)
74 | rng = np.random.default_rng(0)
75 |
76 | (Hi, Wi), (Ht, Wt) = img.shape[:2], img_to.shape[:2]
77 | combined = np.concatenate([img, img_to], axis=1)
78 | canvas = combined.copy()
79 |
80 | norm = np.linalg.norm(corrs[:, 1] - corrs[:, 0], axis=-1)
81 | mask = (
82 | (norm >= min_rad)
83 | & (corrs[..., 0, 0] < Wi)
84 | & (corrs[..., 0, 0] >= 0)
85 | & (corrs[..., 0, 1] < Hi)
86 | & (corrs[..., 0, 1] >= 0)
87 | & (corrs[..., 1, 0] < Wt)
88 | & (corrs[..., 1, 0] >= 0)
89 | & (corrs[..., 1, 1] < Ht)
90 | & (corrs[..., 1, 1] >= 0)
91 | )
92 | filtered_inds = np.nonzero(mask)[0]
93 | num_min_keeps = min(
94 | max(num_min_keeps, filtered_inds.shape[0] // subsample),
95 | filtered_inds.shape[0],
96 | )
97 | filtered_inds = rng.choice(filtered_inds, num_min_keeps, replace=False) if filtered_inds.shape[0] > 0 else []
98 |
99 | if len(filtered_inds) > 0:
100 | if rgbs is None:
101 | # Use normalized pixel coordinate of img for colorization.
102 | corr = corrs[:, 0]
103 | phi = 2 * np.pi * (corr[:, 0] / (Wi - 1) - 0.5)
104 | theta = np.pi * (corr[:, 1] / (Hi - 1) - 0.5)
105 | x = np.cos(theta) * np.cos(phi)
106 | y = np.cos(theta) * np.sin(phi)
107 | z = np.sin(theta)
108 | rgbs = image.to_uint8((np.stack([x, y, z], axis=-1) + 1) / 2)
109 | for idx in filtered_inds:
110 | start = tuple(corrs[idx, 0].astype(np.int32))
111 | end = tuple((corrs[idx, 1] + [Wi, 0]).astype(np.int32))
112 | rgb = tuple(int(c) for c in (rgbs[idx] if rgbs.ndim == 2 else rgbs))
113 | cv2.circle(
114 | combined,
115 | start,
116 | radius=circle_radius,
117 | color=rgb,
118 | thickness=circle_thickness,
119 | lineType=cv2.LINE_AA,
120 | )
121 | cv2.circle(
122 | combined,
123 | end,
124 | radius=circle_radius,
125 | color=rgb,
126 | thickness=circle_thickness,
127 | lineType=cv2.LINE_AA,
128 | )
129 | if line_thickness > 0:
130 | cv2.line(
131 | canvas,
132 | start,
133 | end,
134 | color=rgb,
135 | thickness=line_thickness,
136 | lineType=cv2.LINE_AA,
137 | )
138 |
139 | combined = cv2.addWeighted(combined, alpha, canvas, 1 - alpha, 0)
140 | return combined
141 |
142 |
143 | def visualize_chained_corrs(
144 | corrs: np.ndarray,
145 | imgs: np.ndarray,
146 | *,
147 | rgbs: Optional[np.ndarray] = None,
148 | circle_radius: int = 1,
149 | circle_thickness: int = -1,
150 | line_thickness: int = 1,
151 | alpha: float = 0.7,
152 | ):
153 | """Visualize a set of correspondences.
154 |
155 | By default this function visualizes a sparse set subset of correspondences
156 | with lines.
157 |
158 | Args:
159 | corrs (np.ndarray): A set of correspondences of shape (N, C, 2), where
160 | the second dimension represents chained frames and the last
161 | dimension represents (x, y) coordinates.
162 | imgs (np.ndarray): An image for start points of shape (C, H, W, 3) in
163 | either float32 or uint8.
164 | rgbs (Optional[np.ndarray]): A set of rgbs for each correspondence
165 | of shape (N, 3) or (3,). If None then use pixel coordinates.
166 | Default: None.
167 | circle_radius (int): The radius of the circle. Default: 1.
168 | circle_thickness (int): The thickness of the circle. Default: 1.
169 | line_thickness (int): The thickness of the line. Default: 1.
170 | alpha (float): The alpha value between [0, 1] for foreground blending.
171 | The bigger the more prominent of the visualization. Default: 0.7.
172 |
173 | Returns:
174 | np.ndarray: A visualization image of shape (H, W, 3) in uint8.
175 | """
176 | corrs = np.array(corrs)
177 | imgs = image.to_uint8(imgs)
178 |
179 | C, H, W = imgs.shape[:3]
180 | combined = np.concatenate(list(imgs), axis=1)
181 | canvas = combined.copy()
182 |
183 | if rgbs is None:
184 | # Use normalized pixel coordinate of img for colorization.
185 | corr = corrs[:, 0]
186 | phi = 2 * np.pi * (corr[:, 0] / (W - 1) - 0.5)
187 | theta = np.pi * (corr[:, 1] / (H - 1) - 0.5)
188 | x = np.cos(theta) * np.cos(phi)
189 | y = np.cos(theta) * np.sin(phi)
190 | z = np.sin(theta)
191 | rgbs = image.to_uint8((np.stack([x, y, z], axis=-1) + 1) / 2)
192 |
193 | for i in range(C):
194 | mask = (corrs[..., i, 0] < W) & (corrs[..., i, 0] >= 0) & (corrs[..., i, 1] < H) & (corrs[..., i, 1] >= 0)
195 | filtered_inds = np.nonzero(mask)[0]
196 |
197 | for idx in filtered_inds:
198 | start = tuple((corrs[idx, i] + [W * i, 0]).astype(np.int32))
199 | rgb = tuple(int(c) for c in (rgbs[idx] if rgbs.ndim == 2 else rgbs))
200 | cv2.circle(
201 | combined,
202 | start,
203 | radius=circle_radius,
204 | color=rgb,
205 | thickness=circle_thickness,
206 | lineType=cv2.LINE_AA,
207 | )
208 | if (
209 | line_thickness > 0
210 | and i < C - 1
211 | and (
212 | (corrs[idx, i + 1, 0] < W)
213 | & (corrs[idx, i + 1, 0] >= 0)
214 | & (corrs[idx, i + 1, 1] < H)
215 | & (corrs[idx, i + 1, 1] >= 0)
216 | )
217 | ):
218 | end = tuple((corrs[idx, i + 1] + [W * (i + 1), 0]).astype(np.int32))
219 | cv2.line(
220 | canvas,
221 | start,
222 | end,
223 | color=rgb,
224 | thickness=line_thickness,
225 | lineType=cv2.LINE_AA,
226 | )
227 |
228 | combined = cv2.addWeighted(combined, alpha, canvas, 1 - alpha, 0)
229 | return combined
230 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/depth.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : depth.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import Callable, Optional, Union
21 |
22 | import numpy as np
23 | from dycheck.utils import image
24 | from matplotlib import cm
25 |
26 |
27 | def visualize_depth(
28 | depth: np.ndarray,
29 | acc: Optional[np.ndarray] = None,
30 | near: Optional[float] = None,
31 | far: Optional[float] = None,
32 | ignore_frac: float = 0,
33 | curve_fn: Callable = lambda x: -np.log(x + np.finfo(np.float32).eps),
34 | cmap: Union[str, Callable] = "turbo",
35 | invalid_depth: float = 0,
36 | ) -> np.ndarray:
37 | """Visualize a depth map.
38 |
39 | Args:
40 | depth (np.ndarray): A depth map of shape (H, W, 1).
41 | acc (np.ndarray): An accumulation map of shape (H, W, 1) in [0, 1].
42 | near (Optional[float]): The depth of the near plane. If None then just
43 | use the min. Default: None.
44 | far (Optional[float]): The depth of the far plane. If None then just
45 | use the max. Default: None.
46 | ignore_frac (float): The fraction of the depth map to ignore when
47 | automatically generating `near` and `far`. Depends on `acc` as well
48 | as `depth'. Default: 0.
49 | curve_fn (Callable): A curve function that gets applied to `depth`,
50 | `near`, and `far` before the rest of visualization. Good choices:
51 | x, 1/(x+eps), log(x+eps). Note that the default choice will flip
52 | the sign of depths, so that the default cmap (turbo) renders "near"
53 | as red and "far" as blue. Default: a negative log scale mapping.
54 | cmap (Union[str, Callable]): A cmap for colorization. Default: "turbo".
55 | invalid_depth (float): The value to use for invalid depths. Can be
56 | np.nan. Default: 0.
57 |
58 | Returns:
59 | np.ndarray: A depth visualzation image of shape (H, W, 3) in uint8.
60 | """
61 | depth = np.array(depth)
62 | if acc is None:
63 | acc = np.ones_like(depth)
64 | else:
65 | acc = np.array(acc)
66 | if invalid_depth is not None:
67 | if invalid_depth is np.nan:
68 | acc = np.where(np.isnan(depth), np.zeros_like(acc), acc)
69 | else:
70 | acc = np.where(depth == invalid_depth, np.zeros_like(acc), acc)
71 |
72 | if near is None or far is None:
73 | # Sort `depth` and `acc` according to `depth`, then identify the depth
74 | # values that span the middle of `acc`, ignoring `ignore_frac` fraction
75 | # of `acc`.
76 | sortidx = np.argsort(depth.reshape((-1,)))
77 | depth_sorted = depth.reshape((-1,))[sortidx]
78 | acc_sorted = acc.reshape((-1,))[sortidx] # type: ignore
79 | cum_acc_sorted = np.cumsum(acc_sorted)
80 | mask = (cum_acc_sorted >= cum_acc_sorted[-1] * ignore_frac) & (
81 | cum_acc_sorted <= cum_acc_sorted[-1] * (1 - ignore_frac)
82 | )
83 | if invalid_depth is not None:
84 | mask &= (depth_sorted != invalid_depth) if invalid_depth is not np.nan else ~np.isnan(depth_sorted)
85 | depth_keep = depth_sorted[mask]
86 | eps = np.finfo(np.float32).eps
87 | # If `near` or `far` are None, use the highest and lowest non-NaN
88 | # values in `depth_keep` as automatic near/far planes.
89 | near = near or depth_keep[0] - eps
90 | far = far or depth_keep[-1] + eps
91 |
92 | assert near < far
93 |
94 | # Curve all values.
95 | depth, near, far = [curve_fn(x) for x in [depth, near, far]]
96 |
97 | # Scale to [0, 1].
98 | value = np.nan_to_num(np.clip((depth - np.minimum(near, far)) / np.abs(far - near), 0, 1))[..., 0]
99 |
100 | if isinstance(cmap, str):
101 | cmap = cm.get_cmap(cmap)
102 | color = cmap(value)[..., :3]
103 |
104 | # Set non-accumulated pixels to white.
105 | color = color * acc + (1 - acc) # type: ignore
106 |
107 | return image.to_uint8(color)
108 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/flow.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : flow.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from typing import Optional
21 |
22 | import cv2
23 | import jax
24 | import numpy as np
25 |
26 | from .. import image
27 | from .corrs import visualize_corrs
28 |
29 |
30 | def _make_colorwheel() -> np.ndarray:
31 | """Generates a classic color wheel for optical flow visualization.
32 |
33 | A Database and Evaluation Methodology for Optical Flow.
34 | Baker et al., ICCV 2007.
35 | http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
36 |
37 | Code follows the original C++ source code of Daniel Scharstein.
38 | Code follows the the Matlab source code of Deqing Sun.
39 |
40 | Returns:
41 | colorwheel (np.ndarray): Color wheel of shape (55, 3) in uint8.
42 | """
43 |
44 | RY = 15
45 | YG = 6
46 | GC = 4
47 | CB = 11
48 | BM = 13
49 | MR = 6
50 |
51 | ncols = RY + YG + GC + CB + BM + MR
52 | colorwheel = np.zeros((ncols, 3))
53 | col = 0
54 |
55 | # RY
56 | colorwheel[0:RY, 0] = 255
57 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
58 | col = col + RY
59 | # YG
60 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
61 | colorwheel[col : col + YG, 1] = 255
62 | col = col + YG
63 | # GC
64 | colorwheel[col : col + GC, 1] = 255
65 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
66 | col = col + GC
67 | # CB
68 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
69 | colorwheel[col : col + CB, 2] = 255
70 | col = col + CB
71 | # BM
72 | colorwheel[col : col + BM, 2] = 255
73 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
74 | col = col + BM
75 | # MR
76 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
77 | colorwheel[col : col + MR, 0] = 255
78 | return colorwheel
79 |
80 |
81 | def _flow_to_colors(flow: np.ndarray) -> np.ndarray:
82 | """Applies the flow color wheel to (possibly clipped) flow visualization
83 | image.
84 |
85 | According to the C++ source code of Daniel Scharstein.
86 | According to the Matlab source code of Deqing Sun.
87 |
88 | Args:
89 | flow (np.ndarray): Flow image of shape (H, W, 2).
90 |
91 | Returns:
92 | flow_visual (np.ndarray): Flow visualization image of shape (H, W, 3).
93 | """
94 | u, v = jax.tree_map(lambda x: x[..., 0], np.split(flow, 2, axis=-1))
95 |
96 | flow_visual = np.zeros(flow.shape[:2] + (3,), np.uint8)
97 | colorwheel = _make_colorwheel()
98 | ncols = colorwheel.shape[0]
99 |
100 | rad = np.sqrt(np.square(u) + np.square(v))
101 | a = np.arctan2(-v, -u) / np.pi
102 | fk = (a + 1) / 2 * (ncols - 1)
103 | k0 = np.floor(fk).astype(np.int32)
104 | k1 = k0 + 1
105 | k1[k1 == ncols] = 0
106 | f = fk - k0
107 | for i in range(colorwheel.shape[1]):
108 | tmp = colorwheel[:, i]
109 | col0 = tmp[k0] / 255
110 | col1 = tmp[k1] / 255
111 | col = (1 - f) * col0 + f * col1
112 | idx = rad <= 1
113 | col[idx] = 1 - rad[idx] * (1 - col[idx])
114 | col[~idx] = col[~idx] * 0.75 # Out of range.
115 | flow_visual[..., i] = np.floor(255 * col)
116 | return flow_visual
117 |
118 |
119 | def visualize_flow(
120 | flow: np.ndarray,
121 | *,
122 | clip_flow: Optional[float] = None,
123 | rad_max: Optional[float] = None,
124 | ) -> np.ndarray:
125 | """Visualizei a flow image.
126 |
127 | Args:
128 | flow (np.ndarray): A flow image of shape (H, W, 2).
129 | clip_flow (Optional[float]): Clip flow to [0, clip_flow].
130 | rad_max (Optional[float]): Maximum radius of the flow visualization.
131 |
132 | Returns:
133 | np.ndarray: Flow visualization image of shape (H, W, 3).
134 | """
135 | flow = np.array(flow)
136 |
137 | if clip_flow is not None:
138 | flow = np.clip(flow, 0, clip_flow)
139 | rad = np.linalg.norm(flow, axis=-1, keepdims=True).clip(min=1e-6)
140 | if rad_max is None:
141 | rad_max = np.full_like(rad, rad.max())
142 | else:
143 | # Clip big flow to rad_max while homogenously scaling for the rest.
144 | rad_max = rad.clip(min=rad_max)
145 | flow = flow / rad_max
146 |
147 | return _flow_to_colors(flow)
148 |
149 |
150 | def visualize_flow_arrows(
151 | flow: np.ndarray,
152 | img: np.ndarray,
153 | *,
154 | rgbs: Optional[np.ndarray] = None,
155 | clip_flow: Optional[float] = None,
156 | min_thresh: float = 5,
157 | subsample: int = 50,
158 | num_min_keeps: int = 10,
159 | line_thickness: int = 1,
160 | tip_length: float = 0.2,
161 | alpha: float = 0.5,
162 | ) -> np.ndarray:
163 | """Visualize a flow image with arrows.
164 |
165 | Args:
166 | flow (np.ndarray): A flow image of shape (H, W, 2).
167 | img (np.ndarray): An image for start points of shape (H, W, 3) in
168 | float32 or uint8.
169 | rgbs (Optional[np.ndarray]): A color map for the arrows at each pixel
170 | location of shape (H, W, 3). Default: None.
171 | clip_flow (Optional[float]): Clip flow to [0, clip_flow].
172 | min_thresh (float): Minimum threshold for flow magnitude.
173 | subsample (int): Subsample the flow to speed up visualization.
174 | num_min_keeps (int): The number of correspondences to keep. Default:
175 | 10.
176 | line_thickness (int): Line thickness. Default: 1.
177 | tip_length (float): Length of the arrow tip. Default: 0.2.
178 | alpha (float): The alpha value between [0, 1] for foreground blending.
179 | The bigger the more prominent of the visualization. Default: 0.5.
180 |
181 | Returns:
182 | canvas (np.ndarray): Flow visualization image of shape (H, W, 3).
183 | """
184 | img = image.to_uint8(img)
185 | canvas = img.copy()
186 | rng = np.random.default_rng(0)
187 |
188 | if rgbs is None:
189 | rgbs = visualize_flow(flow, clip_flow=clip_flow)
190 | H, W = flow.shape[:2]
191 |
192 | flow_start = np.stack(np.meshgrid(range(W), range(H)), 2)
193 | flow_end = (flow[flow_start[..., 1], flow_start[..., 0]] + flow_start).astype(np.int32)
194 |
195 | norm = np.linalg.norm(flow, axis=-1)
196 | valid_mask = (
197 | (norm >= min_thresh)
198 | & (flow_end[..., 0] < flow.shape[1])
199 | & (flow_end[..., 0] >= 0)
200 | & (flow_end[..., 1] < flow.shape[0])
201 | & (flow_end[..., 1] >= 0)
202 | )
203 | filtered_inds = np.stack(np.nonzero(valid_mask), axis=-1)
204 | num_min_keeps = min(
205 | max(num_min_keeps, filtered_inds.shape[0] // subsample),
206 | filtered_inds.shape[0],
207 | )
208 | filtered_inds = rng.choice(filtered_inds, num_min_keeps, replace=False) if filtered_inds.shape[0] > 0 else []
209 |
210 | for inds in filtered_inds:
211 | y, x = inds
212 | start = tuple(flow_start[y, x])
213 | end = tuple(flow_end[y, x])
214 | rgb = tuple(int(x) for x in rgbs[y, x])
215 | cv2.arrowedLine(
216 | canvas,
217 | start,
218 | end,
219 | color=rgb,
220 | thickness=line_thickness,
221 | tipLength=tip_length,
222 | line_type=cv2.LINE_AA,
223 | )
224 |
225 | canvas = cv2.addWeighted(img, alpha, canvas, 1 - alpha, 0)
226 | return canvas
227 |
228 |
229 | def visualize_flow_corrs(
230 | flow: np.ndarray,
231 | img: np.ndarray,
232 | img_to: np.ndarray,
233 | *,
234 | mask: Optional[np.ndarray] = None,
235 | rgbs: Optional[np.ndarray] = None,
236 | **kwargs,
237 | ) -> np.ndarray:
238 | """Visualize a flow image as a set of correspondences.
239 |
240 | Args:
241 | flow (np.ndarray): A flow image of shape (H, W, 2).
242 | img (np.ndarray): An image for start points of shape (H, W, 3) in
243 | float32 or uint8.
244 | img_to (np.ndarray): An image for end points of shape (H, W, 3) in
245 | float32 or uint8.
246 | mask (Optional[np.ndarray]): A hard mask for start points of shape
247 | (H, W, 1). Default: None.
248 | rgbs (Optional[np.ndarray]): A color map for the arrows at each pixel
249 | location of shape (H, W, 3). Default: None.
250 |
251 | Returns:
252 | canvas (np.ndarray): Flow visualization image of shape (H, W, 3).
253 | """
254 | flow_start = np.stack(np.meshgrid(range(flow.shape[1]), range(flow.shape[0])), 2)
255 | flow_end = (flow[flow_start[..., 1], flow_start[..., 0]] + flow_start).astype(np.int32)
256 | flow_corrs = np.stack([flow_start, flow_end])
257 |
258 | if mask is not None:
259 | # Only show correspondences inside of the mask.
260 | flow_corrs = flow_corrs[np.stack([mask[..., 0]] * 2)]
261 | if rgbs is not None:
262 | rgbs = rgbs[mask[..., 0]]
263 | if flow_corrs.shape[0] == 0:
264 | flow_corrs = np.ones((0, 4))
265 | if rgbs is not None:
266 | rgbs = np.ones((0, 3))
267 | flow_corrs = flow_corrs.reshape(2, -1, 2).swapaxes(0, 1)
268 |
269 | return visualize_corrs(flow_corrs, img, img_to, rgbs=rgbs, **kwargs)
270 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/kps/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : __init__.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | from copy import deepcopy
21 | from typing import Callable, Optional, Union
22 |
23 | import cv2
24 | import numpy as np
25 | from dycheck.utils import image
26 |
27 | from .skeleton import SKELETON_MAP, Skeleton
28 |
29 |
30 | def visualize_kps(
31 | kps: np.ndarray,
32 | img: np.ndarray,
33 | *,
34 | skeleton: Union[str, Skeleton, Callable[..., Skeleton]] = "unconnected",
35 | rgbs: Optional[np.ndarray] = None,
36 | kp_radius: int = 4,
37 | bone_thickness: int = 3,
38 | **kwargs,
39 | ) -> np.ndarray:
40 | """Visualize 2D keypoints with their skeleton.
41 |
42 | Args:
43 | kps (np.ndarray): an array of shape (J, 3) for keypoints. Expect the
44 | last column to be the visibility in [0, 1].
45 | img (np.ndarray): a RGB image of shape (H, W, 3) in float32 or uint8.
46 | skeleton_cls (Union[str, Callable[..., Skeleton]]): a class name or a
47 | callable that returns a Skeleton instance.
48 | rgbs (Optional[np.ndarray]): A set of rgbs for each keypoint of shape
49 | (J, 3) or (3,). If None then use skeleton palette. Default: None.
50 | kp_radius (int): the radius of kps for visualization. Default: 4.
51 | bone_thickness (int): the thickness of bones connecting kps for
52 | visualization. Default: 3.
53 |
54 | Returns:
55 | combined (np.ndarray): Keypoint visualzation image of shape (H, W, 3)
56 | in uint8.
57 | """
58 | if isinstance(skeleton, str):
59 | skeleton = SKELETON_MAP[skeleton]
60 | if isinstance(skeleton, Callable):
61 | skeleton = skeleton(num_kps=len(kps), **kwargs)
62 | if rgbs is not None:
63 | if rgbs.ndim == 1:
64 | rgbs = rgbs[None, :].repeat(skeleton.num_kps, axis=0)
65 | skeleton = deepcopy(skeleton)
66 | skeleton._palette = rgbs.tolist()
67 |
68 | assert skeleton.num_kps == len(kps)
69 |
70 | kps = np.array(kps)
71 | img = image.to_uint8(img)
72 |
73 | H, W = img.shape[:2]
74 | canvas = img.copy()
75 |
76 | mask = (kps[:, -1] != 0) & (kps[:, 0] >= 0) & (kps[:, 0] < W) & (kps[:, 1] >= 0) & (kps[:, 1] < H)
77 |
78 | # Visualize bones.
79 | palette = skeleton.non_root_palette
80 | bones = skeleton.non_root_bones
81 | for rgb, (j, p) in zip(palette, bones):
82 | # Skip invisible keypoints.
83 | if (~mask[[j, p]]).any():
84 | continue
85 |
86 | kp_p, kp_j = kps[p, :2], kps[j, :2]
87 | kp_mid = (kp_p + kp_j) / 2
88 | bone_length = np.linalg.norm(kp_j - kp_p)
89 | bone_angle = (np.arctan2(kp_j[1] - kp_p[1], kp_j[0] - kp_p[0])) * 180 / np.pi
90 | polygon = cv2.ellipse2Poly(
91 | (int(kp_mid[0]), int(kp_mid[1])),
92 | (int(bone_length / 2), bone_thickness),
93 | int(bone_angle),
94 | arcStart=0,
95 | arcEnd=360,
96 | delta=5,
97 | )
98 | cv2.fillConvexPoly(canvas, polygon, rgb, lineType=cv2.LINE_AA)
99 | canvas = cv2.addWeighted(img, 0.5, canvas, 0.5, 0)
100 |
101 | # Visualize keypoints.
102 | combined = canvas.copy()
103 | palette = skeleton.palette
104 | for rgb, kp, valid in zip(palette, kps, mask):
105 | # Skip invisible keypoints.
106 | if not valid:
107 | continue
108 |
109 | cv2.circle(
110 | combined,
111 | (int(kp[0]), int(kp[1])),
112 | radius=kp_radius,
113 | color=rgb,
114 | thickness=-1,
115 | lineType=cv2.LINE_AA,
116 | )
117 | combined = cv2.addWeighted(canvas, 0.3, combined, 0.7, 0)
118 |
119 | return combined
120 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/kps/skeleton.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : skeleton.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import re
21 | from typing import Callable, Optional, Sequence, Tuple, Union
22 |
23 | import gin
24 | import numpy as np
25 | from dycheck.utils import image
26 | from matplotlib import cm
27 |
28 |
29 | KP_PALETTE_MAP = {}
30 |
31 |
32 | @gin.configurable()
33 | class Skeleton(object):
34 | name = "skeleton"
35 | _anonymous_kp_name = "ANONYMOUS KP"
36 |
37 | def __init__(
38 | self,
39 | parents: Sequence[Optional[int]],
40 | kp_names: Optional[Sequence[str]] = None,
41 | palette: Optional[Sequence[Tuple[int, int, int]]] = None,
42 | ):
43 | if kp_names is not None:
44 | assert len(parents) == len(kp_names)
45 | if palette is not None:
46 | assert len(kp_names) == len(palette)
47 |
48 | self._parents = parents
49 | self._kp_names = kp_names if kp_names is not None else [self._anonymous_kp_name] * self.num_kps
50 | self._palette = palette
51 |
52 | def asdict(self):
53 | return {
54 | "name": self.name,
55 | "parents": self.parents,
56 | "kp_names": self.kp_names,
57 | "palette": self.palette,
58 | }
59 |
60 | @property
61 | def is_unconnected(self):
62 | return all([p is None for p in self._parents])
63 |
64 | @property
65 | def parents(self):
66 | return self._parents
67 |
68 | @property
69 | def kp_names(self):
70 | return self._kp_names
71 |
72 | @property
73 | def palette(self):
74 | if self._palette is not None:
75 | return self._palette
76 |
77 | if self.kp_names[0] != self._anonymous_kp_name and all(
78 | [kp_name in KP_PALETTE_MAP for kp_name in self.kp_names]
79 | ):
80 | return [KP_PALETTE_MAP[kp_name] for kp_name in self.kp_names]
81 |
82 | palette = np.zeros((self.num_kps, 3), dtype=np.uint8)
83 | left_mask = np.array(
84 | [len(re.findall(r"^(\w+ |)L\w+$", kp_name)) > 0 for kp_name in self._kp_names],
85 | dtype=np.bool,
86 | )
87 | palette[left_mask] = (255, 0, 0)
88 | return [tuple(color.tolist()) for color in palette]
89 |
90 | @property
91 | def num_kps(self):
92 | return len(self._parents)
93 |
94 | @property
95 | def root_idx(self):
96 | if self.is_unconnected:
97 | return 0
98 | return self._parents.index(-1)
99 |
100 | @property
101 | def bones(self):
102 | if self.is_unconnected:
103 | return []
104 | return np.stack([list(range(self.num_kps)), self.parents]).T.tolist()
105 |
106 | @property
107 | def non_root_bones(self):
108 | if self.is_unconnected:
109 | return []
110 | return np.delete(self.bones.copy(), self.root_idx, axis=0)
111 |
112 | @property
113 | def non_root_palette(self):
114 | if self.is_unconnected:
115 | return []
116 | return np.delete(self.palette.copy(), self.root_idx, axis=0).tolist()
117 |
118 |
119 | @gin.configurable()
120 | class UnconnectedSkeleton(Skeleton):
121 | """A keypoint skeleton that does not define parents. This could be useful
122 | when organizing randomly annotated keypoints.
123 | """
124 |
125 | name: str = "unconnected"
126 |
127 | def __init__(self, num_kps: int, cmap: Union[str, Callable] = "gist_rainbow"):
128 | if isinstance(cmap, str):
129 | cmap = cm.get_cmap(cmap, num_kps)
130 | pallete = image.to_uint8(np.array([cmap(i)[:3] for i in range(num_kps)], np.float32)).tolist()
131 | super().__init__(
132 | parents=[None for _ in range(num_kps)],
133 | kp_names=[f"KP_{i}" for i in range(num_kps)],
134 | palette=pallete,
135 | )
136 |
137 |
138 | @gin.configurable()
139 | class HumanSkeleton(Skeleton):
140 | """A human skeleton following the COCO dataset.
141 |
142 | Microsoft COCO: Common Objects in Context.
143 | Lin et al., ECCV 2014.
144 | https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48
145 |
146 | For pictorial definition, see also: shorturl.at/ilnpZ.
147 | """
148 |
149 | name: str = "human"
150 |
151 | def __init__(self, **_):
152 | super().__init__(
153 | parents=[
154 | 1,
155 | -1,
156 | 1,
157 | 2,
158 | 3,
159 | 1,
160 | 5,
161 | 6,
162 | 1,
163 | 8,
164 | 9,
165 | 1,
166 | 11,
167 | 12,
168 | 0,
169 | 0,
170 | 14,
171 | 15,
172 | ],
173 | kp_names=[
174 | "Nose",
175 | "Neck",
176 | "RShoulder",
177 | "RElbow",
178 | "RWrist",
179 | "LShoulder",
180 | "LElbow",
181 | "LWrist",
182 | "RHip",
183 | "RKnee",
184 | "RAnkle",
185 | "LHip",
186 | "LKnee",
187 | "LAnkle",
188 | "REye",
189 | "LEye",
190 | "REar",
191 | "LEar",
192 | ],
193 | palette=[
194 | (255, 0, 0),
195 | (255, 85, 0),
196 | (255, 170, 0),
197 | (255, 255, 0),
198 | (170, 255, 0),
199 | (85, 255, 0),
200 | (0, 255, 0),
201 | (0, 255, 85),
202 | (0, 255, 170),
203 | (0, 255, 255),
204 | (0, 170, 255),
205 | (0, 85, 255),
206 | (0, 0, 255),
207 | (85, 0, 255),
208 | (170, 0, 255),
209 | (255, 0, 255),
210 | (255, 0, 170),
211 | (255, 0, 85),
212 | ],
213 | )
214 |
215 |
216 | @gin.configurable()
217 | class QuadrupedSkeleton(Skeleton):
218 | """A quadruped skeleton following StanfordExtra dataset.
219 |
220 | Novel dataset for Fine-Grained Image Categorization.
221 | Khosla et al., CVPR 2011, FGVC workshop.
222 | http://vision.stanford.edu/aditya86/ImageNetDogs/main.html
223 |
224 | Who Left the Dogs Out? 3D Animal Reconstruction with Expectation
225 | Maximization in the Loop.
226 | Biggs et al., ECCV 2020.
227 | https://arxiv.org/abs/2007.11110
228 | """
229 |
230 | name: str = "quadruped"
231 |
232 | def __init__(self, **_):
233 | super().__init__(
234 | parents=[
235 | 1,
236 | 2,
237 | 22,
238 | 4,
239 | 5,
240 | 12,
241 | 7,
242 | 8,
243 | 22,
244 | 10,
245 | 11,
246 | 12,
247 | -1,
248 | 12,
249 | 20,
250 | 21,
251 | 17,
252 | 23,
253 | 14,
254 | 15,
255 | 16,
256 | 16,
257 | 12,
258 | 22,
259 | ],
260 | kp_names=[
261 | "LFrontPaw",
262 | "LFrontWrist",
263 | "LFrontElbow",
264 | "LRearPaw",
265 | "LRearWrist",
266 | "LRearElbow",
267 | "RFrontPaw",
268 | "RFrontWrist",
269 | "RFrontElbow",
270 | "RRearPaw",
271 | "RRearWrist",
272 | "RRearElbow",
273 | "TailStart",
274 | "TailEnd",
275 | "LEar",
276 | "REar",
277 | "Nose",
278 | "Chin",
279 | "LEarTip",
280 | "REarTip",
281 | "LEye",
282 | "REye",
283 | "Withers",
284 | "Throat",
285 | ],
286 | palette=[
287 | (0, 255, 0),
288 | (63, 255, 0),
289 | (127, 255, 0),
290 | (0, 0, 255),
291 | (0, 63, 255),
292 | (0, 127, 255),
293 | (255, 255, 0),
294 | (255, 191, 0),
295 | (255, 127, 0),
296 | (0, 255, 255),
297 | (0, 255, 191),
298 | (0, 255, 127),
299 | (0, 0, 0),
300 | (0, 0, 0),
301 | (255, 0, 170),
302 | (255, 0, 170),
303 | (255, 0, 170),
304 | (255, 0, 170),
305 | (255, 0, 170),
306 | (255, 0, 170),
307 | (255, 0, 170),
308 | (255, 0, 170),
309 | (255, 0, 170),
310 | (255, 0, 170),
311 | ],
312 | )
313 |
314 |
315 | SKELETON_MAP = {
316 | cls.name: cls
317 | for cls in [
318 | Skeleton,
319 | UnconnectedSkeleton,
320 | HumanSkeleton,
321 | QuadrupedSkeleton,
322 | ]
323 | }
324 |
--------------------------------------------------------------------------------
/utils/dycheck_utils/visuals/rendering.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # File : rendering.py
4 | # Author : Hang Gao
5 | # Email : hangg.sv7@gmail.com
6 | #
7 | # Copyright 2022 Adobe. All rights reserved.
8 | #
9 | # This file is licensed to you under the Apache License, Version 2.0 (the
10 | # "License"); you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the
17 | # License for the specific language governing permissions and limitations under
18 | # the License.
19 |
20 | import matplotlib
21 |
22 |
23 | matplotlib.use("agg")
24 |
25 | from typing import NamedTuple, Optional, Tuple
26 |
27 | import matplotlib.pyplot as plt
28 | import numpy as np
29 | from dycheck import geometry
30 | from matplotlib.backends.backend_agg import FigureCanvasAgg
31 | from matplotlib.collections import PolyCollection
32 |
33 | from .. import image
34 |
35 |
36 | class Renderings(NamedTuple):
37 | # (H, W, 3) in uint8.
38 | rgb: Optional[np.ndarray] = None
39 | # (H, W, 1) in float32.
40 | depth: Optional[np.ndarray] = None
41 | # (H, W, 1) in [0, 1].
42 | acc: Optional[np.ndarray] = None
43 |
44 |
45 | def visualize_pcd_renderings(points: np.ndarray, point_rgbs: np.ndarray, camera: geometry.Camera, **_) -> Renderings:
46 | """Visualize a point cloud as a set renderings.
47 |
48 | Args:
49 | points (np.ndarray): (N, 3) array of points.
50 | point_rgbs (np.ndarray): (N, 3) array of point colors in either uint8
51 | or float32.
52 | camera (geometry.Camera): a camera object containing view information.
53 |
54 | Returns:
55 | Renderings: the image output object.
56 | """
57 |
58 | point_rgbs = image.to_uint8(point_rgbs) # type: ignore
59 |
60 | # Setup the camera.
61 | W, H = camera.image_size
62 |
63 | # project the 3D points to 2D on image plane
64 | pixels, depths = camera.project(points, return_depth=True, use_projective_depth=True)
65 | pixels = pixels.astype(np.int32)
66 | mask = (pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] >= 0) & (pixels[:, 1] < H) & (depths[:, 0] > 0)
67 |
68 | pixels = pixels[mask]
69 | rgbs = point_rgbs[mask]
70 | depths = depths[mask]
71 |
72 | sorted_inds = np.argsort(depths[..., 0])[::-1]
73 | pixels = pixels[sorted_inds]
74 | rgbs = rgbs[sorted_inds]
75 | depths = depths[sorted_inds]
76 |
77 | rgb = np.full((H, W, 3), 255, dtype=np.uint8)
78 | rgb[pixels[:, 1], pixels[:, 0]] = rgbs
79 |
80 | depth = np.zeros((H, W, 1), dtype=np.float32)
81 | depth[pixels[:, 1], pixels[:, 0]] = depths
82 |
83 | acc = np.zeros((H, W, 1), dtype=np.float32)
84 | acc[pixels[:, 1], pixels[:, 0]] = 1
85 |
86 | return Renderings(rgb=rgb, depth=depth, acc=acc)
87 |
88 |
89 | def _is_front(T: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
90 | # Triangle is front facing if its projection on XY plane is clock-wise.
91 | Z = (
92 | (T[:, 1, 0] - T[:, 0, 0]) * (T[:, 1, 1] + T[:, 0, 1])
93 | + (T[:, 2, 0] - T[:, 1, 0]) * (T[:, 2, 1] + T[:, 1, 1])
94 | + (T[:, 0, 0] - T[:, 2, 0]) * (T[:, 0, 1] + T[:, 2, 1])
95 | )
96 | return Z >= 0
97 |
98 |
99 | def grid_faces(h: int, w: int, mask: Optional[np.ndarray] = None) -> np.ndarray:
100 | """Creates mesh face indices from a given pixel grid size.
101 |
102 | Args:
103 | h (int): image height.
104 | w (int): image width.
105 | mask (Optional[np.ndarray], optional): mask of valid pixels. Defaults
106 | to None.
107 |
108 | Returns:
109 | faces (np.ndarray): array of face indices. Note that the face indices
110 | include invalid pixels (they are not excluded).
111 | """
112 | if mask is None:
113 | mask = np.ones((h, w), bool)
114 |
115 | x, y = np.meshgrid(range(w - 1), range(h - 1))
116 | tl = y * w + x
117 | tr = y * w + x + 1
118 | bl = (y + 1) * w + x
119 | br = (y + 1) * w + x + 1
120 | faces_1 = np.stack([tl, bl, tr], axis=-1)[mask[:-1, :-1] & mask[1:, :-1] & mask[:-1, 1:]]
121 | faces_2 = np.stack([br, tr, bl], axis=-1)[mask[1:, 1:] & mask[:-1, 1:] & mask[1:, :-1]]
122 | faces = np.concatenate([faces_1, faces_2], axis=0)
123 | return faces
124 |
125 |
126 | def visualize_mesh_renderings(
127 | points: np.ndarray, faces: np.ndarray, point_rgbs: np.ndarray, camera: geometry.Camera, **_
128 | ):
129 | """Visualize a mesh as a set renderings.
130 |
131 | Note that front facing triangles are defined in clock-wise orientation.
132 |
133 | Args:
134 | points (np.ndarray): (N, 3) array of points.
135 | faces (np.ndarray): (F, 3) array of faces.
136 | point_rgbs (np.ndarray): (N, 3) array of point colors in either uint8
137 | or float32.
138 | camera (geometry.Camera): a camera object containing view information.
139 |
140 | Returns:
141 | Renderings: the image output object.
142 | """
143 |
144 | # High quality output.
145 | DPI = 10.0
146 |
147 | face_rgbs = image.to_float32(point_rgbs[faces]).mean(axis=-2)
148 |
149 | # Setup the camera.
150 | W, H = camera.image_size
151 |
152 | # Project the 3D points to 2D on image plane.
153 | pixels, depth = camera.project(points, return_depth=True, use_projective_depth=True)
154 |
155 | T = pixels[faces]
156 | Z = -depth[faces][..., 0].mean(axis=1)
157 | front = _is_front(T)
158 | T, Z = T[front], Z[front]
159 | face_rgbs = face_rgbs[front]
160 |
161 | # Sort triangles according to z buffer.
162 | triangles = T[:, :, :2]
163 | sorted_inds = np.argsort(Z)
164 | triangles = triangles[sorted_inds]
165 | face_rgbs = face_rgbs[sorted_inds]
166 |
167 | # Painter's algorithm using matplotlib.
168 | fig = plt.figure(figsize=(W / DPI, H / DPI), dpi=DPI)
169 | canvas = FigureCanvasAgg(fig)
170 | ax = fig.add_axes([0, 0, 1, 1], xlim=[0, W], ylim=[H, 0], aspect=1)
171 | ax.axis("off")
172 |
173 | collection = PolyCollection([], closed=True)
174 | collection.set_verts(triangles)
175 | collection.set_linewidths(0.0)
176 | collection.set_facecolors(face_rgbs)
177 | ax.add_collection(collection)
178 |
179 | canvas.draw()
180 | s, _ = canvas.print_to_buffer()
181 | img = np.frombuffer(s, np.uint8).reshape((H, W, 4))
182 | plt.close(fig)
183 |
184 | rgb = img[..., :3]
185 | acc = (img[..., -1:] > 0).astype(np.float32)
186 |
187 | # Depth is not fully supported by maptlotlib yet and we render it as if
188 | # it's an image by a hack.
189 | fig = plt.figure(figsize=(W / DPI, H / DPI), dpi=DPI)
190 | canvas = FigureCanvasAgg(fig)
191 | ax = fig.add_axes([0, 0, 1, 1], xlim=[0, W], ylim=[H, 0], aspect=1)
192 | ax.axis("off")
193 |
194 | collection = PolyCollection([], closed=True)
195 | collection.set_verts(triangles)
196 | collection.set_linewidths(0.0)
197 | Z = -Z[sorted_inds]
198 | Zmin = Z.min()
199 | Zmax = Z.max()
200 | Z = (Z - Zmin) / (Zmax - Zmin)
201 | collection.set_facecolors(Z[..., None].repeat(3, axis=-1))
202 | ax.add_collection(collection)
203 |
204 | canvas.draw()
205 | s, _ = canvas.print_to_buffer()
206 | depth = image.to_float32(np.frombuffer(s, np.uint8).reshape((H, W, 4))[..., :1]) * (Zmax - Zmin) + Zmin
207 | depth[acc[..., 0] == 0] = 0
208 | plt.close(fig)
209 |
210 | return Renderings(rgb=rgb, depth=depth, acc=acc)
211 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import random
13 | import sys
14 | from datetime import datetime
15 |
16 | import numpy as np
17 | import torch
18 |
19 |
20 | def inverse_sigmoid(x):
21 | return torch.log(x / (1 - x))
22 |
23 |
24 | def PILtoTorch(pil_image, resolution):
25 | if resolution is not None:
26 | resized_image_PIL = pil_image.resize(resolution)
27 | else:
28 | resized_image_PIL = pil_image
29 | if np.array(resized_image_PIL).max() != 1:
30 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
31 | else:
32 | resized_image = torch.from_numpy(np.array(resized_image_PIL))
33 | if len(resized_image.shape) == 3:
34 | return resized_image.permute(2, 0, 1)
35 | else:
36 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
37 |
38 |
39 | def get_expon_lr_func(lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000):
40 | """
41 | Copied from Plenoxels
42 |
43 | Continuous learning rate decay function. Adapted from JaxNeRF
44 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
45 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
46 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
47 | function of lr_delay_mult, such that the initial learning rate is
48 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
49 | to the normal learning rate when steps>lr_delay_steps.
50 | :param conf: config subtree 'lr' or similar
51 | :param max_steps: int, the number of steps during optimization.
52 | :return HoF which takes step as input
53 | """
54 |
55 | def helper(step):
56 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
57 | # Disable this parameter
58 | return 0.0
59 | if lr_delay_steps > 0:
60 | # A kind of reverse cosine decay.
61 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
62 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
63 | )
64 | else:
65 | delay_rate = 1.0
66 | t = np.clip(step / max_steps, 0, 1)
67 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
68 |
69 | return delay_rate * log_lerp
70 |
71 | return helper
72 |
73 |
74 | def strip_lowerdiag(L):
75 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
76 |
77 | uncertainty[:, 0] = L[:, 0, 0]
78 | uncertainty[:, 1] = L[:, 0, 1]
79 | uncertainty[:, 2] = L[:, 0, 2]
80 | uncertainty[:, 3] = L[:, 1, 1]
81 | uncertainty[:, 4] = L[:, 1, 2]
82 | uncertainty[:, 5] = L[:, 2, 2]
83 | return uncertainty
84 |
85 |
86 | def strip_symmetric(sym):
87 | return strip_lowerdiag(sym)
88 |
89 |
90 | def build_rotation(r):
91 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
92 |
93 | q = r / norm[:, None]
94 |
95 | R = torch.zeros((q.size(0), 3, 3), device="cuda")
96 |
97 | r = q[:, 0]
98 | x = q[:, 1]
99 | y = q[:, 2]
100 | z = q[:, 3]
101 |
102 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
103 | R[:, 0, 1] = 2 * (x * y - r * z)
104 | R[:, 0, 2] = 2 * (x * z + r * y)
105 | R[:, 1, 0] = 2 * (x * y + r * z)
106 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
107 | R[:, 1, 2] = 2 * (y * z - r * x)
108 | R[:, 2, 0] = 2 * (x * z - r * y)
109 | R[:, 2, 1] = 2 * (y * z + r * x)
110 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
111 | return R
112 |
113 |
114 | def build_scaling_rotation(s, r):
115 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
116 | R = build_rotation(r)
117 |
118 | L[:, 0, 0] = s[:, 0]
119 | L[:, 1, 1] = s[:, 1]
120 | L[:, 2, 2] = s[:, 2]
121 |
122 | L = R @ L
123 | return L
124 |
125 |
126 | def safe_state(silent, seed=0):
127 | old_f = sys.stdout
128 |
129 | class F:
130 | def __init__(self, silent):
131 | self.silent = silent
132 |
133 | def write(self, x):
134 | if not self.silent:
135 | if x.endswith("\n"):
136 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
137 | else:
138 | old_f.write(x)
139 |
140 | def flush(self):
141 | old_f.flush()
142 |
143 | sys.stdout = F(silent)
144 |
145 | torch.manual_seed(seed)
146 | torch.cuda.manual_seed_all(seed)
147 | np.random.seed(seed)
148 | random.seed(seed)
149 | torch.backends.cudnn.deterministic = True
150 |
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import math
13 | from typing import NamedTuple
14 |
15 | import numpy as np
16 | import torch
17 |
18 |
19 | # from gsplat._torch_impl import clip_near_plane, scale_rot_to_cov3d, project_cov3d_ewa, compute_cov2d_bounds, project_pix
20 |
21 |
22 | class BasicPointCloud(NamedTuple):
23 | points: np.array
24 | colors: np.array
25 | normals: np.array
26 | times: np.array
27 |
28 |
29 | def geom_transform_points(points, transf_matrix):
30 | P, _ = points.shape
31 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
32 | points_hom = torch.cat([points, ones], dim=1)
33 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
34 |
35 | denom = points_out[..., 3:] + 0.0000001
36 | return (points_out[..., :3] / denom).squeeze(dim=0)
37 |
38 |
39 | def getWorld2View(R, t):
40 | Rt = np.zeros((4, 4))
41 | Rt[:3, :3] = R.transpose()
42 | Rt[:3, 3] = t
43 | Rt[3, 3] = 1.0
44 | return np.float32(Rt)
45 |
46 |
47 | def getWorld2View2(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
48 | Rt = np.zeros((4, 4))
49 | Rt[:3, :3] = R.transpose()
50 | Rt[:3, 3] = t
51 | Rt[3, 3] = 1.0
52 |
53 | C2W = np.linalg.inv(Rt)
54 | cam_center = C2W[:3, 3]
55 | cam_center = (cam_center + translate) * scale
56 | C2W[:3, 3] = cam_center
57 | Rt = np.linalg.inv(C2W)
58 | return np.float32(Rt)
59 |
60 | def getWorld2View2_torch(R, t, translate=torch.tensor([.0, .0, .0]), scale=1.0):
61 | Rt = torch.cat([R.transpose(0, 1), t.unsqueeze(1)], dim=1)
62 | Rt_fill = torch.tensor([0, 0, 0, 1], dtype=Rt.dtype, device=Rt.device).unsqueeze(0)
63 | Rt = torch.cat([Rt, Rt_fill], dim=0)
64 | return Rt
65 |
66 | def getProjectionMatrix(znear, zfar, fovX, fovY):
67 | tanHalfFovY = math.tan((fovY / 2))
68 | tanHalfFovX = math.tan((fovX / 2))
69 |
70 | top = tanHalfFovY * znear
71 | bottom = -top
72 | right = tanHalfFovX * znear
73 | left = -right
74 |
75 | P = torch.zeros(4, 4)
76 |
77 | z_sign = 1.0
78 |
79 | P[0, 0] = 2.0 * znear / (right - left)
80 | P[1, 1] = 2.0 * znear / (top - bottom)
81 | P[0, 2] = (right + left) / (right - left)
82 | P[1, 2] = (top + bottom) / (top - bottom)
83 | P[3, 2] = z_sign
84 | P[2, 2] = z_sign * zfar / (zfar - znear)
85 | P[2, 3] = -(zfar * znear) / (zfar - znear)
86 | return P
87 |
88 |
89 | def fov2focal(fov, pixels):
90 | return pixels / (2 * math.tan(fov / 2))
91 |
92 |
93 | def focal2fov(focal, pixels):
94 | return 2 * math.atan(pixels / (2 * focal))
95 |
96 |
97 | def apply_rotation(q1, q2):
98 | """
99 | Applies a rotation to a quaternion.
100 |
101 | Parameters:
102 | q1 (Tensor): The original quaternion.
103 | q2 (Tensor): The rotation quaternion to be applied.
104 |
105 | Returns:
106 | Tensor: The resulting quaternion after applying the rotation.
107 | """
108 | # Extract components for readability
109 | w1, x1, y1, z1 = q1
110 | w2, x2, y2, z2 = q2
111 |
112 | # Compute the product of the two quaternions
113 | w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
114 | x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
115 | y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
116 | z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
117 |
118 | # Combine the components into a new quaternion tensor
119 | q3 = torch.tensor([w3, x3, y3, z3])
120 |
121 | # Normalize the resulting quaternion
122 | q3_normalized = q3 / torch.norm(q3)
123 |
124 | return q3_normalized
125 |
126 |
127 | def batch_quaternion_multiply(q1, q2):
128 | """
129 | Multiply batches of quaternions.
130 |
131 | Args:
132 | - q1 (torch.Tensor): A tensor of shape [N, 4] representing the first batch of quaternions.
133 | - q2 (torch.Tensor): A tensor of shape [N, 4] representing the second batch of quaternions.
134 |
135 | Returns:
136 | - torch.Tensor: The resulting batch of quaternions after applying the rotation.
137 | """
138 | # Calculate the product of each quaternion in the batch
139 | w = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3]
140 | x = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2]
141 | y = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1]
142 | z = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0]
143 |
144 | # Combine into new quaternions
145 | q3 = torch.stack((w, x, y, z), dim=1)
146 |
147 | # Normalize the quaternions
148 | norm_q3 = q3 / torch.norm(q3, dim=1, keepdim=True)
149 |
150 | return norm_q3
151 |
152 |
153 | def cam2pixel(pts, K):
154 | pixels = torch.matmul(K, pts.unsqueeze(-1)).squeeze(-1)
155 | pixels = pixels[:, :2] / (pixels[:, 2:] + 0.0000001)
156 | return pixels
157 |
158 |
159 | def pts2pixel(pts, cam_info, K):
160 | cam_pts = geom_transform_points(
161 | pts, cam_info.world_view_transform.to(pts.device)
162 | ) #! this is column-wise transformation
163 | # # check
164 | # view_R = torch.tensor(cam_info.R.T, device=pts.device)
165 | # view_T = torch.tensor(cam_info.T, device=pts.device)
166 | # cam_pts = torch.matmul(view_R[None], pts[:, :, None]).squeeze(-1) + view_T[None]
167 |
168 | pixels = cam2pixel(cam_pts, K)
169 | return pixels
170 |
171 |
172 | def project_gaussians_forward(
173 | means3d,
174 | scales,
175 | glob_scale,
176 | quats,
177 | viewmat,
178 | intrins,
179 | img_size,
180 | clip_thresh=0.01,
181 | ):
182 | fx, fy, cx, cy = intrins
183 | tan_fovx = 0.5 * img_size[0] / fx
184 | tan_fovy = 0.5 * img_size[1] / fy
185 | p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
186 | cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
187 | cov2d, compensation = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
188 | xys = project_pix((fx, fy), p_view, (cx, cy))
189 |
190 | return xys, cov2d
191 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 |
14 |
15 | def mse(img1, img2):
16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
17 |
18 |
19 | @torch.no_grad()
20 | def psnr(img1, img2, mask=None):
21 | if mask is not None:
22 | img1 = img1.flatten(1)
23 | img2 = img2.flatten(1)
24 |
25 | mask = mask.flatten(1).repeat(3, 1)
26 | mask = torch.where(mask != 0, True, False)
27 | img1 = img1[mask]
28 | img2 = img2[mask]
29 |
30 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
31 |
32 | else:
33 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
34 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float()))
35 | if mask is not None:
36 | if torch.isinf(psnr).any():
37 | print(mse.mean(), psnr.mean())
38 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float()))
39 | psnr = psnr[~torch.isinf(psnr)]
40 |
41 | return psnr
42 |
--------------------------------------------------------------------------------
/utils/loader_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | def get_stamp_list(dataset, timestamp):
8 | frame_length = int(len(dataset) / len(dataset.dataset.poses))
9 | # print(frame_length)
10 | if timestamp > frame_length:
11 | raise IndexError("input timestamp bigger than total timestamp.")
12 | print("select index:", [i * frame_length + timestamp for i in range(len(dataset.dataset.poses))])
13 | return [dataset[i * frame_length + timestamp] for i in range(len(dataset.dataset.poses))]
14 |
15 |
16 | class FineSampler(Sampler):
17 | def __init__(self, dataset):
18 | self.len_dataset = len(dataset)
19 | self.len_pose = len(dataset.dataset.poses)
20 | self.frame_length = int(self.len_dataset / self.len_pose)
21 |
22 | sample_list = []
23 | for i in range(self.frame_length):
24 | for j in range(4):
25 | idx = torch.randperm(self.len_pose) * self.frame_length + i
26 | # print(idx)
27 | # breakpoint()
28 | now_list = []
29 | cnt = 0
30 | for item in idx.tolist():
31 | now_list.append(item)
32 | cnt += 1
33 | if cnt % 2 == 0 and len(sample_list) > 2:
34 | select_element = [x for x in random.sample(sample_list, 2)]
35 | now_list += select_element
36 |
37 | sample_list += now_list
38 |
39 | self.sample_list = sample_list
40 | # print(self.sample_list)
41 | # breakpoint()
42 | print("one epoch containing:", len(self.sample_list))
43 |
44 | def __iter__(self):
45 | return iter(self.sample_list)
46 |
47 | def __len__(self):
48 | return len(self.sample_list)
49 |
--------------------------------------------------------------------------------
/utils/main_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from matplotlib import cm
7 | from PIL import Image
8 |
9 |
10 | __all__ = (
11 | "save_debug_imgs",
12 | "get_normals",
13 | "sw_cams",
14 | "sw_depth_normalization",
15 | "error_to_prob",
16 | "get_rays",
17 | "get_gs_mask",
18 | "get_pixels",
19 | )
20 |
21 | def to8b(x):
22 | return (255 * np.clip(x.cpu().numpy(), 0, 1)).astype(np.uint8)
23 |
24 | def get_gs_mask(s_image_tensor, gt_image_tensor, s_depth_tensor, depth_tensor, CVD):
25 | B, C, H, W = s_image_tensor.shape
26 |
27 | # Color based
28 | gs_error = torch.mean(torch.abs(s_image_tensor - gt_image_tensor), 1, True)
29 | gs_mask_c = error_to_prob(gs_error.detach())
30 |
31 | # Depth based
32 | gs_mask_d = error_to_prob(torch.mean(torch.abs(s_depth_tensor - depth_tensor), 1, True).detach())
33 | norm_disp = 1 / (CVD + 1e-7)
34 | norm_disp = (norm_disp + F.max_pool2d(-norm_disp, kernel_size=(H, W))) / (
35 | F.max_pool2d(norm_disp, kernel_size=(H, W)) + F.max_pool2d(-norm_disp, kernel_size=(H, W))
36 | )
37 | gs_mask_d = 1 - norm_disp * (1 - gs_mask_d)
38 |
39 | return gs_mask_c.detach(), gs_mask_d.detach()
40 |
41 |
42 | def get_pixels(image_size_x, image_size_y, use_center=None):
43 | """Return the pixel at center or corner."""
44 | xx, yy = np.meshgrid(
45 | np.arange(image_size_x, dtype=np.float32),
46 | np.arange(image_size_y, dtype=np.float32),
47 | )
48 | offset = 0.5 if use_center else 0
49 | return np.stack([xx, yy], axis=-1) + offset
50 |
51 |
52 | def error_to_prob(error, mask=None, mean_prob=0.5):
53 | if mask is None:
54 | mean_err = torch.mean(error, dim=(3, 2, 1)) + 1e-7
55 | else:
56 | mean_err = torch.sum(mask * error, dim=(3, 2)) / (torch.sum(mask, dim=(3, 2)) + 1e-7) + 1e-7
57 | prob = mean_prob * (error / mean_err.view(error.shape[0], 1, 1, 1))
58 | prob[prob > 1] = 1
59 | prob = 1 - prob
60 | return prob
61 |
62 |
63 | def get_rays(H, W, K, c2w):
64 | i, j = torch.meshgrid(
65 | torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)
66 | ) # pytorch's meshgrid has indexing='ij'
67 | i = i.t().type_as(K)
68 | j = j.t().type_as(K)
69 | dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
70 | # Rotate ray directions from camera frame to the world frame
71 | rays_d = torch.sum(
72 | dirs[..., np.newaxis, :] * c2w[:3, :3], -1
73 | ) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
74 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
75 | rays_o = c2w[:3, -1].expand(rays_d.shape)
76 | return rays_o, rays_d
77 |
78 |
79 | def flow2rgb(flow_map, max_value):
80 | _, h, w = flow_map.shape
81 | flow_map[:, (flow_map[0] == 0) & (flow_map[1] == 0)] = float("nan")
82 | rgb_map = np.ones((3, h, w)).astype(np.float32)
83 | if max_value is not None:
84 | normalized_flow_map = flow_map / max_value
85 | else:
86 | normalized_flow_map = flow_map / (np.abs(flow_map).max())
87 | rgb_map[0, :, :] += normalized_flow_map[0, :, :]
88 | rgb_map[1, :, :] -= 0.5 * (normalized_flow_map[0, :, :] + normalized_flow_map[1, :, :])
89 | rgb_map[2, :, :] += normalized_flow_map[1, :, :]
90 | return rgb_map.clip(0, 1)
91 |
92 |
93 | def save_debug_imgs(debug_dict, idx, epoch=0, deb_path=None, ext="jpg"):
94 | new_outputs = {}
95 | for key in debug_dict.keys():
96 | out_tensor = debug_dict[key].detach()
97 | if out_tensor.shape[1] == 3:
98 | if "normal" in key:
99 | p_im = (out_tensor[idx].squeeze().cpu().numpy() + 1) / 2
100 | p_im[p_im > 1] = 1
101 | p_im[p_im < 0] = 0
102 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8)
103 | else:
104 | p_im = out_tensor[idx].squeeze().cpu().numpy()
105 | p_im[p_im > 1] = 1
106 | p_im[p_im < 0] = 0
107 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8)
108 | elif out_tensor.shape[1] == 2:
109 | p_im = flow2rgb(out_tensor[idx].squeeze().cpu().numpy(), None)
110 | p_im = np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8)
111 | elif out_tensor.shape[1] == 1:
112 | if "disp" in key:
113 | nmap = out_tensor[idx].squeeze().cpu().numpy()
114 | nmap = np.clip(nmap / (np.percentile(nmap, 99) + 1e-6), 0, 1)
115 | p_im = (255 * cm.plasma(nmap)).astype(np.uint8)
116 | elif "error" in key:
117 | nmap = out_tensor[idx].squeeze().cpu().numpy()
118 | nmap = np.clip(nmap, 0, 1)
119 | p_im = (255 * cm.jet(nmap)).astype(np.uint8)
120 | else:
121 | B, C, H, W = out_tensor.shape
122 | the_max = torch.max_pool2d(out_tensor, kernel_size=(H, W))
123 | nmap = out_tensor / the_max
124 | p_im = nmap[idx].squeeze().cpu().numpy()
125 | p_im = np.rint(255 * p_im).astype(np.uint8)
126 |
127 | # Save or return normalized image
128 | if deb_path is not None:
129 | if len(p_im.shape) == 3:
130 | p_im = p_im[:, :, 0:3]
131 | im = Image.fromarray(p_im)
132 | im.save(os.path.join(deb_path, "e{}_{}.{}".format(epoch, key, ext)))
133 | else:
134 | new_outputs[key] = p_im
135 |
136 | if deb_path is None:
137 | return new_outputs
138 |
139 |
140 | def get_normals(z, camera_metadata):
141 | pixels = camera_metadata.get_pixels()
142 | y = (pixels[..., 1] - camera_metadata.principal_point_y) / camera_metadata.scale_factor_y
143 | x = (
144 | pixels[..., 0] - camera_metadata.principal_point_x - y * camera_metadata.skew
145 | ) / camera_metadata.scale_factor_x
146 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1)
147 | viewdirs = torch.from_numpy(viewdirs).to(z.device)
148 |
149 | coords = viewdirs[None] * z[..., None]
150 | coords = coords.permute(0, 3, 1, 2)
151 |
152 | dxdu = coords[..., 0, :, 1:] - coords[..., 0, :, :-1]
153 | dydu = coords[..., 1, :, 1:] - coords[..., 1, :, :-1]
154 | dzdu = coords[..., 2, :, 1:] - coords[..., 2, :, :-1]
155 | dxdv = coords[..., 0, 1:, :] - coords[..., 0, :-1, :]
156 | dydv = coords[..., 1, 1:, :] - coords[..., 1, :-1, :]
157 | dzdv = coords[..., 2, 1:, :] - coords[..., 2, :-1, :]
158 |
159 | dxdu = torch.nn.functional.pad(dxdu, (0, 1), mode="replicate")
160 | dydu = torch.nn.functional.pad(dydu, (0, 1), mode="replicate")
161 | dzdu = torch.nn.functional.pad(dzdu, (0, 1), mode="replicate")
162 |
163 | dxdv = torch.cat([dxdv, dxdv[..., -1:, :]], dim=-2)
164 | dydv = torch.cat([dydv, dydv[..., -1:, :]], dim=-2)
165 | dzdv = torch.cat([dzdv, dzdv[..., -1:, :]], dim=-2)
166 |
167 | n_x = dydv * dzdu - dydu * dzdv
168 | n_y = dzdv * dxdu - dzdu * dxdv
169 | n_z = dxdv * dydu - dxdu * dydv
170 |
171 | pred_normal = torch.stack([n_x, n_y, n_z], dim=-3)
172 | pred_normal = torch.nn.functional.normalize(pred_normal, dim=-3)
173 | return pred_normal
174 |
175 | # coords = coords.squeeze(0)
176 | # hd, wd, _ = coords.shape
177 | # bottom_point = coords[..., 2:hd, 1 : wd - 1, :]
178 | # top_point = coords[..., 0 : hd - 2, 1 : wd - 1, :]
179 | # right_point = coords[..., 1 : hd - 1, 2:wd, :]
180 | # left_point = coords[..., 1 : hd - 1, 0 : wd - 2, :]
181 | # left_to_right = right_point - left_point
182 | # bottom_to_top = top_point - bottom_point
183 | # xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)
184 | # xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)
185 | # pred_normal = torch.nn.functional.pad(xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant")
186 | # return pred_normal[None]
187 |
188 |
189 | def sw_cams(viewpoint_stack, cam_id, sw_size=2):
190 | viewpoint_cams_window = [viewpoint_stack[cam_id]]
191 | for sw in range(1, sw_size + 1):
192 | if cam_id - sw >= 0:
193 | viewpoint_cams_window.append(viewpoint_stack[cam_id - sw])
194 | if cam_id + sw < len(viewpoint_stack):
195 | viewpoint_cams_window.append(viewpoint_stack[cam_id + sw])
196 | return viewpoint_cams_window
197 |
198 |
199 | def sw_depth_normalization(viewpoint_cams_window_list, depth_tensor, batch_size):
200 | for n_batch in range(batch_size):
201 | depth_window = []
202 | for viewpoint_cams_window in viewpoint_cams_window_list[n_batch]:
203 | depth_window.append(viewpoint_cams_window.depth[None].cuda())
204 | depth_window = torch.cat(depth_window, 0)
205 | depth_window_min = torch.min(depth_window).cuda()
206 | depth_window_max = torch.max(depth_window).cuda()
207 | depth_tensor[n_batch] = (depth_tensor[n_batch] - depth_window_min) / (depth_window_max - depth_window_min)
208 | return depth_tensor
209 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | # Copyright (c) 2023 OPPO
4 |
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 |
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 |
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | import numpy as np
24 | import torch
25 |
26 | # from mmcv.ops import knn
27 | import torch.nn as nn
28 |
29 |
30 | class Sandwich(nn.Module):
31 | def __init__(self, dim, outdim=3, bias=False):
32 | super(Sandwich, self).__init__()
33 |
34 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias) #
35 |
36 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
37 | self.relu = nn.ReLU()
38 |
39 | self.sigmoid = torch.nn.Sigmoid()
40 |
41 | def forward(self, input, rays, time=None):
42 | albedo, spec, timefeature = input.chunk(3, dim=1)
43 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5
44 | specular = self.mlp1(specular)
45 | specular = self.relu(specular)
46 | specular = self.mlp2(specular)
47 |
48 | result = albedo + specular
49 | result = self.sigmoid(result)
50 | return result
51 |
52 |
53 | class Sandwichnoact(nn.Module):
54 | def __init__(self, dim, outdim=3, bias=False):
55 | super(Sandwichnoact, self).__init__()
56 |
57 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias)
58 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
59 | self.relu = nn.ReLU()
60 |
61 | def forward(self, input, rays, time=None):
62 | albedo, spec, timefeature = input.chunk(3, dim=1)
63 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5
64 | specular = self.mlp1(specular)
65 | specular = self.relu(specular)
66 | specular = self.mlp2(specular)
67 |
68 | result = albedo + specular
69 | result = torch.clamp(result, min=0.0, max=1.0)
70 | return result
71 |
72 |
73 | class Sandwichnoactss(nn.Module):
74 | def __init__(self, dim, outdim=3, bias=False):
75 | super(Sandwichnoactss, self).__init__()
76 |
77 | self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias)
78 | self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
79 |
80 | self.relu = nn.ReLU()
81 |
82 | def forward(self, input, rays, time=None):
83 | albedo, spec, timefeature = input.chunk(3, dim=1)
84 | specular = torch.cat([spec, timefeature, rays], dim=1) # 3+3 + 5
85 | specular = self.mlp1(specular)
86 | specular = self.relu(specular)
87 | specular = self.mlp2(specular)
88 |
89 | result = albedo + specular
90 | return result
91 |
92 |
93 | ####### following are also good rgb model but not used in the paper, slower than sandwich, inspired by color shift in hyperreel
94 | # remove sigmoid for immersive dataset
95 | class RGBDecoderVRayShift(nn.Module):
96 | def __init__(self, dim, outdim=3, bias=False):
97 | super(RGBDecoderVRayShift, self).__init__()
98 |
99 | self.mlp1 = nn.Conv2d(dim, outdim, kernel_size=1, bias=bias)
100 | self.mlp2 = nn.Conv2d(15, outdim, kernel_size=1, bias=bias)
101 | self.mlp3 = nn.Conv2d(6, outdim, kernel_size=1, bias=bias)
102 | self.sigmoid = torch.nn.Sigmoid()
103 |
104 | self.dwconv1 = nn.Conv2d(9, 9, kernel_size=1, bias=bias)
105 |
106 | def forward(self, input, rays, t=None):
107 | x = self.dwconv1(input) + input
108 | albeado = self.mlp1(x)
109 | specualr = torch.cat([x, rays], dim=1)
110 | specualr = self.mlp2(specualr)
111 |
112 | finalfeature = torch.cat([albeado, specualr], dim=1)
113 | result = self.mlp3(finalfeature)
114 | result = self.sigmoid(result)
115 | return result
116 |
117 | def getcolormodel(rgbfuntion):
118 | if rgbfuntion == "sandwich":
119 | rgbdecoder = Sandwich(9, 3)
120 |
121 | elif rgbfuntion == "sandwichnoact":
122 | rgbdecoder = Sandwichnoact(9, 3)
123 | elif rgbfuntion == "sandwichnoactss":
124 | rgbdecoder = Sandwichnoactss(9, 3)
125 | else:
126 | return None
127 | return rgbdecoder
128 |
129 |
130 | def pix2ndc(v, S):
131 | return (v * 2.0 + 1.0) / S - 1.0
132 |
133 |
134 | def ndc2pix(v, S):
135 | return ((v + 1.0) * S - 1.0) * 0.5
136 |
--------------------------------------------------------------------------------
/utils/params_utils.py:
--------------------------------------------------------------------------------
1 | def merge_hparams(args, config):
2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"]
3 | for param in params:
4 | if param in config.keys():
5 | for key, value in config[param].items():
6 | if hasattr(args, key):
7 | setattr(args, key, value)
8 |
9 | return args
10 |
--------------------------------------------------------------------------------
/utils/point_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import torch
4 | from torch.utils.data import TensorDataset, random_split
5 | from torch_cluster import grid_cluster
6 | from sklearn.neighbors import NearestNeighbors
7 |
8 | def voxel_down_sample_custom(points, voxel_size):
9 | # 将点云归一化到体素网格
10 | voxel_grid = torch.floor(points / voxel_size)
11 |
12 | # 找到唯一的体素,并获取它们在原始体素网格中的索引
13 | unique_voxels, inverse_indices = torch.unique(voxel_grid, dim=0, return_inverse=True)
14 |
15 | # 创建一个新的点云,其中每个点是其对应体素中所有点的平均值
16 | new_points = torch.zeros_like(unique_voxels)
17 | new_points_count = torch.zeros(unique_voxels.size(0), dtype=torch.long)
18 | # for i in tqdm(range(points.size(0))):
19 | new_points[inverse_indices] = points
20 | # new_points_count[inverse_indices[i]] += 1
21 | # new_points /= new_points_count.unsqueeze(-1)
22 |
23 | return new_points, inverse_indices
24 |
25 |
26 | def downsample_point_cloud(points, ratio):
27 | # 创建一个TensorDataset
28 | dataset = TensorDataset(points)
29 |
30 | # 计算下采样后的点的数量
31 | num_points = len(dataset)
32 | num_downsampled_points = int(num_points * ratio)
33 |
34 | # 使用random_split进行下采样
35 | downsampled_dataset, _ = random_split(dataset, [num_downsampled_points, num_points - num_downsampled_points])
36 |
37 | # 获取下采样后的点的index和点云矩阵
38 | indices = torch.tensor([i for i, _ in enumerate(downsampled_dataset)])
39 | downsampled_points = torch.stack([x for x, in downsampled_dataset])
40 |
41 | return indices, downsampled_points
42 |
43 |
44 | def downsample_point_cloud_open3d(points, voxel_size):
45 | # 创建一个点云对象
46 |
47 | downsampled_pcd, inverse_indices = voxel_down_sample_custom(points, voxel_size)
48 | downsampled_points = downsampled_pcd
49 | # 获取下采样后的点云矩阵
50 |
51 | return torch.tensor(downsampled_points)
52 |
53 |
54 | def downsample_point_cloud_cluster(points, voxel_size):
55 | # 创建一个点云对象
56 | cluster = grid_cluster(points, size=torch.tensor([1, 1, 1]))
57 |
58 | # 获取下采样后的点云矩阵
59 | # downsampled_points = np.asarray(downsampled_pcd.points)
60 |
61 | return cluster, points
62 |
63 |
64 | def upsample_point_cloud(points, density_threshold, displacement_scale, iter_pass):
65 | # 计算每个点的密度
66 | # breakpoint()
67 | try:
68 | nbrs = NearestNeighbors(n_neighbors=2 + iter_pass, algorithm="ball_tree").fit(points)
69 | distances, indices = nbrs.kneighbors(points)
70 | except:
71 | print("no point added")
72 | return points, torch.tensor([]), torch.tensor([]), torch.zeros((points.shape[0]), dtype=torch.bool)
73 |
74 | # 找出密度低的点
75 | low_density_points = points[distances[:, 1] > density_threshold]
76 | low_density_index = distances[:, 1] > density_threshold
77 | low_density_index = torch.from_numpy(low_density_index)
78 | # 复制这些点并添加随机位移
79 | num_points = low_density_points.shape[0]
80 | displacements = torch.randn(num_points, 3) * displacement_scale
81 | new_points = low_density_points + displacements
82 | # 返回新的点云矩阵
83 | return points, low_density_points, new_points, low_density_index
84 |
85 |
86 | def visualize_point_cloud(points, low_density_points, new_points):
87 | # 创建一个点云对象
88 | pcd = o3d.geometry.PointCloud()
89 |
90 | # 给被选中的点云添加一个小的偏移量
91 | low_density_points += 0.01
92 |
93 | # 将所有的点合并到一起
94 | all_points = np.concatenate([points, low_density_points, new_points], axis=0)
95 | pcd.points = o3d.utility.Vector3dVector(all_points)
96 |
97 | # 创建颜色数组
98 | colors = np.zeros((all_points.shape[0], 3))
99 | colors[: points.shape[0]] = [0, 0, 0] # 黑色表示初始化的点云
100 | colors[points.shape[0] : points.shape[0] + low_density_points.shape[0]] = [1, 0, 0] # 红色表示被选中的点云
101 | colors[points.shape[0] + low_density_points.shape[0] :] = [0, 1, 0] # 绿色表示增长的点云
102 | pcd.colors = o3d.utility.Vector3dVector(colors)
103 |
104 | # 显示点云
105 | o3d.visualization.draw_geometries([pcd])
106 |
107 |
108 | def combine_pointcloud(points, low_density_points, new_points):
109 | pcd = o3d.geometry.PointCloud()
110 |
111 | # 给被选中的点云添加一个小的偏移量
112 | low_density_points += 0.01
113 | new_points -= 0.01
114 | # 将所有的点合并到一起
115 | all_points = np.concatenate([points, low_density_points, new_points], axis=0)
116 | pcd.points = o3d.utility.Vector3dVector(all_points)
117 |
118 | # 创建颜色数组
119 | colors = np.zeros((all_points.shape[0], 3))
120 | colors[: points.shape[0]] = [0, 0, 0] # 黑色表示初始化的点云
121 | colors[points.shape[0] : points.shape[0] + low_density_points.shape[0]] = [1, 0, 0] # 红色表示被选中的点云
122 | colors[points.shape[0] + low_density_points.shape[0] :] = [0, 1, 0] # 绿色表示增长的点云
123 | pcd.colors = o3d.utility.Vector3dVector(colors)
124 | return pcd
125 |
126 |
127 | def addpoint(
128 | point_cloud,
129 | density_threshold,
130 | displacement_scale,
131 | iter_pass,
132 | ):
133 | # density_threshold: 密度的阈值,越大能筛选出越稀疏的点。
134 | # displacement_scale: 在以displacement_scale的圆心内随机生成点
135 |
136 | points, low_density_points, new_points, low_density_index = upsample_point_cloud(
137 | point_cloud, density_threshold, displacement_scale, iter_pass
138 | )
139 | # breakpoint()
140 | # breakpoint()
141 | print("low_density_points", low_density_points.shape[0])
142 |
143 | return point_cloud, low_density_points, new_points, low_density_index
144 |
145 |
146 | def find_point_indices(origin_point, goal_point):
147 | indices = torch.nonzero((origin_point[:, None] == goal_point).all(-1), as_tuple=True)[0]
148 | return indices
149 |
150 |
151 | def find_indices_in_A(A, B):
152 | """
153 | 找出子集矩阵 B 中每个点在点云矩阵 A 中的索引 u。
154 |
155 | 参数:
156 | A (torch.Tensor): 点云矩阵 A,大小为 [N, 3]。
157 | B (torch.Tensor): 子集矩阵 B,大小为 [M, 3]。
158 |
159 | 返回:
160 | torch.Tensor: 包含 B 中每个点在 A 中的索引 u 的张量,形状为 (M,)。
161 | """
162 | is_equal = torch.eq(B.view(1, -1, 3), A.view(-1, 1, 3))
163 | u_indices = torch.nonzero(is_equal, as_tuple=False)[:, 0]
164 | return torch.unique(u_indices)
165 |
166 |
167 | if __name__ == "__main__":
168 | #
169 | from time import time
170 |
171 | pass_ = 0
172 | # filename=f"pointcloud/pass_{pass_}.ply"
173 | filename = "point_cloud.ply"
174 | pcd = o3d.io.read_point_cloud(filename)
175 | point_cloud = torch.tensor(pcd.points)
176 | voxel_size = 8
177 | density_threshold = 20
178 | displacement_scale = 5
179 | for i in range(pass_ + 1, 50):
180 | print("pass ", i)
181 | time0 = time()
182 |
183 | point_downsample = point_cloud
184 | flag = False
185 | while point_downsample.shape[0] > 1000:
186 | if flag:
187 | voxel_size += 8
188 | print("point size:", point_downsample.shape[0])
189 | point_downsample = downsample_point_cloud_open3d(point_cloud, voxel_size=voxel_size)
190 | flag = True
191 |
192 | print("point size:", point_downsample.shape[0])
193 | # downsampled_point_index = find_point_indices(point_cloud, point_downsample)
194 | downsampled_point_index = find_indices_in_A(point_cloud, point_downsample)
195 | print("selected_num", point_cloud[downsampled_point_index].shape[0])
196 | _, low_density_points, new_points, low_density_index = addpoint(
197 | point_cloud[downsampled_point_index],
198 | density_threshold=density_threshold,
199 | displacement_scale=displacement_scale,
200 | iter_pass=0,
201 | )
202 | if new_points.shape[0] < 100:
203 | density_threshold /= 2
204 | displacement_scale /= 2
205 | print("reduce diplacement_scale to: ", displacement_scale)
206 |
207 | global_mask = torch.zeros((point_cloud.shape[0]), dtype=torch.bool)
208 |
209 | global_mask[downsampled_point_index] = low_density_index
210 | time1 = time()
211 |
212 | print("time cost:", time1 - time0, "new_points:", new_points.shape[0])
213 | if low_density_points.shape[0] == 0:
214 | print("no more points.")
215 | continue
216 | # breakpoint()
217 | point = combine_pointcloud(point_cloud, low_density_points, new_points)
218 | point_cloud = torch.tensor(point.points)
219 | o3d.io.write_point_cloud(f"pointcloud/pass_{i}.ply", point)
220 | # visualize_qpoint_cloud( point_cloud, low_density_points, new_points)
221 |
--------------------------------------------------------------------------------
/utils/pose_utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 |
6 |
7 | def rotation_matrix_to_quaternion(rotation_matrix):
8 | """将旋转矩阵转换为四元数"""
9 | return R.from_matrix(rotation_matrix).as_quat()
10 |
11 |
12 | def quaternion_to_rotation_matrix(quat):
13 | """将四元数转换为旋转矩阵"""
14 | return R.from_quat(quat).as_matrix()
15 |
16 |
17 | def quaternion_slerp(q1, q2, t):
18 | """在两个四元数之间进行球面线性插值(SLERP)"""
19 | # 计算两个四元数之间的点积
20 | dot = np.dot(q1, q2)
21 |
22 | # 如果点积为负,取反一个四元数以保证最短路径插值
23 | if dot < 0.0:
24 | q1 = -q1
25 | dot = -dot
26 |
27 | # 防止数值误差导致的问题
28 | dot = np.clip(dot, -1.0, 1.0)
29 |
30 | # 计算插值参数
31 | theta = np.arccos(dot) * t
32 | q3 = q2 - q1 * dot
33 | q3 = q3 / np.linalg.norm(q3)
34 |
35 | # 计算插值结果
36 | return np.cos(theta) * q1 + np.sin(theta) * q3
37 |
38 |
39 | def bezier_interpolation(p1, p2, t):
40 | """在两点之间使用贝塞尔曲线进行插值"""
41 | return (1 - t) * p1 + t * p2
42 |
43 |
44 | def linear_interpolation(v1, v2, t):
45 | """线性插值"""
46 | return (1 - t) * v1 + t * v2
47 |
48 |
49 | def smooth_camera_poses(cameras, num_interpolations=5):
50 | """对一系列相机位姿进行平滑处理,通过在每对位姿之间插入额外的位姿"""
51 | smoothed_cameras = []
52 | smoothed_times = []
53 | total_poses = len(cameras) - 1 + (len(cameras) - 1) * num_interpolations
54 | time_increment = 10 / total_poses
55 |
56 | for i in range(len(cameras) - 1):
57 | cam1 = cameras[i]
58 | cam2 = cameras[i + 1]
59 |
60 | # 将旋转矩阵转换为四元数
61 | quat1 = rotation_matrix_to_quaternion(cam1.orientation)
62 | quat2 = rotation_matrix_to_quaternion(cam2.orientation)
63 |
64 | for j in range(num_interpolations + 1):
65 | t = j / (num_interpolations + 1)
66 |
67 | # 插值方向
68 | interp_orientation_quat = quaternion_slerp(quat1, quat2, t)
69 | interp_orientation_matrix = quaternion_to_rotation_matrix(interp_orientation_quat)
70 |
71 | # 插值位置
72 | interp_position = linear_interpolation(cam1.position, cam2.position, t)
73 |
74 | # 计算插值时间戳
75 | interp_time = i * 10 / (len(cameras) - 1) + time_increment * j
76 |
77 | # 添加新的相机位姿和时间戳
78 | newcam = deepcopy(cam1)
79 | newcam.orientation = interp_orientation_matrix
80 | newcam.position = interp_position
81 | smoothed_cameras.append(newcam)
82 | smoothed_times.append(interp_time)
83 |
84 | # 添加最后一个原始位姿和时间戳
85 | smoothed_cameras.append(cameras[-1])
86 | smoothed_times.append(1.0)
87 | print(smoothed_times)
88 | return smoothed_cameras, smoothed_times
89 |
90 |
91 | # # 示例:使用两个相机位姿
92 | # cam1 = Camera(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), np.array([0, 0, 0]))
93 | # cam2 = Camera(np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]), np.array([1, 1, 1]))
94 |
95 | # # 应用平滑处理
96 | # smoothed_cameras = smooth_camera_poses([cam1, cam2], num_interpolations=5)
97 |
98 | # # 打印结果
99 | # for cam in smoothed_cameras:
100 | # print("Orientation:\n", cam.orientation)
101 | # print("Position:", cam.position)
102 |
--------------------------------------------------------------------------------
/utils/render_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.no_grad()
5 | def get_state_at_time(pc, viewpoint_camera):
6 | means3D = pc.get_xyz
7 | time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0], 1)
8 | opacity = pc._opacity
9 | shs = pc.get_features
10 |
11 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
12 | # scaling / rotation by the rasterizer.
13 | scales = pc._scaling
14 | rotations = pc._rotation
15 |
16 | # time0 = get_time()
17 | # means3D_deform, scales_deform, rotations_deform, opacity_deform = pc._deformation(means3D[deformation_point], scales[deformation_point],
18 | # rotations[deformation_point], opacity[deformation_point],
19 | # time[deformation_point])
20 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = pc._deformation(
21 | means3D, scales, rotations, opacity, shs, time
22 | )
23 | scales_final = pc.scaling_activation(scales_final)
24 | rotations_final = pc.rotation_activation(rotations_final)
25 | opacity = pc.opacity_activation(opacity_final)
26 | return means3D_final, scales_final, rotations_final, opacity, shs
27 |
--------------------------------------------------------------------------------
/utils/scene_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from matplotlib import pyplot as plt
5 | from PIL import Image, ImageDraw, ImageFont
6 |
7 |
8 | plt.rcParams["font.sans-serif"] = ["Times New Roman"]
9 |
10 |
11 | import numpy as np
12 | from scene import deformation
13 |
14 |
15 | @torch.no_grad()
16 | def render_training_image(
17 | scene,
18 | stat_gaussians,
19 | dyn_gaussians,
20 | viewpoints,
21 | render_func,
22 | pipe,
23 | background,
24 | stage,
25 | iteration,
26 | time_now,
27 | dataset_type,
28 | is_train=False,
29 | over_t=None,
30 | ):
31 | # Get base parameters for warping
32 | if stage == "warm":
33 | viewpoint = viewpoints[0]
34 | gt_normal = viewpoint.normal[None].cuda()
35 | gt_normal.reshape(-1, 3)
36 | pixels = viewpoint.metadata.get_pixels(normalize=True)
37 | pixels = torch.from_numpy(pixels).cuda()
38 | pixels.reshape(-1, 2)
39 | gt_depth = viewpoint.depth[None].cuda()
40 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[2::])
41 | depth_in = gt_depth.reshape(-1, 1)
42 | time_in_0 = torch.tensor(0).float().cuda()
43 | time_in_0 = time_in_0.view(1, 1)
44 | # pred_R0, pred_T0, depth_scale_T0, depth_shift_T0 = dyn_gaussians._posenet(time_in_0, depth=depth_in, normals=normal_in, pixel=pixels_in)
45 | pred_R0, pred_T0, CVD0 = dyn_gaussians._posenet(time_in_0, depth=depth_in)
46 |
47 | def render(static_gaussians, dynamic_gaussians, viewpoint, path, scaling, cam_type, over_t=None):
48 | if stage == "warm":
49 | # Get warped images (all warped to time step 0)
50 | if dataset_type == "PanopticSports":
51 | image_tensor = viewpoint["image"][None].cuda()
52 | else:
53 | image_tensor = viewpoint.original_image[None].cuda()
54 | B, C, H, W = image_tensor.shape
55 |
56 | viewpoint.normal[None].cuda().reshape(-1, 3)
57 | pixels = viewpoint.metadata.get_pixels(normalize=True)
58 | torch.from_numpy(pixels).cuda().reshape(-1, 2)
59 | gt_depth = viewpoint.depth[None].cuda()
60 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[2::])
61 | depth_in = gt_depth.reshape(-1, 1)
62 | time_in = torch.tensor(viewpoint.time).float().cuda()
63 | time_in = time_in.view(1, 1)
64 | # CVD = depth_in.view(1, 1, H, W)
65 |
66 | # CVD = CVD_T0.view(1, 1, H, W)
67 | # pred_R, pred_T, _, _ = dynamic_gaussians._posenet(time_in, depth=depth_in, normals=normal_in, pixel=pixels_in)
68 | pred_R, pred_T, CVD = dynamic_gaussians._posenet(time_in, depth=depth_in)
69 |
70 | # depth_in = depth_in.view(1, 1, H, W)
71 | # depth_shift = depth_shift_T0.view(1, 1, H, W)
72 | # depth_scale = depth_scale_T0.view(1, 1, 1, 1)
73 | # CVD = depth_scale * depth_in + depth_shift
74 |
75 | K_tensor = torch.zeros(1, 3, 3).type_as(image_tensor)
76 | K_tensor[:, 0, 0] = float(viewpoint.metadata.scale_factor_x)
77 | K_tensor[:, 1, 1] = float(viewpoint.metadata.scale_factor_y)
78 | K_tensor[:, 0, 2] = float(viewpoint.metadata.principal_point_x)
79 | K_tensor[:, 1, 2] = float(viewpoint.metadata.principal_point_y)
80 | K_tensor[:, 2, 2] = float(1)
81 |
82 | w2c_target = torch.cat((pred_R0, pred_T0[:, :, None]), -1)
83 | w2c_prev = torch.cat((pred_R, pred_T[:, :, None]), -1)
84 | warped_img = deformation.inverse_warp_rt1_rt2(
85 | image_tensor, CVD, w2c_target, w2c_prev, K_tensor, torch.inverse(K_tensor)
86 | )
87 |
88 | p_im = warped_img.detach().squeeze().cpu().numpy()
89 | im = Image.fromarray(np.rint(255 * p_im.transpose(1, 2, 0)).astype(np.uint8))
90 | im.save(path.replace(".jpg", "_warped.jpg"))
91 | return
92 |
93 | # scaling_copy = gaussians._scaling
94 | render_pkg = render_func(
95 | viewpoint, static_gaussians, dynamic_gaussians, background, get_static=True, get_dynamic=True
96 | )
97 |
98 | label1 = f"stage:{stage},iter:{iteration}"
99 | times = time_now / 60
100 | if times < 1:
101 | end = "min"
102 | else:
103 | end = "mins"
104 | label2 = "time:%.2f" % times + end
105 |
106 | image = render_pkg["render"]
107 |
108 | d_image = render_pkg["d_render"]
109 | s_image = render_pkg["s_render"]
110 |
111 | d_alpha = render_pkg["d_alpha"]
112 |
113 | depth = render_pkg["depth"]
114 | st_depth = render_pkg["s_depth"]
115 |
116 | z = depth + 1e-6
117 | camera_metadata = viewpoint.metadata
118 | pixels = camera_metadata.get_pixels()
119 | y = (
120 | pixels[..., 1] - camera_metadata.principal_point_y
121 | ) / dynamic_gaussians._posenet.focal_bias.exp().detach().cpu().numpy()
122 | x = (
123 | pixels[..., 0] - camera_metadata.principal_point_x - y * camera_metadata.skew
124 | ) / dynamic_gaussians._posenet.focal_bias.exp().detach().cpu().numpy()
125 | viewdirs = np.stack([x, y, np.ones_like(x)], axis=-1)
126 | viewdirs = torch.from_numpy(viewdirs).to(z.device)
127 |
128 | coords = viewdirs[None] * z[..., None]
129 | coords = coords.permute(0, 3, 1, 2)
130 |
131 | dxdu = coords[..., 0, :, 1:] - coords[..., 0, :, :-1]
132 | dydu = coords[..., 1, :, 1:] - coords[..., 1, :, :-1]
133 | dzdu = coords[..., 2, :, 1:] - coords[..., 2, :, :-1]
134 | dxdv = coords[..., 0, 1:, :] - coords[..., 0, :-1, :]
135 | dydv = coords[..., 1, 1:, :] - coords[..., 1, :-1, :]
136 | dzdv = coords[..., 2, 1:, :] - coords[..., 2, :-1, :]
137 |
138 | dxdu = torch.nn.functional.pad(dxdu, (0, 1), mode="replicate")
139 | dydu = torch.nn.functional.pad(dydu, (0, 1), mode="replicate")
140 | dzdu = torch.nn.functional.pad(dzdu, (0, 1), mode="replicate")
141 |
142 | dxdv = torch.cat([dxdv, dxdv[..., -1:, :]], dim=-2)
143 | dydv = torch.cat([dydv, dydv[..., -1:, :]], dim=-2)
144 | dzdv = torch.cat([dzdv, dzdv[..., -1:, :]], dim=-2)
145 |
146 | n_x = dydv * dzdu - dydu * dzdv
147 | n_y = dzdv * dxdu - dzdu * dxdv
148 | n_z = dxdv * dydu - dxdu * dydv
149 |
150 | pred_normal = torch.stack([n_x, n_y, n_z], dim=-3)
151 | pred_normal = torch.nn.functional.normalize(pred_normal, dim=-3)
152 |
153 | if dataset_type == "PanopticSports":
154 | gt_np = viewpoint["image"].permute(1, 2, 0).cpu().numpy()
155 | else:
156 | gt_np = viewpoint.original_image.permute(1, 2, 0).cpu().numpy()
157 | image_np = image.permute(1, 2, 0).cpu().numpy() # (H, W, 3)
158 |
159 | d_image_np = d_image.permute(1, 2, 0).cpu().numpy()
160 | s_image_np = s_image.permute(1, 2, 0).cpu().numpy()
161 |
162 | d_alpha_np = d_alpha.permute(1, 2, 0).cpu().numpy()
163 | d_alpha_np = np.repeat(d_alpha_np, 3, axis=2)
164 |
165 | depth_np = depth.permute(1, 2, 0).cpu().numpy()
166 | depth_np /= depth_np.max()
167 | depth_np = np.repeat(depth_np, 3, axis=2)
168 |
169 | st_depth_np = st_depth.permute(1, 2, 0).cpu().numpy()
170 | st_depth_np /= depth_np.max()
171 | st_depth_np = np.repeat(st_depth_np, 3, axis=2)
172 |
173 | pred_normal_np = (pred_normal[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
174 |
175 | error = (image_np - gt_np) ** 2
176 | error_np = (error - np.min(error)) / (max(np.max(error) - np.min(error), 1e-8))
177 |
178 | if is_train:
179 | gt_normal = (viewpoint.normal.cuda() + 1) / 2
180 | gt_normal_np = gt_normal.permute(1, 2, 0).cpu().numpy()
181 |
182 | gt_depth = viewpoint.depth.cuda()
183 | # gt_depth = gt_depth / F.avg_pool2d(gt_depth.detach(), kernel_size=gt_depth.shape[1::])
184 | gt_depth_np = gt_depth.permute(1, 2, 0).cpu().numpy()
185 | gt_depth_np /= gt_depth_np.max()
186 | gt_depth_np = np.repeat(gt_depth_np, 3, axis=2)
187 |
188 | decomp_image_np = np.concatenate((gt_normal_np, pred_normal_np, gt_depth_np, depth_np), axis=1)
189 |
190 | mask_np = viewpoint.mask.permute(1, 2, 0).cpu().numpy()
191 | mask_np = np.repeat(mask_np, 3, axis=2)
192 | image_np = np.concatenate((gt_np, image_np, mask_np, d_alpha_np, d_image_np, s_image_np), axis=1)
193 | else:
194 | decomp_image_np = np.concatenate((pred_normal_np, depth_np), axis=1)
195 | image_np = np.concatenate((gt_np, image_np, error_np, d_alpha_np, d_image_np, s_image_np), axis=1)
196 |
197 | image_with_labels = Image.fromarray((np.clip(image_np, 0, 1) * 255).astype("uint8")) # 转换为8位图像
198 | decomp_image_with_labels = Image.fromarray((np.clip(decomp_image_np, 0, 1) * 255).astype("uint8"))
199 | # 创建PIL图像对象的副本以绘制标签
200 | draw1 = ImageDraw.Draw(image_with_labels)
201 |
202 | # 选择字体和字体大小
203 | font = ImageFont.truetype("./utils/TIMES.TTF", size=40) # 请将路径替换为您选择的字体文件路径
204 |
205 | # 选择文本颜色
206 | text_color = (255, 0, 0) # 白色
207 |
208 | # 选择标签的位置(左上角坐标)
209 | label1_position = (10, 10)
210 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标
211 |
212 | # 在图像上添加标签
213 | draw1.text(label1_position, label1, fill=text_color, font=font)
214 | draw1.text(label2_position, label2, fill=text_color, font=font)
215 |
216 | image_with_labels.save(path)
217 | decomp_image_with_labels.save(path.replace(".jpg", "_decomp.jpg"))
218 |
219 | render_base_path = os.path.join(scene.model_path, f"{stage}_render")
220 | point_cloud_path = os.path.join(render_base_path, "pointclouds")
221 | if is_train:
222 | image_path = os.path.join(render_base_path, "train/images")
223 | else:
224 | image_path = os.path.join(render_base_path, "val/images")
225 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")):
226 | os.makedirs(render_base_path)
227 | if not os.path.exists(point_cloud_path):
228 | os.makedirs(point_cloud_path)
229 | if not os.path.exists(image_path):
230 | os.makedirs(image_path)
231 | # image:3,800,800
232 |
233 | # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg")
234 | for idx in range(len(viewpoints)):
235 | image_save_path = os.path.join(image_path, f"{viewpoints[idx].image_name}.jpg")
236 | render(
237 | stat_gaussians,
238 | dyn_gaussians,
239 | viewpoints[idx],
240 | image_save_path,
241 | scaling=1,
242 | cam_type=dataset_type,
243 | over_t=over_t,
244 | )
245 | # render(gaussians,point_save_path,scaling = 0.1)
246 | # 保存带有标签的图像
247 |
248 | pc_mask = dyn_gaussians.get_opacity
249 | pc_mask = pc_mask > 0.1
250 | dyn_gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1, 0).numpy()
251 | # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path)
252 | # 如果需要,您可以将PIL图像转换回PyTorch张量
253 | # return image
254 | # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0
255 |
256 |
257 | def visualize_and_save_point_cloud(point_cloud, R, T, filename):
258 | # 创建3D散点图
259 | fig = plt.figure()
260 | ax = fig.add_subplot(111, projection="3d")
261 | R = R.T
262 | # 应用旋转和平移变换
263 | T = -R.dot(T)
264 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1)
265 | # pcd = o3d.geometry.PointCloud()
266 | # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式
267 | # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:]
268 | # 可视化点云
269 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c="g", marker="o")
270 | ax.axis("off")
271 | # ax.set_xlabel('X Label')
272 | # ax.set_ylabel('Y Label')
273 | # ax.set_zlabel('Z Label')
274 |
275 | # 保存渲染结果为图片
276 | plt.savefig(filename)
277 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 | C0 = 0.28209479177387814
24 | C1 = 0.4886025119029199
25 | C2 = [1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396]
26 | C3 = [
27 | -0.5900435899266435,
28 | 2.890611442640554,
29 | -0.4570457994644658,
30 | 0.3731763325901154,
31 | -0.4570457994644658,
32 | 1.445305721320277,
33 | -0.5900435899266435,
34 | ]
35 | C4 = [
36 | 2.5033429417967046,
37 | -1.7701307697799304,
38 | 0.9461746957575601,
39 | -0.6690465435572892,
40 | 0.10578554691520431,
41 | -0.6690465435572892,
42 | 0.47308734787878004,
43 | -1.7701307697799304,
44 | 0.6258357354491761,
45 | ]
46 |
47 |
48 | def eval_sh(deg, sh, dirs):
49 | """
50 | Evaluate spherical harmonics at unit directions
51 | using hardcoded SH polynomials.
52 | Works with torch/np/jnp.
53 | ... Can be 0 or more batch dimensions.
54 | Args:
55 | deg: int SH deg. Currently, 0-3 supported
56 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
57 | dirs: jnp.ndarray unit directions [..., 3]
58 | Returns:
59 | [..., C]
60 | """
61 | assert deg <= 4 and deg >= 0
62 | coeff = (deg + 1) ** 2
63 | assert sh.shape[-1] >= coeff
64 |
65 | result = C0 * sh[..., 0]
66 | if deg > 0:
67 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
68 | result = result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]
69 |
70 | if deg > 1:
71 | xx, yy, zz = x * x, y * y, z * z
72 | xy, yz, xz = x * y, y * z, x * z
73 | result = (
74 | result
75 | + C2[0] * xy * sh[..., 4]
76 | + C2[1] * yz * sh[..., 5]
77 | + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
78 | + C2[3] * xz * sh[..., 7]
79 | + C2[4] * (xx - yy) * sh[..., 8]
80 | )
81 |
82 | if deg > 2:
83 | result = (
84 | result
85 | + C3[0] * y * (3 * xx - yy) * sh[..., 9]
86 | + C3[1] * xy * z * sh[..., 10]
87 | + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
88 | + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
89 | + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
90 | + C3[5] * z * (xx - yy) * sh[..., 14]
91 | + C3[6] * x * (xx - 3 * yy) * sh[..., 15]
92 | )
93 |
94 | if deg > 3:
95 | result = (
96 | result
97 | + C4[0] * xy * (xx - yy) * sh[..., 16]
98 | + C4[1] * yz * (3 * xx - yy) * sh[..., 17]
99 | + C4[2] * xy * (7 * zz - 1) * sh[..., 18]
100 | + C4[3] * yz * (7 * zz - 3) * sh[..., 19]
101 | + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
102 | + C4[5] * xz * (7 * zz - 3) * sh[..., 21]
103 | + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
104 | + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
105 | + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]
106 | )
107 | return result
108 |
109 |
110 | def RGB2SH(rgb):
111 | return (rgb - 0.5) / C0
112 |
113 |
114 | def SH2RGB(sh):
115 | return sh * C0 + 0.5
116 |
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from errno import EEXIST
14 | from os import makedirs, path
15 |
16 |
17 | def mkdir_p(folder_path):
18 | # Creates a directory. equivalent to using mkdir -p on the command line
19 | try:
20 | makedirs(folder_path)
21 | except OSError as exc: # Python >2.5
22 | if exc.errno == EEXIST and path.isdir(folder_path):
23 | pass
24 | else:
25 | raise
26 |
27 |
28 | def searchForMaxIteration(folder):
29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
30 | return max(saved_iters)
31 |
--------------------------------------------------------------------------------
/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Timer:
5 | def __init__(self):
6 | self.start_time = None
7 | self.elapsed = 0
8 | self.paused = False
9 |
10 | def start(self):
11 | if self.start_time is None:
12 | self.start_time = time.time()
13 | elif self.paused:
14 | self.start_time = time.time() - self.elapsed
15 | self.paused = False
16 |
17 | def pause(self):
18 | if not self.paused:
19 | self.elapsed = time.time() - self.start_time
20 | self.paused = True
21 |
22 | def get_elapsed_time(self):
23 | if self.paused:
24 | return self.elapsed
25 | else:
26 | return time.time() - self.start_time
27 |
--------------------------------------------------------------------------------