├── .gitignore ├── .gitmodules ├── LICENSE ├── arguments └── __init__.py ├── assets ├── D-NeRF-Results.png ├── Editing_mode.png ├── HKU.png ├── VAST.png ├── ZJU.png ├── badge-website.svg ├── bear.gif ├── bear_editing.gif ├── edited_hook.gif ├── edited_jumpingjacks.gif ├── edited_lego.gif ├── edited_mutant.gif ├── face.gif ├── face_editing.gif ├── family.gif ├── hand.gif ├── horse.gif ├── kitchen.gif ├── person.gif ├── plant.gif └── teaser.png ├── cam_utils.py ├── convert.py ├── data_tools ├── colmap2nerf.py ├── interactive_invoke.py └── phone_catch.py ├── full_eval.py ├── gaussian_renderer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── network_gui.cpython-38.pyc └── network_gui.py ├── lap_deform.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── readme.md ├── render.py ├── requirements.txt ├── scene ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── cameras.cpython-38.pyc │ ├── colmap_loader.cpython-38.pyc │ ├── dataset_readers.cpython-38.pyc │ ├── deform_model.cpython-38.pyc │ └── gaussian_model.cpython-38.pyc ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── deform_model.py └── gaussian_model.py ├── train.py ├── train_gui.py ├── train_gui.sh ├── train_gui_utils.py └── utils ├── arap_deform.py ├── bezier.py ├── camera_utils.py ├── deform_utils.py ├── dual_quaternion.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── interactive_utils.py ├── loss_utils.py ├── other_utils.py ├── pickle_utils.py ├── pose_utils.py ├── preprocess.py ├── rigid_utils.py ├── sh_utils.py ├── system_utils.py ├── time_utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/* 2 | .vscode/* 3 | __pycache__/* 4 | __pycache__/ 5 | *.pyc 6 | *.sh -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/ashawkey/diff-gaussian-rasterization 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yihua Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | # if shorthand: 32 | # if t == bool: 33 | # group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | # else: 35 | # group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | # else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | 50 | class ModelParams(ParamGroup): 51 | def __init__(self, parser, sentinel=False): 52 | self.sh_degree = 3 53 | self.K = 3 54 | self._source_path = "" 55 | self._model_path = "" 56 | self._images = "images" 57 | self._resolution = -1 58 | self._white_background = False 59 | self.data_device = "cuda" 60 | self.eval = False 61 | self.load2gpu_on_the_fly = False 62 | self.is_blender = False 63 | self.deform_type = 'node' 64 | self.skinning = False 65 | self.hyper_dim = 8 66 | self.node_num = 1024 67 | self.pred_opacity = False 68 | self.pred_color = False 69 | self.use_hash = False 70 | self.hash_time = False 71 | self.d_rot_as_rotmat = False # Debug!!! 72 | self.d_rot_as_res = True # Debug!!! 73 | self.local_frame = False 74 | self.progressive_brand_time = False 75 | self.gs_with_motion_mask = False 76 | self.init_isotropic_gs_with_all_colmap_pcl = False 77 | self.as_gs_force_with_motion_mask = False # Only for scenes with both static and dynamic parts and without alpha mask 78 | self.max_d_scale = -1. 79 | self.is_scene_static = False 80 | super().__init__(parser, "Loading Parameters", sentinel) 81 | 82 | def extract(self, args): 83 | g = super().extract(args) 84 | g.source_path = os.path.abspath(g.source_path) 85 | if not g.model_path.endswith(g.deform_type): 86 | g.model_path = os.path.join(os.path.dirname(os.path.normpath(g.model_path)), os.path.basename(os.path.normpath(g.model_path)) + f'_{g.deform_type}') 87 | return g 88 | 89 | 90 | class PipelineParams(ParamGroup): 91 | def __init__(self, parser): 92 | self.convert_SHs_python = False 93 | self.compute_cov3D_python = False 94 | self.debug = False 95 | super().__init__(parser, "Pipeline Parameters") 96 | 97 | 98 | class OptimizationParams(ParamGroup): 99 | def __init__(self, parser): 100 | self.iterations = 80_000 101 | self.warm_up = 3_000 102 | self.dynamic_color_warm_up = 20_000 103 | self.position_lr_init = 0.00016 104 | self.position_lr_final = 0.0000016 105 | self.position_lr_delay_mult = 0.01 106 | self.position_lr_max_steps = 30_000 107 | self.deform_lr_max_steps = 40_000 108 | self.feature_lr = 0.0025 109 | self.opacity_lr = 0.05 110 | self.scaling_lr = 0.001 111 | self.rotation_lr = 0.001 112 | self.percent_dense = 0.01 113 | self.lambda_dssim = 0.2 114 | self.densification_interval = 100 115 | self.opacity_reset_interval = 3000 116 | self.densify_from_iter = 500 117 | self.densify_until_iter = 50_000 118 | self.densify_grad_threshold = 0.0002 119 | self.oneupSHdegree_step = 1000 120 | self.random_bg_color = False 121 | 122 | self.deform_lr_scale = 1. 123 | self.deform_downsamp_strategy = 'samp_hyper' 124 | self.deform_downsamp_with_dynamic_mask = False 125 | self.node_enable_densify_prune = False 126 | self.node_densification_interval = 5000 127 | self.node_densify_from_iter = 1000 128 | self.node_densify_until_iter = 25_000 129 | self.node_force_densify_prune_step = 10_000 130 | self.node_max_num_ratio_during_init = 16 131 | 132 | self.random_init_deform_gs = False 133 | self.node_warm_up = 2_000 134 | self.iterations_node_sampling = 7500 135 | self.iterations_node_rendering = 10000 136 | 137 | self.progressive_train = False 138 | self.progressive_train_node = False 139 | self.progressive_stage_ratio = .2 # The ratio of the number of images added per stage 140 | self.progressive_stage_steps = 3000 # The training steps of each stage 141 | 142 | self.lambda_optical_landmarks = [1e-1, 1e-1, 1e-3, 0] 143 | self.lambda_optical_steps = [0, 15_000, 25_000, 25_001] 144 | 145 | self.lambda_motion_mask_landmarks = [5e-1, 1e-2, 0] 146 | self.lambda_motion_mask_steps = [0, 10_000, 10_001] 147 | self.no_motion_mask_loss = False # Camera pose may be inaccurate and should model the whole scene motion 148 | 149 | self.gt_alpha_mask_as_scene_mask = False 150 | self.gt_alpha_mask_as_dynamic_mask = False 151 | self.no_arap_loss = False # For large scenes arap is too slow 152 | self.with_temporal_smooth_loss = False 153 | 154 | super().__init__(parser, "Optimization Parameters") 155 | 156 | 157 | def get_combined_args(parser: ArgumentParser): 158 | cmdlne_string = sys.argv[1:] 159 | cfgfile_string = "Namespace()" 160 | args_cmdline = parser.parse_args(cmdlne_string) 161 | 162 | if not args_cmdline.model_path.endswith(args_cmdline.deform_type): 163 | args_cmdline.model_path = os.path.join(os.path.dirname(os.path.normpath(args_cmdline.model_path)), os.path.basename(os.path.normpath(args_cmdline.model_path)) + f'_{args_cmdline.deform_type}') 164 | 165 | try: 166 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 167 | print("Looking for config file in", cfgfilepath) 168 | with open(cfgfilepath) as cfg_file: 169 | print("Config file found: {}".format(cfgfilepath)) 170 | cfgfile_string = cfg_file.read() 171 | except TypeError: 172 | print("Config file not found at") 173 | pass 174 | args_cfgfile = eval(cfgfile_string) 175 | 176 | merged_dict = vars(args_cfgfile).copy() 177 | for k, v in vars(args_cmdline).items(): 178 | if v != None: 179 | merged_dict[k] = v 180 | return Namespace(**merged_dict) 181 | -------------------------------------------------------------------------------- /assets/D-NeRF-Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/D-NeRF-Results.png -------------------------------------------------------------------------------- /assets/Editing_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/Editing_mode.png -------------------------------------------------------------------------------- /assets/HKU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/HKU.png -------------------------------------------------------------------------------- /assets/VAST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/VAST.png -------------------------------------------------------------------------------- /assets/ZJU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/ZJU.png -------------------------------------------------------------------------------- /assets/badge-website.svg: -------------------------------------------------------------------------------- 1 | 2 | 15 | 17 | 35 | project: website 37 | 38 | 42 | 47 | 51 | 52 | 54 | 60 | 61 | 64 | 69 | 75 | 80 | 81 | 88 | 98 | Project 105 | 106 | 116 | Website 123 | 124 | 125 | 129 | 130 | -------------------------------------------------------------------------------- /assets/bear.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/bear.gif -------------------------------------------------------------------------------- /assets/bear_editing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/bear_editing.gif -------------------------------------------------------------------------------- /assets/edited_hook.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/edited_hook.gif -------------------------------------------------------------------------------- /assets/edited_jumpingjacks.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/edited_jumpingjacks.gif -------------------------------------------------------------------------------- /assets/edited_lego.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/edited_lego.gif -------------------------------------------------------------------------------- /assets/edited_mutant.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/edited_mutant.gif -------------------------------------------------------------------------------- /assets/face.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/face.gif -------------------------------------------------------------------------------- /assets/face_editing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/face_editing.gif -------------------------------------------------------------------------------- /assets/family.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/family.gif -------------------------------------------------------------------------------- /assets/hand.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/hand.gif -------------------------------------------------------------------------------- /assets/horse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/horse.gif -------------------------------------------------------------------------------- /assets/kitchen.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/kitchen.gif -------------------------------------------------------------------------------- /assets/person.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/person.gif -------------------------------------------------------------------------------- /assets/plant.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/plant.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/assets/teaser.png -------------------------------------------------------------------------------- /cam_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | 4 | import torch 5 | 6 | def dot(x, y): 7 | if isinstance(x, np.ndarray): 8 | return np.sum(x * y, -1, keepdims=True) 9 | else: 10 | return torch.sum(x * y, -1, keepdim=True) 11 | 12 | 13 | def length(x, eps=1e-20): 14 | if isinstance(x, np.ndarray): 15 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 16 | else: 17 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 18 | 19 | 20 | def safe_normalize(x, eps=1e-20): 21 | return x / length(x, eps) 22 | 23 | 24 | def look_at(campos, target, opengl=True): 25 | # campos: [N, 3], camera/eye position 26 | # target: [N, 3], object to look at 27 | # return: [N, 3, 3], rotation matrix 28 | if not opengl: 29 | # camera forward aligns with -z 30 | forward_vector = safe_normalize(target - campos) 31 | up_vector = np.array([0, 1, 0], dtype=np.float32) 32 | right_vector = safe_normalize(np.cross(forward_vector, up_vector)) 33 | up_vector = safe_normalize(np.cross(right_vector, forward_vector)) 34 | else: 35 | # camera forward aligns with +z 36 | forward_vector = safe_normalize(campos - target) 37 | up_vector = np.array([0, 1, 0], dtype=np.float32) 38 | right_vector = safe_normalize(np.cross(up_vector, forward_vector)) 39 | up_vector = safe_normalize(np.cross(forward_vector, right_vector)) 40 | R = np.stack([right_vector, up_vector, forward_vector], axis=1) 41 | return R 42 | 43 | 44 | # elevation & azimuth to pose (cam2world) matrix 45 | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): 46 | # radius: scalar 47 | # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) 48 | # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) 49 | # return: [4, 4], camera pose matrix 50 | if is_degree: 51 | elevation = np.deg2rad(elevation) 52 | azimuth = np.deg2rad(azimuth) 53 | x = radius * np.cos(elevation) * np.sin(azimuth) 54 | y = - radius * np.sin(elevation) 55 | z = radius * np.cos(elevation) * np.cos(azimuth) 56 | if target is None: 57 | target = np.zeros([3], dtype=np.float32) 58 | campos = np.array([x, y, z]) + target # [3] 59 | T = np.eye(4, dtype=np.float32) 60 | T[:3, :3] = look_at(campos, target, opengl) 61 | T[:3, 3] = campos 62 | return T 63 | 64 | 65 | class OrbitCamera: 66 | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): 67 | self.W = W 68 | self.H = H 69 | self.radius = r # camera distance from center 70 | self.fovy = np.deg2rad(fovy) # deg 2 rad 71 | self.near = near 72 | self.far = far 73 | self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point 74 | # self.rot = R.from_matrix(np.eye(3)) 75 | self.rot = R.from_matrix(np.array([[1., 0., 0.,], 76 | [0., 0., -1.], 77 | [0., 1., 0.]])) 78 | self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! 79 | self.side = np.array([1, 0, 0], dtype=np.float32) 80 | 81 | @property 82 | def fovx(self): 83 | return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) 84 | 85 | @property 86 | def campos(self): 87 | return self.pose[:3, 3] 88 | 89 | # pose (c2w) 90 | @property 91 | def pose(self): 92 | # first move camera to radius 93 | res = np.eye(4, dtype=np.float32) 94 | res[2, 3] = self.radius # opengl convention... 95 | # rotate 96 | rot = np.eye(4, dtype=np.float32) 97 | rot[:3, :3] = self.rot.as_matrix() 98 | res = rot @ res 99 | # translate 100 | res[:3, 3] -= self.center 101 | return res 102 | 103 | # view (w2c) 104 | @property 105 | def view(self): 106 | return np.linalg.inv(self.pose) 107 | 108 | # projection (perspective) 109 | @property 110 | def perspective(self): 111 | y = np.tan(self.fovy / 2) 112 | aspect = self.W / self.H 113 | return np.array( 114 | [ 115 | [1 / (y * aspect), 0, 0, 0], 116 | [0, -1 / y, 0, 0], 117 | [ 118 | 0, 119 | 0, 120 | -(self.far + self.near) / (self.far - self.near), 121 | -(2 * self.far * self.near) / (self.far - self.near), 122 | ], 123 | [0, 0, -1, 0], 124 | ], 125 | dtype=np.float32, 126 | ) 127 | 128 | # intrinsics 129 | @property 130 | def intrinsics(self): 131 | focal = self.H / (2 * np.tan(self.fovy / 2)) 132 | return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) 133 | 134 | @property 135 | def mvp(self): 136 | return self.perspective @ np.linalg.inv(self.pose) # [4, 4] 137 | 138 | def orbit(self, dx, dy): 139 | # rotate along camera up/side axis! 140 | side = self.rot.as_matrix()[:3, 0] 141 | up = self.rot.as_matrix()[:3, 1] 142 | rotvec_x = up * np.radians(-0.05 * dx) 143 | rotvec_y = side * np.radians(-0.05 * dy) 144 | self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot 145 | 146 | def scale(self, delta): 147 | self.radius *= 1.1 ** (-delta) 148 | 149 | def pan(self, dx, dy, dz=0): 150 | # pan in camera coordinate system (careful on the sensitivity!) 151 | self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) 152 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | import shutil 15 | 16 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 17 | parser = ArgumentParser("Colmap converter") 18 | parser.add_argument("--no_gpu", action='store_true') 19 | parser.add_argument("--skip_matching", action='store_true') 20 | parser.add_argument("--source_path", "-s", required=True, type=str) 21 | parser.add_argument("--camera", default="OPENCV", type=str) 22 | parser.add_argument("--colmap_executable", default="", type=str) 23 | parser.add_argument("--resize", action="store_true") 24 | parser.add_argument("--magick_executable", default="", type=str) 25 | args = parser.parse_args() 26 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 27 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 28 | use_gpu = 1 if not args.no_gpu else 0 29 | 30 | if not args.skip_matching: 31 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 32 | 33 | ## Feature extraction 34 | os.system(colmap_command + " feature_extractor "\ 35 | "--database_path " + args.source_path + "/distorted/database.db \ 36 | --image_path " + args.source_path + "/input \ 37 | --ImageReader.single_camera 1 \ 38 | --ImageReader.camera_model " + args.camera + " \ 39 | --SiftExtraction.use_gpu " + str(use_gpu)) 40 | 41 | ## Feature matching 42 | os.system(colmap_command + " exhaustive_matcher \ 43 | --database_path " + args.source_path + "/distorted/database.db \ 44 | --SiftMatching.use_gpu " + str(use_gpu)) 45 | 46 | ### Bundle adjustment 47 | # The default Mapper tolerance is unnecessarily large, 48 | # decreasing it speeds up bundle adjustment steps. 49 | os.system(colmap_command + " mapper \ 50 | --database_path " + args.source_path + "/distorted/database.db \ 51 | --image_path " + args.source_path + "/input \ 52 | --output_path " + args.source_path + "/distorted/sparse \ 53 | --Mapper.ba_global_function_tolerance=0.000001") 54 | 55 | ### Image undistortion 56 | ## We need to undistort our images into ideal pinhole intrinsics. 57 | os.system(colmap_command + " image_undistorter \ 58 | --image_path " + args.source_path + "/input \ 59 | --input_path " + args.source_path + "/distorted/sparse/0 \ 60 | --output_path " + args.source_path + "\ 61 | --output_type COLMAP") 62 | 63 | files = os.listdir(args.source_path + "/sparse") 64 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 65 | # Copy each file from the source directory to the destination directory 66 | for file in files: 67 | if file == '0': 68 | continue 69 | source_file = os.path.join(args.source_path, "sparse", file) 70 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 71 | shutil.move(source_file, destination_file) 72 | 73 | if(args.resize): 74 | print("Copying and resizing...") 75 | 76 | # Resize images. 77 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 78 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 79 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 80 | # Get the list of files in the source directory 81 | files = os.listdir(args.source_path + "/images") 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | source_file = os.path.join(args.source_path, "images", file) 85 | 86 | destination_file = os.path.join(args.source_path, "images_2", file) 87 | shutil.copy2(source_file, destination_file) 88 | os.system(magick_command + " mogrify -resize 50% " + destination_file) 89 | 90 | destination_file = os.path.join(args.source_path, "images_4", file) 91 | shutil.copy2(source_file, destination_file) 92 | os.system(magick_command + " mogrify -resize 25% " + destination_file) 93 | 94 | destination_file = os.path.join(args.source_path, "images_8", file) 95 | shutil.copy2(source_file, destination_file) 96 | os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 97 | 98 | print("Done.") -------------------------------------------------------------------------------- /data_tools/phone_catch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import glob 5 | import torch 6 | import shutil 7 | import numpy as np 8 | from PIL import Image 9 | from scipy import optimize 10 | import matplotlib.pyplot as plt 11 | import matplotlib 12 | matplotlib.use('Agg') 13 | 14 | 15 | ''' 16 | MiVOS: https://github.com/hkchengrex/MiVOS 17 | ''' 18 | MIVOS_PATH='YOUR/PATH/TO/MiVOS/' 19 | sys.path.append(MIVOS_PATH) 20 | from interactive_invoke import seg_video 21 | 22 | from colmap2nerf import colmap2nerf_invoke 23 | 24 | 25 | def Laplacian(img): 26 | return cv2.Laplacian(img, cv2.CV_64F).var() 27 | 28 | 29 | def cal_ambiguity(path): 30 | imgs = sorted(glob.glob(path + '/*.png')) 31 | laplace = np.zeros(len(imgs), np.float32) 32 | laplace_dict = {} 33 | for i in range(len(imgs)): 34 | laplace[i] = Laplacian(cv2.cvtColor(cv2.imread(imgs[i]), cv2.COLOR_BGR2GRAY)) 35 | laplace_dict[imgs[i]] = laplace[i] 36 | fig = plt.figure() 37 | fig.add_subplot(1, 2, 1) 38 | plt.hist(laplace) 39 | fig.add_subplot(1, 2, 2) 40 | plt.plot(np.arange(len(laplace)), laplace) 41 | if not os.path.exists(path + '/../noise/'): 42 | os.makedirs(path + '/../noise/') 43 | elif os.path.exists(path + '../noise/'): 44 | return None, None 45 | else: 46 | return None, None 47 | plt.savefig(path+'/../noise/laplace.png') 48 | return laplace, laplace_dict 49 | 50 | 51 | def select_ambiguity(path, nb=10, threshold=0.8, mv_files=False): 52 | if mv_files and os.path.exists(path + '/../noise/'): 53 | print('No need to select. Already done.') 54 | return None, None 55 | def linear(x, a, b): 56 | return a * x + b 57 | laplace, laplace_dic = cal_ambiguity(path) 58 | if laplace is None: 59 | return None, None 60 | imgs = list(laplace_dic.keys()) 61 | amb_img = [] 62 | amb_lap = [] 63 | for i in range(len(laplace)): 64 | i1 = max(0, int(i - nb / 2)) 65 | i2 = min(len(laplace), int(i + nb / 2)) 66 | lap = laplace[i1: i2] 67 | para, _ = optimize.curve_fit(linear, np.arange(i1, i2), lap) 68 | lapi_ = i * para[0] + para[1] 69 | if laplace[i] / lapi_ < threshold: 70 | amb_img.append(imgs[i]) 71 | amb_lap.append(laplace[i]) 72 | if mv_files: 73 | if not os.path.exists(path + '/../noise/'): 74 | os.makedirs(path + '/../noise/') 75 | file_name = amb_img[-1].split('/')[-1].split('\\')[-1] 76 | shutil.move(amb_img[-1], path + '/../noise/' + file_name) 77 | return amb_img, amb_lap 78 | 79 | 80 | def mask_images(img_path, msk_path, sv_path=None, no_mask=False): 81 | image_names = sorted(os.listdir(img_path)) 82 | image_names = [img for img in image_names if img.endswith('.png') or img.endswith('.jpg')] 83 | msk_names = sorted(os.listdir(msk_path)) 84 | msk_names = [img for img in msk_names if img.endswith('.png') or img.endswith('.jpg')] 85 | 86 | if sv_path is None: 87 | if img_path.endswith('/'): 88 | img_path = img_path[:-1] 89 | sv_path = '/'.join(img_path.split('/')[:-1]) + '/masked_images/' 90 | if not os.path.exists(sv_path) and not os.path.exists(sv_path + '../unmasked_images/'): 91 | os.makedirs(sv_path) 92 | else: 93 | return sv_path 94 | 95 | for i in range(len(image_names)): 96 | image_name, msk_name = image_names[i], msk_names[i] 97 | mask = np.array(Image.open(msk_path + '/' + image_name)) 98 | image = np.array(Image.open(img_path + '/' + image_name)) 99 | mask = cv2.resize(mask, (image.shape[1], image.shape[0])) 100 | if no_mask: 101 | mask = np.ones_like(mask) 102 | if mask.max() == 1: 103 | mask = mask * 255 104 | # image[mask==0] = 0 105 | masked_image = np.concatenate([image, mask[..., np.newaxis]], axis=-1) 106 | Image.fromarray(masked_image).save(sv_path + image_name) 107 | return sv_path 108 | 109 | 110 | def extract_frames_mp4(path, gap=None, frame_num=300, sv_path=None): 111 | if sv_path is None: 112 | sv_path = '/'.join(path.split('/')[:-1]) + '/images/' 113 | if not os.path.exists(sv_path): 114 | os.makedirs(sv_path) 115 | else: 116 | return sv_path 117 | if not os.path.exists(path): 118 | raise NotADirectoryError(path + ' does not exists.') 119 | vidcap = cv2.VideoCapture(path) 120 | if gap is None: 121 | total_frame_num = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 122 | gap = int(total_frame_num / frame_num) 123 | gap = max(gap, 1) 124 | 125 | success, image = vidcap.read() 126 | cv2.imwrite(sv_path + "/%05d.png" % 0, image) 127 | count = 1 128 | image_count = 1 129 | while success: 130 | success, image = vidcap.read() 131 | if count % gap == 0 and success: 132 | cv2.imwrite(sv_path + "/%05d.png" % image_count, image) 133 | image_count += 1 134 | count += 1 135 | return sv_path 136 | 137 | 138 | def rename_images(path): 139 | image_names = sorted(os.listdir(path)) 140 | image_names = [img for img in image_names if img.endswith('.png') or img.endswith('.jpg')] 141 | for i in range(len(image_names)): 142 | shutil.move(path + '/' + image_names[i], path + '/%05d.png' % i) 143 | 144 | 145 | if __name__ == '__main__': 146 | gap = None 147 | no_mask = False 148 | dataset_name = 'DATA_NAME' 149 | video_path = f'YOUR/PATH/TO/{dataset_name}/{dataset_name}.mp4' 150 | print('Extracting frames from video: ', video_path, ' with gap: ', gap) 151 | img_path = extract_frames_mp4(video_path, gap=gap) 152 | 153 | # print('Removing Blurry Images') 154 | # laplace, _ = select_ambiguity(img_path, nb=10, threshold=0.8, mv_files=True) 155 | # if laplace is not None: 156 | # rename_images(img_path) 157 | if not no_mask: 158 | print('Segmenting images with MiVOS ...') 159 | msk_path = seg_video(img_path=img_path) 160 | torch.cuda.empty_cache() 161 | print('Masking images with masks ...') 162 | msked_path = mask_images(img_path, msk_path, no_mask=no_mask) 163 | 164 | 165 | print('Running COLMAP ...') 166 | colmap2nerf_invoke(img_path) 167 | if img_path.endswith('/'): 168 | img_path = img_path[:-1] 169 | unmsk_path = '/'.join(img_path.split('/')[:-1]) + '/unmasked_images/' 170 | print('Rename masked and unmasked pathes.') 171 | if not no_mask: 172 | os.rename(img_path, unmsk_path) 173 | os.rename(msked_path, img_path) 174 | 175 | 176 | def red2mask(img_dir): 177 | img_paths = glob.glob(os.path.join(img_dir, "*.png")) 178 | imgs = [cv2.cv2.cvtColor(cv2.imread(x) , cv2.COLOR_BGR2GRAY) for x in img_paths] 179 | save_dir = os.path.join(os.path.dirname(img_dir, "white_mask")) 180 | os.makedirs(save_dir, exist_ok=True) 181 | for idx, img_path in enumerate(img_paths): 182 | save_path = os.path.join(save_dir, os.path.basename(img_path)) 183 | cv2.imwrite(save_path, imgs[idx]) -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system( 68 | "python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | os.system( 70 | "python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 71 | 72 | if not args.skip_metrics: 73 | scenes_string = "" 74 | for scene in all_scenes: 75 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 76 | 77 | os.system("python metrics.py -m " + scenes_string) 78 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from utils.rigid_utils import from_homogenous, to_homogenous 18 | 19 | 20 | # def quaternion_multiply(q1, q2): 21 | # w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] 22 | # w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] 23 | 24 | # w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 25 | # x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 26 | # y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 27 | # z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 28 | 29 | # return torch.stack((w, x, y, z), dim=-1) 30 | 31 | 32 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 33 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 34 | 35 | def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 36 | aw, ax, ay, az = torch.unbind(a, -1) 37 | bw, bx, by, bz = torch.unbind(b, -1) 38 | ow = aw * bw - ax * bx - ay * by - az * bz 39 | ox = aw * bx + ax * bw + ay * bz - az * by 40 | oy = aw * by - ax * bz + ay * bw + az * bx 41 | oz = aw * bz + ax * by - ay * bx + az * bw 42 | return torch.stack((ow, ox, oy, oz), -1) 43 | 44 | def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 45 | ab = quaternion_raw_multiply(a, b) 46 | return standardize_quaternion(ab) 47 | 48 | 49 | def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, d_xyz, d_rotation, d_scaling, d_opacity=None, d_color=None, scaling_modifier=1.0, override_color=None, random_bg_color=False, render_motion=False, detach_xyz=False, detach_scale=False, detach_rot=False, detach_opacity=False, d_rot_as_res=True, scale_const=None, d_rotation_bias=None, force_visible=False): 50 | """ 51 | Render the scene. 52 | 53 | Background tensor (bg_color) must be on GPU! 54 | """ 55 | 56 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 57 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 58 | try: 59 | screenspace_points.retain_grad() 60 | except: 61 | pass 62 | 63 | # Set up rasterization configuration 64 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 65 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 66 | 67 | bg = bg_color if not random_bg_color else torch.rand_like(bg_color) 68 | 69 | raster_settings = GaussianRasterizationSettings( 70 | image_height=int(viewpoint_camera.image_height), 71 | image_width=int(viewpoint_camera.image_width), 72 | tanfovx=tanfovx, 73 | tanfovy=tanfovy, 74 | bg=bg, 75 | scale_modifier=scaling_modifier, 76 | viewmatrix=viewpoint_camera.world_view_transform, 77 | projmatrix=viewpoint_camera.full_proj_transform, 78 | sh_degree=pc.active_sh_degree, 79 | campos=viewpoint_camera.camera_center, 80 | prefiltered=False, 81 | debug=pipe.debug, 82 | ) 83 | 84 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 85 | 86 | # if torch.is_tensor(d_xyz) is False: 87 | # means3D = pc.get_xyz 88 | # else: 89 | # means3D = from_homogenous( 90 | # torch.bmm(d_xyz, to_homogenous(pc.get_xyz).unsqueeze(-1)).squeeze(-1)) 91 | means3D = pc.get_xyz + d_xyz 92 | means2D = screenspace_points 93 | if scale_const is not None: 94 | opacity = torch.ones_like(pc.get_opacity) 95 | else: 96 | opacity = pc.get_opacity if d_opacity is None else pc.get_opacity + d_opacity 97 | 98 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 99 | # scaling / rotation by the rasterizer. 100 | scales = None 101 | rotations = None 102 | cov3D_precomp = None 103 | if pipe.compute_cov3D_python: 104 | cov3D_precomp = pc.get_covariance(scaling_modifier, d_rotation=None if type(d_rotation) is float else d_rotation, gs_rot_bias=d_rotation_bias) 105 | else: 106 | scales = pc.get_scaling + d_scaling 107 | rotations = pc.get_rotation_bias(d_rotation) 108 | if d_rotation_bias is not None: 109 | rotations = quaternion_multiply(d_rotation_bias, rotations) 110 | 111 | if render_motion: 112 | shs = None 113 | colors_precomp = torch.zeros_like(pc.get_xyz) 114 | colors_precomp[..., :1] = pc.motion_mask 115 | colors_precomp[..., -1:] = 1 - pc.motion_mask 116 | else: 117 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 118 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 119 | shs = None 120 | colors_precomp = None 121 | if colors_precomp is None: 122 | sh_features = torch.cat([pc.get_features[:, :1] + d_color[:, None], pc.get_features[:, 1:]], dim=1) if d_color is not None and type(d_color) is not float else pc.get_features 123 | if pipe.convert_SHs_python: 124 | shs_view = sh_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2) 125 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(sh_features.shape[0], 1)) 126 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 127 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 128 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 129 | else: 130 | shs = sh_features 131 | else: 132 | colors_precomp = override_color 133 | 134 | if detach_xyz: 135 | means3D = means3D.detach() 136 | if detach_rot or detach_scale: 137 | if cov3D_precomp is not None: 138 | cov3D_precomp = cov3D_precomp.detach() 139 | else: 140 | rotations = rotations.detach() if detach_rot else rotations 141 | scales = scales.detach() if detach_scale else scales 142 | if detach_opacity: 143 | opacity = opacity.detach() 144 | 145 | if scale_const is not None: 146 | scales = scale_const * torch.ones_like(scales) 147 | 148 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 149 | rendered_image, radii, depth, alpha = rasterizer( 150 | means3D=means3D, 151 | means2D=means2D, 152 | shs=shs, 153 | colors_precomp=colors_precomp, 154 | opacities=opacity, 155 | scales=scales, 156 | rotations=rotations, 157 | cov3D_precomp=cov3D_precomp) 158 | 159 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 160 | # They will be excluded from value updates used in the splitting criteria. 161 | return {"render": rendered_image, 162 | "viewspace_points": screenspace_points, 163 | "visibility_filter": radii > 0, 164 | "radii": radii, 165 | "depth": depth, 166 | "alpha": alpha, 167 | "bg_color": bg} 168 | 169 | 170 | def render_flow( 171 | pc: GaussianModel, 172 | viewpoint_camera1, 173 | viewpoint_camera2, 174 | d_xyz1, d_xyz2, 175 | d_rotation1, d_scaling1, 176 | scaling_modifier=1.0, 177 | compute_cov3D_python=False, 178 | scale_const=None, 179 | d_rot_as_res=True, 180 | **kwargs 181 | ): 182 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 183 | screenspace_points = ( 184 | torch.zeros_like( 185 | pc.get_xyz, 186 | dtype=pc.get_xyz.dtype, 187 | requires_grad=True, 188 | device="cuda", 189 | ) 190 | + 0 191 | ) 192 | try: 193 | screenspace_points.retain_grad() 194 | except: 195 | pass 196 | 197 | # Set up rasterization configuration 198 | tanfovx = math.tan(viewpoint_camera1.FoVx * 0.5) 199 | tanfovy = math.tan(viewpoint_camera1.FoVy * 0.5) 200 | 201 | # About Motion 202 | carnonical_xyz = pc.get_xyz.clone() 203 | xyz_at_t1 = xyz_at_t2 = carnonical_xyz.detach() # Detach coordinates of Gaussians here 204 | xyz_at_t1 = xyz_at_t1 + d_xyz1 205 | xyz_at_t2 = xyz_at_t2 + d_xyz2 206 | gaussians_homogeneous_coor_t2 = torch.cat([xyz_at_t2, torch.ones_like(xyz_at_t2[..., :1])], dim=-1) 207 | full_proj_transform = viewpoint_camera2.full_proj_transform if viewpoint_camera2 is not None else viewpoint_camera1.full_proj_transform 208 | gaussians_uvz_coor_at_cam2 = gaussians_homogeneous_coor_t2 @ full_proj_transform 209 | gaussians_uvz_coor_at_cam2 = gaussians_uvz_coor_at_cam2[..., :3] / gaussians_uvz_coor_at_cam2[..., -1:] 210 | 211 | gaussians_homogeneous_coor_t1 = torch.cat([xyz_at_t1, torch.ones_like(xyz_at_t1[..., :1])], dim=-1) 212 | gaussians_uvz_coor_at_cam1 = gaussians_homogeneous_coor_t1 @ viewpoint_camera1.full_proj_transform 213 | gaussians_uvz_coor_at_cam1 = gaussians_uvz_coor_at_cam1[..., :3] / gaussians_uvz_coor_at_cam1[..., -1:] 214 | 215 | flow_uvz_1to2 = gaussians_uvz_coor_at_cam2 - gaussians_uvz_coor_at_cam1 216 | 217 | # Rendering motion mask 218 | flow_uvz_1to2[..., -1:] = pc.motion_mask 219 | 220 | raster_settings = GaussianRasterizationSettings( 221 | image_height=int(viewpoint_camera1.image_height), 222 | image_width=int(viewpoint_camera1.image_width), 223 | tanfovx=tanfovx, 224 | tanfovy=tanfovy, 225 | bg = torch.zeros_like(flow_uvz_1to2[0]), # Background set as 0 226 | scale_modifier=scaling_modifier, 227 | viewmatrix=viewpoint_camera1.world_view_transform, 228 | projmatrix=viewpoint_camera1.full_proj_transform, 229 | sh_degree=0, 230 | campos=viewpoint_camera1.camera_center, 231 | prefiltered=False, 232 | debug=False, 233 | ) 234 | 235 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 236 | 237 | means3D = pc.get_xyz + d_xyz1 # About Motion 238 | means2D = screenspace_points 239 | opacity = pc.get_opacity 240 | 241 | if scale_const is not None: 242 | # If providing scale_const, directly use scale_const 243 | scales = torch.ones_like(pc.get_scaling) * scale_const 244 | if d_rot_as_res: 245 | rotations = pc.get_rotation + d_rotation1 246 | else: 247 | rotations = pc.get_rotation if type(d_rotation1) is float else quaternion_multiply(d_rotation1, pc.get_rotation) 248 | cov3D_precomp = None 249 | else: 250 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 251 | # scaling / rotation by the rasterizer. 252 | scales = None 253 | rotations = None 254 | cov3D_precomp = None 255 | if compute_cov3D_python: 256 | cov3D_precomp = pc.get_covariance(scaling_modifier, d_rotation=None if type(d_rotation1) is float else d_rotation1) 257 | else: 258 | scales = pc.get_scaling + d_scaling1 259 | if d_rot_as_res: 260 | rotations = pc.get_rotation + d_rotation1 261 | else: 262 | rotations = pc.get_rotation if type(d_rotation1) is float else quaternion_multiply(d_rotation1, pc.get_rotation) 263 | 264 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 265 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 266 | means3D=means3D, 267 | means2D=means2D, 268 | shs=None, 269 | colors_precomp=flow_uvz_1to2, 270 | opacities=opacity, 271 | scales=scales, 272 | rotations=rotations, 273 | cov3D_precomp=cov3D_precomp, 274 | ) 275 | 276 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 277 | # They will be excluded from value updates used in the splitting criteria. 278 | return { 279 | "render": rendered_image, 280 | "depth": rendered_depth, 281 | "alpha": rendered_alpha, 282 | "viewspace_points": screenspace_points, 283 | "visibility_filter": radii > 0, 284 | "radii": radii, 285 | } 286 | -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/gaussian_renderer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /gaussian_renderer/__pycache__/network_gui.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/gaussian_renderer/__pycache__/network_gui.cpython-38.pyc -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, 'little') 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, 'little')) 59 | conn.sendall(bytes(verify, 'ascii')) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 80 | world_view_transform[:, 1] = -world_view_transform[:, 1] 81 | world_view_transform[:, 2] = -world_view_transform[:, 2] 82 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 83 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 84 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 85 | except Exception as e: 86 | print("") 87 | traceback.print_exc() 88 | raise e 89 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 90 | else: 91 | return None, None, None, None, None, None 92 | -------------------------------------------------------------------------------- /lap_deform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch3d.ops 4 | from utils.arap_deform import ARAPDeformer 5 | from utils.deform_utils import cal_arap_error 6 | 7 | 8 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 9 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 10 | 11 | def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 12 | aw, ax, ay, az = torch.unbind(a, -1) 13 | bw, bx, by, bz = torch.unbind(b, -1) 14 | ow = aw * bw - ax * bx - ay * by - az * bz 15 | ox = aw * bx + ax * bw + ay * bz - az * by 16 | oy = aw * by - ax * bz + ay * bw + az * bx 17 | oz = aw * bz + ax * by - ay * bx + az * bw 18 | return torch.stack((ow, ox, oy, oz), -1) 19 | 20 | def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 21 | ab = quaternion_raw_multiply(a, b) 22 | return standardize_quaternion(ab) 23 | 24 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 25 | """ 26 | Returns torch.sqrt(torch.max(0, x)) 27 | but with a zero subgradient where x is 0. 28 | """ 29 | ret = torch.zeros_like(x) 30 | positive_mask = x > 0 31 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 32 | return ret 33 | 34 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Convert rotations given as rotation matrices to quaternions. 37 | 38 | Args: 39 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 40 | 41 | Returns: 42 | quaternions with real part first, as tensor of shape (..., 4). 43 | """ 44 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 45 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 46 | 47 | batch_dim = matrix.shape[:-2] 48 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 49 | matrix.reshape(batch_dim + (9,)), dim=-1 50 | ) 51 | 52 | q_abs = _sqrt_positive_part( 53 | torch.stack( 54 | [ 55 | 1.0 + m00 + m11 + m22, 56 | 1.0 + m00 - m11 - m22, 57 | 1.0 - m00 + m11 - m22, 58 | 1.0 - m00 - m11 + m22, 59 | ], 60 | dim=-1, 61 | ) 62 | ) 63 | 64 | # we produce the desired quaternion multiplied by each of r, i, j, k 65 | quat_by_rijk = torch.stack( 66 | [ 67 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 68 | # `int`. 69 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 70 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 71 | # `int`. 72 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 73 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 74 | # `int`. 75 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 76 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 77 | # `int`. 78 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 79 | ], 80 | dim=-2, 81 | ) 82 | 83 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 84 | # the candidate won't be picked. 85 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 86 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 87 | 88 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 89 | # forall i; we pick the best-conditioned one (with the largest denominator) 90 | 91 | return quat_candidates[ 92 | torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 93 | ].reshape(batch_dim + (4,)) 94 | 95 | 96 | class LapDeform(nn.Module): 97 | def __init__(self, init_pcl, K=4, trajectory=None, node_radius=None): 98 | super().__init__() 99 | self.K = K 100 | self.N = init_pcl.shape[0] 101 | nn_dist, nn_idxs, _ = pytorch3d.ops.knn_points(init_pcl[None], init_pcl[None], None, None, K=K+1) # N, K 102 | nn_dist, nn_idxs = nn_dist[0,:,1:], nn_idxs[0,:,1:] 103 | nn_dist = 1 / (nn_dist + 1e-7) 104 | self.nn_idxs = nn_idxs 105 | self._weight = nn.Parameter(torch.log(nn_dist / (nn_dist.sum(dim=1, keepdim=True) + 1e-5) + 1e-5)) 106 | self.init_pcl = init_pcl 107 | self.init_pcl_copy = init_pcl.clone() 108 | self.tensors = {} 109 | # self.optimizer = torch.optim.Adam([self._weight], lr=1e-5) 110 | self.mask_control_points = False 111 | if self.mask_control_points: 112 | self.generate_mask_init_pcl() 113 | radius = torch.linalg.norm(self.init_pcl_reduced.max(dim=0).values - self.init_pcl_reduced.min(dim=0).values) / 10 * 3 114 | print("Set ball query radius to %f" % radius.item()) 115 | self.arap_deformer = ARAPDeformer(self.init_pcl_reduced, radius=radius, K=30, point_mask=self.init_pcl_mask, trajectory=trajectory, node_radius=node_radius) 116 | else: 117 | radius = torch.linalg.norm(self.init_pcl.max(dim=0).values - self.init_pcl.min(dim=0).values) / 8 118 | print("Set ball query radius to %f" % radius.item()) 119 | self.arap_deformer = ARAPDeformer(init_pcl, radius=radius, K=16, trajectory=trajectory, node_radius=node_radius) 120 | 121 | self.optimizer = torch.optim.Adam([self.arap_deformer.weight], lr=1e-3) 122 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.99) 123 | self.optim_step = 0 124 | 125 | def generate_mask_init_pcl(self): 126 | init_pcl_mask = torch.linalg.norm(self.init_pcl, dim=-1) < 5 127 | self.init_pcl_mask = init_pcl_mask 128 | # init_pcl[~init_pcl_mask] = 0 129 | self.init_pcl_reduced = self.init_pcl[self.init_pcl_mask] 130 | 131 | 132 | def reset(self, ): 133 | self.init_pcl = self.init_pcl_copy.clone() 134 | self.arap_deformer.reset() 135 | self.optim_step = 0 136 | self.generate_mask_init_pcl() 137 | 138 | @property 139 | def weight(self): 140 | return torch.softmax(self._weight, dim=-1) 141 | 142 | @property 143 | def L(self): 144 | L = torch.eye(self.N).cuda() 145 | L.scatter_add_(dim=1, index=self.nn_idxs, src=-self.weight) 146 | return L 147 | 148 | def add_one_ring_nbs(self, idxs): 149 | if type(idxs) is list: 150 | idxs = torch.tensor(idxs).cuda() 151 | elif idxs.dim() == 0: 152 | idxs = idxs[None] 153 | nn_idxs = self.nn_idxs[idxs].reshape([-1]) 154 | return torch.unique(torch.cat([nn_idxs, idxs])) 155 | 156 | def add_n_ring_nbs(self, idxs, n=2): 157 | for i in range(n): 158 | idxs = self.add_one_ring_nbs(idxs) 159 | return idxs 160 | 161 | def initialize(self, pcl): 162 | b = self.L @ pcl 163 | self.tensors['b'] = b 164 | 165 | def estimate_R(self, pcl, return_quaternion=True): 166 | old_edges = torch.gather(input=self.init_pcl[:, None].repeat(1,self.K,1), dim=0, index=self.nn_idxs[..., None].repeat(1,1,3)) - self.init_pcl[:, None] # N, K, 3 167 | edges = torch.gather(input=pcl[:, None].repeat(1,self.K,1), dim=0, index=self.nn_idxs[..., None].repeat(1,1,3)) - pcl[:, None] # N, K, 3 168 | D = torch.diag_embed(self.weight, dim1=1, dim2=2) # N, K, K 169 | S = torch.bmm(old_edges.permute(0, 2, 1), torch.bmm(D, edges)) # N, 3, 3 170 | unchanged = torch.unique(torch.where((edges == old_edges).all(dim=1))[0]) 171 | S[unchanged] = 0 172 | U, _, W = torch.svd(S) 173 | R = torch.bmm(W, U.permute(0, 2, 1)) 174 | if return_quaternion: 175 | q = matrix_to_quaternion(R) 176 | return q 177 | else: 178 | return R 179 | 180 | def energy(self, pcl, prev_pcl=None): 181 | if prev_pcl is None: 182 | if 'b' not in self.tensors: 183 | print('Have not initialized yet and start with init pcl') 184 | self.initialize(self.init_pcl) 185 | b = self.tensors['b'] 186 | else: 187 | b = self.L @ prev_pcl 188 | loss = (self.L @ pcl - b).square().mean() 189 | return loss 190 | 191 | def energy_arap(self, pcl, prev_pcl): 192 | # loss = (self.arap_deformer.L_opt @ pcl - b).square().mean() 193 | self.optim_step += 1 194 | self.arap_deformer.cal_L_opt() 195 | node_seq = torch.stack([prev_pcl, pcl], dim=0) 196 | # print(self.arap_deformer.weight) 197 | loss = cal_arap_error(node_seq, self.arap_deformer.ii, self.arap_deformer.jj, self.arap_deformer.nn, K=self.arap_deformer.K, weight=self.arap_deformer.normalized_weight) 198 | return loss 199 | 200 | def deform(self, handle_idx, handle_pos, static_idx=None): 201 | if 'b' not in self.tensors: 202 | print('Have not initialized yet and start with init pcl') 203 | self.initialize(self.init_pcl) 204 | b = self.tensors['b'] 205 | handle_pos = torch.tensor(handle_pos).float().cuda() 206 | if static_idx is not None: 207 | static_pos = self.init_pcl[static_idx] 208 | handle_idx = handle_idx + static_idx 209 | handle_pos = torch.cat([handle_pos.cuda(), static_pos.cuda()], dim=0) 210 | return lstsq_with_handles(A=self.L, b=b, handle_idx=handle_idx, handle_pos=handle_pos) 211 | 212 | def deform_arap(self, handle_idx, handle_pos, init_verts=None, return_R=False): 213 | handle_idx = torch.tensor(handle_idx).long().cuda() 214 | if type(handle_pos) is not torch.Tensor: 215 | handle_pos = torch.from_numpy(handle_pos).float().cuda() 216 | deformed_p, deformed_r, deformed_s = self.arap_deformer.deform(handle_idx, handle_pos, init_verts=init_verts, return_R=return_R) 217 | if self.mask_control_points: 218 | deformed_p_all = self.init_pcl.clone() 219 | deformed_p_all[self.init_pcl_mask] = deformed_p 220 | deformed_r_all = torch.tensor([[1,0,0,0]]).to(deformed_r.dtype).to(deformed_r.device).repeat(deformed_p_all.shape[0],1) 221 | deformed_r_all[self.init_pcl_mask] = deformed_r 222 | return deformed_p_all, deformed_r_all, deformed_s 223 | else: 224 | return deformed_p, deformed_r, deformed_s 225 | 226 | 227 | def lstsq_with_handles(A, b, handle_idx, handle_pos): 228 | b = b - A[:, handle_idx] @ handle_pos 229 | handle_mask = torch.zeros_like(A[:, 0], dtype=bool) 230 | handle_mask[handle_idx] = 1 231 | L = A[:, handle_mask.logical_not()] 232 | x = torch.linalg.lstsq(L, b)[0] 233 | x_out = torch.zeros_like(b) 234 | x_out[handle_idx] = handle_pos 235 | x_out[handle_mask.logical_not()] = x 236 | return x_out 237 | 238 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | 18 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | 25 | def readImages(renders_dir, gt_dir): 26 | renders = [] 27 | gts = [] 28 | image_names = [] 29 | for fname in os.listdir(renders_dir): 30 | render = Image.open(renders_dir / fname) 31 | gt = Image.open(gt_dir / fname) 32 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 33 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 34 | image_names.append(fname) 35 | return renders, gts, image_names 36 | 37 | 38 | def evaluate(model_paths): 39 | full_dict = {} 40 | per_view_dict = {} 41 | full_dict_polytopeonly = {} 42 | per_view_dict_polytopeonly = {} 43 | print("") 44 | 45 | for scene_dir in model_paths: 46 | try: 47 | print("Scene:", scene_dir) 48 | full_dict[scene_dir] = {} 49 | per_view_dict[scene_dir] = {} 50 | full_dict_polytopeonly[scene_dir] = {} 51 | per_view_dict_polytopeonly[scene_dir] = {} 52 | 53 | test_dir = Path(scene_dir) / "test" 54 | 55 | for method in os.listdir(test_dir): 56 | if not method.startswith("ours"): 57 | continue 58 | print("Method:", method) 59 | 60 | full_dict[scene_dir][method] = {} 61 | per_view_dict[scene_dir][method] = {} 62 | full_dict_polytopeonly[scene_dir][method] = {} 63 | per_view_dict_polytopeonly[scene_dir][method] = {} 64 | 65 | method_dir = test_dir / method 66 | gt_dir = method_dir / "gt" 67 | renders_dir = method_dir / "renders" 68 | renders, gts, image_names = readImages(renders_dir, gt_dir) 69 | 70 | ssims = [] 71 | psnrs = [] 72 | lpipss = [] 73 | 74 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 75 | ssims.append(ssim(renders[idx], gts[idx])) 76 | psnrs.append(psnr(renders[idx], gts[idx])) 77 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 78 | 79 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 80 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 81 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 82 | print("") 83 | 84 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 85 | "PSNR": torch.tensor(psnrs).mean().item(), 86 | "LPIPS": torch.tensor(lpipss).mean().item()}) 87 | per_view_dict[scene_dir][method].update( 88 | {"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 89 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 90 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 91 | 92 | with open(scene_dir + "/results.json", 'w') as fp: 93 | json.dump(full_dict[scene_dir], fp, indent=True) 94 | with open(scene_dir + "/per_view.json", 'w') as fp: 95 | json.dump(per_view_dict[scene_dir], fp, indent=True) 96 | except: 97 | print("Unable to compute metrics for model", scene_dir) 98 | 99 | 100 | if __name__ == "__main__": 101 | device = torch.device("cuda:0") 102 | torch.cuda.set_device(device) 103 | 104 | # Set up command line argument parser 105 | parser = ArgumentParser(description="Training script parameters") 106 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 107 | args = parser.parse_args() 108 | evaluate(args.model_paths) 109 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

2 | SC-GS: Sparse-Controlled Gaussian Splatting for Editable Dynamic Scenes 3 | 4 | 5 | 6 | 7 |

8 | 9 | This is the code for SC-GS: Sparse-Controlled Gaussian Splatting for Editable Dynamic Scenes. 10 | 11 |
12 | 13 | [![Website](assets/badge-website.svg)](https://yihua7.github.io/SC-GS-web/) 14 | [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2312.14937) 15 | 16 |
17 | 18 |
19 | 20 | 21 | 22 | 23 |
24 | 25 |
26 | 27 | 28 | 29 | 30 |
31 | 32 | *With interactive editing empowered by SC-GS, users can effortlessly edit and customize their digital assets with interactive editing features.* 33 | 34 |
35 | 36 |
37 | 38 | *Given (a) an image sequence from a monocular dynamic video, we propose to represent the motion with a set of sparse control points, which can be used to drive 3D Gaussians for high-fidelity rendering.Our approach enables both (b) dynamic view synthesis and (c) motion editing due to the motion representation based on sparse control points* 39 | 40 | 41 | ## Updates 42 | 43 | ### 2025.05.21 44 | 45 | #### Editing Real-World Static Objects 46 | 47 | Solving a reported issues that invertable laplacian matrix causing slow editing. Editing real world static object is now flexible and show interesting results. 48 | 49 | #### 1. Masking the Object to Edit 50 | When editing, remember to mask the object you want to modify. If you are using MiVOS and encounter an issue with non-digital image names (e.g., `frame_000.jpg` causing errors), you can resolve it by replacing the line in [this file](https://github.com/hkchengrex/MiVOS/blob/f2600a6eea8709c7b9f1a7575adc725def680b81/interact/interactive_utils.py#L26) with the following: 51 | 52 | ```python 53 | fnames = sorted(glob.glob(os.path.join(path, '*.jpg')), key=lambda x: int(''.join(char for char in os.path.basename(x).split('.')[0] if char.isdigit()))) 54 | ``` 55 | 56 | #### 2. Training and Editing Static Scenes 57 | 58 | Run the following command to train and edit a static scene: 59 | 60 | ```bash 61 | CUDA_VISIBLE_DEVICES=0 python train_gui.py \ 62 | --source_path "XXX/person-small" \ 63 | --model_path "outputs/person/" \ 64 | --is_scene_static \ 65 | --gui \ 66 | --deform_type "node" \ 67 | --node_num "512" \ 68 | --gt_alpha_mask_as_dynamic_mask \ 69 | --gs_with_motion_mask \ 70 | --W "800" \ 71 | --H "800" \ 72 | --white_background \ 73 | --init_isotropic_gs_with_all_colmap_pcl 74 | ``` 75 | 76 | #### 3. Editing Results on I-N2N Scenes 77 | By following the editing guidance, you can easily achieve satisfactory geometry editing results on static scenes, as demonstrated in [Instruct-NeRF2NeRF](https://instruct-nerf2nerf.github.io/): 78 | 79 |
80 | 81 | 82 | 83 | 84 |
85 | 86 |
87 | 88 | 89 |
90 | 91 | ### 2024-03-17: 92 | 93 | 1. Editing **static scenes** is now supported! Simply include the `--is_scene_static` argument and you are good to go! 94 | 95 | 2. Video rendering is now supported with interpolation of editing results. Press the button `sv_kpt` to save each edited result and press `render_traj` to render the interpolated motions as a video. Click the `spiral` to switch the camera-motion pattern of the rendered video between a spiral trace and a fixed pose. 96 | 97 | 3. On self-captured real-world scenes where Gaussian number will be too large, the dimension of hyper coordinates that seperate close but disconnected parts can be set to 2 to speed up the rendering: ` --hyper_dim 2`. Also remember to remove `--is_blender` in such cases! 98 | 99 | ### 2024-03-07 100 | 101 | We offer two ARAP deformation strategies for motion editing: 1. iterative deformation and 2. deformation from Laplacian initialization. 102 | 103 | ### 2024-03-06 104 | 105 | To prevent initialization failure of control points, you use the argument `--init_isotropic_gs_with_all_colmap_pcl` on self-captured datasets. 106 | 107 | 108 | ## Install 109 | 110 | ```bash 111 | git clone https://github.com/yihua7/SC-GS --recursive 112 | cd SC-GS 113 | 114 | pip install -r requirements.txt 115 | 116 | # a modified gaussian splatting (+ depth, alpha rendering) 117 | pip install ./submodules/diff-gaussian-rasterization 118 | 119 | # simple-knn 120 | pip install ./submodules/simple-knn 121 | ``` 122 | 123 | ## Run 124 | 125 | ### Train wit GUI 126 | 127 | * To begin the training, select the 'start' button. The program will begin with pre-training control points in the form of Gaussians for 10,000 steps before progressing to train dynamic Gaussians. 128 | 129 | * To view the control points, click on the 'Node' button found on the panel located after 'Visualization'. 130 | 131 | ```bash 132 | # Train with GUI (for the resolution of 400*400 with best PSNR) 133 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800 --gui 134 | 135 | # Train with GUI (for the resolution of 800*800) 136 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --W 800 --H 800 --random_bg_color --white_background --gui 137 | ``` 138 | 139 | ### Train with terminal 140 | 141 | * Simply remove the option `--gui` as following: 142 | 143 | ```bash 144 | # Train with terminal only (for the resolution of 400*400 with best PSNR) 145 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800 146 | ``` 147 | 148 | ### Evalualuate 149 | 150 | * Every 1000 steps during the training, the program will evaluate SC-GS on the test set and print the results **on the UI interface and terminal**. You can view them easily. 151 | 152 | * You can also run the evaluation command by replacing `train_gui.py` with `render.py` in the command of training. Results will be saved in the specified log directory `outputs/XXX`. The following is an example: 153 | 154 | ```bash 155 | # Evaluate with GUI (for the resolution of 400*400 with best PSNR) 156 | CUDA_VISIBLE_DEVICES=0 python render.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --hyper_dim 8 --is_blender --eval --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800 157 | ``` 158 | 159 | ## Editing 160 | 161 | ### 2 min editing guidance: 162 | 163 | (The video was recorded prior to the addition of the editing mode selection menu in the UI. In the video, the deformation was performed using the `arap_from_init` method.) 164 | 165 | https://github.com/yihua7/SC-GS/assets/35869256/7a71d29b-975e-4870-afb1-7cdc96bb9482 166 | 167 | ### Editing Mode 168 | 169 | We offer two deformation strategies for editing: **(1)** iterative ARAP deformation and **(2)** ARAP starts with the initial frozen moment. Users can select their preferred strategy from the Editing Mode drop-down menu on the UI interface. 170 | 171 | 172 | 173 | (1) **Iterative deformation (`arap_iterative`)**: 174 | 175 | - **Pros**: It allows easy achievement of large-scale deformation without rotating artifacts. 176 | 177 | - **Cons**: It may be difficult to revert to the previous state after unintentionally obtaining unwanted deformations due to the iterative state update. 178 | 179 | (2) **Deformation from the initial frozen moment (`arap_from_init`)**: 180 | 181 | - **Pros**: It ensures that the deformed state can be restored when control points return to their previous positions, making it easier to control without deviation. 182 | 183 | - **Cons**: For large-scale rotational deformation, ARAP algorithm may fail to achieve the optimum since the initialization from the Laplace deformation is not robust to deal with rotation. This may result in certain areas not experiencing corresponding large-scale rotations. 184 | 185 | **Users can personally operate and experience the differences between the two strategies. They can then choose the most suitable strategy to achieve their desired editing effect.** 186 | 187 | ### Tips on Editing with the deformation from the initial frozen moment (`arap_from_init`) 188 | 189 | 1. **When and why will artifacts appear when using `arap_from_init`?** Most artifacts of editing are caused by the inaccurate initialization of ARAP deformation, which is an iterative optimization process of position and rotation. To optimize both position and rotation to a global optimum, a good initialization of ARAP is highly required. The mode `arap_from_init` uses Laplacian deformation for initialization, which only minimizes the error of the Laplacian coordinate that changes related to rotation. Hence Laplacian deformation is not robust enough for rotation, resulting in inaccurate initialization in the face of large rotation. As a result, some areas fail to achieve correct rotations in subsequent ARAP deformation results. 190 | 191 | 2. **How to deal with artifacts?** To address this issue, the following steps are recommended, of which the core idea is to **include as many control points as possible** for large-scale deformation: (1) If you treat a big region as a rigid part and would like to apply a large deformation, use more control points to include the whole part and manipulate these control points to deform. This allows for a better Laplacian deformation result and better initialization of ARAP deformation. (2) Edit hierarchically. If you need to apply deformation of different levels, please first add control points at the finest part and deform it. After that, you can include more control points; treat them as a rigid body; and perform deformation of larger levels. 192 | 193 | 3. More tips: (1) To more efficiently add handle points, you can set the parameter `n_rings` to 3 or 4 on the GUI interface. (2) You can press `Node` button to visualize control points and check if there are any points in the region of interest missed. Press `RGB` to switch back the Gaussian rendering. 194 | 195 | 4. The above are some operational tricks for editing with `arap_from_init`, which require a sufficient understanding of ARAP deformation or more practice and attempts. This will allow for a clearer understanding of how to operate and achieve the desired deformation results. 196 | 197 | ## SOTA Performance 198 | 199 | Quantitative comparison on D-NeRF datasets. We present the average PSNR/SSIM/LPIPS (VGG) values for novel view synthesis on dynamic scenes from D-NeRF, with each cell colored to indicate the best, second best, and third best. 200 |
201 | 202 |
203 | 204 | ## Dataset 205 | 206 | Our datareader script can recognize and read the following dataset format automatically: 207 | 208 | * [D-NeRF](https://www.albertpumarola.com/research/D-NeRF/index.html): dynamic scenes of synthetic objects ([download](https://www.dropbox.com/s/0bf6fl0ye2vz3vr/data.zip?e=1&dl=0)) 209 | 210 | * [NeRF-DS](https://jokeryan.github.io/projects/nerf-ds/): dynamic scenes of specular objects ([download](https://github.com/JokerYan/NeRF-DS/releases/tag/v0.1-pre-release)) 211 | 212 | * Self-captured videos: 1. install [MiVOS](https://github.com/hkchengrex/MiVOS) and place [interactive_invoke.py](data_tools/interactive_invoke.py) under the installed path. 2. Set the video path in [phone_catch.py](data_tools/phone_catch.py) and run ```python ./data_tools/phone_catch.py``` to achieve frame extraction, video segmentation, and COLMAP pose estimation in sequence. Please refer to [NeRF-Texture](https://github.com/yihua7/NeRF-Texture) for detailed tutorials. 213 | 214 | * Static self-captured scenes: For self-captured static scenes, editing is now also supported! Simply include the `--is_scene_static` argument and you are good to go! 215 | 216 | **Important Note for Using Self-captured Videos**: 217 | 218 | * Please remember to remove `--is_blender` option in your command, which causes the control points to be initialized from random point clouds instead of COLMAP point clouds. 219 | * Additionally, you can remove `--gt_alpha_mask_as_scene_mask` and add `--gt_alpha_mask_as_dynamic_mask --gs_with_motion_mask` if you want to model both the dynamic foreground masked by MiVOS and the static background simultaneously. 220 | * If removing `--is_blender` still meets the failure of control point initialization, please use the option: `--init_isotropic_gs_with_all_colmap_pcl`. This will initialize the isotropic Gaussians with all COLMAP point clouds, which can help avoid the risk of control points becoming extinct. 221 | * The dimension of hyper coordinates that seperate close but disconnected parts can be set to 2 to avoid the slow rendering: `--hyper_dim 2`. 222 | 223 | 224 | ## Acknowledgement 225 | 226 | * This framework has been adapted from the notable [Deformable 3D Gaussians](https://github.com/ingra14m/Deformable-3D-Gaussians), an excellent and pioneering work by [Ziyi Yang](https://github.com/ingra14m). 227 | ``` 228 | @article{yang2023deformable3dgs, 229 | title={Deformable 3D Gaussians for High-Fidelity Monocular Dynamic Scene Reconstruction}, 230 | author={Yang, Ziyi and Gao, Xinyu and Zhou, Wen and Jiao, Shaohui and Zhang, Yuqing and Jin, Xiaogang}, 231 | journal={arXiv preprint arXiv:2309.13101}, 232 | year={2023} 233 | } 234 | ``` 235 | 236 | * Credits to authors of [3D Gaussians](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) for their excellent code. 237 | ``` 238 | @Article{kerbl3Dgaussians, 239 | author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George}, 240 | title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering}, 241 | journal = {ACM Transactions on Graphics}, 242 | number = {4}, 243 | volume = {42}, 244 | month = {July}, 245 | year = {2023}, 246 | url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/} 247 | } 248 | ``` 249 | 250 | ## Citing 251 | If you find our work useful, please consider citing: 252 | ```BibTeX 253 | @article{huang2023sc, 254 | title={SC-GS: Sparse-Controlled Gaussian Splatting for Editable Dynamic Scenes}, 255 | author={Huang, Yi-Hua and Sun, Yang-Tian and Yang, Ziyi and Lyu, Xiaoyang and Cao, Yan-Pei and Qi, Xiaojuan}, 256 | journal={arXiv preprint arXiv:2312.14937}, 257 | year={2023} 258 | } 259 | ``` 260 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene, DeformModel 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from utils.pose_utils import pose_spherical 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams 23 | from gaussian_renderer import GaussianModel 24 | import imageio 25 | import numpy as np 26 | from pytorch_msssim import ms_ssim 27 | from piq import LPIPS 28 | lpips = LPIPS() 29 | from utils.image_utils import ssim as ssim_func 30 | from utils.image_utils import psnr, lpips, alex_lpips 31 | 32 | 33 | def render_set(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 34 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 35 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 36 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 37 | 38 | makedirs(render_path, exist_ok=True) 39 | makedirs(gts_path, exist_ok=True) 40 | makedirs(depth_path, exist_ok=True) 41 | 42 | # Measurement 43 | psnr_list, ssim_list, lpips_list = [], [], [] 44 | ms_ssim_list, alex_lpips_list = [], [] 45 | 46 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 47 | renderings = [] 48 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 49 | if load2gpt_on_the_fly: 50 | view.load2device() 51 | fid = view.fid 52 | xyz = gaussians.get_xyz 53 | if deform.name == 'mlp': 54 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 55 | elif deform.name == 'node': 56 | time_input = deform.deform.expand_time(fid) 57 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask) 58 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 59 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res) 60 | alpha = results["alpha"] 61 | rendering = torch.clamp(torch.cat([results["render"], alpha]), 0.0, 1.0) 62 | 63 | # Measurement 64 | image = rendering[:3] 65 | gt_image = torch.clamp(view.original_image.to("cuda"), 0.0, 1.0) 66 | psnr_list.append(psnr(image[None], gt_image[None]).mean()) 67 | ssim_list.append(ssim_func(image[None], gt_image[None], data_range=1.).mean()) 68 | lpips_list.append(lpips(image[None], gt_image[None]).mean()) 69 | ms_ssim_list.append(ms_ssim(image[None], gt_image[None], data_range=1.).mean()) 70 | alex_lpips_list.append(alex_lpips(image[None], gt_image[None]).mean()) 71 | 72 | renderings.append(to8b(rendering.cpu().numpy())) 73 | depth = results["depth"] 74 | depth = depth / (depth.max() + 1e-5) 75 | 76 | gt = view.original_image[0:4, :, :] 77 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 78 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 79 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 80 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 81 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 82 | 83 | # Measurement 84 | psnr_test = torch.stack(psnr_list).mean() 85 | ssim_test = torch.stack(ssim_list).mean() 86 | lpips_test = torch.stack(lpips_list).mean() 87 | ms_ssim_test = torch.stack(ms_ssim_list).mean() 88 | alex_lpips_test = torch.stack(alex_lpips_list).mean() 89 | print("\n[ITER {}] Evaluating {}: PSNR {} SSIM {} LPIPS {} MS SSIM{} ALEX_LPIPS {}".format(iteration, name, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test)) 90 | 91 | 92 | def interpolate_time(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 93 | render_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "renders") 94 | depth_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "depth") 95 | 96 | makedirs(render_path, exist_ok=True) 97 | makedirs(depth_path, exist_ok=True) 98 | 99 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 100 | 101 | frame = 150 102 | idx = torch.randint(0, len(views), (1,)).item() 103 | view = views[idx] 104 | renderings = [] 105 | for t in tqdm(range(0, frame, 1), desc="Rendering progress"): 106 | fid = torch.Tensor([t / (frame - 1)]).cuda() 107 | xyz = gaussians.get_xyz 108 | if deform.name == 'deform': 109 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 110 | elif deform.name == 'node': 111 | time_input = deform.deform.expand_time(fid) 112 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask) 113 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 114 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res) 115 | rendering = results["render"] 116 | renderings.append(to8b(rendering.cpu().numpy())) 117 | depth = results["depth"] 118 | depth = depth / (depth.max() + 1e-5) 119 | 120 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(t) + ".png")) 121 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(t) + ".png")) 122 | 123 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 124 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 125 | 126 | 127 | def interpolate_all(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 128 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders") 129 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth") 130 | 131 | makedirs(render_path, exist_ok=True) 132 | makedirs(depth_path, exist_ok=True) 133 | 134 | frame = 150 135 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 0) 136 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 137 | 138 | idx = torch.randint(0, len(views), (1,)).item() 139 | view = views[idx] # Choose a specific time for rendering 140 | 141 | renderings = [] 142 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 143 | fid = torch.Tensor([i / (frame - 1)]).cuda() 144 | 145 | matrix = np.linalg.inv(np.array(pose)) 146 | R = -np.transpose(matrix[:3, :3]) 147 | R[:, 0] = -R[:, 0] 148 | T = -matrix[:3, 3] 149 | 150 | view.reset_extrinsic(R, T) 151 | 152 | xyz = gaussians.get_xyz 153 | if deform.name == 'mlp': 154 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 155 | elif deform.name == 'node': 156 | time_input = deform.deform.expand_time(fid) 157 | 158 | d_values = deform.step(xyz.detach(), time_input, feature=gaussians.feature, motion_mask=gaussians.motion_mask) 159 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 160 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res) 161 | rendering = torch.clamp(results["render"], 0.0, 1.0) 162 | renderings.append(to8b(rendering.cpu().numpy())) 163 | depth = results["depth"] 164 | depth = depth / (depth.max() + 1e-5) 165 | 166 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 167 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 168 | 169 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 170 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 171 | 172 | 173 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool, mode: str, load2device_on_the_fly=False): 174 | with torch.no_grad(): 175 | 176 | deform = DeformModel(K=dataset.K, deform_type=dataset.deform_type, is_blender=dataset.is_blender, skinning=dataset.skinning, hyper_dim=dataset.hyper_dim, node_num=dataset.node_num, pred_opacity=dataset.pred_opacity, pred_color=dataset.pred_color, use_hash=dataset.use_hash, hash_time=dataset.hash_time, d_rot_as_res=dataset.d_rot_as_res, local_frame=dataset.local_frame, progressive_brand_time=dataset.progressive_brand_time, max_d_scale=dataset.max_d_scale) 177 | deform.load_weights(dataset.model_path, iteration=iteration) 178 | 179 | gs_fea_dim = deform.deform.node_num if dataset.skinning and deform.name == 'node' else dataset.hyper_dim 180 | gaussians = GaussianModel(dataset.sh_degree, fea_dim=gs_fea_dim, with_motion_mask=dataset.gs_with_motion_mask) 181 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 182 | 183 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 184 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 185 | 186 | if mode == "render": 187 | render_func = render_set 188 | elif mode == "time": 189 | render_func = interpolate_time 190 | else: 191 | render_func = interpolate_all 192 | 193 | if not skip_train: 194 | render_func(dataset.model_path, load2device_on_the_fly, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, deform) 195 | 196 | if not skip_test: 197 | render_func(dataset.model_path, load2device_on_the_fly, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, deform) 198 | 199 | 200 | if __name__ == "__main__": 201 | # Set up command line argument parser 202 | parser = ArgumentParser(description="Testing script parameters") 203 | model = ModelParams(parser, sentinel=True) 204 | pipeline = PipelineParams(parser) 205 | op = OptimizationParams(parser) 206 | parser.add_argument("--iteration", default=-1, type=int) 207 | parser.add_argument("--skip_train", action="store_true") 208 | parser.add_argument("--skip_test", action="store_true") 209 | parser.add_argument("--quiet", action="store_true") 210 | parser.add_argument("--mode", default='render', choices=['render', 'time', 'view', 'all', 'pose', 'original']) 211 | 212 | parser.add_argument('--gui', action='store_true', help="start a GUI") 213 | parser.add_argument('--W', type=int, default=800, help="GUI width") 214 | parser.add_argument('--H', type=int, default=800, help="GUI height") 215 | parser.add_argument('--elevation', type=float, default=0, help="default GUI camera elevation") 216 | parser.add_argument('--radius', type=float, default=5, help="default GUI camera radius from center") 217 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") 218 | 219 | parser.add_argument('--ip', type=str, default="127.0.0.1") 220 | parser.add_argument('--port', type=int, default=6009) 221 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 222 | parser.add_argument("--test_iterations", nargs="+", type=int, 223 | default=[5000, 6000, 7_000] + list(range(10000, 80_0001, 1000))) 224 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 10_000, 20_000, 30_000, 40000]) 225 | # parser.add_argument("--quiet", action="store_true") 226 | parser.add_argument("--deform-type", type=str, default='mlp') 227 | 228 | args = get_combined_args(parser) 229 | if not args.model_path.endswith(args.deform_type): 230 | args.model_path = os.path.join(os.path.dirname(os.path.normpath(args.model_path)), os.path.basename(os.path.normpath(args.model_path)) + f'_{args.deform_type}') 231 | print("Rendering " + args.model_path) 232 | 233 | # Initialize system state (RNG) 234 | safe_state(args.quiet) 235 | 236 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode, load2device_on_the_fly=args.load2gpu_on_the_fly) 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | opencv-python==4.5.5.62 3 | Pillow==7.0.0 4 | PyYAML==6.0 5 | scipy==1.10.1 6 | tensorboard==2.14.0 7 | torch==1.12.1+cu113 # or any later versions 8 | tqdm==4.66.1 9 | imageio 10 | plyfile 11 | piq 12 | dearpygui 13 | lpips 14 | pytorch_msssim 15 | matplotlib 16 | scikit-image 17 | git+https://github.com/Po-Hsun-Su/pytorch-ssim.git 18 | git+https://github.com/facebookresearch/pytorch3d.git 19 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.deform_model import DeformModel 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 21 | 22 | 23 | class Scene: 24 | gaussians: GaussianModel 25 | 26 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.loaded_iter = None 32 | self.gaussians = gaussians 33 | 34 | if load_iteration: 35 | if load_iteration == -1: 36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 37 | else: 38 | self.loaded_iter = load_iteration 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | self.train_cameras = {} 42 | self.test_cameras = {} 43 | 44 | if os.path.exists(os.path.join(args.source_path, "sparse")) or os.path.exists(os.path.join(args.source_path, "colmap_sparse")): 45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 47 | print("Found transforms_train.json file, assuming Blender data set!") 48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 49 | elif os.path.exists(os.path.join(args.source_path, "cameras_sphere.npz")): 50 | print("Found cameras_sphere.npz file, assuming DTU data set!") 51 | scene_info = sceneLoadTypeCallbacks["DTU"](args.source_path, "cameras_sphere.npz", "cameras_sphere.npz") 52 | elif os.path.exists(os.path.join(args.source_path, "dataset.json")): 53 | print("Found dataset.json file, assuming Nerfies data set!") 54 | scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, args.eval) 55 | elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")): 56 | print("Found calibration_full.json, assuming Neu3D data set!") 57 | scene_info = sceneLoadTypeCallbacks["plenopticVideo"](args.source_path, args.eval, 24) 58 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 59 | print("Found calibration_full.json, assuming Dynamic-360 data set!") 60 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.images, args.eval) 61 | elif os.path.exists(os.path.join(args.source_path, "train_meta.json")): 62 | print("Found train_meta.json, assuming CMU data set!") 63 | scene_info = sceneLoadTypeCallbacks["CMU"](args.source_path) 64 | else: 65 | assert False, "Could not recognize scene type!" 66 | 67 | if not self.loaded_iter: 68 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 'wb') as dest_file: 69 | dest_file.write(src_file.read()) 70 | json_cams = [] 71 | camlist = [] 72 | if scene_info.test_cameras: 73 | camlist.extend(scene_info.test_cameras) 74 | if scene_info.train_cameras: 75 | camlist.extend(scene_info.train_cameras) 76 | for id, cam in enumerate(camlist): 77 | json_cams.append(camera_to_JSON(id, cam)) 78 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 79 | json.dump(json_cams, file) 80 | 81 | # Read flow data 82 | self.flow_dir = os.path.join(args.source_path, "raft_neighbouring") 83 | flow_list = os.listdir(self.flow_dir) if os.path.exists(self.flow_dir) else [] 84 | flow_dirs_list = [] 85 | for cam in scene_info.train_cameras: 86 | flow_dirs_list.append([os.path.join(self.flow_dir, flow_dir) for flow_dir in flow_list if flow_dir.startswith(cam.image_name+'.')]) 87 | 88 | # if shuffle: 89 | # random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 90 | # random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 91 | 92 | self.cameras_extent = scene_info.nerf_normalization["radius"] 93 | 94 | for resolution_scale in resolution_scales: 95 | print("Loading Training Cameras") 96 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, flow_dirs_list) 97 | print("Loading Test Cameras") 98 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 99 | 100 | if self.loaded_iter: 101 | self.gaussians.load_ply(os.path.join(self.model_path, 102 | "point_cloud", 103 | "iteration_" + str(self.loaded_iter), 104 | "point_cloud.ply"), 105 | og_number_points=len(scene_info.point_cloud.points)) 106 | else: 107 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 108 | 109 | def save(self, iteration): 110 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 111 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 112 | 113 | def getTrainCameras(self, scale=1.0): 114 | return self.train_cameras[scale] 115 | 116 | def getTestCameras(self, scale=1.0): 117 | return self.test_cameras[scale] 118 | -------------------------------------------------------------------------------- /scene/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scene/__pycache__/cameras.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/cameras.cpython-38.pyc -------------------------------------------------------------------------------- /scene/__pycache__/colmap_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/colmap_loader.cpython-38.pyc -------------------------------------------------------------------------------- /scene/__pycache__/dataset_readers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/dataset_readers.cpython-38.pyc -------------------------------------------------------------------------------- /scene/__pycache__/deform_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/deform_model.cpython-38.pyc -------------------------------------------------------------------------------- /scene/__pycache__/gaussian_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SC-GS/d052a57d000cbf14e4bdf993f30102376cb3effa/scene/__pycache__/gaussian_model.cpython-38.pyc -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class Camera(nn.Module): 19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", fid=None, depth=None, flow_dirs=[]): 20 | super(Camera, self).__init__() 21 | 22 | self.uid = uid 23 | self.colmap_id = colmap_id 24 | self.R = R 25 | self.T = T 26 | self.FoVx = FoVx 27 | self.FoVy = FoVy 28 | self.image_name = image_name 29 | self.flow_dirs = flow_dirs 30 | 31 | try: 32 | self.data_device = torch.device(data_device) 33 | except Exception as e: 34 | print(e) 35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 36 | self.data_device = torch.device("cuda") 37 | 38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 39 | self.fid = torch.Tensor(np.array([fid])).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None 43 | self.gt_alpha_mask = gt_alpha_mask 44 | 45 | if gt_alpha_mask is not None: 46 | self.gt_alpha_mask = self.gt_alpha_mask.to(self.data_device) 47 | # self.original_image *= gt_alpha_mask.to(self.data_device) 48 | 49 | self.zfar = 100.0 50 | self.znear = 0.01 51 | 52 | self.trans = trans 53 | self.scale = scale 54 | 55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to(self.data_device) 56 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0, 1).to(self.data_device) 57 | self.full_proj_transform = ( 58 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 59 | self.camera_center = self.world_view_transform.inverse()[3, :3] 60 | 61 | def reset_extrinsic(self, R, T): 62 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda() 63 | self.full_proj_transform = ( 64 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 65 | self.camera_center = self.world_view_transform.inverse()[3, :3] 66 | 67 | def load2device(self, data_device='cuda'): 68 | self.original_image = self.original_image.to(data_device) 69 | self.world_view_transform = self.world_view_transform.to(data_device) 70 | self.projection_matrix = self.projection_matrix.to(data_device) 71 | self.full_proj_transform = self.full_proj_transform.to(data_device) 72 | self.camera_center = self.camera_center.to(data_device) 73 | self.fid = self.fid.to(data_device) 74 | 75 | 76 | class MiniCam: 77 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 78 | self.image_width = width 79 | self.image_height = height 80 | self.FoVy = fovy 81 | self.FoVx = fovx 82 | self.znear = znear 83 | self.zfar = zfar 84 | self.world_view_transform = world_view_transform 85 | self.full_proj_transform = full_proj_transform 86 | view_inv = torch.inverse(self.world_view_transform) 87 | self.camera_center = view_inv[3][:3] 88 | 89 | def reset_extrinsic(self, R, T): 90 | self.world_view_transform = torch.tensor(getWorld2View2(R, T)).transpose(0, 1).cuda() 91 | self.full_proj_transform = ( 92 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 93 | self.camera_center = self.world_view_transform.inverse()[3, :3] 94 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) 54 | 55 | 56 | def rotmat2qvec(R): 57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 58 | K = np.array([ 59 | [Rxx - Ryy - Rzz, 0, 0, 0], 60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 63 | eigvals, eigvecs = np.linalg.eigh(K) 64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 65 | if qvec[0] < 0: 66 | qvec *= -1 67 | return qvec 68 | 69 | 70 | class Image(BaseImage): 71 | def qvec2rotmat(self): 72 | return qvec2rotmat(self.qvec) 73 | 74 | 75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 76 | """Read and unpack the next bytes from a binary file. 77 | :param fid: 78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 80 | :param endian_character: Any of {@, =, <, >, !} 81 | :return: Tuple of read and unpacked values. 82 | """ 83 | data = fid.read(num_bytes) 84 | return struct.unpack(endian_character + format_char_sequence, data) 85 | 86 | 87 | def read_points3D_text(path): 88 | """ 89 | see: src/base/reconstruction.cc 90 | void Reconstruction::ReadPoints3DText(const std::string& path) 91 | void Reconstruction::WritePoints3DText(const std::string& path) 92 | """ 93 | xyzs = None 94 | rgbs = None 95 | errors = None 96 | with open(path, "r") as fid: 97 | while True: 98 | line = fid.readline() 99 | if not line: 100 | break 101 | line = line.strip() 102 | if len(line) > 0 and line[0] != "#": 103 | elems = line.split() 104 | xyz = np.array(tuple(map(float, elems[1:4]))) 105 | rgb = np.array(tuple(map(int, elems[4:7]))) 106 | error = np.array(float(elems[7])) 107 | if xyzs is None: 108 | xyzs = xyz[None, ...] 109 | rgbs = rgb[None, ...] 110 | errors = error[None, ...] 111 | else: 112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 114 | errors = np.append(errors, error[None, ...], axis=0) 115 | return xyzs, rgbs, errors 116 | 117 | 118 | def read_points3D_binary(path_to_model_file): 119 | """ 120 | see: src/base/reconstruction.cc 121 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 122 | void Reconstruction::WritePoints3DBinary(const std::string& path) 123 | """ 124 | 125 | with open(path_to_model_file, "rb") as fid: 126 | num_points = read_next_bytes(fid, 8, "Q")[0] 127 | 128 | xyzs = np.empty((num_points, 3)) 129 | rgbs = np.empty((num_points, 3)) 130 | errors = np.empty((num_points, 1)) 131 | 132 | for p_id in range(num_points): 133 | binary_point_line_properties = read_next_bytes( 134 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 135 | xyz = np.array(binary_point_line_properties[1:4]) 136 | rgb = np.array(binary_point_line_properties[4:7]) 137 | error = np.array(binary_point_line_properties[7]) 138 | track_length = read_next_bytes( 139 | fid, num_bytes=8, format_char_sequence="Q")[0] 140 | track_elems = read_next_bytes( 141 | fid, num_bytes=8 * track_length, 142 | format_char_sequence="ii" * track_length) 143 | xyzs[p_id] = xyz 144 | rgbs[p_id] = rgb 145 | errors[p_id] = error 146 | return xyzs, rgbs, errors 147 | 148 | 149 | def read_intrinsics_text(path): 150 | """ 151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 152 | """ 153 | cameras = {} 154 | with open(path, "r") as fid: 155 | while True: 156 | line = fid.readline() 157 | if not line: 158 | break 159 | line = line.strip() 160 | if len(line) > 0 and line[0] != "#": 161 | elems = line.split() 162 | camera_id = int(elems[0]) 163 | model = elems[1] 164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 165 | width = int(elems[2]) 166 | height = int(elems[3]) 167 | params = np.array(tuple(map(float, elems[4:]))) 168 | cameras[camera_id] = Camera(id=camera_id, model=model, 169 | width=width, height=height, 170 | params=params) 171 | return cameras 172 | 173 | 174 | def read_extrinsics_binary(path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::ReadImagesBinary(const std::string& path) 178 | void Reconstruction::WriteImagesBinary(const std::string& path) 179 | """ 180 | images = {} 181 | with open(path_to_model_file, "rb") as fid: 182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 183 | for _ in range(num_reg_images): 184 | binary_image_properties = read_next_bytes( 185 | fid, num_bytes=64, format_char_sequence="idddddddi") 186 | image_id = binary_image_properties[0] 187 | qvec = np.array(binary_image_properties[1:5]) 188 | tvec = np.array(binary_image_properties[5:8]) 189 | camera_id = binary_image_properties[8] 190 | image_name = "" 191 | current_char = read_next_bytes(fid, 1, "c")[0] 192 | while current_char != b"\x00": # look for the ASCII 0 entry 193 | image_name += current_char.decode("utf-8") 194 | current_char = read_next_bytes(fid, 1, "c")[0] 195 | num_points2D = read_next_bytes(fid, num_bytes=8, 196 | format_char_sequence="Q")[0] 197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, 198 | format_char_sequence="ddq" * num_points2D) 199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 200 | tuple(map(float, x_y_id_s[1::3]))]) 201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 202 | images[image_id] = Image( 203 | id=image_id, qvec=qvec, tvec=tvec, 204 | camera_id=camera_id, name=image_name, 205 | xys=xys, point3D_ids=point3D_ids) 206 | return images 207 | 208 | 209 | def read_intrinsics_binary(path_to_model_file): 210 | """ 211 | see: src/base/reconstruction.cc 212 | void Reconstruction::WriteCamerasBinary(const std::string& path) 213 | void Reconstruction::ReadCamerasBinary(const std::string& path) 214 | """ 215 | cameras = {} 216 | with open(path_to_model_file, "rb") as fid: 217 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 218 | for _ in range(num_cameras): 219 | camera_properties = read_next_bytes( 220 | fid, num_bytes=24, format_char_sequence="iiQQ") 221 | camera_id = camera_properties[0] 222 | model_id = camera_properties[1] 223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 224 | width = camera_properties[2] 225 | height = camera_properties[3] 226 | num_params = CAMERA_MODEL_IDS[model_id].num_params 227 | params = read_next_bytes(fid, num_bytes=8 * num_params, 228 | format_char_sequence="d" * num_params) 229 | cameras[camera_id] = Camera(id=camera_id, 230 | model=model_name, 231 | width=width, 232 | height=height, 233 | params=np.array(params)) 234 | assert len(cameras) == num_cameras 235 | return cameras 236 | 237 | 238 | def read_extrinsics_text(path): 239 | """ 240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 241 | """ 242 | images = {} 243 | with open(path, "r") as fid: 244 | while True: 245 | line = fid.readline() 246 | if not line: 247 | break 248 | line = line.strip() 249 | if len(line) > 0 and line[0] != "#": 250 | elems = line.split() 251 | image_id = int(elems[0]) 252 | qvec = np.array(tuple(map(float, elems[1:5]))) 253 | tvec = np.array(tuple(map(float, elems[5:8]))) 254 | camera_id = int(elems[8]) 255 | image_name = elems[9] 256 | elems = fid.readline().split() 257 | xys = np.column_stack([tuple(map(float, elems[0::3])), 258 | tuple(map(float, elems[1::3]))]) 259 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 260 | images[image_id] = Image( 261 | id=image_id, qvec=qvec, tvec=tvec, 262 | camera_id=camera_id, name=image_name, 263 | xys=xys, point3D_ids=point3D_ids) 264 | return images 265 | 266 | 267 | def read_colmap_bin_array(path): 268 | """ 269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 270 | 271 | :param path: path to the colmap binary file. 272 | :return: nd array with the floating point values in the value 273 | """ 274 | with open(path, "rb") as fid: 275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 276 | usecols=(0, 1, 2), dtype=int) 277 | fid.seek(0) 278 | num_delimiter = 0 279 | byte = fid.read(1) 280 | while True: 281 | if byte == b"&": 282 | num_delimiter += 1 283 | if num_delimiter >= 3: 284 | break 285 | byte = fid.read(1) 286 | array = np.fromfile(fid, np.float32) 287 | array = array.reshape((width, height, channels), order="F") 288 | return np.transpose(array, (1, 0, 2)).squeeze() 289 | -------------------------------------------------------------------------------- /scene/deform_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.time_utils import DeformNetwork, ControlNodeWarp, StaticNetwork 5 | import os 6 | from utils.system_utils import searchForMaxIteration 7 | from utils.general_utils import get_expon_lr_func 8 | 9 | 10 | model_dict = {'mlp': DeformNetwork, 'node': ControlNodeWarp, 'static': StaticNetwork} 11 | 12 | 13 | class DeformModel: 14 | def __init__(self, deform_type='node', is_blender=False, d_rot_as_res=True, **kwargs): 15 | self.deform = model_dict[deform_type](is_blender=is_blender, d_rot_as_res=d_rot_as_res, **kwargs).cuda() 16 | self.name = self.deform.name 17 | self.optimizer = None 18 | self.spatial_lr_scale = 5 19 | self.d_rot_as_res = d_rot_as_res 20 | 21 | @property 22 | def reg_loss(self): 23 | return self.deform.reg_loss 24 | 25 | def step(self, xyz, time_emb, iteration=0, **kwargs): 26 | return self.deform(xyz, time_emb, iteration=iteration, **kwargs) 27 | 28 | def train_setting(self, training_args): 29 | l = [ 30 | {'params': group['params'], 31 | 'lr': training_args.position_lr_init * self.spatial_lr_scale * training_args.deform_lr_scale, 32 | "name": group['name']} 33 | for group in self.deform.trainable_parameters() 34 | ] 35 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 36 | 37 | self.deform_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale * training_args.deform_lr_scale, lr_final=training_args.position_lr_final * training_args.deform_lr_scale, lr_delay_mult=training_args.position_lr_delay_mult, max_steps=training_args.deform_lr_max_steps) 38 | if self.name == 'node': 39 | self.deform.as_gaussians.training_setup(training_args) 40 | 41 | def save_weights(self, model_path, iteration): 42 | out_weights_path = os.path.join(model_path, "deform/iteration_{}".format(iteration)) 43 | os.makedirs(out_weights_path, exist_ok=True) 44 | torch.save(self.deform.state_dict(), os.path.join(out_weights_path, 'deform.pth')) 45 | 46 | def load_weights(self, model_path, iteration=-1): 47 | if iteration == -1: 48 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "deform")) 49 | else: 50 | loaded_iter = iteration 51 | weights_path = os.path.join(model_path, "deform/iteration_{}/deform.pth".format(loaded_iter)) 52 | if os.path.exists(weights_path): 53 | self.deform.load_state_dict(torch.load(weights_path)) 54 | return True 55 | else: 56 | return False 57 | 58 | def update_learning_rate(self, iteration): 59 | for param_group in self.optimizer.param_groups: 60 | if param_group["name"] == "deform" or param_group["name"] == "mlp" or 'node' in param_group['name']: 61 | lr = self.deform_scheduler_args(iteration) 62 | param_group['lr'] = lr 63 | return lr 64 | 65 | def densify(self, max_grad, x, x_grad, **kwargs): 66 | if self.name == 'node': 67 | self.deform.densify(max_grad=max_grad, optimizer=self.optimizer, x=x, x_grad=x_grad, **kwargs) 68 | else: 69 | return 70 | 71 | def update(self, iteration): 72 | self.deform.update(iteration) 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from scene import Scene 15 | import uuid 16 | from utils.image_utils import psnr, lpips, alex_lpips 17 | from utils.image_utils import ssim as ssim_func 18 | from piq import LPIPS 19 | lpips = LPIPS() 20 | from argparse import Namespace 21 | from pytorch_msssim import ms_ssim 22 | 23 | try: 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | TENSORBOARD_FOUND = True 27 | except ImportError: 28 | TENSORBOARD_FOUND = False 29 | 30 | 31 | def prepare_output_and_logger(args): 32 | if not args.model_path: 33 | if os.getenv('OAR_JOB_ID'): 34 | unique_str = os.getenv('OAR_JOB_ID') 35 | else: 36 | unique_str = str(uuid.uuid4()) 37 | args.model_path = os.path.join("./output/", unique_str[0:10]) 38 | 39 | # Set up output folder 40 | print("Output folder: {}".format(args.model_path)) 41 | os.makedirs(args.model_path, exist_ok=True) 42 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 43 | cfg_log_f.write(str(Namespace(**vars(args)))) 44 | 45 | # Create Tensorboard writer 46 | tb_writer = None 47 | if TENSORBOARD_FOUND: 48 | tb_writer = SummaryWriter(args.model_path) 49 | else: 50 | print("Tensorboard not available: not logging progress") 51 | return tb_writer 52 | 53 | 54 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, renderArgs, deform, load2gpu_on_the_fly, progress_bar=None): 55 | if tb_writer: 56 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 57 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 58 | tb_writer.add_scalar('iter_time', elapsed, iteration) 59 | 60 | test_psnr = 0.0 61 | test_ssim = 0.0 62 | test_lpips = 1e10 63 | test_ms_ssim = 0.0 64 | test_alex_lpips = 1e10 65 | # Report test and samples of training set 66 | if iteration in testing_iterations: 67 | torch.cuda.empty_cache() 68 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, 69 | {'name': 'train', 70 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 71 | for config in validation_configs: 72 | if config['cameras'] and len(config['cameras']) > 0: 73 | # images = torch.tensor([], device="cuda") 74 | # gts = torch.tensor([], device="cuda") 75 | psnr_list, ssim_list, lpips_list, l1_list = [], [], [], [] 76 | ms_ssim_list, alex_lpips_list = [], [] 77 | for idx, viewpoint in enumerate(config['cameras']): 78 | if load2gpu_on_the_fly: 79 | viewpoint.load2device() 80 | fid = viewpoint.fid 81 | xyz = scene.gaussians.get_xyz 82 | 83 | if deform.name == 'mlp': 84 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 85 | elif deform.name == 'node': 86 | time_input = deform.deform.expand_time(fid) 87 | else: 88 | time_input = 0 89 | 90 | d_values = deform.step(xyz.detach(), time_input, feature=scene.gaussians.feature, is_training=False, motion_mask=scene.gaussians.motion_mask, camera_center=viewpoint.camera_center) 91 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 92 | 93 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=deform.d_rot_as_res)["render"], 0.0, 1.0) 94 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 95 | 96 | l1_list.append(l1_loss(image[None], gt_image[None]).mean()) 97 | psnr_list.append(psnr(image[None], gt_image[None]).mean()) 98 | ssim_list.append(ssim_func(image[None], gt_image[None], data_range=1.).mean()) 99 | lpips_list.append(lpips(image[None], gt_image[None]).mean()) 100 | ms_ssim_list.append(ms_ssim(image[None], gt_image[None], data_range=1.).mean()) 101 | alex_lpips_list.append(alex_lpips(image[None], gt_image[None]).mean()) 102 | 103 | # images = torch.cat((images, image.unsqueeze(0)), dim=0) 104 | # gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) 105 | 106 | if load2gpu_on_the_fly: 107 | viewpoint.load2device('cpu') 108 | if tb_writer and (idx < 5): 109 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 110 | if iteration == testing_iterations[0]: 111 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 112 | 113 | l1_test = torch.stack(l1_list).mean() 114 | psnr_test = torch.stack(psnr_list).mean() 115 | ssim_test = torch.stack(ssim_list).mean() 116 | lpips_test = torch.stack(lpips_list).mean() 117 | ms_ssim_test = torch.stack(ms_ssim_list).mean() 118 | alex_lpips_test = torch.stack(alex_lpips_list).mean() 119 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0: 120 | test_psnr = psnr_test 121 | test_ssim = ssim_test 122 | test_lpips = lpips_test 123 | test_ms_ssim = ms_ssim_test 124 | test_alex_lpips = alex_lpips_test 125 | if progress_bar is None: 126 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {} MS SSIM{} ALEX_LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test)) 127 | else: 128 | progress_bar.set_description("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {} MS SSIM {} ALEX_LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test, ms_ssim_test, alex_lpips_test)) 129 | if tb_writer: 130 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 131 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 132 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', test_ssim, iteration) 133 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - lpips', test_lpips, iteration) 134 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ms-ssim', test_ms_ssim, iteration) 135 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - alex-lpips', test_alex_lpips, iteration) 136 | 137 | if tb_writer: 138 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 139 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 140 | torch.cuda.empty_cache() 141 | 142 | return test_psnr, test_ssim, test_lpips, test_ms_ssim, test_alex_lpips 143 | 144 | -------------------------------------------------------------------------------- /train_gui.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_gui.py --source_path YOUR/PATH/TO/DATASET/jumpingjacks --model_path outputs/jumpingjacks --deform_type node --node_num 512 --is_blender --eval --gui --gt_alpha_mask_as_scene_mask --local_frame --resolution 2 --W 800 --H 800 -------------------------------------------------------------------------------- /train_gui_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class DeformKeypoints: 6 | def __init__(self) -> None: 7 | self.keypoints3d_list = [] # list of keypoints group 8 | self.keypoints_idx_list = [] # keypoints index 9 | self.keypoints3d_delta_list = [] 10 | self.selective_keypoints_idx_list = [] # keypoints index 11 | self.idx2group = {} 12 | 13 | self.selective_rotation_keypoints_idx_list = [] 14 | # self.rotation_idx2group = {} 15 | 16 | def get_kpt_idx(self,): 17 | return self.keypoints_idx_list 18 | 19 | def get_kpt(self,): 20 | return self.keypoints3d_list 21 | 22 | def get_kpt_delta(self,): 23 | return self.keypoints3d_delta_list 24 | 25 | def get_deformed_kpt_np(self, rate=1.): 26 | return np.array(self.keypoints3d_list) + np.array(self.keypoints3d_delta_list) * rate 27 | 28 | def add_kpts(self, keypoints_coord, keypoints_idx, expand=False): 29 | # keypoints3d: [N, 3], keypoints_idx: [N,], torch.tensor 30 | # self.selective_keypoints_idx_list.clear() 31 | selective_keypoints_idx_list = [] if not expand else self.selective_keypoints_idx_list 32 | for idx in range(len(keypoints_idx)): 33 | if not self.contain_kpt(keypoints_idx[idx].item()): 34 | selective_keypoints_idx_list.append(len(self.keypoints_idx_list)) 35 | self.keypoints_idx_list.append(keypoints_idx[idx].item()) 36 | self.keypoints3d_list.append(keypoints_coord[idx].cpu().numpy()) 37 | self.keypoints3d_delta_list.append(np.zeros_like(self.keypoints3d_list[-1])) 38 | 39 | for kpt_idx in keypoints_idx: 40 | self.idx2group[kpt_idx.item()] = selective_keypoints_idx_list 41 | 42 | self.selective_keypoints_idx_list = selective_keypoints_idx_list 43 | 44 | def contain_kpt(self, idx): 45 | # idx: int 46 | if idx in self.keypoints_idx_list: 47 | return True 48 | else: 49 | return False 50 | 51 | def select_kpt(self, idx): 52 | # idx: int 53 | # output: idx list of this group 54 | if idx in self.keypoints_idx_list: 55 | self.selective_keypoints_idx_list = self.idx2group[idx] 56 | 57 | def select_rotation_kpt(self, idx): 58 | if idx in self.keypoints_idx_list: 59 | self.selective_rotation_keypoints_idx_list = self.idx2group[idx] 60 | 61 | def get_rotation_center(self,): 62 | selected_rotation_points = self.get_deformed_kpt_np()[self.selective_rotation_keypoints_idx_list] 63 | return selected_rotation_points.mean(axis=0) 64 | 65 | def get_selective_center(self,): 66 | selected_points = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list] 67 | return selected_points.mean(axis=0) 68 | 69 | def delete_kpt(self, idx): 70 | for kidx in self.selective_keypoints_idx_list: 71 | list_idx = self.idx2group.pop(kidx) 72 | self.keypoints3d_delta_list.pop(list_idx) 73 | self.keypoints3d_list.pop(list_idx) 74 | self.keypoints_idx_list.pop(list_idx) 75 | 76 | def delete_batch_ktps(self, batch_idx): 77 | pass 78 | 79 | def update_delta(self, delta): 80 | # delta: [3,], np.array 81 | for idx in self.selective_keypoints_idx_list: 82 | self.keypoints3d_delta_list[idx] += delta 83 | 84 | def set_delta(self, delta): 85 | # delta: [N, 3], np.array 86 | for id, idx in enumerate(self.selective_keypoints_idx_list): 87 | self.keypoints3d_delta_list[idx] = delta[id] 88 | 89 | 90 | def set_rotation_delta(self, rot_mat): 91 | kpts3d = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list] 92 | kpts3d_mean = kpts3d.mean(axis=0) 93 | kpts3d = (kpts3d - kpts3d_mean) @ rot_mat.T + kpts3d_mean 94 | delta = kpts3d - np.array(self.keypoints3d_list)[self.selective_keypoints_idx_list] 95 | for id, idx in enumerate(self.selective_keypoints_idx_list): 96 | self.keypoints3d_delta_list[idx] = delta[id] 97 | -------------------------------------------------------------------------------- /utils/arap_deform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils.deform_utils import cal_laplacian, cal_connectivity_from_points,\ 4 | produce_edge_matrix_nfmt, lstsq_with_handles, cal_verts_deg, rigid_align 5 | from utils.other_utils import matrix_to_quaternion 6 | 7 | 8 | def cal_L_from_points(points, return_nn_idx=False): 9 | # points: (N, 3) 10 | Nv = len(points) 11 | L = torch.eye(Nv).cuda() 12 | 13 | radius = 0.3 # 14 | K = 10 15 | knn_res = ball_query(points[None], points[None], K=K, radius=radius, return_nn=False) 16 | nn_dist, nn_idx = knn_res.dists[0], knn_res.idx[0] # [Nv, K], [Nv, K] 17 | 18 | for idx, cur_nn_idx in enumerate(nn_idx): 19 | real_cur_nn_idx = cur_nn_idx[cur_nn_idx != -1] 20 | real_cur_nn_idx = real_cur_nn_idx[real_cur_nn_idx != idx] 21 | L[idx, idx] = len(real_cur_nn_idx) 22 | L[idx][real_cur_nn_idx] = -1 23 | 24 | if return_nn_idx: 25 | return L, nn_idx 26 | else: 27 | return L 28 | 29 | 30 | def mask_softmax(x, mask, dim=1): 31 | # x: (N, K), mask: (N, K) 0/1 32 | x = torch.exp(x) 33 | x = x * mask 34 | x = x / x.sum(dim=dim, keepdim=True) 35 | return x 36 | 37 | 38 | class ARAPDeformer: 39 | def __init__(self, verts, K=10, radius=0.3, point_mask=None, trajectory=None, node_radius=None) -> None: 40 | # verts: (N, 3), one_ring_idx: (N, K) 41 | self.device = verts.device 42 | self.verts = verts 43 | self.verts_copy = verts.clone() 44 | self.radius = radius 45 | self.K = K 46 | self.N = len(verts) 47 | 48 | self.ii, self.jj, self.nn, weight = cal_connectivity_from_points(self.verts, self.radius, self.K, trajectory=trajectory, node_radius=node_radius) 49 | self.L = cal_laplacian(Nv=self.N, ii=self.ii, jj=self.jj, nn=self.nn) 50 | # self.L = cal_L_from_points(points=self.verts) 51 | 52 | ##### add learnable deformation weights ##### 53 | self.vert_deg = cal_verts_deg(self.N, self.ii) 54 | # weight = torch.ones(self.N, K).float().cuda() # [Nv, K] 55 | # weight[self.ii, self.nn] = -1 / self.vert_deg[self.ii] 56 | self.weight = torch.nn.Parameter(weight, requires_grad=True) # [Nv, K] 57 | self.weight_mask = torch.zeros(self.N, K).float().cuda() # [Nv, K] 58 | self.weight_mask[self.ii, self.nn] = 1 59 | 60 | self.L_opt = torch.eye(self.N).cuda() # replace all the self.L with self.L_opt! s.t. weight is in [0,1], easy to optimize. 61 | self.L_is_degenerate = False 62 | self.cal_L_opt() 63 | self.b = torch.mm(self.L_opt, self.verts) # [Nv, 3] 64 | 65 | self.point_mask = point_mask # [N,] 66 | 67 | def cal_L_opt(self): 68 | self.normalized_weight = self.weight 69 | self.L_opt[self.ii, self.jj] = - self.normalized_weight[self.ii, self.nn] # [Nv, Nv] 70 | self.L_is_degenerate = (torch.linalg.matrix_rank(self.L_opt) < self.L_opt.shape[0]) 71 | if self.L_is_degenerate: 72 | print("L_opt is not invertible, use pseudo inverse instead") 73 | 74 | def reset(self): 75 | self.verts = self.verts_copy.clone() 76 | 77 | def world_2_local_index(self, handle_idx): 78 | # handle_idx: [m,] 79 | # point mask [N,] 80 | # idx_offset = torch.cat([torch.zeros_like(self.point_mask[:1]), torch.cumsum(self.point_mask, dim=0)]) 81 | idx_offset = torch.cumsum(~self.point_mask, dim=0) 82 | handle_idx_offset = idx_offset[handle_idx] 83 | return handle_idx - handle_idx_offset 84 | 85 | 86 | def deform(self, handle_idx, handle_pos, init_verts=None, return_R=False): 87 | # handle_idx: (M, ), handle_pos: (M, 3) 88 | 89 | if self.point_mask is not None: 90 | handle_idx = self.world_2_local_index(handle_idx) 91 | 92 | ##### calculate b ##### 93 | ### b_fixed 94 | unknown_verts = [n for n in range(self.N) if n not in handle_idx.tolist()] # all unknown verts 95 | b_fixed = torch.zeros((self.N, 3), device=self.device) # factor to be subtracted from b, due to constraints 96 | for k, pos in zip(handle_idx, handle_pos): 97 | # b_fixed += torch.einsum("i,j->ij", self.L[:, k], pos) # [Nv,3] 98 | b_fixed += torch.einsum("i,j->ij", self.L_opt[:, k], pos) # [Nv,3] 99 | 100 | ### prepare for b_all 101 | P = produce_edge_matrix_nfmt(self.verts, (self.N, self.K, 3), self.ii, self.jj, self.nn, device=self.device) # [Nv, K, 3] 102 | if init_verts is None: 103 | p_prime = lstsq_with_handles(self.L_opt, self.L_opt@self.verts, handle_idx, handle_pos, A_is_degenarate=self.L_is_degenerate) 104 | else: 105 | p_prime = init_verts 106 | 107 | p_prime_seq = [p_prime] 108 | R = torch.eye(3)[None].repeat(self.N, 1,1).cuda() # compute rotations 109 | 110 | NUM_ITER = 3 111 | D = torch.diag_embed(self.normalized_weight, dim1=1, dim2=2) # [Nv, K, K] 112 | for _ in range(NUM_ITER): 113 | P_prime = produce_edge_matrix_nfmt(p_prime, (self.N, self.K, 3), self.ii, self.jj, self.nn, device=self.device) # [Nv, K, 3] 114 | ### Calculate covariance matrix in bulk 115 | S = torch.bmm(P.permute(0, 2, 1), torch.bmm(D, P_prime)) # [Nv, 3, 3] 116 | 117 | ## in the case of no deflection, set S = 0, such that R = I. This is to avoid numerical errors 118 | unchanged_verts = torch.unique(torch.where((P == P_prime).all(dim=1))[0]) # any verts which are undeformed 119 | S[unchanged_verts] = 0 120 | 121 | U, sig, W = torch.svd(S) 122 | R = torch.bmm(W, U.permute(0, 2, 1)) # compute rotations 123 | 124 | # Need to flip the column of U corresponding to smallest singular value 125 | # for any det(Ri) <= 0 126 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0 127 | if len(entries_to_flip) > 0: 128 | Umod = U.clone() 129 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry 130 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols 131 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1)) 132 | 133 | ### RHS of minimum energy equation 134 | Rsum_shape = (self.N, self.K, 3, 3) 135 | Rsum = torch.zeros(Rsum_shape).to(self.device) # Ri + Rj, as in eq (8) 136 | Rsum[self.ii, self.nn] = R[self.ii] + R[self.jj] 137 | 138 | ### Rsum has shape (V, max_neighbours, 3, 3). P has shape (V, max_neighbours, 3) 139 | ### To batch multiply, collapse first 2 dims into a single batch dim 140 | Rsum_batch, P_batch = Rsum.view(-1, 3, 3), P.view(-1, 3).unsqueeze(-1) 141 | 142 | # RHS of minimum energy equation 143 | b = 0.5 * (torch.bmm(Rsum_batch, P_batch).squeeze(-1).reshape(self.N, self.K, 3) * self.normalized_weight[...,None]).sum(dim=1) 144 | 145 | ### calculate p_prime 146 | p_prime = lstsq_with_handles(self.L_opt, b, handle_idx, handle_pos, A_is_degenarate=self.L_is_degenerate) 147 | 148 | p_prime_seq.append(p_prime) 149 | d_scaling = None 150 | 151 | if return_R: 152 | quat = matrix_to_quaternion(R) 153 | return p_prime, quat, d_scaling 154 | else: 155 | # return p_prime, p_prime_seq 156 | return p_prime 157 | 158 | 159 | 160 | if __name__ == "__main__": 161 | from pytorch3d.io import load_ply 162 | from pytorch3d.ops import ball_query 163 | import pickle 164 | with open("./control_kpt.pkl", "rb") as f: 165 | data = pickle.load(f) 166 | 167 | points = data["pts"] 168 | handle_idx = data["handle_idx"] 169 | handle_pos = data["handle_pos"] 170 | 171 | import trimesh 172 | trimesh.Trimesh(vertices=points).export('deformation_before.ply') 173 | 174 | #### prepare data 175 | points = torch.from_numpy(points).float().cuda() 176 | handle_idx = torch.tensor(handle_idx).long().cuda() 177 | handle_pos = torch.from_numpy(handle_pos).float().cuda() 178 | 179 | deformer = ARAPDeformer(points) 180 | 181 | with torch.no_grad(): 182 | points_prime, p_prime_seq = deformer.deform(handle_idx, handle_pos) 183 | 184 | trimesh.Trimesh(vertices=points_prime.cpu().numpy()).export('deformation_after.ply') 185 | 186 | from utils.deform_utils import cal_arap_error 187 | for p_prime in p_prime_seq: 188 | nodes_sequence = torch.cat([points[None], p_prime[None]], dim=0) 189 | arap_error = cal_arap_error(nodes_sequence, deformer.ii, deformer.jj, deformer.nn, K=deformer.K, weight=deformer.normalized_weight) 190 | print(arap_error) -------------------------------------------------------------------------------- /utils/bezier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BezierCurve: 5 | def __init__(self, points: np.ndarray) -> None: 6 | if points.ndim == 2: 7 | points = points[None] 8 | self.points = points # N, T, D 9 | self.T = points.shape[1] 10 | 11 | def __call__(self, t: float): 12 | assert 0 <= t <= 1, f't: {t} out of range [0, 1]!' 13 | return self.interpolate(t, self.points) 14 | 15 | def interpolate(self, t, points): 16 | if points.shape[1] < 2: 17 | raise ValueError(f"points shape error: {points.shape}") 18 | elif points.shape[1] == 2: 19 | point0, point1 = points[:, 0], points[:, 1] 20 | else: 21 | point0 = self.interpolate(t, points[:, :-1]) 22 | point1 = self.interpolate(t, points[:, 1:]) 23 | return (1 - t) * point0 + t * point1 24 | 25 | 26 | class PieceWiseLinear: 27 | def __init__(self, points: np.ndarray) -> None: 28 | if points.ndim == 2: 29 | points = points[None] 30 | self.points = points # N, T, D 31 | self.T = points.shape[1] 32 | 33 | def __call__(self, t: float): 34 | assert 0 <= t <= 1, f't: {t} out of range [0, 1]!' 35 | return self.interpolate(t, self.points) 36 | 37 | def interpolate(self, t, points): 38 | if points.shape[1] < 2: 39 | raise ValueError(f"points shape error: {points.shape}") 40 | else: 41 | t_scaled = t * (self.T - 1) 42 | t_floor = min(self.T - 2, max(0, int(np.floor(t_scaled)))) 43 | t_ceil = t_floor + 1 44 | point0, point1 = points[:, t_floor], points[:, t_ceil] 45 | return (t_ceil - t_scaled) * point0 + (t_scaled - t_floor) * point1 46 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, ArrayToTorch 15 | from utils.graphics_utils import fov2focal 16 | import json 17 | 18 | WARNED = False 19 | 20 | 21 | def loadCam(args, id, cam_info, resolution_scale, flow_dirs): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution)) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 1600 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | 44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 45 | 46 | gt_image = resized_image_rgb[:3, ...] 47 | loaded_mask = None 48 | 49 | if resized_image_rgb.shape[0] == 4: 50 | loaded_mask = resized_image_rgb[3:4, ...] 51 | 52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 54 | image=gt_image, gt_alpha_mask=loaded_mask, 55 | image_name=cam_info.image_name, uid=id, 56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', fid=cam_info.fid, 57 | depth=cam_info.depth, flow_dirs=flow_dirs) 58 | 59 | 60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args, flow_dirs_list=None): 61 | camera_list = [] 62 | 63 | for id, c in enumerate(cam_infos): 64 | camera_list.append(loadCam(args, id, c, resolution_scale, [] if flow_dirs_list is None else flow_dirs_list[id])) 65 | 66 | return camera_list 67 | 68 | 69 | def camera_to_JSON(id, camera: Camera): 70 | Rt = np.zeros((4, 4)) 71 | Rt[:3, :3] = camera.R.transpose() 72 | Rt[:3, 3] = camera.T 73 | Rt[3, 3] = 1.0 74 | 75 | W2C = np.linalg.inv(Rt) 76 | pos = W2C[:3, 3] 77 | rot = W2C[:3, :3] 78 | serializable_array_2d = [x.tolist() for x in rot] 79 | camera_entry = { 80 | 'id': id, 81 | 'img_name': camera.image_name, 82 | 'width': camera.width, 83 | 'height': camera.height, 84 | 'position': pos.tolist(), 85 | 'rotation': serializable_array_2d, 86 | 'fy': fov2focal(camera.FovY, camera.height), 87 | 'fx': fov2focal(camera.FovX, camera.width) 88 | } 89 | return camera_entry 90 | 91 | 92 | def camera_nerfies_from_JSON(path, scale): 93 | """Loads a JSON camera into memory.""" 94 | with open(path, 'r') as fp: 95 | camera_json = json.load(fp) 96 | 97 | # Fix old camera JSON. 98 | if 'tangential' in camera_json: 99 | camera_json['tangential_distortion'] = camera_json['tangential'] 100 | 101 | return dict( 102 | orientation=np.array(camera_json['orientation']), 103 | position=np.array(camera_json['position']), 104 | focal_length=camera_json['focal_length'] * scale, 105 | principal_point=np.array(camera_json['principal_point']) * scale, 106 | skew=camera_json['skew'], 107 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], 108 | radial_distortion=np.array(camera_json['radial_distortion']), 109 | tangential_distortion=np.array(camera_json['tangential_distortion']), 110 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)), 111 | int(round(camera_json['image_size'][1] * scale)))), 112 | ) 113 | -------------------------------------------------------------------------------- /utils/deform_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch3d.loss.mesh_laplacian_smoothing import cot_laplacian 4 | from pytorch3d.ops import ball_query 5 | from pytorch3d.io import load_ply 6 | # try: 7 | # print('Using speed up torch_batch_svd!') 8 | # from torch_batch_svd import svd 9 | # except: 10 | # print('Use original torch svd!') 11 | svd = torch.svd 12 | import pytorch3d.ops 13 | 14 | 15 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 16 | r, i, j, k = torch.unbind(quaternions, -1) 17 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 18 | o = torch.stack( 19 | ( 20 | 1 - two_s * (j * j + k * k), 21 | two_s * (i * j - k * r), 22 | two_s * (i * k + j * r), 23 | two_s * (i * j + k * r), 24 | 1 - two_s * (i * i + k * k), 25 | two_s * (j * k - i * r), 26 | two_s * (i * k - j * r), 27 | two_s * (j * k + i * r), 28 | 1 - two_s * (i * i + j * j), 29 | ), 30 | -1, 31 | ) 32 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 33 | 34 | 35 | def produce_edge_matrix_nfmt(verts: torch.Tensor, edge_shape, ii, jj, nn, device="cuda") -> torch.Tensor: 36 | """Given a tensor of verts postion, p (V x 3), produce a tensor E, where, for neighbour list J, 37 | E_in = p_i - p_(J[n])""" 38 | 39 | E = torch.zeros(edge_shape).to(device) 40 | E[ii, nn] = verts[ii] - verts[jj] 41 | 42 | return E 43 | 44 | 45 | ####################### utils for arap ####################### 46 | 47 | def geodesic_distance_floyd(cur_node, K=8): 48 | node_num = cur_node.shape[0] 49 | nn_dist, nn_idx, _ = pytorch3d.ops.knn_points(cur_node[None], cur_node[None], None, None, K=K+1) 50 | nn_dist, nn_idx = nn_dist[0]**.5, nn_idx[0] 51 | dist_mat = torch.inf * torch.ones([node_num, node_num], dtype=torch.float32, device=cur_node.device) 52 | dist_mat.scatter_(dim=1, index=nn_idx, src=nn_dist) 53 | dist_mat = torch.minimum(dist_mat, dist_mat.T) 54 | for i in range(nn_dist.shape[0]): 55 | dist_mat = torch.minimum((dist_mat[:, i, None] + dist_mat[None, i, :]), dist_mat) 56 | return dist_mat 57 | 58 | def cal_connectivity_from_points(points=None, radius=0.1, K=10, trajectory=None, least_edge_num=3, node_radius=None, mode='nn', GraphK=4, adaptive_weighting=True): 59 | # input: [Nv,3] 60 | # output: information of edges 61 | # ii : [Ne,] the i th vert 62 | # jj: [Ne,] j th vert is connect to i th vert. 63 | # nn: , [Ne,] the n th neighbour of i th vert is j th vert. 64 | Nv = points.shape[0] if points is not None else trajectory.shape[0] 65 | if trajectory is None: 66 | if mode == 'floyd': 67 | dist_mat = geodesic_distance_floyd(points, K=GraphK) 68 | dist_mat = dist_mat ** 2 69 | mask = torch.eye(Nv).bool() 70 | dist_mat[mask] = torch.inf 71 | nn_dist, nn_idx = dist_mat.sort(dim=1) 72 | nn_dist, nn_idx = nn_dist[:, :K], nn_idx[:, :K] 73 | else: 74 | knn_res = pytorch3d.ops.knn_points(points[None], points[None], None, None, K=K+1) 75 | # Remove themselves 76 | nn_dist, nn_idx = knn_res.dists[0, :, 1:], knn_res.idx[0, :, 1:] # [Nv, K], [Nv, K] 77 | else: 78 | trajectory = trajectory.reshape([Nv, -1]) / trajectory.shape[1] # Average distance of trajectory 79 | if mode == 'floyd': 80 | dist_mat = geodesic_distance_floyd(trajectory, K=GraphK) 81 | dist_mat = dist_mat ** 2 82 | mask = torch.eye(Nv).bool() 83 | dist_mat[mask] = torch.inf 84 | nn_dist, nn_idx = dist_mat.sort(dim=1) 85 | nn_dist, nn_idx = nn_dist[:, :K], nn_idx[:, :K] 86 | else: 87 | knn_res = pytorch3d.ops.knn_points(trajectory[None], trajectory[None], None, None, K=K+1) 88 | # Remove themselves 89 | nn_dist, nn_idx = knn_res.dists[0, :, 1:], knn_res.idx[0, :, 1:] # [Nv, K], [Nv, K] 90 | 91 | # Make sure ranges are within the radius 92 | nn_idx[:, least_edge_num:] = torch.where(nn_dist[:, least_edge_num:] < radius ** 2, nn_idx[:, least_edge_num:], - torch.ones_like(nn_idx[:, least_edge_num:])) 93 | 94 | nn_dist[:, least_edge_num:] = torch.where(nn_dist[:, least_edge_num:] < radius ** 2, nn_dist[:, least_edge_num:], torch.ones_like(nn_dist[:, least_edge_num:]) * torch.inf) 95 | if adaptive_weighting: 96 | nn_dist_1d = nn_dist.reshape(-1) 97 | weight = torch.exp(-nn_dist / nn_dist_1d[~torch.isnan(nn_dist_1d) & ~torch.isinf(nn_dist_1d)].mean()) 98 | elif node_radius is None: 99 | weight = torch.exp(-nn_dist) 100 | else: 101 | nn_radius = node_radius[nn_idx] 102 | weight = torch.exp(-nn_dist / (2 * nn_radius ** 2)) 103 | weight = weight / weight.sum(dim=-1, keepdim=True) 104 | 105 | ii = torch.arange(Nv)[:, None].cuda().long().expand(Nv, K).reshape([-1]) 106 | jj = nn_idx.reshape([-1]) 107 | nn = torch.arange(K)[None].cuda().long().expand(Nv, K).reshape([-1]) 108 | mask = jj != -1 109 | ii, jj, nn = ii[mask], jj[mask], nn[mask] 110 | 111 | return ii, jj, nn, weight 112 | 113 | 114 | def cal_laplacian(Nv, ii, jj, nn): 115 | # input: Nv: int; ii, jj, nn: [Ne,] 116 | # output: laplacian_mat: [Nv, Nv] 117 | laplacian_mat = torch.zeros(Nv, Nv).cuda() 118 | laplacian_mat[ii, jj] = -1 119 | for idx in ii: 120 | laplacian_mat[idx, idx] += 1 # TODO test whether it is correct 121 | return laplacian_mat 122 | 123 | def cal_verts_deg(Nv, ii): 124 | # input: Nv: int; ii, jj, nn: [Ne,] 125 | # output: verts_deg: [Nv,] 126 | verts_deg = torch.zeros(Nv).cuda() 127 | for idx in ii: 128 | verts_deg[idx] += 1 129 | return verts_deg 130 | 131 | def estimate_rotation(source, target, ii, jj, nn, K=10, weight=None, sample_idx=None): 132 | # input: source, target: [Nv, 3]; ii, jj, nn: [Ne,], weight: [Nv, K] 133 | # output: rotation: [Nv, 3, 3] 134 | Nv = len(source) 135 | source_edge_mat = produce_edge_matrix_nfmt(source, (Nv, K, 3), ii, jj, nn) # [Nv, K, 3] 136 | target_edge_mat = produce_edge_matrix_nfmt(target, (Nv, K, 3), ii, jj, nn) # [Nv, K, 3] 137 | if weight is None: 138 | weight = torch.zeros(Nv, K).cuda() 139 | weight[ii, nn] = 1 140 | print("!!! Edge weight is None !!!") 141 | if sample_idx is not None: 142 | source_edge_mat = source_edge_mat[sample_idx] 143 | target_edge_mat = target_edge_mat[sample_idx] 144 | ### Calculate covariance matrix in bulk 145 | D = torch.diag_embed(weight, dim1=1, dim2=2) # [Nv, K, K] 146 | # S = torch.bmm(source_edge_mat.permute(0, 2, 1), target_edge_mat) # [Nv, 3, 3] 147 | S = torch.bmm(source_edge_mat.permute(0, 2, 1), torch.bmm(D, target_edge_mat)) # [Nv, 3, 3] 148 | ## in the case of no deflection, set S = 0, such that R = I. This is to avoid numerical errors 149 | unchanged_verts = torch.unique(torch.where((source_edge_mat == target_edge_mat).all(dim=1))[0]) # any verts which are undeformed 150 | S[unchanged_verts] = 0 151 | 152 | # t2 = time.time() 153 | U, sig, W = svd(S) 154 | R = torch.bmm(W, U.permute(0, 2, 1)) # compute rotations 155 | # t3 = time.time() 156 | 157 | # Need to flip the column of U corresponding to smallest singular value 158 | # for any det(Ri) <= 0 159 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0 160 | if len(entries_to_flip) > 0: 161 | Umod = U.clone() 162 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry 163 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols 164 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1)) 165 | # t4 = time.time() 166 | # print(f'0-1: {t1-t0}, 1-2: {t2-t1}, 2-3: {t3-t2}, 3-4: {t4-t3}') 167 | return R 168 | 169 | import time 170 | def cal_arap_error(nodes_sequence, ii, jj, nn, K=10, weight=None, sample_num=512): 171 | # input: nodes_sequence: [Nt, Nv, 3]; ii, jj, nn: [Ne,], weight: [Nv, K] 172 | # output: arap error: float 173 | Nt, Nv, _ = nodes_sequence.shape 174 | arap_error = 0 175 | if weight is None: 176 | weight = torch.zeros(Nv, K).cuda() 177 | weight[ii, nn] = 1 178 | source_edge_mat = produce_edge_matrix_nfmt(nodes_sequence[0], (Nv, K, 3), ii, jj, nn) # [Nv, K, 3] 179 | sample_idx = torch.arange(Nv).cuda() 180 | if Nv > sample_num: 181 | sample_idx = torch.from_numpy(np.random.choice(Nv, sample_num)).long().cuda() 182 | else: 183 | source_edge_mat = source_edge_mat[sample_idx] 184 | weight = weight[sample_idx] 185 | for idx in range(1, Nt): 186 | # t1 = time.time() 187 | with torch.no_grad(): 188 | rotation = estimate_rotation(nodes_sequence[0], nodes_sequence[idx], ii, jj, nn, K=K, weight=weight, sample_idx=sample_idx) # [Nv, 3, 3] 189 | # Compute energy 190 | target_edge_mat = produce_edge_matrix_nfmt(nodes_sequence[idx], (Nv, K, 3), ii, jj, nn) # [Nv, K, 3] 191 | target_edge_mat = target_edge_mat[sample_idx] 192 | rot_rigid = torch.bmm(rotation, source_edge_mat[sample_idx].permute(0, 2, 1)).permute(0, 2, 1) # [Nv, K, 3] 193 | stretch_vec = target_edge_mat - rot_rigid # stretch vector 194 | stretch_norm = (torch.norm(stretch_vec, dim=2) ** 2) # norm over (x,y,z) space 195 | arap_error += (weight * stretch_norm).sum() 196 | return arap_error 197 | 198 | def cal_L_from_points(points, return_nn_idx=False): 199 | # points: (N, 3) 200 | Nv = len(points) 201 | L = torch.eye(Nv).cuda() 202 | radius = 0.1 # 203 | K = 20 204 | knn_res = ball_query(points[None], points[None], K=K, radius=radius, return_nn=False) 205 | nn_dist, nn_idx = knn_res.dists[0], knn_res.idx[0] # [Nv, K], [Nv, K] 206 | for idx, cur_nn_idx in enumerate(nn_idx): 207 | real_cur_nn_idx = cur_nn_idx[cur_nn_idx != -1] 208 | real_cur_nn_idx = real_cur_nn_idx[real_cur_nn_idx != idx] 209 | L[idx, idx] = len(real_cur_nn_idx) 210 | L[idx][real_cur_nn_idx] = -1 211 | if return_nn_idx: 212 | return L, nn_idx 213 | else: 214 | return L 215 | 216 | def lstsq_with_handles(A, b, handle_idx, handle_pos, A_is_degenarate=False): 217 | b = b - A[:, handle_idx] @ handle_pos 218 | handle_mask = torch.zeros_like(A[:, 0], dtype=bool) 219 | handle_mask[handle_idx] = 1 220 | L = A[:, handle_mask.logical_not()] 221 | if not A_is_degenarate: 222 | x = torch.linalg.lstsq(L, b)[0] 223 | else: 224 | x = torch.linalg.pinv(L) @ b 225 | x_out = torch.zeros_like(b) 226 | x_out[handle_idx] = handle_pos 227 | x_out[handle_mask.logical_not()] = x 228 | return x_out 229 | 230 | def rigid_align(x, y): 231 | x_bar, y_bar = x.mean(0), y.mean(0) 232 | x, y = x - x_bar, y - y_bar 233 | S = x.permute(1, 0) @ y # 3 * 3 234 | U, _, W = svd(S) 235 | R = W @ U.permute(1, 0) 236 | t = y_bar - R @ x_bar 237 | x2y = x @ R.T + t 238 | return x2y, R, t 239 | 240 | def arap_deformation_loss(trajectory, node_radius=None, trajectory_rot=None, K=50, with_rot=True): 241 | init_pcl = trajectory[:, 0] 242 | radius = torch.linalg.norm(init_pcl.max(dim=0).values - init_pcl.min(dim=0).values) / 8 243 | fid = torch.randint(1, trajectory.shape[1], []) 244 | tar_pcl = trajectory[:, fid] 245 | 246 | N = init_pcl.shape[0] 247 | with torch.no_grad(): 248 | radius = torch.linalg.norm(init_pcl.max(dim=0).values - init_pcl.min(dim=0).values) / 8 249 | device = init_pcl.device 250 | ii, jj, nn, weight = cal_connectivity_from_points(init_pcl, radius, K, trajectory=trajectory.detach(), node_radius=node_radius, mode='nn') 251 | L_opt = torch.eye(N).cuda() 252 | L_opt[ii, jj] = - weight[ii, nn] 253 | 254 | P = produce_edge_matrix_nfmt(init_pcl, (N, K, 3), ii, jj, nn, device=device) 255 | P_prime = produce_edge_matrix_nfmt(tar_pcl, (N, K, 3), ii, jj, nn, device=device) 256 | 257 | with torch.no_grad(): 258 | D = torch.diag_embed(weight, dim1=1, dim2=2) 259 | S = torch.bmm(P.permute(0, 2, 1), torch.bmm(D, P_prime)) 260 | U, sig, W = torch.svd(S) 261 | R = torch.bmm(W, U.permute(0, 2, 1)) 262 | with torch.no_grad(): 263 | # Need to flip the column of U corresponding to smallest singular value 264 | # for any det(Ri) <= 0 265 | entries_to_flip = torch.nonzero(torch.det(R) <= 0, as_tuple=False).flatten() # idxs where det(R) <= 0 266 | if len(entries_to_flip) > 0: 267 | Umod = U.clone() 268 | cols_to_flip = torch.argmin(sig[entries_to_flip], dim=1) # Get minimum singular value for each entry 269 | Umod[entries_to_flip, :, cols_to_flip] *= -1 # flip cols 270 | R[entries_to_flip] = torch.bmm(W[entries_to_flip], Umod[entries_to_flip].permute(0, 2, 1)) 271 | arap_error = (weight[..., None] * (P_prime - torch.einsum('bxy,bky->bkx', R, P))).square().mean(dim=0).sum() 272 | 273 | if with_rot: 274 | init_rot = quaternion_to_matrix(trajectory_rot[:, 0]) 275 | tar_rot = quaternion_to_matrix(trajectory_rot[:, fid]) 276 | R_rot = torch.bmm(R, init_rot) 277 | rot_error = (R_rot - tar_rot).square().mean(dim=0).sum() 278 | else: 279 | rot_error = 0. 280 | 281 | return arap_error, rot_error * 1e2 282 | -------------------------------------------------------------------------------- /utils/dual_quaternion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 5 | """ 6 | Returns torch.sqrt(torch.max(0, x)) 7 | but with a zero subgradient where x is 0. 8 | """ 9 | ret = torch.zeros_like(x) 10 | positive_mask = x > 0 11 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 12 | return ret 13 | 14 | 15 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 16 | """ 17 | Convert rotations given as rotation matrices to quaternions. 18 | 19 | Args: 20 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 21 | 22 | Returns: 23 | quaternions with real part first, as tensor of shape (..., 4). 24 | """ 25 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 26 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 27 | 28 | batch_dim = matrix.shape[:-2] 29 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 30 | matrix.reshape(batch_dim + (9,)), dim=-1 31 | ) 32 | 33 | q_abs = _sqrt_positive_part( 34 | torch.stack( 35 | [ 36 | 1.0 + m00 + m11 + m22, 37 | 1.0 + m00 - m11 - m22, 38 | 1.0 - m00 + m11 - m22, 39 | 1.0 - m00 - m11 + m22, 40 | ], 41 | dim=-1, 42 | ) 43 | ) 44 | 45 | # we produce the desired quaternion multiplied by each of r, i, j, k 46 | quat_by_rijk = torch.stack( 47 | [ 48 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 49 | # `int`. 50 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 51 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 52 | # `int`. 53 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 54 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 55 | # `int`. 56 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 57 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 58 | # `int`. 59 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 60 | ], 61 | dim=-2, 62 | ) 63 | 64 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 65 | # the candidate won't be picked. 66 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 67 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 68 | 69 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 70 | # forall i; we pick the best-conditioned one (with the largest denominator) 71 | 72 | return quat_candidates[ 73 | torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 74 | ].reshape(batch_dim + (4,)) 75 | 76 | 77 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 78 | r, i, j, k = torch.unbind(quaternions, -1) 79 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 80 | o = torch.stack( 81 | ( 82 | 1 - two_s * (j * j + k * k), 83 | two_s * (i * j - k * r), 84 | two_s * (i * k + j * r), 85 | two_s * (i * j + k * r), 86 | 1 - two_s * (i * i + k * k), 87 | two_s * (j * k - i * r), 88 | two_s * (i * k - j * r), 89 | two_s * (j * k + i * r), 90 | 1 - two_s * (i * i + j * j), 91 | ), 92 | -1, 93 | ) 94 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 95 | 96 | 97 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 98 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 99 | 100 | 101 | def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 102 | aw, ax, ay, az = torch.unbind(a, -1) 103 | bw, bx, by, bz = torch.unbind(b, -1) 104 | ow = aw * bw - ax * bx - ay * by - az * bz 105 | ox = aw * bx + ax * bw + ay * bz - az * by 106 | oy = aw * by - ax * bz + ay * bw + az * bx 107 | oz = aw * bz + ax * by - ay * bx + az * bw 108 | return torch.stack((ow, ox, oy, oz), -1) 109 | 110 | 111 | def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 112 | ab = quaternion_raw_multiply(a, b) 113 | return standardize_quaternion(ab) 114 | 115 | 116 | def dualquaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 117 | a_real, b_real = a[..., :4], b[..., :4] 118 | a_imag, b_imag = a[..., 4:], b[..., 4:] 119 | o_real = quaternion_multiply(a_real, b_real) 120 | o_imag = quaternion_multiply(a_imag, b_real) + quaternion_multiply(a_real, b_imag) 121 | o = torch.cat([o_real, o_imag], dim=-1) 122 | return o 123 | 124 | 125 | def conjugation(q): 126 | if q.shape[-1] == 4: 127 | q = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) 128 | elif q.shape[-1] == 8: 129 | q = torch.cat([q[..., :1], -q[..., 1:4], q[..., 4:5], -q[..., 5:]], dim=-1) 130 | else: 131 | raise TypeError(f'q should be of [..., 4] or [..., 8] but got {q.shape}!') 132 | return q 133 | 134 | 135 | def QT2DQ(q, t, rot_as_q=True): 136 | if not rot_as_q: 137 | q = matrix_to_quaternion(q) 138 | q = torch.nn.functional.normalize(q) 139 | real = q 140 | t = torch.cat([torch.zeros_like(t[..., :1]), t], dim=-1) 141 | image = quaternion_multiply(t, q) / 2 142 | dq = torch.cat([real, image], dim=-1) 143 | return dq 144 | 145 | 146 | def DQ2QT(dq, rot_as_q=False): 147 | real = dq[..., :4] 148 | imag = dq[..., 4:] 149 | real_norm = real.norm(dim=-1, keepdim=True).clamp(min=1e-8) 150 | real, imag = real / real_norm, imag / real_norm 151 | 152 | w0, x0, y0, z0 = torch.unbind(real, -1) 153 | w1, x1, y1, z1 = torch.unbind(imag, -1) 154 | 155 | t = 2* torch.stack([- w1*x0 + x1*w0 - y1*z0 + z1*y0, 156 | - w1*y0 + x1*z0 + y1*w0 - z1*x0, 157 | - w1*z0 - x1*y0 + y1*x0 + z1*w0], dim=-1) 158 | R = torch.stack([1-2*y0**2-2*z0**2, 2*x0*y0-2*w0*z0, 2*x0*z0+2*w0*y0, 159 | 2*x0*y0+2*w0*z0, 1-2*x0**2-2*z0**2, 2*y0*z0-2*w0*x0, 160 | 2*x0*z0-2*w0*y0, 2*y0*z0+2*w0*x0, 1-2*x0**2-2*y0**2], dim=-1).reshape([*w0.shape, 3, 3]) 161 | if rot_as_q: 162 | q = matrix_to_quaternion(R) 163 | return q, t 164 | else: 165 | return R, t 166 | 167 | 168 | def DQBlending(q, t, weights, rot_as_q=True): 169 | ''' 170 | Input: 171 | q: [..., k, 4]; t: [..., k, 3]; weights: [..., k] 172 | Output: 173 | q_: [..., 4]; t_: [..., 3] 174 | ''' 175 | dq = QT2DQ(q=q, t=t) 176 | dq_avg = (dq * weights[..., None]).sum(dim=-2) 177 | q_, t_ = DQ2QT(dq_avg, rot_as_q=rot_as_q) 178 | return q_, t_ 179 | 180 | 181 | def interpolate(q0, t0, q1, t1, weight, rot_as_q=True): 182 | dq0 = QT2DQ(q=q0, t=t0) 183 | dq1 = QT2DQ(q=q1, t=t1) 184 | dq_avg = dq0 * weight + dq1 * (1 - weight) 185 | q, t = DQ2QT(dq=dq_avg, rot_as_q=rot_as_q) 186 | return q, t 187 | 188 | 189 | def transformation_blending(transformations, weights): 190 | Rs, Ts = transformations[:, :3, :3], transformations[:, :3, 3] 191 | qs = matrix_to_quaternion(Rs) 192 | q, T = DQBlending(qs[None], Ts[None], weights) 193 | R = quaternion_to_matrix(q) 194 | transformation = torch.eye(4).to(transformations.device)[None].expand(weights.shape[0], 4, 4).clone() 195 | transformation[:, :3, :3] = R 196 | transformation[:, :3, 3] = T 197 | return transformation 198 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | from PIL import Image 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | if np.asarray(pil_image).shape[-1] == 4: 25 | # Process rgb and alpha respectively to avoid mask rgb with alpha 26 | rgb = Image.fromarray(np.asarray(pil_image)[..., :3]) 27 | a = Image.fromarray(np.asarray(pil_image)[..., 3]) 28 | rgb, a = np.asarray(rgb.resize(resolution)), np.asarray(a.resize(resolution)) 29 | resized_image = torch.from_numpy(np.concatenate([rgb, a[..., None]], axis=-1)) / 255.0 30 | else: 31 | resized_image_PIL = pil_image.resize(resolution) 32 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 33 | if len(resized_image.shape) == 3: 34 | return resized_image.permute(2, 0, 1) 35 | else: 36 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 37 | 38 | 39 | def ArrayToTorch(array, resolution): 40 | # resized_image = np.resize(array, resolution) 41 | resized_image_torch = torch.from_numpy(array) 42 | 43 | if len(resized_image_torch.shape) == 3: 44 | return resized_image_torch.permute(2, 0, 1) 45 | else: 46 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1) 47 | 48 | 49 | def get_expon_lr_func( 50 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 51 | ): 52 | """ 53 | Copied from Plenoxels 54 | 55 | Continuous learning rate decay function. Adapted from JaxNeRF 56 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 57 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 58 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 59 | function of lr_delay_mult, such that the initial learning rate is 60 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 61 | to the normal learning rate when steps>lr_delay_steps. 62 | :param conf: config subtree 'lr' or similar 63 | :param max_steps: int, the number of steps during optimization. 64 | :return HoF which takes step as input 65 | """ 66 | 67 | def helper(step): 68 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 69 | # Disable this parameter 70 | return 0.0 71 | if lr_delay_steps > 0: 72 | # A kind of reverse cosine decay. 73 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 74 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 75 | ) 76 | else: 77 | delay_rate = 1.0 78 | t = np.clip(step / max_steps, 0, 1) 79 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 80 | return delay_rate * log_lerp 81 | 82 | return helper 83 | 84 | 85 | def get_linear_noise_func( 86 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 87 | ): 88 | """ 89 | Copied from Plenoxels 90 | 91 | Continuous learning rate decay function. Adapted from JaxNeRF 92 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 93 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 94 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 95 | function of lr_delay_mult, such that the initial learning rate is 96 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 97 | to the normal learning rate when steps>lr_delay_steps. 98 | :param conf: config subtree 'lr' or similar 99 | :param max_steps: int, the number of steps during optimization. 100 | :return HoF which takes step as input 101 | """ 102 | 103 | def helper(step): 104 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 105 | # Disable this parameter 106 | return 0.0 107 | if lr_delay_steps > 0: 108 | # A kind of reverse cosine decay. 109 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 110 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 111 | ) 112 | else: 113 | delay_rate = 1.0 114 | t = np.clip(step / max_steps, 0, 1) 115 | log_lerp = lr_init * (1 - t) + lr_final * t 116 | return delay_rate * log_lerp 117 | 118 | return helper 119 | 120 | 121 | def strip_lowerdiag(L): 122 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 123 | 124 | uncertainty[:, 0] = L[:, 0, 0] 125 | uncertainty[:, 1] = L[:, 0, 1] 126 | uncertainty[:, 2] = L[:, 0, 2] 127 | uncertainty[:, 3] = L[:, 1, 1] 128 | uncertainty[:, 4] = L[:, 1, 2] 129 | uncertainty[:, 5] = L[:, 2, 2] 130 | return uncertainty 131 | 132 | 133 | def strip_symmetric(sym): 134 | return strip_lowerdiag(sym) 135 | 136 | 137 | def build_rotation(r): 138 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 139 | 140 | q = r / norm[:, None] 141 | 142 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 143 | 144 | r = q[:, 0] 145 | x = q[:, 1] 146 | y = q[:, 2] 147 | z = q[:, 3] 148 | 149 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 150 | R[:, 0, 1] = 2 * (x * y - r * z) 151 | R[:, 0, 2] = 2 * (x * z + r * y) 152 | R[:, 1, 0] = 2 * (x * y + r * z) 153 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 154 | R[:, 1, 2] = 2 * (y * z - r * x) 155 | R[:, 2, 0] = 2 * (x * z - r * y) 156 | R[:, 2, 1] = 2 * (y * z + r * x) 157 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 158 | return R 159 | 160 | 161 | def build_scaling_rotation(s, r): 162 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 163 | R = build_rotation(r) 164 | 165 | L[:, 0, 0] = s[:, 0] 166 | L[:, 1, 1] = s[:, 1] 167 | L[:, 2, 2] = s[:, 2] 168 | 169 | L = R @ L 170 | return L 171 | 172 | 173 | def build_scaling_rotation_inverse(s, r): 174 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 175 | R = build_rotation(r) 176 | 177 | L[:, 0, 0] = 1 / s[:, 0] 178 | L[:, 1, 1] = 1 / s[:, 1] 179 | L[:, 2, 2] = 1 / s[:, 2] 180 | 181 | L = R.permute(0, 2, 1) @ L 182 | return L 183 | 184 | 185 | def safe_state(silent): 186 | old_f = sys.stdout 187 | 188 | class F: 189 | def __init__(self, silent): 190 | self.silent = silent 191 | 192 | def write(self, x): 193 | if not self.silent: 194 | if x.endswith("\n"): 195 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 196 | else: 197 | old_f.write(x) 198 | 199 | def flush(self): 200 | old_f.flush() 201 | 202 | sys.stdout = F(silent) 203 | 204 | random.seed(0) 205 | np.random.seed(0) 206 | torch.manual_seed(0) 207 | torch.cuda.set_device(torch.device("cuda:0")) 208 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | # NeRF-DS Alex LPIPS 15 | import lpips as lpips_lib 16 | loss_fn_alex = lpips_lib.LPIPS(net='alex') 17 | loss_fn_alex.net.cuda() 18 | loss_fn_alex.scaling_layer.cuda() 19 | loss_fn_alex.lins.cuda() 20 | def alex_lpips(image1, image2): 21 | image1 = image1 * 2 - 1 22 | image2 = image2 * 2 - 1 23 | lpips = loss_fn_alex(image1, image2) 24 | return lpips 25 | 26 | 27 | def mse(img1, img2): 28 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 29 | 30 | 31 | def psnr(img1, img2): 32 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 33 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 34 | 35 | 36 | from piq import ssim, LPIPS 37 | lpips = LPIPS() 38 | -------------------------------------------------------------------------------- /utils/interactive_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class DeformKeypoints: 6 | def __init__(self) -> None: 7 | self.keypoints3d_list = [] # list of keypoints group 8 | self.keypoints_idx_list = [] # keypoints index 9 | self.keypoints3d_delta_list = [] 10 | self.selective_keypoints_idx_list = [] # keypoints index 11 | self.idx2group = {} 12 | 13 | self.selective_rotation_keypoints_idx_list = [] 14 | # self.rotation_idx2group = {} 15 | 16 | def get_kpt_idx(self,): 17 | return self.keypoints_idx_list 18 | 19 | def get_kpt(self,): 20 | return self.keypoints3d_list 21 | 22 | def get_kpt_delta(self,): 23 | return self.keypoints3d_delta_list 24 | 25 | def get_deformed_kpt_np(self, rate=1.): 26 | return np.array(self.keypoints3d_list) + np.array(self.keypoints3d_delta_list) * rate 27 | 28 | def add_kpts(self, keypoints_coord, keypoints_idx, expand=False): 29 | # keypoints3d: [N, 3], keypoints_idx: [N,], torch.tensor 30 | # self.selective_keypoints_idx_list.clear() 31 | selective_keypoints_idx_list = [] if not expand else self.selective_keypoints_idx_list 32 | for idx in range(len(keypoints_idx)): 33 | if not self.contain_kpt(keypoints_idx[idx].item()): 34 | selective_keypoints_idx_list.append(len(self.keypoints_idx_list)) 35 | self.keypoints_idx_list.append(keypoints_idx[idx].item()) 36 | self.keypoints3d_list.append(keypoints_coord[idx].cpu().numpy()) 37 | self.keypoints3d_delta_list.append(np.zeros_like(self.keypoints3d_list[-1])) 38 | 39 | for kpt_idx in keypoints_idx: 40 | self.idx2group[kpt_idx.item()] = selective_keypoints_idx_list 41 | 42 | self.selective_keypoints_idx_list = selective_keypoints_idx_list 43 | 44 | def contain_kpt(self, idx): 45 | # idx: int 46 | if idx in self.keypoints_idx_list: 47 | return True 48 | else: 49 | return False 50 | 51 | def select_kpt(self, idx): 52 | # idx: int 53 | # output: idx list of this group 54 | if idx in self.keypoints_idx_list: 55 | self.selective_keypoints_idx_list = self.idx2group[idx] 56 | 57 | def select_rotation_kpt(self, idx): 58 | if idx in self.keypoints_idx_list: 59 | self.selective_rotation_keypoints_idx_list = self.idx2group[idx] 60 | 61 | def get_rotation_center(self,): 62 | selected_rotation_points = self.get_deformed_kpt_np()[self.selective_rotation_keypoints_idx_list] 63 | return selected_rotation_points.mean(axis=0) 64 | 65 | def get_selective_center(self,): 66 | selected_points = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list] 67 | return selected_points.mean(axis=0) 68 | 69 | def delete_kpt(self, idx): 70 | pass 71 | 72 | def delete_batch_ktps(self, batch_idx): 73 | pass 74 | 75 | def update_delta(self, delta): 76 | # delta: [3,], np.array 77 | for idx in self.selective_keypoints_idx_list: 78 | self.keypoints3d_delta_list[idx] += delta 79 | 80 | def set_delta(self, delta): 81 | # delta: [N, 3], np.array 82 | for id, idx in enumerate(self.selective_keypoints_idx_list): 83 | self.keypoints3d_delta_list[idx] = delta[id] 84 | 85 | 86 | def set_rotation_delta(self, rot_mat): 87 | kpts3d = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list] 88 | kpts3d_mean = self.get_rotation_center() 89 | kpts3d = (kpts3d - kpts3d_mean) @ rot_mat.T + kpts3d_mean 90 | delta = kpts3d - np.array(self.keypoints3d_list)[self.selective_keypoints_idx_list] 91 | for id, idx in enumerate(self.selective_keypoints_idx_list): 92 | self.keypoints3d_delta_list[idx] = delta[id] 93 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def kl_divergence(rho, rho_hat): 23 | rho_hat = torch.mean(torch.sigmoid(rho_hat), 0) 24 | rho = torch.tensor([rho] * len(rho_hat)).cuda() 25 | return torch.mean( 26 | rho * torch.log(rho / (rho_hat + 1e-5)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-5))) 27 | 28 | 29 | def l2_loss(network_output, gt): 30 | return ((network_output - gt) ** 2).mean() 31 | 32 | 33 | def gaussian(window_size, sigma): 34 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 35 | return gauss / gauss.sum() 36 | 37 | 38 | def create_window(window_size, channel): 39 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 40 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 41 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 42 | return window 43 | 44 | 45 | def ssim(img1, img2, window_size=11, size_average=True): 46 | channel = img1.size(-3) 47 | window = create_window(window_size, channel) 48 | 49 | if img1.is_cuda: 50 | window = window.cuda(img1.get_device()) 51 | window = window.type_as(img1) 52 | 53 | return _ssim(img1, img2, window, window_size, channel, size_average) 54 | 55 | 56 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 57 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 58 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 59 | 60 | mu1_sq = mu1.pow(2) 61 | mu2_sq = mu2.pow(2) 62 | mu1_mu2 = mu1 * mu2 63 | 64 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 65 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 66 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 67 | 68 | C1 = 0.01 ** 2 69 | C2 = 0.03 ** 2 70 | 71 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | -------------------------------------------------------------------------------- /utils/other_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Returns torch.sqrt(torch.max(0, x)) 8 | but with a zero subgradient where x is 0. 9 | """ 10 | ret = torch.zeros_like(x) 11 | positive_mask = x > 0 12 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 13 | return ret 14 | 15 | 16 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 17 | """ 18 | Convert rotations given as rotation matrices to quaternions. 19 | 20 | Args: 21 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 22 | 23 | Returns: 24 | quaternions with real part first, as tensor of shape (..., 4). 25 | """ 26 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 27 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 28 | 29 | batch_dim = matrix.shape[:-2] 30 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 31 | matrix.reshape(batch_dim + (9,)), dim=-1 32 | ) 33 | 34 | q_abs = _sqrt_positive_part( 35 | torch.stack( 36 | [ 37 | 1.0 + m00 + m11 + m22, 38 | 1.0 + m00 - m11 - m22, 39 | 1.0 - m00 + m11 - m22, 40 | 1.0 - m00 - m11 + m22, 41 | ], 42 | dim=-1, 43 | ) 44 | ) 45 | 46 | # we produce the desired quaternion multiplied by each of r, i, j, k 47 | quat_by_rijk = torch.stack( 48 | [ 49 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 50 | # `int`. 51 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 52 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 53 | # `int`. 54 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 55 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 56 | # `int`. 57 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 58 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 59 | # `int`. 60 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 61 | ], 62 | dim=-2, 63 | ) 64 | 65 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 66 | # the candidate won't be picked. 67 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 68 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 69 | 70 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 71 | # forall i; we pick the best-conditioned one (with the largest denominator) 72 | 73 | return quat_candidates[ 74 | torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 75 | ].reshape(batch_dim + (4,)) 76 | 77 | 78 | def depth2normal(depth:torch.Tensor, focal:float=None): 79 | if depth.dim() == 2: 80 | depth = depth[None, None] 81 | elif depth.dim() == 3: 82 | depth = depth.squeeze()[None, None] 83 | if focal is None: 84 | focal = depth.shape[-1] / 2 / np.tan(torch.pi/6) 85 | depth = torch.cat([depth[:, :, :1], depth, depth[:, :, -1:]], dim=2) 86 | depth = torch.cat([depth[..., :1], depth, depth[..., -1:]], dim=3) 87 | kernel = torch.tensor([[[ 0, 0, 0], 88 | [-.5, 0, .5], 89 | [ 0, 0, 0]], 90 | [[ 0, -.5, 0], 91 | [ 0, 0, 0], 92 | [ 0, .5, 0]]], device=depth.device, dtype=depth.dtype)[:, None] 93 | normal = torch.nn.functional.conv2d(depth, kernel, padding='valid')[0].permute(1, 2, 0) 94 | normal = normal / (depth[0, 0, 1:-1, 1:-1, None] + 1e-10) * focal 95 | normal = torch.cat([normal, torch.ones_like(normal[..., :1])], dim=-1) 96 | normal = normal / normal.norm(dim=-1, keepdim=True) 97 | return normal.permute(2, 0, 1) 98 | -------------------------------------------------------------------------------- /utils/pickle_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | 4 | def save_obj(path, obj): 5 | file = open(path, 'wb') 6 | obj_str = pickle.dumps(obj) 7 | file.write(obj_str) 8 | file.close() 9 | 10 | 11 | def load_obj(path): 12 | file = open(path, 'rb') 13 | obj = pickle.loads(file.read()) 14 | file.close() 15 | return obj 16 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.graphics_utils import fov2focal 4 | 5 | trans_t = lambda t: torch.Tensor([ 6 | [1, 0, 0, 0], 7 | [0, 1, 0, 0], 8 | [0, 0, 1, t], 9 | [0, 0, 0, 1]]).float() 10 | 11 | rot_phi = lambda phi: torch.Tensor([ 12 | [1, 0, 0, 0], 13 | [0, np.cos(phi), -np.sin(phi), 0], 14 | [0, np.sin(phi), np.cos(phi), 0], 15 | [0, 0, 0, 1]]).float() 16 | 17 | rot_theta = lambda th: torch.Tensor([ 18 | [np.cos(th), 0, -np.sin(th), 0], 19 | [0, 1, 0, 0], 20 | [np.sin(th), 0, np.cos(th), 0], 21 | [0, 0, 0, 1]]).float() 22 | 23 | 24 | def rodrigues_mat_to_rot(R): 25 | eps = 1e-16 26 | trc = np.trace(R) 27 | trc2 = (trc - 1.) / 2. 28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2) 29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) 30 | if (1 - trc2 * trc2) >= eps: 31 | tHeta = np.arccos(trc2) 32 | tHetaf = tHeta / (2 * (np.sin(tHeta))) 33 | else: 34 | tHeta = np.real(np.arccos(trc2)) 35 | tHetaf = 0.5 / (1 - tHeta / 6) 36 | omega = tHetaf * s 37 | return omega 38 | 39 | 40 | def rodrigues_rot_to_mat(r): 41 | wx, wy, wz = r 42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz) 43 | a = np.cos(theta) 44 | b = (1 - np.cos(theta)) / (theta * theta) 45 | c = np.sin(theta) / theta 46 | R = np.zeros([3, 3]) 47 | R[0, 0] = a + b * (wx * wx) 48 | R[0, 1] = b * wx * wy - c * wz 49 | R[0, 2] = b * wx * wz + c * wy 50 | R[1, 0] = b * wx * wy + c * wz 51 | R[1, 1] = a + b * (wy * wy) 52 | R[1, 2] = b * wy * wz - c * wx 53 | R[2, 0] = b * wx * wz - c * wy 54 | R[2, 1] = b * wz * wy + c * wx 55 | R[2, 2] = a + b * (wz * wz) 56 | return R 57 | 58 | 59 | def normalize(x): 60 | return x / np.linalg.norm(x) 61 | 62 | 63 | def pose_spherical(theta, phi, radius): 64 | c2w = trans_t(radius) 65 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 66 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 67 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 68 | return c2w 69 | 70 | def viewmatrix(z, up, pos): 71 | vec2 = normalize(z) 72 | vec1_avg = up 73 | vec0 = normalize(np.cross(vec1_avg, vec2)) 74 | vec1 = normalize(np.cross(vec2, vec0)) 75 | m = np.stack([vec0, vec1, vec2, pos], 1) 76 | return m 77 | 78 | def poses_avg(poses): 79 | center = poses[:, :3, 3].mean(0) 80 | vec2 = normalize(poses[:, :3, 2].sum(0)) 81 | up = poses[:, :3, 1].sum(0) 82 | c2w = viewmatrix(vec2, up, center) 83 | return c2w 84 | 85 | def render_path_spiral(c2ws, focal, zrate=.1, rots=3, N=300): 86 | c2w = poses_avg(c2ws) 87 | up = normalize(c2ws[:, :3, 1].sum(0)) 88 | tt = c2ws[:,:3,3] 89 | rads = np.percentile(np.abs(tt), 90, 0) 90 | rads[:] = rads.max() * .05 91 | 92 | render_poses = [] 93 | rads = np.array(list(rads) + [1.]) 94 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 95 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 96 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 97 | # c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 98 | # z = normalize(c2w[:3, 2]) 99 | render_poses.append(viewmatrix(z, up, c)) 100 | render_poses = np.stack(render_poses, axis=0) 101 | render_poses = np.concatenate([render_poses, np.zeros_like(render_poses[..., :1, :])], axis=1) 102 | render_poses[..., 3, 3] = 1 103 | render_poses = np.array(render_poses, dtype=np.float32) 104 | return render_poses 105 | 106 | def render_wander_path(view): 107 | focal_length = fov2focal(view.FoVy, view.image_height) 108 | R = view.R 109 | R[:, 1] = -R[:, 1] 110 | R[:, 2] = -R[:, 2] 111 | T = -view.T.reshape(-1, 1) 112 | pose = np.concatenate([R, T], -1) 113 | 114 | num_frames = 60 115 | max_disp = 5000.0 # 64 , 48 116 | 117 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter 118 | output_poses = [] 119 | 120 | for i in range(num_frames): 121 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 122 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0 123 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 124 | 125 | i_pose = np.concatenate([ 126 | np.concatenate( 127 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 128 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 129 | ], axis=0) # [np.newaxis, :, :] 130 | 131 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float() 132 | 133 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 134 | 135 | render_pose = np.dot(ref_pose, i_pose) 136 | output_poses.append(torch.Tensor(render_pose)) 137 | 138 | return output_poses 139 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | # @title Configure dataset directories 2 | import os 3 | from pathlib import Path 4 | 5 | # @markdown The base directory for all captures. This can be anything if you're running this notebook on your own Jupyter runtime. 6 | save_dir = '/data00/yzy/Git_Project/data/dynamic/mine/' # @param {type: 'string'} 7 | capture_name = 'lemon' # @param {type: 'string'} 8 | # The root directory for this capture. 9 | root_dir = Path(save_dir, capture_name) 10 | # Where to save RGB images. 11 | rgb_dir = root_dir / 'rgb' 12 | rgb_raw_dir = root_dir / 'rgb-raw' 13 | # Where to save the COLMAP outputs. 14 | colmap_dir = root_dir / 'colmap' 15 | colmap_db_path = colmap_dir / 'database.db' 16 | colmap_out_path = colmap_dir / 'sparse' 17 | 18 | colmap_out_path.mkdir(exist_ok=True, parents=True) 19 | rgb_raw_dir.mkdir(exist_ok=True, parents=True) 20 | 21 | print(f"""Directories configured: 22 | root_dir = {root_dir} 23 | rgb_raw_dir = {rgb_raw_dir} 24 | rgb_dir = {rgb_dir} 25 | colmap_dir = {colmap_dir} 26 | """) 27 | 28 | # ==================== colmap ========================= 29 | # @title Extract features. 30 | # @markdown Computes SIFT features and saves them to the COLMAP DB. 31 | share_intrinsics = True # @param {type: 'boolean'} 32 | assume_upright_cameras = True # @param {type: 'boolean'} 33 | 34 | # @markdown This sets the scale at which we will run COLMAP. A scale of 1 will be more accurate but will be slow. 35 | colmap_image_scale = 4 # @param {type: 'number'} 36 | colmap_rgb_dir = rgb_dir / f'{colmap_image_scale}x' 37 | 38 | # @markdown Check this if you want to re-process SfM. 39 | overwrite = False # @param {type: 'boolean'} 40 | 41 | if overwrite and colmap_db_path.exists(): 42 | colmap_db_path.unlink() 43 | 44 | os.system('colmap feature_extractor \ 45 | --SiftExtraction.use_gpu 0 \ 46 | --SiftExtraction.upright {int(assume_upright_cameras)} \ 47 | --ImageReader.camera_model OPENCV \ 48 | --ImageReader.single_camera {int(share_intrinsics)} \ 49 | --database_path "{str(colmap_db_path)}" \ 50 | --image_path "{str(colmap_rgb_dir)}"') 51 | 52 | # @title Match features. 53 | # @markdown Match the SIFT features between images. Use `exhaustive` if you only have a few images and use `vocab_tree` if you have a lot of images. 54 | 55 | match_method = 'exhaustive' # @param ["exhaustive", "vocab_tree"] 56 | 57 | if match_method == 'exhaustive': 58 | os.system('colmap exhaustive_matcher \ 59 | --SiftMatching.use_gpu 0 \ 60 | --database_path "{str(colmap_db_path)}"') 61 | 62 | # @title Reconstruction. 63 | # @markdown Run structure-from-motion to compute camera parameters. 64 | 65 | refine_principal_point = True # @param {type:"boolean"} 66 | min_num_matches = 32 # @param {type: 'number'} 67 | filter_max_reproj_error = 2 # @param {type: 'number'} 68 | tri_complete_max_reproj_error = 2 # @param {type: 'number'} 69 | 70 | os.system('colmap mapper \ 71 | --Mapper.ba_refine_principal_point {int(refine_principal_point)} \ 72 | --Mapper.filter_max_reproj_error $filter_max_reproj_error \ 73 | --Mapper.tri_complete_max_reproj_error $tri_complete_max_reproj_error \ 74 | --Mapper.min_num_matches $min_num_matches \ 75 | --database_path "{str(colmap_db_path)}" \ 76 | --image_path "{str(colmap_rgb_dir)}" \ 77 | --export_path "{str(colmap_out_path)}"') 78 | 79 | print("debug") 80 | -------------------------------------------------------------------------------- /utils/rigid_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def skew(w: torch.Tensor) -> torch.Tensor: 5 | """Build a skew matrix ("cross product matrix") for vector w. 6 | 7 | Modern Robotics Eqn 3.30. 8 | 9 | Args: 10 | w: (N, 3) A 3-vector 11 | 12 | Returns: 13 | W: (N, 3, 3) A skew matrix such that W @ v == w x v 14 | """ 15 | zeros = torch.zeros(w.shape[0], device=w.device) 16 | w_skew_list = [zeros, -w[:, 2], w[:, 1], 17 | w[:, 2], zeros, -w[:, 0], 18 | -w[:, 1], w[:, 0], zeros] 19 | w_skew = torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3) 20 | return w_skew 21 | 22 | 23 | def rp_to_se3(R: torch.Tensor, p: torch.Tensor) -> torch.Tensor: 24 | """Rotation and translation to homogeneous transform. 25 | 26 | Args: 27 | R: (3, 3) An orthonormal rotation matrix. 28 | p: (3,) A 3-vector representing an offset. 29 | 30 | Returns: 31 | X: (4, 4) The homogeneous transformation matrix described by rotating by R 32 | and translating by p. 33 | """ 34 | bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=R.device).repeat(R.shape[0], 1, 1) 35 | transform = torch.cat([torch.cat([R, p], dim=-1), bottom_row], dim=1) 36 | 37 | return transform 38 | 39 | 40 | def exp_so3(w: torch.Tensor, theta: float) -> torch.Tensor: 41 | """Exponential map from Lie algebra so3 to Lie group SO3. 42 | 43 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 44 | 45 | Args: 46 | w: (3,) An axis of rotation. 47 | theta: An angle of rotation. 48 | 49 | Returns: 50 | R: (3, 3) An orthonormal rotation matrix representing a rotation of 51 | magnitude theta about axis w. 52 | """ 53 | W = skew(w) 54 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 55 | W_sqr = torch.bmm(W, W) # batch matrix multiplication 56 | R = identity + torch.sin(theta.unsqueeze(-1)) * W + (1.0 - torch.cos(theta.unsqueeze(-1))) * W_sqr 57 | return R 58 | 59 | 60 | def exp_se3(S: torch.Tensor, theta: float) -> torch.Tensor: 61 | """Exponential map from Lie algebra so3 to Lie group SO3. 62 | 63 | Modern Robotics Eqn 3.88. 64 | 65 | Args: 66 | S: (6,) A screw axis of motion. 67 | theta: Magnitude of motion. 68 | 69 | Returns: 70 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating 71 | motion of magnitude theta about S for one second. 72 | """ 73 | w, v = torch.split(S, 3, dim=-1) 74 | W = skew(w) 75 | R = exp_so3(w, theta) 76 | 77 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 78 | W_sqr = torch.bmm(W, W) 79 | theta = theta.view(-1, 1, 1) 80 | 81 | p = torch.bmm((theta * identity + (1.0 - torch.cos(theta)) * W + (theta - torch.sin(theta)) * W_sqr), 82 | v.unsqueeze(-1)) 83 | return rp_to_se3(R, p) 84 | 85 | 86 | def to_homogenous(v: torch.Tensor) -> torch.Tensor: 87 | """Converts a vector to a homogeneous coordinate vector by appending a 1. 88 | 89 | Args: 90 | v: A tensor representing a vector or batch of vectors. 91 | 92 | Returns: 93 | A tensor with an additional dimension set to 1. 94 | """ 95 | return torch.cat([v, torch.ones_like(v[..., :1])], dim=-1) 96 | 97 | 98 | def from_homogenous(v: torch.Tensor) -> torch.Tensor: 99 | """Converts a homogeneous coordinate vector to a standard vector by dividing by the last element. 100 | 101 | Args: 102 | v: A tensor representing a homogeneous coordinate vector or batch of homogeneous coordinate vectors. 103 | 104 | Returns: 105 | A tensor with the last dimension removed. 106 | """ 107 | return v[..., :3] / v[..., -1:] 108 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | 115 | def RGB2SH(rgb): 116 | return (rgb - 0.5) / C0 117 | 118 | 119 | def SH2RGB(sh): 120 | return sh * C0 + 0.5 121 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | if not os.path.exists(folder): 30 | return None 31 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder) if "_" in fname] 32 | return max(saved_iters) if saved_iters != [] else None 33 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | from gaussian_renderer import render 2 | 3 | 4 | def render_cur_cam(self, cur_cam): 5 | fid = cur_cam.fid 6 | if self.deform.name == 'node': 7 | if 'Node' in self.visualization_mode: 8 | gaussians = self.deform.deform.as_gaussians # if self.iteration_node_rendering < self.opt.iterations_node_rendering else self.deform.deform.as_gaussians_visualization 9 | time_input = fid.unsqueeze(0).expand(gaussians.get_xyz.shape[0], -1) 10 | d_values = self.deform.deform.query_network(x=gaussians.get_xyz.detach(), t=time_input) 11 | if self.motion_animation_d_values is not None: 12 | for key in self.motion_animation_d_values: 13 | d_values[key] = self.motion_animation_d_values[key] 14 | d_xyz, d_opacity, d_color = d_values['d_xyz'] * gaussians.motion_mask, d_values['d_opacity'] * gaussians.motion_mask if d_values['d_opacity'] is not None else None, d_values['d_color'] * gaussians.motion_mask if d_values['d_color'] is not None else None 15 | d_rotation, d_scaling = 0., 0. 16 | if self.animation_trans_bias is not None: 17 | d_xyz = d_xyz + self.animation_trans_bias 18 | gs_rot_bias = None 19 | vis_scale_const = self.vis_scale_const 20 | else: 21 | time_input = self.deform.deform.expand_time(fid) 22 | d_values = self.deform.step(self.gaussians.get_xyz.detach(), time_input, feature=self.gaussians.feature, is_training=False, node_trans_bias=self.animation_trans_bias, node_rot_bias=self.animation_rot_bias, motion_mask=self.gaussians.motion_mask, camera_center=cur_cam.camera_center, animation_d_values=self.motion_animation_d_values) 23 | gaussians = self.gaussians 24 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 25 | gs_rot_bias = d_values['gs_rot_bias'] # GS rotation bias 26 | vis_scale_const = None 27 | else: 28 | vis_scale_const = None 29 | if self.iteration < self.opt.warm_up: 30 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = 0.0, 0.0, 0.0, 0.0, 0.0 31 | gaussians = self.gaussians 32 | else: 33 | N = self.gaussians.get_xyz.shape[0] 34 | time_input = fid.unsqueeze(0).expand(N, -1) 35 | gaussians = self.gaussians 36 | d_values = self.deform.step(self.gaussians.get_xyz.detach(), time_input, feature=self.gaussians.feature, camera_center=cur_cam.camera_center) 37 | d_xyz, d_rotation, d_scaling, d_opacity, d_color = d_values['d_xyz'], d_values['d_rotation'], d_values['d_scaling'], d_values['d_opacity'], d_values['d_color'] 38 | gs_rot_bias = None 39 | 40 | render_motion = "Motion" in self.visualization_mode 41 | if render_motion: 42 | vis_scale_const = self.vis_scale_const 43 | if type(d_rotation) is not float and gaussians._rotation.shape[0] != d_rotation.shape[0]: 44 | d_xyz, d_rotation, d_scaling = 0, 0, 0 45 | print('Async in Gaussian Switching') 46 | out = render(viewpoint_camera=cur_cam, pc=gaussians, pipe=self.pipe, bg_color=self.background, d_xyz=d_xyz, d_rotation=d_rotation, d_scaling=d_scaling, render_motion=render_motion, d_opacity=d_opacity, d_color=d_color, d_rot_as_res=self.deform.d_rot_as_res, gs_rot_bias=gs_rot_bias, scale_const=vis_scale_const) 47 | 48 | buffer_image = out[self.mode] # [3, H, W] 49 | return buffer_image 50 | --------------------------------------------------------------------------------