├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── arguments └── __init__.py ├── assets ├── pipeline.png └── teaser.png ├── environment.yml ├── gaussian_renderer └── __init__.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── render.py ├── requirements.txt ├── scene ├── __init__.py ├── cameras.py ├── dataset_readers.py └── gaussian_model.py ├── submodules └── simple-knn │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn.cu │ ├── simple_knn.h │ ├── simple_knn │ └── .gitkeep │ ├── spatial.cu │ └── spatial.h └── utils ├── camera_utils.py ├── cmap.py ├── dynamic_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── iou_utils.py ├── loss_utils.py ├── nvseg_utils.py ├── semantic_utils.py ├── sh_utils.py ├── system_utils.py └── vehicle_template ├── benz_kitti.ply ├── benz_kitti360.ply ├── benz_pandaset.ply └── benz_waymo.ply /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | *.egg-info 10 | external 11 | shell 12 | .DS_Store -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/hugs-rasterization"] 5 | path = submodules/hugs-rasterization 6 | url = https://github.com/hyzhou404/hugs-rasterization 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | HUGS License 2 | =========================== 3 | 4 | **Zhejiang University** hold all the ownership rights on the *Software* named **HUGS**. 5 | 6 | The *Software* is still being developed by the *Licensor*. 7 | 8 | *Licensor*'s goal is to allow the research community to use, test and evaluate 9 | the *Software*. 10 | 11 | ## 1. Definitions 12 | 13 | *Licensee* means any person or entity that uses the *Software* and distributes 14 | its *Work*. 15 | 16 | *Licensor* means the owners of the *Software*, i.e Zhejiang University 17 | 18 | *Software* means the original work of authorship made available under this 19 | License ie HUGS. 20 | 21 | *Work* means the *Software* and any additions to or derivative works of the 22 | *Software* that are made available under this License. 23 | 24 | 25 | ## 2. Purpose 26 | This license is intended to define the rights granted to the *Licensee* by 27 | Licensors under the *Software*. 28 | 29 | ## 3. Rights granted 30 | 31 | For the above reasons Licensors have decided to distribute the *Software*. 32 | Licensors grant non-exclusive rights to use the *Software* for research purposes 33 | to research users (both academic and industrial), free of charge, without right 34 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 35 | and/or evaluation purposes only. 36 | 37 | Subject to the terms and conditions of this License, you are granted a 38 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 39 | publicly display, publicly perform and distribute its *Work* and any resulting 40 | derivative works in any form. 41 | 42 | ## 4. Limitations 43 | 44 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 45 | so under this License, (b) you include a complete copy of this License with 46 | your distribution, and (c) you retain without modification any copyright, 47 | patent, trademark, or attribution notices that are present in the *Work*. 48 | 49 | **4.2 Derivative Works.** You may specify that additional or different terms apply 50 | to the use, reproduction, and distribution of your derivative works of the *Work* 51 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 52 | Section 2 applies to your derivative works, and (b) you identify the specific 53 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 54 | this License (including the redistribution requirements in Section 3.1) will 55 | continue to apply to the *Work* itself. 56 | 57 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 58 | users explicitly acknowledge having received from Licensors all information 59 | allowing to appreciate the adequacy between of the *Software* and their needs and 60 | to undertake all necessary precautions for its execution and use. 61 | 62 | **4.4** The *Software* is provided both as a compiled library file and as source 63 | code. In case of using the *Software* for a publication or other results obtained 64 | through the use of the *Software*, users are strongly encouraged to cite the 65 | corresponding publications as explained in the documentation of the *Software*. 66 | 67 | ## 5. Disclaimer 68 | 69 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 70 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT Zhejiang University FOR ANY 71 | UNAUTHORIZED USE: yiyi.liao@zju.edu.cn. ANY SUCH ACTION WILL 72 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 73 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 74 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 75 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL Zhejiang University OR THE 76 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 77 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 78 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 79 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 80 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 81 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HUGS: Holistic Urban 3D Scene Understanding via Gaussian Splatting 2 | 3 | [Hongyu Zhou](https://github.com/hyzhou404), [Jiahao Shao](https://jhaoshao.github.io/), Lu Xu, Dongfeng Bai, [Weichao Qiu](https://weichaoqiu.com/), Bingbing Liu, [Yue Wang](https://ywang-zju.github.io/), [Andreas Geiger](https://www.cvlibs.net/) , [Yiyi Liao](https://yiyiliao.github.io/)
4 | 5 | | [Webpage](https://xdimlab.github.io/hugs_website/) | [Full Paper](https://openaccess.thecvf.com/content/CVPR2024/html/Zhou_HUGS_Holistic_Urban_3D_Scene_Understanding_via_Gaussian_Splatting_CVPR_2024_paper.html) | [Video](https://www.youtube.com/watch?v=DmPhL-8FeT4) 6 | 7 | This repository contains the official authors implementation associated with the paper "HUGS: Holistic Urban 3D Scene Understanding via Gaussian Splatting", which can be found [here](https://xdimlab.github.io/hugs_website/). 8 | 9 | ![image teaser](./assets/teaser.png) 10 | 11 | Abstract: *Holistic understanding of urban scenes based on RGB images is a challenging yet important problem. It encompasses understanding both the geometry and appearance to enable novel view synthesis, parsing semantic labels, and tracking moving objects. Despite considerable progress, existing approaches often focus on specific aspects of this task and require additional inputs such as LiDAR scans or manually annotated 3D bounding boxes. In this paper, we introduce a novel pipeline that utilizes 3D Gaussian Splatting for holistic urban scene understanding. Our main idea involves the joint optimization of geometry, appearance, semantics, and motion using a combination of static and dynamic 3D Gaussians, where moving object poses are regularized via physical constraints. Our approach offers the ability to render new viewpoints in real-time, yielding 2D and 3D semantic information with high accuracy, and reconstruct dynamic scenes, even in scenarios where 3D bounding box detection are highly noisy. Experimental results on KITTI, KITTI-360, and Virtual KITTI 2 demonstrate the effectiveness of our approach.* 12 | 13 | 14 | 15 | ## Cloning the Repository 16 | 17 | The repository contains submodules, thus please check it out with 18 | ```shell 19 | # SSH 20 | git clone git@github.com:hyzhou404/hugs.git --recursive 21 | ``` 22 | or 23 | ```shell 24 | # HTTPS 25 | git clone https://github.com/hyzhou404/hugs --recursive 26 | ``` 27 | 28 | 29 | 30 | ## Prepare Enviroments 31 | 32 | Create conda environment: 33 | 34 | ```shell 35 | conda create -n hugs python=3.10 -y 36 | ``` 37 | 38 | Please install [PyTorch](https://pytorch.org/), [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), [pytorch3d](https://github.com/facebookresearch/pytorch3d/tree/main) and [flow-vis-torch](https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch) by following official instructions. 39 | 40 | Install submodules by running: 41 | 42 | ```shell 43 | pip install submodules/simple-knn 44 | pip install submodules/hugs-rasterization 45 | ``` 46 | 47 | Install remaining packages by running: 48 | ```shell 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 53 | ## Data & Checkpoints Download 54 | 55 | we have made available two sequences from KITTI as indicated in our paper. Furthermore, three sequences from KITTI-360 and one sequence from Waymo has also been provided. 56 | 57 | Download checkpoints from [here](https://huggingface.co/datasets/hyzhou404/hugs_release). 58 | 59 | ```python 60 | unzip ${sequence}.zip 61 | ``` 62 | 63 | 64 | 65 | ## Rendering 66 | 67 | Render test views by running: 68 | 69 | ```shell 70 | python render.py -m ${checkpoint_path} --data_type ${dataset_type} --iteration 30000 --affine 71 | ``` 72 | 73 | The variable **dataset_type** is a string, and its value can be one of the following: **kitti**, **kitti360**, or **waymo**. 74 | 75 | 76 | ## Evaluation 77 | 78 | ``` 79 | python metrics.py -m ${checkpoint_path} 80 | ``` 81 | 82 | ## Training 83 | This repository only includes the inference code of HUGS. The code for training will be released in future work. 84 | 85 | 86 |
87 |
88 |

BibTeX

89 |
@InProceedings{Zhou_2024_CVPR,
90 |     author    = {Zhou, Hongyu and Shao, Jiahao and Xu, Lu and Bai, Dongfeng and Qiu, Weichao and Liu, Bingbing and Wang, Yue and Geiger, Andreas and Liao, Yiyi},
91 |     title     = {HUGS: Holistic Urban 3D Scene Understanding via Gaussian Splatting},
92 |     booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
93 |     month     = {June},
94 |     year      = {2024},
95 |     pages     = {21336-21345}
96 |     }
97 |
98 |
99 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cpu" 56 | self.eval = False 57 | super().__init__(parser, "Loading Parameters", sentinel) 58 | 59 | def extract(self, args): 60 | g = super().extract(args) 61 | g.source_path = os.path.abspath(g.source_path) 62 | return g 63 | 64 | class PipelineParams(ParamGroup): 65 | def __init__(self, parser): 66 | self.convert_SHs_python = False 67 | self.compute_cov3D_python = False 68 | self.debug = False 69 | super().__init__(parser, "Pipeline Parameters") 70 | 71 | class OptimizationParams(ParamGroup): 72 | def __init__(self, parser): 73 | self.iterations = 30_000 74 | self.position_lr_init = 0.00016 75 | self.position_lr_final = 0.0000016 76 | self.position_lr_delay_mult = 0.01 77 | self.position_lr_max_steps = 30_000 78 | self.feature_lr = 0.0025 79 | self.opacity_lr = 0.05 80 | self.scaling_lr = 0.001 81 | self.rotation_lr = 0.001 82 | self.percent_dense = 0.001 83 | self.lambda_dssim = 0.2 84 | self.densification_interval = 100 85 | self.opacity_reset_interval = 3000 86 | self.densify_from_iter = 500 87 | self.densify_until_iter = 15_000 88 | self.densify_grad_threshold = 0.0002 89 | super().__init__(parser, "Optimization Parameters") 90 | 91 | def get_combined_args(parser : ArgumentParser): 92 | cmdlne_string = sys.argv[1:] 93 | cfgfile_string = "Namespace()" 94 | args_cmdline = parser.parse_args(cmdlne_string) 95 | 96 | try: 97 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 98 | print("Looking for config file in", cfgfilepath) 99 | with open(cfgfilepath) as cfg_file: 100 | print("Config file found: {}".format(cfgfilepath)) 101 | cfgfile_string = cfg_file.read() 102 | except TypeError: 103 | print("Config file not found at") 104 | pass 105 | args_cfgfile = eval(cfgfile_string) 106 | 107 | merged_dict = vars(args_cfgfile).copy() 108 | for k,v in vars(args_cmdline).items(): 109 | if v != None: 110 | merged_dict[k] = v 111 | return Namespace(**merged_dict) 112 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/assets/pipeline.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/assets/teaser.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gaussian_splatting 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - pip: 16 | - submodules/diff-gaussian-rasterization 17 | - submodules/simple-knn -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh, RGB2SH 17 | from pytorch3d.transforms import quaternion_to_matrix, matrix_to_quaternion 18 | 19 | def euler2matrix(yaw): 20 | cos = torch.cos(-yaw) 21 | sin = torch.sin(-yaw) 22 | rot = torch.eye(3).float().cuda() 23 | rot[0,0] = cos 24 | rot[0,2] = sin 25 | rot[2,0] = -sin 26 | rot[2,2] = cos 27 | return rot 28 | 29 | def cat_bgfg(bg, fg, only_dynamic=False, only_xyz=False): 30 | if only_xyz: 31 | bg_feats = [bg.get_xyz] 32 | else: 33 | bg_feats = [bg.get_xyz, bg.get_opacity, bg.get_scaling, bg.get_rotation, bg.get_features, bg.get_3D_features] 34 | 35 | output = [] 36 | for fg_feat, bg_feat in zip(fg, bg_feats): 37 | if fg_feat is None: 38 | output.append(bg_feat) 39 | elif only_dynamic: 40 | output.append(fg_feat) 41 | else: 42 | output.append(torch.cat((bg_feat, fg_feat), dim=0)) 43 | 44 | return output 45 | 46 | 47 | def cat_all_fg(all_fg, next_fg): 48 | output = [] 49 | for feat, next_feat in zip(all_fg, next_fg): 50 | if feat is None: 51 | feat = next_feat 52 | else: 53 | feat = torch.cat((feat, next_feat), dim=0) 54 | output.append(feat) 55 | return output 56 | 57 | 58 | def proj_uv(xyz, cam): 59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | intr = torch.as_tensor(cam.K[:3, :3]).float().to(device) # (3, 3) 61 | w2c = torch.tensor(cam.w2c).float().to(device)[:3, :] # (3, 4) 62 | 63 | c_xyz = (w2c[:3, :3] @ xyz.T).T + w2c[:3, 3] 64 | i_xyz = (intr @ c_xyz.mT).mT # (N, 3) 65 | uv = i_xyz[:, :2] / i_xyz[:, -1:].clip(1e-3) # (N, 2) 66 | return uv 67 | 68 | 69 | def unicycle_b2w(timestamp, model): 70 | # model = unicycle_models[track_id]['model'] 71 | pred = model(timestamp) 72 | if pred is None: 73 | return None 74 | pred_a, pred_b, pred_v, pred_phi, pred_h = pred 75 | # r = euler_angles_to_matrix(torch.tensor([0, pred_phi-torch.pi, 0]), 'XYZ') 76 | rt = torch.eye(4).float().cuda() 77 | rt[:3,:3] = euler2matrix(pred_phi) 78 | rt[1, 3], rt[0, 3], rt[2, 3] = pred_h, pred_a, pred_b 79 | return rt 80 | 81 | def render(viewpoint_camera, prev_viewpoint_camera, pc : GaussianModel, dynamic_gaussians : dict, 82 | unicycles : dict, pipe, bg_color : torch.Tensor, 83 | render_optical=False, scaling_modifier = 1.0, only_dynamic=False): 84 | """ 85 | Render the scene. 86 | 87 | Background tensor (bg_color) must be on GPU! 88 | """ 89 | timestamp = viewpoint_camera.timestamp 90 | 91 | all_fg = [None, None, None, None, None, None] 92 | prev_all_fg = [None] 93 | 94 | if len(unicycles) == 0: 95 | track_dict = viewpoint_camera.dynamics 96 | if prev_viewpoint_camera is not None: 97 | prev_track_dict = prev_viewpoint_camera.dynamics 98 | else: 99 | track_dict, prev_track_dict = {}, {} 100 | for track_id, uni_model in unicycles.items(): 101 | B2W = unicycle_b2w(timestamp, uni_model['model']) 102 | track_dict[track_id] = B2W 103 | if prev_viewpoint_camera is not None: 104 | prev_B2W = unicycle_b2w(prev_viewpoint_camera.timestamp, uni_model['model']) 105 | prev_track_dict[track_id] = prev_B2W 106 | 107 | for track_id, B2W in track_dict.items(): 108 | w_dxyz = (B2W[:3, :3] @ dynamic_gaussians[track_id].get_xyz.T).T + B2W[:3, 3] 109 | drot = quaternion_to_matrix(dynamic_gaussians[track_id].get_rotation) 110 | w_drot = matrix_to_quaternion(B2W[:3, :3] @ drot) 111 | next_fg = [w_dxyz, 112 | dynamic_gaussians[track_id].get_opacity, 113 | dynamic_gaussians[track_id].get_scaling, 114 | w_drot, 115 | dynamic_gaussians[track_id].get_features, 116 | dynamic_gaussians[track_id].get_3D_features] 117 | # next_fg = get_next_fg(dynamic_gaussians[track_id], B2W) 118 | # w_dxyz = next_fg[0] 119 | all_fg = cat_all_fg(all_fg, next_fg) 120 | 121 | if render_optical and prev_viewpoint_camera is not None: 122 | if track_id in prev_track_dict: 123 | prev_B2W = prev_track_dict[track_id] 124 | prev_w_dxyz = torch.mm(prev_B2W[:3, :3], dynamic_gaussians[track_id].get_xyz.T).T + prev_B2W[:3, 3] 125 | prev_all_fg = cat_all_fg(prev_all_fg, [prev_w_dxyz]) 126 | else: 127 | prev_all_fg = cat_all_fg(prev_all_fg, [w_dxyz]) 128 | 129 | xyz, opacity, scales, rotations, shs, feats3D = cat_bgfg(pc, all_fg) 130 | if render_optical and prev_viewpoint_camera is not None: 131 | prev_xyz = cat_bgfg(pc, prev_all_fg, only_xyz=True)[0] 132 | uv = proj_uv(xyz, viewpoint_camera) 133 | prev_uv = proj_uv(prev_xyz, prev_viewpoint_camera) 134 | delta_uv = uv - prev_uv 135 | delta_uv = torch.cat([delta_uv, torch.ones_like(delta_uv[:, :1], device=delta_uv.device)], dim=-1) 136 | else: 137 | delta_uv = torch.zeros_like(xyz) 138 | 139 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 140 | screenspace_points = torch.zeros_like(xyz, dtype=xyz.dtype, requires_grad=True, device="cuda") + 0 141 | try: 142 | screenspace_points.retain_grad() 143 | except: 144 | pass 145 | 146 | # Set up rasterization configuration 147 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 148 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 149 | 150 | if pc.affine: 151 | cam_xyz, cam_dir = viewpoint_camera.c2w[:3, 3].cuda(), viewpoint_camera.c2w[:3, 2].cuda() 152 | o_enc = pc.pos_enc(cam_xyz[None, :] / 60) 153 | d_enc = pc.dir_enc(cam_dir[None, :]) 154 | appearance = pc.appearance_model(torch.cat([o_enc, d_enc], dim=1)) * 1e-1 155 | affine_weight, affine_bias = appearance[:, :9].view(3, 3), appearance[:, -3:] 156 | affine_weight = affine_weight + torch.eye(3, device=appearance.device) 157 | 158 | # bg_img = pc.sky_model(enc).view(*rays_d.shape).permute(2, 0, 1).float() 159 | 160 | raster_settings = GaussianRasterizationSettings( 161 | image_height=int(viewpoint_camera.image_height), 162 | image_width=int(viewpoint_camera.image_width), 163 | tanfovx=tanfovx, 164 | tanfovy=tanfovy, 165 | bg=bg_color, 166 | scale_modifier=scaling_modifier, 167 | viewmatrix=viewpoint_camera.world_view_transform, 168 | projmatrix=viewpoint_camera.full_proj_transform, 169 | sh_degree=pc.active_sh_degree, 170 | campos=viewpoint_camera.camera_center, 171 | prefiltered=False, 172 | debug=pipe.debug 173 | ) 174 | 175 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 176 | 177 | means3D = xyz 178 | means2D = screenspace_points 179 | 180 | cov3D_precomp = None 181 | colors_precomp = None 182 | 183 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 184 | rendered_image, radii, feats, depth, flow = rasterizer( 185 | means3D = means3D, 186 | means2D = means2D, 187 | shs = shs, 188 | colors_precomp = colors_precomp, 189 | opacities = opacity, 190 | scales = scales, 191 | rotations = rotations, 192 | cov3D_precomp = cov3D_precomp, 193 | feats3D = feats3D, 194 | delta = delta_uv) 195 | 196 | if pc.affine: 197 | colors = rendered_image.view(3, -1).permute(1, 0) # (H*W, 3) 198 | refined_image = (colors @ affine_weight + affine_bias).clip(0, 1).permute(1, 0).view(*rendered_image.shape) 199 | else: 200 | refined_image = rendered_image 201 | 202 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 203 | # They will be excluded from value updates used in the splitting criteria. 204 | return {"render": refined_image, 205 | "feats": feats, 206 | "depth": depth, 207 | "opticalflow": flow, 208 | "viewspace_points": screenspace_points, 209 | "visibility_filter" : radii > 0, 210 | "radii": radii} 211 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | from collections import OrderedDict 24 | 25 | def readImages(renders_dir, gt_dir): 26 | renders = [] 27 | gts = [] 28 | image_names = [] 29 | for fname in os.listdir(renders_dir): 30 | render = Image.open(renders_dir / fname) 31 | gt = Image.open(gt_dir / fname) 32 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 33 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 34 | image_names.append(fname) 35 | return renders, gts, image_names 36 | 37 | def evaluate(model_paths, write): 38 | # import ipdb; ipdb.set_trace() 39 | full_dict = {} 40 | per_view_dict = {} 41 | full_dict_polytopeonly = {} 42 | per_view_dict_polytopeonly = {} 43 | print("") 44 | 45 | scene_dir = model_paths[0] 46 | 47 | print("Scene:", scene_dir) 48 | 49 | for splits in ['test', 'train']: 50 | full_dict[splits] = {} 51 | per_view_dict[splits] = {} 52 | dir_path = Path(scene_dir) / splits 53 | for method in os.listdir(dir_path): 54 | print("Method:", method) 55 | full_dict[splits][method] = {} 56 | per_view_dict[splits][method] = {} 57 | 58 | method_dir = dir_path / method 59 | gt_dir = method_dir/ "gt" 60 | renders_dir = method_dir / "renders" 61 | renders, gts, image_names = readImages(renders_dir, gt_dir) 62 | 63 | ssims = [] 64 | psnrs = [] 65 | lpipss = [] 66 | 67 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 68 | ssims.append(ssim(renders[idx], gts[idx])) 69 | psnrs.append(psnr(renders[idx], gts[idx])) 70 | lpipss.append(lpips(renders[idx], gts[idx], net_type='alex')) 71 | 72 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 73 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 74 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 75 | print("") 76 | 77 | full_dict[splits][method].update({"SSIM": torch.tensor(ssims).mean().item(), 78 | "PSNR": torch.tensor(psnrs).mean().item(), 79 | "LPIPS": torch.tensor(lpipss).mean().item()}) 80 | per_view_dict[splits][method].update({ 81 | "SSIM": OrderedDict(sorted({name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}.items())), 82 | "PSNR": OrderedDict(sorted({name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}.items())), 83 | "LPIPS": OrderedDict(sorted({name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}.items())) 84 | }) 85 | if write: 86 | with open(scene_dir + "/metric_results.json", 'w') as fp: 87 | json.dump(full_dict, fp, indent=True) 88 | with open(scene_dir + "/per_view.json", 'w') as fp: 89 | json.dump(per_view_dict, fp, indent=True) 90 | 91 | if __name__ == "__main__": 92 | device = torch.device("cuda:0") 93 | torch.cuda.set_device(device) 94 | 95 | # Set up command line argument parser 96 | parser = ArgumentParser(description="Training script parameters") 97 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 98 | parser.add_argument('--write', action='store_false', default=True) 99 | args = parser.parse_args() 100 | evaluate(args.model_paths, args.write) 101 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | import numpy as np 24 | from copy import deepcopy 25 | from torchmetrics.functional import structural_similarity_index_measure as ssim 26 | import matplotlib.pyplot as plt 27 | from mpl_toolkits.axes_grid1 import make_axes_locatable 28 | from matplotlib import cm 29 | from utils.semantic_utils import colorize 30 | import flow_vis_torch 31 | from utils.cmap import color_depth_map 32 | from imageio.v2 import imwrite 33 | 34 | def to4x4(R, T): 35 | RT = np.eye(4,4) 36 | RT[:3, :3] = R 37 | RT[:3, 3] = T 38 | return RT 39 | 40 | def apply_colormap(image, cmap="viridis"): 41 | colormap = cm.get_cmap(cmap) 42 | colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore 43 | image_long = (image * 255).long() 44 | image_long_min = torch.min(image_long) 45 | image_long_max = torch.max(image_long) 46 | assert image_long_min >= 0, f"the min value is {image_long_min}" 47 | assert image_long_max <= 255, f"the max value is {image_long_max}" 48 | return colormap[image_long[0, ...]].permute(2, 0, 1) 49 | 50 | 51 | def apply_depth_colormap(depth, near_plane=None, far_plane=None, cmap="turbo"): 52 | near_plane = near_plane or float(torch.min(depth)) 53 | far_plane = far_plane or float(torch.max(depth)) 54 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) 55 | depth = torch.clip(depth, 0, 1) 56 | 57 | colored_image = apply_colormap(depth, cmap=cmap) 58 | return colored_image 59 | 60 | 61 | def render_set(model_path, name, iteration, views, scene, pipeline, background, data_type): 62 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 63 | semantic_path = os.path.join(model_path, name, "ours_{}".format(iteration), "semantic") 64 | optical_path = os.path.join(model_path, name, "ours_{}".format(iteration), "optical") 65 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 66 | error_path = os.path.join(model_path, name, "ours_{}".format(iteration), "error_map") 67 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 68 | 69 | makedirs(render_path, exist_ok=True) 70 | makedirs(semantic_path, exist_ok=True) 71 | makedirs(optical_path, exist_ok=True) 72 | makedirs(gts_path, exist_ok=True) 73 | makedirs(error_path, exist_ok=True) 74 | makedirs(depth_path, exist_ok=True) 75 | 76 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 77 | 78 | if data_type == 'kitti': 79 | gap = 2 80 | elif data_type == 'kitti360': 81 | gap = 4 82 | elif data_type == 'waymo': 83 | gap = 1 84 | elif data_type == 'nuscenes' or data_type == 'pandaset': 85 | gap = 6 86 | 87 | if idx - gap < 0: 88 | prev_view = None 89 | else: 90 | prev_view = views[idx-4] 91 | render_pkg = render( 92 | view, prev_view, scene.gaussians, scene.dynamic_gaussians, scene.unicycles, pipeline, background, True 93 | ) 94 | rendering = render_pkg['render'].detach().cpu() 95 | semantic = render_pkg['feats'].detach().cpu() 96 | semantic = torch.argmax(semantic, dim=0) 97 | semantic_rgb = colorize(semantic.detach().cpu().numpy()) 98 | depth = render_pkg['depth'] 99 | color_depth = color_depth_map(depth[0].detach().cpu().numpy()) 100 | color_depth[semantic == 10] = np.array([255.0, 255.0, 255.0]) 101 | gt = view.original_image[0:3, :, :] 102 | 103 | # _, ssim_map = ssim(rendering[None, ...], gt[None, ...], return_full_image=True) 104 | # ssim_map = torch.mean(ssim_map[0], dim=0).clip(0, 1)[None, ...] 105 | # error_map = 1 - ssim_maps 106 | error_map = torch.mean((rendering - gt) ** 2, dim=0)[None, ...] 107 | 108 | fig = plt.figure(frameon=False) 109 | fig.set_size_inches(1.408, 0.376) 110 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 111 | ax.set_axis_off() 112 | fig.add_axes(ax) 113 | ax.imshow((error_map.detach().cpu().numpy().transpose(1,2,0)), cmap='jet') 114 | plt.savefig(os.path.join(error_path, view.image_name + ".png"), dpi=1000) 115 | plt.close('all') 116 | 117 | torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + ".png")) 118 | torchvision.utils.save_image(gt, os.path.join(gts_path, view.image_name + ".png")) 119 | semantic_rgb.save(os.path.join(semantic_path, view.image_name + ".png")) 120 | imwrite(os.path.join(depth_path, view.image_name + ".png"), color_depth) 121 | 122 | opticalflow = render_pkg["opticalflow"] 123 | opticalflow = opticalflow.permute(1,2,0) 124 | opticalflow = opticalflow[..., :2] 125 | pytorch_optic_rgb = flow_vis_torch.flow_to_color(opticalflow.permute(2, 0, 1)) # (2, h, w) 126 | torchvision.utils.save_image(pytorch_optic_rgb.float(), os.path.join(optical_path, view.image_name + ".png"), normalize=True) 127 | # torchvision.utils.save_image(error_map, os.path.join(error_path, '{0:05d}'.format(idx) + ".png")) 128 | 129 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, 130 | skip_train : bool, skip_test : bool, data_type, affine, ignore_dynamic): 131 | with torch.no_grad(): 132 | gaussians = GaussianModel(dataset.sh_degree, affine=affine) 133 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False, data_type=data_type, ignore_dynamic=ignore_dynamic) 134 | 135 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 136 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 137 | 138 | if not skip_train: 139 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), scene, pipeline, background, data_type) 140 | 141 | if not skip_test: 142 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), scene, pipeline, background, data_type) 143 | 144 | 145 | if __name__ == "__main__": 146 | # Set up command line argument parser 147 | parser = ArgumentParser(description="Testing script parameters") 148 | model = ModelParams(parser, sentinel=True) 149 | pipeline = PipelineParams(parser) 150 | parser.add_argument("--iteration", default=-1, type=int) 151 | parser.add_argument("--data_type", default='kitti360', type=str) 152 | parser.add_argument("--affine", action="store_true") 153 | parser.add_argument("--ignore_dynamic", action="store_true") 154 | parser.add_argument("--skip_train", action="store_true") 155 | parser.add_argument("--skip_test", action="store_true") 156 | parser.add_argument("--quiet", action="store_true") 157 | args = get_combined_args(parser) 158 | print("Rendering " + args.model_path) 159 | args.source_path = os.path.join(args.model_path, 'data') 160 | 161 | # Initialize system state (RNG) 162 | # safe_state(args.quiet) 163 | 164 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), 165 | args.skip_train, args.skip_test, args.data_type, args.affine, args.ignore_dynamic) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | config==0.5.1 2 | datasets==2.19.2 3 | # flow_vis_torch==0.1 4 | imageio==2.34.1 5 | matplotlib==3.9.0 6 | network==0.1 7 | numpy==1.26.4 8 | open3d==0.18.0 9 | opencv_python==4.10.0.82 10 | Pillow==10.3.0 11 | plyfile==1.0.3 12 | # pytorch3d==0.7.4 13 | runx==0.0.11 14 | scipy==1.13.1 15 | setuptools==69.5.1 16 | # torch==2.3.1+cu118 17 | torchmetrics==1.4.0.post0 18 | # torchvision==0.18.1+cu118 19 | tqdm==4.66.4 -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | import torch 21 | import open3d as o3d 22 | import numpy as np 23 | from utils.dynamic_utils import create_unicycle_model 24 | import shutil 25 | 26 | class Scene: 27 | 28 | gaussians : GaussianModel 29 | 30 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, 31 | unicycle=False, uc_fit_iter=0, resolution_scales=[1.0], data_type='kitti360', ignore_dynamic=False): 32 | """b 33 | :param path: Path to colmap scene main folder. 34 | """ 35 | self.model_path = args.model_path 36 | self.loaded_iter = None 37 | self.gaussians = gaussians 38 | 39 | if load_iteration: 40 | if load_iteration == -1: 41 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "ckpts")) 42 | else: 43 | self.loaded_iter = load_iteration 44 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 45 | 46 | self.train_cameras = {} 47 | self.test_cameras = {} 48 | if os.path.exists(os.path.join(args.source_path, "sparse")): 49 | # scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 50 | raise NotImplementedError 51 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 52 | print("Found transforms_train.json file, assuming Blender data set!") 53 | # scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 54 | raise NotImplementedError 55 | elif os.path.exists(os.path.join(args.source_path, "meta_data.json")): 56 | print("Found meta_data.json file, assuming Studio data set!") 57 | scene_info = sceneLoadTypeCallbacks['Studio'](args.source_path, args.white_background, args.eval, data_type, ignore_dynamic) 58 | else: 59 | assert False, "Could not recognize scene type!" 60 | 61 | self.dynamic_verts = scene_info.verts 62 | self.dynamic_gaussians = {} 63 | for track_id in scene_info.verts: 64 | self.dynamic_gaussians[track_id] = GaussianModel(args.sh_degree, feat_mutable=False) 65 | 66 | if unicycle: 67 | self.unicycles = create_unicycle_model(scene_info.train_cameras, self.model_path, uc_fit_iter, data_type) 68 | else: 69 | self.unicycles = {} 70 | 71 | if not self.loaded_iter: 72 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 73 | dest_file.write(src_file.read()) 74 | json_cams = [] 75 | camlist = [] 76 | if scene_info.test_cameras: 77 | camlist.extend(scene_info.test_cameras) 78 | if scene_info.train_cameras: 79 | camlist.extend(scene_info.train_cameras) 80 | for id, cam in enumerate(camlist): 81 | json_cams.append(camera_to_JSON(id, cam)) 82 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 83 | json.dump(json_cams, file) 84 | shutil.copyfile(os.path.join(args.source_path, 'meta_data.json'), os.path.join(self.model_path, 'meta_data.json')) 85 | 86 | if shuffle: 87 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 88 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 89 | 90 | self.cameras_extent = scene_info.nerf_normalization["radius"] 91 | 92 | for resolution_scale in resolution_scales: 93 | print("Loading Training Cameras") 94 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 95 | print("Loading Test Cameras") 96 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 97 | 98 | if self.loaded_iter: 99 | (model_params, first_iter) = torch.load(os.path.join(self.model_path, "ckpts", f"chkpnt{self.loaded_iter}.pth")) 100 | gaussians.restore(model_params, None) 101 | for iid, dynamic_gaussian in self.dynamic_gaussians.items(): 102 | (model_params, first_iter) = torch.load(os.path.join(self.model_path, "ckpts", f"dynamic_{iid}_chkpnt{self.loaded_iter}.pth")) 103 | dynamic_gaussian.restore(model_params, None) 104 | for iid, unicycle_pkg in self.unicycles.items(): 105 | model_params = torch.load(os.path.join(self.model_path, "ckpts", f"unicycle_{iid}_chkpnt{self.loaded_iter}.pth")) 106 | unicycle_pkg['model'].restore(model_params) 107 | else: 108 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 109 | for track_id in self.dynamic_gaussians.keys(): 110 | vertices = scene_info.verts[track_id] 111 | 112 | # init from template 113 | l, h, w = vertices[:, 0].max() - vertices[:, 0].min(), vertices[:, 1].max() - vertices[:, 1].min(), vertices[:, 2].max() - vertices[:, 2].min() 114 | pcd = o3d.io.read_point_cloud(f"utils/vehicle_template/benz_{data_type}.ply") 115 | points = np.array(pcd.points) * np.array([l, h, w]) 116 | pcd.points = o3d.utility.Vector3dVector(points) 117 | pcd.colors = o3d.utility.Vector3dVector(np.ones_like(points) * 0.5) 118 | 119 | self.dynamic_gaussians[track_id].create_from_pcd(pcd, self.cameras_extent) 120 | 121 | def save(self, iteration): 122 | # self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 123 | point_cloud_vis_path = os.path.join(self.model_path, "point_cloud_vis/iteration_{}".format(iteration)) 124 | self.gaussians.save_vis_ply(os.path.join(point_cloud_vis_path, "point.ply")) 125 | for iid, dynamic_gaussian in self.dynamic_gaussians.items(): 126 | dynamic_gaussian.save_vis_ply(os.path.join(point_cloud_vis_path, f"dynamic_{iid}.ply")) 127 | 128 | def getTrainCameras(self, scale=1.0): 129 | return self.train_cameras[scale] 130 | 131 | def getTestCameras(self, scale=1.0): 132 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, fov2focal 16 | from utils.general_utils import decode_op 17 | 18 | class Camera(nn.Module): 19 | def __init__(self, colmap_id, R, T, K, FoVx, FoVy, image, 20 | image_name, uid, 21 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", 22 | cx_ratio=None, cy_ratio=None, semantic2d=None, mask=None, timestamp=-1, optical_image=None, dynamics={} 23 | ): 24 | super(Camera, self).__init__() 25 | self.uid = uid 26 | self.colmap_id = colmap_id 27 | self.R = R 28 | self.T = T 29 | self.K = K 30 | self.FoVx = FoVx 31 | self.FoVy = FoVy 32 | self.image_name = image_name 33 | self.cx_ratio = cx_ratio 34 | self.cy_ratio = cy_ratio 35 | self.timestamp = timestamp 36 | _, self.H, self.W = image.shape 37 | self.w2c = np.eye(4) 38 | self.w2c[:3, :3] = self.R.T 39 | self.w2c[:3, 3] = self.T 40 | self.c2w = torch.from_numpy(np.linalg.inv(self.w2c)).cuda() 41 | self.fx = fov2focal(self.FoVx, self.W) 42 | self.fy = fov2focal(self.FoVy, self.H) 43 | self.dynamics = dynamics 44 | 45 | try: 46 | self.data_device = torch.device(data_device) 47 | except Exception as e: 48 | print(e) 49 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 50 | self.data_device = torch.device("cuda") 51 | 52 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 53 | if semantic2d is not None: 54 | self.semantic2d = semantic2d.to(self.data_device) 55 | else: 56 | self.semantic2d = None 57 | if mask is not None: 58 | self.mask = torch.from_numpy(mask).bool().to(self.data_device) 59 | else: 60 | self.mask = None 61 | self.image_width = self.original_image.shape[2] 62 | self.image_height = self.original_image.shape[1] 63 | if optical_image is not None: 64 | self.optical_gt = torch.from_numpy(optical_image).to(self.data_device) 65 | else: 66 | self.optical_gt = None 67 | 68 | self.zfar = 100.0 69 | self.znear = 0.01 70 | 71 | self.trans = trans 72 | self.scale = scale 73 | 74 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 75 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, 76 | fovX=self.FoVx, fovY=self.FoVy, cx_ratio=cx_ratio, cy_ratio=cy_ratio).transpose(0,1).cuda() 77 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 78 | self.camera_center = self.world_view_transform.inverse()[3, :3] 79 | 80 | def get_rays(self): 81 | i, j = torch.meshgrid(torch.linspace(0, self.W-1, self.W), 82 | torch.linspace(0, self.H-1, self.H)) # pytorch's meshgrid has indexing='ij' 83 | i = i.t() 84 | j = j.t() 85 | dirs = torch.stack([(i-self.cx_ratio)/self.fx, -(j-self.cy_ratio)/self.fy, -torch.ones_like(i)], -1) 86 | rays_d = torch.sum(dirs[..., np.newaxis, :] * self.c2w[:3,:3], -1).to(self.data_device) 87 | rays_o = self.c2w[:3,-1].expand(rays_d.shape).to(self.data_device) 88 | rays_d = torch.nn.functional.normalize(rays_d, dim=-1) 89 | return rays_o.permute(2,0,1), rays_d.permute(2,0,1) 90 | 91 | class MiniCam: 92 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 93 | self.image_width = width 94 | self.image_height = height 95 | self.FoVy = fovy 96 | self.FoVx = fovx 97 | self.znear = znear 98 | self.zfar = zfar 99 | self.world_view_transform = world_view_transform 100 | self.full_proj_transform = full_proj_transform 101 | view_inv = torch.inverse(self.world_view_transform) 102 | self.camera_center = view_inv[3][:3] 103 | 104 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 17 | import numpy as np 18 | import json 19 | from pathlib import Path 20 | from plyfile import PlyData, PlyElement 21 | from utils.sh_utils import SH2RGB 22 | from scene.gaussian_model import BasicPointCloud 23 | import torch.nn.functional as F 24 | from imageio.v2 import imread 25 | import torch 26 | import random 27 | 28 | 29 | class CameraInfo(NamedTuple): 30 | uid: int 31 | R: np.array 32 | T: np.array 33 | K: np.array 34 | FovY: np.array 35 | FovX: np.array 36 | image: np.array 37 | image_path: str 38 | image_name: str 39 | width: int 40 | height: int 41 | cx_ratio: float 42 | cy_ratio: float 43 | semantic2d: np.array 44 | optical_image: np.array 45 | mask: np.array 46 | timestamp: int 47 | dynamics: dict 48 | 49 | class SceneInfo(NamedTuple): 50 | point_cloud: BasicPointCloud 51 | train_cameras: list 52 | test_cameras: list 53 | nerf_normalization: dict 54 | ply_path: str 55 | verts: dict 56 | 57 | def getNerfppNorm(cam_info): 58 | def get_center_and_diag(cam_centers): 59 | cam_centers = np.hstack(cam_centers) 60 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 61 | center = avg_cam_center 62 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 63 | diagonal = np.max(dist) 64 | return center.flatten(), diagonal 65 | 66 | cam_centers = [] 67 | 68 | for cam in cam_info: 69 | W2C = getWorld2View2(cam.R, cam.T) 70 | C2W = np.linalg.inv(W2C) 71 | cam_centers.append(C2W[:3, 3:4]) # cam_centers in world coordinate 72 | 73 | center, diagonal = get_center_and_diag(cam_centers) 74 | # radius = diagonal * 1.1 + 30 75 | radius = 10 76 | 77 | translate = -center 78 | 79 | return {"translate": translate, "radius": radius} 80 | 81 | def fetchPly(path): 82 | plydata = PlyData.read(path) 83 | vertices = plydata['vertex'] 84 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 85 | if 'red' in vertices: 86 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 87 | else: 88 | print('Create random colors') 89 | # shs = np.random.random((positions.shape[0], 3)) / 255.0 90 | shs = np.ones((positions.shape[0], 3)) * 0.5 91 | colors = SH2RGB(shs) 92 | # shs = np.ones((positions.shape[0], 3)) * 0.5 93 | # colors = SH2RGB(shs) 94 | normals = np.zeros((positions.shape[0], 3)) 95 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 96 | 97 | def storePly(path, xyz, rgb): 98 | # Define the dtype for the structured array 99 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 100 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 101 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 102 | 103 | normals = np.zeros_like(xyz) 104 | 105 | elements = np.empty(xyz.shape[0], dtype=dtype) 106 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 107 | elements[:] = list(map(tuple, attributes)) 108 | 109 | # Create the PlyData object and write to file 110 | vertex_element = PlyElement.describe(elements, 'vertex') 111 | ply_data = PlyData([vertex_element]) 112 | ply_data.write(path) 113 | 114 | def readStudioCameras(path, white_background, data_type, ignore_dynamic): 115 | train_cam_infos, test_cam_infos = [], [] 116 | with open(os.path.join(path, 'meta_data.json')) as json_file: 117 | meta_data = json.load(json_file) 118 | 119 | verts = {} 120 | if 'verts' in meta_data and not ignore_dynamic: 121 | verts_list = meta_data['verts'] 122 | for k, v in verts_list.items(): 123 | verts[k] = np.array(v) 124 | 125 | frames = meta_data['frames'] 126 | for idx, frame in enumerate(frames): 127 | matrix = np.linalg.inv(np.array(frame['camtoworld'])) 128 | R = matrix[:3, :3] 129 | T = matrix[:3, 3] 130 | R = np.transpose(R) 131 | 132 | rgb_path = os.path.join(path, frame['rgb_path'].replace('./', '')) 133 | 134 | rgb_split = rgb_path.split('/') 135 | image_name = '_'.join([rgb_split[-2], rgb_split[-1][:-4]]) 136 | image = Image.open(rgb_path) 137 | 138 | semantic_2d = None 139 | semantic_pth = rgb_path.replace("images", "semantics").replace('.png', '.npy').replace('.jpg', '.npy') 140 | if os.path.exists(semantic_pth): 141 | semantic_2d = np.load(semantic_pth) 142 | semantic_2d[(semantic_2d == 14) | (semantic_2d == 15)] = 13 143 | 144 | optical_path = rgb_path.replace("images", "flow").replace('.png', '_flow.npy').replace('.jpg', '_flow.npy') 145 | if os.path.exists(optical_path): 146 | optical_image = np.load(optical_path) 147 | else: 148 | optical_image = None 149 | 150 | mask = None 151 | mask_path = rgb_path.replace("images", "masks").replace('.png', '.npy').replace('.jpg', '.npy') 152 | if os.path.exists(mask_path): 153 | mask = np.load(mask_path) 154 | 155 | timestamp = frame.get('timestamp', -1) 156 | 157 | intrinsic = np.array(frame['intrinsics']) 158 | FovX = focal2fov(intrinsic[0, 0], image.size[0]) 159 | FovY = focal2fov(intrinsic[1, 1], image.size[1]) 160 | cx, cy = intrinsic[0, 2], intrinsic[1, 2] 161 | w, h = image.size 162 | 163 | dynamics = {} 164 | if 'dynamics' in frame and not ignore_dynamic: 165 | dynamics_list = frame['dynamics'] 166 | for iid in dynamics_list.keys(): 167 | dynamics[iid] = torch.tensor(dynamics_list[iid]).cuda() 168 | 169 | cam_info = CameraInfo(uid=idx, R=R, T=T, K=intrinsic, FovY=FovY, FovX=FovX, image=image, 170 | image_path=rgb_path, image_name=image_name, width=image.size[0], 171 | height=image.size[1], cx_ratio=2*cx/w, cy_ratio=2*cy/h, semantic2d=semantic_2d, 172 | optical_image=optical_image, mask=mask, timestamp=timestamp, dynamics=dynamics) 173 | 174 | # kitti360 175 | if data_type == 'kitti360': 176 | # if 'cam_2' in cam_info.image_name or 'cam_3' in cam_info.image_name: 177 | # train_cam_infos.append(cam_info) 178 | # # continue 179 | if idx < 20: 180 | train_cam_infos.append(cam_info) 181 | elif idx % 8 < 4: 182 | train_cam_infos.append(cam_info) 183 | elif idx % 8 >= 4: 184 | test_cam_infos.append(cam_info) 185 | else: 186 | continue 187 | 188 | elif data_type == 'kitti': 189 | if idx < 10 or idx >= len(frames) - 4: 190 | train_cam_infos.append(cam_info) 191 | elif idx % 4 < 2: 192 | train_cam_infos.append(cam_info) 193 | elif idx % 4 == 2: 194 | test_cam_infos.append(cam_info) 195 | else: 196 | continue 197 | 198 | elif data_type == "nuscenes": 199 | if idx < 600 or idx >= 1200: 200 | continue 201 | elif idx % 30 >= 24: 202 | # print('test', cam_info.image_name) 203 | test_cam_infos.append(cam_info) 204 | else: 205 | # print('train', cam_info.image_name) 206 | train_cam_infos.append(cam_info) 207 | 208 | elif data_type == "waymo": 209 | if idx > 10 and idx % 10 >= 9: 210 | test_cam_infos.append(cam_info) 211 | else: 212 | train_cam_infos.append(cam_info) 213 | 214 | elif data_type == "pandaset": 215 | # if idx >= 360: 216 | # continue 217 | if idx > 30 and idx % 30 >= 24: 218 | test_cam_infos.append(cam_info) 219 | else: 220 | train_cam_infos.append(cam_info) 221 | 222 | else: 223 | raise NotImplementedError 224 | return train_cam_infos, test_cam_infos, verts 225 | 226 | 227 | def readStudioInfo(path, white_background, eval, data_type, ignore_dynamic): 228 | train_cam_infos, test_cam_infos, verts = readStudioCameras(path, white_background, data_type, ignore_dynamic) 229 | 230 | print(f'Loaded {len(train_cam_infos)} train cameras and {len(test_cam_infos)} test cameras') 231 | nerf_normalization = getNerfppNorm(train_cam_infos) 232 | 233 | ply_path = os.path.join(path, "points3d.ply") 234 | # ply_path = os.path.join(path, 'lidar', 'cat.ply') 235 | if not os.path.exists(ply_path): 236 | # Since this data set has no colmap data, we start with random points 237 | num_pts = 500_000 238 | print(f"Generating random point cloud ({num_pts})...") 239 | 240 | # We create random points inside the bounds of the synthetic Blender scenes 241 | AABB = [[-20, -25, -20], [20, 5, 80]] 242 | xyz = np.random.uniform(AABB[0], AABB[1], (500000, 3)) 243 | # xyz = np.load(os.path.join(path, 'lidar_point.npy')) 244 | num_pts = xyz.shape[0] 245 | shs = np.ones((num_pts, 3)) * 0.5 246 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 247 | 248 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 249 | try: 250 | pcd = fetchPly(ply_path) 251 | except Exception as e: 252 | print('When loading point clound, meet error:', e) 253 | exit(0) 254 | 255 | scene_info = SceneInfo(point_cloud=pcd, 256 | train_cameras=train_cam_infos, 257 | test_cameras=test_cam_infos, 258 | nerf_normalization=nerf_normalization, 259 | ply_path=ply_path, 260 | verts=verts) 261 | return scene_info 262 | 263 | 264 | sceneLoadTypeCallbacks = { 265 | "Studio": readStudioInfo, 266 | } -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH, SH2RGB 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation 23 | import open3d as o3d 24 | import tinycudann as tcnn 25 | from math import sqrt 26 | 27 | class CustomAdam(torch.optim.Optimizer): 28 | def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8): 29 | defaults = dict(lr=lr, betas=betas, eps=eps) 30 | super(CustomAdam, self).__init__(params, defaults) 31 | 32 | def step(self, custom_lr=None, name=None): 33 | for group in self.param_groups: 34 | for p in group['params']: 35 | if p.grad is None: 36 | continue 37 | 38 | grad = p.grad.data 39 | if grad.is_sparse: 40 | raise RuntimeError('Adam does not support sparse gradients') 41 | 42 | state = self.state[p] 43 | 44 | # State initialization 45 | if len(state) == 0: 46 | state['step'] = 0 47 | # Exponential moving averages of gradient values 48 | state['exp_avg'] = torch.zeros_like(p.data) 49 | # Exponential moving averages of squared gradient values 50 | state['exp_avg_sq'] = torch.zeros_like(p.data) 51 | 52 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 53 | beta1, beta2 = group['betas'] 54 | 55 | # Add op to update moving averages 56 | state['step'] += 1 57 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 58 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 59 | 60 | denom = exp_avg_sq.sqrt().add_(group['eps']) 61 | 62 | bias_correction1 = 1.0 - beta1 ** state['step'] 63 | bias_correction2 = 1.0 - beta2 ** state['step'] 64 | 65 | if (custom_lr is not None) and (name is not None) and (group['name'] in name): 66 | step_size = custom_lr[:, None] * group['lr'] * (sqrt(bias_correction2) / bias_correction1) 67 | else: 68 | step_size = group['lr'] * (sqrt(bias_correction2) / bias_correction1) 69 | 70 | p.data -= step_size * exp_avg / denom 71 | 72 | class GaussianModel: 73 | 74 | def setup_functions(self): 75 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 76 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 77 | actual_covariance = L @ L.transpose(1, 2) 78 | symm = strip_symmetric(actual_covariance) 79 | return symm 80 | 81 | self.scaling_activation = torch.exp 82 | self.scaling_inverse_activation = torch.log 83 | 84 | self.covariance_activation = build_covariance_from_scaling_rotation 85 | 86 | self.opacity_activation = torch.sigmoid 87 | self.inverse_opacity_activation = inverse_sigmoid 88 | 89 | self.rotation_activation = torch.nn.functional.normalize 90 | 91 | 92 | def __init__(self, sh_degree : int, feat_mutable=True, affine=False): 93 | self.active_sh_degree = 0 94 | self.max_sh_degree = sh_degree 95 | self._xyz = torch.empty(0) 96 | self._features_dc = torch.empty(0) 97 | self._features_rest = torch.empty(0) 98 | self._feats3D = torch.empty(0) 99 | self._scaling = torch.empty(0) 100 | self._rotation = torch.empty(0) 101 | self._opacity = torch.empty(0) 102 | self.max_radii2D = torch.empty(0) 103 | self.xyz_gradient_accum = torch.empty(0) 104 | self.denom = torch.empty(0) 105 | self.optimizer = None 106 | self.percent_dense = 0 107 | self.spatial_lr_scale = 0 108 | self.feat_mutable = feat_mutable 109 | self.setup_functions() 110 | 111 | self.pos_enc = tcnn.Encoding( 112 | n_input_dims=3, 113 | encoding_config={"otype": "Frequency", "n_frequencies": 2}, 114 | ) 115 | self.dir_enc = tcnn.Encoding( 116 | n_input_dims=3, 117 | encoding_config={ 118 | "otype": "SphericalHarmonics", 119 | "degree": 3, 120 | }, 121 | ) 122 | 123 | self.affine = affine 124 | if affine: 125 | self.appearance_model = tcnn.Network( 126 | n_input_dims=self.pos_enc.n_output_dims + self.dir_enc.n_output_dims, 127 | n_output_dims=12, 128 | network_config={ 129 | "otype": "FullyFusedMLP", 130 | "activation": "ReLU", 131 | "output_activation": "None", 132 | "n_neurons": 32, 133 | "n_hidden_layers": 2, 134 | } 135 | ) 136 | else: 137 | self.appearance_model = None 138 | 139 | def capture(self): 140 | return ( 141 | self.active_sh_degree, 142 | self._xyz, 143 | self._features_dc, 144 | self._features_rest, 145 | self._feats3D, 146 | self._scaling, 147 | self._rotation, 148 | self._opacity, 149 | self.max_radii2D, 150 | self.xyz_gradient_accum, 151 | self.denom, 152 | self.optimizer.state_dict(), 153 | self.spatial_lr_scale, 154 | self.appearance_model, 155 | ) 156 | 157 | def restore(self, model_args, training_args): 158 | (self.active_sh_degree, 159 | self._xyz, 160 | self._features_dc, 161 | self._features_rest, 162 | self._feats3D, 163 | self._scaling, 164 | self._rotation, 165 | self._opacity, 166 | self.max_radii2D, 167 | xyz_gradient_accum, 168 | denom, 169 | opt_dict, 170 | self.spatial_lr_scale, 171 | self.appearance_model,) = model_args 172 | self.xyz_gradient_accum = xyz_gradient_accum 173 | self.denom = denom 174 | if training_args is not None: 175 | self.training_setup(training_args) 176 | self.optimizer.load_state_dict(opt_dict) 177 | 178 | @property 179 | def get_scaling(self): 180 | return self.scaling_activation(self._scaling) 181 | 182 | @property 183 | def get_rotation(self): 184 | return self.rotation_activation(self._rotation) 185 | 186 | # TODO add get_xyz for dynamic car 187 | @property 188 | def get_xyz(self): 189 | return self._xyz 190 | 191 | @property 192 | def get_features(self): 193 | features_dc = self._features_dc 194 | features_rest = self._features_rest 195 | return torch.cat((features_dc, features_rest), dim=1) 196 | 197 | @property 198 | def get_3D_features(self): 199 | return torch.softmax(self._feats3D, dim=-1) 200 | 201 | @property 202 | def get_opacity(self): 203 | return self.opacity_activation(self._opacity) 204 | 205 | def get_covariance(self, scaling_modifier = 1): 206 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 207 | 208 | def oneupSHdegree(self): 209 | if self.active_sh_degree < self.max_sh_degree: 210 | self.active_sh_degree += 1 211 | 212 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): 213 | # self.spatial_lr_scale = 1 214 | self.spatial_lr_scale = spatial_lr_scale 215 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 216 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 217 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 218 | features[:, :3, 0 ] = fused_color 219 | features[:, 3:, 1:] = 0.0 220 | 221 | if self.feat_mutable: 222 | feats3D = torch.zeros(fused_color.shape[0], 20).float().cuda() 223 | self._feats3D = nn.Parameter(feats3D.requires_grad_(True)) 224 | else: 225 | feats3D = torch.zeros(fused_color.shape[0], 20).float().cuda() 226 | feats3D[:, 13] = 1 227 | self._feats3D = nn.Parameter(feats3D.requires_grad_(True)) 228 | 229 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 230 | 231 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 232 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 233 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 234 | rots[:, 0] = 1 235 | 236 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 237 | 238 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 239 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 240 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 241 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 242 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 243 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 244 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 245 | 246 | def training_setup(self, training_args): 247 | self.percent_dense = training_args.percent_dense 248 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 249 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 250 | 251 | # self.spatial_lr_scale /= 3 252 | 253 | l = [ 254 | {'params': [self._xyz], 'lr': training_args.position_lr_init*self.spatial_lr_scale, "name": "xyz"}, 255 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 256 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 257 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 258 | {'params': [self._scaling], 'lr': training_args.scaling_lr*self.spatial_lr_scale, "name": "scaling"}, 259 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}, 260 | ] 261 | 262 | if self.affine: 263 | l.append({'params': [*self.appearance_model.parameters()], 'lr': 1e-3, "name": "appearance_model"}) 264 | 265 | if self.feat_mutable: 266 | l.append({'params': [self._feats3D], 'lr': 1e-2, "name": "feats3D"}) 267 | 268 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 269 | # self.optimizer = CustomAdam(l, lr=0.0, eps=1e-15) 270 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, 271 | lr_final=training_args.position_lr_final*self.spatial_lr_scale, 272 | lr_delay_mult=training_args.position_lr_delay_mult, 273 | max_steps=training_args.position_lr_max_steps) 274 | 275 | def update_learning_rate(self, iteration): 276 | ''' Learning rate scheduling per step ''' 277 | for param_group in self.optimizer.param_groups: 278 | if param_group["name"] == "xyz": 279 | lr = self.xyz_scheduler_args(iteration) 280 | param_group['lr'] = lr 281 | return lr 282 | 283 | def construct_list_of_attributes(self): 284 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 285 | # All channels except the 3 DC 286 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 287 | l.append('f_dc_{}'.format(i)) 288 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 289 | l.append('f_rest_{}'.format(i)) 290 | for i in range(self._feats3D.shape[1]): 291 | l.append('semantic_{}'.format(i)) 292 | l.append('opacity') 293 | for i in range(self._scaling.shape[1]): 294 | l.append('scale_{}'.format(i)) 295 | for i in range(self._rotation.shape[1]): 296 | l.append('rot_{}'.format(i)) 297 | return l 298 | 299 | def save_ply(self, path): 300 | mkdir_p(os.path.dirname(path)) 301 | 302 | xyz = self._xyz.detach().cpu().numpy() 303 | normals = np.zeros_like(xyz) 304 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 305 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 306 | feats3D = self._feats3D.detach().cpu().numpy() 307 | opacities = self._opacity.detach().cpu().numpy() 308 | scale = self._scaling.detach().cpu().numpy() 309 | rotation = self._rotation.detach().cpu().numpy() 310 | 311 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 312 | 313 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 314 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, feats3D, opacities, scale, rotation), axis=1) 315 | elements[:] = list(map(tuple, attributes)) 316 | el = PlyElement.describe(elements, 'vertex') 317 | PlyData([el]).write(path) 318 | 319 | def save_vis_ply(self, path): 320 | mkdir_p(os.path.dirname(path)) 321 | xyz = self.get_xyz.detach().cpu().numpy() 322 | pcd = o3d.geometry.PointCloud() 323 | pcd.points = o3d.utility.Vector3dVector(xyz) 324 | colors = SH2RGB(self._features_dc[:, 0, :].detach().cpu().numpy()).clip(0, 1) 325 | pcd.colors = o3d.utility.Vector3dVector(colors) 326 | o3d.io.write_point_cloud(path, pcd) 327 | 328 | def reset_opacity(self): 329 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 330 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 331 | self._opacity = optimizable_tensors["opacity"] 332 | 333 | def load_ply(self, path): 334 | plydata = PlyData.read(path) 335 | 336 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 337 | np.asarray(plydata.elements[0]["y"]), 338 | np.asarray(plydata.elements[0]["z"])), axis=1) 339 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 340 | 341 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 342 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 343 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 344 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 345 | 346 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 347 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 348 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 349 | for idx, attr_name in enumerate(extra_f_names): 350 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 351 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 352 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 353 | 354 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 355 | scales = np.zeros((xyz.shape[0], len(scale_names))) 356 | for idx, attr_name in enumerate(scale_names): 357 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 358 | 359 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 360 | rots = np.zeros((xyz.shape[0], len(rot_names))) 361 | for idx, attr_name in enumerate(rot_names): 362 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 363 | 364 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 365 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 366 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 367 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 368 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 369 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 370 | 371 | self.active_sh_degree = self.max_sh_degree 372 | 373 | def replace_tensor_to_optimizer(self, tensor, name): 374 | optimizable_tensors = {} 375 | for group in self.optimizer.param_groups: 376 | if group["name"] == name: 377 | stored_state = self.optimizer.state.get(group['params'][0], None) 378 | stored_state["exp_avg"] = torch.zeros_like(tensor) 379 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 380 | 381 | del self.optimizer.state[group['params'][0]] 382 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 383 | self.optimizer.state[group['params'][0]] = stored_state 384 | 385 | optimizable_tensors[group["name"]] = group["params"][0] 386 | return optimizable_tensors 387 | 388 | def _prune_optimizer(self, mask): 389 | optimizable_tensors = {} 390 | for group in self.optimizer.param_groups: 391 | if group['name'] == 'appearance_model': 392 | continue 393 | stored_state = self.optimizer.state.get(group['params'][0], None) 394 | if stored_state is not None: 395 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 396 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 397 | 398 | del self.optimizer.state[group['params'][0]] 399 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 400 | self.optimizer.state[group['params'][0]] = stored_state 401 | 402 | optimizable_tensors[group["name"]] = group["params"][0] 403 | else: 404 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 405 | optimizable_tensors[group["name"]] = group["params"][0] 406 | return optimizable_tensors 407 | 408 | def prune_points(self, mask): 409 | valid_points_mask = ~mask 410 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 411 | 412 | self._xyz = optimizable_tensors["xyz"] 413 | self._features_dc = optimizable_tensors["f_dc"] 414 | self._features_rest = optimizable_tensors["f_rest"] 415 | if self.feat_mutable: 416 | self._feats3D = optimizable_tensors["feats3D"] 417 | else: 418 | self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1)) 419 | self._opacity = optimizable_tensors["opacity"] 420 | self._scaling = optimizable_tensors["scaling"] 421 | self._rotation = optimizable_tensors["rotation"] 422 | 423 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 424 | 425 | self.denom = self.denom[valid_points_mask] 426 | self.max_radii2D = self.max_radii2D[valid_points_mask] 427 | 428 | def cat_tensors_to_optimizer(self, tensors_dict): 429 | optimizable_tensors = {} 430 | for group in self.optimizer.param_groups: 431 | if group['name'] not in tensors_dict: 432 | continue 433 | assert len(group["params"]) == 1 434 | extension_tensor = tensors_dict[group["name"]] 435 | stored_state = self.optimizer.state.get(group["params"][0], None) 436 | if stored_state is not None: 437 | 438 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 439 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 440 | 441 | del self.optimizer.state[group["params"][0]] 442 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 443 | self.optimizer.state[group["params"][0]] = stored_state 444 | 445 | optimizable_tensors[group["name"]] = group["params"][0] 446 | else: 447 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 448 | optimizable_tensors[group["name"]] = group["params"][0] 449 | 450 | return optimizable_tensors 451 | 452 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation): 453 | d = {"xyz": new_xyz, 454 | "f_dc": new_features_dc, 455 | "f_rest": new_features_rest, 456 | "feats3D": new_feats3D, 457 | "opacity": new_opacities, 458 | "scaling" : new_scaling, 459 | "rotation" : new_rotation} 460 | 461 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 462 | self._xyz = optimizable_tensors["xyz"] 463 | self._features_dc = optimizable_tensors["f_dc"] 464 | if self.feat_mutable: 465 | self._feats3D = optimizable_tensors["feats3D"] 466 | else: 467 | self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1)) 468 | self._features_rest = optimizable_tensors["f_rest"] 469 | self._opacity = optimizable_tensors["opacity"] 470 | self._scaling = optimizable_tensors["scaling"] 471 | self._rotation = optimizable_tensors["rotation"] 472 | 473 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 474 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 475 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 476 | 477 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 478 | n_init_points = self.get_xyz.shape[0] 479 | # Extract points that satisfy the gradient condition 480 | padded_grad = torch.zeros((n_init_points), device="cuda") 481 | padded_grad[:grads.shape[0]] = grads.squeeze() 482 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 483 | selected_pts_mask = torch.logical_and(selected_pts_mask, 484 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 485 | 486 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 487 | means =torch.zeros((stds.size(0), 3),device="cuda") 488 | samples = torch.normal(mean=means, std=stds) 489 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 490 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 491 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 492 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 493 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 494 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 495 | new_feats3D = self._feats3D[selected_pts_mask].repeat(N,1) 496 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 497 | 498 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacity, new_scaling, new_rotation) 499 | 500 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 501 | self.prune_points(prune_filter) 502 | 503 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 504 | # Extract points that satisfy the gradient condition 505 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 506 | selected_pts_mask = torch.logical_and(selected_pts_mask, 507 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 508 | 509 | new_xyz = self._xyz[selected_pts_mask] 510 | new_features_dc = self._features_dc[selected_pts_mask] 511 | new_features_rest = self._features_rest[selected_pts_mask] 512 | new_feats3D = self._feats3D[selected_pts_mask] 513 | new_opacities = self._opacity[selected_pts_mask] 514 | new_scaling = self._scaling[selected_pts_mask] 515 | new_rotation = self._rotation[selected_pts_mask] 516 | 517 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation) 518 | 519 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 520 | grads = self.xyz_gradient_accum / self.denom 521 | grads[grads.isnan()] = 0.0 522 | 523 | self.densify_and_clone(grads, max_grad, extent) 524 | self.densify_and_split(grads, max_grad, extent) 525 | 526 | prune_mask = (self.get_opacity < min_opacity).squeeze() 527 | if max_screen_size: 528 | big_points_vs = self.max_radii2D > max_screen_size 529 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent * 10 530 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 531 | self.prune_points(prune_mask) 532 | 533 | torch.cuda.empty_cache() 534 | 535 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 536 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) 537 | self.denom[update_filter] += 1 538 | 539 | def add_densification_stats_grad(self, tensor_grad, update_filter): 540 | self.xyz_gradient_accum[update_filter] += torch.norm(tensor_grad[update_filter,:2], dim=-1, keepdim=True) 541 | self.denom[update_filter] += 1 -------------------------------------------------------------------------------- /submodules/simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /submodules/simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /submodules/simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/submodules/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /submodules/simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, PIL2toTorch 15 | from utils.graphics_utils import fov2focal 16 | import torch 17 | 18 | WARNED = False 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | orig_w, orig_h = cam_info.image.size 22 | 23 | if args.resolution in [1, 2, 4, 8]: 24 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 25 | else: # should be a type that converts to float 26 | if args.resolution == -1: 27 | if orig_w > 1600: 28 | global WARNED 29 | if not WARNED: 30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 32 | WARNED = True 33 | global_down = orig_w / 1600 34 | else: 35 | global_down = 1 36 | else: 37 | global_down = orig_w / args.resolution 38 | 39 | scale = float(global_down) * float(resolution_scale) 40 | resolution = (int(orig_w / scale), int(orig_h / scale)) 41 | 42 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 43 | 44 | if cam_info.semantic2d is not None: 45 | semantic2d = torch.from_numpy(cam_info.semantic2d).long()[None, ...] 46 | else: 47 | semantic2d = None 48 | 49 | optical_image = cam_info.optical_image 50 | mask = cam_info.mask 51 | 52 | gt_image = resized_image_rgb[:3, ...] 53 | 54 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, K=cam_info.K, 55 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 56 | image=gt_image, image_name=cam_info.image_name, uid=id, data_device=args.data_device, 57 | cx_ratio=cam_info.cx_ratio, cy_ratio=cam_info.cy_ratio, semantic2d=semantic2d, mask=mask, 58 | timestamp=cam_info.timestamp, optical_image=optical_image, dynamics=cam_info.dynamics) 59 | 60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 61 | camera_list = [] 62 | 63 | for id, c in enumerate(cam_infos): 64 | camera_list.append(loadCam(args, id, c, resolution_scale)) 65 | 66 | return camera_list 67 | 68 | def camera_to_JSON(id, camera : Camera): 69 | Rt = np.zeros((4, 4)) 70 | Rt[:3, :3] = camera.R.transpose() 71 | Rt[:3, 3] = camera.T 72 | Rt[3, 3] = 1.0 73 | 74 | W2C = np.linalg.inv(Rt) 75 | pos = W2C[:3, 3] 76 | rot = W2C[:3, :3] 77 | serializable_array_2d = [x.tolist() for x in rot] 78 | camera_entry = { 79 | 'id' : id, 80 | 'img_name' : camera.image_name, 81 | 'width' : camera.width, 82 | 'height' : camera.height, 83 | 'position': pos.tolist(), 84 | 'rotation': serializable_array_2d, 85 | 'fy' : fov2focal(camera.FovY, camera.height), 86 | 'fx' : fov2focal(camera.FovX, camera.width), 87 | } 88 | return camera_entry 89 | -------------------------------------------------------------------------------- /utils/cmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | _color_map_errors = np.array([ 4 | [149, 54, 49], #0: log2(x) = -infinity 5 | [180, 117, 69], #0.0625: log2(x) = -4 6 | [209, 173, 116], #0.125: log2(x) = -3 7 | [233, 217, 171], #0.25: log2(x) = -2 8 | [248, 243, 224], #0.5: log2(x) = -1 9 | [144, 224, 254], #1.0: log2(x) = 0 10 | [97, 174, 253], #2.0: log2(x) = 1 11 | [67, 109, 244], #4.0: log2(x) = 2 12 | [39, 48, 215], #8.0: log2(x) = 3 13 | [38, 0, 165], #16.0: log2(x) = 4 14 | [38, 0, 165] #inf: log2(x) = inf 15 | ]).astype(float) 16 | 17 | def color_error_image(errors, scale=1, mask=None, BGR=True): 18 | """ 19 | Color an input error map. 20 | 21 | Arguments: 22 | errors -- HxW numpy array of errors 23 | [scale=1] -- scaling the error map (color change at unit error) 24 | [mask=None] -- zero-pixels are masked white in the result 25 | [BGR=True] -- toggle between BGR and RGB 26 | 27 | Returns: 28 | colored_errors -- HxWx3 numpy array visualizing the errors 29 | """ 30 | 31 | errors_flat = errors.flatten() 32 | errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9) 33 | i0 = np.floor(errors_color_indices).astype(int) 34 | f1 = errors_color_indices - i0.astype(float) 35 | colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1) 36 | 37 | if mask is not None: 38 | colored_errors_flat[mask.flatten() == 0] = 255 39 | 40 | if not BGR: 41 | colored_errors_flat = colored_errors_flat[:,[2,1,0]] 42 | 43 | return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) 44 | 45 | _color_map_depths = np.array([ 46 | [0, 0, 0], # 0.000 47 | [0, 0, 255], # 0.114 48 | [255, 0, 0], # 0.299 49 | [255, 0, 255], # 0.413 50 | [0, 255, 0], # 0.587 51 | [0, 255, 255], # 0.701 52 | [255, 255, 0], # 0.886 53 | [255, 255, 255], # 1.000 54 | [255, 255, 255], # 1.000 55 | ]).astype(float) 56 | _color_map_bincenters = np.array([ 57 | 0.0, 58 | 0.114, 59 | 0.299, 60 | 0.413, 61 | 0.587, 62 | 0.701, 63 | 0.886, 64 | 1.000, 65 | 2.000, # doesn't make a difference, just strictly higher than 1 66 | ]) 67 | 68 | def color_depth_map(depths, scale=None): 69 | """ 70 | Color an input depth map. 71 | 72 | Arguments: 73 | depths -- HxW numpy array of depths 74 | [scale=None] -- scaling the values (defaults to the maximum depth) 75 | 76 | Returns: 77 | colored_depths -- HxWx3 numpy array visualizing the depths 78 | """ 79 | 80 | # if scale is None: 81 | # scale = depths.max() / 1.5 82 | scale = 50 83 | values = np.clip(depths.flatten() / scale, 0, 1) 84 | # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? 85 | lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1) 86 | lower_bin_value = _color_map_bincenters[lower_bin] 87 | higher_bin_value = _color_map_bincenters[lower_bin + 1] 88 | alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) 89 | colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1) 90 | return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) -------------------------------------------------------------------------------- /utils/dynamic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | from torch import nn 5 | from tqdm import tqdm 6 | from matplotlib import pyplot as plt 7 | import torch.nn.functional as F 8 | from collections import defaultdict 9 | import os 10 | 11 | 12 | def rot2Euler(R): 13 | sy = torch.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) 14 | singular = sy < 1e-6 15 | 16 | if not singular: 17 | x = torch.atan2(R[2,1] , R[2,2]) 18 | y = torch.atan2(-R[2,0], sy) 19 | z = torch.atan2(R[1,0], R[0,0]) 20 | else: 21 | x = torch.atan2(-R[1,2], R[1,1]) 22 | y = torch.atan2(-R[2,0], sy) 23 | z = 0 24 | 25 | return torch.stack([x,y,z]) 26 | 27 | class unicycle(torch.nn.Module): 28 | 29 | def __init__(self, train_timestamp, centers=None, heights=None, phis=None): 30 | super(unicycle, self).__init__() 31 | self.train_timestamp = train_timestamp 32 | self.delta = torch.diff(self.train_timestamp) 33 | 34 | self.input_a = centers[:, 0].clone() 35 | self.input_b = centers[:, 1].clone() 36 | 37 | if centers is None: 38 | self.a = nn.Parameter(torch.zeros_like(train_timestamp).float()) 39 | self.b = nn.Parameter(torch.zeros_like(train_timestamp).float()) 40 | else: 41 | self.a = nn.Parameter(centers[:, 0]) 42 | self.b = nn.Parameter(centers[:, 1]) 43 | 44 | diff_a = torch.diff(centers[:, 0]) / self.delta 45 | diff_b = torch.diff(centers[:, 1]) / self.delta 46 | v = torch.sqrt(diff_a ** 2 + diff_b**2) 47 | self.v = nn.Parameter(F.pad(v, (0, 1), 'constant', v[-1].item())) 48 | self.phi = nn.Parameter(phis) 49 | 50 | if heights is None: 51 | self.h = nn.Parameter(torch.zeros_like(train_timestamp).float()) 52 | else: 53 | self.h = nn.Parameter(heights) 54 | 55 | def acc_omega(self): 56 | acc = torch.diff(self.v) / self.delta 57 | omega = torch.diff(self.phi) / self.delta 58 | acc = F.pad(acc, (0, 1), 'constant', acc[-1].item()) 59 | omega = F.pad(omega, (0, 1), 'constant', omega[-1].item()) 60 | return acc, omega 61 | 62 | def forward(self, timestamps): 63 | idx = torch.searchsorted(self.train_timestamp, timestamps, side='left') 64 | invalid = (idx == self.train_timestamp.shape[0]) 65 | idx[invalid] -= 1 66 | idx[self.train_timestamp[idx] != timestamps] -= 1 67 | idx[invalid] += 1 68 | prev_timestamps = self.train_timestamp[idx] 69 | delta_t = timestamps - prev_timestamps 70 | prev_a, prev_b = self.a[idx], self.b[idx] 71 | prev_v, prev_phi = self.v[idx], self.phi[idx] 72 | 73 | acc, omega = self.acc_omega() 74 | v = prev_v + acc[idx] * delta_t 75 | phi = prev_phi + omega[idx] * delta_t 76 | a = prev_a + prev_v * ((torch.sin(phi) - torch.sin(prev_phi)) / (omega[idx] + 1e-6)) 77 | b = prev_b - prev_v * ((torch.cos(phi) - torch.cos(prev_phi)) / (omega[idx] + 1e-6)) 78 | h = self.h[idx] 79 | return a, b, v, phi, h 80 | 81 | def capture(self): 82 | return ( 83 | self.a, 84 | self.b, 85 | self.v, 86 | self.phi, 87 | self.h, 88 | self.train_timestamp, 89 | self.delta 90 | ) 91 | 92 | def restore(self, model_args): 93 | ( 94 | self.a, 95 | self.b, 96 | self.v, 97 | self.phi, 98 | self.h, 99 | self.train_timestamp, 100 | self.delta 101 | ) = model_args 102 | 103 | def visualize(self, save_path, noise_centers=None, gt_centers=None): 104 | a, b, _, phi, _ = self.forward(self.train_timestamp) 105 | a = a.detach().cpu().numpy() 106 | b = b.detach().cpu().numpy() 107 | phi = phi.detach().cpu().numpy() 108 | plt.scatter(a, b, marker='x', color='b') 109 | plt.quiver(a, b, np.ones_like(a) * np.cos(phi), np.ones_like(b) * np.sin(phi), scale=20, width=0.005) 110 | if noise_centers is not None: 111 | noise_centers = noise_centers.detach().cpu().numpy() 112 | plt.scatter(noise_centers[:, 0], noise_centers[:, 1], marker='o', color='gray') 113 | if gt_centers is not None: 114 | gt_centers = gt_centers.detach().cpu().numpy() 115 | plt.scatter(gt_centers[:, 0], gt_centers[:, 1], marker='v', color='g') 116 | plt.axis('equal') 117 | plt.savefig(save_path) 118 | plt.close() 119 | 120 | def reg_loss(self): 121 | reg = 0 122 | acc, omega = self.acc_omega() 123 | reg += torch.mean(torch.abs(torch.diff(acc))) * 1 124 | reg += torch.mean(torch.abs(torch.diff(omega))) * 1 125 | reg_a_motion = self.v[:-1] * ((torch.sin(self.phi[1:]) - torch.sin(self.phi[:-1])) / (omega[:-1] + 1e-6)) 126 | reg_b_motion = -self.v[:-1] * ((torch.cos(self.phi[1:]) - torch.cos(self.phi[:-1])) / (omega[:-1] + 1e-6)) 127 | reg_a = self.a[:-1] + reg_a_motion 128 | reg_b = self.b[:-1] + reg_b_motion 129 | reg += torch.mean((reg_a - self.a[1:])**2 + (reg_b - self.b[1:])**2) * 1 130 | return reg 131 | 132 | def pos_loss(self): 133 | # a, b, _, _, _ = self.forward(self.train_timestamp) 134 | return torch.mean((self.a - self.input_a) ** 2 + (self.b - self.input_b) ** 2) * 10 135 | 136 | 137 | def create_unicycle_model(train_cams, model_path, opt_iter=0, data_type='kitti'): 138 | unicycle_models = {} 139 | if data_type == 'kitti': 140 | cameras = [cam for cam in train_cams if 'cam_0' in cam.image_name] 141 | elif data_type == 'waymo': 142 | cameras = [cam for cam in train_cams if 'cam_1' in cam.image_name] 143 | else: 144 | raise NotImplementedError 145 | 146 | all_centers, all_heights, all_phis, all_timestamps = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) 147 | seq_timestamps = [] 148 | for cam in cameras: 149 | t = cam.timestamp 150 | seq_timestamps.append(t) 151 | for track_id, b2w in cam.dynamics.items(): 152 | all_centers[track_id].append(b2w[[0, 2], 3]) 153 | all_heights[track_id].append(b2w[1, 3]) 154 | eulers = rot2Euler(b2w[:3, :3]) 155 | all_phis[track_id].append(eulers[1]) 156 | all_timestamps[track_id].append(t) 157 | 158 | for track_id in all_centers.keys(): 159 | centers = torch.stack(all_centers[track_id], dim=0).cuda() 160 | timestamps = torch.tensor(all_timestamps[track_id]).cuda() 161 | heights = torch.tensor(all_heights[track_id]).cuda() 162 | phis = torch.tensor(all_phis[track_id]).cuda() + torch.pi 163 | model = unicycle(timestamps, centers.clone(), heights.clone(), phis.clone()) 164 | l = [ 165 | {'params': [model.a], 'lr': 1e-2, "name": "a"}, 166 | {'params': [model.b], 'lr': 1e-2, "name": "b"}, 167 | {'params': [model.v], 'lr': 1e-3, "name": "v"}, 168 | {'params': [model.phi], 'lr': 1e-4, "name": "phi"}, 169 | {'params': [model.h], 'lr': 0, "name": "h"} 170 | ] 171 | 172 | optimizer = optim.Adam(l, lr=0.0) 173 | 174 | t_range = tqdm(range(opt_iter), desc=f"Fitting {track_id}") 175 | for iter in t_range: 176 | loss = 0.2 * model.pos_loss() + model.reg_loss() 177 | t_range.set_postfix({'loss': loss.item()}) 178 | optimizer.zero_grad() 179 | loss.backward() 180 | optimizer.step() 181 | 182 | unicycle_models[track_id] = {'model': model, 183 | 'optimizer': optimizer, 184 | 'input_centers': centers} 185 | 186 | os.makedirs(os.path.join(model_path, "unicycle"), exist_ok=True) 187 | for track_id, unicycle_pkg in unicycle_models.items(): 188 | model = unicycle_pkg['model'] 189 | optimizer = unicycle_pkg['optimizer'] 190 | 191 | model.visualize(os.path.join(model_path, "unicycle", f"{track_id}_init.png"), 192 | # noise_centers=unicycle_pkg['input_centers'] 193 | ) 194 | # gt_centers=gt_centers) 195 | 196 | return unicycle_models -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | import os 18 | import cv2 19 | 20 | def inverse_sigmoid(x): 21 | return torch.log(x/(1-x)) 22 | 23 | def PILtoTorch(pil_image, resolution): 24 | resized_image_PIL = pil_image.resize(resolution) 25 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 26 | if len(resized_image.shape) == 3: 27 | return resized_image.permute(2, 0, 1) 28 | else: 29 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 30 | 31 | def PIL2toTorch(pil_image, resolution): 32 | resized_image_PIL = pil_image.resize(resolution) 33 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 * (2.0 ** 16 - 1.0) 34 | return resized_image 35 | 36 | def decode_op(optical_png): 37 | # use 'PIL Image.Open' to READ 38 | "Convert from .png (h, w, 3-rgb) -> (h,w,2)(flow_x, flow_y) .. float32 array" 39 | optical_png = optical_png[..., [2, 1, 0]] # bgr -> rgb 40 | h, w, _c = optical_png.shape 41 | assert optical_png.dtype == np.uint16 and _c == 3 42 | "invalid flow flag: b == 0 for sky or other invalid flow" 43 | invalid_points = np.where(optical_png[..., 2] == 0) 44 | out_flow = torch.empty((h, w, 2)) 45 | decoded = 2.0 / (2**16 - 1.0) * optical_png.astype('f4') - 1 46 | out_flow[..., 0] = torch.tensor(decoded[:, :, 0] * (w - 1)) # (pixel) delta_x : R 47 | out_flow[..., 1] = torch.tensor(decoded[:, :, 1] * (h - 1)) # delta_y : G 48 | out_flow[invalid_points[0], invalid_points[1], :] = 0 # B=0 for invalid flow 49 | return out_flow 50 | 51 | def get_expon_lr_func( 52 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 53 | ): 54 | """ 55 | Copied from Plenoxels 56 | 57 | Continuous learning rate decay function. Adapted from JaxNeRF 58 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 59 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 60 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 61 | function of lr_delay_mult, such that the initial learning rate is 62 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 63 | to the normal learning rate when steps>lr_delay_steps. 64 | :param conf: config subtree 'lr' or similar 65 | :param max_steps: int, the number of steps during optimization. 66 | :return HoF which takes step as input 67 | """ 68 | 69 | def helper(step): 70 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 71 | # Disable this parameter 72 | return 0.0 73 | if lr_delay_steps > 0: 74 | # A kind of reverse cosine decay. 75 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 76 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 77 | ) 78 | else: 79 | delay_rate = 1.0 80 | t = np.clip(step / max_steps, 0, 1) 81 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 82 | return delay_rate * log_lerp 83 | 84 | return helper 85 | 86 | def strip_lowerdiag(L): 87 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 88 | 89 | uncertainty[:, 0] = L[:, 0, 0] 90 | uncertainty[:, 1] = L[:, 0, 1] 91 | uncertainty[:, 2] = L[:, 0, 2] 92 | uncertainty[:, 3] = L[:, 1, 1] 93 | uncertainty[:, 4] = L[:, 1, 2] 94 | uncertainty[:, 5] = L[:, 2, 2] 95 | return uncertainty 96 | 97 | def strip_symmetric(sym): 98 | return strip_lowerdiag(sym) 99 | 100 | def build_rotation(r): 101 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 102 | 103 | q = r / norm[:, None] 104 | 105 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 106 | 107 | r = q[:, 0] 108 | x = q[:, 1] 109 | y = q[:, 2] 110 | z = q[:, 3] 111 | 112 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 113 | R[:, 0, 1] = 2 * (x*y - r*z) 114 | R[:, 0, 2] = 2 * (x*z + r*y) 115 | R[:, 1, 0] = 2 * (x*y + r*z) 116 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 117 | R[:, 1, 2] = 2 * (y*z - r*x) 118 | R[:, 2, 0] = 2 * (x*z - r*y) 119 | R[:, 2, 1] = 2 * (y*z + r*x) 120 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 121 | return R 122 | 123 | def build_scaling_rotation(s, r): 124 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 125 | R = build_rotation(r) 126 | 127 | L[:,0,0] = s[:,0] 128 | L[:,1,1] = s[:,1] 129 | L[:,2,2] = s[:,2] 130 | 131 | L = R @ L 132 | return L 133 | 134 | DEFAULT_RANDOM_SEED = 0 135 | 136 | def seedBasic(seed=DEFAULT_RANDOM_SEED): 137 | random.seed(seed) 138 | os.environ['PYTHONHASHSEED'] = str(seed) 139 | np.random.seed(seed) 140 | 141 | def seedTorch(seed=DEFAULT_RANDOM_SEED): 142 | torch.manual_seed(seed) 143 | torch.cuda.manual_seed(seed) 144 | torch.backends.cudnn.deterministic = True 145 | torch.backends.cudnn.benchmark = False 146 | 147 | # basic + tensorflow + torch 148 | def seedEverything(seed=DEFAULT_RANDOM_SEED): 149 | seedBasic(seed) 150 | seedTorch(seed) 151 | 152 | def safe_state(silent): 153 | old_f = sys.stdout 154 | class F: 155 | def __init__(self, silent): 156 | self.silent = silent 157 | 158 | def write(self, x): 159 | if not self.silent: 160 | if x.endswith("\n"): 161 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 162 | else: 163 | old_f.write(x) 164 | 165 | def flush(self): 166 | old_f.flush() 167 | 168 | sys.stdout = F(silent) 169 | 170 | random.seed(DEFAULT_RANDOM_SEED) 171 | np.random.seed(DEFAULT_RANDOM_SEED) 172 | torch.manual_seed(DEFAULT_RANDOM_SEED) 173 | torch.cuda.set_device(torch.device("cuda:0")) 174 | # sys.stdout = old_f 175 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | # feats3D : np.array 22 | 23 | def geom_transform_points(points, transf_matrix): 24 | P, _ = points.shape 25 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 26 | points_hom = torch.cat([points, ones], dim=1) 27 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 28 | 29 | denom = points_out[..., 3:] + 0.0000001 30 | return (points_out[..., :3] / denom).squeeze(dim=0) 31 | 32 | def getWorld2View(R, t): 33 | Rt = np.zeros((4, 4)) 34 | Rt[:3, :3] = R.transpose() 35 | Rt[:3, 3] = t 36 | Rt[3, 3] = 1.0 37 | return np.float32(Rt) 38 | 39 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 40 | Rt = np.zeros((4, 4)) 41 | Rt[:3, :3] = R.transpose() 42 | Rt[:3, 3] = t 43 | Rt[3, 3] = 1.0 44 | 45 | C2W = np.linalg.inv(Rt) 46 | cam_center = C2W[:3, 3] 47 | cam_center = (cam_center + translate) * scale 48 | C2W[:3, 3] = cam_center 49 | Rt = np.linalg.inv(C2W) 50 | return np.float32(Rt) 51 | 52 | def getProjectionMatrix(znear, zfar, fovX, fovY, cx_ratio, cy_ratio): 53 | tanHalfFovY = math.tan((fovY / 2)) 54 | tanHalfFovX = math.tan((fovX / 2)) 55 | 56 | top = tanHalfFovY * znear 57 | bottom = -top 58 | right = tanHalfFovX * znear 59 | left = -right 60 | 61 | P = torch.zeros(4, 4) 62 | 63 | z_sign = 1.0 64 | 65 | P[0, 0] = 2.0 * znear / (right - left) 66 | P[1, 1] = 2.0 * znear / (top - bottom) 67 | P[0, 2] = (right + left) / (right - left) - 1 + cx_ratio 68 | P[1, 2] = (top + bottom) / (top - bottom) - 1 + cy_ratio 69 | P[3, 2] = z_sign 70 | P[2, 2] = z_sign * (zfar + znear) / (zfar - znear) 71 | P[2, 3] = -(2 * zfar * znear) / (zfar - znear) 72 | 73 | # P[0, 0] = 2.0 * znear / (right - left) 74 | # P[1, 1] = 2.0 * znear / (top - bottom) 75 | # P[0, 2] = (right + left) / (right - left) 76 | # P[1, 2] = (top + bottom) / (top - bottom) 77 | # P[3, 2] = z_sign 78 | # P[2, 2] = z_sign * zfar / (zfar - znear) 79 | # P[2, 3] = -(zfar * znear) / (zfar - znear) 80 | return P 81 | 82 | def fov2focal(fov, pixels): 83 | return pixels / (2 * math.tan(fov / 2)) 84 | 85 | def focal2fov(focal, pixels): 86 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/iou_utils.py: -------------------------------------------------------------------------------- 1 | # 3D IoU caculate code for 3D object detection 2 | # Kent 2018/12 3 | 4 | import numpy as np 5 | from scipy.spatial import ConvexHull 6 | from numpy import * 7 | 8 | def polygon_clip(subjectPolygon, clipPolygon): 9 | """ Clip a polygon with another polygon. 10 | 11 | Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python 12 | 13 | Args: 14 | subjectPolygon: a list of (x,y) 2d points, any polygon. 15 | clipPolygon: a list of (x,y) 2d points, has to be *convex* 16 | Note: 17 | **points have to be counter-clockwise ordered** 18 | 19 | Return: 20 | a list of (x,y) vertex point for the intersection polygon. 21 | """ 22 | def inside(p): 23 | return(cp2[0]-cp1[0])*(p[1]-cp1[1]) > (cp2[1]-cp1[1])*(p[0]-cp1[0]) 24 | 25 | def computeIntersection(): 26 | dc = [ cp1[0] - cp2[0], cp1[1] - cp2[1] ] 27 | dp = [ s[0] - e[0], s[1] - e[1] ] 28 | n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] 29 | n2 = s[0] * e[1] - s[1] * e[0] 30 | n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) 31 | return [(n1*dp[0] - n2*dc[0]) * n3, (n1*dp[1] - n2*dc[1]) * n3] 32 | 33 | outputList = subjectPolygon 34 | cp1 = clipPolygon[-1] 35 | 36 | for clipVertex in clipPolygon: 37 | cp2 = clipVertex 38 | inputList = outputList 39 | outputList = [] 40 | s = inputList[-1] 41 | 42 | for subjectVertex in inputList: 43 | e = subjectVertex 44 | if inside(e): 45 | if not inside(s): 46 | outputList.append(computeIntersection()) 47 | outputList.append(e) 48 | elif inside(s): 49 | outputList.append(computeIntersection()) 50 | s = e 51 | cp1 = cp2 52 | if len(outputList) == 0: 53 | return None 54 | return(outputList) 55 | 56 | def poly_area(x,y): 57 | """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """ 58 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 59 | 60 | def convex_hull_intersection(p1, p2): 61 | """ Compute area of two convex hull's intersection area. 62 | p1,p2 are a list of (x,y) tuples of hull vertices. 63 | return a list of (x,y) for the intersection and its volume 64 | """ 65 | inter_p = polygon_clip(p1,p2) 66 | if inter_p is not None: 67 | hull_inter = ConvexHull(inter_p) 68 | return inter_p, hull_inter.volume 69 | else: 70 | return None, 0.0 71 | 72 | def box3d_vol(corners): 73 | ''' corners: (8,3) no assumption on axis direction ''' 74 | a = np.sqrt(np.sum((corners[0,:] - corners[1,:])**2)) 75 | b = np.sqrt(np.sum((corners[1,:] - corners[2,:])**2)) 76 | c = np.sqrt(np.sum((corners[0,:] - corners[4,:])**2)) 77 | return a*b*c 78 | 79 | def is_clockwise(p): 80 | x = p[:,0] 81 | y = p[:,1] 82 | return np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)) > 0 83 | 84 | def box3d_iou(corners1, corners2): 85 | ''' Compute 3D bounding box IoU. 86 | 87 | Input: 88 | corners1: numpy array (8,3), assume up direction is negative Y 89 | corners2: numpy array (8,3), assume up direction is negative Y 90 | Output: 91 | iou: 3D bounding box IoU 92 | iou_2d: bird's eye view 2D bounding box IoU 93 | 94 | todo (kent): add more description on corner points' orders. 95 | ''' 96 | # corner points are in counter clockwise order 97 | rect1 = [(corners1[i,0], corners1[i,2]) for i in [4,5,1,0]] 98 | rect2 = [(corners2[i,0], corners2[i,2]) for i in [4,5,1,0]] 99 | 100 | area1 = poly_area(np.array(rect1)[:,0], np.array(rect1)[:,1]) 101 | area2 = poly_area(np.array(rect2)[:,0], np.array(rect2)[:,1]) 102 | 103 | inter, inter_area = convex_hull_intersection(rect1, rect2) 104 | iou_2d = inter_area/(area1+area2-inter_area) 105 | # if iou_2d < 0: 106 | # print(inter_area, area1, area2) 107 | # ymax = min(corners1[0,1], corners2[0,1]) 108 | # ymin = max(corners1[4,1], corners2[4,1]) 109 | 110 | # inter_vol = inter_area * max(0.0, ymax-ymin) 111 | 112 | # vol1 = box3d_vol(corners1) 113 | # vol2 = box3d_vol(corners2) 114 | # iou = inter_vol / (vol1 + vol2 - inter_vol) 115 | # return iou, iou_2d 116 | return 0, iou_2d 117 | 118 | # ---------------------------------- 119 | # Helper functions for evaluation 120 | # ---------------------------------- 121 | 122 | def get_3d_box(box_size, heading_angle, center): 123 | ''' Calculate 3D bounding box corners from its parameterization. 124 | 125 | Input: 126 | box_size: tuple of (length,wide,height) 127 | heading_angle: rad scalar, clockwise from pos x axis 128 | center: tuple of (x,y,z) 129 | Output: 130 | corners_3d: numpy array of shape (8,3) for 3D box cornders 131 | ''' 132 | def roty(t): 133 | c = np.cos(t) 134 | s = np.sin(t) 135 | return np.array([[c, 0, s], 136 | [0, 1, 0], 137 | [-s, 0, c]]) 138 | 139 | R = roty(heading_angle) 140 | l,w,h = box_size 141 | x_corners = [l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2]; 142 | y_corners = [h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2]; 143 | z_corners = [w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2]; 144 | corners_3d = np.dot(R, np.vstack([x_corners,y_corners,z_corners])) 145 | corners_3d[0,:] = corners_3d[0,:] + center[0]; 146 | corners_3d[1,:] = corners_3d[1,:] + center[1]; 147 | corners_3d[2,:] = corners_3d[2,:] + center[2]; 148 | corners_3d = np.transpose(corners_3d) 149 | return corners_3d 150 | 151 | 152 | if __name__=='__main__': 153 | print('------------------') 154 | # get_3d_box(box_size, heading_angle, center) 155 | corners_3d_ground = get_3d_box((1.497255,1.644981, 3.628938), -1.531692, (2.882992 ,1.698800 ,20.785644)) 156 | corners_3d_predict = get_3d_box((1.458242, 1.604773, 3.707947), -1.549553, (2.756923, 1.661275, 20.943280 )) 157 | (IOU_3d,IOU_2d)=box3d_iou(corners_3d_predict,corners_3d_ground) 158 | print (IOU_3d,IOU_2d) #3d IoU/ 2d IoU of BEV(bird eye's view) 159 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt, mask=None): 18 | l1 = torch.abs((network_output - gt)) 19 | if mask is not None: 20 | l1 = l1[:, mask] 21 | return l1.mean() 22 | 23 | def l2_loss(network_output, gt): 24 | return ((network_output - gt) ** 2).mean() 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | def create_window(window_size, channel): 31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 34 | return window 35 | 36 | def ssim(img1, img2, window_size=11, size_average=True): 37 | channel = img1.size(-3) 38 | window = create_window(window_size, channel) 39 | 40 | if img1.is_cuda: 41 | window = window.cuda(img1.get_device()) 42 | window = window.type_as(img1) 43 | 44 | return _ssim(img1, img2, window, window_size, channel, size_average) 45 | 46 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 49 | 50 | mu1_sq = mu1.pow(2) 51 | mu2_sq = mu2.pow(2) 52 | mu1_mu2 = mu1 * mu2 53 | 54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 57 | 58 | C1 = 0.01 ** 2 59 | C2 = 0.03 ** 2 60 | 61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map.mean(1).mean(1).mean(1) 67 | 68 | def ssim_loss(img1, img2, window_size=11, size_average=True, mask=None): 69 | channel = img1.size(-3) 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim_loss(img1, img2, window, window_size, channel, size_average, mask) 77 | 78 | def _ssim_loss(img1, img2, window, window_size, channel, size_average=True, mask=None): 79 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 80 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 81 | 82 | mu1_sq = mu1.pow(2) 83 | mu2_sq = mu2.pow(2) 84 | mu1_mu2 = mu1 * mu2 85 | 86 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 87 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 88 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 89 | 90 | C1 = 0.01 ** 2 91 | C2 = 0.03 ** 2 92 | 93 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 94 | ssim_map = 1 - ssim_map 95 | 96 | if mask is not None: 97 | ssim_map = ssim_map[:, mask] 98 | if size_average: 99 | return ssim_map.mean() 100 | else: 101 | return ssim_map.mean(1).mean(1).mean(1) -------------------------------------------------------------------------------- /utils/nvseg_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/data0/hyzhou/workspace/nv_seg") 3 | from network import get_model 4 | from config import cfg, torch_version_float 5 | from datasets.cityscapes import Loader as dataset_cls 6 | from runx.logx import logx 7 | import cv2 8 | import torch 9 | from imageio.v2 import imread, imwrite 10 | import os 11 | import numpy as np 12 | from glob import glob 13 | from tqdm import tqdm 14 | from torchvision.utils import save_image 15 | 16 | def restore_net(net, checkpoint): 17 | assert 'state_dict' in checkpoint, 'cant find state_dict in checkpoint' 18 | forgiving_state_restore(net, checkpoint['state_dict']) 19 | 20 | 21 | def forgiving_state_restore(net, loaded_dict): 22 | """ 23 | Handle partial loading when some tensors don't match up in size. 24 | Because we want to use models that were trained off a different 25 | number of classes. 26 | """ 27 | 28 | net_state_dict = net.state_dict() 29 | new_loaded_dict = {} 30 | for k in net_state_dict: 31 | new_k = k 32 | if new_k in loaded_dict and net_state_dict[k].size() == loaded_dict[new_k].size(): 33 | new_loaded_dict[k] = loaded_dict[new_k] 34 | else: 35 | logx.msg("Skipped loading parameter {}".format(k)) 36 | net_state_dict.update(new_loaded_dict) 37 | net.load_state_dict(net_state_dict) 38 | return net 39 | 40 | def get_nvseg_model(): 41 | logx.initialize(logdir="./results", 42 | global_rank=0) 43 | 44 | cfg.immutable(False) 45 | cfg.DATASET.NUM_CLASSES = dataset_cls.num_classes 46 | cfg.DATASET.IGNORE_LABEL = dataset_cls.ignore_label 47 | cfg.MODEL.MSCALE = True 48 | cfg.MODEL.N_SCALES = [0.5,1.0,2.0] 49 | cfg.MODEL.BNFUNC = torch.nn.BatchNorm2d 50 | cfg.OPTIONS.TORCH_VERSION = torch_version_float() 51 | cfg.DATASET_INST = dataset_cls('folder') 52 | cfg.immutable(True) 53 | colorize_mask_fn = cfg.DATASET_INST.colorize_mask 54 | 55 | net = get_model(network='network.ocrnet.HRNet_Mscale', 56 | num_classes=cfg.DATASET.NUM_CLASSES, 57 | criterion=None) 58 | 59 | snapshot = "ASSETS_PATH/seg_weights/cityscapes_trainval_ocr.HRNet_Mscale_nimble-chihuahua.pth".replace('ASSETS_PATH', cfg.ASSETS_PATH) 60 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 61 | renamed_ckpt = {'state_dict': {}} 62 | for k, v in checkpoint['state_dict'].items(): 63 | renamed_ckpt['state_dict'][k.replace('module.', '')] = v 64 | restore_net(net, renamed_ckpt) 65 | net = net.eval().cuda() 66 | return net -------------------------------------------------------------------------------- /utils/semantic_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # KITTI-360 labels 4 | # 5 | 6 | from collections import namedtuple 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | #-------------------------------------------------------------------------------- 12 | # Definitions 13 | #-------------------------------------------------------------------------------- 14 | 15 | # a label and all meta information 16 | Label = namedtuple( 'Label' , [ 17 | 18 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 19 | # We use them to uniquely name a class 20 | 21 | 'id' , # An integer ID that is associated with this label. 22 | # The IDs are used to represent the label in ground truth images 23 | # An ID of -1 means that this label does not have an ID and thus 24 | # is ignored when creating ground truth images (e.g. license plate). 25 | # Do not modify these IDs, since exactly these IDs are expected by the 26 | # evaluation server. 27 | 28 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 29 | # ground truth images with train IDs, using the tools provided in the 30 | # 'preparation' folder. However, make sure to validate or submit results 31 | # to our evaluation server using the regular IDs above! 32 | # For trainIds, multiple labels might have the same ID. Then, these labels 33 | # are mapped to the same class in the ground truth images. For the inverse 34 | # mapping, we use the label that is defined first in the list below. 35 | # For example, mapping all void-type classes to the same ID in training, 36 | # might make sense for some approaches. 37 | # Max value is 255! 38 | 39 | 'category' , # The name of the category that this label belongs to 40 | 41 | 'categoryId' , # The ID of this category. Used to create ground truth images 42 | # on category level. 43 | 44 | 'hasInstances', # Whether this label distinguishes between single instances or not 45 | 46 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 47 | # during evaluations or not 48 | 49 | 'color' , # The color of this label 50 | ] ) 51 | 52 | 53 | #-------------------------------------------------------------------------------- 54 | # A list of all labels 55 | #-------------------------------------------------------------------------------- 56 | 57 | # Please adapt the train IDs as appropriate for your approach. 58 | # Note that you might want to ignore labels with ID 255 during training. 59 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 60 | # Make sure to provide your results using the original IDs and not the training IDs. 61 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 62 | 63 | labels = [ 64 | # name id trainId category catId hasInstances ignoreInEval color 65 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 66 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 67 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 68 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 69 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 70 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 71 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 72 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 73 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 74 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 75 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 76 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 77 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 78 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 79 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 80 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 81 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 82 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 83 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 84 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 85 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 86 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 87 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 88 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 89 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 90 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 91 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 92 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 93 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 94 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 95 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 96 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 97 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 98 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 99 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 100 | ] 101 | 102 | 103 | #-------------------------------------------------------------------------------- 104 | # Create dictionaries for a fast lookup 105 | #-------------------------------------------------------------------------------- 106 | 107 | # Please refer to the main method below for example usages! 108 | 109 | # name to label object 110 | name2label = { label.name : label for label in labels } 111 | # id to label object 112 | id2label = { label.id : label for label in labels } 113 | # trainId to label object 114 | trainId2label = { label.trainId : label for label in reversed(labels) } 115 | # label2trainid 116 | label2trainid = { label.id : label.trainId for label in labels } 117 | # trainId to label object 118 | trainId2name = { label.trainId : label.name for label in labels } 119 | trainId2color = { label.trainId : label.color for label in labels } 120 | # category to list of label objects 121 | category2labels = {} 122 | for label in labels: 123 | category = label.category 124 | if category in category2labels: 125 | category2labels[category].append(label) 126 | else: 127 | category2labels[category] = [label] 128 | 129 | #-------------------------------------------------------------------------------- 130 | # color mapping 131 | #-------------------------------------------------------------------------------- 132 | 133 | palette = [128, 64, 128, 134 | 244, 35, 232, 135 | 70, 70, 70, 136 | 102, 102, 156, 137 | 190, 153, 153, 138 | 153, 153, 153, 139 | 250, 170, 30, 140 | 220, 220, 0, 141 | 107, 142, 35, 142 | 152, 251, 152, 143 | 70, 130, 180, 144 | 220, 20, 60, 145 | 255, 0, 0, 146 | 0, 0, 142, 147 | 0, 0, 70, 148 | 0, 60, 100, 149 | 0, 80, 100, 150 | 0, 0, 230, 151 | 119, 11, 32] 152 | zero_pad = 256 * 3 - len(palette) 153 | for i in range(zero_pad): 154 | palette.append(0) 155 | color_mapping = palette 156 | 157 | def colorize(image_array): 158 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 159 | new_mask.putpalette(color_mapping) 160 | return new_mask 161 | 162 | #-------------------------------------------------------------------------------- 163 | # Assure single instance name 164 | #-------------------------------------------------------------------------------- 165 | 166 | # returns the label name that describes a single instance (if possible) 167 | # e.g. input | output 168 | # ---------------------- 169 | # car | car 170 | # cargroup | car 171 | # foo | None 172 | # foogroup | None 173 | # skygroup | None 174 | def assureSingleInstanceName( name ): 175 | # if the name is known, it is not a group 176 | if name in name2label: 177 | return name 178 | # test if the name actually denotes a group 179 | if not name.endswith("group"): 180 | return None 181 | # remove group 182 | name = name[:-len("group")] 183 | # test if the new name exists 184 | if not name in name2label: 185 | return None 186 | # test if the new name denotes a label that actually has instances 187 | if not name2label[name].hasInstances: 188 | return None 189 | # all good then 190 | return name 191 | 192 | #-------------------------------------------------------------------------------- 193 | # Main for testing 194 | #-------------------------------------------------------------------------------- 195 | 196 | # just a dummy main 197 | if __name__ == "__main__": 198 | # Print all the labels 199 | print("List of KITTI-360 labels:") 200 | print("") 201 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) 202 | print(" " + ('-' * 98)) 203 | for label in labels: 204 | # print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) 205 | print(" \"{:}\"".format(label.name)) 206 | print("") 207 | 208 | print("Example usages:") 209 | 210 | # Map from name to label 211 | name = 'car' 212 | id = name2label[name].id 213 | print("ID of label '{name}': {id}".format( name=name, id=id )) 214 | 215 | # Map from ID to label 216 | category = id2label[id].category 217 | print("Category of label with ID '{id}': {category}".format( id=id, category=category )) 218 | 219 | # Map from trainID to label 220 | trainId = 0 221 | name = trainId2label[trainId].name 222 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) -------------------------------------------------------------------------------- /utils/vehicle_template/benz_kitti.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/utils/vehicle_template/benz_kitti.ply -------------------------------------------------------------------------------- /utils/vehicle_template/benz_kitti360.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/utils/vehicle_template/benz_kitti360.ply -------------------------------------------------------------------------------- /utils/vehicle_template/benz_pandaset.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/utils/vehicle_template/benz_pandaset.ply -------------------------------------------------------------------------------- /utils/vehicle_template/benz_waymo.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyzhou404/HUGS/dbb17df8c2b9d50fdfbcd097c93cec73d70100f9/utils/vehicle_template/benz_waymo.ply --------------------------------------------------------------------------------