├── .gitignore ├── LICENSE ├── README.md ├── downsample.py ├── renderer.py ├── requirements.txt ├── setup.py ├── splatter.py ├── src ├── bindings.cpp ├── gaussian.cu ├── gaussian_nosh.cu └── include │ └── common.hpp ├── train.py ├── transforms ├── __init__.py ├── _base.py ├── _se2.py ├── _se3.py ├── _so2.py ├── _so3.py ├── hints │ └── __init__.py └── utils │ ├── __init__.py │ └── _utils.py ├── utils.py └── visergui.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | External 3 | build 4 | fmatmul.egg-info 5 | matsrc 6 | *.egg-info 7 | *.png 8 | colmap_garden 9 | *.so 10 | .vscode 11 | *.pth 12 | *.pt 13 | *.mp4 14 | src/gaussian_nosh.cu 15 | RES.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 FengWang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3d-Gaussian-Splatting 2 | An unofficial Implementation of 3D Gaussian Splatting for Real-Time Radiance Field Rendering [SIGGRAPH 2023]. 3 | 4 | We implement the 3d gaussian splatting methods through PyTorch with CUDA extensions, including the global culling, tile-based culling and rendering forward/backward codes. 5 | 6 | Work in progress. 7 | #### Update 8 | - 6/26/2023 Fix bugs of SSIM criterion, PSNR is improved from 24.28 to 24.85 (Garden Scene) 9 | - 6/26/2023 Accelerate **Training** Speed from avg 4 it/s to 13 it/s, by (1) replacing part of atomicAdd by warp reduction primitive (2) fixing bugs for SSIM functions. The training costs 9 minutes for 7k iterations on Garden scene. 10 | 11 | | Scene | PSNR from paper | PSNR from this repo | Rendering Speed (official) | Rendering Speed (Ours) | 12 | | --- | --- | --- | --- | --- | 13 | | Garden | 25.82(5k) | 24.91 (7k) | 160 FPS (avg MIPNeRF360) | 60 FPS | 14 | | Garden | 25.82(5k) | 25.70 (7k) | 160 FPS (avg MIPNeRF360) | 25 FPS | 15 | 16 | 17 | 18 | https://github.com/WangFeng18/3d-gaussian-splatting/assets/43294876/79703b5d-50ae-404b-96c9-c73690646f34 19 | 20 | 21 | 22 | QuickStart 23 | 24 | #### Install CUDA Extensions 25 | ``` 26 | # compile CUDA extension 27 | pip install -e ./ 28 | ``` 29 | #### Data Preparation 30 | Put the colmap output in this folder, e.g., colmap_garden/sparse/0/, as well as the images. 31 | 32 | ### Traning 33 | ``` 34 | python train.py --exp garden --grad_thresh 0.000004 --debug 1 --ssim_weight 0.1 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --split_thresh 0.08 # PSNR 24.75 SSIM 71.95 FPS 70 N_Gaussians 376467 35 | python train.py --exp garden --grad_thresh 0.000004 --debug 1 --ssim_weight 0.1 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 # PSNR 25.03 SSIM 0.7541 FPS 40 N_GAUSSIANS 933918 36 | python train.py --exp garden --grad_thresh 0.000002 --debug 1 --ssim_weight 0.1 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --split_thresh 0.08 # PSNR 24.91 SSIM 73.18 FPS 64 N_GAUSSIANS 506627 GOOD 37 | 38 | python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 # PSNR 25.55 SSIM 79.83 N_GAUSSIANS 2418528 FPS 24.68 39 | 40 | CUDA_VISIBLE_DEVICES=3 python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 # PSNR 25.5586 SSIM 80.10 FPS 25.30 N_GAUSSIANS 2401413 41 | 42 | python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 --lr_factor_for_scale 0.2 --lr_factor_for_quat 10 --split_thresh 0.05 #PSNR 24.896 SSIM 76.55 FPS 65 N_GAUSSIANS 765932 43 | 44 | python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 --lr_factor_for_quat 10 # PSNR 25.6906 SSIM 80.66 FPS 24.68 45 | 46 | python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 --lr_factor_for_scale 0.5 --lr_factor_for_quat 10 --split_thresh 0.05 # PSNR 25.3769 SSIM 0.7902 FPS 41.3186 47 | 48 | CUDA_VISIBLE_DEVICES=3 python train.py --exp garden2 --grad_thresh 0.000004 --debug 1 --ssim_weight 0.2 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 --adaptive_control_end_iter 3000 --opa_init_value 0.05 --lr_factor_for_opa 20 --lr_factor_for_quat 20 # PSNR 25.7021 SSIM 0.8052 FPS 25.3567 49 | 50 | ``` 51 | 52 | ### Rendering With a GUI 53 | 54 | ``` 55 | python train.py --ckpt ckpt.pth --gui 1 --test 1 56 | ``` 57 | The GUI is based on [Viser](https://github.com/nerfstudio-project/viser) and written by [ZiLong Chen](https://github.com/heheyas). 58 | 59 | 60 | The transforms folder are from [Viser](https://github.com/nerfstudio-project/viser) 61 | 62 | ### Link 63 | Another good implementation for 3D gaussian splatting, by [Zilong Chen](https://github.com/heheyas/gaussian_splatting_3d) 64 | 65 | -------------------------------------------------------------------------------- /downsample.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | 5 | for i, img_path in enumerate(glob.glob("images_2/*.JPG")): 6 | img = cv2.imread(img_path) 7 | w, h = img.shape[1], img.shape[0] 8 | img = cv2.resize(img, (w//2, h//2)) 9 | cv2.imwrite("images_4/{}".format(os.path.basename(img_path)), img) 10 | print(i) 11 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import gaussian 4 | from torch.autograd.function import once_differentiable 5 | 6 | class _Drawer(torch.autograd.Function): 7 | @staticmethod 8 | def forward( 9 | ctx, 10 | gaussians_pos, 11 | gaussians_rgb, 12 | gaussians_opa, 13 | gaussians_cov, 14 | tile_n_point_accum, 15 | padded_height, 16 | padded_width, 17 | focal_x, 18 | focal_y, 19 | render_weight_normalize=False, 20 | sigmoid=False, 21 | use_sh_coeff=False, 22 | fast=False, 23 | rays_o=None, 24 | lefttop_pos=None, 25 | vec_dx=None, 26 | vec_dy=None 27 | ): 28 | rendered_image = torch.zeros(padded_height, padded_width, 3, device=gaussians_pos.device, dtype=torch.float32) 29 | gaussian.draw( 30 | gaussians_pos, 31 | gaussians_rgb, 32 | gaussians_opa, 33 | gaussians_cov, 34 | tile_n_point_accum, 35 | rendered_image, 36 | focal_x, 37 | focal_y, 38 | render_weight_normalize, 39 | sigmoid, 40 | fast, 41 | rays_o, 42 | lefttop_pos, 43 | vec_dx, 44 | vec_dy, 45 | use_sh_coeff, 46 | ) 47 | ctx.save_for_backward(gaussians_pos, gaussians_rgb, gaussians_opa, gaussians_cov, tile_n_point_accum, rendered_image, rays_o, lefttop_pos, vec_dx, vec_dy) 48 | ctx.focal_x = focal_x 49 | ctx.focal_y = focal_y 50 | ctx.render_weight_normalize = render_weight_normalize 51 | ctx.sigmoid = sigmoid 52 | ctx.fast = fast 53 | ctx.use_sh_coeff = use_sh_coeff 54 | return rendered_image 55 | 56 | @staticmethod 57 | def backward(ctx, grad_output): 58 | # grad_output 59 | gaussians_pos, gaussians_rgb, gaussians_opa, gaussians_cov, tile_n_point_accum, rendered_image, rays_o, lefttop_pos, vec_dx, vec_dy = ctx.saved_tensors 60 | grad_pos = torch.zeros_like(gaussians_pos) 61 | grad_rgb = torch.zeros_like(gaussians_rgb) 62 | grad_opa = torch.zeros_like(gaussians_opa) 63 | grad_cov = torch.zeros_like(gaussians_cov) 64 | gaussian.draw_backward( 65 | gaussians_pos, 66 | gaussians_rgb, 67 | gaussians_opa, 68 | gaussians_cov, 69 | tile_n_point_accum, 70 | rendered_image, 71 | grad_output, 72 | grad_pos, 73 | grad_rgb, 74 | grad_opa, 75 | grad_cov, 76 | ctx.focal_x, 77 | ctx.focal_y, 78 | ctx.render_weight_normalize, 79 | ctx.sigmoid, 80 | ctx.fast, 81 | rays_o, 82 | lefttop_pos, 83 | vec_dx, 84 | vec_dy, 85 | ctx.use_sh_coeff, 86 | ) 87 | return grad_pos, grad_rgb, grad_opa, grad_cov, None, None, None, None, None, None, None, None, None, None, None, None, None 88 | 89 | draw = _Drawer.apply 90 | 91 | class _trunc_exp(torch.autograd.Function): 92 | @staticmethod 93 | def forward(ctx, x): 94 | ctx.save_for_backward(x) 95 | return torch.exp(x) 96 | 97 | @staticmethod 98 | def backward(ctx, g): 99 | x = ctx.saved_tensors[0] 100 | return g * torch.exp(x.clamp(-1, 1)) 101 | 102 | trunc_exp = _trunc_exp.apply 103 | 104 | class _world2camera(torch.autograd.Function): 105 | @staticmethod 106 | def forward(ctx, pos, rot, tran): 107 | ctx.save_for_backward(rot) 108 | res = torch.zeros_like(pos) 109 | gaussian.world2camera(pos, rot, tran, res) 110 | return res 111 | 112 | @staticmethod 113 | def backward(ctx, grad_out): 114 | rot = ctx.saved_tensors[0] 115 | grad_inp = torch.zeros_like(grad_out) 116 | gaussian.world2camera_backward(grad_out, rot, grad_inp) 117 | return grad_inp, None, None 118 | 119 | world2camera_func = _world2camera.apply 120 | 121 | class _GlobalCulling(torch.autograd.Function): 122 | @staticmethod 123 | def forward(ctx, pos, quat, scale, current_rot, current_tran, near, half_width, half_height): 124 | res_pos = torch.zeros_like(pos) 125 | res_cov = torch.zeros((pos.shape[0], 2, 2), device=pos.device) 126 | culling_mask = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device) 127 | 128 | gaussian.global_culling( 129 | pos, quat, scale, current_rot, current_tran, res_pos, res_cov, culling_mask, near, half_width, half_height 130 | ) 131 | ctx.save_for_backward(culling_mask, pos, quat, scale, current_rot, current_tran) 132 | return res_pos, res_cov, culling_mask.detach() 133 | 134 | @staticmethod 135 | def backward(ctx, gradout_pos, gradout_cov, grad_culling_mask): 136 | culling_mask = ctx.saved_tensors[0] 137 | pos = ctx.saved_tensors[1] 138 | quat = ctx.saved_tensors[2] 139 | scale = ctx.saved_tensors[3] 140 | current_rot = ctx.saved_tensors[4] 141 | current_tran = ctx.saved_tensors[5] 142 | 143 | gradinput_pos = torch.zeros_like(gradout_pos) 144 | gradinput_quat = torch.zeros((gradout_pos.shape[0], 4), device=gradout_pos.device) 145 | gradinput_scale = torch.zeros((gradout_pos.shape[0], 3), device=gradout_pos.device) 146 | gaussian.global_culling_backward( 147 | pos, quat, scale, current_rot, current_tran, 148 | gradout_pos, 149 | gradout_cov, 150 | culling_mask, 151 | gradinput_pos, 152 | gradinput_quat, 153 | gradinput_scale, 154 | ) 155 | 156 | return gradinput_pos, gradinput_quat, gradinput_scale, None, None, None, None, None 157 | 158 | global_culling = _GlobalCulling.apply 159 | 160 | 161 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | tqdm 5 | opencv-python 6 | torchmetrics 7 | einops 8 | torchgeometry 9 | kornia 10 | viser 11 | trimesh 12 | omegaconf 13 | pykdtree -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | # '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++14', '/w'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | include_dirs = [os.path.join("src", "include")] 33 | setup( 34 | name='gaussian', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='gaussian', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | # 'matmul.cu', 40 | 'gaussian.cu', 41 | 'bindings.cpp', 42 | ]], 43 | # include_dirs=include_dirs, 44 | extra_compile_args={ 45 | 'cxx': c_flags, 46 | 'nvcc': nvcc_flags, 47 | } 48 | ), 49 | ], 50 | cmdclass={ 51 | 'build_ext': BuildExtension, 52 | } 53 | ) -------------------------------------------------------------------------------- /splatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import gaussian 6 | from utils import read_points3d_binary, read_cameras_binary, read_images_binary, q2r, jacobian_torch, initialize_sh, inverse_sigmoid, inverse_sigmoid_torch, Timer, sample_two_point, Camera 7 | from dataclasses import dataclass, field 8 | import transforms as tf 9 | import time 10 | import math 11 | import cv2 12 | import numpy as np 13 | from typing import Any, List 14 | from einops import repeat 15 | from renderer import draw, trunc_exp, global_culling, world2camera_func 16 | from tqdm import tqdm 17 | import argparse 18 | from pykdtree.kdtree import KDTree 19 | EPS=1e-4 20 | 21 | def world_to_camera(points, rot, tran): 22 | # r = torch.empty_like(points) 23 | # gaussian.world2camera(points, rot, tran, r) 24 | # return r 25 | return world2camera_func(points, rot, tran) 26 | # _r = points @ rot.T + tran.unsqueeze(0) 27 | # return _r 28 | 29 | def camera_to_image(points_camera_space): 30 | points_image_space = [ 31 | points_camera_space[:,0]/points_camera_space[:,2], 32 | points_camera_space[:,1]/points_camera_space[:,2], 33 | points_camera_space.norm(dim=-1), 34 | ] 35 | return torch.stack(points_image_space, dim=-1) 36 | 37 | 38 | #TODO ensure if the camera coordinate system is normalized 39 | class Gaussian3ds(nn.Module): 40 | def __init__(self, pos, rgb, opa, quat=None, scale=None, cov=None, init_values=False): 41 | super().__init__() 42 | self.init_values = init_values 43 | if init_values: 44 | self.pos = nn.parameter.Parameter(pos) 45 | self.rgb = nn.parameter.Parameter(rgb) 46 | self.opa = nn.parameter.Parameter(opa) 47 | self.quat = quat if quat is None else nn.parameter.Parameter(quat) 48 | self.scale = scale if scale is None else nn.parameter.Parameter(scale) 49 | self.cov = cov if cov is None else nn.parameter.Parameter(cov) 50 | else: 51 | self.pos = pos 52 | self.rgb = rgb 53 | self.opa = opa 54 | self.quat = quat 55 | self.scale = scale 56 | self.cov = cov 57 | 58 | def to_scale_matrix(self): 59 | return torch.diag_embed(self.scale) 60 | 61 | def _tocpp(self): 62 | _cobj = gaussian.Gaussian3ds() 63 | _cobj.pos = self.pos.clone() 64 | _cobj.rgb = self.rgb.clone() 65 | _cobj.opa = self.opa.clone() 66 | _cobj.cov = self.cov.clone() 67 | return _cobj 68 | 69 | def to(self, *args, **kwargs): 70 | self.pos.to(*args, **kwargs) 71 | self.rgb.to(*args, **kwargs) 72 | self.opa.to(*args, **kwargs) 73 | if self.quat is not None: 74 | self.quat.to(*args, **kwargs) 75 | if self.scale is not None: 76 | self.scale.to(*args, **kwargs) 77 | if self.cov is not None: 78 | self.cov.to(*args, **kwargs) 79 | 80 | def filte(self, mask): 81 | if self.quat is not None and self.scale is not None: 82 | assert self.cov is None 83 | return Gaussian3ds( 84 | pos=self.pos[mask], 85 | rgb=self.rgb[mask], 86 | opa=self.opa[mask], 87 | quat=self.quat[mask], 88 | scale=self.scale[mask], 89 | ) 90 | else: 91 | assert self.cov is not None 92 | return Gaussian3ds( 93 | pos=self.pos[mask], 94 | rgb=self.rgb[mask], 95 | opa=self.opa[mask], 96 | cov=self.cov[mask], 97 | ) 98 | 99 | # Deprecated, now fused to CUDA kernel in gaussian.global culling 100 | def get_gaussian_3d_cov(self, scale_activation="abs"): 101 | R = q2r(self.quat) 102 | if scale_activation == "abs": 103 | _scale = self.scale.abs()+EPS 104 | elif scale_activation == "exp": 105 | _scale = trunc_exp(self.scale) 106 | else: 107 | print("No support scale activation") 108 | exit() 109 | # _scale = trunc_exp(self.scale) 110 | # _scale = torch.clamp(_scale, min=1e-4, max=0.1) 111 | S = torch.diag_embed(_scale) 112 | RS = torch.bmm(R, S) 113 | RSSR = torch.bmm(RS, RS.permute(0,2,1)) 114 | return RSSR 115 | 116 | # def stable_cov(self): 117 | # return self.cov + 1e-2*torch.eye(2).unsqueeze(dim=0).to(self.cov) 118 | 119 | def reset_opa(self): 120 | torch.nn.init.uniform_(self.opa, a=inverse_sigmoid(0.01), b=inverse_sigmoid(0.01)) 121 | 122 | def adaptive_control( 123 | self, 124 | grad, 125 | taus, 126 | delete_thresh, 127 | scale_activation="abs", 128 | grad_thresh=0.0002, 129 | grad_aggregation="max", 130 | use_clone=True, 131 | use_split=True, 132 | clone_dt=0.01, 133 | ): 134 | # grad: B x 3 135 | # densification 136 | # 1. delete gaussians with small opacities 137 | assert self.init_values # only for the initial gaussians 138 | # print(inverse_sigmoid(0.01)) 139 | # print(self.opa.min()) 140 | # print(self.opa.max()) 141 | if scale_activation == "abs": 142 | _mask = (self.opa > inverse_sigmoid(0.02)) & (self.scale.norm(dim=-1) < delete_thresh) #& (self.scale.abs().max(dim=-1)[0] > 5e-4) 143 | elif scale_activation == "exp": 144 | _mask = (self.opa > inverse_sigmoid(0.02)) & (self.scale.exp().norm(dim=-1) < delete_thresh) #& (self.scale.exp().max(dim=-1)[0] > 1e-4) 145 | else: 146 | print("Wrong activation") 147 | exit() 148 | 149 | self.pos = nn.parameter.Parameter(self.pos.detach()[_mask]) 150 | self.rgb = nn.parameter.Parameter(self.rgb.detach()[_mask]) 151 | self.opa = nn.parameter.Parameter(self.opa.detach()[_mask]) 152 | self.quat = nn.parameter.Parameter(self.quat.detach()[_mask]) 153 | self.scale = nn.parameter.Parameter(self.scale.detach()[_mask]) 154 | grad = grad[_mask] 155 | print("DELETE: {} Gaussians".format((~_mask).sum())) 156 | # 2. clone or split 157 | 158 | if grad_aggregation == "max": 159 | densify_mask = grad.abs().max(-1)[0] > grad_thresh 160 | else: 161 | assert grad_aggregation == "mean" 162 | densify_mask = grad.abs().mean(-1) > grad_thresh 163 | 164 | cat_pos = [self.pos.clone().detach()] 165 | cat_rgb = [self.rgb.clone().detach()] 166 | cat_opa = [self.opa.clone().detach()] 167 | cat_quat = [self.quat.clone().detach()] 168 | cat_scale = [self.scale.clone().detach()] 169 | if densify_mask.any(): 170 | scale_norm = self.scale.norm(dim=-1) if scale_activation == "abs" else self.scale.exp().norm(dim=-1) 171 | split_mask = scale_norm > taus 172 | clone_mask = scale_norm <= taus 173 | split_mask = split_mask & densify_mask 174 | clone_mask = clone_mask & densify_mask 175 | 176 | if clone_mask.any() and use_clone: 177 | cloned_pos = self.pos[clone_mask].clone().detach() 178 | cloned_pos -= grad[clone_mask] * clone_dt 179 | cloned_rgb = self.rgb[clone_mask].clone().detach() 180 | cloned_opa = self.opa[clone_mask].clone().detach() 181 | cloned_quat = self.quat[clone_mask].clone().detach() 182 | cloned_scale = self.scale[clone_mask].clone().detach() 183 | print("CLONE: {} Gaussians".format(cloned_pos.shape[0])) 184 | cat_pos.append(cloned_pos) 185 | cat_rgb.append(cloned_rgb) 186 | cat_opa.append(cloned_opa) 187 | cat_quat.append(cloned_quat) 188 | cat_scale.append(cloned_scale) 189 | 190 | if split_mask.any() and use_split: 191 | _scale = self.scale.clone().detach() 192 | if scale_activation == "abs": 193 | _scale[split_mask] /= 1.6 194 | elif scale_activation == "exp": 195 | _scale[split_mask] -= math.log(1.6) 196 | else: 197 | print("Wrong activation") 198 | exit() 199 | 200 | cat_scale[0] = _scale 201 | # cat_scale[0][split_mask] /= 1.6 202 | # self.scale = nn.parameter.Parameter(_scale) 203 | 204 | # sampling two positions 205 | this_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation)[split_mask] 206 | p1, p2 = sample_two_point(self.pos[split_mask], this_cov) 207 | 208 | # split_pos = self.pos[split_mask].clone().detach() 209 | # split_pos -= grad[split_mask] * 0.01 210 | origin_pos = cat_pos[0] 211 | origin_pos[split_mask] = p1.detach() 212 | cat_pos[0] = origin_pos 213 | split_pos = p2.detach() 214 | split_rgb = self.rgb[split_mask].clone().detach() 215 | split_opa = self.opa[split_mask].clone().detach() 216 | split_quat = self.quat[split_mask].clone().detach() 217 | split_scale = _scale[split_mask].clone() 218 | print("SPLIT : {} Gaussians".format(split_pos.shape[0])) 219 | cat_pos.append(split_pos) 220 | cat_rgb.append(split_rgb) 221 | cat_opa.append(split_opa) 222 | cat_quat.append(split_quat) 223 | cat_scale.append(split_scale) 224 | self.pos = nn.parameter.Parameter(torch.cat(cat_pos)) 225 | self.rgb = nn.parameter.Parameter(torch.cat(cat_rgb)) 226 | self.opa = nn.parameter.Parameter(torch.cat(cat_opa)) 227 | self.quat = nn.parameter.Parameter(torch.cat(cat_quat)) 228 | self.scale = nn.parameter.Parameter(torch.cat(cat_scale)) 229 | 230 | # Deprecated, now fused to CUDA kernel in gaussian.global culling 231 | def project(self, rot, tran, near, jacobian_calc, scale_activation="abs"): 232 | 233 | pos_cam_space = world_to_camera(self.pos, rot, tran) 234 | pos_img_space = camera_to_image(pos_cam_space) 235 | 236 | if jacobian_calc == "cuda": 237 | jacobian = torch.empty(pos_cam_space.shape[0], 3, 3, device=self.pos.device) 238 | gaussian.jacobian(pos_cam_space, jacobian) 239 | else: 240 | jacobian = jacobian_torch(pos_cam_space) 241 | gaussian_3d_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation) 242 | # JW = torch.einsum("bij,bjk->bik", jacobian, rot.unsqueeze(dim=0)) 243 | JW = torch.matmul(jacobian, rot.unsqueeze(dim=0)) 244 | JWC = torch.bmm(JW, gaussian_3d_cov) 245 | gaussian_2d_cov = torch.bmm(JWC, JW.permute(0,2,1))[:, :2, :2] 246 | 247 | gaussian_3ds_image_space = Gaussian3ds( 248 | pos=pos_img_space, 249 | rgb=self.rgb.sigmoid(), 250 | opa=self.opa.sigmoid(), 251 | cov=gaussian_2d_cov, 252 | ) 253 | return gaussian_3ds_image_space 254 | 255 | class Tiles: 256 | def __init__(self, width, height, focal_x, focal_y, device): 257 | self.width = width 258 | self.height = height 259 | self.padded_width = int(math.ceil(self.width/16)) * 16 260 | self.padded_height = int(math.ceil(self.height/16)) * 16 261 | self.focal_x = focal_x 262 | self.focal_y = focal_y 263 | self.n_tile_x = self.padded_width // 16 264 | self.n_tile_y = self.padded_height // 16 265 | self.device = device 266 | 267 | def crop(self, image): 268 | # image: padded_height x padded_width x 3 269 | # output: height x width x 3 270 | top = int(self.padded_height - self.height)//2 271 | left = int(self.padded_width - self.width)//2 272 | return image[top:top+int(self.height), left:left+int(self.width), :] 273 | 274 | def create_tiles(self): 275 | self.tiles_left = torch.linspace(-self.padded_width/2, self.padded_width/2, self.n_tile_x + 1, device=self.device)[:-1] 276 | self.tiles_right = self.tiles_left + 16 277 | self.tiles_top = torch.linspace(-self.padded_height/2, self.padded_height/2, self.n_tile_y + 1, device=self.device)[:-1] 278 | self.tiles_bottom = self.tiles_top + 16 279 | self.tile_geo_length_x = 16 / self.focal_x 280 | self.tile_geo_length_y = 16 / self.focal_y 281 | self.leftmost = -self.padded_width/2/self.focal_x 282 | self.topmost = -self.padded_height/2/self.focal_y 283 | 284 | self.tiles_left = self.tiles_left/self.focal_x 285 | self.tiles_top = self.tiles_top/self.focal_y 286 | self.tiles_right = self.tiles_right/self.focal_x 287 | self.tiles_bottom = self.tiles_bottom/self.focal_y 288 | 289 | self.tiles_left = repeat(self.tiles_left, "b -> (c b)", c=self.n_tile_y) 290 | self.tiles_right = repeat(self.tiles_right, "b -> (c b)", c=self.n_tile_y) 291 | 292 | self.tiles_top = repeat(self.tiles_top, "b -> (b c)", c=self.n_tile_x) 293 | self.tiles_bottom = repeat(self.tiles_bottom, "b -> (b c)", c=self.n_tile_x) 294 | 295 | _tile = gaussian.Tiles() 296 | _tile.top = self.tiles_top 297 | _tile.bottom = self.tiles_bottom 298 | _tile.left = self.tiles_left 299 | _tile.right = self.tiles_right 300 | return _tile 301 | 302 | def __len__(self): 303 | return self.tiles_top.shape[0] 304 | 305 | class RayInfo: 306 | def __init__(self, w2c, tran, H, W, focal_x, focal_y): 307 | self.w2c = w2c 308 | self.c2w = torch.inverse(w2c) 309 | self.tran = tran 310 | self.H = H 311 | self.W = W 312 | self.focal_x = focal_x 313 | self.focal_y = focal_y 314 | self.rays_o = - self.c2w @ tran 315 | 316 | lefttop_cam = torch.Tensor([(-W/2 + 0.5)/focal_x, (-H/2 + 0.5)/focal_y, 1.0]).to(self.w2c.device) 317 | dx_cam = torch.Tensor([1./focal_x, 0, 0]).to(self.w2c.device) 318 | dy_cam = torch.Tensor([0, 1./focal_y, 0]).to(self.w2c.device) 319 | self.lefttop = self.c2w @ (lefttop_cam - tran) 320 | self.dx = self.c2w @ dx_cam 321 | self.dy = self.c2w @ dy_cam 322 | 323 | class Splatter(nn.Module): 324 | def __init__(self, 325 | colmap_path, 326 | image_path, 327 | near=0.3, 328 | #near=1.1, 329 | jacobian_calc="cuda", 330 | render_downsample=1, 331 | use_sh_coeff=False, 332 | render_weight_normalize=False, 333 | opa_init_value=0.1, 334 | scale_init_value=0.02, 335 | tile_culling_method="dist", # dist or prob 336 | tile_culling_dist_thresh=0.5, 337 | tile_culling_prob_thresh=0.1, 338 | debug=1, 339 | scale_activation="abs", 340 | cudaculling=0, 341 | load_ckpt=None, 342 | debug_align=False, 343 | fast_drawing=False, 344 | test=False, 345 | ): 346 | super().__init__() 347 | self.device = torch.device("cuda") 348 | self.use_sh_coeff = use_sh_coeff 349 | self.near = near 350 | self.jacobian_calc = jacobian_calc 351 | self.render_downsample = render_downsample 352 | self.render_weight_normalize = render_weight_normalize 353 | self.tile_culling_method = tile_culling_method 354 | self.tile_culling_dist_thresh = tile_culling_dist_thresh 355 | self.tile_culling_prob_thresh = tile_culling_prob_thresh 356 | self.debug = debug 357 | self.scale_activation = scale_activation 358 | self.cudaculling = cudaculling 359 | assert jacobian_calc == "cuda" or jacobian_calc == "torch" 360 | self.fast_drawing = fast_drawing 361 | 362 | self.points3d = read_points3d_binary(os.path.join(colmap_path, "points3D.bin")) 363 | self.cameras = read_cameras_binary(os.path.join(colmap_path,"cameras.bin")) 364 | self.images_info = read_images_binary(os.path.join(colmap_path,"images.bin")) 365 | self.image_path = image_path 366 | self.test = test 367 | if not self.test: 368 | self.parse_imgs() 369 | # self.vis_culling = Vis() 370 | self.tic = torch.cuda.Event(enable_timing=True) 371 | self.toc = torch.cuda.Event(enable_timing=True) 372 | 373 | _points = [] 374 | _rgbs = [] 375 | for pid, point in self.points3d.items(): 376 | _points.append(torch.from_numpy(point.xyz)) 377 | if self.use_sh_coeff: 378 | _rgbs.append(inverse_sigmoid_torch(torch.from_numpy(point.rgb/255.))) 379 | else: 380 | _rgbs.append(inverse_sigmoid_torch(torch.from_numpy(point.rgb/255.))) 381 | # _rgbs.append(torch.from_numpy(point.rgb)) 382 | # self.vis_culling.add_item(point.id, point.xyz, point.rgb, point.error, point.image_ids, point.point2D_idxs) 383 | rgb = torch.stack(_rgbs).to(torch.float32).to(self.device) # B x 3 384 | if self.use_sh_coeff: 385 | rgb = initialize_sh(rgb) 386 | # rgb = torch.zeros(rgb.shape[0], 27).to(torch.float32).to(self.device) 387 | 388 | _pos=torch.stack(_points).to(torch.float32).to(self.device) 389 | if load_ckpt is None: 390 | _pos_np = _pos.cpu().numpy() 391 | kd_tree = KDTree(_pos_np) 392 | dist, idx = kd_tree.query(_pos_np, k=4) 393 | mean_min_three_dis = dist[:, 1:].mean(axis=1) 394 | mean_min_three_dis = torch.Tensor(mean_min_three_dis).to(torch.float32) * scale_init_value 395 | 396 | if scale_activation == "exp": 397 | mean_min_three_dis = mean_min_three_dis.log() 398 | 399 | self.gaussian_3ds = Gaussian3ds( 400 | pos=_pos.to(self.device), # B x 3 401 | rgb = rgb, # B x 3 or 27 402 | opa = torch.ones(len(_points)).to(torch.float32).to(self.device)*inverse_sigmoid(opa_init_value), # B 403 | quat = torch.Tensor([1, 0, 0, 0]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4 404 | scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*mean_min_three_dis.unsqueeze(dim=1).to(self.device), 405 | init_values=True, 406 | ) 407 | else: 408 | self.gaussian_3ds = Gaussian3ds( 409 | pos=_pos.to(self.device), # B x 3 410 | rgb = rgb, # B x 3 or 27 411 | opa = torch.ones(len(_points)).to(torch.float32).to(self.device)*inverse_sigmoid(opa_init_value), # B 412 | quat = torch.Tensor([1, 0, 0, 0]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4 413 | scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device), 414 | init_values=True, 415 | ) 416 | 417 | if load_ckpt is not None: 418 | # load checkpoint 419 | ckpt = torch.load(load_ckpt) 420 | self.gaussian_3ds.pos = nn.Parameter(ckpt["pos"]) 421 | self.gaussian_3ds.opa = nn.Parameter(ckpt["opa"]) 422 | self.gaussian_3ds.rgb = nn.Parameter(ckpt["rgb"]) 423 | self.gaussian_3ds.quat = nn.Parameter(ckpt["quat"]) 424 | self.gaussian_3ds.scale = nn.Parameter(ckpt["scale"]) 425 | self.current_camera = None 426 | if not self.test: 427 | self.set_camera(0) 428 | 429 | def parse_imgs(self): 430 | img_ids = sorted([im.id for im in self.images_info.values()]) 431 | self.w2c_quats = [] 432 | self.w2c_rots = [] 433 | self.w2c_trans = [] 434 | self.cam_ids = [] 435 | self.imgs = [] 436 | for img_id in tqdm(img_ids): 437 | img_info = self.images_info[img_id] 438 | cam = self.cameras[img_info.camera_id] 439 | image_filename = os.path.join(self.image_path, img_info.name) 440 | if not os.path.exists(image_filename): 441 | continue 442 | _current_image = cv2.imread(image_filename) 443 | _current_image = cv2.cvtColor(_current_image, cv2.COLOR_BGR2RGB) 444 | self.imgs.append(torch.from_numpy(_current_image).to(torch.uint8).to(self.device)) 445 | 446 | T_world_camera = tf.SE3.from_rotation_and_translation( 447 | tf.SO3(img_info.qvec), img_info.tvec, 448 | )#.inverse() 449 | self.w2c_quats.append(torch.from_numpy(T_world_camera.rotation().wxyz).to(torch.float32).to(self.device)) 450 | self.w2c_trans.append(torch.from_numpy(T_world_camera.translation()).to(torch.float32).to(self.device)) 451 | self.w2c_rots.append(q2r(self.w2c_quats[-1].unsqueeze(0)).squeeze().to(torch.float32).to(self.device)) 452 | # print(self.w2c_trans) 453 | # print(self.w2c_rots) 454 | self.cam_ids.append(img_info.camera_id) 455 | 456 | def switch_resolution(self, downsample_factor): 457 | if downsample_factor == self.render_downsample: 458 | return 459 | self.image_path = self.image_path.replace(f"images_{self.render_downsample}", f"images_{downsample_factor}") 460 | self.render_downsample = downsample_factor 461 | self.parse_imgs() 462 | self.current_camera = None 463 | self.set_camera(0) 464 | # print(torch.stack(self.w2c_trans, dim=0).mean(0)) 465 | # print(torch.stack(self.w2c_rots, dim=0).mean(0)) 466 | 467 | def set_camera(self, idx, extrinsics=None, intrinsics=None): 468 | if idx is None: 469 | # print(extrinsics) 470 | self.current_w2c_rot = torch.from_numpy(extrinsics["rot"]).to(torch.float32).to(self.device) 471 | self.current_w2c_tran = torch.from_numpy(extrinsics["tran"]).to(torch.float32).to(self.device) 472 | self.current_w2c_quat = None 473 | self.ground_truth = None 474 | self.current_camera = Camera( 475 | id=-1, model="pinhole", width=intrinsics["width"], height=intrinsics["height"], 476 | params = np.array( 477 | [intrinsics["focal_x"], intrinsics["focal_y"]] 478 | ), 479 | ) 480 | self.tile_info = Tiles( 481 | math.ceil(intrinsics["width"]), 482 | math.ceil(intrinsics["height"]), 483 | intrinsics["focal_x"], 484 | intrinsics["focal_y"], 485 | self.device 486 | ) 487 | self.tile_info_cpp = self.tile_info.create_tiles() 488 | else: 489 | with Timer(" set image", debug=self.debug): 490 | self.current_w2c_quat = self.w2c_quats[idx] 491 | self.current_w2c_tran = self.w2c_trans[idx] 492 | self.current_w2c_rot = self.w2c_rots[idx] 493 | self.ground_truth = self.imgs[idx].to(torch.float16)/255. 494 | with Timer(" set camera", debug=self.debug): 495 | if self.cameras[self.cam_ids[idx]] != self.current_camera: 496 | self.current_camera = self.cameras[self.cam_ids[idx]] 497 | width = self.current_camera.width / self.render_downsample 498 | height = self.current_camera.height / self.render_downsample 499 | focal_x = self.current_camera.params[0] / self.render_downsample 500 | focal_y = self.current_camera.params[1] / self.render_downsample 501 | self.tile_info = Tiles(int(self.ground_truth.shape[1]), int(self.ground_truth.shape[0]), focal_x, focal_y, self.device) 502 | self.tile_info_cpp = self.tile_info.create_tiles() 503 | 504 | self.ray_info = RayInfo( 505 | w2c=self.current_w2c_rot, 506 | tran=self.current_w2c_tran, 507 | H=self.tile_info.padded_height, 508 | W=self.tile_info.padded_width, 509 | focal_x=self.tile_info.focal_x, 510 | focal_y=self.tile_info.focal_y 511 | ) 512 | 513 | def project_and_culling(self): 514 | # project 3D to 2D 515 | # print(f"number of gaussians {len(self.gaussian_3ds.pos)}") 516 | # self.tic.record() 517 | if self.cudaculling: 518 | with Timer(" frustum cuda", debug=self.debug): 519 | normed_quat = (self.gaussian_3ds.quat/self.gaussian_3ds.quat.norm(dim=1, keepdim=True)) 520 | if self.scale_activation == "abs": 521 | normed_scale = self.gaussian_3ds.scale.abs()+EPS 522 | else: 523 | assert self.scale_activation == "exp" 524 | normed_scale = trunc_exp(self.gaussian_3ds.scale) 525 | _pos, _cov, _culling_mask = global_culling( 526 | self.gaussian_3ds.pos, 527 | normed_quat, 528 | normed_scale, 529 | self.current_w2c_rot.detach(), 530 | self.current_w2c_tran.detach(), 531 | self.near, 532 | self.current_camera.width*1.2/2/self.current_camera.params[0], 533 | self.current_camera.height*1.2/2/self.current_camera.params[1], 534 | ) 535 | 536 | self.culling_gaussian_3d_image_space = Gaussian3ds( 537 | pos=_pos[_culling_mask.bool()], 538 | cov=_cov[_culling_mask.bool()], 539 | rgb=self.gaussian_3ds.rgb[_culling_mask.bool()] if self.use_sh_coeff else self.gaussian_3ds.rgb[_culling_mask.bool()].sigmoid(), 540 | opa=self.gaussian_3ds.opa[_culling_mask.bool()].sigmoid(), 541 | ) 542 | self.culling_mask = _culling_mask 543 | else: 544 | with Timer("culling 1"): 545 | gaussian_3ds_pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran) 546 | with Timer("culling 2"): 547 | valid = gaussian_3ds_pos_camera_space[:,2] > self.near 548 | gaussian_3ds_pos_image_space = camera_to_image(gaussian_3ds_pos_camera_space) 549 | culling_mask = (gaussian_3ds_pos_image_space[:, 0].abs() < (self.current_camera.width*1.2/2/self.current_camera.params[0])) & \ 550 | (gaussian_3ds_pos_image_space[:, 1].abs() < (self.current_camera.height*1.2/2/self.current_camera.params[1])) 551 | valid &= culling_mask 552 | self.gaussian_3ds_valid = self.gaussian_3ds.filte(valid) 553 | with Timer("cullint 3"): 554 | self.culling_gaussian_3d_image_space = self.gaussian_3ds_valid.project( 555 | self.current_w2c_rot, 556 | self.current_w2c_tran, 557 | self.near, 558 | self.jacobian_calc, 559 | scale_activation=self.scale_activation, 560 | ) 561 | 562 | def render(self, out_write=True): 563 | if len(self.culling_gaussian_3d_image_space.pos) == 0: 564 | return torch.zeros(self.tile_info.padded_height, self.tile_info.padded_width, 3, device=self.device, dtype=torch.float32) 565 | # self.tic.record() 566 | with Timer(" culling tiles", debug=self.debug): 567 | tile_n_point = torch.zeros(len(self.tile_info), device=self.device, dtype=torch.int32) 568 | # MAXP = len(self.culling_gaussian_3d_image_space.pos)//10 569 | MAXP = len(self.culling_gaussian_3d_image_space.pos)//20 570 | tile_gaussian_list = torch.ones(len(self.tile_info), MAXP, device=self.device, dtype=torch.int32) * -1 571 | _method_config = {"dist": 0, "prob": 1, "prob2": 2} 572 | gaussian.calc_tile_list( 573 | self.culling_gaussian_3d_image_space._tocpp(), 574 | self.tile_info_cpp, 575 | tile_n_point, 576 | tile_gaussian_list, 577 | (self.tile_info.tile_geo_length_x/self.tile_culling_dist_thresh) ** 2 if self.tile_culling_method == "dist" else self.tile_culling_prob_thresh, 578 | _method_config[self.tile_culling_method], 579 | self.tile_info.tile_geo_length_x, 580 | self.tile_info.tile_geo_length_y, 581 | self.tile_info.n_tile_x, 582 | self.tile_info.n_tile_y, 583 | self.tile_info.leftmost, 584 | self.tile_info.topmost, 585 | ) 586 | tile_n_point = torch.min(tile_n_point, torch.ones_like(tile_n_point)*MAXP) 587 | 588 | if tile_n_point.sum() == 0: 589 | return torch.zeros(self.tile_info.padded_height, self.tile_info.padded_width, 3, device=self.device, dtype=torch.float32) 590 | 591 | with Timer(" gather culled tiles", debug=self.debug): 592 | gathered_list = torch.empty(tile_n_point.sum(), dtype=torch.int32, device=self.device) 593 | tile_ids_for_points = torch.empty(tile_n_point.sum(), dtype=torch.int32, device=self.device) 594 | tile_n_point_accum = torch.cat([torch.Tensor([0]).to(self.device), torch.cumsum(tile_n_point, 0)]).to(tile_n_point) 595 | max_points_for_tile = tile_n_point.max().item() 596 | # print(max_points_for_tile) 597 | gaussian.gather_gaussians( 598 | tile_n_point_accum, 599 | tile_gaussian_list, 600 | gathered_list, 601 | tile_ids_for_points, 602 | int(max_points_for_tile), 603 | ) 604 | self.tile_gaussians = self.culling_gaussian_3d_image_space.filte(gathered_list.long()) 605 | self.n_tile_gaussians = len(self.tile_gaussians.pos) 606 | self.n_gaussians = len(self.gaussian_3ds.pos) 607 | 608 | with Timer(" sorting", debug=self.debug): 609 | # cat id and sort 610 | BASE = self.tile_gaussians.pos[..., 2].max() 611 | id_and_depth = self.tile_gaussians.pos[..., 2].to(torch.float32) + tile_ids_for_points.to(torch.float32) * (BASE+1) 612 | _, sort_indices = torch.sort(id_and_depth) 613 | self.tile_gaussians = self.tile_gaussians.filte(sort_indices) 614 | 615 | with Timer(" rendering", debug=self.debug): 616 | rendered_image = draw( 617 | self.tile_gaussians.pos, 618 | self.tile_gaussians.rgb, 619 | self.tile_gaussians.opa, 620 | self.tile_gaussians.cov, 621 | tile_n_point_accum, 622 | self.tile_info.padded_height, 623 | self.tile_info.padded_width, 624 | self.tile_info.focal_x, 625 | self.tile_info.focal_y, 626 | self.render_weight_normalize, 627 | False, 628 | self.use_sh_coeff, 629 | self.fast_drawing, 630 | self.ray_info.rays_o, 631 | self.ray_info.lefttop, 632 | self.ray_info.dx, 633 | self.ray_info.dy, 634 | ) 635 | 636 | with Timer(" write out", debug=self.debug): 637 | if out_write: 638 | img_npy = rendered_image.clip(0,1).detach().cpu().numpy() 639 | cv2.imwrite("test.png", (img_npy*255).astype(np.uint8)[...,::-1]) 640 | 641 | return rendered_image 642 | 643 | def forward(self, camera_id=None, extrinsics=None, intrinsics=None): 644 | with Timer("forward", debug=self.debug): 645 | with Timer("set camera", debug=self.debug): 646 | self.set_camera(camera_id, extrinsics, intrinsics) 647 | with Timer("frustum culling", debug=self.debug): 648 | self.project_and_culling() 649 | with Timer("render function", debug=self.debug): 650 | padded_render_img = self.render(out_write=False) 651 | with Timer("crop", debug=self.debug): 652 | padded_render_img = torch.clamp(padded_render_img, 0, 1) 653 | ret = self.tile_info.crop(padded_render_img) 654 | 655 | return ret 656 | 657 | if __name__ == "__main__": 658 | parser = argparse.ArgumentParser() 659 | parser.add_argument("--cudaculling", type=int, default=0) 660 | opt = parser.parse_args() 661 | test = Splatter( 662 | os.path.join("colmap_garden/sparse/0/"), 663 | "colmap_garden/images_4/", 664 | render_weight_normalize=False, 665 | jacobian_calc="cuda", 666 | render_downsample=4, 667 | opa_init_value=0.8, 668 | scale_init_value=0.2, 669 | cudaculling=opt.cudaculling, 670 | load_ckpt="ckpt.pth", 671 | scale_activation="exp", 672 | ) 673 | test.forward(camera_id=0) 674 | loss = (test.ground_truth - test.forward(camera_id=0)).abs().mean() 675 | loss.backward() -------------------------------------------------------------------------------- /src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | // #include "include/gaussian.h" 3 | #include "include/common.hpp" 4 | 5 | void culling(torch::Tensor pos, torch::Tensor rgb, torch::Tensor quatenions, torch::Tensor scales, torch::Tensor w2c_quat, torch::Tensor w2c_tran); 6 | void world2camera(torch::Tensor pos, torch::Tensor rot, torch::Tensor trans, torch::Tensor res); 7 | void world2camera_backward(torch::Tensor grad_out, torch::Tensor rot, torch::Tensor grad_inp); 8 | void jacobian(torch::Tensor pos_camera_space, torch::Tensor jacobian); 9 | void calc_tile_list(Gaussian3ds & gaussians_image_space, Tiles & tile_info, torch::Tensor tile_n_point, torch::Tensor tile_gaussian_list, float thresh, int method, float tile_length_x, float tile_length_y, int n_tiles_x, int n_tiles_y, float leftmost, float topmost); 10 | void gather_gaussians(torch::Tensor tile_n_point_accum, torch::Tensor tile_gaussian_list, torch::Tensor gathered_list, torch::Tensor tile_ids_for_points, int max_points_for_tile); 11 | // void draw(Gaussian3ds & tile_sorted_gaussians, torch::Tensor tile_n_point_accum, torch::Tensor res, float focal_x, float focal_y); 12 | // void draw(torch::Tensor gaussian_pos, torch::Tensor gaussian_rgb, torch::Tensor gaussian_opa, torch::Tensor gaussian_cov, torch::Tensor tile_n_point_accum, torch::Tensor res, float focal_x, float focal_y, bool weight_normalize, bool sigmoid, bool fast); //, torch::Tensor rays_o, torch::Tensor lefttop_pos, torch::Tensor vec_dx, torch::Tensor vec_dy); 13 | // void draw_backward(torch::Tensor gaussian_pos, torch::Tensor gaussian_rgb, torch::Tensor gaussian_opa, torch::Tensor gaussian_cov, torch::Tensor tile_n_point_accum, torch::Tensor output, torch::Tensor grad_output, torch::Tensor grad_pos, torch::Tensor grad_rgb, torch::Tensor grad_opa, torch::Tensor grad_cov, float focal_x, float focal_y, bool weight_normalize, bool sigmoid, bool fast);//, torch::Tensor rays_o, torch::Tensor lefttop_pos, torch::Tensor vec_dx, torch::Tensor vec_dy); 14 | void draw(torch::Tensor gaussian_pos, torch::Tensor gaussian_rgb, torch::Tensor gaussian_opa, torch::Tensor gaussian_cov, torch::Tensor tile_n_point_accum, torch::Tensor res, float focal_x, float focal_y, bool weight_normalize, bool sigmoid, bool fast, torch::Tensor rays_o, torch::Tensor lefttop_pos, torch::Tensor vec_dx, torch::Tensor vec_dy, bool use_sh_coeff); 15 | void draw_backward(torch::Tensor gaussian_pos, torch::Tensor gaussian_rgb, torch::Tensor gaussian_opa, torch::Tensor gaussian_cov, torch::Tensor tile_n_point_accum, torch::Tensor output, torch::Tensor grad_output, torch::Tensor grad_pos, torch::Tensor grad_rgb, torch::Tensor grad_opa, torch::Tensor grad_cov, float focal_x, float focal_y, bool weight_normalize, bool sigmoid, bool fast, torch::Tensor rays_o, torch::Tensor lefttop_pos, torch::Tensor vec_dx, torch::Tensor vec_dy, bool use_sh_coeff); 16 | 17 | void global_culling(torch::Tensor pos, torch::Tensor quat, torch::Tensor scale, torch::Tensor current_rot, torch::Tensor current_tran, torch::Tensor res_pos, torch::Tensor res_cov, torch::Tensor culling_mask, float near, float half_width, float half_height); 18 | void global_culling_backward(torch::Tensor pos, torch::Tensor quat, torch::Tensor scale, torch::Tensor current_rot, torch::Tensor current_tran, torch::Tensor gradout_pos, torch::Tensor gradout_cov, torch::Tensor culling_mask, torch::Tensor gradinput_pos, torch::Tensor gradinput_quat, torch::Tensor gradinput_scale); 19 | // void cat_ids(torch::Tensor depth, torch::Tensor ids, torch::Tensor res); 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | // m.def("matmul", &matmul, "matmul forward (CUDA)"); 23 | m.def("culling", &culling, "gaussian culling (CUDA)"); 24 | m.def("world2camera", &world2camera, "world to camera fast (CUDA)"); 25 | m.def("world2camera_backward", &world2camera_backward, "world to camera fast backward (CUDA)"); 26 | m.def("jacobian", &jacobian, "jacobian (CUDA)"); 27 | 28 | py::class_(m, "Tiles") 29 | .def(py::init<>()) 30 | .def_readwrite("top", &Tiles::top) 31 | .def_readwrite("bottom", &Tiles::bottom) 32 | .def_readwrite("left", &Tiles::left) 33 | .def_readwrite("right", &Tiles::right); 34 | 35 | py::class_(m, "Gaussian3ds") 36 | .def(py::init<>()) 37 | .def_readwrite("pos", &Gaussian3ds::pos) 38 | .def_readwrite("rgb", &Gaussian3ds::rgb) 39 | .def_readwrite("opa", &Gaussian3ds::opa) 40 | .def_readwrite("quat", &Gaussian3ds::quat) 41 | .def_readwrite("scale", &Gaussian3ds::scale) 42 | .def_readwrite("cov", &Gaussian3ds::cov); 43 | 44 | m.def("calc_tile_list", &calc_tile_list, "calc tile list (CUDA)"); 45 | m.def("gather_gaussians", &gather_gaussians, "gather gaussian (CUDA)"); 46 | m.def("draw", &draw, "draw (CUDA)"); 47 | m.def("draw_backward", &draw_backward, "draw backward (CUDA)"); 48 | m.def("global_culling", &global_culling, "global culling (CUDA)"); 49 | m.def("global_culling_backward", &global_culling_backward, "global culling backward (CUDA)"); 50 | } -------------------------------------------------------------------------------- /src/gaussian_nosh.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "include/common.hpp" 3 | #include 4 | #include 5 | 6 | void culling(torch::Tensor pos, torch::Tensor rgb, torch::Tensor quatenions, torch::Tensor scales, torch::Tensor w2c_quat, torch::Tensor w2c_tran){ 7 | printf("hellow\n"); 8 | } 9 | 10 | __global__ void jacobian_kernel( 11 | const float* pos_camera_space, 12 | float* jacobian, 13 | uint32_t B 14 | ){ 15 | uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; 16 | if(tid >= B) return; 17 | pos_camera_space += tid * 3; 18 | jacobian += tid * 9; 19 | float _res[9]; 20 | float u0 = pos_camera_space[0]; 21 | float u1 = pos_camera_space[1]; 22 | float u2 = pos_camera_space[2]; 23 | 24 | _res[0] = 1/u2; 25 | _res[1] = 0; 26 | _res[2] = -u0/(u2*u2); 27 | _res[3] = 0; 28 | _res[4] = 1/u2; 29 | _res[5] = -u1/(u2*u2); 30 | float _rsqr = rsqrtf(u0*u0+u1*u1+u2*u2); 31 | _res[6] = _rsqr * u0; 32 | _res[7] = _rsqr * u1; 33 | _res[8] = _rsqr * u2; 34 | 35 | #pragma unroll 36 | for(uint32_t i=0; i<9; ++i){ 37 | jacobian[i] = _res[i]; 38 | } 39 | } 40 | 41 | void jacobian(torch::Tensor pos_camera_space, torch::Tensor jacobian){ 42 | // pos_camera_space B x 3 43 | // jacobian B x 3 x 3 44 | uint32_t B = pos_camera_space.size(0); 45 | uint32_t gridsize = DIV_ROUND_UP(B, 1024); 46 | jacobian_kernel<<>>(pos_camera_space.data_ptr(), jacobian.data_ptr(), B); 47 | } 48 | 49 | __global__ void world2camera_kernel(const float * pos, const float * rot, const float * trans, float * res, uint32_t B){ 50 | uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; 51 | if(tid >= B) return; 52 | pos += tid * 3; 53 | float _rot[9]; 54 | float _trans[3]; 55 | #pragma unroll 56 | for(int i=0;i<9;++i){ 57 | _rot[i] = rot[i]; 58 | } 59 | 60 | #pragma unroll 61 | for(int i=0;i<3;++i){ 62 | _trans[i] = trans[i]; 63 | } 64 | 65 | res += tid * 3; 66 | res[0] = pos[0] * _rot[0] + pos[1] * _rot[1] + pos[2] * _rot[2] + _trans[0]; 67 | res[1] = pos[0] * _rot[3] + pos[1] * _rot[4] + pos[2] * _rot[5] + _trans[1]; 68 | res[2] = pos[0] * _rot[6] + pos[1] * _rot[7] + pos[2] * _rot[8] + _trans[2]; 69 | } 70 | 71 | void world2camera(torch::Tensor pos, torch::Tensor rot, torch::Tensor trans, torch::Tensor res){ 72 | uint32_t B = pos.size(0); 73 | uint32_t gridsize = DIV_ROUND_UP(B, 1024); 74 | world2camera_kernel<<>>(pos.data_ptr(), rot.data_ptr(), trans.data_ptr(), res.data_ptr(), B); 75 | 76 | } 77 | 78 | __global__ void world2camera_backward_kernel(const float * grad_out, const float * rot, float * grad_inp, uint32_t B){ 79 | uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; 80 | if(tid >= B) return; 81 | grad_out += tid * 3; 82 | float _rot[9]; 83 | 84 | #pragma unroll 85 | for(int i=0;i<9;++i){ 86 | _rot[i] = rot[i]; 87 | } 88 | 89 | grad_inp += tid * 3; 90 | grad_inp[0] = grad_out[0] * _rot[0] + grad_out[1] * _rot[3] + grad_out[2] * _rot[6]; 91 | grad_inp[1] = grad_out[0] * _rot[1] + grad_out[1] * _rot[4] + grad_out[2] * _rot[7]; 92 | grad_inp[2] = grad_out[0] * _rot[2] + grad_out[1] * _rot[5] + grad_out[2] * _rot[8]; 93 | } 94 | 95 | void world2camera_backward(torch::Tensor grad_out, torch::Tensor rot, torch::Tensor grad_inp){ 96 | uint32_t B = grad_out.size(0); 97 | uint32_t gridsize = DIV_ROUND_UP(B, 1024); 98 | world2camera_backward_kernel<<>>(grad_out.data_ptr(), rot.data_ptr(), grad_inp.data_ptr(), B); 99 | } 100 | 101 | __global__ void calc_tile_info_kernel( 102 | float * gaussian_pos, 103 | float * top, 104 | float * bottom, 105 | float * left, 106 | float * right, 107 | int * tile_n_point, 108 | int * tile_gaussian_list, 109 | uint32_t n_point, 110 | uint32_t n_tiles, 111 | uint32_t max_points_per_tile, 112 | float thresh_dis 113 | ){ 114 | uint32_t pid = blockDim.x * blockIdx.x + threadIdx.x; 115 | if(pid >= n_point) return; 116 | uint32_t tid = blockDim.y * blockIdx.y + threadIdx.y; 117 | if(tid >= n_tiles) return; 118 | gaussian_pos += pid * 3; 119 | top += tid; 120 | bottom += tid; 121 | left += tid; 122 | right += tid; 123 | // simple test 124 | float center_y = (*top + *bottom) / 2; 125 | float center_x = (*left + *right) / 2; 126 | float d1 = gaussian_pos[0] - center_x; 127 | float d2 = gaussian_pos[1] - center_y; 128 | if(d1*d1 + d2*d2 < thresh_dis){ 129 | // write pid -> append 130 | uint32_t old = atomicAdd(tile_n_point + tid, 1); 131 | if(old= n_point) return; 155 | uint32_t tid = blockDim.y * blockIdx.y + threadIdx.y; 156 | if(tid >= n_tiles) return; 157 | gaussian_pos += pid * 3; 158 | top += tid; 159 | bottom += tid; 160 | left += tid; 161 | right += tid; 162 | // simple test 163 | float _a, _b, _c, _d; 164 | _a = gaussian_cov[pid*4+0]; 165 | _b = gaussian_cov[pid*4+1]; 166 | _c = gaussian_cov[pid*4+2]; 167 | _d = gaussian_cov[pid*4+3]; 168 | 169 | float _center_x = gaussian_pos[0]; 170 | float _center_y = gaussian_pos[1]; 171 | 172 | float det = (_a * _d - _b * _c); 173 | if(det<=0) return; 174 | 175 | float _ai = _d / (det + 1e-14); 176 | float _bi = -_b / (det + 1e-14); 177 | float _ci = -_c / (det + 1e-14); 178 | float _di = _a / (det + 1e-14); 179 | float thresh_dis_log = -2 * logf(thresh_dis); 180 | float shift_x = sqrtf(_di * thresh_dis_log * det); 181 | float shift_y = sqrtf(_ai * thresh_dis_log * det); 182 | float bbx_right = _center_x + shift_x; 183 | float bbx_left = _center_x - shift_x; 184 | float bbx_top = _center_y - shift_y; 185 | float bbx_bottom = _center_y + shift_y; 186 | 187 | if(! (*right < bbx_left || bbx_right < *left || *bottom < bbx_top || bbx_bottom < *top)){ 188 | // uint32_t old = atomicAdd(tile_n_point + tid, 1); 189 | if(tile_n_point[tid]= n_point) return; 215 | gaussian_pos += pid * 3; 216 | // simple test 217 | float _a, _b, _c, _d; 218 | _a = gaussian_cov[pid*4+0]; 219 | _b = gaussian_cov[pid*4+1]; 220 | _c = gaussian_cov[pid*4+2]; 221 | _d = gaussian_cov[pid*4+3]; 222 | 223 | float _center_x = gaussian_pos[0]; 224 | float _center_y = gaussian_pos[1]; 225 | 226 | float det = (_a * _d - _b * _c); 227 | if(det<=0) return; 228 | 229 | float _ai = _d / (det + 1e-14); 230 | float _bi = -_b / (det + 1e-14); 231 | float _ci = -_c / (det + 1e-14); 232 | float _di = _a / (det + 1e-14); 233 | float thresh_dis_log = -2 * logf(thresh_dis); 234 | float shift_x = sqrtf(_di * thresh_dis_log * det); 235 | float shift_y = sqrtf(_ai * thresh_dis_log * det); 236 | float bbx_right = _center_x + shift_x; 237 | float bbx_left = _center_x - shift_x; 238 | float bbx_top = _center_y - shift_y; 239 | float bbx_bottom = _center_y + shift_y; 240 | 241 | for(uint32_t i_top=fmaxf((bbx_top-topmost)/tile_length_y, 0); i_top<(uint32_t)((bbx_bottom-topmost)/tile_length_y+1) && i_top>>( 282 | gaussians_image_space.pos.data_ptr(), 283 | tile_info.top.data_ptr(), 284 | tile_info.bottom.data_ptr(), 285 | tile_info.left.data_ptr(), 286 | tile_info.right.data_ptr(), 287 | tile_n_point.data_ptr(), 288 | tile_gaussian_list.data_ptr(), 289 | n_point, 290 | n_tiles, 291 | max_points_per_tile, 292 | thresh 293 | ); 294 | } 295 | else if(method == 1){ 296 | uint32_t gridsize_x = DIV_ROUND_UP(n_point, 32); 297 | uint32_t gridsize_y = DIV_ROUND_UP(n_tiles, 32); 298 | dim3 gridsize(gridsize_x, gridsize_y, 1); 299 | dim3 blocksize(32, 32, 1); 300 | calc_tile_info_kernel2<<>>( 301 | gaussians_image_space.pos.data_ptr(), 302 | gaussians_image_space.cov.data_ptr(), 303 | tile_info.top.data_ptr(), 304 | tile_info.bottom.data_ptr(), 305 | tile_info.left.data_ptr(), 306 | tile_info.right.data_ptr(), 307 | tile_n_point.data_ptr(), 308 | tile_gaussian_list.data_ptr(), 309 | n_point, 310 | n_tiles, 311 | max_points_per_tile, 312 | thresh 313 | ); 314 | } 315 | else{ 316 | uint32_t gridsize_x = DIV_ROUND_UP(n_point, 1024); 317 | dim3 gridsize(gridsize_x, 1, 1); 318 | dim3 blocksize(1024, 1, 1); 319 | calc_tile_info_kernel3<<>>( 320 | gaussians_image_space.pos.data_ptr(), 321 | gaussians_image_space.cov.data_ptr(), 322 | tile_length_x, 323 | tile_length_y, 324 | tile_n_point.data_ptr(), 325 | tile_gaussian_list.data_ptr(), 326 | n_point, 327 | n_tiles_x, 328 | n_tiles_y, 329 | max_points_per_tile, 330 | thresh, 331 | leftmost, 332 | topmost 333 | ); 334 | } 335 | } 336 | 337 | __global__ void gather_gaussians_kernel( 338 | const int * tile_n_point_accum, 339 | const int * gaussian_list, 340 | int * gather_list, 341 | int * tile_ids_for_points, 342 | int n_tiles, 343 | int max_points_for_tile, 344 | int gaussian_list_size 345 | ){ 346 | uint32_t tid = blockDim.y * blockIdx.y + threadIdx.y; 347 | if(tid >= n_tiles) return; 348 | uint32_t pid = blockDim.x * blockIdx.x + threadIdx.x; 349 | uint32_t shift_res = tile_n_point_accum[tid]; 350 | 351 | uint32_t n_point_this_tile = tile_n_point_accum[tid+1] - shift_res; 352 | if(pid >= n_point_this_tile) return; 353 | 354 | gaussian_list += tid * gaussian_list_size; 355 | gather_list[shift_res+pid] = gaussian_list[pid]; 356 | tile_ids_for_points[shift_res+pid] = tid; 357 | } 358 | 359 | void gather_gaussians( 360 | torch::Tensor tile_n_point_accum, 361 | torch::Tensor tile_gaussian_list, 362 | torch::Tensor gathered_list, 363 | torch::Tensor tile_ids_for_points, 364 | int max_points_for_tile 365 | ){ 366 | uint32_t n_tiles = tile_n_point_accum.size(0) - 1; 367 | uint32_t gridsize_x = DIV_ROUND_UP(max_points_for_tile, 32); 368 | uint32_t gridsize_y = DIV_ROUND_UP(n_tiles, 32); 369 | dim3 gridsize(gridsize_x, gridsize_y, 1); 370 | dim3 blocksize(32, 32, 1); 371 | uint32_t gaussian_list_size = tile_gaussian_list.size(1); 372 | gather_gaussians_kernel<<>>( 373 | tile_n_point_accum.data_ptr(), 374 | tile_gaussian_list.data_ptr(), 375 | gathered_list.data_ptr(), 376 | tile_ids_for_points.data_ptr(), 377 | n_tiles, 378 | max_points_for_tile, 379 | gaussian_list_size 380 | ); 381 | } 382 | 383 | 384 | // Spherical functions from svox2 385 | __device__ __constant__ const float C0 = 0.28209479177387814; 386 | __device__ __constant__ const float C1 = 0.4886025119029199; 387 | __device__ __constant__ const float C2[] = { 388 | 1.0925484305920792, 389 | -1.0925484305920792, 390 | 0.31539156525252005, 391 | -1.0925484305920792, 392 | 0.5462742152960396 393 | }; 394 | 395 | __device__ __constant__ const float C3[] = { 396 | -0.5900435899266435, 397 | 2.890611442640554, 398 | -0.4570457994644658, 399 | 0.3731763325901154, 400 | -0.4570457994644658, 401 | 1.445305721320277, 402 | -0.5900435899266435 403 | }; 404 | 405 | __device__ __inline__ void calc_sh( 406 | const int basis_dim, 407 | const float* __restrict__ dir, 408 | float* __restrict__ out) { 409 | out[0] = C0; 410 | const float x = dir[0], y = dir[1], z = dir[2]; 411 | const float xx = x * x, yy = y * y, zz = z * z; 412 | const float xy = x * y, yz = y * z, xz = x * z; 413 | switch (basis_dim) { 414 | case 9: 415 | out[4] = C2[0] * xy; 416 | out[5] = C2[1] * yz; 417 | out[6] = C2[2] * (2.0 * zz - xx - yy); 418 | out[7] = C2[3] * xz; 419 | out[8] = C2[4] * (xx - yy); 420 | [[fallthrough]]; 421 | case 4: 422 | out[1] = -C1 * y; 423 | out[2] = C1 * z; 424 | out[3] = -C1 * x; 425 | } 426 | } 427 | 428 | template 429 | __global__ void draw_backward_kernel( 430 | const float * gaussian_pos, 431 | const float * gaussian_rgb, 432 | const float * gaussian_opa, 433 | const float * gaussian_cov, 434 | const int * tile_n_point_accum, 435 | const float * output, 436 | const float * grad_output, 437 | float * grad_pos, 438 | float * grad_rgb, 439 | float * grad_opa, 440 | float * grad_cov, 441 | const float focal_x, 442 | const float focal_y, 443 | const uint32_t w, 444 | const uint32_t h, 445 | const bool weight_normalize, 446 | const bool sigmoid, 447 | const bool fast 448 | ){ 449 | uint32_t id_x = blockDim.x * blockIdx.x + threadIdx.x; 450 | uint32_t id_y = blockDim.y * blockIdx.y + threadIdx.y; 451 | if(id_x>=w || id_y>=h) return; 452 | uint32_t id_tile = blockIdx.x + blockIdx.y * gridDim.x; 453 | uint32_t start_idx = tile_n_point_accum[id_tile]; 454 | uint32_t end_idx = tile_n_point_accum[id_tile+1]; 455 | // const uint32_t interval_length = end_idx - start_idx; 456 | __shared__ float _gaussian_pos[SMSIZE*2]; 457 | __shared__ float _gaussian_rgb[SMSIZE*3]; 458 | __shared__ float _gaussian_opa[SMSIZE*1]; 459 | __shared__ float _gaussian_cov[SMSIZE*4]; 460 | __shared__ float _grad_pos[SMSIZE*2]; 461 | __shared__ float _grad_rgb[SMSIZE*3]; 462 | __shared__ float _grad_opa[SMSIZE*1]; 463 | __shared__ float _grad_cov[SMSIZE*4]; 464 | uint32_t id_thread = threadIdx.x + threadIdx.y * blockDim.x; 465 | uint32_t blocksize = blockDim.x * blockDim.y; 466 | 467 | // initialize shared memory for gradients 468 | for(uint32_t i=id_thread; i=(end_idx - start_idx)){ 515 | break; 516 | } 517 | _gaussian_pos[i*2 + 0] = gaussian_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 0]; 518 | _gaussian_pos[i*2 + 1] = gaussian_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 1]; 519 | // rgb 520 | _gaussian_rgb[i*3 + 0] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 0]; 521 | _gaussian_rgb[i*3 + 1] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 1]; 522 | _gaussian_rgb[i*3 + 2] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 2]; 523 | // opa 524 | _gaussian_opa[i] = gaussian_opa[start_idx + i_loadings*SMSIZE + i]; 525 | // cov 526 | _gaussian_cov[i*4 + 0] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 0]; 527 | _gaussian_cov[i*4 + 1] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 1]; 528 | _gaussian_cov[i*4 + 2] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 2]; 529 | _gaussian_cov[i*4 + 3] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 3]; 530 | } 531 | __syncthreads(); 532 | 533 | for(uint32_t i=0; i=(end_idx-start_idx)||accum < 0.0001){ 538 | break; 539 | } 540 | _a = _gaussian_cov[i*4+0]; 541 | _b = _gaussian_cov[i*4+1]; 542 | _c = _gaussian_cov[i*4+2]; 543 | _d = _gaussian_cov[i*4+3]; 544 | _x = pixel_x - _gaussian_pos[i*2 + 0]; 545 | _y = pixel_y - _gaussian_pos[i*2 + 1]; 546 | det = (_a * _d - _b * _c); 547 | 548 | T Pm = -(_d * _x * _x - (_b + _c) * _x * _y + _a * _y * _y); 549 | T Pn = (2 * det + 1e-14); 550 | 551 | current_prob_c0 = sigmoid ? 1.0/2*3.1415926536 : 1.0; 552 | current_prob0 = sigmoid ? current_prob_c0 * rsqrtf(det+1e-7) : 1.0; 553 | if(fast){ 554 | current_prob1 = __expf(Pm / Pn); 555 | } 556 | else{ 557 | current_prob1 = exp(Pm / Pn); 558 | } 559 | current_prob = current_prob0 * current_prob1; 560 | 561 | // gradient primitives for 2d gaussian 562 | T dPm_da = - (_y * _y); 563 | T dPm_db = _x * _y; 564 | T dPm_dc = _x * _y; 565 | T dPm_dd = - (_x * _x); 566 | T dPn_da = 2 * _d; 567 | T dPn_db = -2 * _c; 568 | T dPn_dc = -2 * _b; 569 | T dPn_dd = 2 * _a; 570 | T dP1_da = current_prob1 * (dPm_da*Pn - dPn_da*Pm) / (Pn*Pn); 571 | T dP1_db = current_prob1 * (dPm_db*Pn - dPn_db*Pm) / (Pn*Pn); 572 | T dP1_dc = current_prob1 * (dPm_dc*Pn - dPn_dc*Pm) / (Pn*Pn); 573 | T dP1_dd = current_prob1 * (dPm_dd*Pn - dPn_dd*Pm) / (Pn*Pn); 574 | 575 | T dP0_da = sigmoid ? -0.5 * (current_prob0 * current_prob0 * current_prob0) / (current_prob_c0 * current_prob_c0) * _d : 0.0; 576 | T dP0_db = sigmoid ? 0.5 * (current_prob0 * current_prob0 * current_prob0) / (current_prob_c0 * current_prob_c0) * _c : 0.0; 577 | T dP0_dc = sigmoid ? 0.5 * (current_prob0 * current_prob0 * current_prob0) / (current_prob_c0 * current_prob_c0) * _b : 0.0; 578 | T dP0_dd = sigmoid ? -0.5 * (current_prob0 * current_prob0 * current_prob0) / (current_prob_c0 * current_prob_c0) * _a : 0.0; 579 | 580 | T dP_da = current_prob0 * dP1_da + current_prob1 * dP0_da; 581 | T dP_db = current_prob0 * dP1_db + current_prob1 * dP0_db; 582 | T dP_dc = current_prob0 * dP1_dc + current_prob1 * dP0_dc; 583 | T dP_dd = current_prob0 * dP1_dd + current_prob1 * dP0_dd; 584 | // gradient w.r.t position (not _x, _y, with a minus) 585 | T dP_dx = current_prob / Pn * (2*_d*_x - _b*_y - _c*_y); 586 | T dP_dy = current_prob / Pn * (2*_a*_y - _b*_x - _c*_x); 587 | 588 | alpha = current_prob * _gaussian_opa[i]; 589 | // sigmoid + scale -> 0-1 590 | if(sigmoid){ 591 | alpha = 2./(exp(-alpha)+1) - 1; 592 | } 593 | weight = alpha * accum; 594 | float cur_point_color[3]; 595 | cur_point_color[0] = _gaussian_rgb[i*3 + 0]; 596 | cur_point_color[1] = _gaussian_rgb[i*3 + 1]; 597 | cur_point_color[2] = _gaussian_rgb[i*3 + 2]; 598 | color[0] += cur_point_color[0] * weight; 599 | color[1] += cur_point_color[1] * weight; 600 | color[2] += cur_point_color[2] * weight; 601 | accum_weight += weight; 602 | 603 | // grad w.r.t gaussian rgb 604 | atomicAdd(_grad_rgb+i*3+0, cur_grad_out[0]*weight); 605 | atomicAdd(_grad_rgb+i*3+1, cur_grad_out[1]*weight); 606 | atomicAdd(_grad_rgb+i*3+2, cur_grad_out[2]*weight); 607 | 608 | // grad w.r.t pos opa cov -> grad w.r.t alpha 609 | float d_alpha = 0; 610 | #pragma unroll 611 | for(int _m=0; _m<3; ++_m){ 612 | d_alpha += cur_grad_out[_m] * cur_point_color[_m]; 613 | } 614 | d_alpha *= accum; 615 | float _d_alpha_acc = 0; 616 | #pragma unroll 617 | for(int _m=0; _m<3; ++_m){ 618 | _d_alpha_acc += cur_grad_out[_m] * (cur_out[_m]-color[_m]); 619 | } 620 | _d_alpha_acc /= (1-alpha+1e-7); 621 | d_alpha -= _d_alpha_acc; 622 | // backward to activation function 623 | if(sigmoid){ 624 | d_alpha = d_alpha * (alpha + 1 - 0.5*(alpha+1)*(alpha+1)); 625 | } 626 | // grad w.r.t opa 627 | atomicAdd(_grad_opa+i*1+0, (float)(d_alpha*current_prob)); 628 | float d_current_prob = d_alpha * _gaussian_opa[i]; 629 | // grad w.r.t pos 630 | atomicAdd(_grad_pos+i*2+0, (float)(d_current_prob * dP_dx)); 631 | atomicAdd(_grad_pos+i*2+1, (float)(d_current_prob * dP_dy)); 632 | // grad w.r.t cov 633 | atomicAdd(_grad_cov+i*4+0, (float)(d_current_prob * dP_da)); 634 | atomicAdd(_grad_cov+i*4+1, (float)(d_current_prob * dP_db)); 635 | atomicAdd(_grad_cov+i*4+2, (float)(d_current_prob * dP_dc)); 636 | atomicAdd(_grad_cov+i*4+3, (float)(d_current_prob * dP_dd)); 637 | 638 | accum *= (1-alpha); 639 | } 640 | __syncthreads(); 641 | 642 | // write gradients back to global memory 643 | for(uint32_t i=id_thread; i=(end_idx - start_idx)){ 647 | break; 648 | } 649 | grad_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 0] = _grad_pos[i*2 + 0]; 650 | grad_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 1] = _grad_pos[i*2 + 1]; 651 | 652 | grad_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 0] = _grad_rgb[i*3 + 0]; 653 | grad_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 1] = _grad_rgb[i*3 + 1]; 654 | grad_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 2] = _grad_rgb[i*3 + 2]; 655 | 656 | grad_opa[(start_idx + i_loadings*SMSIZE + i)*1 + 0] = _grad_opa[i]; 657 | 658 | grad_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 0] = _grad_cov[i*4 + 0]; 659 | grad_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 1] = _grad_cov[i*4 + 1]; 660 | grad_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 2] = _grad_cov[i*4 + 2]; 661 | grad_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 3] = _grad_cov[i*4 + 3]; 662 | } 663 | // if this sync is necessary ? 664 | // __syncthreads(); 665 | } 666 | } 667 | 668 | 669 | template 670 | __global__ void draw_kernel( 671 | // Gaussian3ds & tile_sorted_gaussians, 672 | const float * gaussian_pos, 673 | const float * gaussian_rgb, 674 | const float * gaussian_opa, 675 | const float * gaussian_cov, 676 | const int * tile_n_point_accum, 677 | float * res, 678 | const float focal_x, 679 | const float focal_y, 680 | const uint32_t w, 681 | const uint32_t h, 682 | const bool weight_normalize, 683 | const bool sigmoid, 684 | const bool fast 685 | ){ 686 | uint32_t id_x = blockDim.x * blockIdx.x + threadIdx.x; 687 | uint32_t id_y = blockDim.y * blockIdx.y + threadIdx.y; 688 | if(id_x>=w || id_y>=h) return; 689 | uint32_t id_tile = blockIdx.x + blockIdx.y * gridDim.x; 690 | uint32_t start_idx = tile_n_point_accum[id_tile]; 691 | uint32_t end_idx = tile_n_point_accum[id_tile+1]; 692 | // const uint32_t interval_length = end_idx - start_idx; 693 | __shared__ float _gaussian_pos[SMSIZE*2]; 694 | __shared__ float _gaussian_rgb[SMSIZE*3]; 695 | __shared__ float _gaussian_opa[SMSIZE*1]; 696 | __shared__ float _gaussian_cov[SMSIZE*4]; 697 | uint32_t id_thread = threadIdx.x + threadIdx.y * blockDim.x; 698 | uint32_t blocksize = blockDim.x * blockDim.y; 699 | //draw: access all point with early stop 700 | float pixel_x = (id_x + 0.5 - w/2)/focal_x; 701 | float pixel_y = (id_y + 0.5 - h/2)/focal_y; 702 | float color[] = {0, 0, 0}; 703 | float accum = 1.0; 704 | float accum_weight = 0.0; 705 | 706 | // double current_prob = 0.0; 707 | // double _a, _b, _c, _d, _x, _y, det; 708 | T current_prob = 0.0; 709 | T _a, _b, _c, _d, _x, _y, det; 710 | 711 | float alpha, weight; 712 | 713 | // load to memory 714 | uint32_t n_loadings = DIV_ROUND_UP(end_idx - start_idx, SMSIZE); 715 | uint32_t global_idx; 716 | for(uint32_t i_loadings=0; i_loadings=(end_idx - start_idx)){ 721 | break; 722 | } 723 | _gaussian_pos[i*2 + 0] = gaussian_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 0]; 724 | _gaussian_pos[i*2 + 1] = gaussian_pos[(start_idx + i_loadings*SMSIZE + i)*3 + 1]; 725 | // rgb 726 | _gaussian_rgb[i*3 + 0] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 0]; 727 | _gaussian_rgb[i*3 + 1] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 1]; 728 | _gaussian_rgb[i*3 + 2] = gaussian_rgb[(start_idx + i_loadings*SMSIZE + i)*3 + 2]; 729 | // opa 730 | _gaussian_opa[i] = gaussian_opa[start_idx + i_loadings*SMSIZE + i]; 731 | // cov 732 | _gaussian_cov[i*4 + 0] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 0]; 733 | _gaussian_cov[i*4 + 1] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 1]; 734 | _gaussian_cov[i*4 + 2] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 2]; 735 | _gaussian_cov[i*4 + 3] = gaussian_cov[(start_idx + i_loadings*SMSIZE + i)*4 + 3]; 736 | } 737 | __syncthreads(); 738 | 739 | for(uint32_t i=0; i=(end_idx-start_idx)||accum < 0.0001){ 744 | break; 745 | } 746 | _a = _gaussian_cov[i*4+0]; 747 | _b = _gaussian_cov[i*4+1]; 748 | _c = _gaussian_cov[i*4+2]; 749 | _d = _gaussian_cov[i*4+3]; 750 | _x = pixel_x - _gaussian_pos[i*2 + 0]; 751 | _y = pixel_y - _gaussian_pos[i*2 + 1]; 752 | det = (_a * _d - _b * _c); 753 | 754 | current_prob = sigmoid ? 1.0/2*3.1415926536 * rsqrtf(det+1e-7) : 1; 755 | if(fast){ 756 | current_prob *= __expf(-(_d * _x * _x - (_b + _c) * _x * _y + _a * _y * _y) / (2 * det+1e-14)); 757 | } 758 | else{ 759 | current_prob *= exp(-(_d * _x * _x - (_b + _c) * _x * _y + _a * _y * _y) / (2 * det+1e-14)); 760 | } 761 | 762 | alpha = current_prob * _gaussian_opa[i]; 763 | // printf("current_alpha: %f\n", alpha); 764 | // sigmoid + scale -> 0-1 765 | if(sigmoid){ 766 | alpha = 2./(exp(-alpha)+1) - 1; 767 | } 768 | weight = alpha * accum; 769 | // printf("a: %f, b: %f, c: %f, d: %f, x: %f, y: %f, det: %f, current_prob: %f, weight: %f\n", _a, _b, _c, _d, _x, _y, det, current_prob, weight); 770 | color[0] += _gaussian_rgb[i*3 + 0] * weight; 771 | color[1] += _gaussian_rgb[i*3 + 1] * weight; 772 | color[2] += _gaussian_rgb[i*3 + 2] * weight; 773 | accum_weight += weight; 774 | accum *= (1-alpha); 775 | } 776 | } 777 | 778 | if(accum_weight < 0.01 || !weight_normalize){ 779 | accum_weight = 1; 780 | } 781 | res[(id_x + id_y * w)*3 + 0] = color[0] / accum_weight; 782 | res[(id_x + id_y * w)*3 + 1] = color[1] / accum_weight; 783 | res[(id_x + id_y * w)*3 + 2] = color[2] / accum_weight; 784 | } 785 | 786 | // void draw(Gaussian3ds & tile_sorted_gaussians, torch::Tensor tile_n_point_accum, torch::Tensor res, float focal_x, float focal_y){ 787 | void draw( 788 | torch::Tensor gaussian_pos, 789 | torch::Tensor gaussian_rgb, 790 | torch::Tensor gaussian_opa, 791 | torch::Tensor gaussian_cov, 792 | torch::Tensor tile_n_point_accum, 793 | torch::Tensor res, 794 | float focal_x, 795 | float focal_y, 796 | bool weight_normalize, 797 | bool sigmoid, 798 | bool fast 799 | ){ 800 | uint32_t h = res.size(0); 801 | uint32_t w = res.size(1); 802 | uint32_t gridsize_x = DIV_ROUND_UP(w, 16); 803 | uint32_t gridsize_y = DIV_ROUND_UP(h, 16); 804 | dim3 gridsize(gridsize_x, gridsize_y, 1); 805 | dim3 blocksize(16, 16, 1); 806 | if(fast){ 807 | draw_kernel<1200, float><<>>( 808 | gaussian_pos.data_ptr(), 809 | gaussian_rgb.data_ptr(), 810 | gaussian_opa.data_ptr(), 811 | gaussian_cov.data_ptr(), 812 | tile_n_point_accum.data_ptr(), 813 | res.data_ptr(), 814 | focal_x, 815 | focal_y, 816 | w, 817 | h, 818 | weight_normalize, 819 | sigmoid, 820 | fast 821 | ); 822 | } 823 | else{ 824 | draw_kernel<1200, double><<>>( 825 | gaussian_pos.data_ptr(), 826 | gaussian_rgb.data_ptr(), 827 | gaussian_opa.data_ptr(), 828 | gaussian_cov.data_ptr(), 829 | tile_n_point_accum.data_ptr(), 830 | res.data_ptr(), 831 | focal_x, 832 | focal_y, 833 | w, 834 | h, 835 | weight_normalize, 836 | sigmoid, 837 | fast 838 | ); 839 | } 840 | } 841 | 842 | void draw_backward( 843 | torch::Tensor gaussian_pos, 844 | torch::Tensor gaussian_rgb, 845 | torch::Tensor gaussian_opa, 846 | torch::Tensor gaussian_cov, 847 | torch::Tensor tile_n_point_accum, 848 | torch::Tensor output, 849 | torch::Tensor grad_output, 850 | torch::Tensor grad_pos, 851 | torch::Tensor grad_rgb, 852 | torch::Tensor grad_opa, 853 | torch::Tensor grad_cov, 854 | float focal_x, 855 | float focal_y, 856 | bool weight_normalize, 857 | bool sigmoid, 858 | bool fast 859 | ){ 860 | uint32_t h = output.size(0); 861 | uint32_t w = output.size(1); 862 | uint32_t gridsize_x = DIV_ROUND_UP(w, 16); 863 | uint32_t gridsize_y = DIV_ROUND_UP(h, 16); 864 | dim3 gridsize(gridsize_x, gridsize_y, 1); 865 | dim3 blocksize(16, 16, 1); 866 | if(fast){ 867 | draw_backward_kernel<512, float><<>>( 868 | gaussian_pos.data_ptr(), 869 | gaussian_rgb.data_ptr(), 870 | gaussian_opa.data_ptr(), 871 | gaussian_cov.data_ptr(), 872 | tile_n_point_accum.data_ptr(), 873 | output.data_ptr(), 874 | grad_output.data_ptr(), 875 | grad_pos.data_ptr(), 876 | grad_rgb.data_ptr(), 877 | grad_opa.data_ptr(), 878 | grad_cov.data_ptr(), 879 | focal_x, 880 | focal_y, 881 | w, 882 | h, 883 | weight_normalize, 884 | sigmoid, 885 | fast 886 | ); 887 | } 888 | else{ 889 | draw_backward_kernel<512, double><<>>( 890 | gaussian_pos.data_ptr(), 891 | gaussian_rgb.data_ptr(), 892 | gaussian_opa.data_ptr(), 893 | gaussian_cov.data_ptr(), 894 | tile_n_point_accum.data_ptr(), 895 | output.data_ptr(), 896 | grad_output.data_ptr(), 897 | grad_pos.data_ptr(), 898 | grad_rgb.data_ptr(), 899 | grad_opa.data_ptr(), 900 | grad_cov.data_ptr(), 901 | focal_x, 902 | focal_y, 903 | w, 904 | h, 905 | weight_normalize, 906 | sigmoid, 907 | fast 908 | ); 909 | } 910 | } 911 | 912 | __device__ void world_to_camera( 913 | const float* pos_w, 914 | const float* current_rot, 915 | const float* current_tran, 916 | float* pos_c 917 | ){ 918 | // pos_w 3 current_rot 3*3 919 | float _rot[9]; 920 | float _trans[3]; 921 | #pragma unroll 922 | for(int i=0;i<9;++i){ 923 | _rot[i] = current_rot[i]; 924 | } 925 | 926 | #pragma unroll 927 | for(int i=0;i<3;++i){ 928 | _trans[i] = current_tran[i]; 929 | } 930 | 931 | #pragma unroll 932 | for(int i=0;i<3;++i){ 933 | pos_c[i] = _rot[i*3+0] * pos_w[0] + _rot[i*3+1] * pos_w[1] + _rot[i*3+2] * pos_w[2] + _trans[i]; 934 | } 935 | } 936 | 937 | __device__ void calc_jacobian( 938 | const float* pos_camera_space, 939 | float* jacobian 940 | ){ 941 | float _res[9]; 942 | float u0 = pos_camera_space[0]; 943 | float u1 = pos_camera_space[1]; 944 | float u2 = pos_camera_space[2]; 945 | 946 | _res[0] = 1/u2; 947 | _res[1] = 0; 948 | _res[2] = -u0/(u2*u2); 949 | _res[3] = 0; 950 | _res[4] = 1/u2; 951 | _res[5] = -u1/(u2*u2); 952 | float _rsqr = rsqrtf(u0*u0+u1*u1+u2*u2); 953 | _res[6] = _rsqr * u0; 954 | _res[7] = _rsqr * u1; 955 | _res[8] = _rsqr * u2; 956 | 957 | #pragma unroll 958 | for(uint32_t i=0; i<9; ++i){ 959 | jacobian[i] = _res[i]; 960 | } 961 | } 962 | 963 | __global__ void global_culling_kernel( 964 | const float* pos, 965 | const float* quat, 966 | const float* scale, 967 | const float* current_rot, 968 | const float* current_tran, 969 | const uint32_t n_point, 970 | const float near, 971 | const float half_width, 972 | const float half_height, 973 | float* res_pos, 974 | float* res_cov, 975 | long* culling_mask 976 | // int* res_size 977 | ){ 978 | uint32_t pid = blockDim.x * blockIdx.x + threadIdx.x; 979 | if(pid >= n_point) return; 980 | pos += pid*3; 981 | 982 | // 1. calculate the camera space coordinate 983 | float pos_c[3]; 984 | world_to_camera(pos, current_rot, current_tran, pos_c); 985 | 986 | // printf("z: %f\n", pos_c[2]); 987 | 988 | // 2. check if the point is before the near plane 989 | if(pos_c[2] <= near){ 990 | // culling_mask[pid] = 0; 991 | return; 992 | } 993 | 994 | // 3. image space transform 995 | float pos_i[3]; 996 | pos_i[0] = pos_c[0] / pos_c[2]; 997 | pos_i[1] = pos_c[1] / pos_c[2]; 998 | pos_i[2] = sqrtf(pos_c[0]*pos_c[0] + pos_c[1]*pos_c[1] + pos_c[2]*pos_c[2]); 999 | 1000 | // 4. frustum culling 1001 | if(abs(pos_i[0]) >= half_width || abs(pos_i[1]) >= half_height){ 1002 | // culling_mask[pid] = 0; 1003 | return; 1004 | } 1005 | culling_mask[pid] = 1; 1006 | res_pos[pid*3 + 0] = pos_i[0]; 1007 | res_pos[pid*3 + 1] = pos_i[1]; 1008 | res_pos[pid*3 + 2] = pos_i[2]; 1009 | 1010 | 1011 | // 5. calculate the covariance matrix and jacobian 1012 | float w = quat[4*pid+0]; 1013 | float x = quat[4*pid+1]; 1014 | float y = quat[4*pid+2]; 1015 | float z = quat[4*pid+3]; 1016 | 1017 | float R[9]; 1018 | R[0] = 1 - 2*y*y - 2*z*z; 1019 | R[1] = 2*x*y - 2*z*w; 1020 | R[2] = 2*x*z + 2*y*w; 1021 | R[3] = 2*x*y + 2*z*w; 1022 | R[4] = 1 - 2*x*x - 2*z*z; 1023 | R[5] = 2*y*z - 2*x*w; 1024 | R[6] = 2*x*z - 2*y*w; 1025 | R[7] = 2*y*z + 2*x*w; 1026 | R[8] = 1 - 2*x*x - 2*y*y; 1027 | 1028 | float S[9]; 1029 | S[0] = scale[pid*3+0]; 1030 | S[1] = 0; 1031 | S[2] = 0; 1032 | S[3] = 0; 1033 | S[4] = scale[pid*3+1]; 1034 | S[5] = 0; 1035 | S[6] = 0; 1036 | S[7] = 0; 1037 | S[8] = scale[pid*3+2]; 1038 | 1039 | // RS 1040 | float RS[9]; 1041 | #pragma unroll 1042 | for(uint32_t i_r=0; i_r<3; ++i_r){ 1043 | #pragma unroll 1044 | for(uint32_t i_c=0; i_c<3; ++i_c){ 1045 | RS[i_r*3+i_c] = 0; 1046 | #pragma unroll 1047 | for(uint32_t i_k=0; i_k<3; ++i_k){ 1048 | RS[i_r*3+i_c] += R[i_r*3+i_k] * S[i_k*3+i_c]; 1049 | } 1050 | } 1051 | } 1052 | 1053 | float RSSR[9]; 1054 | #pragma unroll 1055 | for(uint32_t i_r=0; i_r<3; ++i_r){ 1056 | #pragma unroll 1057 | for(uint32_t i_c=0; i_c<3; ++i_c){ 1058 | RSSR[i_r*3+i_c] = 0; 1059 | #pragma unroll 1060 | for(uint32_t i_k=0; i_k<3; ++i_k){ 1061 | RSSR[i_r*3+i_c] += RS[i_r*3+i_k] * RS[i_c*3+i_k]; 1062 | } 1063 | } 1064 | } 1065 | 1066 | 1067 | 1068 | float jacobian[9]; 1069 | calc_jacobian(pos_c, jacobian); 1070 | 1071 | // jacobian is required to multiplied by rotation matrix, in the form of jwRSSRw'j' 1072 | 1073 | float JW[9]; 1074 | #pragma unroll 1075 | for(int i_r=0; i_r<3; ++i_r){ 1076 | #pragma unroll 1077 | for(int i_c=0; i_c<3; ++i_c){ 1078 | JW[i_r*3+i_c] = 0; 1079 | #pragma unroll 1080 | for(int i_k=0; i_k<3; ++i_k){ 1081 | JW[i_r*3+i_c] += jacobian[i_r*3+i_k] * current_rot[i_k*3+i_c]; 1082 | } 1083 | } 1084 | } 1085 | 1086 | float JWC[9]; 1087 | #pragma unroll 1088 | for(int i_r=0; i_r<3; ++i_r){ 1089 | #pragma unroll 1090 | for(int i_c=0; i_c<3; ++i_c){ 1091 | JWC[i_r*3+i_c] = 0; 1092 | #pragma unroll 1093 | for(int i_k=0; i_k<3; ++i_k){ 1094 | JWC[i_r*3+i_c] += JW[i_r*3+i_k] * RSSR[i_k*3+i_c]; 1095 | } 1096 | } 1097 | } 1098 | 1099 | float JWCWJ[9]; 1100 | #pragma unroll 1101 | for(int i_r=0; i_r<3; ++i_r){ 1102 | #pragma unroll 1103 | for(int i_c=0; i_c<3; ++i_c){ 1104 | JWCWJ[i_r*3+i_c] = 0; 1105 | #pragma unroll 1106 | for(int i_k=0; i_k<3; ++i_k){ 1107 | JWCWJ[i_r*3+i_c] += JWC[i_r*3+i_k] * JW[i_c*3+i_k]; 1108 | } 1109 | } 1110 | } 1111 | 1112 | // write back to the covariance matrix to the res_cov variables. 1113 | res_cov[pid*4+0] = JWCWJ[0]; 1114 | res_cov[pid*4+1] = JWCWJ[1]; 1115 | res_cov[pid*4+2] = JWCWJ[3]; 1116 | res_cov[pid*4+3] = JWCWJ[4]; 1117 | } 1118 | 1119 | void global_culling( 1120 | torch::Tensor pos, 1121 | torch::Tensor quat, 1122 | torch::Tensor scale, 1123 | torch::Tensor current_rot, 1124 | torch::Tensor current_tran, 1125 | torch::Tensor res_pos, 1126 | torch::Tensor res_cov, 1127 | torch::Tensor culling_mask, 1128 | float near, 1129 | float half_width, 1130 | float half_height 1131 | ){ 1132 | uint32_t n_point = pos.size(0); 1133 | uint32_t gridsize_x = DIV_ROUND_UP(n_point, 1024); 1134 | dim3 gridsize(gridsize_x, 1, 1); 1135 | dim3 blocksize(1024, 1, 1); 1136 | global_culling_kernel<<>>( 1137 | pos.data_ptr(), 1138 | quat.data_ptr(), 1139 | scale.data_ptr(), 1140 | current_rot.data_ptr(), 1141 | current_tran.data_ptr(), 1142 | n_point, 1143 | near, 1144 | half_width, 1145 | half_height, 1146 | res_pos.data_ptr(), 1147 | res_cov.data_ptr(), 1148 | culling_mask.data_ptr() 1149 | ); 1150 | } 1151 | 1152 | __global__ void global_culling_backward_kernel( 1153 | const float* pos, 1154 | const float* quat, 1155 | const float* scale, 1156 | const float* current_rot, 1157 | const float* current_tran, 1158 | const uint32_t n_point, 1159 | const float* gradout_pos, 1160 | const float* gradout_cov, 1161 | const long* culling_mask, 1162 | float* gradinput_pos, 1163 | float* gradinput_quat, 1164 | float* gradinput_scale 1165 | ){ 1166 | uint32_t pid = blockDim.x * blockIdx.x + threadIdx.x; 1167 | if(pid >= n_point) return; 1168 | pos += pid*3; 1169 | 1170 | if(culling_mask[pid]==0){ 1171 | return; 1172 | } 1173 | 1174 | // forward pass pos 0,1,2 -> pos_c 0,1,2 -> pos_i 0,1,2 1175 | float pos_c[3]; 1176 | world_to_camera(pos, current_rot, current_tran, pos_c); 1177 | 1178 | float grad_c[3]; 1179 | float pos_i_z = sqrtf(pos_c[0]*pos_c[0] + pos_c[1]*pos_c[1] + pos_c[2]*pos_c[2]); 1180 | float grad_i[3]; 1181 | grad_i[0] = gradout_pos[pid*3+0]; 1182 | grad_i[1] = gradout_pos[pid*3+1]; 1183 | grad_i[2] = gradout_pos[pid*3+2]; 1184 | 1185 | grad_c[0] = grad_i[0] / pos_c[2] + grad_i[2] * pos_c[0] / pos_i_z; 1186 | grad_c[1] = grad_i[1] / pos_c[2] + grad_i[2] * pos_c[1] / pos_i_z; 1187 | grad_c[2] = - grad_i[0] * pos_c[0] / (pos_c[2] * pos_c[2]) - grad_i[1] * pos_c[1] / (pos_c[2] * pos_c[2]) + grad_i[2] * pos_c[2] / pos_i_z; 1188 | 1189 | float grad_w[3]; 1190 | 1191 | #pragma unroll 1192 | for(int i_r=0; i_r<3; ++i_r){ 1193 | #pragma unroll 1194 | grad_w[i_r] = 0; 1195 | for(int i_k=0; i_k<3; ++i_k){ 1196 | grad_w[i_r] += current_rot[i_k*3+i_r] * grad_c[i_k]; 1197 | } 1198 | } 1199 | // write back the pos gradient 1200 | gradinput_pos[pid*3+0] = grad_w[0]; 1201 | gradinput_pos[pid*3+1] = grad_w[1]; 1202 | gradinput_pos[pid*3+2] = grad_w[2]; 1203 | 1204 | float jacobian[9]; 1205 | calc_jacobian(pos_c, jacobian); 1206 | // jacobian is required to multiplied by rotation matrix, in the form of jwRSSRw'j' 1207 | 1208 | float JW[9]; 1209 | #pragma unroll 1210 | for(int i_r=0; i_r<3; ++i_r){ 1211 | #pragma unroll 1212 | for(int i_c=0; i_c<3; ++i_c){ 1213 | JW[i_r*3+i_c] = 0; 1214 | #pragma unroll 1215 | for(int i_k=0; i_k<3; ++i_k){ 1216 | JW[i_r*3+i_c] += jacobian[i_r*3+i_k] * current_rot[i_k*3+i_c]; 1217 | } 1218 | } 1219 | } 1220 | // calc grad_3d_cov 1221 | float grad_3d_cov[9]; 1222 | 1223 | // move to register 1224 | float grad_2d_cov[4]; 1225 | #pragma unroll 1226 | for(uint32_t i=0;i<4;++i){ 1227 | grad_2d_cov[i] = gradout_cov[pid*4+i]; 1228 | } 1229 | 1230 | #pragma unroll 1231 | for(uint32_t i_r=0; i_r<3; ++i_r){ 1232 | for(uint32_t i_c=0; i_c<3; ++i_c){ 1233 | grad_3d_cov[i_r*3+i_c] = 0; 1234 | #pragma unroll 1235 | for(uint32_t i_i=0; i_i<2; ++i_i){ 1236 | #pragma unroll 1237 | for(uint32_t i_j=0; i_j<2; ++i_j){ 1238 | grad_3d_cov[i_r*3+i_c] += grad_2d_cov[i_i*2+i_j] * JW[i_i*3+i_r] * JW[i_j*3+i_c]; 1239 | } 1240 | } 1241 | } 1242 | } 1243 | 1244 | // 5. calculate the covariance matrix and jacobian 1245 | float w = quat[4*pid+0]; 1246 | float x = quat[4*pid+1]; 1247 | float y = quat[4*pid+2]; 1248 | float z = quat[4*pid+3]; 1249 | 1250 | float R[9]; 1251 | R[0] = 1 - 2*y*y - 2*z*z; 1252 | R[1] = 2*x*y - 2*z*w; 1253 | R[2] = 2*x*z + 2*y*w; 1254 | R[3] = 2*x*y + 2*z*w; 1255 | R[4] = 1 - 2*x*x - 2*z*z; 1256 | R[5] = 2*y*z - 2*x*w; 1257 | R[6] = 2*x*z - 2*y*w; 1258 | R[7] = 2*y*z + 2*x*w; 1259 | R[8] = 1 - 2*x*x - 2*y*y; 1260 | 1261 | float S[9]; 1262 | S[0] = scale[pid*3+0]; 1263 | S[1] = 0; 1264 | S[2] = 0; 1265 | S[3] = 0; 1266 | S[4] = scale[pid*3+1]; 1267 | S[5] = 0; 1268 | S[6] = 0; 1269 | S[7] = 0; 1270 | S[8] = scale[pid*3+2]; 1271 | 1272 | // RS 1273 | float RS[9]; 1274 | #pragma unroll 1275 | for(uint32_t i_r=0; i_r<3; ++i_r){ 1276 | #pragma unroll 1277 | for(uint32_t i_c=0; i_c<3; ++i_c){ 1278 | RS[i_r*3+i_c] = 0; 1279 | #pragma unroll 1280 | for(uint32_t i_k=0; i_k<3; ++i_k){ 1281 | RS[i_r*3+i_c] += R[i_r*3+i_k] * S[i_k*3+i_c]; 1282 | } 1283 | } 1284 | } 1285 | 1286 | // gradient w.r.t M=RS 1287 | float grad_RS[9]; 1288 | #pragma unroll 1289 | for(uint32_t i_r=0; i_r<3; ++i_r){ 1290 | #pragma unroll 1291 | for(uint32_t i_c=0; i_c<3; ++i_c){ 1292 | grad_RS[i_r*3+i_c] = 0; 1293 | #pragma unroll 1294 | for(uint32_t i_k=0; i_k<3; ++i_k){ 1295 | // grad_RS[i_r*3+i_c] += 2 * grad_3d_cov[i_r*3+i_k] * RS[i_c*3+i_k]; 1296 | // grad_RS[i_r*3+i_c] += (grad_3d_cov[i_k*3+i_r] + grad_3d_cov[i_r*3+i_k]) * RS[i_c*3+i_k]; 1297 | grad_RS[i_r*3+i_c] += (grad_3d_cov[i_k*3+i_r] + grad_3d_cov[i_r*3+i_k]) * RS[i_k*3+i_c]; 1298 | } 1299 | } 1300 | } 1301 | 1302 | //gradient w.r.t scale 1303 | float grad_scale[3]; 1304 | #pragma unroll 1305 | for(uint32_t i=0; i<3; ++i){ 1306 | grad_scale[i] = grad_RS[0*3+i]*R[0*3+i] + grad_RS[1*3+i]*R[1*3+i] + grad_RS[2*3+i]*R[2*3+i]; 1307 | } 1308 | 1309 | float sx = S[0]; 1310 | float sy = S[4]; 1311 | float sz = S[8]; 1312 | float qr = w; 1313 | float qi = x; 1314 | float qj = y; 1315 | float qk = z; 1316 | float gradcoeff_qr[9] = { 1317 | 0, -2*sy*qk, 2*sz*qj, 1318 | 2*sx*qk, 0, -2*sz*qi, 1319 | -2*sx*qj, 2*sy*qi, 0 1320 | }; 1321 | float gradcoeff_qi[9] = { 1322 | 0, 2*sy*qj, 2*sz*qk, 1323 | 2*sx*qj, -4*sy*qi, -2*sz*qr, 1324 | 2*sx*qk, 2*sy*qr, -4*sz*qi 1325 | }; 1326 | float gradcoeff_qj[9] ={ 1327 | -4*sx*qj, 2*sy*qi, 2*sz*qr, 1328 | 2*sx*qi, 0, 2*sz*qk, 1329 | -2*sx*qr, 2*sy*qk, -4*sz*qj 1330 | }; 1331 | float gradcoeff_qk[9] = { 1332 | -4*sx*qk, -2*sy*qr, 2*sz*qi, 1333 | 2*sx*qr, -4*sy*qk, 2*sz*qj, 1334 | 2*sx*qi, 2*sy*qj, 0 1335 | }; 1336 | float grad_qr = 0; 1337 | float grad_qi = 0; 1338 | float grad_qj = 0; 1339 | float grad_qk = 0; 1340 | 1341 | #pragma unroll 1342 | for(uint32_t i=0; i<9; ++i){ 1343 | grad_qr += gradcoeff_qr[i] * grad_RS[i]; 1344 | grad_qi += gradcoeff_qi[i] * grad_RS[i]; 1345 | grad_qj += gradcoeff_qj[i] * grad_RS[i]; 1346 | grad_qk += gradcoeff_qk[i] * grad_RS[i]; 1347 | } 1348 | //write back to global memory 1349 | gradinput_quat[pid*4+0] = grad_qr; 1350 | gradinput_quat[pid*4+1] = grad_qi; 1351 | gradinput_quat[pid*4+2] = grad_qj; 1352 | gradinput_quat[pid*4+3] = grad_qk; 1353 | 1354 | gradinput_scale[pid*3+0] = grad_scale[0]; 1355 | gradinput_scale[pid*3+1] = grad_scale[1]; 1356 | gradinput_scale[pid*3+2] = grad_scale[2]; 1357 | } 1358 | 1359 | void global_culling_backward( 1360 | torch::Tensor pos, 1361 | torch::Tensor quat, 1362 | torch::Tensor scale, 1363 | torch::Tensor current_rot, 1364 | torch::Tensor current_tran, 1365 | torch::Tensor gradout_pos, 1366 | torch::Tensor gradout_cov, 1367 | torch::Tensor culling_mask, 1368 | torch::Tensor gradinput_pos, 1369 | torch::Tensor gradinput_quat, 1370 | torch::Tensor gradinput_scale 1371 | ){ 1372 | uint32_t n_point = pos.size(0); 1373 | uint32_t gridsize_x = DIV_ROUND_UP(n_point, 1024); 1374 | dim3 gridsize(gridsize_x, 1, 1); 1375 | dim3 blocksize(1024, 1, 1); 1376 | global_culling_backward_kernel<<>>( 1377 | pos.data_ptr(), 1378 | quat.data_ptr(), 1379 | scale.data_ptr(), 1380 | current_rot.data_ptr(), 1381 | current_tran.data_ptr(), 1382 | n_point, 1383 | gradout_pos.data_ptr(), 1384 | gradout_cov.data_ptr(), 1385 | culling_mask.data_ptr(), 1386 | gradinput_pos.data_ptr(), 1387 | gradinput_quat.data_ptr(), 1388 | gradinput_scale.data_ptr() 1389 | ); 1390 | } -------------------------------------------------------------------------------- /src/include/common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef COMMON_H 3 | #define COMMON_H 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 13 | #define CHECK_CONTIGUOUS(x) \ 14 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 15 | #define CHECK_INPUT(x) \ 16 | CHECK_CUDA(x); \ 17 | CHECK_CONTIGUOUS(x) 18 | #define CHECK_CPU_INPUT(x) \ 19 | CHECK_CPU(x); \ 20 | CHECK_CONTIGUOUS(x) 21 | 22 | #if defined(__CUDACC__) 23 | // #define _EXP(x) expf(x) // SLOW EXP 24 | #define _EXP(x) __expf(x) // FAST EXP 25 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x)))) 26 | 27 | #else 28 | 29 | #define _EXP(x) expf(x) 30 | #define _SIGMOID(x) (1 / (1 + expf(-(x)))) 31 | #endif 32 | #define _SQR(x) ((x) * (x)) 33 | 34 | #define DIV_ROUND_UP(X, Y) ((X) + (Y) - 1) / (Y) 35 | 36 | struct Tiles{ 37 | torch::Tensor top; 38 | torch::Tensor bottom; 39 | torch::Tensor left; 40 | torch::Tensor right; 41 | 42 | inline void check() { 43 | CHECK_INPUT(top); 44 | CHECK_INPUT(bottom); 45 | CHECK_INPUT(left); 46 | CHECK_INPUT(right); 47 | } 48 | 49 | uint32_t len(){ 50 | return top.size(0); 51 | } 52 | }; 53 | 54 | struct Gaussian3ds { 55 | torch::Tensor pos; 56 | torch::Tensor rgb; 57 | torch::Tensor opa; 58 | torch::Tensor quat; 59 | torch::Tensor scale; 60 | torch::Tensor cov; 61 | 62 | inline void check() { 63 | CHECK_INPUT(pos); 64 | CHECK_INPUT(rgb); 65 | CHECK_INPUT(opa); 66 | CHECK_INPUT(quat); 67 | CHECK_INPUT(scale); 68 | CHECK_INPUT(cov); 69 | } 70 | 71 | uint32_t len(){ 72 | return pos.size(0); 73 | } 74 | }; 75 | 76 | #endif -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import argparse 7 | from splatter import Splatter 8 | import cv2 9 | # from torchgeometry.losses import SSIM 10 | from torchmetrics.functional import peak_signal_noise_ratio as psnr_func 11 | from torchmetrics import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio 12 | from utils import Timer 13 | # from gui import NeRFGUI 14 | from visergui import ViserViewer 15 | 16 | class Trainer: 17 | def __init__(self, gaussian_splatter, opt): 18 | self.gaussian_splatter = gaussian_splatter 19 | self.opt = opt 20 | self.lr_opa = opt.lr * opt.lr_factor_for_opa 21 | self.lr_rgb = opt.lr * opt.lr_factor_for_rgb 22 | self.lr_pos = opt.lr * 1 23 | self.lr_quat = opt.lr * opt.lr_factor_for_quat 24 | self.lr_scale = opt.lr * opt.lr_factor_for_scale 25 | self.lrs = [self.lr_opa, self.lr_rgb, self.lr_pos, self.lr_scale, self.lr_quat] 26 | 27 | warmup_iters = opt.n_iters_warmup 28 | if self.opt.lr_decay == "official": 29 | _gamma = (0.01)**(1/(self.opt.n_iters-warmup_iters)) 30 | self.lr_lambdas = [ 31 | # lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 1, 32 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 33 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 1, 34 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 35 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 1, 36 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 1, 37 | ] 38 | elif self.opt.lr_decay == "none": 39 | self.lr_lambdas = [ 40 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000), 41 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000), 42 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000), 43 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000), 44 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000), 45 | ] 46 | else: 47 | assert self.opt.lr_decay == "exp" 48 | _gamma = (0.01)**(1/(self.opt.n_iters-warmup_iters)) 49 | self.lr_lambdas = [ 50 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 51 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 52 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 53 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 54 | lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else _gamma**(i_iter-warmup_iters), 55 | ] 56 | self.optimizer = torch.optim.Adam([ 57 | {"params": gaussian_splatter.gaussian_3ds.opa, "lr": self.lr_opa * self.lr_lambdas[0](0)}, 58 | {"params": gaussian_splatter.gaussian_3ds.rgb, "lr": self.lr_rgb * self.lr_lambdas[1](0)}, 59 | {"params": gaussian_splatter.gaussian_3ds.pos, "lr": self.lr_pos * self.lr_lambdas[2](0)}, 60 | {"params": gaussian_splatter.gaussian_3ds.scale, "lr": self.lr_scale * self.lr_lambdas[3](0)}, 61 | {"params": gaussian_splatter.gaussian_3ds.quat, "lr": self.lr_quat * self.lr_lambdas[4](0)}, 62 | ], 63 | betas=(0.9, 0.99), 64 | ) 65 | 66 | if not opt.test: 67 | self.n_cameras = len(gaussian_splatter.imgs) 68 | self.test_split = np.arange(0, self.n_cameras, 8) 69 | self.train_split = np.array(list(set(np.arange(0, self.n_cameras, 1)) - set(self.test_split))) 70 | 71 | # self.ssim_criterion = SSIM(window_size=11, reduction='mean') 72 | self.ssim_criterion = StructuralSimilarityIndexMeasure(data_range=1.0).to(gaussian_splatter.device) 73 | self.psnr_metrics = PeakSignalNoiseRatio().to(gaussian_splatter.device) 74 | self.l1_losses = np.zeros(opt.n_history_track) 75 | self.psnrs = np.zeros(opt.n_history_track) 76 | self.ssim_losses = np.zeros(opt.n_history_track) 77 | 78 | self.grad_counter = 0 79 | self.clear_grad() 80 | 81 | def clear_grad(self): 82 | self.accum_max_grad = torch.zeros_like(self.gaussian_splatter.gaussian_3ds.pos) 83 | self.grad_counter = 0 84 | 85 | def train_step(self, i_iter, bar): 86 | opt = self.opt 87 | _reset_opa = i_iter % (opt.n_opa_reset) == 0 and i_iter > 0 88 | _in_reset_interval = (i_iter >= opt.n_opa_reset) and (i_iter % opt.n_opa_reset < opt.reset_interval) 89 | _adaptive_control_only_delete = (i_iter > 600 and i_iter % opt.n_adaptive_control == 0) 90 | _adaptive_control = (i_iter > 600 and i_iter < opt.adaptive_control_end_iter and i_iter % opt.n_adaptive_control == 0) 91 | _adaptive_control_accum_start = i_iter > 600 and (i_iter + opt.grad_accum_iters - 1) % opt.n_adaptive_control == 0 92 | self.optimizer.zero_grad() 93 | 94 | # forward 95 | camera_id = np.random.choice(self.train_split, 1)[0] 96 | rendered_img = self.gaussian_splatter(camera_id) 97 | 98 | # loss 99 | l1_loss = ((rendered_img - self.gaussian_splatter.ground_truth).abs()).mean() 100 | if opt.ssim_weight > 0: 101 | ssim_loss = 1. - self.ssim_criterion( 102 | rendered_img.unsqueeze(0).permute(0, 3, 1, 2), 103 | self.gaussian_splatter.ground_truth.unsqueeze(0).permute(0, 3, 1, 2).to(rendered_img) 104 | ) 105 | else: 106 | ssim_loss = torch.Tensor([0.0,]).to(l1_loss.device) 107 | loss = (1-opt.ssim_weight)*l1_loss + opt.ssim_weight*ssim_loss 108 | if opt.scale_reg > 0: 109 | loss += opt.scale_reg * self.gaussian_splatter.gaussian_3ds.scale.abs().mean() 110 | if opt.opa_reg > 0: 111 | opa_sigmoid = self.gaussian_splatter.gaussian_3ds.opa.sigmoid() 112 | loss += opt.opa_reg * (opa_sigmoid * (1-opa_sigmoid)).mean() 113 | 114 | psnr = self.psnr_metrics(rendered_img, self.gaussian_splatter.ground_truth) 115 | 116 | # optimize 117 | with Timer("backward", debug=opt.debug): 118 | loss.backward() 119 | with Timer("step", debug=opt.debug): 120 | self.optimizer.step() 121 | 122 | # historical losses for smoothing 123 | self.l1_losses = np.roll(self.l1_losses, 1) 124 | self.psnrs = np.roll(self.psnrs, 1) 125 | self.ssim_losses = np.roll(self.ssim_losses, 1) 126 | self.l1_losses[0] = l1_loss.item() 127 | self.psnrs[0] = psnr.item() 128 | self.ssim_losses[0] = ssim_loss.item() 129 | 130 | avg_l1_loss = self.l1_losses[:min(i_iter+1, self.l1_losses.shape[0])].mean() 131 | avg_ssim_loss = self.ssim_losses[:min(i_iter+1, self.ssim_losses.shape[0])].mean() 132 | avg_psnr = self.psnrs[:min(i_iter+1, self.psnrs.shape[0])].mean() 133 | 134 | # grad info for debuging 135 | grad_info = [ 136 | self.gaussian_splatter.gaussian_3ds.opa.grad.abs().mean(), 137 | self.gaussian_splatter.gaussian_3ds.rgb.grad.abs().mean(), 138 | self.gaussian_splatter.gaussian_3ds.pos.grad.abs().mean(), 139 | self.gaussian_splatter.gaussian_3ds.scale.grad.abs().mean(), 140 | self.gaussian_splatter.gaussian_3ds.quat.grad.abs().mean(), 141 | ] 142 | 143 | # log 144 | 145 | if _adaptive_control_accum_start: 146 | self.clear_grad() 147 | # self.accum_max_grad = torch.max(self.gaussian_splatter.gaussian_3ds.pos.grad, self.accum_max_grad) 148 | if opt.grad_accum_method == "mean": 149 | self.accum_max_grad += self.gaussian_splatter.gaussian_3ds.pos.grad.abs() 150 | self.grad_counter += self.gaussian_splatter.culling_mask.to(torch.float32) 151 | else: 152 | assert opt.grad_accum_method == "max" 153 | self.accum_max_grad = torch.max(self.gaussian_splatter.gaussian_3ds.pos.grad.abs(), self.accum_max_grad) 154 | self.grad_counter = 1 155 | 156 | if _adaptive_control or _adaptive_control_only_delete: 157 | # adaptive control for gaussians 158 | # grad = self.gaussian_splatter.gaussian_3ds.pos.grad 159 | # adaptive_number = (self.accum_max_grad.abs().max(-1)[0] > 0.0002).sum() 160 | # adaptive_ratio = adaptive_number / grad[..., 0].numel() 161 | self.gaussian_splatter.gaussian_3ds.adaptive_control( 162 | self.accum_max_grad/(self.grad_counter+1e-3).unsqueeze(dim=-1), 163 | taus=opt.split_thresh, 164 | delete_thresh=opt.delete_thresh, 165 | scale_activation=gaussian_splatter.scale_activation, 166 | grad_thresh=opt.grad_thresh, 167 | use_clone=opt.use_clone if (_adaptive_control and (not _in_reset_interval)) else False, 168 | use_split=opt.use_split if (_adaptive_control and (not _in_reset_interval)) else False, 169 | grad_aggregation=opt.grad_aggregation, 170 | clone_dt=opt.clone_dt, 171 | ) 172 | # optimizer = torch.optim.Adam(gaussian_splatter.parameters(), lr=lr_lambda(0), betas=(0.9, 0.99)) 173 | self.optimizer = torch.optim.Adam([ 174 | {"params": self.gaussian_splatter.gaussian_3ds.opa, "lr": self.lr_opa * self.lr_lambdas[0](i_iter)}, 175 | {"params": self.gaussian_splatter.gaussian_3ds.rgb, "lr": self.lr_rgb * self.lr_lambdas[1](i_iter)}, 176 | {"params": self.gaussian_splatter.gaussian_3ds.pos, "lr": self.lr_pos * self.lr_lambdas[2](i_iter)}, 177 | {"params": self.gaussian_splatter.gaussian_3ds.scale, "lr": self.lr_scale * self.lr_lambdas[3](i_iter)}, 178 | {"params": self.gaussian_splatter.gaussian_3ds.quat, "lr": self.lr_quat * self.lr_lambdas[4](i_iter)}, 179 | ], 180 | betas=(0.9, 0.99), 181 | ) 182 | self.clear_grad() 183 | 184 | for i_opt, (param_group, lr) in enumerate(zip(self.optimizer.param_groups, self.lrs)): 185 | param_group['lr'] = self.lr_lambdas[i_opt](i_iter) * lr 186 | # if _in_reset_interval and i_opt == 0: 187 | # param_group["lr"] = lr 188 | 189 | if i_iter % (opt.n_opa_reset) == 0 and i_iter > 0: 190 | self.gaussian_splatter.gaussian_3ds.reset_opa() 191 | 192 | return { 193 | "image": rendered_img, 194 | "loss": (1-opt.ssim_weight) * avg_l1_loss + opt.ssim_weight * avg_ssim_loss, 195 | "avg_l1_loss": avg_l1_loss, 196 | "avg_ssim_loss": avg_ssim_loss, 197 | "avg_psnr": avg_psnr, 198 | "n_tile_gaussians": self.gaussian_splatter.n_tile_gaussians, 199 | "n_gaussians": self.gaussian_splatter.n_gaussians, 200 | "grad_info": grad_info, 201 | } 202 | 203 | def train(self): 204 | bar = tqdm(range(0, opt.n_iters)) 205 | for i_iter in bar: 206 | output = self.train_step(i_iter, bar) 207 | avg_l1_loss = output["avg_l1_loss"] 208 | avg_ssim_loss = output["avg_ssim_loss"] 209 | avg_psnr = output["avg_psnr"] 210 | n_tile_gaussians = output["n_tile_gaussians"] 211 | n_gaussians = output["n_gaussians"] 212 | grad_info = output["grad_info"] 213 | 214 | grad_desc = "[{:.6f}|{:.6f}|{:.6f}|{:.6f}|{:.6f}]".format(*grad_info) 215 | bar.set_description( 216 | desc=f"loss: {avg_l1_loss:.6f}/{avg_ssim_loss:.6f}/{avg_psnr:.4f}/[{n_tile_gaussians}/{n_gaussians}]:" + 217 | f"lr: {self.optimizer.param_groups[0]['lr']:.4f}|{self.optimizer.param_groups[1]['lr']:.4f}|{self.optimizer.param_groups[2]['lr']:.4f}|{self.optimizer.param_groups[3]['lr']:.4f}|{self.optimizer.param_groups[4]['lr']:.4f} " + 218 | f"grad: {grad_desc}" 219 | ) 220 | 221 | rendered_img = output["image"] 222 | # write img 223 | if i_iter % opt.n_save_train_img == 0: 224 | img_npy = rendered_img.clip(0,1).detach().cpu().numpy() 225 | dirpath = f"{opt.exp}/imgs/" 226 | os.makedirs(dirpath, exist_ok=True) 227 | cv2.imwrite(f"{opt.exp}/imgs/train_{i_iter}.png", (img_npy*255).astype(np.uint8)[...,::-1]) 228 | self.save_checkpoint() 229 | 230 | if i_iter % 100 == 0: 231 | Timer.show_recorder() 232 | 233 | if i_iter == 400: 234 | gaussian_splatter.switch_resolution(opt.render_downsample) 235 | 236 | if i_iter % (opt.n_iters_test) == 0: 237 | test_psnrs = [] 238 | test_ssims = [] 239 | elapsed = 0 240 | for test_camera_id in self.test_split: 241 | output = self.test(test_camera_id) 242 | elapsed += output["render_time"] 243 | test_psnrs.append(output["psnr"]) 244 | test_ssims.append(output["ssim"]) 245 | # save imgs 246 | dirpath = f"{opt.exp}/test_imgs/" 247 | os.makedirs(dirpath, exist_ok=True) 248 | img_npy = output["image"].clip(0,1).detach().cpu().numpy() 249 | cv2.imwrite(f"{opt.exp}/test_imgs/iter_{i_iter}_cid_{test_camera_id}.png", (img_npy*255).astype(np.uint8)[...,::-1]) 250 | print(test_psnrs) 251 | print(test_ssims) 252 | print("TEST SPLIT PSNR: {:.4f}".format(np.mean(test_psnrs))) 253 | print("TEST SPLIT SSIM: {:.4f}".format(np.mean(test_ssims))) 254 | print("REDNDERING SPEED: {:.4f}".format(len(self.test_split)/elapsed)) 255 | 256 | @torch.no_grad() 257 | def test(self, camera_id, extrinsics=None, intrinsics=None): 258 | 259 | tic = torch.cuda.Event(enable_timing=True) 260 | toc = torch.cuda.Event(enable_timing=True) 261 | tic.record() 262 | self.gaussian_splatter.eval() 263 | rendered_img = self.gaussian_splatter(camera_id, extrinsics, intrinsics) 264 | toc.record() 265 | torch.cuda.synchronize() 266 | render_time = tic.elapsed_time(toc)/1000 267 | if camera_id is not None: 268 | psnr = self.psnr_metrics(rendered_img, self.gaussian_splatter.ground_truth).item() 269 | ssim = self.ssim_criterion( 270 | rendered_img.unsqueeze(0).permute(0, 3, 1, 2), 271 | self.gaussian_splatter.ground_truth.unsqueeze(0).permute(0, 3, 1, 2).to(rendered_img), 272 | ).item() 273 | self.gaussian_splatter.train() 274 | output = {"image": rendered_img} 275 | if camera_id is not None: 276 | output.update({ 277 | "psnr": psnr, 278 | "ssim": ssim, 279 | "render_time": render_time, 280 | }) 281 | return output 282 | 283 | def save_checkpoint(self): 284 | ckpt = { 285 | "pos": self.gaussian_splatter.gaussian_3ds.pos, 286 | "opa": self.gaussian_splatter.gaussian_3ds.opa, 287 | "rgb": self.gaussian_splatter.gaussian_3ds.rgb, 288 | "quat": self.gaussian_splatter.gaussian_3ds.quat, 289 | "scale": self.gaussian_splatter.gaussian_3ds.scale, 290 | } 291 | torch.save(ckpt, os.path.join(opt.exp, "ckpt.pth")) 292 | 293 | 294 | if __name__ == "__main__": 295 | # CUDA_VISIBLE_DEVICES=3 python train.py --exp garden_sh --grad_thresh 0.000004 --debug 1 --ssim_weight 0.1 --lr 0.002 --use_sh_coeff 0 --grad_accum_method mean --grad_accum_iters 300 // 25 296 | parser = argparse.ArgumentParser() 297 | parser.add_argument("--n_iters", type=int, default=7001) 298 | parser.add_argument("--n_iters_warmup", type=int, default=300) 299 | parser.add_argument("--n_iters_test", type=int, default=200) 300 | parser.add_argument("--n_history_track", type=int, default=100) 301 | parser.add_argument("--n_save_train_img", type=int, default=100) 302 | parser.add_argument("--n_adaptive_control", type=int, default=100) 303 | parser.add_argument("--render_downsample_start", type=int, default=4) 304 | parser.add_argument("--render_downsample", type=int, default=4) 305 | parser.add_argument("--jacobian_track", type=int, default=0) 306 | parser.add_argument("--data", type=str, default="colmap_garden/") 307 | parser.add_argument("--scale_init_value", type=float, default=1) 308 | parser.add_argument("--opa_init_value", type=float, default=0.3) 309 | parser.add_argument("--tile_culling_dist_thresh", type=float, default=0.5) 310 | parser.add_argument("--tile_culling_prob_thresh", type=float, default=0.05) 311 | parser.add_argument("--tile_culling_method", type=str, default="prob2", choices=["dist", "prob", "prob2"]) 312 | 313 | # learning rate 314 | parser.add_argument("--lr", type=float, default=0.003) 315 | parser.add_argument("--lr_factor_for_scale", type=float, default=1) 316 | parser.add_argument("--lr_factor_for_rgb", type=float, default=10) 317 | parser.add_argument("--lr_factor_for_opa", type=float, default=10) 318 | parser.add_argument("--lr_factor_for_quat", type=float, default=1) 319 | parser.add_argument("--lr_decay", type=str, default="exp", choices=["none", "official", "exp"]) 320 | 321 | parser.add_argument("--delete_thresh", type=float, default=1.5) 322 | parser.add_argument("--n_opa_reset", type=int, default=10000000) 323 | parser.add_argument("--reset_interval", type=int, default=500) 324 | parser.add_argument("--split_thresh", type=float, default=0.05) 325 | parser.add_argument("--ssim_weight", type=float, default=0.1) 326 | parser.add_argument("--debug", type=int, default=0) 327 | parser.add_argument("--use_sh_coeff", type=int, default=0) 328 | parser.add_argument("--scale_reg", type=float, default=0) 329 | parser.add_argument("--opa_reg", type=float, default=0) 330 | parser.add_argument("--cudaculling", type=int, default=1) 331 | parser.add_argument("--adaptive_lr", type=int, default=0) 332 | parser.add_argument("--seed", type=int, default=2023) 333 | parser.add_argument("--ckpt", type=str, default="") 334 | parser.add_argument("--scale_activation", type=str, default="abs", choices=["abs", "exp"]) 335 | parser.add_argument("--fast_drawing", type=int, default=1) 336 | parser.add_argument("--exp", type=str, default="default") 337 | 338 | # adaptive control 339 | # parser.add_argument("--grad_accum_iters", type=int, default=20) 340 | parser.add_argument("--grad_accum_iters", type=int, default=50) 341 | parser.add_argument("--grad_accum_method", type=str, default="max", choices=["mean", "max"]) 342 | parser.add_argument("--grad_thresh", type=float, default=0.0002) 343 | parser.add_argument("--use_clone", type=int, default=0) 344 | parser.add_argument("--use_split", type=int, default=1) 345 | parser.add_argument("--clone_dt", type=float, default=0.01) 346 | parser.add_argument("--grad_aggregation", type=str, default="max", choices=["max", "mean"]) 347 | parser.add_argument("--adaptive_control_end_iter", type=int, default=1000000000) 348 | 349 | # GUI related 350 | parser.add_argument("--gui", default=0, type=int) 351 | parser.add_argument("--test", default=0, type=int) 352 | parser.add_argument("--H", default=768, type=int) 353 | parser.add_argument("--W", default=1024, type=int) 354 | parser.add_argument("--radius", default=5.0, type=float) 355 | parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy") 356 | parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") 357 | #parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 358 | parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 359 | parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") 360 | parser.add_argument('--bound', type=float, default=10, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") 361 | 362 | 363 | opt = parser.parse_args() 364 | np.random.seed(opt.seed) 365 | if opt.jacobian_track: 366 | jacobian_calc="torch" 367 | else: 368 | jacobian_calc="cuda" 369 | data_path = os.path.join(opt.data, 'sparse', '0') 370 | img_path = os.path.join(opt.data, f'images_{opt.render_downsample_start}') 371 | 372 | if opt.ckpt == "": 373 | opt.ckpt = None 374 | gaussian_splatter = Splatter( 375 | data_path, 376 | img_path, 377 | render_weight_normalize=False, 378 | render_downsample=opt.render_downsample, 379 | use_sh_coeff=opt.use_sh_coeff, 380 | scale_init_value=opt.scale_init_value, 381 | opa_init_value=opt.opa_init_value, 382 | tile_culling_method=opt.tile_culling_method, 383 | tile_culling_dist_thresh=opt.tile_culling_dist_thresh, 384 | tile_culling_prob_thresh=opt.tile_culling_prob_thresh, 385 | debug=opt.debug, 386 | scale_activation=opt.scale_activation, 387 | cudaculling=opt.cudaculling, 388 | load_ckpt=opt.ckpt, 389 | fast_drawing=opt.fast_drawing, 390 | test=opt.test, 391 | #jacobian_calc="torch", 392 | ) 393 | trainer = Trainer(gaussian_splatter, opt) 394 | if opt.gui: 395 | assert opt.test == 1 396 | # gui = NeRFGUI(opt, trainer) 397 | # gui.render() 398 | gui = ViserViewer(device=gaussian_splatter.device, viewer_port=6789) 399 | gui.set_renderer(trainer) 400 | while(True): 401 | gui.update() 402 | else: 403 | trainer.train() 404 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """Lie group interface for rigid transforms, ported from 2 | [jaxlie](https://github.com/brentyi/jaxlie). Used by `viser` internally and 3 | in examples. 4 | 5 | Implements SO(2), SO(3), SE(2), and SE(3) Lie groups. Rotations are parameterized 6 | via S^1 and S^3. 7 | """ 8 | 9 | from ._base import MatrixLieGroup as MatrixLieGroup 10 | from ._base import SEBase as SEBase 11 | from ._base import SOBase as SOBase 12 | from ._se2 import SE2 as SE2 13 | from ._se3 import SE3 as SE3 14 | from ._so2 import SO2 as SO2 15 | from ._so3 import SO3 as SO3 16 | -------------------------------------------------------------------------------- /transforms/_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import ClassVar, Generic, Type, TypeVar, Union, overload 3 | 4 | import numpy as onp 5 | import numpy.typing as onpt 6 | from typing_extensions import final, override 7 | 8 | from . import hints 9 | 10 | GroupType = TypeVar("GroupType", bound="MatrixLieGroup") 11 | SEGroupType = TypeVar("SEGroupType", bound="SEBase") 12 | 13 | 14 | class MatrixLieGroup(abc.ABC): 15 | """Interface definition for matrix Lie groups.""" 16 | 17 | # Class properties. 18 | # > These will be set in `_utils.register_lie_group()`. 19 | 20 | matrix_dim: ClassVar[int] 21 | """Dimension of square matrix output from `.as_matrix()`.""" 22 | 23 | parameters_dim: ClassVar[int] 24 | """Dimension of underlying parameters, `.parameters()`.""" 25 | 26 | tangent_dim: ClassVar[int] 27 | """Dimension of tangent space.""" 28 | 29 | space_dim: ClassVar[int] 30 | """Dimension of coordinates that can be transformed.""" 31 | 32 | def __init__( 33 | # Notes: 34 | # - For the constructor signature to be consistent with subclasses, `parameters` 35 | # should be marked as positional-only. But this isn't possible in Python 3.7. 36 | # - This method is implicitly overriden by the dataclass decorator and 37 | # should _not_ be marked abstract. 38 | self, 39 | parameters: onpt.NDArray[onp.floating], 40 | ): 41 | """Construct a group object from its underlying parameters.""" 42 | raise NotImplementedError() 43 | 44 | # Shared implementations. 45 | 46 | @overload 47 | def __matmul__(self: GroupType, other: GroupType) -> GroupType: 48 | ... 49 | 50 | @overload 51 | def __matmul__(self, other: hints.Array) -> onpt.NDArray[onp.floating]: 52 | ... 53 | 54 | def __matmul__( 55 | self: GroupType, other: Union[GroupType, hints.Array] 56 | ) -> Union[GroupType, onpt.NDArray[onp.floating]]: 57 | """Overload for the `@` operator. 58 | 59 | Switches between the group action (`.apply()`) and multiplication 60 | (`.multiply()`) based on the type of `other`. 61 | """ 62 | if isinstance(other, onp.ndarray): 63 | return self.apply(target=other) 64 | elif isinstance(other, MatrixLieGroup): 65 | assert self.space_dim == other.space_dim 66 | return self.multiply(other=other) 67 | else: 68 | assert False, f"Invalid argument type for `@` operator: {type(other)}" 69 | 70 | # Factory. 71 | 72 | @classmethod 73 | @abc.abstractmethod 74 | def identity(cls: Type[GroupType]) -> GroupType: 75 | """Returns identity element. 76 | 77 | Returns: 78 | Identity element. 79 | """ 80 | 81 | @classmethod 82 | @abc.abstractmethod 83 | def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType: 84 | """Get group member from matrix representation. 85 | 86 | Args: 87 | matrix: Matrix representaiton. 88 | 89 | Returns: 90 | Group member. 91 | """ 92 | 93 | # Accessors. 94 | 95 | @abc.abstractmethod 96 | def as_matrix(self) -> onpt.NDArray[onp.floating]: 97 | """Get transformation as a matrix. Homogeneous for SE groups.""" 98 | 99 | @abc.abstractmethod 100 | def parameters(self) -> onpt.NDArray[onp.floating]: 101 | """Get underlying representation.""" 102 | 103 | # Operations. 104 | 105 | @abc.abstractmethod 106 | def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: 107 | """Applies group action to a point. 108 | 109 | Args: 110 | target: Point to transform. 111 | 112 | Returns: 113 | Transformed point. 114 | """ 115 | 116 | @abc.abstractmethod 117 | def multiply(self: GroupType, other: GroupType) -> GroupType: 118 | """Composes this transformation with another. 119 | 120 | Returns: 121 | self @ other 122 | """ 123 | 124 | @classmethod 125 | @abc.abstractmethod 126 | def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType: 127 | """Computes `expm(wedge(tangent))`. 128 | 129 | Args: 130 | tangent: Tangent vector to take the exponential of. 131 | 132 | Returns: 133 | Output. 134 | """ 135 | 136 | @abc.abstractmethod 137 | def log(self) -> onpt.NDArray[onp.floating]: 138 | """Computes `vee(logm(transformation matrix))`. 139 | 140 | Returns: 141 | Output. Shape should be `(tangent_dim,)`. 142 | """ 143 | 144 | @abc.abstractmethod 145 | def adjoint(self) -> onpt.NDArray[onp.floating]: 146 | """Computes the adjoint, which transforms tangent vectors between tangent 147 | spaces. 148 | 149 | More precisely, for a transform `GroupType`: 150 | ``` 151 | GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType 152 | ``` 153 | 154 | In robotics, typically used for transforming twists, wrenches, and Jacobians 155 | across different reference frames. 156 | 157 | Returns: 158 | Output. Shape should be `(tangent_dim, tangent_dim)`. 159 | """ 160 | 161 | @abc.abstractmethod 162 | def inverse(self: GroupType) -> GroupType: 163 | """Computes the inverse of our transform. 164 | 165 | Returns: 166 | Output. 167 | """ 168 | 169 | @abc.abstractmethod 170 | def normalize(self: GroupType) -> GroupType: 171 | """Normalize/projects values and returns. 172 | 173 | Returns: 174 | GroupType: Normalized group member. 175 | """ 176 | 177 | # @classmethod 178 | # @abc.abstractmethod 179 | # def sample_uniform(cls: Type[GroupType], key: hints.KeyArray) -> GroupType: 180 | # """Draw a uniform sample from the group. Translations (if applicable) are in the 181 | # range [-1, 1]. 182 | # 183 | # Args: 184 | # key: PRNG key, as returned by `jax.random.PRNGKey()`. 185 | # 186 | # Returns: 187 | # Sampled group member. 188 | # """ 189 | 190 | 191 | class SOBase(MatrixLieGroup): 192 | """Base class for special orthogonal groups.""" 193 | 194 | 195 | ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) 196 | 197 | 198 | class SEBase(Generic[ContainedSOType], MatrixLieGroup): 199 | """Base class for special Euclidean groups. 200 | 201 | Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional 202 | translation vector. 203 | """ 204 | 205 | # SE-specific interface. 206 | 207 | @classmethod 208 | @abc.abstractmethod 209 | def from_rotation_and_translation( 210 | cls: Type[SEGroupType], 211 | rotation: ContainedSOType, 212 | translation: hints.Array, 213 | ) -> SEGroupType: 214 | """Construct a rigid transform from a rotation and a translation. 215 | 216 | Args: 217 | rotation: Rotation term. 218 | translation: translation term. 219 | 220 | Returns: 221 | Constructed transformation. 222 | """ 223 | 224 | @final 225 | @classmethod 226 | def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: 227 | return cls.from_rotation_and_translation( 228 | rotation=rotation, 229 | translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype), 230 | ) 231 | 232 | @classmethod 233 | @abc.abstractmethod 234 | def from_translation( 235 | cls: Type[SEGroupType], translation: onpt.NDArray[onp.floating] 236 | ) -> SEGroupType: 237 | """Construct a transform from a translation term.""" 238 | 239 | @abc.abstractmethod 240 | def rotation(self) -> ContainedSOType: 241 | """Returns a transform's rotation term.""" 242 | 243 | @abc.abstractmethod 244 | def translation(self) -> onpt.NDArray[onp.floating]: 245 | """Returns a transform's translation term.""" 246 | 247 | # Overrides. 248 | 249 | @final 250 | @override 251 | def apply(self, target: hints.Array) -> onpt.NDArray[onp.floating]: 252 | return self.rotation() @ target + self.translation() # type: ignore 253 | 254 | @final 255 | @override 256 | def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: 257 | return type(self).from_rotation_and_translation( 258 | rotation=self.rotation() @ other.rotation(), 259 | translation=(self.rotation() @ other.translation()) + self.translation(), 260 | ) 261 | 262 | @final 263 | @override 264 | def inverse(self: SEGroupType) -> SEGroupType: 265 | R_inv = self.rotation().inverse() 266 | return type(self).from_rotation_and_translation( 267 | rotation=R_inv, 268 | translation=-(R_inv @ self.translation()), 269 | ) 270 | 271 | @final 272 | @override 273 | def normalize(self: SEGroupType) -> SEGroupType: 274 | return type(self).from_rotation_and_translation( 275 | rotation=self.rotation().normalize(), 276 | translation=self.translation(), 277 | ) 278 | -------------------------------------------------------------------------------- /transforms/_se2.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import cast 3 | 4 | import numpy as onp 5 | import numpy.typing as onpt 6 | from typing_extensions import override 7 | 8 | from . import _base, hints 9 | from ._so2 import SO2 10 | from .utils import get_epsilon, register_lie_group 11 | 12 | 13 | @register_lie_group( 14 | matrix_dim=3, 15 | parameters_dim=4, 16 | tangent_dim=3, 17 | space_dim=2, 18 | ) 19 | @dataclasses.dataclass 20 | class SE2(_base.SEBase[SO2]): 21 | """Special Euclidean group for proper rigid transforms in 2D. 22 | 23 | Ported to numpy from `jaxlie.SE2`. 24 | 25 | Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx, 26 | vy, omega)`. 27 | """ 28 | 29 | # SE2-specific. 30 | 31 | unit_complex_xy: onpt.NDArray[onp.floating] 32 | """Internal parameters. `(cos, sin, x, y)`.""" 33 | 34 | @override 35 | def __repr__(self) -> str: 36 | unit_complex = onp.round(self.unit_complex_xy[..., :2], 5) 37 | xy = onp.round(self.unit_complex_xy[..., 2:], 5) 38 | return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" 39 | 40 | @staticmethod 41 | def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": 42 | """Construct a transformation from standard 2D pose parameters. 43 | 44 | Note that this is not the same as integrating over a length-3 twist. 45 | """ 46 | cos = onp.cos(theta) 47 | sin = onp.sin(theta) 48 | return SE2(unit_complex_xy=onp.array([cos, sin, x, y])) 49 | 50 | # SE-specific. 51 | 52 | @staticmethod 53 | @override 54 | def from_rotation_and_translation( 55 | rotation: SO2, 56 | translation: hints.Array, 57 | ) -> "SE2": 58 | assert translation.shape == (2,) 59 | return SE2( 60 | unit_complex_xy=onp.concatenate([rotation.unit_complex, translation]) 61 | ) 62 | 63 | @override 64 | @classmethod 65 | def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> "SE2": 66 | return SE2.from_rotation_and_translation(SO2.identity(), translation) 67 | 68 | @override 69 | def rotation(self) -> SO2: 70 | return SO2(unit_complex=self.unit_complex_xy[..., :2]) 71 | 72 | @override 73 | def translation(self) -> onpt.NDArray[onp.floating]: 74 | return self.unit_complex_xy[..., 2:] 75 | 76 | # Factory. 77 | 78 | @staticmethod 79 | @override 80 | def identity() -> "SE2": 81 | return SE2(unit_complex_xy=onp.array([1.0, 0.0, 0.0, 0.0])) 82 | 83 | @staticmethod 84 | @override 85 | def from_matrix(matrix: hints.Array) -> "SE2": 86 | assert matrix.shape == (3, 3) 87 | # Currently assumes bottom row is [0, 0, 1]. 88 | return SE2.from_rotation_and_translation( 89 | rotation=SO2.from_matrix(matrix[:2, :2]), 90 | translation=matrix[:2, 2], 91 | ) 92 | 93 | # Accessors. 94 | 95 | @override 96 | def parameters(self) -> onpt.NDArray[onp.floating]: 97 | return self.unit_complex_xy 98 | 99 | @override 100 | def as_matrix(self) -> onpt.NDArray[onp.floating]: 101 | cos, sin, x, y = self.unit_complex_xy 102 | return onp.array( 103 | [ 104 | [cos, -sin, x], 105 | [sin, cos, y], 106 | [0.0, 0.0, 1.0], 107 | ] 108 | ) 109 | 110 | # Operations. 111 | 112 | @staticmethod 113 | @override 114 | def exp(tangent: hints.Array) -> "SE2": 115 | # Reference: 116 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 117 | # Also see: 118 | # > http://ethaneade.com/lie.pdf 119 | 120 | assert tangent.shape == (3,) 121 | 122 | theta = tangent[2] 123 | use_taylor = onp.abs(theta) < get_epsilon(tangent.dtype) 124 | 125 | # Shim to avoid NaNs in onp.where branches, which cause failures for 126 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 127 | safe_theta = cast( 128 | onpt.NDArray[onp.floating], 129 | onp.where( 130 | use_taylor, 131 | 1.0, # Any non-zero value should do here. 132 | theta, 133 | ), 134 | ) 135 | 136 | theta_sq = theta**2 137 | sin_over_theta = cast( 138 | onpt.NDArray[onp.floating], 139 | onp.where( 140 | use_taylor, 141 | 1.0 - theta_sq / 6.0, 142 | onp.sin(safe_theta) / safe_theta, 143 | ), 144 | ) 145 | one_minus_cos_over_theta = cast( 146 | onpt.NDArray[onp.floating], 147 | onp.where( 148 | use_taylor, 149 | 0.5 * theta - theta * theta_sq / 24.0, 150 | (1.0 - onp.cos(safe_theta)) / safe_theta, 151 | ), 152 | ) 153 | 154 | V = onp.array( 155 | [ 156 | [sin_over_theta, -one_minus_cos_over_theta], 157 | [one_minus_cos_over_theta, sin_over_theta], 158 | ] 159 | ) 160 | return SE2.from_rotation_and_translation( 161 | rotation=SO2.from_radians(theta), 162 | translation=V @ tangent[:2], 163 | ) 164 | 165 | @override 166 | def log(self) -> onpt.NDArray[onp.floating]: 167 | # Reference: 168 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160 169 | # Also see: 170 | # > http://ethaneade.com/lie.pdf 171 | 172 | theta = self.rotation().log()[0] 173 | 174 | cos = onp.cos(theta) 175 | cos_minus_one = cos - 1.0 176 | half_theta = theta / 2.0 177 | use_taylor = onp.abs(cos_minus_one) < get_epsilon(theta.dtype) 178 | 179 | # Shim to avoid NaNs in onp.where branches, which cause failures for 180 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 181 | safe_cos_minus_one = onp.where( 182 | use_taylor, 183 | 1.0, # Any non-zero value should do here. 184 | cos_minus_one, 185 | ) 186 | 187 | half_theta_over_tan_half_theta = onp.where( 188 | use_taylor, 189 | # Taylor approximation. 190 | 1.0 - theta**2 / 12.0, 191 | # Default. 192 | -(half_theta * onp.sin(theta)) / safe_cos_minus_one, 193 | ) 194 | 195 | V_inv = onp.array( 196 | [ 197 | [half_theta_over_tan_half_theta, half_theta], 198 | [-half_theta, half_theta_over_tan_half_theta], 199 | ] 200 | ) 201 | 202 | tangent = onp.concatenate([V_inv @ self.translation(), theta[None]]) 203 | return tangent 204 | 205 | @override 206 | def adjoint(self: "SE2") -> onpt.NDArray[onp.floating]: 207 | cos, sin, x, y = self.unit_complex_xy 208 | return onp.array( 209 | [ 210 | [cos, -sin, y], 211 | [sin, cos, -x], 212 | [0.0, 0.0, 1.0], 213 | ] 214 | ) 215 | 216 | # @staticmethod 217 | # @override 218 | # def sample_uniform(key: hints.KeyArray) -> "SE2": 219 | # key0, key1 = jax.random.split(key) 220 | # return SE2.from_rotation_and_translation( 221 | # rotation=SO2.sample_uniform(key0), 222 | # translation=jax.random.uniform( 223 | # key=key1, shape=(2,), minval=-1.0, maxval=1.0 224 | # ), 225 | # ) 226 | -------------------------------------------------------------------------------- /transforms/_se3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | from typing import cast 5 | 6 | import numpy as onp 7 | import numpy.typing as onpt 8 | from typing_extensions import override 9 | 10 | from . import _base 11 | from ._so3 import SO3 12 | from .utils import get_epsilon, register_lie_group 13 | 14 | 15 | def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: 16 | """Returns the skew-symmetric form of a length-3 vector.""" 17 | 18 | wx, wy, wz = omega 19 | return onp.array( 20 | [ # type: ignore 21 | [0.0, -wz, wy], 22 | [wz, 0.0, -wx], 23 | [-wy, wx, 0.0], 24 | ] 25 | ) 26 | 27 | 28 | @register_lie_group( 29 | matrix_dim=4, 30 | parameters_dim=7, 31 | tangent_dim=6, 32 | space_dim=3, 33 | ) 34 | @dataclasses.dataclass 35 | class SE3(_base.SEBase[SO3]): 36 | """Special Euclidean group for proper rigid transforms in 3D. 37 | 38 | Ported to numpy from `jaxlie.SE3`. 39 | 40 | Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization 41 | is `(vx, vy, vz, omega_x, omega_y, omega_z)`. 42 | """ 43 | 44 | # SE3-specific. 45 | 46 | wxyz_xyz: onpt.NDArray[onp.floating] 47 | """Internal parameters. wxyz quaternion followed by xyz translation.""" 48 | 49 | @override 50 | def __repr__(self) -> str: 51 | quat = onp.round(self.wxyz_xyz[..., :4], 5) 52 | trans = onp.round(self.wxyz_xyz[..., 4:], 5) 53 | return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" 54 | 55 | # SE-specific. 56 | 57 | @staticmethod 58 | @override 59 | def from_rotation_and_translation( 60 | rotation: SO3, 61 | translation: onpt.NDArray[onp.floating], 62 | ) -> SE3: 63 | assert translation.shape == (3,) 64 | return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation])) 65 | 66 | @override 67 | @classmethod 68 | def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> "SE3": 69 | return SE3.from_rotation_and_translation(SO3.identity(), translation) 70 | 71 | @override 72 | def rotation(self) -> SO3: 73 | return SO3(wxyz=self.wxyz_xyz[..., :4]) 74 | 75 | @override 76 | def translation(self) -> onpt.NDArray[onp.floating]: 77 | return self.wxyz_xyz[..., 4:] 78 | 79 | # Factory. 80 | 81 | @staticmethod 82 | @override 83 | def identity() -> SE3: 84 | return SE3(wxyz_xyz=onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])) 85 | 86 | @staticmethod 87 | @override 88 | def from_matrix(matrix: onpt.NDArray[onp.floating]) -> SE3: 89 | assert matrix.shape == (4, 4) 90 | # Currently assumes bottom row is [0, 0, 0, 1]. 91 | return SE3.from_rotation_and_translation( 92 | rotation=SO3.from_matrix(matrix[:3, :3]), 93 | translation=matrix[:3, 3], 94 | ) 95 | 96 | # Accessors. 97 | 98 | @override 99 | def as_matrix(self) -> onpt.NDArray[onp.floating]: 100 | out = onp.eye(4) 101 | out[:3, :3] = self.rotation().as_matrix() 102 | out[:3, 3] = self.translation() 103 | return out 104 | # return ( 105 | # onp.eye(4) 106 | # .at[:3, :3] 107 | # .set(self.rotation().as_matrix()) 108 | # .at[:3, 3] 109 | # .set(self.translation()) 110 | # ) 111 | 112 | @override 113 | def parameters(self) -> onpt.NDArray[onp.floating]: 114 | return self.wxyz_xyz 115 | 116 | # Operations. 117 | 118 | @staticmethod 119 | @override 120 | def exp(tangent: onpt.NDArray[onp.floating]) -> SE3: 121 | # Reference: 122 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 123 | 124 | # (x, y, z, omega_x, omega_y, omega_z) 125 | assert tangent.shape == (6,) 126 | 127 | rotation = SO3.exp(tangent[3:]) 128 | 129 | theta_squared = tangent[3:] @ tangent[3:] 130 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 131 | 132 | # Shim to avoid NaNs in onp.where branches, which cause failures for 133 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 134 | theta_squared_safe = cast( 135 | onpt.NDArray[onp.floating], 136 | onp.where( 137 | use_taylor, 138 | 1.0, # Any non-zero value should do here. 139 | theta_squared, 140 | ), 141 | ) 142 | del theta_squared 143 | theta_safe = onp.sqrt(theta_squared_safe) 144 | 145 | skew_omega = _skew(tangent[3:]) 146 | V = onp.where( 147 | use_taylor, 148 | rotation.as_matrix(), 149 | ( 150 | onp.eye(3) 151 | + (1.0 - onp.cos(theta_safe)) / (theta_squared_safe) * skew_omega 152 | # We can drop this type: ignore after upgrading numpy / dropping Python 153 | # 3.7 support. 154 | + (theta_safe - onp.sin(theta_safe)) # type: ignore 155 | / (theta_squared_safe * theta_safe) 156 | * (skew_omega @ skew_omega) 157 | ), 158 | ) 159 | 160 | return SE3.from_rotation_and_translation( 161 | rotation=rotation, 162 | translation=V @ tangent[:3], 163 | ) 164 | 165 | @override 166 | def log(self) -> onpt.NDArray[onp.floating]: 167 | # Reference: 168 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 169 | omega = self.rotation().log() 170 | theta_squared = omega @ omega 171 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 172 | 173 | skew_omega = _skew(omega) 174 | 175 | # Shim to avoid NaNs in onp.where branches, which cause failures for 176 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 177 | theta_squared_safe = onp.where( 178 | use_taylor, 179 | 1.0, # Any non-zero value should do here. 180 | theta_squared, 181 | ) 182 | del theta_squared 183 | theta_safe = onp.sqrt(theta_squared_safe) 184 | half_theta_safe = theta_safe / 2.0 185 | 186 | V_inv = onp.where( 187 | use_taylor, 188 | onp.eye(3) - 0.5 * skew_omega + (skew_omega @ skew_omega) / 12.0, 189 | ( 190 | onp.eye(3) 191 | - 0.5 * skew_omega 192 | + ( 193 | 1.0 194 | - theta_safe 195 | * onp.cos(half_theta_safe) 196 | / (2.0 * onp.sin(half_theta_safe)) 197 | ) 198 | / theta_squared_safe 199 | * (skew_omega @ skew_omega) 200 | ), 201 | ) 202 | return onp.concatenate([V_inv @ self.translation(), omega]) 203 | 204 | @override 205 | def adjoint(self) -> onpt.NDArray[onp.floating]: 206 | R = self.rotation().as_matrix() 207 | return onp.block( 208 | [ 209 | [R, _skew(self.translation()) @ R], 210 | [onp.zeros((3, 3)), R], 211 | ] 212 | ) 213 | 214 | # @staticmethod 215 | # @override 216 | # def sample_uniform(key: hints.KeyArray) -> SE3: 217 | # key0, key1 = jax.random.split(key) 218 | # return SE3.from_rotation_and_translation( 219 | # rotation=SO3.sample_uniform(key0), 220 | # translation=jax.random.uniform( 221 | # key=key1, shape=(3,), minval=-1.0, maxval=1.0 222 | # ), 223 | # ) 224 | -------------------------------------------------------------------------------- /transforms/_so2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | 5 | import numpy as onp 6 | import numpy.typing as onpt 7 | from typing_extensions import override 8 | 9 | from . import _base, hints 10 | from .utils import register_lie_group 11 | 12 | 13 | @register_lie_group( 14 | matrix_dim=2, 15 | parameters_dim=2, 16 | tangent_dim=1, 17 | space_dim=2, 18 | ) 19 | @dataclasses.dataclass 20 | class SO2(_base.SOBase): 21 | """Special orthogonal group for 2D rotations. 22 | 23 | Ported to numpy from `jaxlie.SO2`. 24 | 25 | Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`. 26 | """ 27 | 28 | # SO2-specific. 29 | 30 | unit_complex: onpt.NDArray[onp.floating] 31 | """Internal parameters. `(cos, sin)`.""" 32 | 33 | @override 34 | def __repr__(self) -> str: 35 | unit_complex = onp.round(self.unit_complex, 5) 36 | return f"{self.__class__.__name__}(unit_complex={unit_complex})" 37 | 38 | @staticmethod 39 | def from_radians(theta: hints.Scalar) -> SO2: 40 | """Construct a rotation object from a scalar angle.""" 41 | cos = onp.cos(theta) 42 | sin = onp.sin(theta) 43 | return SO2(unit_complex=onp.array([cos, sin])) 44 | 45 | def as_radians(self) -> onpt.NDArray[onp.floating]: 46 | """Compute a scalar angle from a rotation object.""" 47 | radians = self.log()[..., 0] 48 | return radians 49 | 50 | # Factory. 51 | 52 | @staticmethod 53 | @override 54 | def identity() -> SO2: 55 | return SO2(unit_complex=onp.array([1.0, 0.0])) 56 | 57 | @staticmethod 58 | @override 59 | def from_matrix(matrix: onpt.NDArray[onp.floating]) -> SO2: 60 | assert matrix.shape == (2, 2) 61 | return SO2(unit_complex=onp.asarray(matrix[:, 0])) 62 | 63 | # Accessors. 64 | 65 | @override 66 | def as_matrix(self) -> onpt.NDArray[onp.floating]: 67 | cos_sin = self.unit_complex 68 | out = onp.array( 69 | [ 70 | # [cos, -sin], 71 | cos_sin * onp.array([1, -1]), 72 | # [sin, cos], 73 | cos_sin[::-1], 74 | ] 75 | ) 76 | assert out.shape == (2, 2) 77 | return out 78 | 79 | @override 80 | def parameters(self) -> onpt.NDArray[onp.floating]: 81 | return self.unit_complex 82 | 83 | # Operations. 84 | 85 | @override 86 | def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: 87 | assert target.shape == (2,) 88 | return self.as_matrix() @ target # type: ignore 89 | 90 | @override 91 | def multiply(self, other: SO2) -> SO2: 92 | return SO2(unit_complex=self.as_matrix() @ other.unit_complex) 93 | 94 | @staticmethod 95 | @override 96 | def exp(tangent: onpt.NDArray[onp.floating]) -> SO2: 97 | (theta,) = tangent 98 | cos = onp.cos(theta) 99 | sin = onp.sin(theta) 100 | return SO2(unit_complex=onp.array([cos, sin])) 101 | 102 | @override 103 | def log(self) -> onpt.NDArray[onp.floating]: 104 | return onp.arctan2( 105 | self.unit_complex[..., 1, None], self.unit_complex[..., 0, None] 106 | ) 107 | 108 | @override 109 | def adjoint(self) -> onpt.NDArray[onp.floating]: 110 | return onp.eye(1) 111 | 112 | @override 113 | def inverse(self) -> SO2: 114 | return SO2(unit_complex=self.unit_complex * onp.array([1, -1])) 115 | 116 | @override 117 | def normalize(self) -> SO2: 118 | return SO2(unit_complex=self.unit_complex / onp.linalg.norm(self.unit_complex)) 119 | 120 | # @staticmethod 121 | # @override 122 | # def sample_uniform(key: hints.KeyArray) -> SO2: 123 | # return SO2.from_radians( 124 | # jax.random.uniform(key=key, minval=0.0, maxval=2.0 * onp.pi) 125 | # ) 126 | -------------------------------------------------------------------------------- /transforms/_so3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | 5 | import numpy as onp 6 | import numpy.typing as onpt 7 | from typing_extensions import override 8 | 9 | from . import _base, hints 10 | from .utils import get_epsilon, register_lie_group 11 | 12 | 13 | @register_lie_group( 14 | matrix_dim=3, 15 | parameters_dim=4, 16 | tangent_dim=3, 17 | space_dim=3, 18 | ) 19 | @dataclasses.dataclass 20 | class SO3(_base.SOBase): 21 | """Special orthogonal group for 3D rotations. 22 | 23 | Ported to numpy from `jaxlie.SO3`. 24 | 25 | Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is 26 | `(omega_x, omega_y, omega_z)`. 27 | """ 28 | 29 | # SO3-specific. 30 | 31 | wxyz: onpt.NDArray[onp.floating] 32 | """Internal parameters. `(w, x, y, z)` quaternion.""" 33 | 34 | @override 35 | def __repr__(self) -> str: 36 | wxyz = onp.round(self.wxyz, 5) 37 | return f"{self.__class__.__name__}(wxyz={wxyz})" 38 | 39 | @staticmethod 40 | def from_x_radians(theta: hints.Scalar) -> SO3: 41 | """Generates a x-axis rotation. 42 | 43 | Args: 44 | angle: X rotation, in radians. 45 | 46 | Returns: 47 | Output. 48 | """ 49 | return SO3.exp(onp.array([theta, 0.0, 0.0])) 50 | 51 | @staticmethod 52 | def from_y_radians(theta: hints.Scalar) -> SO3: 53 | """Generates a y-axis rotation. 54 | 55 | Args: 56 | angle: Y rotation, in radians. 57 | 58 | Returns: 59 | Output. 60 | """ 61 | return SO3.exp(onp.array([0.0, theta, 0.0])) 62 | 63 | @staticmethod 64 | def from_z_radians(theta: hints.Scalar) -> SO3: 65 | """Generates a z-axis rotation. 66 | 67 | Args: 68 | angle: Z rotation, in radians. 69 | 70 | Returns: 71 | Output. 72 | """ 73 | return SO3.exp(onp.array([0.0, 0.0, theta])) 74 | 75 | @staticmethod 76 | def from_rpy_radians( 77 | roll: hints.Scalar, 78 | pitch: hints.Scalar, 79 | yaw: hints.Scalar, 80 | ) -> SO3: 81 | """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot 82 | convention. 83 | 84 | Args: 85 | roll: X rotation, in radians. Applied first. 86 | pitch: Y rotation, in radians. Applied second. 87 | yaw: Z rotation, in radians. Applied last. 88 | 89 | Returns: 90 | Output. 91 | """ 92 | return ( 93 | SO3.from_z_radians(yaw) 94 | @ SO3.from_y_radians(pitch) 95 | @ SO3.from_x_radians(roll) 96 | ) 97 | 98 | @staticmethod 99 | def from_quaternion_xyzw(xyzw: onpt.NDArray[onp.floating]) -> SO3: 100 | """Construct a rotation from an `xyzw` quaternion. 101 | 102 | Note that `wxyz` quaternions can be constructed using the default dataclass 103 | constructor. 104 | 105 | Args: 106 | xyzw: xyzw quaternion. Shape should be (4,). 107 | 108 | Returns: 109 | Output. 110 | """ 111 | assert xyzw.shape == (4,) 112 | return SO3(onp.roll(xyzw, shift=1)) 113 | 114 | def as_quaternion_xyzw(self) -> onpt.NDArray[onp.floating]: 115 | """Grab parameters as xyzw quaternion.""" 116 | return onp.roll(self.wxyz, shift=-1) 117 | 118 | def as_rpy_radians(self) -> hints.RollPitchYaw: 119 | """Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention. 120 | 121 | Returns: 122 | Named tuple containing Euler angles in radians. 123 | """ 124 | return hints.RollPitchYaw( 125 | roll=self.compute_roll_radians(), 126 | pitch=self.compute_pitch_radians(), 127 | yaw=self.compute_yaw_radians(), 128 | ) 129 | 130 | def compute_roll_radians(self) -> onpt.NDArray[onp.floating]: 131 | """Compute roll angle. Uses the ZYX mobile robot convention. 132 | 133 | Returns: 134 | Euler angle in radians. 135 | """ 136 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 137 | q0, q1, q2, q3 = self.wxyz 138 | return onp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) 139 | 140 | def compute_pitch_radians(self) -> onpt.NDArray[onp.floating]: 141 | """Compute pitch angle. Uses the ZYX mobile robot convention. 142 | 143 | Returns: 144 | Euler angle in radians. 145 | """ 146 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 147 | q0, q1, q2, q3 = self.wxyz 148 | return onp.arcsin(2 * (q0 * q2 - q3 * q1)) 149 | 150 | def compute_yaw_radians(self) -> onpt.NDArray[onp.floating]: 151 | """Compute yaw angle. Uses the ZYX mobile robot convention. 152 | 153 | Returns: 154 | Euler angle in radians. 155 | """ 156 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 157 | q0, q1, q2, q3 = self.wxyz 158 | return onp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) 159 | 160 | # Factory. 161 | 162 | @staticmethod 163 | @override 164 | def identity() -> SO3: 165 | return SO3(wxyz=onp.array([1.0, 0.0, 0.0, 0.0])) 166 | 167 | @staticmethod 168 | @override 169 | def from_matrix(matrix: onpt.NDArray[onp.floating]) -> SO3: 170 | assert matrix.shape == (3, 3) 171 | 172 | # Modified from: 173 | # > "Converting a Rotation Matrix to a Quaternion" from Mike Day 174 | # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf 175 | 176 | def case0(m): 177 | t = 1 + m[0, 0] - m[1, 1] - m[2, 2] 178 | q = onp.array( 179 | [ 180 | m[2, 1] - m[1, 2], 181 | t, 182 | m[1, 0] + m[0, 1], 183 | m[0, 2] + m[2, 0], 184 | ] 185 | ) 186 | return t, q 187 | 188 | def case1(m): 189 | t = 1 - m[0, 0] + m[1, 1] - m[2, 2] 190 | q = onp.array( 191 | [ 192 | m[0, 2] - m[2, 0], 193 | m[1, 0] + m[0, 1], 194 | t, 195 | m[2, 1] + m[1, 2], 196 | ] 197 | ) 198 | return t, q 199 | 200 | def case2(m): 201 | t = 1 - m[0, 0] - m[1, 1] + m[2, 2] 202 | q = onp.array( 203 | [ 204 | m[1, 0] - m[0, 1], 205 | m[0, 2] + m[2, 0], 206 | m[2, 1] + m[1, 2], 207 | t, 208 | ] 209 | ) 210 | return t, q 211 | 212 | def case3(m): 213 | t = 1 + m[0, 0] + m[1, 1] + m[2, 2] 214 | q = onp.array( 215 | [ 216 | t, 217 | m[2, 1] - m[1, 2], 218 | m[0, 2] - m[2, 0], 219 | m[1, 0] - m[0, 1], 220 | ] 221 | ) 222 | return t, q 223 | 224 | # Compute four cases, then pick the most precise one. 225 | # Probably worth revisiting this! 226 | case0_t, case0_q = case0(matrix) 227 | case1_t, case1_q = case1(matrix) 228 | case2_t, case2_q = case2(matrix) 229 | case3_t, case3_q = case3(matrix) 230 | 231 | cond0 = matrix[2, 2] < 0 232 | cond1 = matrix[0, 0] > matrix[1, 1] 233 | cond2 = matrix[0, 0] < -matrix[1, 1] 234 | 235 | t = onp.where( 236 | cond0, 237 | onp.where(cond1, case0_t, case1_t), 238 | onp.where(cond2, case2_t, case3_t), 239 | ) 240 | q = onp.where( 241 | cond0, 242 | onp.where(cond1, case0_q, case1_q), 243 | onp.where(cond2, case2_q, case3_q), 244 | ) 245 | 246 | # We can also choose to branch, but this is slower. 247 | # t, q = jax.lax.cond( 248 | # matrix[2, 2] < 0, 249 | # true_fun=lambda matrix: jax.lax.cond( 250 | # matrix[0, 0] > matrix[1, 1], 251 | # true_fun=case0, 252 | # false_fun=case1, 253 | # operand=matrix, 254 | # ), 255 | # false_fun=lambda matrix: jax.lax.cond( 256 | # matrix[0, 0] < -matrix[1, 1], 257 | # true_fun=case2, 258 | # false_fun=case3, 259 | # operand=matrix, 260 | # ), 261 | # operand=matrix, 262 | # ) 263 | 264 | return SO3(wxyz=q * 0.5 / onp.sqrt(t)) 265 | 266 | # Accessors. 267 | 268 | @override 269 | def as_matrix(self) -> onpt.NDArray[onp.floating]: 270 | norm = self.wxyz @ self.wxyz 271 | q = self.wxyz * onp.sqrt(2.0 / norm) 272 | q = onp.outer(q, q) 273 | return onp.array( 274 | [ 275 | [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0]], 276 | [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0]], 277 | [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2]], 278 | ] 279 | ) 280 | 281 | @override 282 | def parameters(self) -> onpt.NDArray[onp.floating]: 283 | return self.wxyz 284 | 285 | # Operations. 286 | 287 | @override 288 | def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: 289 | assert target.shape == (3,) 290 | 291 | # Compute using quaternion multiplys. 292 | padded_target = onp.concatenate([onp.zeros(1), target]) 293 | return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[1:] 294 | 295 | @override 296 | def multiply(self, other: SO3) -> SO3: 297 | w0, x0, y0, z0 = self.wxyz 298 | w1, x1, y1, z1 = other.wxyz 299 | return SO3( 300 | wxyz=onp.array( 301 | [ 302 | -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, 303 | x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, 304 | -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, 305 | x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, 306 | ] 307 | ) 308 | ) 309 | 310 | @staticmethod 311 | @override 312 | def exp(tangent: onpt.NDArray[onp.floating]) -> SO3: 313 | # Reference: 314 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 315 | 316 | assert tangent.shape == (3,) 317 | 318 | theta_squared = tangent @ tangent 319 | theta_pow_4 = theta_squared * theta_squared 320 | use_taylor = theta_squared < get_epsilon(tangent.dtype) 321 | 322 | # Shim to avoid NaNs in onp.where branches, which cause failures for 323 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 324 | safe_theta = onp.sqrt( 325 | onp.where( 326 | use_taylor, 327 | 1.0, # Any constant value should do here. 328 | theta_squared, 329 | ) 330 | ) 331 | safe_half_theta = 0.5 * safe_theta 332 | 333 | real_factor = onp.where( 334 | use_taylor, 335 | 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, 336 | onp.cos(safe_half_theta), 337 | ) 338 | 339 | imaginary_factor = onp.where( 340 | use_taylor, 341 | 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, 342 | onp.sin(safe_half_theta) / safe_theta, 343 | ) 344 | 345 | return SO3( 346 | wxyz=onp.concatenate( 347 | [ 348 | real_factor[None], 349 | imaginary_factor * tangent, 350 | ] 351 | ) 352 | ) 353 | 354 | @override 355 | def log(self) -> onpt.NDArray[onp.floating]: 356 | # Reference: 357 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 358 | 359 | w = self.wxyz[..., 0] 360 | norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:] 361 | use_taylor = norm_sq < get_epsilon(norm_sq.dtype) 362 | 363 | # Shim to avoid NaNs in onp.where branches, which cause failures for 364 | # reverse-mode AD. (note: this is needed in JAX, but not in numpy) 365 | norm_safe = onp.sqrt( 366 | onp.where( 367 | use_taylor, 368 | 1.0, # Any non-zero value should do here. 369 | norm_sq, 370 | ) 371 | ) 372 | w_safe = onp.where(use_taylor, w, 1.0) 373 | atan_n_over_w = onp.arctan2( 374 | onp.where(w < 0, -norm_safe, norm_safe), 375 | onp.abs(w), 376 | ) 377 | atan_factor = onp.where( 378 | use_taylor, 379 | 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, 380 | onp.where( 381 | onp.abs(w) < get_epsilon(w.dtype), 382 | onp.where(w > 0, 1.0, -1.0) * onp.pi / norm_safe, 383 | 2.0 * atan_n_over_w / norm_safe, 384 | ), 385 | ) 386 | 387 | return atan_factor * self.wxyz[1:] 388 | 389 | @override 390 | def adjoint(self) -> onpt.NDArray[onp.floating]: 391 | return self.as_matrix() 392 | 393 | @override 394 | def inverse(self) -> SO3: 395 | # Negate complex terms. 396 | return SO3(wxyz=self.wxyz * onp.array([1, -1, -1, -1])) 397 | 398 | @override 399 | def normalize(self) -> SO3: 400 | return SO3(wxyz=self.wxyz / onp.linalg.norm(self.wxyz)) 401 | 402 | # @staticmethod 403 | # @override 404 | # def sample_uniform(key: hints.KeyArray) -> SO3: 405 | # # Uniformly sample over S^3. 406 | # # > Reference: http://planning.cs.uiuc.edu/node198.html 407 | # u1, u2, u3 = jax.random.uniform( 408 | # key=key, 409 | # shape=(3,), 410 | # minval=onp.zeros(3), 411 | # maxval=onp.array([1.0, 2.0 * onp.pi, 2.0 * onp.pi]), 412 | # ) 413 | # a = onp.sqrt(1.0 - u1) 414 | # b = onp.sqrt(u1) 415 | # 416 | # return SO3( 417 | # wxyz=onp.array( 418 | # [ 419 | # a * onp.sin(u2), 420 | # a * onp.cos(u2), 421 | # b * onp.sin(u3), 422 | # b * onp.cos(u3), 423 | # ] 424 | # ) 425 | # ) 426 | -------------------------------------------------------------------------------- /transforms/hints/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Union 2 | 3 | import numpy as onp 4 | 5 | # Type aliases for JAX/Numpy arrays; primarily for function inputs. 6 | 7 | Array = onp.ndarray 8 | """Type alias for `onp.ndarray`.""" 9 | 10 | Scalar = Union[float, Array] 11 | """Type alias for `Union[float, Array]`.""" 12 | 13 | 14 | class RollPitchYaw(NamedTuple): 15 | """Tuple containing roll, pitch, and yaw Euler angles.""" 16 | 17 | roll: Scalar 18 | pitch: Scalar 19 | yaw: Scalar 20 | 21 | 22 | __all__ = [ 23 | "Array", 24 | "Scalar", 25 | "RollPitchYaw", 26 | ] 27 | -------------------------------------------------------------------------------- /transforms/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import get_epsilon, register_lie_group 2 | 3 | __all__ = ["get_epsilon", "register_lie_group"] 4 | -------------------------------------------------------------------------------- /transforms/utils/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Type, TypeVar 2 | 3 | import numpy as onp 4 | 5 | if TYPE_CHECKING: 6 | from .._base import MatrixLieGroup 7 | 8 | 9 | T = TypeVar("T", bound="MatrixLieGroup") 10 | 11 | 12 | def get_epsilon(dtype: onp.dtype) -> float: 13 | """Helper for grabbing type-specific precision constants. 14 | 15 | Args: 16 | dtype: Datatype. 17 | 18 | Returns: 19 | Output float. 20 | """ 21 | return { 22 | onp.dtype("float32"): 1e-5, 23 | onp.dtype("float64"): 1e-10, 24 | }[dtype] 25 | 26 | 27 | def register_lie_group( 28 | *, 29 | matrix_dim: int, 30 | parameters_dim: int, 31 | tangent_dim: int, 32 | space_dim: int, 33 | ) -> Callable[[Type[T]], Type[T]]: 34 | """Decorator for registering Lie group dataclasses. 35 | 36 | Sets dimensionality class variables, and (formerly in the JAX version) marks all methods for JIT compilation. 37 | """ 38 | 39 | def _wrap(cls: Type[T]) -> Type[T]: 40 | # Register dimensions as class attributes. 41 | cls.matrix_dim = matrix_dim 42 | cls.parameters_dim = parameters_dim 43 | cls.tangent_dim = tangent_dim 44 | cls.space_dim = space_dim 45 | 46 | return cls 47 | 48 | return _wrap 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import time 3 | import numpy as np 4 | from dataclasses import dataclass 5 | from typing import Dict, Union 6 | from pathlib import Path 7 | import torch 8 | import math 9 | from collections import defaultdict 10 | from pprint import pprint 11 | from kornia import create_meshgrid 12 | 13 | @dataclass(frozen=True) 14 | class CameraModel: 15 | model_id: int 16 | model_name: str 17 | num_params: int 18 | 19 | 20 | @dataclass(frozen=True) 21 | class Camera: 22 | id: int 23 | model: str 24 | width: int 25 | height: int 26 | params: np.ndarray 27 | 28 | 29 | @dataclass(frozen=True) 30 | class BaseImage: 31 | id: int 32 | qvec: np.ndarray 33 | tvec: np.ndarray 34 | camera_id: int 35 | name: str 36 | xys: np.ndarray 37 | point3D_ids: np.ndarray 38 | 39 | 40 | @dataclass(frozen=True) 41 | class Point3D: 42 | id: int 43 | xyz: np.ndarray 44 | rgb: np.ndarray 45 | error: Union[float, np.ndarray] 46 | image_ids: np.ndarray 47 | point2D_idxs: np.ndarray 48 | 49 | 50 | class Image(BaseImage): 51 | def qvec2rotmat(self): 52 | return qvec2rotmat(self.qvec) 53 | 54 | 55 | CAMERA_MODELS = { 56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 63 | CameraModel(model_id=7, model_name="FOV", num_params=5), 64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), 67 | } 68 | CAMERA_MODEL_IDS = dict( 69 | [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] 70 | ) 71 | 72 | 73 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 74 | """Read and unpack the next bytes from a binary file. 75 | :param fid: 76 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 77 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 78 | :param endian_character: Any of {@, =, <, >, !} 79 | :return: Tuple of read and unpacked values. 80 | """ 81 | data = fid.read(num_bytes) 82 | return struct.unpack(endian_character + format_char_sequence, data) 83 | 84 | 85 | def read_cameras_text(path: Union[str, Path]) -> Dict[int, Camera]: 86 | """ 87 | see: src/base/reconstruction.cc 88 | void Reconstruction::WriteCamerasText(const std::string& path) 89 | void Reconstruction::ReadCamerasText(const std::string& path) 90 | """ 91 | cameras = {} 92 | with open(path, "r") as fid: 93 | while True: 94 | line = fid.readline() 95 | if not line: 96 | break 97 | line = line.strip() 98 | if len(line) > 0 and line[0] != "#": 99 | elems = line.split() 100 | camera_id = int(elems[0]) 101 | model = elems[1] 102 | width = int(elems[2]) 103 | height = int(elems[3]) 104 | params = np.array(tuple(map(float, elems[4:]))) 105 | cameras[camera_id] = Camera( 106 | id=camera_id, model=model, width=width, height=height, params=params 107 | ) 108 | return cameras 109 | 110 | 111 | def read_cameras_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Camera]: 112 | """ 113 | see: src/base/reconstruction.cc 114 | void Reconstruction::WriteCamerasBinary(const std::string& path) 115 | void Reconstruction::ReadCamerasBinary(const std::string& path) 116 | """ 117 | cameras = {} 118 | with open(path_to_model_file, "rb") as fid: 119 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 120 | for camera_line_index in range(num_cameras): 121 | camera_properties = read_next_bytes( 122 | fid, num_bytes=24, format_char_sequence="iiQQ" 123 | ) 124 | camera_id = camera_properties[0] 125 | model_id = camera_properties[1] 126 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 127 | width = camera_properties[2] 128 | height = camera_properties[3] 129 | num_params = CAMERA_MODEL_IDS[model_id].num_params 130 | params = read_next_bytes( 131 | fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params 132 | ) 133 | cameras[camera_id] = Camera( 134 | id=camera_id, 135 | model=model_name, 136 | width=width, 137 | height=height, 138 | params=np.array(params), 139 | ) 140 | assert len(cameras) == num_cameras 141 | return cameras 142 | 143 | 144 | def read_images_text(path: Union[str, Path]) -> Dict[int, Image]: 145 | """ 146 | see: src/base/reconstruction.cc 147 | void Reconstruction::ReadImagesText(const std::string& path) 148 | void Reconstruction::WriteImagesText(const std::string& path) 149 | """ 150 | images = {} 151 | with open(path, "r") as fid: 152 | while True: 153 | line = fid.readline() 154 | if not line: 155 | break 156 | line = line.strip() 157 | if len(line) > 0 and line[0] != "#": 158 | elems = line.split() 159 | image_id = int(elems[0]) 160 | qvec = np.array(tuple(map(float, elems[1:5]))) 161 | tvec = np.array(tuple(map(float, elems[5:8]))) 162 | camera_id = int(elems[8]) 163 | image_name = elems[9] 164 | elems = fid.readline().split() 165 | xys = np.column_stack( 166 | [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))] 167 | ) 168 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 169 | images[image_id] = Image( 170 | id=image_id, 171 | qvec=qvec, 172 | tvec=tvec, 173 | camera_id=camera_id, 174 | name=image_name, 175 | xys=xys, 176 | point3D_ids=point3D_ids, 177 | ) 178 | return images 179 | 180 | 181 | def read_images_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Image]: 182 | """ 183 | see: src/base/reconstruction.cc 184 | void Reconstruction::ReadImagesBinary(const std::string& path) 185 | void Reconstruction::WriteImagesBinary(const std::string& path) 186 | """ 187 | images = {} 188 | with open(path_to_model_file, "rb") as fid: 189 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 190 | for image_index in range(num_reg_images): 191 | binary_image_properties = read_next_bytes( 192 | fid, num_bytes=64, format_char_sequence="idddddddi" 193 | ) 194 | image_id = binary_image_properties[0] 195 | qvec = np.array(binary_image_properties[1:5]) 196 | tvec = np.array(binary_image_properties[5:8]) 197 | camera_id = binary_image_properties[8] 198 | image_name = "" 199 | current_char = read_next_bytes(fid, 1, "c")[0] 200 | while current_char != b"\x00": # look for the ASCII 0 entry 201 | image_name += current_char.decode("utf-8") 202 | current_char = read_next_bytes(fid, 1, "c")[0] 203 | num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 204 | 0 205 | ] 206 | x_y_id_s = read_next_bytes( 207 | fid, 208 | num_bytes=24 * num_points2D, 209 | format_char_sequence="ddq" * num_points2D, 210 | ) 211 | xys = np.column_stack( 212 | [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))] 213 | ) 214 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 215 | images[image_id] = Image( 216 | id=image_id, 217 | qvec=qvec, 218 | tvec=tvec, 219 | camera_id=camera_id, 220 | name=image_name, 221 | xys=xys, 222 | point3D_ids=point3D_ids, 223 | ) 224 | return images 225 | 226 | 227 | def read_points3D_text(path: Union[str, Path]): 228 | """ 229 | see: src/base/reconstruction.cc 230 | void Reconstruction::ReadPoints3DText(const std::string& path) 231 | void Reconstruction::WritePoints3DText(const std::string& path) 232 | """ 233 | points3D = {} 234 | with open(path, "r") as fid: 235 | while True: 236 | line = fid.readline() 237 | if not line: 238 | break 239 | line = line.strip() 240 | if len(line) > 0 and line[0] != "#": 241 | elems = line.split() 242 | point3D_id = int(elems[0]) 243 | xyz = np.array(tuple(map(float, elems[1:4]))) 244 | rgb = np.array(tuple(map(int, elems[4:7]))) 245 | error = float(elems[7]) 246 | image_ids = np.array(tuple(map(int, elems[8::2]))) 247 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 248 | points3D[point3D_id] = Point3D( 249 | id=point3D_id, 250 | xyz=xyz, 251 | rgb=rgb, 252 | error=error, 253 | image_ids=image_ids, 254 | point2D_idxs=point2D_idxs, 255 | ) 256 | return points3D 257 | 258 | 259 | def read_points3d_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Point3D]: 260 | """ 261 | see: src/base/reconstruction.cc 262 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 263 | void Reconstruction::WritePoints3DBinary(const std::string& path) 264 | """ 265 | points3D = {} 266 | with open(path_to_model_file, "rb") as fid: 267 | num_points = read_next_bytes(fid, 8, "Q")[0] 268 | for point_line_index in range(num_points): 269 | binary_point_line_properties = read_next_bytes( 270 | fid, num_bytes=43, format_char_sequence="QdddBBBd" 271 | ) 272 | point3D_id = binary_point_line_properties[0] 273 | xyz = np.array(binary_point_line_properties[1:4]) 274 | rgb = np.array(binary_point_line_properties[4:7]) 275 | error = np.array(binary_point_line_properties[7]) 276 | track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[ 277 | 0 278 | ] 279 | track_elems = read_next_bytes( 280 | fid, 281 | num_bytes=8 * track_length, 282 | format_char_sequence="ii" * track_length, 283 | ) 284 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 285 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 286 | points3D[point3D_id] = Point3D( 287 | id=point3D_id, 288 | xyz=xyz, 289 | rgb=rgb, 290 | error=error, 291 | image_ids=image_ids, 292 | point2D_idxs=point2D_idxs, 293 | ) 294 | return points3D 295 | 296 | 297 | def qvec2rotmat(qvec): 298 | return np.array( 299 | [ 300 | [ 301 | 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 302 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 303 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], 304 | ], 305 | [ 306 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 307 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 308 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], 309 | ], 310 | [ 311 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 312 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 313 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, 314 | ], 315 | ] 316 | ) 317 | 318 | def q2r(qvec): 319 | # qvec B x 4 320 | qvec = qvec / qvec.norm(dim=1, keepdim=True) 321 | rot = [ 322 | 1 - 2 * qvec[:, 2] ** 2 - 2 * qvec[:, 3] ** 2, 323 | 2 * qvec[:, 1] * qvec[:, 2] - 2 * qvec[:, 0] * qvec[:, 3], 324 | 2 * qvec[:, 3] * qvec[:, 1] + 2 * qvec[:, 0] * qvec[:, 2], 325 | 2 * qvec[:, 1] * qvec[:, 2] + 2 * qvec[:, 0] * qvec[:, 3], 326 | 1 - 2 * qvec[:, 1] ** 2 - 2 * qvec[:, 3] ** 2, 327 | 2 * qvec[:, 2] * qvec[:, 3] - 2 * qvec[:, 0] * qvec[:, 1], 328 | 2 * qvec[:, 3] * qvec[:, 1] - 2 * qvec[:, 0] * qvec[:, 2], 329 | 2 * qvec[:, 2] * qvec[:, 3] + 2 * qvec[:, 0] * qvec[:, 1], 330 | 1 - 2 * qvec[:, 1] ** 2 - 2 * qvec[:, 2] ** 2, 331 | ] 332 | rot = torch.stack(rot, dim=1).reshape(-1, 3, 3) 333 | return rot 334 | 335 | def jacobian_torch(a): 336 | _rsqr = 1./(a[:, 0]**2 + a[:, 1]**2 + a[:, 2]**2).sqrt() 337 | _res = [ 338 | 1/a[:,2], torch.zeros_like(a[:,0]), -a[:,0]/(a[:,2]**2), 339 | torch.zeros_like(a[:,0]), 1/a[:,2], -a[:,1]/(a[:,2]**2), 340 | _rsqr * a[:, 0], _rsqr * a[:, 1], _rsqr * a[:, 2] 341 | ] 342 | return torch.stack(_res, dim=-1).reshape(-1, 3, 3) 343 | 344 | 345 | def initialize_sh(rgbs): 346 | sh_coeff = torch.zeros(rgbs.shape[0], 3, 9, device=rgbs.device, dtype=rgbs.dtype) 347 | sh_coeff[:, :, 0] = rgbs / 0.28209479177387814 348 | return sh_coeff.flatten(1) 349 | 350 | def inverse_sigmoid(y=0.001): 351 | return -math.log(1/y - 1) 352 | 353 | def inverse_sigmoid_torch(y): 354 | return -torch.log(1/y - 1) 355 | 356 | 357 | class Timer: 358 | recorder = defaultdict(list) 359 | 360 | def __init__(self, des="", verbose=False, record=True, debug=True) -> None: 361 | self.des = des 362 | self.verbose = verbose 363 | self.record = record 364 | self.debug = debug 365 | 366 | def __enter__(self): 367 | if not self.debug: 368 | return self 369 | self.start = time.time() 370 | self.start_cuda = torch.cuda.Event(enable_timing=True) 371 | self.end_cuda = torch.cuda.Event(enable_timing=True) 372 | self.start_cuda.record() 373 | return self 374 | 375 | def __exit__(self, *args): 376 | if not self.debug: 377 | return 378 | self.end = time.time() 379 | self.end_cuda.record() 380 | torch.cuda.synchronize() 381 | self.interval = self.start_cuda.elapsed_time(self.end_cuda)/1000. 382 | if self.verbose: 383 | print(f"[cudasync]{self.des} consuming {self.interval:.8f}") 384 | if self.record: 385 | Timer.recorder[self.des].append(self.interval) 386 | 387 | @staticmethod 388 | def show_recorder(): 389 | pprint({k: np.mean(v) for k, v in Timer.recorder.items()}) 390 | 391 | def sample_two_point(gaussian_pos, gaussian_cov): 392 | # gaussian_cov: (..., 3, 3) 393 | # gaussian_pos: (..., 3) 394 | # n_samples: (...) 395 | # return: (..., n_samples, 3) 396 | dist = torch.distributions.multivariate_normal.MultivariateNormal( 397 | gaussian_pos, 398 | gaussian_cov, 399 | ) 400 | p1 = dist.sample() 401 | p2 = dist.sample() 402 | return p1, p2 403 | 404 | def clamp(x): 405 | return torch.clamp(x, min=0, max=1) 406 | 407 | 408 | def get_rays_direction_in_camera_space(H, W, focal): 409 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 410 | i, j = grid.unbind(-1) 411 | cent = [W/2, H/2] 412 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) 413 | return directions 414 | 415 | def get_rays_direction(w2c_rot, H, W, focal): 416 | c2w = torch.inverse(w2c_rot) 417 | directions = get_rays_direction_in_camera_space(H, W, focal) 418 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 419 | return rays_d -------------------------------------------------------------------------------- /visergui.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | import torch 3 | import numpy as np 4 | import time 5 | import viser 6 | import viser.transforms as tf 7 | from omegaconf import OmegaConf 8 | from utils import qvec2rotmat 9 | import cv2 10 | from utils import Timer 11 | from collections import deque 12 | 13 | 14 | def get_c2w(camera): 15 | c2w = np.eye(4, dtype=np.float32) 16 | c2w[:3, :3] = qvec2rotmat(camera.wxyz) 17 | c2w[:3, 3] = camera.position 18 | return c2w 19 | 20 | def get_w2c(camera): 21 | c2w = get_c2w(camera) 22 | w2c = np.linalg.inv(c2w) 23 | return w2c 24 | 25 | class RenderThread(Thread): 26 | pass 27 | 28 | 29 | class ViserViewer: 30 | def __init__(self, device, viewer_port): 31 | self.device = device 32 | self.port = viewer_port 33 | 34 | self.render_times = deque(maxlen=3) 35 | self.server = viser.ViserServer(port=self.port) 36 | self.reset_view_button = self.server.add_gui_button("Reset View") 37 | 38 | self.need_update = False 39 | 40 | self.pause_training = False 41 | self.train_viewer_update_period_slider = self.server.add_gui_slider( 42 | "Train Viewer Update Period", 43 | min=1, 44 | max=100, 45 | step=1, 46 | initial_value=10, 47 | disabled=self.pause_training, 48 | ) 49 | 50 | self.pause_training_button = self.server.add_gui_button("Pause Training") 51 | self.sh_order = self.server.add_gui_slider( 52 | "SH Order", min=1, max=4, step=1, initial_value=1 53 | ) 54 | self.resolution_slider = self.server.add_gui_slider( 55 | "Resolution", min=384, max=4096, step=2, initial_value=1024 56 | ) 57 | self.near_plane_slider = self.server.add_gui_slider( 58 | "Near", min=0.1, max=30, step=0.5, initial_value=0.1 59 | ) 60 | self.far_plane_slider = self.server.add_gui_slider( 61 | "Far", min=30.0, max=1000.0, step=10.0, initial_value=1000.0 62 | ) 63 | 64 | self.show_train_camera = self.server.add_gui_checkbox( 65 | "Show Train Camera", initial_value=False 66 | ) 67 | 68 | self.fps = self.server.add_gui_text("FPS", initial_value="-1", disabled=True) 69 | 70 | @self.show_train_camera.on_update 71 | def _(_): 72 | self.need_update = True 73 | 74 | @self.resolution_slider.on_update 75 | def _(_): 76 | self.need_update = True 77 | 78 | @self.near_plane_slider.on_update 79 | def _(_): 80 | self.need_update = True 81 | 82 | @self.far_plane_slider.on_update 83 | def _(_): 84 | self.need_update = True 85 | 86 | @self.pause_training_button.on_click 87 | def _(_): 88 | self.pause_training = not self.pause_training 89 | self.train_viewer_update_period_slider.disabled = not self.pause_training 90 | self.pause_training_button.name = ( 91 | "Resume Training" if self.pause_training else "Pause Training" 92 | ) 93 | 94 | @self.reset_view_button.on_click 95 | def _(_): 96 | self.need_update = True 97 | for client in self.server.get_clients().values(): 98 | client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( 99 | [0.0, -1.0, 0.0] 100 | ) 101 | 102 | self.c2ws = [] 103 | self.camera_infos = [] 104 | 105 | @self.resolution_slider.on_update 106 | def _(_): 107 | self.need_update = True 108 | 109 | @self.server.on_client_connect 110 | def _(client: viser.ClientHandle): 111 | @client.camera.on_update 112 | def _(_): 113 | self.need_update = True 114 | 115 | self.debug_idx = 0 116 | 117 | def set_renderer(self, renderer): 118 | self.renderer = renderer 119 | 120 | @torch.no_grad() 121 | def update(self): 122 | if self.need_update: 123 | start = time.time() 124 | for client in self.server.get_clients().values(): 125 | camera = client.camera 126 | w2c = get_w2c(camera) 127 | try: 128 | W = self.resolution_slider.value 129 | H = int(self.resolution_slider.value/camera.aspect) 130 | focal_x = W/2/np.tan(camera.fov/2) 131 | focal_y = H/2/np.tan(camera.fov/2) 132 | 133 | start_cuda = torch.cuda.Event(enable_timing=True) 134 | end_cuda = torch.cuda.Event(enable_timing=True) 135 | start_cuda.record() 136 | 137 | outputs = self.renderer.test( 138 | None, 139 | extrinsics={ 140 | "rot": w2c[:3,:3], 141 | "tran": w2c[:3, 3], 142 | }, 143 | intrinsics={ 144 | "width": W, 145 | "height": H, 146 | "focal_x": focal_x, 147 | "focal_y": focal_y, 148 | } 149 | ) 150 | end_cuda.record() 151 | torch.cuda.synchronize() 152 | interval = start_cuda.elapsed_time(end_cuda)/1000. 153 | 154 | out = outputs["image"].cpu().detach().numpy().astype(np.float32) 155 | except RuntimeError as e: 156 | print(e) 157 | interval = 1 158 | continue 159 | client.set_background_image(out, format="jpeg") 160 | self.debug_idx += 1 161 | # if self.debug_idx % 100 == 0: 162 | # cv2.imwrite( 163 | # f"./tmp/viewer/debug_{self.debug_idx}.png", 164 | # cv2.cvtColor(out, cv2.COLOR_RGB2BGR), 165 | # ) 166 | 167 | self.render_times.append(interval) 168 | self.fps.value = f"{1.0 / np.mean(self.render_times):.3g}" 169 | # print(f"Update time: {end - start:.3g}") 170 | 171 | --------------------------------------------------------------------------------