├── .gitignore ├── .gitmodules ├── environment.yml ├── lpipsPyTorch ├── __init__.py └── modules │ ├── utils.py │ ├── lpips.py │ └── networks.py ├── utils ├── system_utils.py ├── graphics_utils.py ├── loss_utils.py ├── image_utils.py ├── camera_utils.py ├── general_utils.py └── sh_utils.py ├── LICENSE ├── stat.py ├── scene ├── cameras.py ├── __init__.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── gaussian_renderer ├── network_gui.py └── __init__.py ├── README.md ├── arguments └── __init__.py ├── convert.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sh 3 | .vscode 4 | output 5 | dataset 6 | build 7 | depth-diff-gaussian-rasterization/diff_rast.egg-info 8 | depth-diff-gaussian-rasterization/dist 9 | tensorboard_3d 10 | screenshots 11 | output 12 | *__pycache__* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/depth-diff-gaussian-rasterization"] 2 | path = submodules/depth-diff-gaussian-rasterization 3 | url = https://github.com/ingra14m/depth-diff-gaussian-rasterization.git 4 | [submodule "submodules/simple-knn"] 5 | path = submodules/simple-knn 6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 7 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: variational_gs 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile 9 | - pytorch=1.12.1 10 | - torchaudio=0.12.1 11 | - torchvision=0.13.1 12 | - tqdm 13 | - pip: 14 | - submodules/depth-diff-gaussian-rasterization 15 | - submodules/simple-knn 16 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ruiqi LI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /stat.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | 4 | def calculate_column_averages(file_path): 5 | with open(file_path, mode='r', newline='') as csvfile: 6 | csvreader = csv.reader(csvfile) 7 | 8 | # Initialize a list to store sum of each column and a counter for rows 9 | sums = None 10 | row_count = 0 11 | 12 | for row in csvreader: 13 | # Skip header if present 14 | 15 | values = [float(value) for value in row[2:]] # Skip the first two columns 16 | if sums is None: 17 | sums = values 18 | else: 19 | sums = [s + v for s, v in zip(sums, values)] 20 | 21 | row_count += 1 22 | 23 | # Calculate averages 24 | averages = [s / row_count for s in sums] 25 | 26 | return row_count, averages 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dataset_name', type=str, help='dataset name') 31 | 32 | args = parser.parse_args() 33 | # Path to your CSV file 34 | if args.dataset_name == "LF": 35 | file_path = './output/eval_results_LF.csv' 36 | 37 | # Calculate and print the averages 38 | row_count, averages = calculate_column_averages(file_path) 39 | results = f"LF datset total {row_count} Scenes, Averaged Results: PSNR {averages[0]} SSIM {averages[1]} LPIPS {averages[2]} AUSE {averages[3]} NLL {averages[4]} Depth AUSE {averages[5]}" 40 | print(results) 41 | elif args.dataset_name == "LLFF": 42 | file_path = './output/eval_results_LLFF.csv' 43 | 44 | # Calculate and print the averages 45 | row_count, averages = calculate_column_averages(file_path) 46 | results = f"LLFF dataset total {row_count} Scenes, Averaged Results: PSNR {averages[0]} SSIM {averages[1]} LPIPS {averages[2]} AUSE {averages[3]} NLL {averages[4]}" 47 | print(results) 48 | else: 49 | raise ValueError("Datset name is required. ") -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | 16 | def mse(img1, img2): 17 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | 23 | def nll_kernel_density(pred_rgbs, pred_std, ground_truth): 24 | n = pred_std.numel() 25 | eps = 1e-05 26 | H_sqrt = pred_std.detach() * torch.pow(0.8/n,torch.tensor(-1/7)) + eps # (N_rays, 3) 27 | H_sqrt = H_sqrt[...,None] # (N_rays, 3, 1) 28 | r_P_C_1 = torch.exp( -((pred_rgbs - ground_truth[...,None])**2) / (2*H_sqrt*H_sqrt)) # [N_rays, 3, k] 29 | r_P_C_2 = torch.pow(torch.tensor(2*math.pi),-1.5) / H_sqrt # [N_rays, 3, 1] 30 | r_P_C = r_P_C_1 * r_P_C_2 # [N_rays, 3, k] 31 | r_P_C_mean = r_P_C.mean(-1) + eps 32 | loss_nll = - torch.log(r_P_C_mean).mean() 33 | return loss_nll 34 | 35 | def ause_br(unc_vec, err_vec, err_type='rmse'): 36 | ratio_removed = np.linspace(0, 1, 100, endpoint=False) 37 | # Sort the error 38 | err_vec_sorted, _ = torch.sort(err_vec) 39 | 40 | # Calculate the error when removing a fraction pixels with error 41 | n_valid_pixels = len(err_vec) 42 | ause_err = [] 43 | for r in ratio_removed: 44 | err_slice = err_vec_sorted[0:int((1-r)*n_valid_pixels)] 45 | if err_type == 'rmse': 46 | ause_err.append(torch.sqrt(err_slice.mean()).cpu().numpy()) 47 | elif err_type == 'mae' or err_type == 'mse': 48 | ause_err.append(err_slice.mean().cpu().numpy()) 49 | 50 | # Sort by variance 51 | _, var_vec_sorted_idxs = torch.sort(unc_vec) 52 | # Sort error by variance 53 | err_vec_sorted_by_var = err_vec[var_vec_sorted_idxs] 54 | ause_err_by_var = [] 55 | for r in ratio_removed: 56 | 57 | err_slice = err_vec_sorted_by_var[0:int((1 - r) * n_valid_pixels)] 58 | if err_type == 'rmse': 59 | ause_err_by_var.append(torch.sqrt(err_slice.mean()).cpu().numpy()) 60 | elif err_type == 'mae'or err_type == 'mse': 61 | ause_err_by_var.append(err_slice.mean().cpu().numpy()) 62 | 63 | #Normalize and append 64 | max_val = max(max(ause_err), max(ause_err_by_var)) 65 | ause_err = ause_err / max_val 66 | ause_err = np.array(ause_err) 67 | 68 | ause_err_by_var = ause_err_by_var / max_val 69 | ause_err_by_var = np.array(ause_err_by_var) 70 | ause = np.trapz(ause_err_by_var - ause_err, ratio_removed) 71 | 72 | return ause, ause_err, ause_err_by_var 73 | 74 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), depth=None, scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | self.depth = depth 32 | 33 | try: 34 | self.data_device = torch.device(data_device) 35 | except Exception as e: 36 | print(e) 37 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 38 | self.data_device = torch.device("cuda") 39 | 40 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 41 | self.image_width = self.original_image.shape[2] 42 | self.image_height = self.original_image.shape[1] 43 | 44 | if gt_alpha_mask is not None: 45 | self.original_image *= gt_alpha_mask.to(self.data_device) 46 | else: 47 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 48 | 49 | self.zfar = 100.0 50 | self.znear = 0.01 51 | 52 | self.trans = trans 53 | self.scale = scale 54 | 55 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 56 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 57 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 58 | self.camera_center = self.world_view_transform.inverse()[3, :3] 59 | 60 | class MiniCam: 61 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 62 | self.image_width = width 63 | self.image_height = height 64 | self.FoVy = fovy 65 | self.FoVx = fovx 66 | self.znear = znear 67 | self.zfar = zfar 68 | self.world_view_transform = world_view_transform 69 | self.full_proj_transform = full_proj_transform 70 | view_inv = torch.inverse(self.world_view_transform) 71 | self.camera_center = view_inv[3][:3] 72 | 73 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | import torch.nn.functional as F 17 | 18 | WARNED = False 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | orig_w, orig_h = cam_info.image.size 22 | 23 | if args.resolution in [1, 2, 4, 8]: 24 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 25 | else: # should be a type that converts to float 26 | if args.resolution == -1: 27 | if orig_w > 1600: 28 | global WARNED 29 | if not WARNED: 30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 32 | WARNED = True 33 | global_down = orig_w / 1600 34 | else: 35 | global_down = 1 36 | else: 37 | global_down = orig_w / args.resolution 38 | 39 | scale = float(global_down) * float(resolution_scale) 40 | resolution = (int(orig_w / scale), int(orig_h / scale)) 41 | 42 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 43 | 44 | gt_image = resized_image_rgb[:3, ...] 45 | loaded_mask = None 46 | 47 | if resized_image_rgb.shape[1] == 4: 48 | loaded_mask = resized_image_rgb[3:4, ...] 49 | 50 | if cam_info.depth is not None: 51 | depth = F.interpolate(cam_info.depth[None, ...], size=(resolution[1], resolution[0]), mode='bilinear', align_corners=False)[0,0] 52 | else: 53 | depth = None 54 | 55 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 56 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 57 | image=gt_image, gt_alpha_mask=loaded_mask, 58 | image_name=cam_info.image_name, depth=depth, uid=id, data_device=args.data_device) 59 | 60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 61 | camera_list = [] 62 | 63 | for id, c in enumerate(cam_infos): 64 | camera_list.append(loadCam(args, id, c, resolution_scale)) 65 | 66 | return camera_list 67 | 68 | def camera_to_JSON(id, camera : Camera): 69 | Rt = np.zeros((4, 4)) 70 | Rt[:3, :3] = camera.R.transpose() 71 | Rt[:3, 3] = camera.T 72 | Rt[3, 3] = 1.0 73 | 74 | W2C = np.linalg.inv(Rt) 75 | pos = W2C[:3, 3] 76 | rot = W2C[:3, :3] 77 | serializable_array_2d = [x.tolist() for x in rot] 78 | camera_entry = { 79 | 'id' : id, 80 | 'img_name' : camera.image_name, 81 | 'width' : camera.width, 82 | 'height' : camera.height, 83 | 'position': pos.tolist(), 84 | 'rotation': serializable_array_2d, 85 | 'fy' : fov2focal(camera.FovY, camera.height), 86 | 'fx' : fov2focal(camera.FovX, camera.width) 87 | } 88 | return camera_entry 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Multi-scale Representation for Estimating Uncertainty in 3D Gaussian Splatting 2 | [Ruiqi Li](https://www.comp.hkbu.edu.hk/~csrqli/), [Yiu-ming Cheung](https://www.comp.hkbu.edu.hk/~ymc/)
3 | [Paper Link](https://proceedings.neurips.cc/paper_files/paper/2024/file/a076d0d1ed77364fc57693bdee1958fb-Paper-Conference.pdf) 4 | 5 | This repository contains the official open-source implementation of the paper "Variational Multi-scale Representation for Estimating Uncertainty in 3D Gaussian Splatting". We developed an uncertainty estimation method for Gaussian Splatting method based on Bayesian inference framework and multi-scale representation. 6 | 7 | ## Citation 8 | If you found our work useful welcome to cite our paper: 9 | 10 | ``` 11 | @inproceedings{li2024variational, 12 | title={Variational Multi-scale Representation for Estimating Uncertainty in 3D Gaussian Splatting}, 13 | author={Li, Ruiqi and Cheung, Yiu-ming}, 14 | booktitle={Advances in Neural Information Processing Systems}, 15 | volume={37}, 16 | pages={87934--87958}, 17 | year={2024} 18 | } 19 | ``` 20 | 21 | ## Requirements 22 | 23 | **Hardware Requirements** 24 | 25 | CUDA-ready GPU with Compute Capability 7.0+ 26 | 27 | **Software Requirements** 28 | 29 | Conda (recommended for easy setup) 30 | 31 | C++ Compiler for PyTorch extensions 32 | 33 | CUDA SDK 11 for PyTorch extensions 34 | 35 | C++ Compiler and CUDA SDK must be compatible 36 | 37 | ## Usage 38 | 39 | ### Cloning the Repository 40 | 41 | Please clone with submodules: 42 | ```shell 43 | # SSH 44 | git clone git@github.com:csrqli/variational-3dgs.git --recursive 45 | ``` 46 | or 47 | ```shell 48 | # HTTPS 49 | git clone https://github.com/csrqli/variational-3dgs --recursive 50 | ``` 51 | 52 | ### Setup 53 | 54 | We provide conda environment file to creat experiment environment: 55 | ```shell 56 | conda env create --file environment.yml 57 | conda activate variational_gs 58 | ``` 59 | We test our code on ubuntu system, please refer to original 3DGS repo about the potential error building the environment or running on windows. 60 | 61 | ### Preparing Dataset 62 | 63 | The LF dataset and LLFF dataset files are provided here: [LF dataset](https://drive.google.com/file/d/1RrfrMN5wSaishYJu5vYiTy6gUPZfLaDM/view?usp=sharing), [LLFF dataset](https://drive.google.com/file/d/1kDclWpEpUPm9Nw0tGoQTLWz3L4g5Hu2L/view?usp=sharing). 64 | 65 | Please unzip and put them under the a dataset folder: 66 | 67 | ```bash 68 | variational-gs 69 | │ 70 | ├──dataset 71 | │ │ 72 | │ ├──── LF 73 | │ └──── nerf_llff_data 74 | ``` 75 | 76 | ### Running 77 | 78 | To train and evaluate the image quality and the image/depth uncertainty on LF dataset: 79 | 80 | ```shell 81 | python train.py --eval --dataset_name LF -s ./dataset/LF/$scene_name --resolution 2 --iterations 3000 --densify_until_iter 2000 --model_path ./output/$scene_name 82 | ``` 83 | 84 | To get the averaged results: 85 | ``` 86 | python stat.py --dataset_name LF 87 | ``` 88 | 89 | To train and evaluate the image quality and image uncertainty quality on LLFF dataset: 90 | 91 | ```shell 92 | python train.py --eval --dataset_name LLFF -s ./dataset/nerf_llff_data/$scene_name --resolution 8 --iterations 7000 --densify_until_iter 4000 --model_path ./output/$scene_name 93 | ``` 94 | 95 | and also get the averaged results: 96 | ``` 97 | python stat.py --dataset_name LLFF 98 | ``` 99 | 100 | ## Funding and Acknowledgments 101 | 102 | This work was supported in part by the NSFC / Research Grants Council (RGC) Joint Research Scheme under the grant: N\_HKBU214/21, and the RGC Senior Research Fellow Scheme under the grant: SRFS2324-2S02. 103 | 104 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = False 55 | self.data_device = "cuda" 56 | self.eval = False 57 | self.dataset_name = "LF" 58 | 59 | self.spawn_interval = 1000 60 | self.spawn_percent_base = 0.01 61 | self.spawn_min_opacity = 0.0005 62 | 63 | 64 | super().__init__(parser, "Loading Parameters", sentinel) 65 | 66 | def extract(self, args): 67 | g = super().extract(args) 68 | g.source_path = os.path.abspath(g.source_path) 69 | return g 70 | 71 | class PipelineParams(ParamGroup): 72 | def __init__(self, parser): 73 | self.convert_SHs_python = False 74 | self.compute_cov3D_python = False 75 | self.debug = False 76 | super().__init__(parser, "Pipeline Parameters") 77 | 78 | class OptimizationParams(ParamGroup): 79 | def __init__(self, parser): 80 | self.iterations = 30_000 81 | self.position_lr_init = 0.00016 82 | self.position_lr_final = 0.0000016 83 | self.position_lr_delay_mult = 0.01 84 | self.position_lr_max_steps = 30_000 85 | self.feature_lr = 0.0025 86 | self.opacity_lr = 0.05 87 | self.scaling_lr = 0.005 88 | self.rotation_lr = 0.001 89 | self.percent_dense = 0.01 90 | self.lambda_dssim = 0.2 91 | self.densification_interval = 100 92 | self.opacity_reset_interval = 30000 93 | self.densify_from_iter = 100 94 | self.densify_until_iter = 15_000 95 | self.densify_grad_threshold = 0.0002 96 | self.random_background = False 97 | super().__init__(parser, "Optimization Parameters") 98 | 99 | def get_combined_args(parser : ArgumentParser): 100 | cmdlne_string = sys.argv[1:] 101 | cfgfile_string = "Namespace()" 102 | args_cmdline = parser.parse_args(cmdlne_string) 103 | 104 | try: 105 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 106 | print("Looking for config file in", cfgfilepath) 107 | with open(cfgfilepath) as cfg_file: 108 | print("Config file found: {}".format(cfgfilepath)) 109 | cfgfile_string = cfg_file.read() 110 | except TypeError: 111 | print("Config file not found at") 112 | pass 113 | args_cfgfile = eval(cfgfile_string) 114 | 115 | merged_dict = vars(args_cfgfile).copy() 116 | for k,v in vars(args_cmdline).items(): 117 | if v != None: 118 | merged_dict[k] = v 119 | return Namespace(**merged_dict) 120 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], depth_scale=1.0): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | self.depth_scale = depth_scale 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | 45 | if args.dataset_name == "LF": 46 | scene_info = sceneLoadTypeCallbacks["LF"](args.source_path, args.images, args.eval) 47 | self.dataset_name = "LF" 48 | elif os.path.exists(os.path.join(args.source_path, "sparse")): 49 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 50 | self.dataset_name = "colmap" 51 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 52 | print("Found transforms_train.json file, assuming Blender data set!") 53 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 54 | self.dataset_name = "blender" 55 | else: 56 | assert False, "Could not recognize scene type!" 57 | self.depth_scale = scene_info.depth_scale 58 | 59 | if not self.loaded_iter: 60 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 61 | dest_file.write(src_file.read()) 62 | json_cams = [] 63 | camlist = [] 64 | if scene_info.test_cameras: 65 | camlist.extend(scene_info.test_cameras) 66 | if scene_info.train_cameras: 67 | camlist.extend(scene_info.train_cameras) 68 | for id, cam in enumerate(camlist): 69 | json_cams.append(camera_to_JSON(id, cam)) 70 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 71 | json.dump(json_cams, file) 72 | 73 | if shuffle: 74 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 75 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 76 | 77 | self.cameras_extent = scene_info.nerf_normalization["radius"] 78 | 79 | for resolution_scale in resolution_scales: 80 | print("Loading Training Cameras") 81 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 82 | print("Loading Test Cameras") 83 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 84 | 85 | if self.loaded_iter: 86 | self.gaussians.load_ply(os.path.join(self.model_path, 87 | "point_cloud", 88 | "iteration_" + str(self.loaded_iter), 89 | "point_cloud.ply")) 90 | else: 91 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 92 | 93 | def save(self, iteration): 94 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 95 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 96 | 97 | def getTrainCameras(self, scale=1.0): 98 | return self.train_cameras[scale] 99 | 100 | def getTestCameras(self, scale=1.0): 101 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from utils.sh_utils import eval_sh 16 | 17 | #def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp = False): 18 | def render(viewpoint_camera, pc, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if pipe.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if pipe.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii, depth_image = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 96 | # They will be excluded from value updates used in the splitting criteria. 97 | return {"render": rendered_image, 98 | "viewspace_points": screenspace_points, 99 | "visibility_filter" : radii > 0, 100 | "radii": radii, 101 | "depth" : depth_image, 102 | } 103 | 104 | def forward_k_times(viewpoint_camera, pc, pipe, bg_color, scaling_modifier = 1.0, override_color = None, k=10): 105 | rgbs = [] 106 | depths = [] 107 | 108 | for model_id in range(pc.n_models): 109 | pc.model_id = model_id 110 | out = render(viewpoint_camera, pc, pipe, bg_color, scaling_modifier = 1.0, override_color = None) 111 | rgb = out['render'] 112 | depth = out['depth'] 113 | depths.append(depth) 114 | rgbs.append(rgb) 115 | 116 | rgbs = torch.stack(rgbs, dim=0) 117 | depths = torch.stack(depths, dim=0) 118 | depth_mean = depths.mean(dim=0) 119 | depth_var = depths.var(dim=0) 120 | 121 | std = rgbs.std(dim=0) 122 | var = rgbs.var(dim=0) 123 | 124 | mean = rgbs.mean(dim=0) 125 | 126 | return {'comp_rgb': mean, 127 | 'comp_rgbs': rgbs, 128 | 'comp_var': var, 129 | 'comp_std': std, 130 | 'depths': depths, 131 | 'depth_var': depth_var, 132 | 'depth_mean': depth_mean, 133 | } 134 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | from glob import glob 26 | import torchvision.transforms.functional as TF 27 | 28 | class CameraInfo: 29 | def __init__(self, uid, R, T, FovY, FovX, image, image_path, image_name, width, height, depth=None): 30 | self.uid = uid 31 | self.R = R 32 | self.T = T 33 | self.FovY = FovY 34 | self.FovX = FovX 35 | self.image = image 36 | self.image_path = image_path 37 | self.image_name = image_name 38 | self.width = width 39 | self.height = height 40 | self.depth = depth 41 | 42 | class CameraInfoOld(NamedTuple): 43 | R: np.array 44 | T: np.array 45 | FovY: np.array 46 | FovX: np.array 47 | image: np.array 48 | image_path: str 49 | image_name: str 50 | width: int 51 | height: int 52 | depth_path: str 53 | 54 | class SceneInfo(NamedTuple): 55 | point_cloud: BasicPointCloud 56 | train_cameras: list 57 | test_cameras: list 58 | nerf_normalization: dict 59 | ply_path: str 60 | depth_scale: float = 1.0 61 | 62 | def getNerfppNorm(cam_info): 63 | def get_center_and_diag(cam_centers): 64 | cam_centers = np.hstack(cam_centers) 65 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 66 | center = avg_cam_center 67 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 68 | diagonal = np.max(dist) 69 | return center.flatten(), diagonal 70 | 71 | cam_centers = [] 72 | 73 | for cam in cam_info: 74 | W2C = getWorld2View2(cam.R, cam.T) 75 | C2W = np.linalg.inv(W2C) 76 | cam_centers.append(C2W[:3, 3:4]) 77 | 78 | center, diagonal = get_center_and_diag(cam_centers) 79 | radius = diagonal * 1.1 80 | 81 | translate = -center 82 | 83 | return {"translate": translate, "radius": radius} 84 | 85 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 86 | cam_infos = [] 87 | for idx, key in enumerate(cam_extrinsics): 88 | sys.stdout.write('\r') 89 | # the exact output you're looking for: 90 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 91 | sys.stdout.flush() 92 | 93 | extr = cam_extrinsics[key] 94 | intr = cam_intrinsics[extr.camera_id] 95 | height = intr.height 96 | width = intr.width 97 | 98 | uid = intr.id 99 | R = np.transpose(qvec2rotmat(extr.qvec)) 100 | T = np.array(extr.tvec) 101 | 102 | if intr.model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]: 103 | focal_length_x = intr.params[0] 104 | FovY = focal2fov(focal_length_x, height) 105 | FovX = focal2fov(focal_length_x, width) 106 | elif intr.model in ['PINHOLE', 'OPENCV']: 107 | focal_length_x = intr.params[0] 108 | focal_length_y = intr.params[1] 109 | FovY = focal2fov(focal_length_y, height) 110 | FovX = focal2fov(focal_length_x, width) 111 | else: 112 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 113 | 114 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 115 | image_name = os.path.basename(image_path).split(".")[0] 116 | image = Image.open(image_path) 117 | 118 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 119 | image_path=image_path, image_name=image_name, width=width, height=height) 120 | cam_infos.append(cam_info) 121 | sys.stdout.write('\n') 122 | return cam_infos 123 | 124 | def fetchPly(path): 125 | plydata = PlyData.read(path) 126 | vertices = plydata['vertex'] 127 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 128 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 129 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 130 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 131 | 132 | def storePly(path, xyz, rgb): 133 | # Define the dtype for the structured array 134 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 135 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 136 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 137 | 138 | normals = np.zeros_like(xyz) 139 | 140 | elements = np.empty(xyz.shape[0], dtype=dtype) 141 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 142 | elements[:] = list(map(tuple, attributes)) 143 | 144 | # Create the PlyData object and write to file 145 | vertex_element = PlyElement.describe(elements, 'vertex') 146 | ply_data = PlyData([vertex_element]) 147 | ply_data.write(path) 148 | 149 | def readColmapSceneInfo(path, images, eval, llffhold=8): 150 | try: 151 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 152 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 153 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 154 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 155 | except: 156 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 157 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 158 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 159 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 160 | 161 | reading_dir = "images" if images == None else images 162 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 163 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 164 | 165 | if eval: 166 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 167 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 168 | else: 169 | train_cam_infos = cam_infos 170 | test_cam_infos = [] 171 | 172 | nerf_normalization = getNerfppNorm(train_cam_infos) 173 | 174 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 175 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 176 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 177 | if not os.path.exists(ply_path): 178 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 179 | try: 180 | xyz, rgb, _ = read_points3D_binary(bin_path) 181 | except: 182 | xyz, rgb, _ = read_points3D_text(txt_path) 183 | storePly(ply_path, xyz, rgb) 184 | try: 185 | pcd = fetchPly(ply_path) 186 | except: 187 | pcd = None 188 | 189 | scene_info = SceneInfo(point_cloud=pcd, 190 | train_cameras=train_cam_infos, 191 | test_cameras=test_cam_infos, 192 | nerf_normalization=nerf_normalization, 193 | ply_path=ply_path) 194 | return scene_info 195 | 196 | def readLFSceneInfo(path, images, eval, llffhold=8): 197 | scene_name = os.path.basename(path) 198 | 199 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 200 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 201 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 202 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 203 | 204 | reading_dir = "images" if images == None else images 205 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 206 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 207 | 208 | #train/test split from CF-NeRF https://github.com/poetrywanderer/CFNeRF/blob/66918a9748c137e1c0242c12be7aa6efa39ece06/run_nerf_uncertainty_NF.py#L750 209 | if scene_name == 'basket': 210 | i_train = list(np.arange(43,50,2)) 211 | i_val = list(np.arange(42,50,2)) 212 | depth_scale = 1 / 8 213 | 214 | elif scene_name == 'africa': 215 | i_train = list(np.arange(5,14,2)) 216 | i_val = list(np.arange(6,14,2)) 217 | depth_scale = 4 218 | 219 | elif scene_name == 'statue': 220 | i_train = list(np.arange(67,76,2)) 221 | i_val = list(np.arange(68,76,2)) 222 | depth_scale = 1.25 223 | 224 | elif scene_name == 'torch': 225 | i_train = list(np.arange(8,17,2)) 226 | i_val = list(np.arange(9,17,2)) 227 | depth_scale = 15 228 | 229 | if eval: 230 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx in i_train] 231 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx in i_val] 232 | else: 233 | train_cam_infos = cam_infos 234 | test_cam_infos = [] 235 | 236 | all_depth = [] 237 | depth_files = glob(os.path.join(path, f'depth_*.npy')) 238 | depth_files = sorted(depth_files) 239 | for i in range(4): 240 | depth = np.ascontiguousarray(np.load(depth_files[i])) 241 | depth = TF.to_tensor(depth).cuda() 242 | test_cam_infos[i].depth = depth 243 | 244 | nerf_normalization = getNerfppNorm(train_cam_infos) 245 | 246 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 247 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 248 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 249 | if not os.path.exists(ply_path): 250 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 251 | try: 252 | xyz, rgb, _ = read_points3D_binary(bin_path) 253 | except: 254 | xyz, rgb, _ = read_points3D_text(txt_path) 255 | storePly(ply_path, xyz, rgb) 256 | try: 257 | pcd = fetchPly(ply_path) 258 | except: 259 | pcd = None 260 | 261 | scene_info = SceneInfo(point_cloud=pcd, 262 | train_cameras=train_cam_infos, 263 | test_cameras=test_cam_infos, 264 | nerf_normalization=nerf_normalization, 265 | ply_path=ply_path, 266 | depth_scale=depth_scale) 267 | return scene_info 268 | 269 | 270 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 271 | cam_infos = [] 272 | 273 | with open(os.path.join(path, transformsfile)) as json_file: 274 | contents = json.load(json_file) 275 | fovx = contents["camera_angle_x"] 276 | 277 | frames = contents["frames"] 278 | for idx, frame in enumerate(frames): 279 | cam_name = os.path.join(path, frame["file_path"] + extension) 280 | 281 | # NeRF 'transform_matrix' is a camera-to-world transform 282 | c2w = np.array(frame["transform_matrix"]) 283 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 284 | c2w[:3, 1:3] *= -1 285 | 286 | # get the world-to-camera transform and set R, T 287 | w2c = np.linalg.inv(c2w) 288 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 289 | T = w2c[:3, 3] 290 | 291 | image_path = os.path.join(path, cam_name) 292 | image_name = Path(cam_name).stem 293 | image = Image.open(image_path) 294 | 295 | im_data = np.array(image.convert("RGBA")) 296 | 297 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 298 | 299 | norm_data = im_data / 255.0 300 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 301 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 302 | 303 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 304 | FovY = fovy 305 | FovX = fovx 306 | 307 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 308 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 309 | 310 | return cam_infos 311 | 312 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 313 | print("Reading Training Transforms") 314 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 315 | print("Reading Test Transforms") 316 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 317 | 318 | if not eval: 319 | train_cam_infos.extend(test_cam_infos) 320 | test_cam_infos = [] 321 | 322 | nerf_normalization = getNerfppNorm(train_cam_infos) 323 | 324 | ply_path = os.path.join(path, "points3d.ply") 325 | if not os.path.exists(ply_path): 326 | # Since this data set has no colmap data, we start with random points 327 | num_pts = 100_000 328 | print(f"Generating random point cloud ({num_pts})...") 329 | 330 | # We create random points inside the bounds of the synthetic Blender scenes 331 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 332 | shs = np.random.random((num_pts, 3)) / 255.0 333 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 334 | 335 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 336 | try: 337 | pcd = fetchPly(ply_path) 338 | except: 339 | pcd = None 340 | 341 | scene_info = SceneInfo(point_cloud=pcd, 342 | train_cameras=train_cam_infos, 343 | test_cameras=test_cam_infos, 344 | nerf_normalization=nerf_normalization, 345 | ply_path=ply_path) 346 | return scene_info 347 | 348 | sceneLoadTypeCallbacks = { 349 | "Colmap": readColmapSceneInfo, 350 | "Blender" : readNerfSyntheticInfo, 351 | "LF" : readLFSceneInfo 352 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | import csv 26 | from lpipsPyTorch import lpips 27 | 28 | if False: 29 | from torch.utils.tensorboard import SummaryWriter 30 | TENSORBOARD_FOUND = True 31 | else: 32 | TENSORBOARD_FOUND = False 33 | 34 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 35 | opt.position_lr_max_steps = opt.iterations 36 | 37 | first_iter = 0 38 | tb_writer = prepare_output_and_logger(dataset) 39 | gaussians = GaussianModel(dataset) 40 | scene = Scene(dataset, gaussians) 41 | gaussians.training_setup(opt) 42 | 43 | if checkpoint: 44 | (model_params, first_iter) = torch.load(checkpoint) 45 | gaussians.restore(model_params, opt) 46 | 47 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 48 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 49 | 50 | iter_start = torch.cuda.Event(enable_timing = True) 51 | iter_end = torch.cuda.Event(enable_timing = True) 52 | 53 | viewpoint_stack = None 54 | ema_loss_for_log = 0.0 55 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 56 | first_iter += 1 57 | 58 | for iteration in range(first_iter, opt.iterations + 1): 59 | if network_gui.conn == None: 60 | network_gui.try_connect() 61 | while network_gui.conn != None: 62 | try: 63 | net_image_bytes = None 64 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 65 | if custom_cam != None: 66 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 67 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 68 | network_gui.send(net_image_bytes, dataset.source_path) 69 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 70 | break 71 | except Exception as e: 72 | network_gui.conn = None 73 | 74 | iter_start.record() 75 | 76 | 77 | # Pick a random Camera 78 | if not viewpoint_stack: 79 | viewpoint_stack = scene.getTrainCameras().copy() 80 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 81 | 82 | # Render 83 | if (iteration - 1) == debug_from: 84 | pipe.debug = True 85 | 86 | bg = torch.rand((3), device="cuda") if opt.random_background else background 87 | model_id = torch.randint(0, gaussians.n_models, (1,)).item() 88 | gaussians.model_id = model_id 89 | 90 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 91 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 92 | 93 | # Loss 94 | gt_image = viewpoint_cam.original_image.cuda() 95 | Ll1 = l1_loss(image, gt_image) 96 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 97 | 98 | loss_kl_scal = gaussians.compute_kl_uniform_scal() 99 | loss_kl_xyz = gaussians.compute_kl_xyz() 100 | loss_kl_opacity = gaussians.compute_kl_opacity() 101 | 102 | loss += 1.0*(loss_kl_scal + loss_kl_xyz + loss_kl_opacity) 103 | 104 | gaussians.update_learning_rate(iteration) 105 | 106 | # Every 1000 its we increase the levels of SH up to a maximum degree 107 | if iteration % 1000 == 0: 108 | gaussians.oneupSHdegree() 109 | 110 | 111 | loss.backward() 112 | 113 | iter_end.record() 114 | 115 | with torch.no_grad(): 116 | # Progress bar 117 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 118 | if iteration % 10 == 0: 119 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 120 | progress_bar.update(10) 121 | if iteration == opt.iterations: 122 | progress_bar.close() 123 | 124 | spawn_interval = dataset.spawn_interval 125 | 126 | # Log and save 127 | training_report(dataset, tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 128 | if (iteration in saving_iterations): 129 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 130 | scene.save(iteration) 131 | 132 | # Densification 133 | if iteration < opt.densify_until_iter: 134 | # Keep track of max radii in image-space for pruning 135 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 136 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 137 | 138 | if iteration % spawn_interval == 0: # spawn interval should be a multiple of densification interval 139 | gaussians.spawn(scene.cameras_extent) 140 | 141 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 142 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 143 | 144 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 145 | 146 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 147 | gaussians.reset_opacity() 148 | 149 | # Optimizer step 150 | if iteration < opt.iterations: 151 | gaussians.optimizer.step() 152 | gaussians.optimizer.zero_grad(set_to_none=True) 153 | 154 | if (iteration in checkpoint_iterations): 155 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 156 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 157 | 158 | 159 | def prepare_output_and_logger(args): 160 | if not args.model_path: 161 | if os.getenv('OAR_JOB_ID'): 162 | unique_str=os.getenv('OAR_JOB_ID') 163 | else: 164 | unique_str = str(uuid.uuid4()) 165 | args.model_path = os.path.join("./output/", unique_str[0:10]) 166 | 167 | # Set up output folder 168 | print("Output folder: {}".format(args.model_path)) 169 | os.makedirs(args.model_path, exist_ok = True) 170 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 171 | cfg_log_f.write(str(Namespace(**vars(args)))) 172 | 173 | # Create Tensorboard writer 174 | tb_writer = None 175 | if TENSORBOARD_FOUND: 176 | tb_writer = SummaryWriter(args.model_path) 177 | else: 178 | print("Tensorboard not available: not logging progress") 179 | return tb_writer 180 | 181 | import torchvision 182 | 183 | def training_report(dataset, tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 184 | if tb_writer: 185 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 186 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 187 | tb_writer.add_scalar('iter_time', elapsed, iteration) 188 | 189 | # Report test and samples of training set 190 | if iteration in testing_iterations: 191 | torch.cuda.empty_cache() 192 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 193 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 194 | 195 | for config in validation_configs: 196 | if config['cameras'] and len(config['cameras']) > 0: 197 | l1_test = 0.0 198 | psnr_test = 0.0 199 | for idx, viewpoint in enumerate(config['cameras']): 200 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 201 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 202 | if tb_writer and (idx < 5): 203 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 204 | if iteration == testing_iterations[0]: 205 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 206 | l1_test += l1_loss(image, gt_image).mean().double() 207 | psnr_test += psnr(image, gt_image).mean().double() 208 | 209 | psnr_test /= len(config['cameras']) 210 | l1_test /= len(config['cameras']) 211 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 212 | if tb_writer: 213 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 214 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 215 | 216 | if tb_writer: 217 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 218 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 219 | torch.cuda.empty_cache() 220 | if iteration == testing_iterations[-1]: 221 | render_set(dataset, scene, renderArgs[0]) 222 | 223 | from utils.image_utils import psnr, nll_kernel_density, ause_br 224 | from gaussian_renderer import render, forward_k_times 225 | from os import makedirs 226 | 227 | 228 | def render_set(dataset, scene, pipeline): 229 | gaussians, views = scene.gaussians, scene.getTestCameras() 230 | 231 | bg_color = [0, 0, 0] 232 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 233 | 234 | psnr_all, ssim_all, lpips_all, ause_mae_all, mean_nll_all, depth_ause_mae_all = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 235 | eval_depth = True if dataset.dataset_name == "LF" else False 236 | 237 | scene_name = scene.model_path.split("/")[-1] 238 | 239 | render_path = f"{scene.model_path}/test/ours_7000/renders" 240 | gts_path = f"{scene.model_path}/test/ours_7000/gt" 241 | unc_path = f"{scene.model_path}/test/ours_7000/unc" 242 | 243 | makedirs(render_path, exist_ok=True) 244 | makedirs(gts_path, exist_ok=True) 245 | makedirs(unc_path, exist_ok=True) 246 | 247 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 248 | 249 | gt = view.original_image[0:3, :, :] 250 | out = forward_k_times(view, gaussians, pipeline, background) 251 | mean = out['comp_rgb'].detach() 252 | rgbs = out['comp_rgbs'].detach() 253 | std = out['comp_std'].detach() 254 | depths = out['depths'].detach() 255 | 256 | mae = ((mean - gt)).abs() 257 | 258 | ause_mae, ause_err_mae, ause_err_by_var_mae = ause_br(std.reshape(-1), mae.reshape(-1), err_type='mae') 259 | mean_nll = nll_kernel_density(rgbs.permute(1,2,3,0), std, gt) 260 | 261 | psnr_all += psnr(mean, gt).mean().item() 262 | ssim_all += ssim(mean, gt).mean().item() 263 | lpips_all += lpips(mean, gt, net_type="vgg").mean().item() 264 | 265 | ause_mae_all += ause_mae.item() 266 | mean_nll_all += mean_nll.item() 267 | 268 | if eval_depth: 269 | depths = depths * scene.depth_scale 270 | 271 | depth = depths.mean(dim=0) 272 | depth_std = depths.std(dim=0) 273 | depth_gt = view.depth 274 | 275 | depth_mae = ((depth - depth_gt)).abs() 276 | depth_ause_mae, depth_ause_err_mae, depth_ause_err_by_var_mae = ause_br(depth_std.reshape(-1), depth_mae.reshape(-1), err_type='mae') 277 | depth_ause_mae_all += depth_ause_mae 278 | 279 | 280 | unc_vis_multiply = 10 281 | torchvision.utils.save_image(mean, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 282 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 283 | torchvision.utils.save_image(unc_vis_multiply*std, os.path.join(unc_path, '{0:05d}'.format(idx) + ".png")) 284 | 285 | 286 | psnr_all /= len(views) 287 | ause_mae_all /= len(views) 288 | mean_nll_all /= len(views) 289 | ssim_all /= len(views) 290 | lpips_all /= len(views) 291 | 292 | depth_ause_mae_all /= len(views) 293 | 294 | csv_file = f"output/eval_results_{dataset.dataset_name}.csv" 295 | with open(csv_file, mode='a', newline='') as file: 296 | writer = csv.writer(file) 297 | 298 | if eval_depth: 299 | results = f"\nEvaluation Results: PSNR {psnr_all} SSIM {ssim_all} LPIPS {lpips_all} AUSE {ause_mae_all} NLL {mean_nll_all} Depth AUSE {depth_ause_mae_all}" 300 | print(results) 301 | writer.writerow([dataset.dataset_name, scene_name, psnr_all, ssim_all, lpips_all, ause_mae_all, mean_nll_all, depth_ause_mae_all]) 302 | else: 303 | results = f"\nEvaluation Results: PSNR {psnr_all} SSIM {ssim_all} LPIPS {lpips_all} AUSE {ause_mae_all} NLL {mean_nll_all}" 304 | print(results) 305 | writer.writerow([dataset.dataset_name, scene_name, psnr_all, ssim_all, lpips_all, ause_mae_all, mean_nll_all]) 306 | 307 | if __name__ == "__main__": 308 | # Set up command line argument parser 309 | parser = ArgumentParser(description="Training script parameters") 310 | lp = ModelParams(parser) 311 | op = OptimizationParams(parser) 312 | pp = PipelineParams(parser) 313 | parser.add_argument('--ip', type=str, default="127.0.0.1") 314 | parser.add_argument('--port', type=int, default=6009) 315 | parser.add_argument('--debug_from', type=int, default=-1) 316 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 317 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[3_000, 7_000, 30_000]) 318 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[3_000, 7_000, 30_000]) 319 | parser.add_argument("--quiet", action="store_true") 320 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 321 | parser.add_argument("--start_checkpoint", type=str, default = None) 322 | args = parser.parse_args(sys.argv[1:]) 323 | 324 | args.test_iterations.append(args.iterations) 325 | args.save_iterations.append(args.iterations) 326 | 327 | print("Optimizing " + args.model_path) 328 | 329 | # Initialize system state (RNG) 330 | safe_state(args.quiet) 331 | 332 | # Start GUI server, configure and run training 333 | #network_gui.init(args.ip, args.port) 334 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 335 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 336 | 337 | # All done 338 | print("\nTraining complete.") 339 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH, SH2RGB 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation 23 | 24 | class GaussianModel: 25 | def setup_functions(self): 26 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 27 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 28 | actual_covariance = L @ L.transpose(1, 2) 29 | symm = strip_symmetric(actual_covariance) 30 | return symm 31 | 32 | self.scaling_activation = torch.exp 33 | self.scaling_inverse_activation = torch.log 34 | 35 | self.covariance_activation = build_covariance_from_scaling_rotation 36 | 37 | self.opacity_activation = torch.sigmoid 38 | self.inverse_opacity_activation = inverse_sigmoid 39 | 40 | self.rotation_activation = torch.nn.functional.normalize 41 | 42 | def __init__(self, dataset): 43 | self.active_sh_degree = 0 44 | self.max_sh_degree = dataset.sh_degree 45 | self._xyz = torch.empty(0) 46 | self._features_dc = torch.empty(0) 47 | self._features_rest = torch.empty(0) 48 | self._scaling = torch.empty(0) 49 | self._rotation = torch.empty(0) 50 | self._opacity = torch.empty(0) 51 | self.max_radii2D = torch.empty(0) 52 | self.xyz_gradient_accum = torch.empty(0) 53 | self.denom = torch.empty(0) 54 | self.optimizer = None 55 | self.percent_dense = 0 56 | self.spatial_lr_scalar = 0 57 | self.tmp = 0.2 58 | self.iteration = 0 59 | self._feat_unc = torch.empty(0) 60 | self.model_id = 0 61 | 62 | self.n_models = 10 63 | self.pri_std = -6 64 | self.pri_width = 0.1 65 | self.M = 2 66 | 67 | self.pri_opacity_std = 1.85 68 | self.pri_opacity_mean = 2 69 | 70 | self.lr_scales = [0.1, 0.1, 0.1] 71 | 72 | self.spawn_percent_base = dataset.spawn_percent_base 73 | self.spawn_min_opacity = dataset.spawn_min_opacity 74 | 75 | self.setup_functions() 76 | 77 | def capture(self): 78 | return ( 79 | self.active_sh_degree, 80 | self._xyz, 81 | self._features_dc, 82 | self._features_rest, 83 | self._scaling, 84 | self._rotation, 85 | self._opacity, 86 | self.max_radii2D, 87 | self.xyz_gradient_accum, 88 | self.denom, 89 | self.optimizer.state_dict(), 90 | self.spatial_lr_scalar, 91 | ) 92 | 93 | def restore(self, model_args, training_args): 94 | (self.active_sh_degree, 95 | self._xyz, 96 | self._features_dc, 97 | self._features_rest, 98 | self._scaling, 99 | self._rotation, 100 | self._opacity, 101 | self.max_radii2D, 102 | xyz_gradient_accum, 103 | denom, 104 | opt_dict, 105 | self.spatial_lr_scalar) = model_args 106 | self.training_setup(training_args) 107 | self.xyz_gradient_accum = xyz_gradient_accum 108 | self.denom = denom 109 | self.optimizer.load_state_dict(opt_dict) 110 | 111 | 112 | @property 113 | def get_scaling(self): 114 | scale = self.compute_scal() 115 | return scale 116 | 117 | def compute_scal(self): 118 | scal = self._scaling 119 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 120 | 121 | width = self.offsets["_scaling_offset"][...,sample_model_ids].mean(dim=-1) 122 | width = torch.nn.functional.softplus(width).clamp_(1e-2, 1e2) 123 | left = self.offsets["_scaling_offset"][...,sample_model_ids+self.n_models].mean(dim=-1) 124 | 125 | offset_scal = (width) * torch.rand_like(scal) + left 126 | offset_scal = self.mr_list*offset_scal+(1-self.mr_list)* torch.ones_like(offset_scal).cuda().requires_grad_(True) 127 | 128 | scal = scal * offset_scal 129 | scal = self.scaling_activation(scal) 130 | return scal 131 | 132 | def compute_kl_uniform_scal(self): 133 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 134 | width = self.offsets["_scaling_offset"][...,sample_model_ids].mean(dim=-1) 135 | width = torch.nn.functional.softplus(width).clamp_(1e-2, 1e2) 136 | left = self.offsets["_scaling_offset"][...,sample_model_ids+self.n_models].mean(dim=-1) 137 | 138 | pri_left = 1-self.pri_width 139 | 140 | right = left + width 141 | prior_right = 1 142 | 143 | kl = torch.abs(left - pri_left) + torch.abs(right - prior_right) 144 | 145 | return kl.mean() 146 | 147 | @property 148 | def get_rotation(self): 149 | return self.compute_rotation() 150 | def compute_rotation(self): 151 | r = self.rotation_activation(self._rotation) 152 | 153 | return r 154 | 155 | @property 156 | def get_xyz(self): 157 | xyz = self.compute_xyz() 158 | return xyz 159 | def compute_xyz(self): 160 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 161 | xyz = self._xyz 162 | 163 | std = self.offsets["_xyz_offset"][..., sample_model_ids].mean(dim=-1) 164 | std = torch.nn.functional.softplus(std) 165 | 166 | mean = self.offsets["_xyz_offset"][..., sample_model_ids+self.n_models].mean(dim=-1) 167 | 168 | offset = torch.randn_like(xyz).cuda().requires_grad_(True) 169 | offset = offset*std+mean 170 | 171 | xyz = xyz + self.mr_list*offset 172 | return xyz 173 | 174 | def compute_kl_xyz(self): 175 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 176 | std = self.offsets["_xyz_offset"][..., sample_model_ids].mean(dim=-1) 177 | std = torch.nn.functional.softplus(std) 178 | mean = self.offsets["_xyz_offset"][..., sample_model_ids+self.n_models].mean(dim=-1) 179 | 180 | pri_std = torch.nn.functional.softplus(torch.tensor(self.pri_std).float()).item() 181 | pri_mean, pri_std = torch.zeros_like(mean), pri_std*torch.ones_like(std) 182 | 183 | log_sigma_pri, log_sigma_post = torch.log(pri_std), torch.log(std) 184 | kl = log_sigma_pri - log_sigma_post + \ 185 | (torch.exp(log_sigma_post)**2 + (mean-pri_mean)**2)/(2*torch.exp(log_sigma_pri)**2) - 0.5 186 | return kl.mean() 187 | 188 | @property 189 | def get_features(self): 190 | return self.compute_features() 191 | def compute_features(self): 192 | features_dc = self._features_dc 193 | features_rest = self._features_rest 194 | 195 | return torch.cat((features_dc, features_rest), dim=1) 196 | 197 | @property 198 | def get_opacity(self): 199 | opacity = self.compute_opacity() 200 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 201 | std = self.offsets["_opacity_offset"][..., sample_model_ids].mean(dim=-1) 202 | std = torch.nn.functional.softplus(std) 203 | mean = self.offsets["_opacity_offset"][..., sample_model_ids+self.n_models].mean(dim=-1) 204 | 205 | p_logit = (std*torch.randn_like(opacity).cuda().requires_grad_(True) + mean) 206 | offset_opc = torch.sigmoid(p_logit / self.tmp) 207 | offset_opc = self.mr_list*offset_opc+(1-self.mr_list)*torch.ones_like(offset_opc).cuda().requires_grad_(True) 208 | 209 | opacity = opacity * offset_opc 210 | 211 | return opacity 212 | 213 | def compute_kl_opacity(self): 214 | sample_model_ids = torch.randperm(self.n_models)[:self.M].cuda().requires_grad_(False).detach() 215 | std = self.offsets["_opacity_offset"][..., sample_model_ids].mean(dim=-1) 216 | std = torch.nn.functional.softplus(std) 217 | mean = self.offsets["_opacity_offset"][..., sample_model_ids+self.n_models].mean(dim=-1) 218 | 219 | pri_std = torch.nn.functional.softplus(torch.tensor(self.pri_opacity_std).float()).item() 220 | pri_mean = self.pri_opacity_mean 221 | 222 | pri_mean, pri_std = pri_mean * torch.ones_like(mean), pri_std * torch.ones_like(std) 223 | 224 | log_sigma_pri, log_sigma_post = torch.log(pri_std), torch.log(std) 225 | kl = log_sigma_pri - log_sigma_post + \ 226 | (torch.exp(log_sigma_post)**2 + (mean-pri_mean)**2)/(2*torch.exp(log_sigma_pri)**2) - 0.5 227 | return kl.mean() 228 | 229 | 230 | def compute_opacity(self): 231 | x = self.opacity_activation(self._opacity) 232 | return x 233 | 234 | def get_covariance(self, scaling_modifier = 1): 235 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 236 | 237 | def oneupSHdegree(self): 238 | if self.active_sh_degree < self.max_sh_degree: 239 | self.active_sh_degree += 1 240 | 241 | def init_offset(self): 242 | _xyz_offset = torch.zeros([self._xyz.shape[0], 3, self.n_models*2]) 243 | _xyz_offset[..., :self.n_models] = self.pri_std 244 | _xyz_offset = nn.Parameter(_xyz_offset.requires_grad_(True).cuda()) 245 | 246 | _scaling_offset = torch.zeros([self._xyz.shape[0], 3, self.n_models*2]) 247 | 248 | _scaling_offset[..., :self.n_models] = torch.log(torch.exp(1/torch.tensor(self.n_models))-1).item() 249 | _scaling_offset[..., self.n_models:] = 1-self.pri_width 250 | _scaling_offset = nn.Parameter(_scaling_offset.requires_grad_(True).cuda()) 251 | 252 | _opacity_offset = torch.zeros([self._xyz.shape[0], 1, self.n_models*2]) 253 | _opacity_offset[..., :self.n_models] = self.pri_opacity_std 254 | _opacity_offset[..., self.n_models:] = self.pri_opacity_mean 255 | _opacity_offset = nn.Parameter(_opacity_offset.requires_grad_(True).cuda()) 256 | offsets = [ 257 | {"_xyz_offset": _xyz_offset}, 258 | {"_scaling_offset": _scaling_offset}, 259 | {"_opacity_offset": _opacity_offset} 260 | ] 261 | 262 | lr_scales = self.lr_scales 263 | self.offsets = {} 264 | 265 | for i in range(len(offsets)): 266 | if lr_scales[i] != 0.0: 267 | self.offsets[list(offsets[i].keys())[0]] = list(offsets[i].values())[0] 268 | 269 | self.mr_list = torch.zeros([self._xyz.shape[0], 1]).cuda().requires_grad_(False) 270 | 271 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scalar : float): 272 | self.spatial_lr_scalar = spatial_lr_scalar 273 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 274 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 275 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 276 | features[:, :3, 0 ] = fused_color 277 | features[:, 3:, 1:] = 0.0 278 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 279 | 280 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 281 | scalars = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 282 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 283 | rots[:, 0] = 1 284 | 285 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 286 | 287 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 288 | self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) 289 | self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) 290 | self._scaling = nn.Parameter(scalars.requires_grad_(True)) 291 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 292 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 293 | self.init_offset() 294 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 295 | 296 | def training_setup(self, training_args): 297 | self.percent_dense = training_args.percent_dense 298 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 299 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 300 | 301 | l = [ 302 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scalar, "name": "xyz"}, 303 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 304 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 305 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 306 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 307 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}, 308 | ] 309 | 310 | lr_scales = self.lr_scales 311 | l_offset = [] 312 | 313 | _l_offset = [ 314 | {'params': [], 'lr': lr_scales[0]*training_args.position_lr_init * self.spatial_lr_scalar, "name": "_xyz_offset"}, 315 | {'params': [], 'lr': lr_scales[1]*training_args.opacity_lr, "name": "_scaling_offset"}, 316 | {'params': [], 'lr': lr_scales[2]*training_args.rotation_lr, "name": "_opacity_offset"}, 317 | ] 318 | 319 | j=0 320 | for i in range(len(_l_offset)): 321 | if lr_scales[i] != 0.0: 322 | _l_offset[i]['params'] = [self.offsets[list(self.offsets.keys())[j]]] 323 | l_offset += [_l_offset[i]] 324 | j+=1 325 | l = l + l_offset 326 | 327 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 328 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scalar, 329 | lr_final=training_args.position_lr_final*self.spatial_lr_scalar, 330 | lr_delay_mult=training_args.position_lr_delay_mult, 331 | max_steps=training_args.position_lr_max_steps) 332 | 333 | 334 | def update_learning_rate(self, iteration): 335 | ''' Learning rate scheduling per step ''' 336 | for param_group in self.optimizer.param_groups: 337 | if param_group["name"] == "xyz": 338 | lr = self.xyz_scheduler_args(iteration) 339 | param_group['lr'] = lr 340 | return lr 341 | 342 | def construct_list_of_attributes(self, with_offsets=True): 343 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 344 | # All channels except the 3 DC 345 | for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 346 | l.append('f_dc_{}'.format(i)) 347 | for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 348 | l.append('f_rest_{}'.format(i)) 349 | l.append('opacity') 350 | for i in range(self._scaling.shape[1]): 351 | l.append('scalar_{}'.format(i)) 352 | for i in range(self._rotation.shape[1]): 353 | l.append('rot_{}'.format(i)) 354 | 355 | if not with_offsets: 356 | return l 357 | 358 | for i in range(self.offsets["_xyz_offset"].shape[1]*self.offsets["_xyz_offset"].shape[2]): 359 | l.append('xyz_offset_{}'.format(i)) 360 | for i in range(self.offsets["_scaling_offset"].shape[1]*self.offsets["_scaling_offset"].shape[2]): 361 | l.append('scaling_offset_{}'.format(i)) 362 | for i in range(self.offsets["_opacity_offset"].shape[1]*self.offsets["_opacity_offset"].shape[2]): 363 | l.append('opacity_offset_{}'.format(i)) 364 | 365 | l.append('mr') 366 | 367 | return l 368 | 369 | def save_ply(self, path, with_offsets=True): 370 | mkdir_p(os.path.dirname(path)) 371 | 372 | xyz = self._xyz.detach().cpu().numpy() 373 | normals = np.zeros_like(xyz) 374 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 375 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 376 | opacities = self._opacity.detach().cpu().numpy() 377 | scale = self._scaling.detach().cpu().numpy() 378 | rotation = self._rotation.detach().cpu().numpy() 379 | 380 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes(with_offsets=with_offsets)] 381 | 382 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 383 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 384 | 385 | if with_offsets: 386 | _xyz_offset = self.offsets["_xyz_offset"].detach().flatten(start_dim=1).contiguous().cpu().numpy() 387 | _scaling_offset = self.offsets["_scaling_offset"].detach().flatten(start_dim=1).contiguous().cpu().numpy() 388 | _opacity_offset = self.offsets["_opacity_offset"].detach().flatten(start_dim=1).contiguous().cpu().numpy() 389 | 390 | attributes = np.concatenate((attributes, _xyz_offset, _scaling_offset, _opacity_offset), axis=1) 391 | attributes = np.concatenate((attributes, self.mr_list.detach().cpu().numpy()), axis=1) 392 | 393 | elements[:] = list(map(tuple, attributes)) 394 | el = PlyElement.describe(elements, 'vertex') 395 | PlyData([el]).write(path) 396 | 397 | def reset_opacity(self): 398 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 399 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 400 | self._opacity = optimizable_tensors["opacity"] 401 | 402 | def load_ply(self, path): 403 | plydata = PlyData.read(path) 404 | 405 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 406 | np.asarray(plydata.elements[0]["y"]), 407 | np.asarray(plydata.elements[0]["z"])), axis=1) 408 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 409 | 410 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 411 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 412 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 413 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 414 | 415 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 416 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 417 | assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 418 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 419 | for idx, attr_name in enumerate(extra_f_names): 420 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 421 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 422 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 423 | 424 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scalar_")] 425 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 426 | scalars = np.zeros((xyz.shape[0], len(scale_names))) 427 | for idx, attr_name in enumerate(scale_names): 428 | scalars[:, idx] = np.asarray(plydata.elements[0][attr_name]) 429 | 430 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 431 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 432 | rots = np.zeros((xyz.shape[0], len(rot_names))) 433 | for idx, attr_name in enumerate(rot_names): 434 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 435 | 436 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 437 | self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 438 | self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 439 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 440 | self._scaling = nn.Parameter(torch.tensor(scalars, dtype=torch.float, device="cuda").requires_grad_(True)) 441 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 442 | 443 | self.init_offset() 444 | 445 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 446 | 447 | self.active_sh_degree = self.max_sh_degree 448 | 449 | 450 | xyz_offset_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("xyz_offset")] 451 | xyz_offset_names = sorted(xyz_offset_names, key = lambda x: int(x.split('_')[-1])) 452 | xyz_offset = np.zeros((xyz.shape[0], len(xyz_offset_names))) 453 | for idx, attr_name in enumerate(xyz_offset_names): 454 | xyz_offset[:, idx] = np.asarray(plydata.elements[0][attr_name]) 455 | xyz_offset = xyz_offset.reshape((xyz_offset.shape[0], 3, len(xyz_offset_names) // 3)) 456 | 457 | scaling_offset_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scaling_offset")] 458 | scaling_offset_names = sorted(scaling_offset_names, key = lambda x: int(x.split('_')[-1])) 459 | scaling_offset = np.zeros((xyz.shape[0], len(scaling_offset_names))) 460 | for idx, attr_name in enumerate(scaling_offset_names): 461 | scaling_offset[:, idx] = np.asarray(plydata.elements[0][attr_name]) 462 | scaling_offset = scaling_offset.reshape((scaling_offset.shape[0], 3, len(scaling_offset_names) // 3)) 463 | 464 | opacity_offset_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("opacity_offset")] 465 | opacity_offset_names = sorted(opacity_offset_names, key = lambda x: int(x.split('_')[-1])) 466 | opacity_offset = np.zeros((xyz.shape[0], len(opacity_offset_names))) 467 | for idx, attr_name in enumerate(opacity_offset_names): 468 | opacity_offset[:, idx] = np.asarray(plydata.elements[0][attr_name]) 469 | opacity_offset = opacity_offset.reshape((opacity_offset.shape[0], 1, len(opacity_offset_names))) 470 | 471 | self.offsets["_xyz_offset"] = nn.Parameter(torch.tensor(xyz_offset, dtype=torch.float, device="cuda").requires_grad_(True)) 472 | self.offsets["_scaling_offset"] = nn.Parameter(torch.tensor(scaling_offset, dtype=torch.float, device="cuda").requires_grad_(True)) 473 | self.offsets["_opacity_offset"] = nn.Parameter(torch.tensor(opacity_offset, dtype=torch.float, device="cuda").requires_grad_(True)) 474 | 475 | self.mr_list = torch.tensor(np.asarray(plydata.elements[0]["mr"])[..., np.newaxis]).float().cuda().requires_grad_(False) 476 | 477 | def replace_tensor_to_optimizer(self, tensor, name): 478 | optimizable_tensors = {} 479 | for group in self.optimizer.param_groups: 480 | if group["name"] == name: 481 | stored_state = self.optimizer.state.get(group['params'][0], None) 482 | stored_state["exp_avg"] = torch.zeros_like(tensor) 483 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 484 | 485 | del self.optimizer.state[group['params'][0]] 486 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 487 | self.optimizer.state[group['params'][0]] = stored_state 488 | 489 | optimizable_tensors[group["name"]] = group["params"][0] 490 | return optimizable_tensors 491 | 492 | def _prune_optimizer(self, mask): 493 | optimizable_tensors = {} 494 | for group in self.optimizer.param_groups: 495 | stored_state = self.optimizer.state.get(group['params'][0], None) 496 | if stored_state is not None: 497 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 498 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 499 | 500 | del self.optimizer.state[group['params'][0]] 501 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 502 | self.optimizer.state[group['params'][0]] = stored_state 503 | 504 | optimizable_tensors[group["name"]] = group["params"][0] 505 | else: 506 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 507 | optimizable_tensors[group["name"]] = group["params"][0] 508 | return optimizable_tensors 509 | 510 | def prune_points(self, mask): 511 | valid_points_mask = ~mask 512 | 513 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 514 | 515 | self._xyz = optimizable_tensors["xyz"] 516 | self._features_dc = optimizable_tensors["f_dc"] 517 | self._features_rest = optimizable_tensors["f_rest"] 518 | self._opacity = optimizable_tensors["opacity"] 519 | self._scaling = optimizable_tensors["scaling"] 520 | self._rotation = optimizable_tensors["rotation"] 521 | 522 | for name in self.offsets.keys(): 523 | self.offsets[name] = optimizable_tensors[name] 524 | 525 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 526 | 527 | self.denom = self.denom[valid_points_mask] 528 | self.max_radii2D = self.max_radii2D[valid_points_mask] 529 | 530 | def cat_tensors_to_optimizer(self, tensors_dict): 531 | optimizable_tensors = {} 532 | for group in self.optimizer.param_groups: 533 | assert len(group["params"]) == 1 534 | extension_tensor = tensors_dict[group["name"]] 535 | stored_state = self.optimizer.state.get(group['params'][0], None) 536 | if stored_state is not None: 537 | 538 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 539 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 540 | 541 | del self.optimizer.state[group['params'][0]] 542 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 543 | self.optimizer.state[group['params'][0]] = stored_state 544 | 545 | optimizable_tensors[group["name"]] = group["params"][0] 546 | else: 547 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 548 | optimizable_tensors[group["name"]] = group["params"][0] 549 | 550 | return optimizable_tensors 551 | 552 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_offset): 553 | d = {"xyz": new_xyz, 554 | "f_dc": new_features_dc, 555 | "f_rest": new_features_rest, 556 | "opacity": new_opacities, 557 | "scaling" : new_scaling, 558 | "rotation" : new_rotation} 559 | 560 | d = {**d, **new_offset} 561 | 562 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 563 | self._xyz = optimizable_tensors["xyz"] 564 | self._features_dc = optimizable_tensors["f_dc"] 565 | self._features_rest = optimizable_tensors["f_rest"] 566 | self._opacity = optimizable_tensors["opacity"] 567 | self._scaling = optimizable_tensors["scaling"] 568 | self._rotation = optimizable_tensors["rotation"] 569 | 570 | for name in self.offsets.keys(): 571 | self.offsets[name] = optimizable_tensors[name] 572 | 573 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 574 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 575 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 576 | 577 | def spawn(self, extent): 578 | percent_base = self.spawn_percent_base 579 | min_opacity = self.spawn_min_opacity 580 | 581 | mr_mask = torch.norm(self.get_scaling, dim=1) > percent_base * extent 582 | self.mr_list = mr_mask.int()[...,None] 583 | 584 | transparent_mask = (self.get_opacity < min_opacity)[:,0] 585 | self.mr_list[transparent_mask] = 0 586 | 587 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 588 | n_init_points = self.get_xyz.shape[0] 589 | # Extract points that satisfy the gradient condition 590 | padded_grad = torch.zeros((n_init_points), device="cuda") 591 | padded_grad[:grads.shape[0]] = grads.squeeze() 592 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 593 | selected_pts_mask = torch.logical_and(selected_pts_mask, 594 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 595 | 596 | 597 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 598 | means =torch.zeros((stds.size(0), 3),device="cuda") 599 | samples = torch.normal(mean=means, std=stds) 600 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 601 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 602 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 603 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 604 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 605 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 606 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 607 | 608 | new_offset = {} 609 | for name in self.offsets.keys(): 610 | n_dim = len(self.offsets[name].shape) 611 | new_shape = [1 for i in range(n_dim)] 612 | new_shape[0] = N 613 | new_offset[name] = self.offsets[name][selected_pts_mask, ...].repeat(*new_shape) 614 | 615 | self.mr_list = torch.cat([self.mr_list, torch.ones_like(self.mr_list[selected_pts_mask].repeat(N,1))], dim=0) 616 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_offset) 617 | 618 | prune_filter = torch.cat([selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)]) 619 | 620 | self.mr_list = self.mr_list[~prune_filter] 621 | self.prune_points(prune_filter) 622 | 623 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 624 | # Extract points that satisfy the gradient condition 625 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 626 | selected_pts_mask = torch.logical_and(selected_pts_mask, 627 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 628 | 629 | new_xyz = self._xyz[selected_pts_mask] 630 | new_features_dc = self._features_dc[selected_pts_mask] 631 | new_features_rest = self._features_rest[selected_pts_mask] 632 | new_opacities = self._opacity[selected_pts_mask] 633 | new_scaling = self._scaling[selected_pts_mask] 634 | new_rotation = self._rotation[selected_pts_mask] 635 | 636 | new_offset = {} 637 | for name in self.offsets.keys(): 638 | new_offset[name] = self.offsets[name][selected_pts_mask, ...] 639 | 640 | self.mr_list = torch.cat([self.mr_list, self.mr_list[selected_pts_mask]], dim=0) 641 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_offset) 642 | 643 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, ): 644 | grads = self.xyz_gradient_accum / self.denom 645 | grads[grads.isnan()] = 0.0 646 | 647 | self.densify_and_clone(grads, max_grad, extent) 648 | self.densify_and_split(grads, max_grad, extent) 649 | 650 | prune_mask = (self.get_opacity < min_opacity).squeeze() 651 | if max_screen_size: 652 | big_points_vs = self.max_radii2D > max_screen_size 653 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 654 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 655 | 656 | self.mr_list = self.mr_list[~prune_mask] 657 | self.prune_points(prune_mask) 658 | torch.cuda.empty_cache() 659 | 660 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 661 | vp_update_filter = update_filter 662 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[vp_update_filter,:2], dim=-1, keepdim=True) 663 | self.denom[update_filter] += 1 --------------------------------------------------------------------------------