├── vgg.pth ├── assets └── pipeline.png ├── environment.yml ├── lpipsPyTorch ├── __init__.py └── modules │ ├── utils.py │ ├── lpips.py │ └── networks.py ├── utils ├── system_utils.py ├── loss_utils.py ├── image_utils.py ├── graphics_utils.py ├── camera_utils.py ├── sh_utils.py └── general_utils.py ├── README.md ├── gaussian_renderer ├── network_gui.py └── __init__.py ├── full_eval.py ├── scene ├── cameras.py ├── __init__.py ├── colmap_loader.py ├── dataset_readers.py └── gaussian_model.py ├── render_fps.py ├── arguments └── __init__.py ├── metrics.py ├── render.py ├── convert.py ├── train.py ├── mynet.py ├── train_2.py └── model.py /vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcWangzhiru/SpeclatentGS/HEAD/vgg.pth -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcWangzhiru/SpeclatentGS/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: speclatentgs 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.6 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - imgviz 16 | - imageio -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'vgg', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpeclatentGS 2 | This repo contains the official implementation of the **ACM MM 2024** [paper](https://arxiv.org/abs/2409.05868) : 3 | 4 |
5 |

6 | 7 | SpecGaussian with latent features: A high-quality modeling of the view-dependent appearance for 3D Gaussian Splatting 8 | 9 |

10 |

11 | 12 | Zhiru Wang,Shiyun Xie,Chengwei Pan and Guoping Wang 13 | 14 |

15 |
16 | 17 | ### PIPELINE 18 | ![pipeline](/assets/pipeline.png) 19 | 20 | ## Environment Installation 21 | You can install the base environment using: 22 | ```shell 23 | git clone https://github.com/MarcWangzhiru/SpeclatentGS.git 24 | cd SpeclatentGS 25 | conda env create --file environment.yml 26 | ``` 27 | For the installation of **submodules**, you can use the following command: 28 | ```shell 29 | cd submodules/diff-gaussian-rasterization 30 | python stup.py install 31 | ``` 32 | and 33 | ```shell 34 | cd submodules/simple-knn 35 | python stup.py install 36 | ``` 37 | You also need to install the **tinycudann** library. In general, you can use the following command: 38 | ```shell 39 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 40 | ``` 41 | 42 | ## Dataset Preparation 43 | The dataset used in our method is in the same format as the dataset in Gaussian Splatting. If you want to use your custom dataset, follow the process of in [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting.git). We obtained our own [shiny_dataset]( https://drive.google.com/file/d/1mmBmptl9Pd8crLfO9y2E51F_R4Ecy0R1/view?usp=sharing) by resize the images of original [Shiny Dataset](https://vistec-my.sharepoint.com/:f:/g/personal/pakkapon_p_s19_vistec_ac_th/EnIUhsRVJOdNsZ_4smdhye0B8z0VlxqOR35IR3bp0uGupQ?e=TsaQgM) and recolmap. 44 | 45 | 46 | ## Trainng 47 | For training, you can use the following command: 48 | ```shell 49 | python train.py -s --eval 50 | ``` 51 | ## Evalution 52 | For evalution, you can use the following command: 53 | ```shell 54 | python render.py -m --eval 55 | ``` 56 | 57 | 58 | ## Citation 59 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from sklearn.metrics import confusion_matrix 14 | import numpy as np 15 | 16 | def mse(img1, img2): 17 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | 23 | def nanmean(data, **args): 24 | # This makes it ignore the first 'background' class 25 | return np.ma.masked_array(data, np.isnan(data)).mean(**args) 26 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean() 27 | 28 | 29 | def calculate_segmentation_metrics(true_labels, predicted_labels, number_classes, ignore_label): 30 | # if (true_labels == ignore_label).all(): 31 | # return [0]*4 32 | 33 | true_labels = true_labels.flatten() 34 | predicted_labels = predicted_labels.flatten() 35 | valid_pix_ids = true_labels != ignore_label 36 | predicted_labels = predicted_labels[valid_pix_ids] 37 | true_labels = true_labels[valid_pix_ids] 38 | 39 | conf_mat = confusion_matrix(true_labels, predicted_labels, labels=list(range(number_classes))) 40 | norm_conf_mat = np.transpose( 41 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1)) 42 | 43 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) # missing class will have NaN at corresponding class 44 | exsiting_class_mask = ~ missing_class_mask 45 | 46 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat)) 47 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat)) 48 | ious = np.zeros(number_classes) 49 | for class_id in range(number_classes): 50 | ious[class_id] = (conf_mat[class_id, class_id] / ( 51 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) - 52 | conf_mat[class_id, class_id])) 53 | miou = nanmean(ious) 54 | miou_valid_class = np.mean(ious[exsiting_class_mask]) 55 | return miou, miou_valid_class, total_accuracy, class_average_accuracy, ious 56 | 57 | def _fast_hist(num_classes, label_pred, label_true): 58 | # 找出标签中需要计算的类别,去掉了背景 59 | mask = (label_true >= 0) & (label_true < num_classes) 60 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 61 | hist = np.bincount( 62 | num_classes * label_true[mask].astype(int) + 63 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 64 | return hist 65 | 66 | 67 | def evaluate(num_classes,predictions, gts): 68 | hist = np.zeros((num_classes, num_classes)) 69 | for lp, lt in zip(predictions, gts): 70 | assert len(lp.flatten()) == len(lt.flatten()) 71 | hist += _fast_hist(num_classes,lp.flatten(), lt.flatten()) 72 | 73 | # miou 74 | iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 75 | miou = np.nanmean(iou) 76 | 77 | acc = np.diag(hist).sum() / hist.sum() 78 | acc_cls = np.nanmean(np.diag(hist) / hist.sum(axis=1)) 79 | 80 | freq = hist.sum(axis=1) / hist.sum() 81 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum() 82 | 83 | return acc, acc_cls, iou, miou, fwavacc -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 68 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | 70 | if not args.skip_metrics: 71 | scenes_string = "" 72 | for scene in all_scenes: 73 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 74 | 75 | os.system("python metrics.py -m " + scenes_string) -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | 86 | 87 | def getc2w(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 88 | Rt = np.zeros((4, 4)) 89 | Rt[:3, :3] = R.transpose() 90 | Rt[:3, 3] = t 91 | Rt[3, 3] = 1.0 92 | 93 | C2W = np.linalg.inv(Rt) 94 | cam_center = C2W[:3, 3] 95 | cam_center = (cam_center + translate) * scale 96 | C2W[:3, 3] = cam_center 97 | 98 | return np.float32(C2W) 99 | 100 | 101 | def views_dir(H, W, K, c2w): 102 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 103 | # 生成网格 104 | 105 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1) 106 | # 获取以相机原点为世界坐标系中心的相机光线,[H, W, 3] 3->[i-cx,j-cy,f]/f = [(i-cx)/f,(j-cy)/f,1] 107 | # Rotate ray directions from camera frame to the world frame 108 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) 109 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] [H, W, 1, 3] 110 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 111 | # rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 112 | # return rays_o, rays_d 113 | # print(rays_d.shape) 114 | return rays_d 115 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getc2w 16 | from imgviz import label_colormap 17 | import math 18 | from mynet import MyNet, embedding_fn 19 | from utils.graphics_utils import views_dir 20 | 21 | class Camera(nn.Module): 22 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 23 | image_name, semantic_image, semantic_image_name, semantic_classes, 24 | uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda" 25 | ): 26 | super(Camera, self).__init__() 27 | 28 | self.uid = uid 29 | self.colmap_id = colmap_id 30 | self.R = R 31 | self.T = T 32 | self.FoVx = FoVx 33 | self.FoVy = FoVy 34 | self.image_name = image_name 35 | # self.semantic_image_name = semantic_image_name 36 | # self.color_map_np = label_colormap()[semantic_classes] 37 | 38 | try: 39 | self.data_device = torch.device(data_device) 40 | except Exception as e: 41 | print(e) 42 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 43 | self.data_device = torch.device("cuda") 44 | 45 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 46 | # self.semantic_image = semantic_image.to(self.data_device) 47 | self.image_width = self.original_image.shape[2] 48 | self.image_height = self.original_image.shape[1] 49 | self.K = np.array([ 50 | [self.image_width / (2 * math.tan(self.FoVx / 2)), 0, 0.5 * self.image_width], 51 | [0, self.image_height / (2 * math.tan(self.FoVy / 2)), 0.5 * self.image_height], 52 | [0, 0, 1] 53 | ]) 54 | 55 | if gt_alpha_mask is not None: 56 | self.original_image *= gt_alpha_mask.to(self.data_device) 57 | else: 58 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 59 | 60 | self.zfar = 100.0 61 | self.znear = 0.01 62 | 63 | self.trans = trans 64 | self.scale = scale 65 | 66 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 67 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 68 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 69 | self.camera_center = self.world_view_transform.inverse()[3, :3] 70 | self.c2w = getc2w(R, T, trans, scale) 71 | 72 | rays_d = torch.from_numpy( 73 | views_dir(self.image_height, self.image_width, self.K, self.c2w)).cuda() 74 | 75 | self.views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 76 | # print(self.K, self.c2w) 77 | 78 | 79 | 80 | class MiniCam: 81 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 82 | self.image_width = width 83 | self.image_height = height 84 | self.FoVy = fovy 85 | self.FoVx = fovx 86 | self.znear = znear 87 | self.zfar = zfar 88 | self.world_view_transform = world_view_transform 89 | self.full_proj_transform = full_proj_transform 90 | view_inv = torch.inverse(self.world_view_transform) 91 | self.camera_center = view_inv[3][:3] 92 | 93 | 94 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene.cameras import Camera 14 | import numpy as np 15 | from utils.general_utils import PILtoTorch 16 | from utils.graphics_utils import fov2focal 17 | 18 | WARNED = False 19 | 20 | def loadCam(args, id, cam_info, resolution_scale): 21 | orig_w, orig_h = cam_info.image.size 22 | 23 | if args.resolution in [1, 2, 4, 8]: 24 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 25 | else: # should be a type that converts to float 26 | if args.resolution == -1: 27 | if orig_w > 1600: 28 | global WARNED 29 | if not WARNED: 30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 32 | WARNED = True 33 | global_down = orig_w / 1600 34 | else: 35 | global_down = 1 36 | else: 37 | global_down = orig_w / args.resolution 38 | 39 | scale = float(global_down) * float(resolution_scale) 40 | resolution = (int(orig_w / scale), int(orig_h / scale)) 41 | 42 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 43 | origin_semantic_image = cam_info.semantic_image 44 | 45 | # semantic_remap = np.empty(origin_semantic_image.shape) 46 | # for i in range(origin_semantic_image.shape[0]): 47 | # for j in range(origin_semantic_image.shape[1]): 48 | # semantic_value = origin_semantic_image[i][j] 49 | # if semantic_value in cam_info.semantic_classes: 50 | # semantic_remap[i, j] = cam_info.semantic_classes.index(semantic_value) 51 | 52 | gt_image = resized_image_rgb[:3, ...] 53 | # gt_semantic_image = torch.from_numpy(np.array(semantic_remap)).long() 54 | gt_semantic_image=None 55 | loaded_mask = None 56 | 57 | if resized_image_rgb.shape[1] == 4: 58 | loaded_mask = resized_image_rgb[3:4, ...] 59 | 60 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 61 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 62 | image=gt_image, gt_alpha_mask=loaded_mask, 63 | image_name=cam_info.image_name, 64 | semantic_image=gt_semantic_image, 65 | semantic_image_name=cam_info.semantic_image_name, 66 | semantic_classes=cam_info.semantic_classes, 67 | uid=id, data_device=args.data_device) 68 | 69 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 70 | camera_list = [] 71 | 72 | for id, c in enumerate(cam_infos): 73 | camera_list.append(loadCam(args, id, c, resolution_scale)) 74 | 75 | return camera_list 76 | 77 | def camera_to_JSON(id, camera : Camera): 78 | Rt = np.zeros((4, 4)) 79 | Rt[:3, :3] = camera.R.transpose() 80 | Rt[:3, 3] = camera.T 81 | Rt[3, 3] = 1.0 82 | 83 | W2C = np.linalg.inv(Rt) 84 | pos = W2C[:3, 3] 85 | rot = W2C[:3, :3] 86 | serializable_array_2d = [x.tolist() for x in rot] 87 | camera_entry = { 88 | 'id' : id, 89 | 'img_name' : camera.image_name, 90 | 'sem_img_name': camera.semantic_image_name, 91 | 'width' : camera.width, 92 | 'height' : camera.height, 93 | 'position': pos.tolist(), 94 | 'rotation': serializable_array_2d, 95 | 'fy' : fov2focal(camera.FovY, camera.height), 96 | 'fx' : fov2focal(camera.FovX, camera.width) 97 | } 98 | return camera_entry 99 | -------------------------------------------------------------------------------- /render_fps.py: -------------------------------------------------------------------------------- 1 | # roup, https://team.inria.fr/graphdeco 2 | # All rights reserved. 3 | # 4 | # This software is free for non-commercial, research and evaluation use 5 | # under the terms of the LICENSE.md file. 6 | # 7 | # For inquiries contact george.drettakis@inria.fr 8 | # 9 | 10 | import torch 11 | import numpy as np 12 | from scene import Scene 13 | import os 14 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | import imageio 20 | from utils.general_utils import safe_state, calculate_reflection_direction, visualize_normal_map 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | from gaussian_renderer import GaussianModel 24 | from utils.image_utils import psnr, calculate_segmentation_metrics, evaluate 25 | from model import UNet, SimpleUNet, SimpleNet 26 | from mynet import MyNet, embedding_fn 27 | from utils.graphics_utils import views_dir 28 | import time 29 | 30 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 31 | net_path = os.path.join(model_path, "model_ckpt{}pth".format(iteration)) 32 | model = MyNet().to("cuda") 33 | net_weights = torch.load(net_path) 34 | # print(net_weights) 35 | model.load_state_dict(net_weights) 36 | 37 | t_list = [] 38 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 39 | torch.cuda.synchronize(); 40 | t0 = time.time() 41 | # color 42 | render_pkg = render(view, gaussians, pipeline, background) 43 | 44 | rendered_features, viewspace_point_tensor, visibility_filter, radii, depths, _ = render_pkg 45 | 46 | # 加入视角信息 47 | rays_d = torch.from_numpy(views_dir(view.image_height, view.image_width, view.K, view.c2w)).cuda() 48 | 49 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 50 | 51 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 52 | image = model(*rendered_features) 53 | rendering = image['im_out'].squeeze(0) 54 | 55 | torch.cuda.synchronize(); 56 | t1 = time.time() 57 | 58 | t_list.append(t1 - t0) 59 | 60 | t = np.array(t_list[3:]) 61 | fps = 1.0 / t.mean() 62 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m') 63 | 64 | 65 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): 66 | with torch.no_grad(): 67 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes) 68 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 69 | 70 | 71 | 72 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 73 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 74 | 75 | if not skip_train: 76 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 77 | 78 | if not skip_test: 79 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 80 | 81 | if __name__ == "__main__": 82 | # Set up command line argument parser 83 | parser = ArgumentParser(description="Testing script parameters") 84 | model = ModelParams(parser, sentinel=True) 85 | pipeline = PipelineParams(parser) 86 | parser.add_argument("--iteration", default=-1, type=int) 87 | parser.add_argument("--skip_train", action="store_true") 88 | parser.add_argument("--skip_test", action="store_true") 89 | parser.add_argument("--quiet", action="store_true") 90 | args = get_combined_args(parser) 91 | print("Rendering " + args.model_path) 92 | 93 | # Initialize system state (RNG) 94 | safe_state(args.quiet) 95 | 96 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self.num_sem_classes = 16 51 | self._source_path = "" 52 | self._model_path = "" 53 | self._images = "images" 54 | # self._semantic_class = "semantic_class" 55 | self._resolution = -1 56 | self._white_background = False 57 | self.data_device = "cuda" 58 | self.eval = False 59 | super().__init__(parser, "Loading Parameters", sentinel) 60 | 61 | def extract(self, args): 62 | g = super().extract(args) 63 | g.source_path = os.path.abspath(g.source_path) 64 | return g 65 | 66 | class PipelineParams(ParamGroup): 67 | def __init__(self, parser): 68 | self.convert_SHs_python = False 69 | self.compute_cov3D_python = False 70 | self.debug = False 71 | super().__init__(parser, "Pipeline Parameters") 72 | 73 | class OptimizationParams(ParamGroup): 74 | def __init__(self, parser): 75 | self.iterations = 30_000 76 | self.position_lr_init = 0.00016 77 | self.position_lr_final = 0.0000016 78 | self.position_lr_delay_mult = 0.01 79 | self.position_lr_max_steps = 30_000 # 30_000 80 | self.feature_lr = 0.0025 81 | self.semantic_lr = 0.0025 #0.0025 82 | self.opacity_lr = 0.025 # 0.025 83 | self.scaling_lr = 0.005 84 | self.rotation_lr = 0.001 85 | self.percent_dense = 0.01 86 | self.lambda_dssim = 0.2 87 | self.densification_interval = 100 88 | self.opacity_reset_interval = 3000 89 | self.densify_from_iter = 500 90 | self.densify_until_iter = 15_000 # 15_000 91 | self.densify_grad_threshold = 0.0002 92 | super().__init__(parser, "Optimization Parameters") 93 | 94 | def get_combined_args(parser : ArgumentParser): 95 | cmdlne_string = sys.argv[1:] 96 | cfgfile_string = "Namespace()" 97 | args_cmdline = parser.parse_args(cmdlne_string) 98 | 99 | try: 100 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 101 | print("Looking for config file in", cfgfilepath) 102 | with open(cfgfilepath) as cfg_file: 103 | print("Config file found: {}".format(cfgfilepath)) 104 | cfgfile_string = cfg_file.read() 105 | except TypeError: 106 | print("Config file not found at") 107 | pass 108 | args_cfgfile = eval(cfgfile_string) 109 | 110 | merged_dict = vars(args_cfgfile).copy() 111 | for k,v in vars(args_cmdline).items(): 112 | if v != None: 113 | merged_dict[k] = v 114 | return Namespace(**merged_dict) 115 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 75 | 76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 79 | print("") 80 | 81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 82 | "PSNR": torch.tensor(psnrs).mean().item(), 83 | "LPIPS": torch.tensor(lpipss).mean().item()}) 84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 87 | 88 | with open(scene_dir + "/results.json", 'w') as fp: 89 | json.dump(full_dict[scene_dir], fp, indent=True) 90 | with open(scene_dir + "/per_view.json", 'w') as fp: 91 | json.dump(per_view_dict[scene_dir], fp, indent=True) 92 | except: 93 | print("Unable to compute metrics for model", scene_dir) 94 | 95 | if __name__ == "__main__": 96 | device = torch.device("cuda:0") 97 | torch.cuda.set_device(device) 98 | 99 | # Set up command line argument parser 100 | parser = ArgumentParser(description="Training script parameters") 101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 102 | args = parser.parse_args() 103 | evaluate(args.model_paths) 104 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | # from typing import NamedTuple 16 | # from diff_gaussian_rasterization import GaussianRasterizer 17 | from scene.gaussian_model import GaussianModel 18 | from utils.sh_utils import eval_sh 19 | from utils.general_utils import compute_gaussian_normals, flip_align_view 20 | import torch.nn.functional as F 21 | import time 22 | 23 | 24 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 25 | """ 26 | Render the scene. 27 | 28 | Background tensor (bg_color) must be on GPU! 29 | """ 30 | 31 | scale_factor = [1] 32 | 33 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 34 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 35 | try: 36 | screenspace_points.retain_grad() 37 | except: 38 | pass 39 | 40 | # Set up rasterization configuration 41 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 42 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 43 | 44 | rasterizers = [] 45 | for scale in scale_factor: 46 | raster_setting = GaussianRasterizationSettings( 47 | image_height=int(viewpoint_camera.image_height/scale), 48 | image_width=int(viewpoint_camera.image_width/scale), 49 | tanfovx=tanfovx, 50 | tanfovy=tanfovy, 51 | bg=bg_color, 52 | scale_modifier=scaling_modifier, 53 | viewmatrix=viewpoint_camera.world_view_transform, 54 | projmatrix=viewpoint_camera.full_proj_transform, 55 | sh_degree=pc.active_sh_degree, 56 | # num_sem_classes=pc.num_sem_classes, 57 | campos=viewpoint_camera.camera_center, 58 | prefiltered=False, 59 | debug=pipe.debug 60 | ) 61 | rasterizer = GaussianRasterizer(raster_settings=raster_setting) 62 | rasterizers.append(rasterizer) 63 | 64 | means3D = pc.get_xyz 65 | means2D = screenspace_points 66 | opacity = pc.get_opacity 67 | 68 | # xyz = pc.get_xyz 69 | dir_pp = (means3D - viewpoint_camera.camera_center.repeat(means3D.shape[0], 1)) 70 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True) 71 | color_features = pc.get_semantic 72 | 73 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 74 | # scaling / rotation by the rasterizer. 75 | scales = None 76 | rotations = None 77 | cov3D_precomp = None 78 | if pipe.compute_cov3D_python: 79 | cov3D_precomp = pc.get_covariance(scaling_modifier) 80 | else: 81 | scales = pc.get_scaling 82 | rotations = pc.get_rotation 83 | 84 | normals = pc.get_normal 85 | 86 | direction_features = pc.mlp_direction_head( 87 | torch.cat([pc.direction_encoding(dir_pp), normals], dim=-1)).float() 88 | 89 | Ln = 0 90 | 91 | semantic = torch.cat([color_features, direction_features, normals], dim=-1) 92 | 93 | rendered_features = [] 94 | rendered_depths = [] 95 | visibility_filter = 0 96 | radii = 0 97 | 98 | for i in range(1): 99 | semantic_logits, rendered_depth, rendered_alpha, radii = rasterizers[i]( 100 | means3D = means3D, 101 | means2D = means2D, 102 | # shs = shs, 103 | # colors_precomp = colors_precomp, 104 | semantic = semantic, 105 | opacities = opacity, 106 | scales = scales, 107 | rotations = rotations, 108 | cov3D_precomp = cov3D_precomp) 109 | rendered_features.append(semantic_logits.unsqueeze(0)) 110 | rendered_depths.append(rendered_depth.unsqueeze(0)) 111 | if i == 0: 112 | visibility_filter = radii > 0 113 | radii = radii 114 | # rendered_image: [3, 376, 1408] [3, image_height, image_width] 115 | # radii: [10458] 116 | 117 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 118 | # They will be excluded from value updates used in the splitting criteria. 119 | return rendered_features, screenspace_points, visibility_filter, radii, rendered_depths, Ln 120 | 121 | 122 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | import torch 16 | from utils.system_utils import searchForMaxIteration 17 | from scene.dataset_readers import sceneLoadTypeCallbacks 18 | from scene.gaussian_model import GaussianModel 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 21 | 22 | class Scene: 23 | 24 | gaussians : GaussianModel 25 | 26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.loaded_iter = None 32 | self.gaussians = gaussians 33 | 34 | if load_iteration: 35 | if load_iteration == -1: 36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 37 | else: 38 | self.loaded_iter = load_iteration 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | self.train_cameras = {} 42 | self.test_cameras = {} 43 | 44 | if os.path.exists(os.path.join(args.source_path, "sparse")): 45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 47 | print("Found transforms_train.json file, assuming Blender data set!") 48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 49 | else: 50 | assert False, "Could not recognize scene type!" 51 | 52 | if not self.loaded_iter: 53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 54 | dest_file.write(src_file.read()) 55 | json_cams = [] 56 | camlist = [] 57 | if scene_info.test_cameras: 58 | camlist.extend(scene_info.test_cameras) 59 | if scene_info.train_cameras: 60 | camlist.extend(scene_info.train_cameras) 61 | for id, cam in enumerate(camlist): 62 | json_cams.append(camera_to_JSON(id, cam)) 63 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 64 | json.dump(json_cams, file) 65 | 66 | if shuffle: 67 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 68 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 69 | 70 | self.cameras_extent = scene_info.nerf_normalization["radius"] 71 | 72 | for resolution_scale in resolution_scales: 73 | print("Loading Training Cameras") 74 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 75 | print("Loading Test Cameras") 76 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 77 | 78 | if self.loaded_iter: 79 | self.gaussians.load_ply(os.path.join(self.model_path, 80 | "point_cloud", 81 | "iteration_" + str(self.loaded_iter), 82 | "point_cloud.ply")) 83 | # torch.nn.ModuleList([self.gaussians.recolor, self.gaussians.mlp_head, self.gaussians.mlp_direction_head, self.gaussians.direction_encoding]).load_state_dict( 84 | # torch.load(os.path.join(self.model_path, 85 | # "point_cloud", 86 | # "iteration_" + str(self.loaded_iter), 87 | # "point_cloud.pth"))) 88 | torch.nn.ModuleList([ self.gaussians.mlp_direction_head, self.gaussians.direction_encoding, self.gaussians.mlp_normal_head, self.gaussians.positional_encoding]).load_state_dict( 89 | torch.load(os.path.join(self.model_path, 90 | "point_cloud", 91 | "iteration_" + str(self.loaded_iter), 92 | "point_cloud.pth"))) 93 | else: 94 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 95 | 96 | def save(self, iteration): 97 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 98 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 99 | 100 | torch.save(torch.nn.ModuleList([self.gaussians.mlp_direction_head, self.gaussians.direction_encoding, self.gaussians.mlp_normal_head, self.gaussians.positional_encoding]).state_dict(), 101 | os.path.join(point_cloud_path, "point_cloud.pth")) 102 | 103 | def getTrainCameras(self, scale=1.0): 104 | return self.train_cameras[scale] 105 | 106 | def getTestCameras(self, scale=1.0): 107 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from scene import Scene 15 | import time 16 | import os 17 | 18 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 19 | from tqdm import tqdm 20 | from os import makedirs 21 | from gaussian_renderer import render 22 | import torchvision 23 | import imageio 24 | from utils.general_utils import safe_state, calculate_reflection_direction, visualize_normal_map 25 | from argparse import ArgumentParser 26 | from arguments import ModelParams, PipelineParams, get_combined_args 27 | from gaussian_renderer import GaussianModel 28 | from utils.image_utils import psnr, calculate_segmentation_metrics, evaluate 29 | from mynet import MyNet, embedding_fn 30 | from utils.graphics_utils import views_dir 31 | import time 32 | 33 | 34 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 35 | net_path = os.path.join(model_path, "model_ckpt{}pth".format(iteration)) 36 | model = MyNet().to("cuda") 37 | net_weights = torch.load(net_path) 38 | model.load_state_dict(net_weights) 39 | 40 | model.eval() 41 | 42 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 43 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 44 | 45 | mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "mask") 46 | f_path = os.path.join(model_path, name, "ours_{}".format(iteration), "f") 47 | highlight_path = os.path.join(model_path, name, "ours_{}".format(iteration), "highlight") 48 | color_path = os.path.join(model_path, name, "ours_{}".format(iteration), "color") 49 | normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "normal") 50 | 51 | makedirs(render_path, exist_ok=True) 52 | makedirs(gts_path, exist_ok=True) 53 | 54 | makedirs(mask_path, exist_ok=True) 55 | makedirs(f_path, exist_ok=True) 56 | makedirs(highlight_path, exist_ok=True) 57 | makedirs(color_path, exist_ok=True) 58 | makedirs(normal_path, exist_ok=True) 59 | 60 | all_time = [] 61 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 62 | start_time = time.time() 63 | 64 | render_pkg = render(view, gaussians, pipeline, background) 65 | 66 | rendered_features, viewspace_point_tensor, visibility_filter, radii, depths, _ = render_pkg 67 | 68 | views_emd = view.views_emd 69 | 70 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 71 | image = model(*rendered_features) 72 | rendering = image['im_out'].squeeze(0) 73 | 74 | all_time.append(time.time() - start_time) 75 | 76 | gt = view.original_image[0:3, :, :] 77 | 78 | 79 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 80 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 81 | # 82 | torchvision.utils.save_image(image['mask_out'].squeeze(0), 83 | os.path.join(mask_path, '{0:05d}'.format(idx) + ".png")) 84 | 85 | torchvision.utils.save_image(image['f_out'].squeeze(0), os.path.join(f_path, '{0:05d}'.format(idx) + ".png")) 86 | torchvision.utils.save_image(image['highlight_out'].squeeze(0), 87 | os.path.join(highlight_path, '{0:05d}'.format(idx) + ".png")) 88 | torchvision.utils.save_image(image['color_out'].squeeze(0), 89 | os.path.join(color_path, '{0:05d}'.format(idx) + ".png")) 90 | print("Average time per image: {}".format(sum(all_time) / len(all_time))) 91 | print("Render FPS: {}".format(1 / (sum(all_time) / len(all_time)))) 92 | 93 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool): 94 | with torch.no_grad(): 95 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes) 96 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 97 | 98 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 99 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 100 | 101 | # for camera in scene.getTestCameras(): 102 | # print(camera.R) 103 | 104 | if not skip_train: 105 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, 106 | background) 107 | 108 | if not skip_test: 109 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, 110 | background) 111 | 112 | 113 | if __name__ == "__main__": 114 | # Set up command line argument parser 115 | parser = ArgumentParser(description="Testing script parameters") 116 | model = ModelParams(parser, sentinel=True) 117 | pipeline = PipelineParams(parser) 118 | parser.add_argument("--iteration", default=-1, type=int) 119 | parser.add_argument("--skip_train", action="store_true") 120 | parser.add_argument("--skip_test", action="store_true") 121 | parser.add_argument("--quiet", action="store_true") 122 | args = get_combined_args(parser) 123 | print("Rendering " + args.model_path) 124 | 125 | # Initialize system state (RNG) 126 | safe_state(args.quiet) 127 | 128 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) 129 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | import torch.nn.functional as F 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x/(1-x)) 21 | 22 | def PILtoTorch(pil_image, resolution): 23 | resized_image_PIL = pil_image.resize(resolution) 24 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 25 | if len(resized_image.shape) == 3: 26 | return resized_image.permute(2, 0, 1) 27 | else: 28 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 29 | 30 | def get_expon_lr_func( 31 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 32 | ): 33 | """ 34 | Copied from Plenoxels 35 | 36 | Continuous learning rate decay function. Adapted from JaxNeRF 37 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 38 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 39 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 40 | function of lr_delay_mult, such that the initial learning rate is 41 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 42 | to the normal learning rate when steps>lr_delay_steps. 43 | :param conf: config subtree 'lr' or similar 44 | :param max_steps: int, the number of steps during optimization. 45 | :return HoF which takes step as input 46 | """ 47 | 48 | def helper(step): 49 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 50 | # Disable this parameter 51 | return 0.0 52 | if lr_delay_steps > 0: 53 | # A kind of reverse cosine decay. 54 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 55 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 56 | ) 57 | else: 58 | delay_rate = 1.0 59 | t = np.clip(step / max_steps, 0, 1) 60 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 61 | return delay_rate * log_lerp 62 | 63 | return helper 64 | 65 | def strip_lowerdiag(L): 66 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 67 | 68 | uncertainty[:, 0] = L[:, 0, 0] 69 | uncertainty[:, 1] = L[:, 0, 1] 70 | uncertainty[:, 2] = L[:, 0, 2] 71 | uncertainty[:, 3] = L[:, 1, 1] 72 | uncertainty[:, 4] = L[:, 1, 2] 73 | uncertainty[:, 5] = L[:, 2, 2] 74 | return uncertainty 75 | 76 | def strip_symmetric(sym): 77 | return strip_lowerdiag(sym) 78 | 79 | def build_rotation(r): 80 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 81 | 82 | q = r / norm[:, None] 83 | 84 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 85 | 86 | r = q[:, 0] 87 | x = q[:, 1] 88 | y = q[:, 2] 89 | z = q[:, 3] 90 | 91 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 92 | R[:, 0, 1] = 2 * (x*y - r*z) 93 | R[:, 0, 2] = 2 * (x*z + r*y) 94 | R[:, 1, 0] = 2 * (x*y + r*z) 95 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 96 | R[:, 1, 2] = 2 * (y*z - r*x) 97 | R[:, 2, 0] = 2 * (x*z - r*y) 98 | R[:, 2, 1] = 2 * (y*z + r*x) 99 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 100 | return R 101 | 102 | def build_scaling_rotation(s, r): 103 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 104 | R = build_rotation(r) 105 | 106 | L[:,0,0] = s[:,0] 107 | L[:,1,1] = s[:,1] 108 | L[:,2,2] = s[:,2] 109 | 110 | L = R @ L 111 | return L 112 | 113 | def safe_state(silent): 114 | old_f = sys.stdout 115 | class F: 116 | def __init__(self, silent): 117 | self.silent = silent 118 | 119 | def write(self, x): 120 | if not self.silent: 121 | if x.endswith("\n"): 122 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 123 | else: 124 | old_f.write(x) 125 | 126 | def flush(self): 127 | old_f.flush() 128 | 129 | sys.stdout = F(silent) 130 | 131 | random.seed(0) 132 | np.random.seed(0) 133 | torch.manual_seed(0) 134 | torch.cuda.set_device(torch.device("cuda:0")) 135 | 136 | # def compute_gaussian_normals(cov_matrices): 137 | # eigenvalues, eigenvectors = np.linalg.eig(cov_matrices) 138 | # min_eigenvalue_indices = np.argmin(eigenvalues, axis=1) 139 | # min_eigenvectors = eigenvectors[np.arange(cov_matrices.shape[0]), :, min_eigenvalue_indices] 140 | # normals = min_eigenvectors / np.linalg.norm(min_eigenvectors, axis=1)[:, np.newaxis] 141 | # # print(normals.shape) 142 | # return normals 143 | 144 | def compute_gaussian_normals(s, r): 145 | L = build_scaling_rotation(s, r) 146 | actual_covariance = L @ L.transpose(1, 2) 147 | 148 | covariance_matrix = actual_covariance 149 | eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix, UPLO='U') 150 | min_eigenvalue_idx = torch.argmin(eigenvalues) 151 | min_eigenvector = eigenvectors[:, min_eigenvalue_idx] 152 | normal = min_eigenvector / torch.norm(min_eigenvector) 153 | normal = normal.cuda() 154 | return normal 155 | 156 | def get_minimum_axis(scales, rotations): 157 | sorted_idx = torch.argsort(scales, descending=False, dim=-1) 158 | R = build_rotation(rotations) 159 | R_sorted = torch.gather(R, dim=2, index=sorted_idx[:,None,:].repeat(1, 3, 1)).squeeze() 160 | x_axis = R_sorted[:,0,:] # normalized by defaut 161 | 162 | return x_axis 163 | 164 | def flip_align_view(normal, viewdir): 165 | # normal: (N, 3), viewdir: (N, 3) 166 | dotprod = torch.sum(normal * (-viewdir), dim=-1, keepdims=True) # (N, 1) 167 | non_flip = dotprod>=0 # (N, 1) 168 | normal_flipped = normal*torch.where(non_flip, 1, -1) # (N, 3) 169 | return normal_flipped, non_flip 170 | 171 | def calculate_reflection_direction(wo, n): 172 | wo_dot_n = torch.sum(wo * n, dim=-1, keepdim=True) 173 | reflection_direction = 2 * wo_dot_n * n - wo 174 | return reflection_direction 175 | 176 | # def calculate_reflection_direction(wo, n): 177 | # normal_dot_viewdir = ((-wo) * n).sum(dim=-1, keepdim=True) 178 | # return normal_dot_viewdir 179 | 180 | 181 | import torch 182 | 183 | def visualize_normal_map(normal_map): 184 | # normal_map = F.normalize(normal_map, dim=0) 185 | normal_map = (normal_map + 1.0) /2.0 186 | normal_map = torch.clamp(normal_map, 0, 1) 187 | return normal_map -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | 26 | class CameraInfo(NamedTuple): 27 | uid: int 28 | R: np.array 29 | T: np.array 30 | FovY: np.array 31 | FovX: np.array 32 | image: np.array 33 | image_path: str 34 | image_name: str 35 | # semantic_image: np.array 36 | # semantic_path: str 37 | # semantic_image_name: str 38 | # semantic_classes: list 39 | # num_semantic_classes: int 40 | semantic_image: None 41 | semantic_path: None 42 | semantic_image_name: None 43 | semantic_classes: None 44 | num_semantic_classes: None 45 | width: int 46 | height: int 47 | 48 | class SceneInfo(NamedTuple): 49 | point_cloud: BasicPointCloud 50 | train_cameras: list 51 | test_cameras: list 52 | nerf_normalization: dict 53 | ply_path: str 54 | 55 | def getNerfppNorm(cam_info): 56 | def get_center_and_diag(cam_centers): 57 | cam_centers = np.hstack(cam_centers) 58 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 59 | center = avg_cam_center 60 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 61 | diagonal = np.max(dist) 62 | return center.flatten(), diagonal 63 | 64 | cam_centers = [] 65 | 66 | for cam in cam_info: 67 | W2C = getWorld2View2(cam.R, cam.T) 68 | C2W = np.linalg.inv(W2C) 69 | cam_centers.append(C2W[:3, 3:4]) 70 | 71 | center, diagonal = get_center_and_diag(cam_centers) 72 | radius = diagonal * 1.1 73 | 74 | translate = -center 75 | 76 | return {"translate": translate, "radius": radius} 77 | 78 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 79 | cam_infos = [] 80 | for idx, key in enumerate(cam_extrinsics): 81 | sys.stdout.write('\r') 82 | # the exact output you're looking for: 83 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 84 | sys.stdout.flush() 85 | 86 | extr = cam_extrinsics[key] 87 | intr = cam_intrinsics[extr.camera_id] 88 | height = intr.height 89 | width = intr.width 90 | 91 | uid = intr.id 92 | R = np.transpose(qvec2rotmat(extr.qvec)) 93 | T = np.array(extr.tvec) 94 | 95 | if intr.model=="SIMPLE_PINHOLE": 96 | focal_length_x = intr.params[0] 97 | FovY = focal2fov(focal_length_x, height) 98 | FovX = focal2fov(focal_length_x, width) 99 | elif intr.model=="PINHOLE": 100 | focal_length_x = intr.params[0] 101 | focal_length_y = intr.params[1] 102 | FovY = focal2fov(focal_length_y, height) 103 | FovX = focal2fov(focal_length_x, width) 104 | else: 105 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 106 | 107 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 108 | image_name = os.path.basename(image_path).split(".")[0] 109 | image = Image.open(image_path) 110 | 111 | # semantic_image 112 | base_dir, _ = os.path.split(images_folder) 113 | # semantic_folder = os.path.join(base_dir, "semantic_class") 114 | # semantic_path = os.path.join(semantic_folder, os.path.basename(extr.name)) 115 | 116 | # prefix = '0000_' 117 | # semantic_image_name = prefix + extr.name 118 | 119 | # semantic_path = os.path.join(semantic_folder, os.path.basename(semantic_image_name)) 120 | # semantic_image = np.array(Image.open(semantic_path)) 121 | 122 | # semantic_classes 123 | # semantic_all_imgs = np.empty(0) 124 | # semantic_all_files = [file for file in os.listdir(semantic_folder) if file.endswith(".png")] 125 | # for semantic_file in semantic_all_files: 126 | # semantic_img_path = os.path.join(semantic_folder, semantic_file) 127 | # semantic_img = np.array(Image.open(semantic_img_path)) 128 | # if semantic_all_imgs.size == 0: 129 | # semantic_all_imgs = semantic_img 130 | # else: 131 | # semantic_all_imgs = np.concatenate((semantic_all_imgs, semantic_img),axis=0) 132 | # 133 | # semantic_classes = np.unique(semantic_all_imgs).astype(np.uint8) 134 | # num_semantic_classes = semantic_classes.shape[0] 135 | # semantic_classes = list(semantic_classes) 136 | 137 | # 快速 138 | # num_semantic_classes = 16 139 | # semantic_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 140 | 141 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 142 | image_path=image_path, image_name=image_name, 143 | semantic_image=None, semantic_path=None, 144 | semantic_image_name=None, 145 | semantic_classes=None, num_semantic_classes=None, 146 | width=width, height=height) 147 | cam_infos.append(cam_info) 148 | sys.stdout.write('\n') 149 | return cam_infos 150 | 151 | def fetchPly(path): 152 | plydata = PlyData.read(path) 153 | vertices = plydata['vertex'] 154 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 155 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 156 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 157 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 158 | 159 | def storePly(path, xyz, rgb): 160 | # Define the dtype for the structured array 161 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 162 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 163 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 164 | 165 | normals = np.zeros_like(xyz) 166 | 167 | elements = np.empty(xyz.shape[0], dtype=dtype) 168 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 169 | elements[:] = list(map(tuple, attributes)) 170 | 171 | # Create the PlyData object and write to file 172 | vertex_element = PlyElement.describe(elements, 'vertex') 173 | ply_data = PlyData([vertex_element]) 174 | ply_data.write(path) 175 | 176 | def readColmapSceneInfo(path, images, eval, llffhold=8): 177 | try: 178 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 179 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 180 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 181 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 182 | except: 183 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 184 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 185 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 186 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 187 | 188 | reading_dir = "images" if images == None else images 189 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 190 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 191 | 192 | if eval: 193 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 194 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 195 | 196 | # train_cam_infos = [c for idx, c in enumerate(train_cam_infos) if idx % 2 == 0] 197 | # test_cam_infos = [c for idx, c in enumerate(test_cam_infos) if idx % 4 == 0] 198 | else: 199 | train_cam_infos = cam_infos 200 | test_cam_infos = [] 201 | 202 | nerf_normalization = getNerfppNorm(train_cam_infos) 203 | 204 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 205 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 206 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 207 | if not os.path.exists(ply_path): 208 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 209 | try: 210 | xyz, rgb, _ = read_points3D_binary(bin_path) 211 | except: 212 | xyz, rgb, _ = read_points3D_text(txt_path) 213 | storePly(ply_path, xyz, rgb) 214 | try: 215 | pcd = fetchPly(ply_path) 216 | except: 217 | pcd = None 218 | 219 | scene_info = SceneInfo(point_cloud=pcd, 220 | train_cameras=train_cam_infos, 221 | test_cameras=test_cam_infos, 222 | nerf_normalization=nerf_normalization, 223 | ply_path=ply_path) 224 | return scene_info 225 | 226 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 227 | cam_infos = [] 228 | 229 | with open(os.path.join(path, transformsfile)) as json_file: 230 | contents = json.load(json_file) 231 | fovx = contents["camera_angle_x"] 232 | 233 | frames = contents["frames"] 234 | for idx, frame in enumerate(frames): 235 | cam_name = os.path.join(path, frame["file_path"] + extension) 236 | 237 | # NeRF 'transform_matrix' is a camera-to-world transform 238 | c2w = np.array(frame["transform_matrix"]) 239 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 240 | c2w[:3, 1:3] *= -1 241 | 242 | # get the world-to-camera transform and set R, T 243 | w2c = np.linalg.inv(c2w) 244 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 245 | T = w2c[:3, 3] 246 | 247 | image_path = os.path.join(path, cam_name) 248 | image_name = Path(cam_name).stem 249 | image = Image.open(image_path) 250 | 251 | im_data = np.array(image.convert("RGBA")) 252 | 253 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 254 | 255 | norm_data = im_data / 255.0 256 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 257 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 258 | 259 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 260 | FovY = fovy 261 | FovX = fovx 262 | 263 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 264 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 265 | 266 | return cam_infos 267 | 268 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 269 | print("Reading Training Transforms") 270 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 271 | print("Reading Test Transforms") 272 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 273 | 274 | if not eval: 275 | train_cam_infos.extend(test_cam_infos) 276 | test_cam_infos = [] 277 | 278 | nerf_normalization = getNerfppNorm(train_cam_infos) 279 | 280 | ply_path = os.path.join(path, "points3d.ply") 281 | if not os.path.exists(ply_path): 282 | # Since this data set has no colmap data, we start with random points 283 | num_pts = 100_000 284 | print(f"Generating random point cloud ({num_pts})...") 285 | 286 | # We create random points inside the bounds of the synthetic Blender scenes 287 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 288 | shs = np.random.random((num_pts, 3)) / 255.0 289 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 290 | 291 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 292 | try: 293 | pcd = fetchPly(ply_path) 294 | except: 295 | pcd = None 296 | 297 | scene_info = SceneInfo(point_cloud=pcd, 298 | train_cameras=train_cam_infos, 299 | test_cameras=test_cam_infos, 300 | nerf_normalization=nerf_normalization, 301 | ply_path=ply_path) 302 | return scene_info 303 | 304 | sceneLoadTypeCallbacks = { 305 | "Colmap": readColmapSceneInfo, 306 | "Blender" : readNerfSyntheticInfo 307 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torchvision 14 | import torch.optim as optim 15 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 16 | os.environ['CUDA_VISIBLE_DEVICES']= '1' 17 | import torch 18 | import torch.nn as nn 19 | from random import randint 20 | from utils.loss_utils import l1_loss, ssim 21 | from gaussian_renderer import render, network_gui 22 | import sys 23 | from scene import Scene, GaussianModel 24 | from utils.general_utils import safe_state, calculate_reflection_direction 25 | from utils.general_utils import compute_gaussian_normals, flip_align_view 26 | import uuid 27 | from tqdm import tqdm 28 | from utils.image_utils import psnr 29 | from utils.graphics_utils import views_dir 30 | from argparse import ArgumentParser, Namespace 31 | from arguments import ModelParams, PipelineParams, OptimizationParams 32 | from mynet import MyNet, embedding_fn 33 | from lpipsPyTorch import lpips 34 | try: 35 | from torch.utils.tensorboard import SummaryWriter 36 | TENSORBOARD_FOUND = True 37 | except ImportError: 38 | TENSORBOARD_FOUND = False 39 | 40 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 41 | first_iter = 0 42 | tb_writer = prepare_output_and_logger(dataset) 43 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes) 44 | scene = Scene(dataset, gaussians) 45 | gaussians.training_setup(opt) 46 | 47 | model = MyNet().to('cuda') 48 | 49 | learning_rate = 7e-6 # 5e-6 50 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 51 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2000, verbose=True, factor=0.25, min_lr=1e-7) 52 | 53 | if checkpoint: 54 | (model_params, first_iter) = torch.load(checkpoint) 55 | gaussians.restore(model_params, opt) 56 | 57 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 58 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 59 | 60 | iter_start = torch.cuda.Event(enable_timing=True) 61 | iter_end = torch.cuda.Event(enable_timing=True) 62 | 63 | CE_loss = nn.CrossEntropyLoss() 64 | 65 | viewpoint_stack = None 66 | ema_loss_for_log = 0.0 67 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 68 | first_iter += 1 69 | for iteration in range(first_iter, opt.iterations + 1): 70 | # if network_gui.conn == None: 71 | # network_gui.try_connect() 72 | # while network_gui.conn != None: 73 | # try: 74 | # net_image_bytes = None 75 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 76 | # if custom_cam != None: 77 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 78 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 79 | # network_gui.send(net_image_bytes, dataset.source_path) 80 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 81 | # break 82 | # except Exception as e: 83 | # network_gui.conn = None 84 | 85 | iter_start.record() 86 | 87 | gaussians.update_learning_rate(iteration) 88 | 89 | # Every 1000 its we increase the levels of SH up to a maximum degree 90 | if iteration % 1000 == 0: 91 | gaussians.oneupSHdegree() 92 | 93 | # Pick a random Camera 94 | if not viewpoint_stack: 95 | viewpoint_stack = scene.getTrainCameras().copy() 96 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 97 | viewpoint_camera_center = viewpoint_cam.camera_center 98 | 99 | # Render 100 | if (iteration - 1) == debug_from: 101 | pipe.debug = True 102 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 103 | 104 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths, Ln = render_pkg 105 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:,:, 17:] 106 | 107 | views_emd = viewpoint_cam.views_emd 108 | # print(rendered_features[0].shape, views_emd.shape) 109 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 110 | 111 | 112 | image = model(*rendered_features) 113 | 114 | mask = image['mask_out'].squeeze(0) 115 | color = image['color_out'].squeeze(0) 116 | image = image['im_out'].squeeze(0) 117 | 118 | # Color Loss 119 | semantic_loss = 0 120 | 121 | gt_image = viewpoint_cam.original_image.cuda() 122 | 123 | dir_pp = (gaussians.get_xyz - viewpoint_camera_center.repeat(gaussians.get_xyz.shape[0], 1)) 124 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True) 125 | normals = gaussians.get_normal 126 | pseudo_normals = gaussians.get_minimum_axis 127 | pseudo_normals, _ = flip_align_view(pseudo_normals, dir_pp) 128 | Ln = 1 - torch.nn.functional.cosine_similarity(normals, pseudo_normals).mean() 129 | 130 | Ll1 = l1_loss(image, gt_image) 131 | color_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 132 | 133 | LB = l1_loss(color, gt_image) 134 | basic_loss = (1.0 - opt.lambda_dssim) * LB + opt.lambda_dssim * (1.0 - ssim(color, gt_image)) 135 | 136 | loss = color_loss + 0.001 * Ln + 0.05 * basic_loss 137 | 138 | loss.backward() 139 | 140 | iter_end.record() 141 | 142 | with torch.no_grad(): 143 | # Progress bar 144 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 145 | if iteration % 10 == 0: 146 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "C_Loss": f"{color_loss:.{5}f}", "S_Loss": f"{semantic_loss:.{5}f}"}) 147 | progress_bar.update(10) 148 | if iteration == opt.iterations: 149 | progress_bar.close() 150 | 151 | # Log and save 152 | 153 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), model=model) 154 | if (iteration in saving_iterations): 155 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 156 | scene.save(iteration) 157 | 158 | # Densification 159 | if iteration < opt.densify_until_iter: 160 | # Keep track of max radii in image-space for pruning 161 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 162 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 163 | 164 | 165 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 166 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 167 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) # 0.005 168 | 169 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 170 | gaussians.reset_opacity() 171 | 172 | optimizer.step() 173 | 174 | # Optimizer step 175 | if iteration < opt.iterations: 176 | gaussians.optimizer.step() 177 | gaussians.optimizer.zero_grad(set_to_none=True) 178 | gaussians.optimizer_net.step() 179 | gaussians.optimizer_net.zero_grad(set_to_none=True) 180 | gaussians.scheduler_net.step() 181 | scheduler.step(loss) 182 | 183 | if (iteration in checkpoint_iterations): 184 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 185 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 186 | 187 | 188 | 189 | def prepare_output_and_logger(args): 190 | if not args.model_path: 191 | if os.getenv('OAR_JOB_ID'): 192 | unique_str = os.getenv('OAR_JOB_ID') 193 | else: 194 | unique_str = str(uuid.uuid4()) 195 | args.model_path = os.path.join("./output/", unique_str[0:10]) 196 | 197 | # Set up output folder 198 | print("Output folder: {}".format(args.model_path)) 199 | os.makedirs(args.model_path, exist_ok=True) 200 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 201 | cfg_log_f.write(str(Namespace(**vars(args)))) 202 | 203 | # Create Tensorboard writer 204 | tb_writer = None 205 | if TENSORBOARD_FOUND: 206 | tb_writer = SummaryWriter(args.model_path) 207 | else: 208 | print("Tensorboard not available: not logging progress") 209 | return tb_writer 210 | 211 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, model=None): 212 | if tb_writer: 213 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 214 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 215 | tb_writer.add_scalar('iter_time', elapsed, iteration) 216 | 217 | # Report test and samples of training set 218 | if iteration in testing_iterations: 219 | torch.save(model.state_dict(), scene.model_path + "/model_ckpt" + str(iteration) + "pth") 220 | torch.cuda.empty_cache() 221 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 222 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 223 | for config in validation_configs: 224 | if config['cameras'] and len(config['cameras']) > 0: 225 | l1_test = 0.0 226 | psnr_test = 0.0 227 | ssim_test = 0.0 228 | lpips_test = 0.0 229 | for idx, viewpoint in enumerate(config['cameras']): 230 | if model: 231 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths,Ln= renderFunc(viewpoint, scene.gaussians, *renderArgs) 232 | rays_d = torch.from_numpy( 233 | views_dir(viewpoint.image_height, viewpoint.image_width, viewpoint.K, viewpoint.c2w)).cuda() 234 | 235 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:, :, 17:] 236 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 237 | 238 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 239 | 240 | 241 | image = model(*rendered_features) 242 | image_save = image['im_out'].squeeze(0) 243 | image = torch.clamp(image_save, 0.0, 1.0) 244 | else: 245 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 246 | 247 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 248 | 249 | if tb_writer and (idx < 5): 250 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 251 | if iteration == testing_iterations[0]: 252 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 253 | l1_test += l1_loss(image, gt_image).mean().double() 254 | psnr_test += psnr(image, gt_image).mean().double() 255 | ssim_test += ssim(image, gt_image).mean().double() 256 | lpips_test += lpips(image, gt_image).mean().double() 257 | psnr_test /= len(config['cameras']) 258 | ssim_test /= len(config['cameras']) 259 | lpips_test /= len(config['cameras']) 260 | l1_test /= len(config['cameras']) 261 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test)) 262 | if tb_writer: 263 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 264 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 265 | 266 | if tb_writer: 267 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 268 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 269 | torch.cuda.empty_cache() 270 | 271 | if __name__ == "__main__": 272 | # Set up command line argument parser 273 | parser = ArgumentParser(description="Training script parameters") 274 | lp = ModelParams(parser) 275 | op = OptimizationParams(parser) 276 | pp = PipelineParams(parser) 277 | parser.add_argument('--ip', type=str, default="127.0.0.1") 278 | parser.add_argument('--port', type=int, default=6009) 279 | parser.add_argument('--debug_from', type=int, default=-1) 280 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 281 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000]) 282 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000]) 283 | parser.add_argument("--quiet", action="store_true") 284 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 285 | parser.add_argument("--start_checkpoint", type=str, default = None) 286 | args = parser.parse_args(sys.argv[1:]) 287 | args.save_iterations.append(args.iterations) 288 | 289 | print("Scene path" + args.source_path) 290 | print("Optimizing " + args.model_path) 291 | 292 | # Initialize system state (RNG) 293 | safe_state(args.quiet) 294 | 295 | # Start GUI server, configure and run training 296 | # network_gui.init(args.ip, args.port) 297 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 298 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 299 | 300 | # All done 301 | print("\nTraining complete.") 302 | -------------------------------------------------------------------------------- /mynet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from functools import partial 7 | 8 | 9 | def embedding_fn(inputs): 10 | embed_fns = [] 11 | d = 3 12 | out_dim = 0 13 | max_freq = 3 # 3 14 | 15 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=4) # [2**0, 2**1,...,] 16 | for freq in freq_bands: 17 | for p_fn in [torch.sin, torch.cos]: 18 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) # [torch.sin, torch.cos] 19 | out_dim += d 20 | 21 | return torch.cat([fn(inputs) for fn in embed_fns], -1) 22 | 23 | class BasicConv(nn.Module): 24 | # Gated_conv 25 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, dilation=1, 26 | padding_mode='reflect', act_fun=nn.ELU, normalization=nn.BatchNorm2d): 27 | super().__init__() 28 | self.pad_mode = padding_mode 29 | self.filter_size = kernel_size 30 | self.stride = stride 31 | self.dilation = dilation 32 | group_normalization = nn.GroupNorm 33 | 34 | n_pad_pxl = int(self.dilation * (self.filter_size - 1) / 2) 35 | self.flag = relu 36 | 37 | # this is for backward campatibility with older model checkpoints 38 | self.block = nn.ModuleDict( 39 | { 40 | 'conv_f': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, 41 | padding=n_pad_pxl), 42 | 'act_f': act_fun(), 43 | 'conv_m': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, 44 | padding=n_pad_pxl), 45 | 'act_m': nn.Sigmoid(), 46 | 'norm': normalization(out_channels), 47 | 'group_norm': group_normalization(num_groups=out_channels, num_channels=out_channels), 48 | } 49 | ) 50 | 51 | def forward(self, x, *args, **kwargs): 52 | features = self.block.act_f(self.block.conv_f(x)) 53 | output = features 54 | return output 55 | 56 | 57 | class ResBlock(nn.Module): 58 | def __init__(self, in_channel, out_channel): 59 | super(ResBlock, self).__init__() 60 | self.main = nn.Sequential( 61 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 62 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False), 63 | ) 64 | self.relu = nn.ELU() 65 | self.norm = nn.GroupNorm(1, out_channel) 66 | 67 | def forward(self, x): 68 | return self.relu(self.norm(self.main(x) + (x))) 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, in_channel, out_channel): 72 | super(ResNet, self).__init__() 73 | self.main = nn.Sequential( 74 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 75 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False), 76 | nn.BatchNorm2d(out_channel) 77 | 78 | ) 79 | self.net = BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True) 80 | self.norm = nn.BatchNorm2d(out_channel) 81 | self.relu = nn.ELU() 82 | 83 | def forward(self, x): 84 | output = self.main(x) + self.norm(self.net(x)) 85 | return self.relu(output) 86 | 87 | 88 | class SCM(nn.Module): 89 | def __init__(self, out_plane): 90 | super(SCM, self).__init__() 91 | self.main = nn.Sequential( 92 | BasicConv(8, out_plane - 8, kernel_size=3, stride=1, relu=True), 93 | ) 94 | 95 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False) 96 | 97 | def forward(self, x): 98 | x = torch.cat([x, self.main(x)], dim=1) 99 | return self.conv(x) 100 | 101 | 102 | class EBlock(nn.Module): 103 | def __init__(self, out_channel, num_res=1): 104 | super(EBlock, self).__init__() 105 | 106 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 107 | 108 | self.layers = nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | return self.layers(x) 112 | 113 | 114 | class DBlock(nn.Module): 115 | def __init__(self, channel, num_res=1): 116 | super(DBlock, self).__init__() 117 | 118 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 119 | self.layers = nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | return self.layers(x) 123 | 124 | 125 | class FAM(nn.Module): 126 | def __init__(self, channel): 127 | super(FAM, self).__init__() 128 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False) 129 | 130 | def forward(self, x1, x2): 131 | x = x1 * x2 132 | out = x1 + self.merge(x) 133 | return out 134 | 135 | 136 | class AFF(nn.Module): 137 | def __init__(self, in_channel, out_channel): 138 | super(AFF, self).__init__() 139 | self.conv = nn.Sequential( 140 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 141 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 142 | ) 143 | 144 | def forward(self, x1, x2, x3, x4): 145 | x = torch.cat([x1, x2, x3, x4], dim=1) 146 | return self.conv(x) 147 | 148 | 149 | class MyNet(nn.Module): 150 | def __init__( 151 | self, 152 | num_input_channels=8, 153 | num_output_channels=3, 154 | feature_scale=4, 155 | num_res=1 156 | 157 | ): 158 | super().__init__() 159 | 160 | self.feature_scale = feature_scale 161 | base_channel = 8 162 | 163 | filters = [64, 128, 256, 512, 1024] 164 | filters = [x // self.feature_scale for x in filters] 165 | 166 | base_channel = 8 167 | 168 | self.resnet = nn.ModuleList([ 169 | ResNet(base_channel, base_channel * 2), 170 | ResNet(base_channel * 2, base_channel * 4), 171 | ResNet(base_channel * 4, base_channel * 8), 172 | ResNet(base_channel * 8, base_channel * 4), 173 | ResNet(base_channel * 4, base_channel * 2), 174 | ResNet(base_channel * 2, base_channel), 175 | ]) 176 | 177 | self.feat_extract = nn.ModuleList([ 178 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1), 179 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2), 180 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2), 181 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2), 182 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2), 183 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 184 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2), 185 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2), 186 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=1), 187 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=3, relu=True, stride=1), 188 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=3, relu=True, stride=1), 189 | BasicConv(base_channel * 2, base_channel * 1, kernel_size=3, relu=True, stride=1), 190 | BasicConv(base_channel, base_channel * 4, kernel_size=3, relu=True, stride=1), 191 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=1), 192 | BasicConv(base_channel * 8 + 24, base_channel * 16, kernel_size=3, relu=True, stride=1), 193 | # nn.Linear(base_channel * 8 + 24, base_channel*16), 194 | BasicConv(base_channel * 16, base_channel * 4, kernel_size=3, relu=True, stride=1), 195 | BasicConv(base_channel * 4, base_channel, kernel_size=3, relu=True, stride=1), 196 | BasicConv(24, base_channel * 8, kernel_size=3, relu=True, stride=1), 197 | BasicConv(base_channel * 8, base_channel * 2, kernel_size=3, relu=True, stride=1), 198 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 199 | BasicConv(base_channel * 8, 20, kernel_size=3, relu=False, stride=1), 200 | 201 | BasicConv(base_channel, base_channel * 8, kernel_size=3, relu=False, stride=1), 202 | BasicConv(base_channel * 8 + 24, base_channel, kernel_size=3, relu=True, stride=1), 203 | # BasicConv(base_channel * 8 + 24, base_channel, kernel_size=3, relu=True, stride=1), 204 | # BasicConv(base_channel * 8 , base_channel, kernel_size=3, relu=True, stride=1), 205 | BasicConv(base_channel, 3, kernel_size=3, relu=True, stride=1), 206 | 207 | BasicConv(24, base_channel, kernel_size=3, relu=True, stride=1), 208 | BasicConv(base_channel * 4, 20, kernel_size=3, relu=False, stride=1), 209 | # BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 210 | BasicConv(base_channel, 1, kernel_size=3, relu=False, stride=1), 211 | # BasicConv(base_channel*2+24, base_channel, kernel_size=3, relu=False, stride=1), 212 | BasicConv(base_channel * 2, base_channel, kernel_size=3, relu=False, stride=1), 213 | 214 | BasicConv(40, base_channel, kernel_size=3, relu=False, stride=1), 215 | 216 | # 29开始引入hash 217 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 218 | BasicConv(base_channel * 8, 3, kernel_size=3, relu=False, stride=1), 219 | 220 | # 31 开始MLP: 221 | nn.Linear(8, 64, bias=True), 222 | nn.Linear(64 + 24, 64, bias=True), 223 | nn.Linear(64, 3, bias=True) 224 | 225 | ]) 226 | 227 | self.SCM0 = SCM(base_channel * 8) 228 | self.SCM1 = SCM(base_channel * 4) 229 | self.SCM2 = SCM(base_channel * 2) 230 | 231 | self.FAM0 = FAM(base_channel * 8) 232 | self.FAM1 = FAM(base_channel * 4) 233 | self.FAM2 = FAM(base_channel * 2) 234 | 235 | self.AFFs = nn.ModuleList([ 236 | AFF(base_channel * 15, base_channel * 1), 237 | AFF(base_channel * 15, base_channel * 2), 238 | AFF(base_channel * 15, base_channel * 4), 239 | ]) 240 | 241 | self.Encoder = nn.ModuleList([ 242 | EBlock(base_channel, num_res), 243 | EBlock(base_channel * 2, num_res), 244 | EBlock(base_channel * 4, num_res), 245 | EBlock(base_channel * 8, num_res) 246 | ]) 247 | 248 | self.Decoder = nn.ModuleList([ 249 | DBlock(base_channel * 8, num_res), 250 | DBlock(base_channel * 4, num_res), 251 | DBlock(base_channel * 2, num_res), 252 | DBlock(base_channel, num_res) 253 | ]) 254 | 255 | self.Convs = nn.ModuleList([ 256 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 257 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 258 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 259 | BasicConv(base_channel * 8, base_channel * 8, kernel_size=1, relu=True, stride=1), 260 | BasicConv(base_channel * 4, base_channel * 4, kernel_size=1, relu=True, stride=1), 261 | BasicConv(base_channel * 2, base_channel * 2, kernel_size=1, relu=True, stride=1), 262 | 263 | ]) 264 | 265 | self.ConvsOut = nn.ModuleList( 266 | [ 267 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 268 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 269 | ] 270 | ) 271 | 272 | self.up = nn.Upsample(scale_factor=2, mode='bilinear') 273 | self.up1 = nn.Upsample(scale_factor=4, mode='bilinear') 274 | self.sigmoid = nn.Sigmoid() 275 | 276 | # basic_model 277 | def forward(self, *inputs, **kwargs): 278 | inputs = list(inputs) 279 | 280 | n_input = len(inputs) 281 | 282 | x = inputs[0][:, :8, :, :] 283 | 284 | f = inputs[0][:, 8:16, :, :] 285 | 286 | sh_f = list(f.shape) 287 | sh_f[1] = 3 288 | 289 | tint = inputs[0][:, 16, :, :] 290 | 291 | views = inputs[0][:, 20:, :, :] 292 | 293 | # highlight 294 | f = self.feat_extract[21](f) 295 | f = self.Encoder[3](f) 296 | # print(f.shape, views.shape) 297 | f = torch.cat((f, views), dim=1) 298 | f = self.feat_extract[22](f) 299 | f = self.Decoder[3](f) 300 | f = self.feat_extract[23](f) 301 | 302 | mask = tint 303 | 304 | highlight = f * mask 305 | 306 | s = x 307 | z1 = x 308 | 309 | # color 310 | x = self.feat_extract[1](x) 311 | x = self.Encoder[1](x) 312 | z2 = x 313 | 314 | x = self.feat_extract[2](x) 315 | x = self.Encoder[2](x) 316 | z3 = x 317 | 318 | x = self.feat_extract[6](x) 319 | x = self.Encoder[3](x) 320 | 321 | x = self.up(x) 322 | 323 | x = self.Convs[0](x) 324 | x = F.interpolate(x, z3.shape[-2:]) 325 | x = torch.cat([x, z3], dim=1) 326 | x = self.Decoder[0](x) 327 | 328 | x = self.feat_extract[9](x) 329 | x = self.up(x) 330 | 331 | x = self.Convs[1](x) 332 | x = F.interpolate(x, z2.shape[-2:]) 333 | x = torch.cat([x, z2], dim=1) 334 | x = self.Decoder[1](x) 335 | 336 | x = self.feat_extract[10](x) 337 | x = self.up(x) 338 | 339 | x = self.Convs[2](x) 340 | x = F.interpolate(x, z1.shape[-2:]) 341 | x = torch.cat([x, z1], dim=1) 342 | x = self.Decoder[2](x) 343 | 344 | x = self.feat_extract[11](x) 345 | 346 | z = self.feat_extract[5](x) 347 | 348 | color = z 349 | 350 | z = z + highlight 351 | 352 | return {'im_out': z, 353 | 's_out': s, 354 | 'mask_out': mask, 355 | 'highlight_out': highlight, 356 | 'f_out': f, 357 | 'color_out': color} 358 | 359 | 360 | if __name__ == '__main__': 361 | import pdb 362 | import time 363 | import numpy as np 364 | 365 | # model = UNet().to('cuda') 366 | model = MyNet().to('cuda') 367 | input = [] 368 | img_sh = [1408, 376] 369 | sh_unit = 8 370 | # img_sh = list(map(lambda a: a - a % sh_unit + sh_unit if a % sh_unit != 0 else a, img_sh)) 371 | 372 | # print(img_sh) 373 | down = lambda a, b: a // 2 ** b 374 | input.append(torch.zeros((1, 43, down(img_sh[0], 0), down(img_sh[1], 0)), requires_grad=True).cuda()) 375 | input.append(F.interpolate(input[0], scale_factor=0.5)) 376 | input.append(F.interpolate(input[1], scale_factor=0.5)) 377 | input.append(F.interpolate(input[2], scale_factor=0.5)) 378 | print(input) 379 | 380 | model.eval() 381 | st = time.time() 382 | print(input[0].max(), input[0].min()) 383 | print(input[0].shape, input[1].shape, input[2].shape, input[3].shape) 384 | with torch.set_grad_enabled(False): 385 | out = model(*input) 386 | pdb.set_trace() 387 | print('model', time.time() - st) 388 | print(out['im_out'], out['im_out'].shape) 389 | print(out['s_out'], out['s_out'].shape) 390 | model.to('cpu') 391 | -------------------------------------------------------------------------------- /train_2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torchvision 14 | import torch.optim as optim 15 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 16 | os.environ['CUDA_VISIBLE_DEVICES']= '0' 17 | import torch 18 | import torch.nn as nn 19 | from random import randint 20 | from utils.loss_utils import l1_loss, ssim 21 | from gaussian_renderer import render, network_gui 22 | import sys 23 | from scene import Scene, GaussianModel 24 | from utils.general_utils import safe_state, calculate_reflection_direction 25 | import uuid 26 | from tqdm import tqdm 27 | from utils.image_utils import psnr 28 | from utils.graphics_utils import views_dir 29 | from argparse import ArgumentParser, Namespace 30 | from model import UNet, SimpleNet, SimpleUNet 31 | from arguments import ModelParams, PipelineParams, OptimizationParams 32 | from mynet import MyNet, embedding_fn 33 | from lpipsPyTorch import lpips 34 | try: 35 | from torch.utils.tensorboard import SummaryWriter 36 | TENSORBOARD_FOUND = True 37 | except ImportError: 38 | TENSORBOARD_FOUND = False 39 | 40 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 41 | first_iter = 0 42 | tb_writer = prepare_output_and_logger(dataset) 43 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes) 44 | scene = Scene(dataset, gaussians) 45 | gaussians.training_setup(opt) 46 | 47 | # 初始化模型 48 | # model = SimpleNet().to('cuda') 49 | # model = SimpleUNet().to('cuda') 50 | model = MyNet().to('cuda') 51 | 52 | # 加载预训练权重 53 | # net_path = "/home/wangzhiru/projects/nerf/latent-gaussian-splatting/output/bonsai_output/model_ckpt30000pth" 54 | # net_weights = torch.load(net_path) 55 | # # print(net_weights) 56 | # model.load_state_dict(net_weights) 57 | 58 | learning_rate = 5e-6 # 5e-6 1e-5 59 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 60 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2000, verbose=True, factor=0.25, min_lr=1e-7) 61 | 62 | if checkpoint: 63 | (model_params, first_iter) = torch.load(checkpoint) 64 | gaussians.restore(model_params, opt) 65 | 66 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 67 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 68 | 69 | iter_start = torch.cuda.Event(enable_timing=True) 70 | iter_end = torch.cuda.Event(enable_timing=True) 71 | 72 | CE_loss = nn.CrossEntropyLoss() 73 | 74 | viewpoint_stack = None 75 | ema_loss_for_log = 0.0 76 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 77 | first_iter += 1 78 | for iteration in range(first_iter, opt.iterations + 1): 79 | # if network_gui.conn == None: 80 | # network_gui.try_connect() 81 | # while network_gui.conn != None: 82 | # try: 83 | # net_image_bytes = None 84 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 85 | # if custom_cam != None: 86 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 87 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 88 | # network_gui.send(net_image_bytes, dataset.source_path) 89 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 90 | # break 91 | # except Exception as e: 92 | # network_gui.conn = None 93 | 94 | iter_start.record() 95 | 96 | gaussians.update_learning_rate(iteration) 97 | 98 | # Every 1000 its we increase the levels of SH up to a maximum degree 99 | if iteration % 1000 == 0: 100 | gaussians.oneupSHdegree() 101 | 102 | # Pick a random Camera 103 | if not viewpoint_stack: 104 | viewpoint_stack = scene.getTrainCameras().copy() 105 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 106 | 107 | # Render 108 | if (iteration - 1) == debug_from: 109 | pipe.debug = True 110 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 111 | 112 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths, Ln = render_pkg 113 | # print(rendered_features[0].shape) 114 | # print(rendered_features, rendered_features[0].shape) 115 | # depths = rendered_depths[0] 116 | 117 | # 加入视角信息 118 | rays_d = torch.from_numpy(views_dir(viewpoint_cam.image_height, viewpoint_cam.image_width, viewpoint_cam.K, viewpoint_cam.c2w)).cuda() 119 | # print(rays_d.shape) 120 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:,:, 17:] 121 | # print(normals.shape) 122 | 123 | 124 | # rays_r = calculate_reflection_direction(rays_d, normals).permute(2,0,1).unsqueeze(0) 125 | # rays_d = calculate_reflection_direction(rays_d, normals) 126 | 127 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 128 | # rays_o, rays_d = views_dir(viewpoint_cam.image_height, viewpoint_cam.image_width, viewpoint_cam.K, viewpoint_cam.c2w) 129 | # rays_o, rays_d = torch.from_numpy(rays_o.copy()).cuda(), torch.from_numpy(rays_d.copy()).cuda(), 130 | # views_emdd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 131 | # views_emdo = embedding_fn(rays_o).permute(2, 0, 1).unsqueeze(0) 132 | # views_emd = torch.cat((views_emdo, views_emdd), dim = 1) 133 | rays_d = rays_d.permute(2, 0, 1).unsqueeze(0) 134 | normals = normals.permute(2, 0, 1).unsqueeze(0) 135 | 136 | # print(rays_dot.shape, views_emd.shape) 137 | # views_emd = torch.cat((views_emd, -rays_d, normals), dim=1) 138 | # views_emd = torch.cat((views_emd,normals), dim=1) 139 | 140 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 141 | # rendered_features[0] = torch.cat((rendered_features[0], depths), dim=1) 142 | # print(rendered_features[0].shape) 143 | 144 | 145 | image = model(*rendered_features) 146 | # semantic_logits = image['s_out'].squeeze(0) 147 | mask = image['mask_out'].squeeze(0) 148 | color = image['color_out'].squeeze(0) 149 | image = image['im_out'].squeeze(0) 150 | 151 | # image_mask = image * (torch.ones_like(mask)-mask) 152 | 153 | # image, semantic_logits, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["semantic"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 154 | # print(image.shape, semantic_logits.shape) 155 | # print("semantic:", semantic_logits.shape) 156 | 157 | # Semantic Loss 158 | # semantic_logits = semantic_logits.permute(1, 2, 0) 159 | # gt_semantic_image = viewpoint_cam.semantic_image.cuda() 160 | # Lce = CE_loss(semantic_logits.reshape(-1, 20), gt_semantic_image.reshape(-1).long()) 161 | 162 | # semantic_loss = Lce 163 | # Color Loss 164 | semantic_loss = 0 165 | 166 | gt_image = viewpoint_cam.original_image.cuda() 167 | # gt_image_mask = gt_image * (torch.ones_like(mask)-mask) 168 | # 169 | # Ln = torch.mean((gs_normals - pseudo_normals) ** 2) 170 | 171 | Ll1 = l1_loss(image, gt_image) 172 | color_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 173 | 174 | LB = l1_loss(color, gt_image) 175 | basic_loss = (1.0 - opt.lambda_dssim) * LB + opt.lambda_dssim * (1.0 - ssim(color, gt_image)) 176 | 177 | # LD = l1_loss(image_mask, gt_image_mask) 178 | # diffuse_loss = (1.0 - opt.lambda_dssim) * LD + opt.lambda_dssim * (1.0 - ssim(image_mask, gt_image_mask)) 179 | 180 | loss = color_loss + 0.001 * Ln + 0.05 * basic_loss #+ diffuse_loss #+ basic_loss # + semantic_loss * 0.2 181 | 182 | loss.backward() 183 | 184 | iter_end.record() 185 | 186 | with torch.no_grad(): 187 | # Progress bar 188 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 189 | if iteration % 10 == 0: 190 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "C_Loss": f"{color_loss:.{5}f}", "S_Loss": f"{semantic_loss:.{5}f}"}) 191 | progress_bar.update(10) 192 | if iteration == opt.iterations: 193 | progress_bar.close() 194 | 195 | # Log and save 196 | 197 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), model=model) 198 | if (iteration in saving_iterations): 199 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 200 | scene.save(iteration) 201 | 202 | # Densification 203 | if iteration < opt.densify_until_iter: 204 | # Keep track of max radii in image-space for pruning 205 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 206 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 207 | 208 | 209 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 210 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 211 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) # 0.005 212 | 213 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 214 | gaussians.reset_opacity() 215 | 216 | optimizer.step() 217 | 218 | # Optimizer step 219 | if iteration < opt.iterations: 220 | gaussians.optimizer.step() 221 | gaussians.optimizer.zero_grad(set_to_none=True) 222 | gaussians.optimizer_net.step() 223 | gaussians.optimizer_net.zero_grad(set_to_none=True) 224 | gaussians.scheduler_net.step() 225 | scheduler.step(loss) 226 | 227 | if (iteration in checkpoint_iterations): 228 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 229 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 230 | 231 | 232 | 233 | def prepare_output_and_logger(args): 234 | if not args.model_path: 235 | if os.getenv('OAR_JOB_ID'): 236 | unique_str = os.getenv('OAR_JOB_ID') 237 | else: 238 | unique_str = str(uuid.uuid4()) 239 | args.model_path = os.path.join("./output/", unique_str[0:10]) 240 | 241 | # Set up output folder 242 | print("Output folder: {}".format(args.model_path)) 243 | os.makedirs(args.model_path, exist_ok=True) 244 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 245 | cfg_log_f.write(str(Namespace(**vars(args)))) 246 | 247 | # Create Tensorboard writer 248 | tb_writer = None 249 | if TENSORBOARD_FOUND: 250 | tb_writer = SummaryWriter(args.model_path) 251 | else: 252 | print("Tensorboard not available: not logging progress") 253 | return tb_writer 254 | 255 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, model=None): 256 | if tb_writer: 257 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 258 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 259 | tb_writer.add_scalar('iter_time', elapsed, iteration) 260 | 261 | # Report test and samples of training set 262 | if iteration in testing_iterations: 263 | torch.save(model.state_dict(), scene.model_path + "/model_ckpt" + str(iteration) + "pth") 264 | torch.cuda.empty_cache() 265 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 266 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 267 | for config in validation_configs: 268 | if config['cameras'] and len(config['cameras']) > 0: 269 | l1_test = 0.0 270 | psnr_test = 0.0 271 | ssim_test = 0.0 272 | lpips_test = 0.0 273 | for idx, viewpoint in enumerate(config['cameras']): 274 | if model: 275 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths,Ln= renderFunc(viewpoint, scene.gaussians, *renderArgs) 276 | # depths = rendered_depths[0] 277 | rays_d = torch.from_numpy( 278 | views_dir(viewpoint.image_height, viewpoint.image_width, viewpoint.K, viewpoint.c2w)).cuda() 279 | 280 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:, :, 17:] 281 | # print(normals.shape) 282 | # 283 | # rays_d = calculate_reflection_direction(rays_d, normals) 284 | # rays_r = calculate_reflection_direction(rays_d, normals).permute(2,0,1).unsqueeze(0) 285 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 286 | 287 | # rays_o, rays_d = views_dir(viewpoint.image_height, viewpoint.image_width,viewpoint.K, viewpoint.c2w) 288 | # rays_o, rays_d = torch.from_numpy(rays_o.copy()).cuda(), torch.from_numpy(rays_d.copy()).cuda(), 289 | # views_emdd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0) 290 | # views_emdo = embedding_fn(rays_o).permute(2, 0, 1).unsqueeze(0) 291 | # views_emd = torch.cat((views_emdo, views_emdd), dim=1) 292 | # views_emd = torch.cat((views_emd, -rays_d.permute(2, 0, 1).unsqueeze(0), normals.permute(2, 0, 1).unsqueeze(0)), dim=1) 293 | # views_emd = torch.cat((views_emd, normals.permute(2, 0, 1).unsqueeze(0)),dim=1) 294 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1) 295 | # rendered_features[0] = torch.cat((rendered_features[0], depths), dim=1) 296 | 297 | 298 | image = model(*rendered_features) 299 | image_save = image['im_out'].squeeze(0) 300 | image = torch.clamp(image_save, 0.0, 1.0) 301 | else: 302 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 303 | 304 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 305 | 306 | if tb_writer and (idx < 5): 307 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 308 | if iteration == testing_iterations[0]: 309 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 310 | l1_test += l1_loss(image, gt_image).mean().double() 311 | psnr_test += psnr(image, gt_image).mean().double() 312 | ssim_test += ssim(image, gt_image).mean().double() 313 | lpips_test += lpips(image, gt_image).mean().double() 314 | psnr_test /= len(config['cameras']) 315 | ssim_test /= len(config['cameras']) 316 | lpips_test /= len(config['cameras']) 317 | l1_test /= len(config['cameras']) 318 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test)) 319 | if tb_writer: 320 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 321 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 322 | 323 | if tb_writer: 324 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 325 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 326 | torch.cuda.empty_cache() 327 | 328 | if __name__ == "__main__": 329 | # Set up command line argument parser 330 | parser = ArgumentParser(description="Training script parameters") 331 | lp = ModelParams(parser) 332 | op = OptimizationParams(parser) 333 | pp = PipelineParams(parser) 334 | parser.add_argument('--ip', type=str, default="127.0.0.1") 335 | parser.add_argument('--port', type=int, default=6009) 336 | parser.add_argument('--debug_from', type=int, default=-1) 337 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 338 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000, 50_000, 60_000, 70_000, 80_000]) #, 40_000, 50_000, 60_000, 70_000, 80_000]) 339 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000, 50_000, 60_000, 70_000, 80_000]) # ,40_000, 50_000, 60_000, 70_000, 80_000]) 340 | parser.add_argument("--quiet", action="store_true") 341 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 342 | parser.add_argument("--start_checkpoint", type=str, default = None) 343 | args = parser.parse_args(sys.argv[1:]) 344 | args.save_iterations.append(args.iterations) 345 | 346 | print("Scene path" + args.source_path) 347 | print("Optimizing " + args.model_path) 348 | 349 | # Initialize system state (RNG) 350 | safe_state(args.quiet) 351 | 352 | # Start GUI server, configure and run training 353 | # network_gui.init(args.ip, args.port) 354 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 355 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 356 | 357 | # All done 358 | print("\nTraining complete.") 359 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from functools import partial 7 | 8 | 9 | class BasicConv(nn.Module): 10 | # Gated_conv 11 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, dilation=1, 12 | padding_mode='reflect', act_fun=nn.ELU, normalization=nn.BatchNorm2d): 13 | super().__init__() 14 | self.pad_mode = padding_mode 15 | self.filter_size = kernel_size 16 | self.stride = stride 17 | self.dilation = dilation 18 | 19 | n_pad_pxl = int(self.dilation * (self.filter_size - 1) / 2) 20 | self.flag = relu 21 | 22 | # this is for backward campatibility with older model checkpoints 23 | self.block = nn.ModuleDict( 24 | { 25 | 'conv_f': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, 26 | padding=n_pad_pxl), 27 | 'act_f': act_fun(), 28 | 'conv_m': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, 29 | padding=n_pad_pxl), 30 | 'act_m': nn.Sigmoid(), 31 | 'norm': normalization(out_channels) 32 | } 33 | ) 34 | 35 | def forward(self, x, *args, **kwargs): 36 | # if self.flag: 37 | # features = self.block.act_f(self.block.conv_f(x)) 38 | features = self.block.act_f(self.block.conv_f(x)) 39 | # else: 40 | # features = self.block.conv_f(x) 41 | # mask = self.block.act_m(self.block.conv_m(x)) 42 | # output = features * mask 43 | output = features 44 | # output = self.block.norm(output) 45 | 46 | return output 47 | 48 | 49 | class ResBlock(nn.Module): 50 | def __init__(self, in_channel, out_channel): 51 | super(ResBlock, self).__init__() 52 | self.main = nn.Sequential( 53 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 54 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False), 55 | nn.BatchNorm2d(out_channel) 56 | ) 57 | 58 | def forward(self, x): 59 | return self.main(x) + x 60 | 61 | 62 | class SCM(nn.Module): 63 | def __init__(self, out_plane): 64 | super(SCM, self).__init__() 65 | self.main = nn.Sequential( 66 | BasicConv(16, out_plane - 8, kernel_size=3, stride=1, relu=True), 67 | # BasicConv(8, out_plane // 4, kernel_size=3, stride=1, relu=True), 68 | # BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 69 | # BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 70 | # BasicConv(out_plane // 2, out_plane - 8, kernel_size=1, stride=1, relu=True) 71 | ) 72 | 73 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False) 74 | 75 | def forward(self, x): 76 | x = torch.cat([x, self.main(x)], dim=1) 77 | return self.conv(x) 78 | 79 | 80 | class EBlock(nn.Module): 81 | def __init__(self, out_channel, num_res=1): 82 | super(EBlock, self).__init__() 83 | 84 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 85 | 86 | self.layers = nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | return self.layers(x) 90 | 91 | 92 | class DBlock(nn.Module): 93 | def __init__(self, channel, num_res=1): 94 | super(DBlock, self).__init__() 95 | 96 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 97 | self.layers = nn.Sequential(*layers) 98 | 99 | def forward(self, x): 100 | return self.layers(x) 101 | 102 | 103 | class FAM(nn.Module): 104 | def __init__(self, channel): 105 | super(FAM, self).__init__() 106 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False) 107 | 108 | def forward(self, x1, x2): 109 | x = x1 * x2 110 | out = x1 + self.merge(x) 111 | return out 112 | 113 | 114 | class AFF(nn.Module): 115 | def __init__(self, in_channel, out_channel): 116 | super(AFF, self).__init__() 117 | self.conv = nn.Sequential( 118 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 119 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 120 | ) 121 | 122 | def forward(self, x1, x2, x3, x4): 123 | x = torch.cat([x1, x2, x3, x4], dim=1) 124 | return self.conv(x) 125 | 126 | 127 | class UNet(nn.Module): 128 | r""" Rendering network with UNet architecture and multi-scale input. 129 | 130 | Args: 131 | num_input_channels: Number of channels in the input tensor or list of tensors. An integer or a list of integers for each input tensor. 132 | num_output_channels: Number of output channels. 133 | feature_scale: Division factor of number of convolutional channels. The bigger the less parameters in the model. 134 | num_res: Number of block resnet. 135 | """ 136 | 137 | def __init__( 138 | self, 139 | num_input_channels=8, 140 | num_output_channels=3, 141 | feature_scale=4, 142 | num_res=1 143 | 144 | ): 145 | super().__init__() 146 | 147 | self.feature_scale = feature_scale 148 | base_channel = 32 149 | 150 | filters = [64, 128, 256, 512, 1024] 151 | filters = [x // self.feature_scale for x in filters] 152 | 153 | base_channel = 32 154 | 155 | self.feat_extract = nn.ModuleList([ 156 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1), 157 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2), 158 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2), 159 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2), 160 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2), 161 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 162 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2), 163 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2), 164 | ]) 165 | 166 | self.SCM0 = SCM(base_channel * 8) 167 | self.SCM1 = SCM(base_channel * 4) 168 | self.SCM2 = SCM(base_channel * 2) 169 | 170 | self.FAM0 = FAM(base_channel * 8) 171 | self.FAM1 = FAM(base_channel * 4) 172 | self.FAM2 = FAM(base_channel * 2) 173 | 174 | self.AFFs = nn.ModuleList([ 175 | AFF(base_channel * 15, base_channel * 1), 176 | AFF(base_channel * 15, base_channel * 2), 177 | AFF(base_channel * 15, base_channel * 4), 178 | ]) 179 | 180 | self.Encoder = nn.ModuleList([ 181 | EBlock(base_channel, num_res), 182 | EBlock(base_channel * 2, num_res), 183 | EBlock(base_channel * 4, num_res), 184 | EBlock(base_channel * 8, num_res) 185 | ]) 186 | 187 | self.Decoder = nn.ModuleList([ 188 | DBlock(base_channel * 8, num_res), 189 | DBlock(base_channel * 4, num_res), 190 | DBlock(base_channel * 2, num_res), 191 | DBlock(base_channel, num_res) 192 | ]) 193 | 194 | self.Convs = nn.ModuleList([ 195 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 196 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 197 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1) 198 | 199 | ]) 200 | 201 | self.ConvsOut = nn.ModuleList( 202 | [ 203 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 204 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 205 | ] 206 | ) 207 | 208 | self.up = nn.Upsample(scale_factor=4, mode='bilinear') 209 | 210 | def forward(self, *inputs, **kwargs): 211 | inputs = list(inputs) 212 | 213 | n_input = len(inputs) 214 | 215 | x = inputs[0] 216 | x_2 = inputs[1] 217 | x_4 = inputs[2] 218 | x_8 = inputs[3] 219 | 220 | z2 = self.SCM2(x_2) 221 | z4 = self.SCM1(x_4) 222 | z8 = self.SCM0(x_8) 223 | 224 | x_ = self.feat_extract[0](x) 225 | res1 = self.Encoder[0](x_) 226 | 227 | z = self.feat_extract[1](res1) 228 | z = self.FAM2(z, z2) 229 | res2 = self.Encoder[1](z) 230 | 231 | z = self.feat_extract[2](res2) 232 | z = self.FAM1(z, z4) 233 | res3 = self.Encoder[2](z) 234 | 235 | z = self.feat_extract[6](res3) 236 | z = self.FAM0(z, z8) 237 | z = self.Encoder[3](z) 238 | 239 | z12 = F.interpolate(res1, scale_factor=0.5) 240 | z13 = F.interpolate(res1, scale_factor=0.25) 241 | 242 | z21 = F.interpolate(res2, scale_factor=2) 243 | z23 = F.interpolate(res2, scale_factor=0.5) 244 | 245 | z32 = F.interpolate(res3, scale_factor=2) 246 | z31 = F.interpolate(res3, scale_factor=4) 247 | 248 | z43 = F.interpolate(z, scale_factor=2) 249 | z42 = F.interpolate(z43, scale_factor=2) 250 | z41 = F.interpolate(z42, scale_factor=2) 251 | 252 | res1 = self.AFFs[0](res1, z21, z31, z41) 253 | res2 = self.AFFs[1](z12, res2, z32, z42) 254 | res3 = self.AFFs[2](z13, z23, res3, z43) 255 | z = self.Decoder[0](z) 256 | 257 | z = self.feat_extract[7](z) 258 | 259 | z = self.up(z) 260 | z = torch.cat([z, res3], dim=1) 261 | z = self.Convs[0](z) 262 | z = self.Decoder[1](z) 263 | 264 | z = self.feat_extract[3](z) 265 | z = self.up(z) 266 | 267 | z = torch.cat([z, res2], dim=1) 268 | z = self.Convs[1](z) 269 | z = self.Decoder[2](z) 270 | 271 | z = self.feat_extract[4](z) 272 | z = self.up(z) 273 | 274 | z = torch.cat([z, res1], dim=1) 275 | z = self.Convs[2](z) 276 | z = self.Decoder[3](z) 277 | z = self.feat_extract[5](z) 278 | 279 | return {'im_out': z} 280 | 281 | class SimpleUNet(nn.Module): 282 | r""" Rendering network with UNet architecture and multi-scale input. 283 | 284 | Args: 285 | num_input_channels: Number of channels in the input tensor or list of tensors. An integer or a list of integers for each input tensor. 286 | num_output_channels: Number of output channels. 287 | feature_scale: Division factor of number of convolutional channels. The bigger the less parameters in the model. 288 | num_res: Number of block resnet. 289 | """ 290 | 291 | def __init__( 292 | self, 293 | num_input_channels=8, 294 | num_output_channels=3, 295 | feature_scale=4, 296 | num_res=1 297 | 298 | ): 299 | super().__init__() 300 | 301 | self.feature_scale = feature_scale 302 | base_channel = 8 303 | 304 | filters = [64, 128, 256, 512, 1024] 305 | filters = [x // self.feature_scale for x in filters] 306 | 307 | base_channel = 8 308 | 309 | self.feat_extract = nn.ModuleList([ 310 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1), 311 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2), 312 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2), 313 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2), 314 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2), 315 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 316 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2), 317 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2), 318 | ]) 319 | 320 | self.SCM0 = SCM(base_channel * 8) 321 | self.SCM1 = SCM(base_channel * 4) 322 | self.SCM2 = SCM(base_channel * 2) 323 | 324 | self.FAM0 = FAM(base_channel * 8) 325 | self.FAM1 = FAM(base_channel * 4) 326 | self.FAM2 = FAM(base_channel * 2) 327 | 328 | self.AFFs = nn.ModuleList([ 329 | AFF(base_channel * 15, base_channel * 1), 330 | AFF(base_channel * 15, base_channel * 2), 331 | AFF(base_channel * 15, base_channel * 4), 332 | ]) 333 | 334 | self.Encoder = nn.ModuleList([ 335 | EBlock(base_channel, num_res), 336 | EBlock(base_channel * 2, num_res), 337 | EBlock(base_channel * 4, num_res), 338 | EBlock(base_channel * 8, num_res) 339 | ]) 340 | 341 | self.Decoder = nn.ModuleList([ 342 | DBlock(base_channel * 8, num_res), 343 | DBlock(base_channel * 4, num_res), 344 | DBlock(base_channel * 2, num_res), 345 | DBlock(base_channel, num_res) 346 | ]) 347 | 348 | self.Convs = nn.ModuleList([ 349 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 350 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 351 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1) 352 | 353 | ]) 354 | 355 | self.ConvsOut = nn.ModuleList( 356 | [ 357 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 358 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 359 | ] 360 | ) 361 | 362 | self.up = nn.Upsample(scale_factor=4, mode='bilinear') 363 | 364 | def forward(self, *inputs, **kwargs): 365 | inputs = list(inputs) 366 | 367 | n_input = len(inputs) 368 | 369 | x = inputs[0] 370 | x_2 = inputs[1] 371 | x_4 = inputs[2] 372 | x_8 = inputs[3] 373 | 374 | z2 = self.SCM2(x_2) 375 | z4 = self.SCM1(x_4) 376 | z8 = self.SCM0(x_8) 377 | 378 | x_ = self.feat_extract[0](x) 379 | res1 = self.Encoder[0](x_) 380 | 381 | z = self.feat_extract[1](res1) 382 | z = self.FAM2(z, z2) 383 | res2 = self.Encoder[1](z) 384 | 385 | z = self.feat_extract[2](res2) 386 | z = self.FAM1(z, z4) 387 | res3 = self.Encoder[2](z) 388 | 389 | z = self.feat_extract[6](res3) 390 | z = self.FAM0(z, z8) 391 | z = self.Encoder[3](z) 392 | 393 | z12 = F.interpolate(res1, scale_factor=0.5) 394 | z13 = F.interpolate(res1, scale_factor=0.25) 395 | 396 | z21 = F.interpolate(res2, scale_factor=2) 397 | z23 = F.interpolate(res2, scale_factor=0.5) 398 | 399 | z32 = F.interpolate(res3, scale_factor=2) 400 | z31 = F.interpolate(res3, scale_factor=4) 401 | 402 | z43 = F.interpolate(z, scale_factor=2) 403 | z42 = F.interpolate(z43, scale_factor=2) 404 | z41 = F.interpolate(z42, scale_factor=2) 405 | 406 | res1 = self.AFFs[0](res1, z21, z31, z41) 407 | res2 = self.AFFs[1](z12, res2, z32, z42) 408 | res3 = self.AFFs[2](z13, z23, res3, z43) 409 | z = self.Decoder[0](z) 410 | 411 | z = self.feat_extract[7](z) 412 | 413 | z = self.up(z) 414 | z = F.interpolate(z, res3.shape[-2:]) 415 | z = torch.cat([z, res3], dim=1) 416 | z = self.Convs[0](z) 417 | z = self.Decoder[1](z) 418 | 419 | z = self.feat_extract[3](z) 420 | z = self.up(z) 421 | 422 | z = torch.cat([z, res2], dim=1) 423 | z = self.Convs[1](z) 424 | z = self.Decoder[2](z) 425 | 426 | z = self.feat_extract[4](z) 427 | z = self.up(z) 428 | 429 | z = torch.cat([z, res1], dim=1) 430 | z = self.Convs[2](z) 431 | z = self.Decoder[3](z) 432 | z = self.feat_extract[5](z) 433 | 434 | return {'im_out': z} 435 | 436 | 437 | class SimpleNet(nn.Module): 438 | def __init__( 439 | self, 440 | num_input_channels=8, 441 | num_output_channels=3, 442 | feature_scale=4, 443 | num_res=2 444 | 445 | ): 446 | super().__init__() 447 | 448 | self.feature_scale = feature_scale 449 | base_channel = 8 450 | 451 | filters = [64, 128, 256, 512, 1024] 452 | filters = [x // self.feature_scale for x in filters] 453 | 454 | base_channel = 16 455 | 456 | self.feat_extract = nn.ModuleList([ 457 | BasicConv(16, base_channel, kernel_size=3, relu=True, stride=1), 458 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2), 459 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2), 460 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2), 461 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2), 462 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1), 463 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2), 464 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2), 465 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=1), 466 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=3, relu=True, stride=1), 467 | BasicConv(base_channel*4, base_channel * 2, kernel_size=3, relu=True, stride=1), 468 | BasicConv(base_channel * 2, base_channel * 1, kernel_size=3, relu=True, stride=1) 469 | ]) 470 | 471 | self.SCM0 = SCM(base_channel * 8) 472 | self.SCM1 = SCM(base_channel * 4) 473 | self.SCM2 = SCM(base_channel * 2) 474 | 475 | self.FAM0 = FAM(base_channel * 8) 476 | self.FAM1 = FAM(base_channel * 4) 477 | self.FAM2 = FAM(base_channel * 2) 478 | 479 | self.AFFs = nn.ModuleList([ 480 | AFF(base_channel * 15, base_channel * 1), 481 | AFF(base_channel * 15, base_channel * 2), 482 | AFF(base_channel * 15, base_channel * 4), 483 | ]) 484 | 485 | self.Encoder = nn.ModuleList([ 486 | EBlock(base_channel, num_res), 487 | EBlock(base_channel * 2, num_res), 488 | EBlock(base_channel * 4, num_res), 489 | EBlock(base_channel * 8, num_res) 490 | ]) 491 | 492 | self.Decoder = nn.ModuleList([ 493 | DBlock(base_channel * 8, num_res), 494 | DBlock(base_channel * 4, num_res), 495 | DBlock(base_channel * 2, num_res), 496 | DBlock(base_channel, num_res) 497 | ]) 498 | 499 | self.Convs = nn.ModuleList([ 500 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 501 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 502 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 503 | BasicConv(base_channel * 8, base_channel * 8, kernel_size=1, relu=True, stride=1), 504 | BasicConv(base_channel * 4, base_channel * 4, kernel_size=1, relu=True, stride=1), 505 | BasicConv(base_channel * 2, base_channel * 2, kernel_size=1, relu=True, stride=1), 506 | 507 | ]) 508 | 509 | self.ConvsOut = nn.ModuleList( 510 | [ 511 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 512 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 513 | ] 514 | ) 515 | 516 | self.up = nn.Upsample(scale_factor=2, mode='bilinear') 517 | 518 | def forward(self, *inputs, **kwargs): 519 | inputs = list(inputs) 520 | 521 | n_input = len(inputs) 522 | 523 | x = inputs[0] 524 | 525 | # x = self.feat_extract[0](x) 526 | x = self.Encoder[0](x) 527 | z1 = x 528 | # print(x.shape, "1") 529 | 530 | x = self.feat_extract[1](x) 531 | x = self.Encoder[1](x) 532 | z2 = x 533 | # print(x.shape, "2") 534 | 535 | x = self.feat_extract[2](x) 536 | x = self.Encoder[2](x) 537 | z3 = x 538 | # print(x.shape, "3") 539 | 540 | x = self.feat_extract[6](x) 541 | x = self.Encoder[3](x) 542 | # print(x.shape, "4") 543 | 544 | x = self.up(x) 545 | # print(x.shape, "5") 546 | 547 | x = self.Convs[0](x) 548 | x = torch.cat([x, z3], dim=1) 549 | x = self.Decoder[0](x) 550 | 551 | # print(x.shape, "6") 552 | 553 | 554 | x = self.feat_extract[9](x) 555 | x = self.up(x) 556 | # print(x.shape, "7") 557 | 558 | 559 | x = self.Convs[1](x) 560 | x = torch.cat([x, z2], dim=1) 561 | x = self.Decoder[1](x) 562 | 563 | # print(x.shape, "8") 564 | 565 | x = self.feat_extract[10](x) 566 | x = self.up(x) 567 | # print(x.shape, "9") 568 | 569 | x = self.Convs[2](x) 570 | x = torch.cat([x, z1], dim=1) 571 | x = self.Decoder[2](x) 572 | # print(x.shape, "10") 573 | 574 | x = self.feat_extract[11](x) 575 | # print(x.shape, "11") 576 | 577 | z = self.feat_extract[5](x) 578 | # print(z.shape, "12") 579 | 580 | return {'im_out': z} 581 | 582 | # # x = self.feat_extract[0](x) 583 | # x = self.Encoder[0](x) 584 | # print(x.shape, "1") 585 | # 586 | # x = self.feat_extract[1](x) 587 | # x = self.Encoder[1](x) 588 | # print(x.shape, "2") 589 | # 590 | # x = self.feat_extract[2](x) 591 | # x = self.Encoder[2](x) 592 | # print(x.shape, "3") 593 | # 594 | # x = self.feat_extract[6](x) 595 | # x = self.Encoder[3](x) 596 | # print(x.shape, "4") 597 | # 598 | # x = self.up(x) 599 | # print(x.shape, "5") 600 | # 601 | # x = self.Convs[3](x) 602 | # x = self.Decoder[0](x) 603 | # print(x.shape, "6") 604 | # 605 | # x = self.feat_extract[9](x) 606 | # x = self.up(x) 607 | # print(x.shape, "7") 608 | # 609 | # x = self.Convs[4](x) 610 | # x = self.Decoder[1](x) 611 | # print(x.shape, "8") 612 | # 613 | # x = self.feat_extract[10](x) 614 | # x = self.up(x) 615 | # print(x.shape, "9") 616 | # 617 | # x = self.Convs[5](x) 618 | # x = self.Decoder[2](x) 619 | # print(x.shape, "10") 620 | # 621 | # x = self.feat_extract[11](x) 622 | # print(x.shape, "11") 623 | # 624 | # z = self.feat_extract[5](x) 625 | # print(z.shape, "12") 626 | # 627 | # return {'im_out': z} 628 | 629 | if __name__ == '__main__': 630 | import pdb 631 | import time 632 | import numpy as np 633 | 634 | # model = UNet().to('cuda') 635 | model = SimpleNet().to('cuda') 636 | input = [] 637 | img_sh = [1408, 376] 638 | sh_unit = 8 639 | # img_sh = list(map(lambda a: a - a % sh_unit + sh_unit if a % sh_unit != 0 else a, img_sh)) 640 | 641 | # print(img_sh) 642 | down = lambda a, b: a // 2 ** b 643 | input.append(torch.zeros((1, 8, down(img_sh[0], 0), down(img_sh[1], 0)), requires_grad=True).cuda()) 644 | input.append(F.interpolate(input[0], scale_factor=0.5)) 645 | input.append(F.interpolate(input[1], scale_factor=0.5)) 646 | input.append(F.interpolate(input[2], scale_factor=0.5)) 647 | print(input) 648 | 649 | model.eval() 650 | st = time.time() 651 | print(input[0].max(), input[0].min()) 652 | print(input[0].shape, input[1].shape, input[2].shape, input[3].shape) 653 | with torch.set_grad_enabled(False): 654 | out = model(*input) 655 | pdb.set_trace() 656 | print('model', time.time() - st) 657 | print(out['im_out'], out['im_out'].shape) 658 | model.to('cpu') 659 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation, get_minimum_axis, flip_align_view 23 | import tinycudann as tcnn 24 | 25 | class GaussianModel: 26 | 27 | def setup_functions(self): 28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 30 | actual_covariance = L @ L.transpose(1, 2) 31 | symm = strip_symmetric(actual_covariance) 32 | return symm 33 | 34 | self.scaling_activation = torch.exp 35 | self.scaling_inverse_activation = torch.log 36 | 37 | self.covariance_activation = build_covariance_from_scaling_rotation 38 | 39 | self.opacity_activation = torch.sigmoid 40 | self.inverse_opacity_activation = inverse_sigmoid 41 | 42 | self.rotation_activation = torch.nn.functional.normalize 43 | 44 | def __init__(self, sh_degree : int, num_sem_classes : int): # num_sem_classes: 语义类别数 45 | self.active_sh_degree = 0 46 | self.max_sh_degree = sh_degree 47 | self.num_sem_classes = num_sem_classes 48 | self._xyz = torch.empty(0) 49 | # self._features_dc = torch.empty(0) 50 | # self._features_rest = torch.empty(0) 51 | self._semantic = torch.empty(0) # 语义参数 52 | self._scaling = torch.empty(0) 53 | self._rotation = torch.empty(0) 54 | self._opacity = torch.empty(0) 55 | self.max_radii2D = torch.empty(0) 56 | self.xyz_gradient_accum = torch.empty(0) 57 | self.denom = torch.empty(0) 58 | self.optimizer = None 59 | self.percent_dense = 0 60 | self.spatial_lr_scale = 0 61 | self.setup_functions() 62 | 63 | self.recolor = tcnn.Encoding( 64 | n_input_dims=3, 65 | encoding_config={ 66 | "otype": "HashGrid", 67 | "n_levels": 16, 68 | "n_features_per_level": 2, 69 | "log2_hashmap_size": 22, 70 | "base_resolution": 128, 71 | "per_level_scale": 1.4, 72 | }, 73 | ) 74 | self.positional_encoding = tcnn.Encoding( 75 | n_input_dims=3, 76 | encoding_config={ 77 | "otype": "Frequency", 78 | "n_frequencies": 4 79 | }, 80 | ) 81 | self.direction_encoding = tcnn.Encoding( 82 | n_input_dims=3, 83 | encoding_config={ 84 | "otype": "SphericalHarmonics", 85 | "degree": 3 86 | }, 87 | ) 88 | self.mlp_head = tcnn.Network( 89 | n_input_dims=(self.recolor.n_output_dims), 90 | n_output_dims=8, 91 | network_config={ 92 | "otype": "FullyFusedMLP", 93 | "activation": "ReLU", 94 | "output_activation": "None", 95 | "n_neurons": 64, 96 | "n_hidden_layers": 2, 97 | }, 98 | ) 99 | self.mlp_direction_head = tcnn.Network( 100 | n_input_dims=(self.direction_encoding.n_output_dims + 3), 101 | n_output_dims=1, 102 | network_config={ 103 | "otype": "FullyFusedMLP", 104 | "activation": "ReLU", 105 | "output_activation": "Sigmoid", 106 | "n_neurons": 32, 107 | "n_hidden_layers": 1, 108 | }, 109 | ) 110 | self.mlp_normal_head = tcnn.Network( 111 | n_input_dims=16, 112 | n_output_dims=3, 113 | network_config={ 114 | "otype": "FullyFusedMLP", 115 | "activation": "ReLU", 116 | "output_activation": "None", 117 | "n_neurons": 32, 118 | "n_hidden_layers": 1, 119 | }, 120 | ) 121 | 122 | def capture(self): 123 | return ( 124 | self.active_sh_degree, 125 | self.num_sem_classes, 126 | self._xyz, 127 | # self._features_dc, 128 | # self._features_rest, 129 | self._semantic, 130 | self._scaling, 131 | self._rotation, 132 | self._opacity, 133 | self.max_radii2D, 134 | self.xyz_gradient_accum, 135 | self.denom, 136 | self.optimizer.state_dict(), 137 | self.spatial_lr_scale, 138 | ) 139 | 140 | def restore(self, model_args, training_args): 141 | (self.active_sh_degree, 142 | self._xyz, 143 | # self._features_dc, 144 | # self._features_rest, 145 | self._semantic, 146 | self._sem_class, 147 | self._scaling, 148 | self._rotation, 149 | self._opacity, 150 | self.max_radii2D, 151 | xyz_gradient_accum, 152 | denom, 153 | opt_dict, 154 | self.spatial_lr_scale) = model_args 155 | self.training_setup(training_args) 156 | self.xyz_gradient_accum = xyz_gradient_accum 157 | self.denom = denom 158 | self.optimizer.load_state_dict(opt_dict) 159 | 160 | @property 161 | def get_scaling(self): 162 | return self.scaling_activation(self._scaling) 163 | 164 | @property 165 | def get_rotation(self): 166 | return self.rotation_activation(self._rotation) 167 | 168 | @property 169 | def get_xyz(self): 170 | return self._xyz 171 | 172 | @property 173 | def get_minimum_axis(self): 174 | return get_minimum_axis(self.get_scaling, self.get_rotation) 175 | 176 | @property 177 | def get_pseudo_normal(self, dir_pp_normalized): 178 | normal_axis = self.get_minimum_axis 179 | normal_axis, positive = flip_align_view(normal_axis, dir_pp_normalized) 180 | normal = normal_axis 181 | normal = normal / normal.norm(dim=1, keepdim=True) # (N, 3) 182 | return normal 183 | 184 | @property 185 | def get_dir_pp(self, view): 186 | means3D = self.get_xyz 187 | dir_pp = (means3D - view.camera_center.repeat(means3D.shape[0], 1)) 188 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True) 189 | return dir_pp 190 | 191 | @property 192 | def get_normal(self): 193 | normal = self.mlp_normal_head(self.get_semantic) 194 | # normal = self.mlp_normal_head( 195 | # torch.cat((self.get_semantic, self.get_rotation, self.get_scaling, self._xyz), dim=1)) 196 | return normal 197 | 198 | # @property 199 | # def get_features(self): 200 | # features_dc = self._features_dc 201 | # features_rest = self._features_rest 202 | # return torch.cat((features_dc, features_rest), dim=1) 203 | 204 | @property 205 | def get_semantic(self): 206 | return self._semantic 207 | 208 | @property 209 | def get_opacity(self): 210 | return self.opacity_activation(self._opacity) 211 | 212 | def get_covariance(self, scaling_modifier = 1): 213 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 214 | 215 | def oneupSHdegree(self): 216 | if self.active_sh_degree < self.max_sh_degree: 217 | self.active_sh_degree += 1 218 | 219 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): 220 | self.spatial_lr_scale = spatial_lr_scale 221 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() # [10458, 3] 222 | # fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) # [10458, 3] 223 | # features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() # [10458, 3, 16] 224 | # features[:, :3, 0 ] = fused_color 225 | # features[:, 3:, 1:] = 0.0 # [10458, 3, 16] 226 | 227 | semantic = torch.zeros((fused_point_cloud.shape[0], self.num_sem_classes)).float().cuda() # 语义 [..., 20, 1] 228 | 229 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 230 | 231 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 232 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) 233 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 234 | rots[:, 0] = 1 235 | 236 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 237 | 238 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 239 | # self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) # [10458, 1, 3] 240 | # self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) # [10458, 15, 3] 241 | self._semantic = nn.Parameter(semantic.requires_grad_(True)) # [10458, 1, 20] ??? 242 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 243 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 244 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 245 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 246 | 247 | 248 | def training_setup(self, training_args): 249 | self.percent_dense = training_args.percent_dense 250 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 251 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 252 | 253 | other_params = [] 254 | # for params in self.recolor.parameters(): 255 | # other_params.append(params) 256 | # for params in self.mlp_head.parameters(): 257 | # other_params.append(params) 258 | for params in self.mlp_direction_head.parameters(): 259 | other_params.append(params) 260 | for params in self.direction_encoding.parameters(): 261 | other_params.append(params) 262 | for params in self.mlp_normal_head.parameters(): 263 | other_params.append(params) 264 | for params in self.positional_encoding.parameters(): 265 | other_params.append(params) 266 | 267 | l = [ 268 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 269 | # {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 270 | # {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 271 | {'params': [self._semantic], 'lr': training_args.semantic_lr, "name": "semantic"}, 272 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 273 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, 274 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 275 | ] 276 | 277 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 278 | self.optimizer_net = torch.optim.Adam(other_params, lr=1e-2, eps=1e-15) # 5e-3 279 | self.scheduler_net = torch.optim.lr_scheduler.ChainedScheduler( 280 | [ 281 | torch.optim.lr_scheduler.LinearLR( 282 | self.optimizer_net, start_factor=0.01, total_iters=100 283 | ), 284 | torch.optim.lr_scheduler.MultiStepLR( 285 | self.optimizer_net, 286 | milestones=[5_000, 15_000, 25_000], 287 | gamma=0.33, 288 | ), 289 | ] 290 | ) 291 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 292 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 293 | lr_delay_mult=training_args.position_lr_delay_mult, 294 | max_steps=training_args.position_lr_max_steps) 295 | 296 | def update_learning_rate(self, iteration): 297 | ''' Learning rate scheduling per step ''' 298 | for param_group in self.optimizer.param_groups: 299 | if param_group["name"] == "xyz": 300 | lr = self.xyz_scheduler_args(iteration) 301 | param_group['lr'] = lr 302 | return lr 303 | 304 | def construct_list_of_attributes(self): 305 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 306 | # All channels except the 3 DC 307 | # for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): 308 | # l.append('f_dc_{}'.format(i)) 309 | # for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): 310 | # l.append('f_rest_{}'.format(i)) 311 | for i in range(self._semantic.shape[1]): 312 | l.append('semantic_{}'.format(i)) 313 | l.append('opacity') 314 | for i in range(self._scaling.shape[1]): 315 | l.append('scale_{}'.format(i)) 316 | for i in range(self._rotation.shape[1]): 317 | l.append('rot_{}'.format(i)) 318 | return l 319 | 320 | def save_ply(self, path): 321 | mkdir_p(os.path.dirname(path)) 322 | 323 | xyz = self._xyz.detach().cpu().numpy() 324 | normals = np.zeros_like(xyz) 325 | # f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 326 | # f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 327 | semantic = self._semantic.detach().cpu().numpy() 328 | opacities = self._opacity.detach().cpu().numpy() 329 | scale = self._scaling.detach().cpu().numpy() 330 | rotation = self._rotation.detach().cpu().numpy() 331 | 332 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 333 | 334 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 335 | # attributes = np.concatenate((xyz, normals, f_dc, f_rest, semantic, opacities, scale, rotation), axis=1) 336 | attributes = np.concatenate((xyz, normals, semantic, opacities, scale, rotation), axis=1) 337 | # attributes = np.concatenate((xyz, normals, opacities, scale, rotation), axis=1) 338 | elements[:] = list(map(tuple, attributes)) 339 | el = PlyElement.describe(elements, 'vertex') 340 | PlyData([el]).write(path) 341 | 342 | def reset_opacity(self): 343 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) 344 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 345 | self._opacity = optimizable_tensors["opacity"] 346 | 347 | def load_ply(self, path): 348 | plydata = PlyData.read(path) 349 | 350 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 351 | np.asarray(plydata.elements[0]["y"]), 352 | np.asarray(plydata.elements[0]["z"])), axis=1) 353 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 354 | 355 | semantic_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("semantic_")] 356 | semantic_names = sorted(semantic_names, key = lambda x: int(x.split('_')[-1])) 357 | semantic = np.zeros((xyz.shape[0], len(semantic_names))) 358 | for idx, attr_name in enumerate(semantic_names): 359 | semantic[:, idx] = np.asarray(plydata.elements[0][attr_name]) 360 | 361 | # features_dc = np.zeros((xyz.shape[0], 3, 1)) 362 | # features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 363 | # features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 364 | # features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 365 | 366 | # extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 367 | # extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 368 | # assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 369 | # features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 370 | # for idx, attr_name in enumerate(extra_f_names): 371 | # features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 372 | # # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 373 | # features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 374 | 375 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 376 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 377 | scales = np.zeros((xyz.shape[0], len(scale_names))) 378 | for idx, attr_name in enumerate(scale_names): 379 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 380 | 381 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 382 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 383 | rots = np.zeros((xyz.shape[0], len(rot_names))) 384 | for idx, attr_name in enumerate(rot_names): 385 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 386 | 387 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 388 | # self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 389 | # self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) 390 | self._semantic = nn.Parameter(torch.tensor(semantic, dtype=torch.float, device="cuda").requires_grad_(True)) 391 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 392 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 393 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 394 | 395 | self.active_sh_degree = self.max_sh_degree 396 | 397 | def replace_tensor_to_optimizer(self, tensor, name): 398 | optimizable_tensors = {} 399 | for group in self.optimizer.param_groups: 400 | if group["name"] == name: 401 | stored_state = self.optimizer.state.get(group['params'][0], None) 402 | stored_state["exp_avg"] = torch.zeros_like(tensor) 403 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 404 | 405 | del self.optimizer.state[group['params'][0]] 406 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 407 | self.optimizer.state[group['params'][0]] = stored_state 408 | 409 | optimizable_tensors[group["name"]] = group["params"][0] 410 | return optimizable_tensors 411 | 412 | def _prune_optimizer(self, mask): 413 | optimizable_tensors = {} 414 | for group in self.optimizer.param_groups: 415 | stored_state = self.optimizer.state.get(group['params'][0], None) 416 | if stored_state is not None: 417 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 418 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 419 | 420 | del self.optimizer.state[group['params'][0]] 421 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 422 | self.optimizer.state[group['params'][0]] = stored_state 423 | 424 | optimizable_tensors[group["name"]] = group["params"][0] 425 | else: 426 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 427 | optimizable_tensors[group["name"]] = group["params"][0] 428 | return optimizable_tensors 429 | 430 | def prune_points(self, mask): 431 | valid_points_mask = ~mask 432 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 433 | 434 | self._xyz = optimizable_tensors["xyz"] 435 | # self._features_dc = optimizable_tensors["f_dc"] 436 | # self._features_rest = optimizable_tensors["f_rest"] 437 | self._semantic = optimizable_tensors["semantic"] 438 | self._opacity = optimizable_tensors["opacity"] 439 | self._scaling = optimizable_tensors["scaling"] 440 | self._rotation = optimizable_tensors["rotation"] 441 | 442 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 443 | 444 | self.denom = self.denom[valid_points_mask] 445 | self.max_radii2D = self.max_radii2D[valid_points_mask] 446 | 447 | def cat_tensors_to_optimizer(self, tensors_dict): 448 | optimizable_tensors = {} 449 | for group in self.optimizer.param_groups: 450 | assert len(group["params"]) == 1 451 | extension_tensor = tensors_dict[group["name"]] 452 | stored_state = self.optimizer.state.get(group['params'][0], None) 453 | if stored_state is not None: 454 | 455 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) 456 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) 457 | 458 | del self.optimizer.state[group['params'][0]] 459 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 460 | self.optimizer.state[group['params'][0]] = stored_state 461 | 462 | optimizable_tensors[group["name"]] = group["params"][0] 463 | else: 464 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 465 | optimizable_tensors[group["name"]] = group["params"][0] 466 | 467 | return optimizable_tensors 468 | 469 | # def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacities, new_scaling, new_rotation): 470 | def densification_postfix(self, new_xyz, new_semantic, new_opacities,new_scaling, new_rotation): 471 | # def densification_postfix(self, new_xyz, new_opacities, new_scaling, new_rotation): 472 | d = {"xyz": new_xyz, 473 | # "f_dc": new_features_dc, 474 | # "f_rest": new_features_rest, 475 | "semantic": new_semantic, 476 | "opacity": new_opacities, 477 | "scaling" : new_scaling, 478 | "rotation" : new_rotation} 479 | 480 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 481 | self._xyz = optimizable_tensors["xyz"] 482 | # self._features_dc = optimizable_tensors["f_dc"] 483 | # self._features_rest = optimizable_tensors["f_rest"] 484 | self._semantic = optimizable_tensors["semantic"] 485 | self._opacity = optimizable_tensors["opacity"] 486 | self._scaling = optimizable_tensors["scaling"] 487 | self._rotation = optimizable_tensors["rotation"] 488 | 489 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 490 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 491 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 492 | 493 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 494 | n_init_points = self.get_xyz.shape[0] 495 | # Extract points that satisfy the gradient condition 496 | padded_grad = torch.zeros((n_init_points), device="cuda") 497 | padded_grad[:grads.shape[0]] = grads.squeeze() 498 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 499 | selected_pts_mask = torch.logical_and(selected_pts_mask, 500 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) 501 | 502 | stds = self.get_scaling[selected_pts_mask].repeat(N,1) 503 | means = torch.zeros((stds.size(0), 3),device="cuda") 504 | samples = torch.normal(mean=means, std=stds) 505 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) 506 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 507 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) 508 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1) 509 | # new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) 510 | # new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) 511 | new_semantic = self._semantic[selected_pts_mask].repeat(N,1) 512 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1) 513 | 514 | # self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacity, new_scaling, new_rotation) 515 | self.densification_postfix(new_xyz, new_semantic, new_opacity, new_scaling,new_rotation) 516 | # self.densification_postfix(new_xyz, new_opacity, new_scaling, new_rotation) 517 | 518 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 519 | self.prune_points(prune_filter) 520 | 521 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 522 | # Extract points that satisfy the gradient condition 523 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 524 | selected_pts_mask = torch.logical_and(selected_pts_mask, 525 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) 526 | 527 | new_xyz = self._xyz[selected_pts_mask] 528 | # new_features_dc = self._features_dc[selected_pts_mask] 529 | # new_features_rest = self._features_rest[selected_pts_mask] 530 | new_semantic = self._semantic[selected_pts_mask] 531 | new_opacities = self._opacity[selected_pts_mask] 532 | new_scaling = self._scaling[selected_pts_mask] 533 | new_rotation = self._rotation[selected_pts_mask] 534 | 535 | # self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacities, new_scaling, new_rotation) 536 | self.densification_postfix(new_xyz, new_semantic, new_opacities, new_scaling, new_rotation) 537 | # self.densification_postfix(new_xyz, new_opacities, new_scaling, new_rotation) 538 | 539 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 540 | grads = self.xyz_gradient_accum / self.denom 541 | grads[grads.isnan()] = 0.0 542 | 543 | self.densify_and_clone(grads, max_grad, extent) 544 | self.densify_and_split(grads, max_grad, extent) 545 | 546 | prune_mask = (self.get_opacity < min_opacity).squeeze() 547 | if max_screen_size: 548 | big_points_vs = self.max_radii2D > max_screen_size 549 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 550 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 551 | self.prune_points(prune_mask) 552 | 553 | torch.cuda.empty_cache() 554 | 555 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 556 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True) 557 | self.denom[update_filter] += 1 558 | 559 | --------------------------------------------------------------------------------