├── vgg.pth
├── assets
└── pipeline.png
├── environment.yml
├── lpipsPyTorch
├── __init__.py
└── modules
│ ├── utils.py
│ ├── lpips.py
│ └── networks.py
├── utils
├── system_utils.py
├── loss_utils.py
├── image_utils.py
├── graphics_utils.py
├── camera_utils.py
├── sh_utils.py
└── general_utils.py
├── README.md
├── gaussian_renderer
├── network_gui.py
└── __init__.py
├── full_eval.py
├── scene
├── cameras.py
├── __init__.py
├── colmap_loader.py
├── dataset_readers.py
└── gaussian_model.py
├── render_fps.py
├── arguments
└── __init__.py
├── metrics.py
├── render.py
├── convert.py
├── train.py
├── mynet.py
├── train_2.py
└── model.py
/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcWangzhiru/SpeclatentGS/HEAD/vgg.pth
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcWangzhiru/SpeclatentGS/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: speclatentgs
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - cudatoolkit=11.6
8 | - plyfile=0.8.1
9 | - python=3.7.13
10 | - pip=22.3.1
11 | - pytorch=1.12.1
12 | - torchaudio=0.12.1
13 | - torchvision=0.13.1
14 | - tqdm
15 | - imgviz
16 | - imageio
--------------------------------------------------------------------------------
/lpipsPyTorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'vgg',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/utils/system_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from errno import EEXIST
13 | from os import makedirs, path
14 | import os
15 |
16 | def mkdir_p(folder_path):
17 | # Creates a directory. equivalent to using mkdir -p on the command line
18 | try:
19 | makedirs(folder_path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == EEXIST and path.isdir(folder_path):
22 | pass
23 | else:
24 | raise
25 |
26 | def searchForMaxIteration(folder):
27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
28 | return max(saved_iters)
29 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SpeclatentGS
2 | This repo contains the official implementation of the **ACM MM 2024** [paper](https://arxiv.org/abs/2409.05868) :
3 |
4 |
5 |
6 |
7 | SpecGaussian with latent features: A high-quality modeling of the view-dependent appearance for 3D Gaussian Splatting
8 |
9 |
10 |
11 |
12 | Zhiru Wang,Shiyun Xie,Chengwei Pan and Guoping Wang
13 |
14 |
15 |
16 |
17 | ### PIPELINE
18 | 
19 |
20 | ## Environment Installation
21 | You can install the base environment using:
22 | ```shell
23 | git clone https://github.com/MarcWangzhiru/SpeclatentGS.git
24 | cd SpeclatentGS
25 | conda env create --file environment.yml
26 | ```
27 | For the installation of **submodules**, you can use the following command:
28 | ```shell
29 | cd submodules/diff-gaussian-rasterization
30 | python stup.py install
31 | ```
32 | and
33 | ```shell
34 | cd submodules/simple-knn
35 | python stup.py install
36 | ```
37 | You also need to install the **tinycudann** library. In general, you can use the following command:
38 | ```shell
39 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
40 | ```
41 |
42 | ## Dataset Preparation
43 | The dataset used in our method is in the same format as the dataset in Gaussian Splatting. If you want to use your custom dataset, follow the process of in [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting.git). We obtained our own [shiny_dataset]( https://drive.google.com/file/d/1mmBmptl9Pd8crLfO9y2E51F_R4Ecy0R1/view?usp=sharing) by resize the images of original [Shiny Dataset](https://vistec-my.sharepoint.com/:f:/g/personal/pakkapon_p_s19_vistec_ac_th/EnIUhsRVJOdNsZ_4smdhye0B8z0VlxqOR35IR3bp0uGupQ?e=TsaQgM) and recolmap.
44 |
45 |
46 | ## Trainng
47 | For training, you can use the following command:
48 | ```shell
49 | python train.py -s --eval
50 | ```
51 | ## Evalution
52 | For evalution, you can use the following command:
53 | ```shell
54 | python render.py -m --eval
55 | ```
56 |
57 |
58 | ## Citation
59 |
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | from math import exp
16 |
17 | def l1_loss(network_output, gt):
18 | return torch.abs((network_output - gt)).mean()
19 |
20 | def l2_loss(network_output, gt):
21 | return ((network_output - gt) ** 2).mean()
22 |
23 | def gaussian(window_size, sigma):
24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
25 | return gauss / gauss.sum()
26 |
27 | def create_window(window_size, channel):
28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
31 | return window
32 |
33 | def ssim(img1, img2, window_size=11, size_average=True):
34 | channel = img1.size(-3)
35 | window = create_window(window_size, channel)
36 |
37 | if img1.is_cuda:
38 | window = window.cuda(img1.get_device())
39 | window = window.type_as(img1)
40 |
41 | return _ssim(img1, img2, window, window_size, channel, size_average)
42 |
43 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
46 |
47 | mu1_sq = mu1.pow(2)
48 | mu2_sq = mu2.pow(2)
49 | mu1_mu2 = mu1 * mu2
50 |
51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
54 |
55 | C1 = 0.01 ** 2
56 | C2 = 0.03 ** 2
57 |
58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
59 |
60 | if size_average:
61 | return ssim_map.mean()
62 | else:
63 | return ssim_map.mean(1).mean(1).mean(1)
64 |
65 |
--------------------------------------------------------------------------------
/gaussian_renderer/network_gui.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import traceback
14 | import socket
15 | import json
16 | from scene.cameras import MiniCam
17 |
18 | host = "127.0.0.1"
19 | port = 6009
20 |
21 | conn = None
22 | addr = None
23 |
24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25 |
26 | def init(wish_host, wish_port):
27 | global host, port, listener
28 | host = wish_host
29 | port = wish_port
30 | listener.bind((host, port))
31 | listener.listen()
32 | listener.settimeout(0)
33 |
34 | def try_connect():
35 | global conn, addr, listener
36 | try:
37 | conn, addr = listener.accept()
38 | print(f"\nConnected by {addr}")
39 | conn.settimeout(None)
40 | except Exception as inst:
41 | pass
42 |
43 | def read():
44 | global conn
45 | messageLength = conn.recv(4)
46 | messageLength = int.from_bytes(messageLength, 'little')
47 | message = conn.recv(messageLength)
48 | return json.loads(message.decode("utf-8"))
49 |
50 | def send(message_bytes, verify):
51 | global conn
52 | if message_bytes != None:
53 | conn.sendall(message_bytes)
54 | conn.sendall(len(verify).to_bytes(4, 'little'))
55 | conn.sendall(bytes(verify, 'ascii'))
56 |
57 | def receive():
58 | message = read()
59 |
60 | width = message["resolution_x"]
61 | height = message["resolution_y"]
62 |
63 | if width != 0 and height != 0:
64 | try:
65 | do_training = bool(message["train"])
66 | fovy = message["fov_y"]
67 | fovx = message["fov_x"]
68 | znear = message["z_near"]
69 | zfar = message["z_far"]
70 | do_shs_python = bool(message["shs_python"])
71 | do_rot_scale_python = bool(message["rot_scale_python"])
72 | keep_alive = bool(message["keep_alive"])
73 | scaling_modifier = message["scaling_modifier"]
74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75 | world_view_transform[:,1] = -world_view_transform[:,1]
76 | world_view_transform[:,2] = -world_view_transform[:,2]
77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78 | full_proj_transform[:,1] = -full_proj_transform[:,1]
79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80 | except Exception as e:
81 | print("")
82 | traceback.print_exc()
83 | raise e
84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85 | else:
86 | return None, None, None, None, None, None
--------------------------------------------------------------------------------
/lpipsPyTorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from sklearn.metrics import confusion_matrix
14 | import numpy as np
15 |
16 | def mse(img1, img2):
17 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
18 |
19 | def psnr(img1, img2):
20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
21 | return 20 * torch.log10(1.0 / torch.sqrt(mse))
22 |
23 | def nanmean(data, **args):
24 | # This makes it ignore the first 'background' class
25 | return np.ma.masked_array(data, np.isnan(data)).mean(**args)
26 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean()
27 |
28 |
29 | def calculate_segmentation_metrics(true_labels, predicted_labels, number_classes, ignore_label):
30 | # if (true_labels == ignore_label).all():
31 | # return [0]*4
32 |
33 | true_labels = true_labels.flatten()
34 | predicted_labels = predicted_labels.flatten()
35 | valid_pix_ids = true_labels != ignore_label
36 | predicted_labels = predicted_labels[valid_pix_ids]
37 | true_labels = true_labels[valid_pix_ids]
38 |
39 | conf_mat = confusion_matrix(true_labels, predicted_labels, labels=list(range(number_classes)))
40 | norm_conf_mat = np.transpose(
41 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1))
42 |
43 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) # missing class will have NaN at corresponding class
44 | exsiting_class_mask = ~ missing_class_mask
45 |
46 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat))
47 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat))
48 | ious = np.zeros(number_classes)
49 | for class_id in range(number_classes):
50 | ious[class_id] = (conf_mat[class_id, class_id] / (
51 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) -
52 | conf_mat[class_id, class_id]))
53 | miou = nanmean(ious)
54 | miou_valid_class = np.mean(ious[exsiting_class_mask])
55 | return miou, miou_valid_class, total_accuracy, class_average_accuracy, ious
56 |
57 | def _fast_hist(num_classes, label_pred, label_true):
58 | # 找出标签中需要计算的类别,去掉了背景
59 | mask = (label_true >= 0) & (label_true < num_classes)
60 | # # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
61 | hist = np.bincount(
62 | num_classes * label_true[mask].astype(int) +
63 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
64 | return hist
65 |
66 |
67 | def evaluate(num_classes,predictions, gts):
68 | hist = np.zeros((num_classes, num_classes))
69 | for lp, lt in zip(predictions, gts):
70 | assert len(lp.flatten()) == len(lt.flatten())
71 | hist += _fast_hist(num_classes,lp.flatten(), lt.flatten())
72 |
73 | # miou
74 | iou = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
75 | miou = np.nanmean(iou)
76 |
77 | acc = np.diag(hist).sum() / hist.sum()
78 | acc_cls = np.nanmean(np.diag(hist) / hist.sum(axis=1))
79 |
80 | freq = hist.sum(axis=1) / hist.sum()
81 | fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
82 |
83 | return acc, acc_cls, iou, miou, fwavacc
--------------------------------------------------------------------------------
/full_eval.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | from argparse import ArgumentParser
14 |
15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
17 | tanks_and_temples_scenes = ["truck", "train"]
18 | deep_blending_scenes = ["drjohnson", "playroom"]
19 |
20 | parser = ArgumentParser(description="Full evaluation script parameters")
21 | parser.add_argument("--skip_training", action="store_true")
22 | parser.add_argument("--skip_rendering", action="store_true")
23 | parser.add_argument("--skip_metrics", action="store_true")
24 | parser.add_argument("--output_path", default="./eval")
25 | args, _ = parser.parse_known_args()
26 |
27 | all_scenes = []
28 | all_scenes.extend(mipnerf360_outdoor_scenes)
29 | all_scenes.extend(mipnerf360_indoor_scenes)
30 | all_scenes.extend(tanks_and_temples_scenes)
31 | all_scenes.extend(deep_blending_scenes)
32 |
33 | if not args.skip_training or not args.skip_rendering:
34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
36 | parser.add_argument("--deepblending", "-db", required=True, type=str)
37 | args = parser.parse_args()
38 |
39 | if not args.skip_training:
40 | common_args = " --quiet --eval --test_iterations -1 "
41 | for scene in mipnerf360_outdoor_scenes:
42 | source = args.mipnerf360 + "/" + scene
43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
44 | for scene in mipnerf360_indoor_scenes:
45 | source = args.mipnerf360 + "/" + scene
46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
47 | for scene in tanks_and_temples_scenes:
48 | source = args.tanksandtemples + "/" + scene
49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
50 | for scene in deep_blending_scenes:
51 | source = args.deepblending + "/" + scene
52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
53 |
54 | if not args.skip_rendering:
55 | all_sources = []
56 | for scene in mipnerf360_outdoor_scenes:
57 | all_sources.append(args.mipnerf360 + "/" + scene)
58 | for scene in mipnerf360_indoor_scenes:
59 | all_sources.append(args.mipnerf360 + "/" + scene)
60 | for scene in tanks_and_temples_scenes:
61 | all_sources.append(args.tanksandtemples + "/" + scene)
62 | for scene in deep_blending_scenes:
63 | all_sources.append(args.deepblending + "/" + scene)
64 |
65 | common_args = " --quiet --eval --skip_train"
66 | for scene, source in zip(all_scenes, all_sources):
67 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
68 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
69 |
70 | if not args.skip_metrics:
71 | scenes_string = ""
72 | for scene in all_scenes:
73 | scenes_string += "\"" + args.output_path + "/" + scene + "\" "
74 |
75 | os.system("python metrics.py -m " + scenes_string)
--------------------------------------------------------------------------------
/utils/graphics_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | import numpy as np
15 | from typing import NamedTuple
16 |
17 |
18 | class BasicPointCloud(NamedTuple):
19 | points: np.array
20 | colors: np.array
21 | normals: np.array
22 |
23 |
24 | def geom_transform_points(points, transf_matrix):
25 | P, _ = points.shape
26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
27 | points_hom = torch.cat([points, ones], dim=1)
28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
29 |
30 | denom = points_out[..., 3:] + 0.0000001
31 | return (points_out[..., :3] / denom).squeeze(dim=0)
32 |
33 |
34 | def getWorld2View(R, t):
35 | Rt = np.zeros((4, 4))
36 | Rt[:3, :3] = R.transpose()
37 | Rt[:3, 3] = t
38 | Rt[3, 3] = 1.0
39 | return np.float32(Rt)
40 |
41 |
42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
43 | Rt = np.zeros((4, 4))
44 | Rt[:3, :3] = R.transpose()
45 | Rt[:3, 3] = t
46 | Rt[3, 3] = 1.0
47 |
48 | C2W = np.linalg.inv(Rt)
49 | cam_center = C2W[:3, 3]
50 | cam_center = (cam_center + translate) * scale
51 | C2W[:3, 3] = cam_center
52 | Rt = np.linalg.inv(C2W)
53 | return np.float32(Rt)
54 |
55 |
56 | def getProjectionMatrix(znear, zfar, fovX, fovY):
57 | tanHalfFovY = math.tan((fovY / 2))
58 | tanHalfFovX = math.tan((fovX / 2))
59 |
60 | top = tanHalfFovY * znear
61 | bottom = -top
62 | right = tanHalfFovX * znear
63 | left = -right
64 |
65 | P = torch.zeros(4, 4)
66 |
67 | z_sign = 1.0
68 |
69 | P[0, 0] = 2.0 * znear / (right - left)
70 | P[1, 1] = 2.0 * znear / (top - bottom)
71 | P[0, 2] = (right + left) / (right - left)
72 | P[1, 2] = (top + bottom) / (top - bottom)
73 | P[3, 2] = z_sign
74 | P[2, 2] = z_sign * zfar / (zfar - znear)
75 | P[2, 3] = -(zfar * znear) / (zfar - znear)
76 | return P
77 |
78 |
79 | def fov2focal(fov, pixels):
80 | return pixels / (2 * math.tan(fov / 2))
81 |
82 |
83 | def focal2fov(focal, pixels):
84 | return 2 * math.atan(pixels / (2 * focal))
85 |
86 |
87 | def getc2w(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
88 | Rt = np.zeros((4, 4))
89 | Rt[:3, :3] = R.transpose()
90 | Rt[:3, 3] = t
91 | Rt[3, 3] = 1.0
92 |
93 | C2W = np.linalg.inv(Rt)
94 | cam_center = C2W[:3, 3]
95 | cam_center = (cam_center + translate) * scale
96 | C2W[:3, 3] = cam_center
97 |
98 | return np.float32(C2W)
99 |
100 |
101 | def views_dir(H, W, K, c2w):
102 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
103 | # 生成网格
104 |
105 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1)
106 | # 获取以相机原点为世界坐标系中心的相机光线,[H, W, 3] 3->[i-cx,j-cy,f]/f = [(i-cx)/f,(j-cy)/f,1]
107 | # Rotate ray directions from camera frame to the world frame
108 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
109 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] [H, W, 1, 3]
110 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
111 | # rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
112 | # return rays_o, rays_d
113 | # print(rays_d.shape)
114 | return rays_d
115 |
--------------------------------------------------------------------------------
/scene/cameras.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from torch import nn
14 | import numpy as np
15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix, getc2w
16 | from imgviz import label_colormap
17 | import math
18 | from mynet import MyNet, embedding_fn
19 | from utils.graphics_utils import views_dir
20 |
21 | class Camera(nn.Module):
22 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
23 | image_name, semantic_image, semantic_image_name, semantic_classes,
24 | uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda"
25 | ):
26 | super(Camera, self).__init__()
27 |
28 | self.uid = uid
29 | self.colmap_id = colmap_id
30 | self.R = R
31 | self.T = T
32 | self.FoVx = FoVx
33 | self.FoVy = FoVy
34 | self.image_name = image_name
35 | # self.semantic_image_name = semantic_image_name
36 | # self.color_map_np = label_colormap()[semantic_classes]
37 |
38 | try:
39 | self.data_device = torch.device(data_device)
40 | except Exception as e:
41 | print(e)
42 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
43 | self.data_device = torch.device("cuda")
44 |
45 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
46 | # self.semantic_image = semantic_image.to(self.data_device)
47 | self.image_width = self.original_image.shape[2]
48 | self.image_height = self.original_image.shape[1]
49 | self.K = np.array([
50 | [self.image_width / (2 * math.tan(self.FoVx / 2)), 0, 0.5 * self.image_width],
51 | [0, self.image_height / (2 * math.tan(self.FoVy / 2)), 0.5 * self.image_height],
52 | [0, 0, 1]
53 | ])
54 |
55 | if gt_alpha_mask is not None:
56 | self.original_image *= gt_alpha_mask.to(self.data_device)
57 | else:
58 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
59 |
60 | self.zfar = 100.0
61 | self.znear = 0.01
62 |
63 | self.trans = trans
64 | self.scale = scale
65 |
66 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
67 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
68 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
69 | self.camera_center = self.world_view_transform.inverse()[3, :3]
70 | self.c2w = getc2w(R, T, trans, scale)
71 |
72 | rays_d = torch.from_numpy(
73 | views_dir(self.image_height, self.image_width, self.K, self.c2w)).cuda()
74 |
75 | self.views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
76 | # print(self.K, self.c2w)
77 |
78 |
79 |
80 | class MiniCam:
81 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
82 | self.image_width = width
83 | self.image_height = height
84 | self.FoVy = fovy
85 | self.FoVx = fovx
86 | self.znear = znear
87 | self.zfar = zfar
88 | self.world_view_transform = world_view_transform
89 | self.full_proj_transform = full_proj_transform
90 | view_inv = torch.inverse(self.world_view_transform)
91 | self.camera_center = view_inv[3][:3]
92 |
93 |
94 |
--------------------------------------------------------------------------------
/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | from scene.cameras import Camera
14 | import numpy as np
15 | from utils.general_utils import PILtoTorch
16 | from utils.graphics_utils import fov2focal
17 |
18 | WARNED = False
19 |
20 | def loadCam(args, id, cam_info, resolution_scale):
21 | orig_w, orig_h = cam_info.image.size
22 |
23 | if args.resolution in [1, 2, 4, 8]:
24 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
25 | else: # should be a type that converts to float
26 | if args.resolution == -1:
27 | if orig_w > 1600:
28 | global WARNED
29 | if not WARNED:
30 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
31 | "If this is not desired, please explicitly specify '--resolution/-r' as 1")
32 | WARNED = True
33 | global_down = orig_w / 1600
34 | else:
35 | global_down = 1
36 | else:
37 | global_down = orig_w / args.resolution
38 |
39 | scale = float(global_down) * float(resolution_scale)
40 | resolution = (int(orig_w / scale), int(orig_h / scale))
41 |
42 | resized_image_rgb = PILtoTorch(cam_info.image, resolution)
43 | origin_semantic_image = cam_info.semantic_image
44 |
45 | # semantic_remap = np.empty(origin_semantic_image.shape)
46 | # for i in range(origin_semantic_image.shape[0]):
47 | # for j in range(origin_semantic_image.shape[1]):
48 | # semantic_value = origin_semantic_image[i][j]
49 | # if semantic_value in cam_info.semantic_classes:
50 | # semantic_remap[i, j] = cam_info.semantic_classes.index(semantic_value)
51 |
52 | gt_image = resized_image_rgb[:3, ...]
53 | # gt_semantic_image = torch.from_numpy(np.array(semantic_remap)).long()
54 | gt_semantic_image=None
55 | loaded_mask = None
56 |
57 | if resized_image_rgb.shape[1] == 4:
58 | loaded_mask = resized_image_rgb[3:4, ...]
59 |
60 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
61 | FoVx=cam_info.FovX, FoVy=cam_info.FovY,
62 | image=gt_image, gt_alpha_mask=loaded_mask,
63 | image_name=cam_info.image_name,
64 | semantic_image=gt_semantic_image,
65 | semantic_image_name=cam_info.semantic_image_name,
66 | semantic_classes=cam_info.semantic_classes,
67 | uid=id, data_device=args.data_device)
68 |
69 | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
70 | camera_list = []
71 |
72 | for id, c in enumerate(cam_infos):
73 | camera_list.append(loadCam(args, id, c, resolution_scale))
74 |
75 | return camera_list
76 |
77 | def camera_to_JSON(id, camera : Camera):
78 | Rt = np.zeros((4, 4))
79 | Rt[:3, :3] = camera.R.transpose()
80 | Rt[:3, 3] = camera.T
81 | Rt[3, 3] = 1.0
82 |
83 | W2C = np.linalg.inv(Rt)
84 | pos = W2C[:3, 3]
85 | rot = W2C[:3, :3]
86 | serializable_array_2d = [x.tolist() for x in rot]
87 | camera_entry = {
88 | 'id' : id,
89 | 'img_name' : camera.image_name,
90 | 'sem_img_name': camera.semantic_image_name,
91 | 'width' : camera.width,
92 | 'height' : camera.height,
93 | 'position': pos.tolist(),
94 | 'rotation': serializable_array_2d,
95 | 'fy' : fov2focal(camera.FovY, camera.height),
96 | 'fx' : fov2focal(camera.FovX, camera.width)
97 | }
98 | return camera_entry
99 |
--------------------------------------------------------------------------------
/render_fps.py:
--------------------------------------------------------------------------------
1 | # roup, https://team.inria.fr/graphdeco
2 | # All rights reserved.
3 | #
4 | # This software is free for non-commercial, research and evaluation use
5 | # under the terms of the LICENSE.md file.
6 | #
7 | # For inquiries contact george.drettakis@inria.fr
8 | #
9 |
10 | import torch
11 | import numpy as np
12 | from scene import Scene
13 | import os
14 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
15 | from tqdm import tqdm
16 | from os import makedirs
17 | from gaussian_renderer import render
18 | import torchvision
19 | import imageio
20 | from utils.general_utils import safe_state, calculate_reflection_direction, visualize_normal_map
21 | from argparse import ArgumentParser
22 | from arguments import ModelParams, PipelineParams, get_combined_args
23 | from gaussian_renderer import GaussianModel
24 | from utils.image_utils import psnr, calculate_segmentation_metrics, evaluate
25 | from model import UNet, SimpleUNet, SimpleNet
26 | from mynet import MyNet, embedding_fn
27 | from utils.graphics_utils import views_dir
28 | import time
29 |
30 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
31 | net_path = os.path.join(model_path, "model_ckpt{}pth".format(iteration))
32 | model = MyNet().to("cuda")
33 | net_weights = torch.load(net_path)
34 | # print(net_weights)
35 | model.load_state_dict(net_weights)
36 |
37 | t_list = []
38 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
39 | torch.cuda.synchronize();
40 | t0 = time.time()
41 | # color
42 | render_pkg = render(view, gaussians, pipeline, background)
43 |
44 | rendered_features, viewspace_point_tensor, visibility_filter, radii, depths, _ = render_pkg
45 |
46 | # 加入视角信息
47 | rays_d = torch.from_numpy(views_dir(view.image_height, view.image_width, view.K, view.c2w)).cuda()
48 |
49 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
50 |
51 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
52 | image = model(*rendered_features)
53 | rendering = image['im_out'].squeeze(0)
54 |
55 | torch.cuda.synchronize();
56 | t1 = time.time()
57 |
58 | t_list.append(t1 - t0)
59 |
60 | t = np.array(t_list[3:])
61 | fps = 1.0 / t.mean()
62 | print(f'Test FPS: \033[1;35m{fps:.5f}\033[0m')
63 |
64 |
65 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
66 | with torch.no_grad():
67 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes)
68 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
69 |
70 |
71 |
72 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
73 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
74 |
75 | if not skip_train:
76 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
77 |
78 | if not skip_test:
79 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
80 |
81 | if __name__ == "__main__":
82 | # Set up command line argument parser
83 | parser = ArgumentParser(description="Testing script parameters")
84 | model = ModelParams(parser, sentinel=True)
85 | pipeline = PipelineParams(parser)
86 | parser.add_argument("--iteration", default=-1, type=int)
87 | parser.add_argument("--skip_train", action="store_true")
88 | parser.add_argument("--skip_test", action="store_true")
89 | parser.add_argument("--quiet", action="store_true")
90 | args = get_combined_args(parser)
91 | print("Rendering " + args.model_path)
92 |
93 | # Initialize system state (RNG)
94 | safe_state(args.quiet)
95 |
96 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
--------------------------------------------------------------------------------
/arguments/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from argparse import ArgumentParser, Namespace
13 | import sys
14 | import os
15 |
16 | class GroupParams:
17 | pass
18 |
19 | class ParamGroup:
20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21 | group = parser.add_argument_group(name)
22 | for key, value in vars(self).items():
23 | shorthand = False
24 | if key.startswith("_"):
25 | shorthand = True
26 | key = key[1:]
27 | t = type(value)
28 | value = value if not fill_none else None
29 | if shorthand:
30 | if t == bool:
31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32 | else:
33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34 | else:
35 | if t == bool:
36 | group.add_argument("--" + key, default=value, action="store_true")
37 | else:
38 | group.add_argument("--" + key, default=value, type=t)
39 |
40 | def extract(self, args):
41 | group = GroupParams()
42 | for arg in vars(args).items():
43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44 | setattr(group, arg[0], arg[1])
45 | return group
46 |
47 | class ModelParams(ParamGroup):
48 | def __init__(self, parser, sentinel=False):
49 | self.sh_degree = 3
50 | self.num_sem_classes = 16
51 | self._source_path = ""
52 | self._model_path = ""
53 | self._images = "images"
54 | # self._semantic_class = "semantic_class"
55 | self._resolution = -1
56 | self._white_background = False
57 | self.data_device = "cuda"
58 | self.eval = False
59 | super().__init__(parser, "Loading Parameters", sentinel)
60 |
61 | def extract(self, args):
62 | g = super().extract(args)
63 | g.source_path = os.path.abspath(g.source_path)
64 | return g
65 |
66 | class PipelineParams(ParamGroup):
67 | def __init__(self, parser):
68 | self.convert_SHs_python = False
69 | self.compute_cov3D_python = False
70 | self.debug = False
71 | super().__init__(parser, "Pipeline Parameters")
72 |
73 | class OptimizationParams(ParamGroup):
74 | def __init__(self, parser):
75 | self.iterations = 30_000
76 | self.position_lr_init = 0.00016
77 | self.position_lr_final = 0.0000016
78 | self.position_lr_delay_mult = 0.01
79 | self.position_lr_max_steps = 30_000 # 30_000
80 | self.feature_lr = 0.0025
81 | self.semantic_lr = 0.0025 #0.0025
82 | self.opacity_lr = 0.025 # 0.025
83 | self.scaling_lr = 0.005
84 | self.rotation_lr = 0.001
85 | self.percent_dense = 0.01
86 | self.lambda_dssim = 0.2
87 | self.densification_interval = 100
88 | self.opacity_reset_interval = 3000
89 | self.densify_from_iter = 500
90 | self.densify_until_iter = 15_000 # 15_000
91 | self.densify_grad_threshold = 0.0002
92 | super().__init__(parser, "Optimization Parameters")
93 |
94 | def get_combined_args(parser : ArgumentParser):
95 | cmdlne_string = sys.argv[1:]
96 | cfgfile_string = "Namespace()"
97 | args_cmdline = parser.parse_args(cmdlne_string)
98 |
99 | try:
100 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
101 | print("Looking for config file in", cfgfilepath)
102 | with open(cfgfilepath) as cfg_file:
103 | print("Config file found: {}".format(cfgfilepath))
104 | cfgfile_string = cfg_file.read()
105 | except TypeError:
106 | print("Config file not found at")
107 | pass
108 | args_cfgfile = eval(cfgfile_string)
109 |
110 | merged_dict = vars(args_cfgfile).copy()
111 | for k,v in vars(args_cmdline).items():
112 | if v != None:
113 | merged_dict[k] = v
114 | return Namespace(**merged_dict)
115 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | from pathlib import Path
13 | import os
14 | from PIL import Image
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from utils.loss_utils import ssim
18 | from lpipsPyTorch import lpips
19 | import json
20 | from tqdm import tqdm
21 | from utils.image_utils import psnr
22 | from argparse import ArgumentParser
23 |
24 | def readImages(renders_dir, gt_dir):
25 | renders = []
26 | gts = []
27 | image_names = []
28 | for fname in os.listdir(renders_dir):
29 | render = Image.open(renders_dir / fname)
30 | gt = Image.open(gt_dir / fname)
31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
33 | image_names.append(fname)
34 | return renders, gts, image_names
35 |
36 | def evaluate(model_paths):
37 |
38 | full_dict = {}
39 | per_view_dict = {}
40 | full_dict_polytopeonly = {}
41 | per_view_dict_polytopeonly = {}
42 | print("")
43 |
44 | for scene_dir in model_paths:
45 | try:
46 | print("Scene:", scene_dir)
47 | full_dict[scene_dir] = {}
48 | per_view_dict[scene_dir] = {}
49 | full_dict_polytopeonly[scene_dir] = {}
50 | per_view_dict_polytopeonly[scene_dir] = {}
51 |
52 | test_dir = Path(scene_dir) / "test"
53 |
54 | for method in os.listdir(test_dir):
55 | print("Method:", method)
56 |
57 | full_dict[scene_dir][method] = {}
58 | per_view_dict[scene_dir][method] = {}
59 | full_dict_polytopeonly[scene_dir][method] = {}
60 | per_view_dict_polytopeonly[scene_dir][method] = {}
61 |
62 | method_dir = test_dir / method
63 | gt_dir = method_dir/ "gt"
64 | renders_dir = method_dir / "renders"
65 | renders, gts, image_names = readImages(renders_dir, gt_dir)
66 |
67 | ssims = []
68 | psnrs = []
69 | lpipss = []
70 |
71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
72 | ssims.append(ssim(renders[idx], gts[idx]))
73 | psnrs.append(psnr(renders[idx], gts[idx]))
74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
75 |
76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
79 | print("")
80 |
81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
82 | "PSNR": torch.tensor(psnrs).mean().item(),
83 | "LPIPS": torch.tensor(lpipss).mean().item()})
84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
87 |
88 | with open(scene_dir + "/results.json", 'w') as fp:
89 | json.dump(full_dict[scene_dir], fp, indent=True)
90 | with open(scene_dir + "/per_view.json", 'w') as fp:
91 | json.dump(per_view_dict[scene_dir], fp, indent=True)
92 | except:
93 | print("Unable to compute metrics for model", scene_dir)
94 |
95 | if __name__ == "__main__":
96 | device = torch.device("cuda:0")
97 | torch.cuda.set_device(device)
98 |
99 | # Set up command line argument parser
100 | parser = ArgumentParser(description="Training script parameters")
101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
102 | args = parser.parse_args()
103 | evaluate(args.model_paths)
104 |
--------------------------------------------------------------------------------
/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import math
14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15 | # from typing import NamedTuple
16 | # from diff_gaussian_rasterization import GaussianRasterizer
17 | from scene.gaussian_model import GaussianModel
18 | from utils.sh_utils import eval_sh
19 | from utils.general_utils import compute_gaussian_normals, flip_align_view
20 | import torch.nn.functional as F
21 | import time
22 |
23 |
24 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
25 | """
26 | Render the scene.
27 |
28 | Background tensor (bg_color) must be on GPU!
29 | """
30 |
31 | scale_factor = [1]
32 |
33 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
34 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
35 | try:
36 | screenspace_points.retain_grad()
37 | except:
38 | pass
39 |
40 | # Set up rasterization configuration
41 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
42 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
43 |
44 | rasterizers = []
45 | for scale in scale_factor:
46 | raster_setting = GaussianRasterizationSettings(
47 | image_height=int(viewpoint_camera.image_height/scale),
48 | image_width=int(viewpoint_camera.image_width/scale),
49 | tanfovx=tanfovx,
50 | tanfovy=tanfovy,
51 | bg=bg_color,
52 | scale_modifier=scaling_modifier,
53 | viewmatrix=viewpoint_camera.world_view_transform,
54 | projmatrix=viewpoint_camera.full_proj_transform,
55 | sh_degree=pc.active_sh_degree,
56 | # num_sem_classes=pc.num_sem_classes,
57 | campos=viewpoint_camera.camera_center,
58 | prefiltered=False,
59 | debug=pipe.debug
60 | )
61 | rasterizer = GaussianRasterizer(raster_settings=raster_setting)
62 | rasterizers.append(rasterizer)
63 |
64 | means3D = pc.get_xyz
65 | means2D = screenspace_points
66 | opacity = pc.get_opacity
67 |
68 | # xyz = pc.get_xyz
69 | dir_pp = (means3D - viewpoint_camera.camera_center.repeat(means3D.shape[0], 1))
70 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True)
71 | color_features = pc.get_semantic
72 |
73 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
74 | # scaling / rotation by the rasterizer.
75 | scales = None
76 | rotations = None
77 | cov3D_precomp = None
78 | if pipe.compute_cov3D_python:
79 | cov3D_precomp = pc.get_covariance(scaling_modifier)
80 | else:
81 | scales = pc.get_scaling
82 | rotations = pc.get_rotation
83 |
84 | normals = pc.get_normal
85 |
86 | direction_features = pc.mlp_direction_head(
87 | torch.cat([pc.direction_encoding(dir_pp), normals], dim=-1)).float()
88 |
89 | Ln = 0
90 |
91 | semantic = torch.cat([color_features, direction_features, normals], dim=-1)
92 |
93 | rendered_features = []
94 | rendered_depths = []
95 | visibility_filter = 0
96 | radii = 0
97 |
98 | for i in range(1):
99 | semantic_logits, rendered_depth, rendered_alpha, radii = rasterizers[i](
100 | means3D = means3D,
101 | means2D = means2D,
102 | # shs = shs,
103 | # colors_precomp = colors_precomp,
104 | semantic = semantic,
105 | opacities = opacity,
106 | scales = scales,
107 | rotations = rotations,
108 | cov3D_precomp = cov3D_precomp)
109 | rendered_features.append(semantic_logits.unsqueeze(0))
110 | rendered_depths.append(rendered_depth.unsqueeze(0))
111 | if i == 0:
112 | visibility_filter = radii > 0
113 | radii = radii
114 | # rendered_image: [3, 376, 1408] [3, image_height, image_width]
115 | # radii: [10458]
116 |
117 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
118 | # They will be excluded from value updates used in the splitting criteria.
119 | return rendered_features, screenspace_points, visibility_filter, radii, rendered_depths, Ln
120 |
121 |
122 |
--------------------------------------------------------------------------------
/utils/sh_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The PlenOctree Authors.
2 | # Redistribution and use in source and binary forms, with or without
3 | # modification, are permitted provided that the following conditions are met:
4 | #
5 | # 1. Redistributions of source code must retain the above copyright notice,
6 | # this list of conditions and the following disclaimer.
7 | #
8 | # 2. Redistributions in binary form must reproduce the above copyright notice,
9 | # this list of conditions and the following disclaimer in the documentation
10 | # and/or other materials provided with the distribution.
11 | #
12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22 | # POSSIBILITY OF SUCH DAMAGE.
23 |
24 | import torch
25 |
26 | C0 = 0.28209479177387814
27 | C1 = 0.4886025119029199
28 | C2 = [
29 | 1.0925484305920792,
30 | -1.0925484305920792,
31 | 0.31539156525252005,
32 | -1.0925484305920792,
33 | 0.5462742152960396
34 | ]
35 | C3 = [
36 | -0.5900435899266435,
37 | 2.890611442640554,
38 | -0.4570457994644658,
39 | 0.3731763325901154,
40 | -0.4570457994644658,
41 | 1.445305721320277,
42 | -0.5900435899266435
43 | ]
44 | C4 = [
45 | 2.5033429417967046,
46 | -1.7701307697799304,
47 | 0.9461746957575601,
48 | -0.6690465435572892,
49 | 0.10578554691520431,
50 | -0.6690465435572892,
51 | 0.47308734787878004,
52 | -1.7701307697799304,
53 | 0.6258357354491761,
54 | ]
55 |
56 |
57 | def eval_sh(deg, sh, dirs):
58 | """
59 | Evaluate spherical harmonics at unit directions
60 | using hardcoded SH polynomials.
61 | Works with torch/np/jnp.
62 | ... Can be 0 or more batch dimensions.
63 | Args:
64 | deg: int SH deg. Currently, 0-3 supported
65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66 | dirs: jnp.ndarray unit directions [..., 3]
67 | Returns:
68 | [..., C]
69 | """
70 | assert deg <= 4 and deg >= 0
71 | coeff = (deg + 1) ** 2
72 | assert sh.shape[-1] >= coeff
73 |
74 | result = C0 * sh[..., 0]
75 | if deg > 0:
76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77 | result = (result -
78 | C1 * y * sh[..., 1] +
79 | C1 * z * sh[..., 2] -
80 | C1 * x * sh[..., 3])
81 |
82 | if deg > 1:
83 | xx, yy, zz = x * x, y * y, z * z
84 | xy, yz, xz = x * y, y * z, x * z
85 | result = (result +
86 | C2[0] * xy * sh[..., 4] +
87 | C2[1] * yz * sh[..., 5] +
88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89 | C2[3] * xz * sh[..., 7] +
90 | C2[4] * (xx - yy) * sh[..., 8])
91 |
92 | if deg > 2:
93 | result = (result +
94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95 | C3[1] * xy * z * sh[..., 10] +
96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99 | C3[5] * z * (xx - yy) * sh[..., 14] +
100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101 |
102 | if deg > 3:
103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112 | return result
113 |
114 | def RGB2SH(rgb):
115 | return (rgb - 0.5) / C0
116 |
117 | def SH2RGB(sh):
118 | return sh * C0 + 0.5
--------------------------------------------------------------------------------
/scene/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import random
14 | import json
15 | import torch
16 | from utils.system_utils import searchForMaxIteration
17 | from scene.dataset_readers import sceneLoadTypeCallbacks
18 | from scene.gaussian_model import GaussianModel
19 | from arguments import ModelParams
20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
21 |
22 | class Scene:
23 |
24 | gaussians : GaussianModel
25 |
26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
27 | """b
28 | :param path: Path to colmap scene main folder.
29 | """
30 | self.model_path = args.model_path
31 | self.loaded_iter = None
32 | self.gaussians = gaussians
33 |
34 | if load_iteration:
35 | if load_iteration == -1:
36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
37 | else:
38 | self.loaded_iter = load_iteration
39 | print("Loading trained model at iteration {}".format(self.loaded_iter))
40 |
41 | self.train_cameras = {}
42 | self.test_cameras = {}
43 |
44 | if os.path.exists(os.path.join(args.source_path, "sparse")):
45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
47 | print("Found transforms_train.json file, assuming Blender data set!")
48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
49 | else:
50 | assert False, "Could not recognize scene type!"
51 |
52 | if not self.loaded_iter:
53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
54 | dest_file.write(src_file.read())
55 | json_cams = []
56 | camlist = []
57 | if scene_info.test_cameras:
58 | camlist.extend(scene_info.test_cameras)
59 | if scene_info.train_cameras:
60 | camlist.extend(scene_info.train_cameras)
61 | for id, cam in enumerate(camlist):
62 | json_cams.append(camera_to_JSON(id, cam))
63 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
64 | json.dump(json_cams, file)
65 |
66 | if shuffle:
67 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
68 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
69 |
70 | self.cameras_extent = scene_info.nerf_normalization["radius"]
71 |
72 | for resolution_scale in resolution_scales:
73 | print("Loading Training Cameras")
74 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
75 | print("Loading Test Cameras")
76 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
77 |
78 | if self.loaded_iter:
79 | self.gaussians.load_ply(os.path.join(self.model_path,
80 | "point_cloud",
81 | "iteration_" + str(self.loaded_iter),
82 | "point_cloud.ply"))
83 | # torch.nn.ModuleList([self.gaussians.recolor, self.gaussians.mlp_head, self.gaussians.mlp_direction_head, self.gaussians.direction_encoding]).load_state_dict(
84 | # torch.load(os.path.join(self.model_path,
85 | # "point_cloud",
86 | # "iteration_" + str(self.loaded_iter),
87 | # "point_cloud.pth")))
88 | torch.nn.ModuleList([ self.gaussians.mlp_direction_head, self.gaussians.direction_encoding, self.gaussians.mlp_normal_head, self.gaussians.positional_encoding]).load_state_dict(
89 | torch.load(os.path.join(self.model_path,
90 | "point_cloud",
91 | "iteration_" + str(self.loaded_iter),
92 | "point_cloud.pth")))
93 | else:
94 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
95 |
96 | def save(self, iteration):
97 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
98 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
99 |
100 | torch.save(torch.nn.ModuleList([self.gaussians.mlp_direction_head, self.gaussians.direction_encoding, self.gaussians.mlp_normal_head, self.gaussians.positional_encoding]).state_dict(),
101 | os.path.join(point_cloud_path, "point_cloud.pth"))
102 |
103 | def getTrainCameras(self, scale=1.0):
104 | return self.train_cameras[scale]
105 |
106 | def getTestCameras(self, scale=1.0):
107 | return self.test_cameras[scale]
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import numpy as np
14 | from scene import Scene
15 | import time
16 | import os
17 |
18 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
19 | from tqdm import tqdm
20 | from os import makedirs
21 | from gaussian_renderer import render
22 | import torchvision
23 | import imageio
24 | from utils.general_utils import safe_state, calculate_reflection_direction, visualize_normal_map
25 | from argparse import ArgumentParser
26 | from arguments import ModelParams, PipelineParams, get_combined_args
27 | from gaussian_renderer import GaussianModel
28 | from utils.image_utils import psnr, calculate_segmentation_metrics, evaluate
29 | from mynet import MyNet, embedding_fn
30 | from utils.graphics_utils import views_dir
31 | import time
32 |
33 |
34 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
35 | net_path = os.path.join(model_path, "model_ckpt{}pth".format(iteration))
36 | model = MyNet().to("cuda")
37 | net_weights = torch.load(net_path)
38 | model.load_state_dict(net_weights)
39 |
40 | model.eval()
41 |
42 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
43 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
44 |
45 | mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "mask")
46 | f_path = os.path.join(model_path, name, "ours_{}".format(iteration), "f")
47 | highlight_path = os.path.join(model_path, name, "ours_{}".format(iteration), "highlight")
48 | color_path = os.path.join(model_path, name, "ours_{}".format(iteration), "color")
49 | normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "normal")
50 |
51 | makedirs(render_path, exist_ok=True)
52 | makedirs(gts_path, exist_ok=True)
53 |
54 | makedirs(mask_path, exist_ok=True)
55 | makedirs(f_path, exist_ok=True)
56 | makedirs(highlight_path, exist_ok=True)
57 | makedirs(color_path, exist_ok=True)
58 | makedirs(normal_path, exist_ok=True)
59 |
60 | all_time = []
61 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
62 | start_time = time.time()
63 |
64 | render_pkg = render(view, gaussians, pipeline, background)
65 |
66 | rendered_features, viewspace_point_tensor, visibility_filter, radii, depths, _ = render_pkg
67 |
68 | views_emd = view.views_emd
69 |
70 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
71 | image = model(*rendered_features)
72 | rendering = image['im_out'].squeeze(0)
73 |
74 | all_time.append(time.time() - start_time)
75 |
76 | gt = view.original_image[0:3, :, :]
77 |
78 |
79 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
80 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
81 | #
82 | torchvision.utils.save_image(image['mask_out'].squeeze(0),
83 | os.path.join(mask_path, '{0:05d}'.format(idx) + ".png"))
84 |
85 | torchvision.utils.save_image(image['f_out'].squeeze(0), os.path.join(f_path, '{0:05d}'.format(idx) + ".png"))
86 | torchvision.utils.save_image(image['highlight_out'].squeeze(0),
87 | os.path.join(highlight_path, '{0:05d}'.format(idx) + ".png"))
88 | torchvision.utils.save_image(image['color_out'].squeeze(0),
89 | os.path.join(color_path, '{0:05d}'.format(idx) + ".png"))
90 | print("Average time per image: {}".format(sum(all_time) / len(all_time)))
91 | print("Render FPS: {}".format(1 / (sum(all_time) / len(all_time))))
92 |
93 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool):
94 | with torch.no_grad():
95 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes)
96 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
97 |
98 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
99 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
100 |
101 | # for camera in scene.getTestCameras():
102 | # print(camera.R)
103 |
104 | if not skip_train:
105 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline,
106 | background)
107 |
108 | if not skip_test:
109 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline,
110 | background)
111 |
112 |
113 | if __name__ == "__main__":
114 | # Set up command line argument parser
115 | parser = ArgumentParser(description="Testing script parameters")
116 | model = ModelParams(parser, sentinel=True)
117 | pipeline = PipelineParams(parser)
118 | parser.add_argument("--iteration", default=-1, type=int)
119 | parser.add_argument("--skip_train", action="store_true")
120 | parser.add_argument("--skip_test", action="store_true")
121 | parser.add_argument("--quiet", action="store_true")
122 | args = get_combined_args(parser)
123 | print("Rendering " + args.model_path)
124 |
125 | # Initialize system state (RNG)
126 | safe_state(args.quiet)
127 |
128 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
129 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import logging
14 | from argparse import ArgumentParser
15 | import shutil
16 |
17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository.
18 | parser = ArgumentParser("Colmap converter")
19 | parser.add_argument("--no_gpu", action='store_true')
20 | parser.add_argument("--skip_matching", action='store_true')
21 | parser.add_argument("--source_path", "-s", required=True, type=str)
22 | parser.add_argument("--camera", default="OPENCV", type=str)
23 | parser.add_argument("--colmap_executable", default="", type=str)
24 | parser.add_argument("--resize", action="store_true")
25 | parser.add_argument("--magick_executable", default="", type=str)
26 | args = parser.parse_args()
27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
29 | use_gpu = 1 if not args.no_gpu else 0
30 |
31 | if not args.skip_matching:
32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
33 |
34 | ## Feature extraction
35 | feat_extracton_cmd = colmap_command + " feature_extractor "\
36 | "--database_path " + args.source_path + "/distorted/database.db \
37 | --image_path " + args.source_path + "/input \
38 | --ImageReader.single_camera 1 \
39 | --ImageReader.camera_model " + args.camera + " \
40 | --SiftExtraction.use_gpu " + str(use_gpu)
41 | exit_code = os.system(feat_extracton_cmd)
42 | if exit_code != 0:
43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
44 | exit(exit_code)
45 |
46 | ## Feature matching
47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \
48 | --database_path " + args.source_path + "/distorted/database.db \
49 | --SiftMatching.use_gpu " + str(use_gpu)
50 | exit_code = os.system(feat_matching_cmd)
51 | if exit_code != 0:
52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
53 | exit(exit_code)
54 |
55 | ### Bundle adjustment
56 | # The default Mapper tolerance is unnecessarily large,
57 | # decreasing it speeds up bundle adjustment steps.
58 | mapper_cmd = (colmap_command + " mapper \
59 | --database_path " + args.source_path + "/distorted/database.db \
60 | --image_path " + args.source_path + "/input \
61 | --output_path " + args.source_path + "/distorted/sparse \
62 | --Mapper.ba_global_function_tolerance=0.000001")
63 | exit_code = os.system(mapper_cmd)
64 | if exit_code != 0:
65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
66 | exit(exit_code)
67 |
68 | ### Image undistortion
69 | ## We need to undistort our images into ideal pinhole intrinsics.
70 | img_undist_cmd = (colmap_command + " image_undistorter \
71 | --image_path " + args.source_path + "/input \
72 | --input_path " + args.source_path + "/distorted/sparse/0 \
73 | --output_path " + args.source_path + "\
74 | --output_type COLMAP")
75 | exit_code = os.system(img_undist_cmd)
76 | if exit_code != 0:
77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.")
78 | exit(exit_code)
79 |
80 | files = os.listdir(args.source_path + "/sparse")
81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
82 | # Copy each file from the source directory to the destination directory
83 | for file in files:
84 | if file == '0':
85 | continue
86 | source_file = os.path.join(args.source_path, "sparse", file)
87 | destination_file = os.path.join(args.source_path, "sparse", "0", file)
88 | shutil.move(source_file, destination_file)
89 |
90 | if(args.resize):
91 | print("Copying and resizing...")
92 |
93 | # Resize images.
94 | os.makedirs(args.source_path + "/images_2", exist_ok=True)
95 | os.makedirs(args.source_path + "/images_4", exist_ok=True)
96 | os.makedirs(args.source_path + "/images_8", exist_ok=True)
97 | # Get the list of files in the source directory
98 | files = os.listdir(args.source_path + "/images")
99 | # Copy each file from the source directory to the destination directory
100 | for file in files:
101 | source_file = os.path.join(args.source_path, "images", file)
102 |
103 | destination_file = os.path.join(args.source_path, "images_2", file)
104 | shutil.copy2(source_file, destination_file)
105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
106 | if exit_code != 0:
107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.")
108 | exit(exit_code)
109 |
110 | destination_file = os.path.join(args.source_path, "images_4", file)
111 | shutil.copy2(source_file, destination_file)
112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
113 | if exit_code != 0:
114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.")
115 | exit(exit_code)
116 |
117 | destination_file = os.path.join(args.source_path, "images_8", file)
118 | shutil.copy2(source_file, destination_file)
119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
120 | if exit_code != 0:
121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
122 | exit(exit_code)
123 |
124 | print("Done.")
125 |
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import sys
14 | from datetime import datetime
15 | import numpy as np
16 | import random
17 | import torch.nn.functional as F
18 |
19 | def inverse_sigmoid(x):
20 | return torch.log(x/(1-x))
21 |
22 | def PILtoTorch(pil_image, resolution):
23 | resized_image_PIL = pil_image.resize(resolution)
24 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
25 | if len(resized_image.shape) == 3:
26 | return resized_image.permute(2, 0, 1)
27 | else:
28 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
29 |
30 | def get_expon_lr_func(
31 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
32 | ):
33 | """
34 | Copied from Plenoxels
35 |
36 | Continuous learning rate decay function. Adapted from JaxNeRF
37 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
38 | is log-linearly interpolated elsewhere (equivalent to exponential decay).
39 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth
40 | function of lr_delay_mult, such that the initial learning rate is
41 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back
42 | to the normal learning rate when steps>lr_delay_steps.
43 | :param conf: config subtree 'lr' or similar
44 | :param max_steps: int, the number of steps during optimization.
45 | :return HoF which takes step as input
46 | """
47 |
48 | def helper(step):
49 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
50 | # Disable this parameter
51 | return 0.0
52 | if lr_delay_steps > 0:
53 | # A kind of reverse cosine decay.
54 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
55 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
56 | )
57 | else:
58 | delay_rate = 1.0
59 | t = np.clip(step / max_steps, 0, 1)
60 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
61 | return delay_rate * log_lerp
62 |
63 | return helper
64 |
65 | def strip_lowerdiag(L):
66 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
67 |
68 | uncertainty[:, 0] = L[:, 0, 0]
69 | uncertainty[:, 1] = L[:, 0, 1]
70 | uncertainty[:, 2] = L[:, 0, 2]
71 | uncertainty[:, 3] = L[:, 1, 1]
72 | uncertainty[:, 4] = L[:, 1, 2]
73 | uncertainty[:, 5] = L[:, 2, 2]
74 | return uncertainty
75 |
76 | def strip_symmetric(sym):
77 | return strip_lowerdiag(sym)
78 |
79 | def build_rotation(r):
80 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
81 |
82 | q = r / norm[:, None]
83 |
84 | R = torch.zeros((q.size(0), 3, 3), device='cuda')
85 |
86 | r = q[:, 0]
87 | x = q[:, 1]
88 | y = q[:, 2]
89 | z = q[:, 3]
90 |
91 | R[:, 0, 0] = 1 - 2 * (y*y + z*z)
92 | R[:, 0, 1] = 2 * (x*y - r*z)
93 | R[:, 0, 2] = 2 * (x*z + r*y)
94 | R[:, 1, 0] = 2 * (x*y + r*z)
95 | R[:, 1, 1] = 1 - 2 * (x*x + z*z)
96 | R[:, 1, 2] = 2 * (y*z - r*x)
97 | R[:, 2, 0] = 2 * (x*z - r*y)
98 | R[:, 2, 1] = 2 * (y*z + r*x)
99 | R[:, 2, 2] = 1 - 2 * (x*x + y*y)
100 | return R
101 |
102 | def build_scaling_rotation(s, r):
103 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
104 | R = build_rotation(r)
105 |
106 | L[:,0,0] = s[:,0]
107 | L[:,1,1] = s[:,1]
108 | L[:,2,2] = s[:,2]
109 |
110 | L = R @ L
111 | return L
112 |
113 | def safe_state(silent):
114 | old_f = sys.stdout
115 | class F:
116 | def __init__(self, silent):
117 | self.silent = silent
118 |
119 | def write(self, x):
120 | if not self.silent:
121 | if x.endswith("\n"):
122 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
123 | else:
124 | old_f.write(x)
125 |
126 | def flush(self):
127 | old_f.flush()
128 |
129 | sys.stdout = F(silent)
130 |
131 | random.seed(0)
132 | np.random.seed(0)
133 | torch.manual_seed(0)
134 | torch.cuda.set_device(torch.device("cuda:0"))
135 |
136 | # def compute_gaussian_normals(cov_matrices):
137 | # eigenvalues, eigenvectors = np.linalg.eig(cov_matrices)
138 | # min_eigenvalue_indices = np.argmin(eigenvalues, axis=1)
139 | # min_eigenvectors = eigenvectors[np.arange(cov_matrices.shape[0]), :, min_eigenvalue_indices]
140 | # normals = min_eigenvectors / np.linalg.norm(min_eigenvectors, axis=1)[:, np.newaxis]
141 | # # print(normals.shape)
142 | # return normals
143 |
144 | def compute_gaussian_normals(s, r):
145 | L = build_scaling_rotation(s, r)
146 | actual_covariance = L @ L.transpose(1, 2)
147 |
148 | covariance_matrix = actual_covariance
149 | eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix, UPLO='U')
150 | min_eigenvalue_idx = torch.argmin(eigenvalues)
151 | min_eigenvector = eigenvectors[:, min_eigenvalue_idx]
152 | normal = min_eigenvector / torch.norm(min_eigenvector)
153 | normal = normal.cuda()
154 | return normal
155 |
156 | def get_minimum_axis(scales, rotations):
157 | sorted_idx = torch.argsort(scales, descending=False, dim=-1)
158 | R = build_rotation(rotations)
159 | R_sorted = torch.gather(R, dim=2, index=sorted_idx[:,None,:].repeat(1, 3, 1)).squeeze()
160 | x_axis = R_sorted[:,0,:] # normalized by defaut
161 |
162 | return x_axis
163 |
164 | def flip_align_view(normal, viewdir):
165 | # normal: (N, 3), viewdir: (N, 3)
166 | dotprod = torch.sum(normal * (-viewdir), dim=-1, keepdims=True) # (N, 1)
167 | non_flip = dotprod>=0 # (N, 1)
168 | normal_flipped = normal*torch.where(non_flip, 1, -1) # (N, 3)
169 | return normal_flipped, non_flip
170 |
171 | def calculate_reflection_direction(wo, n):
172 | wo_dot_n = torch.sum(wo * n, dim=-1, keepdim=True)
173 | reflection_direction = 2 * wo_dot_n * n - wo
174 | return reflection_direction
175 |
176 | # def calculate_reflection_direction(wo, n):
177 | # normal_dot_viewdir = ((-wo) * n).sum(dim=-1, keepdim=True)
178 | # return normal_dot_viewdir
179 |
180 |
181 | import torch
182 |
183 | def visualize_normal_map(normal_map):
184 | # normal_map = F.normalize(normal_map, dim=0)
185 | normal_map = (normal_map + 1.0) /2.0
186 | normal_map = torch.clamp(normal_map, 0, 1)
187 | return normal_map
--------------------------------------------------------------------------------
/scene/colmap_loader.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import numpy as np
13 | import collections
14 | import struct
15 |
16 | CameraModel = collections.namedtuple(
17 | "CameraModel", ["model_id", "model_name", "num_params"])
18 | Camera = collections.namedtuple(
19 | "Camera", ["id", "model", "width", "height", "params"])
20 | BaseImage = collections.namedtuple(
21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22 | Point3D = collections.namedtuple(
23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 | def qvec2rotmat(qvec):
44 | return np.array([
45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54 |
55 | def rotmat2qvec(R):
56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57 | K = np.array([
58 | [Rxx - Ryy - Rzz, 0, 0, 0],
59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62 | eigvals, eigvecs = np.linalg.eigh(K)
63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64 | if qvec[0] < 0:
65 | qvec *= -1
66 | return qvec
67 |
68 | class Image(BaseImage):
69 | def qvec2rotmat(self):
70 | return qvec2rotmat(self.qvec)
71 |
72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73 | """Read and unpack the next bytes from a binary file.
74 | :param fid:
75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77 | :param endian_character: Any of {@, =, <, >, !}
78 | :return: Tuple of read and unpacked values.
79 | """
80 | data = fid.read(num_bytes)
81 | return struct.unpack(endian_character + format_char_sequence, data)
82 |
83 | def read_points3D_text(path):
84 | """
85 | see: src/base/reconstruction.cc
86 | void Reconstruction::ReadPoints3DText(const std::string& path)
87 | void Reconstruction::WritePoints3DText(const std::string& path)
88 | """
89 | xyzs = None
90 | rgbs = None
91 | errors = None
92 | num_points = 0
93 | with open(path, "r") as fid:
94 | while True:
95 | line = fid.readline()
96 | if not line:
97 | break
98 | line = line.strip()
99 | if len(line) > 0 and line[0] != "#":
100 | num_points += 1
101 |
102 |
103 | xyzs = np.empty((num_points, 3))
104 | rgbs = np.empty((num_points, 3))
105 | errors = np.empty((num_points, 1))
106 | count = 0
107 | with open(path, "r") as fid:
108 | while True:
109 | line = fid.readline()
110 | if not line:
111 | break
112 | line = line.strip()
113 | if len(line) > 0 and line[0] != "#":
114 | elems = line.split()
115 | xyz = np.array(tuple(map(float, elems[1:4])))
116 | rgb = np.array(tuple(map(int, elems[4:7])))
117 | error = np.array(float(elems[7]))
118 | xyzs[count] = xyz
119 | rgbs[count] = rgb
120 | errors[count] = error
121 | count += 1
122 |
123 | return xyzs, rgbs, errors
124 |
125 | def read_points3D_binary(path_to_model_file):
126 | """
127 | see: src/base/reconstruction.cc
128 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
129 | void Reconstruction::WritePoints3DBinary(const std::string& path)
130 | """
131 |
132 |
133 | with open(path_to_model_file, "rb") as fid:
134 | num_points = read_next_bytes(fid, 8, "Q")[0]
135 |
136 | xyzs = np.empty((num_points, 3))
137 | rgbs = np.empty((num_points, 3))
138 | errors = np.empty((num_points, 1))
139 |
140 | for p_id in range(num_points):
141 | binary_point_line_properties = read_next_bytes(
142 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
143 | xyz = np.array(binary_point_line_properties[1:4])
144 | rgb = np.array(binary_point_line_properties[4:7])
145 | error = np.array(binary_point_line_properties[7])
146 | track_length = read_next_bytes(
147 | fid, num_bytes=8, format_char_sequence="Q")[0]
148 | track_elems = read_next_bytes(
149 | fid, num_bytes=8*track_length,
150 | format_char_sequence="ii"*track_length)
151 | xyzs[p_id] = xyz
152 | rgbs[p_id] = rgb
153 | errors[p_id] = error
154 | return xyzs, rgbs, errors
155 |
156 | def read_intrinsics_text(path):
157 | """
158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159 | """
160 | cameras = {}
161 | with open(path, "r") as fid:
162 | while True:
163 | line = fid.readline()
164 | if not line:
165 | break
166 | line = line.strip()
167 | if len(line) > 0 and line[0] != "#":
168 | elems = line.split()
169 | camera_id = int(elems[0])
170 | model = elems[1]
171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172 | width = int(elems[2])
173 | height = int(elems[3])
174 | params = np.array(tuple(map(float, elems[4:])))
175 | cameras[camera_id] = Camera(id=camera_id, model=model,
176 | width=width, height=height,
177 | params=params)
178 | return cameras
179 |
180 | def read_extrinsics_binary(path_to_model_file):
181 | """
182 | see: src/base/reconstruction.cc
183 | void Reconstruction::ReadImagesBinary(const std::string& path)
184 | void Reconstruction::WriteImagesBinary(const std::string& path)
185 | """
186 | images = {}
187 | with open(path_to_model_file, "rb") as fid:
188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189 | for _ in range(num_reg_images):
190 | binary_image_properties = read_next_bytes(
191 | fid, num_bytes=64, format_char_sequence="idddddddi")
192 | image_id = binary_image_properties[0]
193 | qvec = np.array(binary_image_properties[1:5])
194 | tvec = np.array(binary_image_properties[5:8])
195 | camera_id = binary_image_properties[8]
196 | image_name = ""
197 | current_char = read_next_bytes(fid, 1, "c")[0]
198 | while current_char != b"\x00": # look for the ASCII 0 entry
199 | image_name += current_char.decode("utf-8")
200 | current_char = read_next_bytes(fid, 1, "c")[0]
201 | num_points2D = read_next_bytes(fid, num_bytes=8,
202 | format_char_sequence="Q")[0]
203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204 | format_char_sequence="ddq"*num_points2D)
205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206 | tuple(map(float, x_y_id_s[1::3]))])
207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208 | images[image_id] = Image(
209 | id=image_id, qvec=qvec, tvec=tvec,
210 | camera_id=camera_id, name=image_name,
211 | xys=xys, point3D_ids=point3D_ids)
212 | return images
213 |
214 |
215 | def read_intrinsics_binary(path_to_model_file):
216 | """
217 | see: src/base/reconstruction.cc
218 | void Reconstruction::WriteCamerasBinary(const std::string& path)
219 | void Reconstruction::ReadCamerasBinary(const std::string& path)
220 | """
221 | cameras = {}
222 | with open(path_to_model_file, "rb") as fid:
223 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
224 | for _ in range(num_cameras):
225 | camera_properties = read_next_bytes(
226 | fid, num_bytes=24, format_char_sequence="iiQQ")
227 | camera_id = camera_properties[0]
228 | model_id = camera_properties[1]
229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230 | width = camera_properties[2]
231 | height = camera_properties[3]
232 | num_params = CAMERA_MODEL_IDS[model_id].num_params
233 | params = read_next_bytes(fid, num_bytes=8*num_params,
234 | format_char_sequence="d"*num_params)
235 | cameras[camera_id] = Camera(id=camera_id,
236 | model=model_name,
237 | width=width,
238 | height=height,
239 | params=np.array(params))
240 | assert len(cameras) == num_cameras
241 | return cameras
242 |
243 |
244 | def read_extrinsics_text(path):
245 | """
246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247 | """
248 | images = {}
249 | with open(path, "r") as fid:
250 | while True:
251 | line = fid.readline()
252 | if not line:
253 | break
254 | line = line.strip()
255 | if len(line) > 0 and line[0] != "#":
256 | elems = line.split()
257 | image_id = int(elems[0])
258 | qvec = np.array(tuple(map(float, elems[1:5])))
259 | tvec = np.array(tuple(map(float, elems[5:8])))
260 | camera_id = int(elems[8])
261 | image_name = elems[9]
262 | elems = fid.readline().split()
263 | xys = np.column_stack([tuple(map(float, elems[0::3])),
264 | tuple(map(float, elems[1::3]))])
265 | point3D_ids = np.array(tuple(map(int, elems[2::3])))
266 | images[image_id] = Image(
267 | id=image_id, qvec=qvec, tvec=tvec,
268 | camera_id=camera_id, name=image_name,
269 | xys=xys, point3D_ids=point3D_ids)
270 | return images
271 |
272 |
273 | def read_colmap_bin_array(path):
274 | """
275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276 |
277 | :param path: path to the colmap binary file.
278 | :return: nd array with the floating point values in the value
279 | """
280 | with open(path, "rb") as fid:
281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282 | usecols=(0, 1, 2), dtype=int)
283 | fid.seek(0)
284 | num_delimiter = 0
285 | byte = fid.read(1)
286 | while True:
287 | if byte == b"&":
288 | num_delimiter += 1
289 | if num_delimiter >= 3:
290 | break
291 | byte = fid.read(1)
292 | array = np.fromfile(fid, np.float32)
293 | array = array.reshape((width, height, channels), order="F")
294 | return np.transpose(array, (1, 0, 2)).squeeze()
295 |
--------------------------------------------------------------------------------
/scene/dataset_readers.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import sys
14 | from PIL import Image
15 | from typing import NamedTuple
16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19 | import numpy as np
20 | import json
21 | from pathlib import Path
22 | from plyfile import PlyData, PlyElement
23 | from utils.sh_utils import SH2RGB
24 | from scene.gaussian_model import BasicPointCloud
25 |
26 | class CameraInfo(NamedTuple):
27 | uid: int
28 | R: np.array
29 | T: np.array
30 | FovY: np.array
31 | FovX: np.array
32 | image: np.array
33 | image_path: str
34 | image_name: str
35 | # semantic_image: np.array
36 | # semantic_path: str
37 | # semantic_image_name: str
38 | # semantic_classes: list
39 | # num_semantic_classes: int
40 | semantic_image: None
41 | semantic_path: None
42 | semantic_image_name: None
43 | semantic_classes: None
44 | num_semantic_classes: None
45 | width: int
46 | height: int
47 |
48 | class SceneInfo(NamedTuple):
49 | point_cloud: BasicPointCloud
50 | train_cameras: list
51 | test_cameras: list
52 | nerf_normalization: dict
53 | ply_path: str
54 |
55 | def getNerfppNorm(cam_info):
56 | def get_center_and_diag(cam_centers):
57 | cam_centers = np.hstack(cam_centers)
58 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
59 | center = avg_cam_center
60 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
61 | diagonal = np.max(dist)
62 | return center.flatten(), diagonal
63 |
64 | cam_centers = []
65 |
66 | for cam in cam_info:
67 | W2C = getWorld2View2(cam.R, cam.T)
68 | C2W = np.linalg.inv(W2C)
69 | cam_centers.append(C2W[:3, 3:4])
70 |
71 | center, diagonal = get_center_and_diag(cam_centers)
72 | radius = diagonal * 1.1
73 |
74 | translate = -center
75 |
76 | return {"translate": translate, "radius": radius}
77 |
78 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
79 | cam_infos = []
80 | for idx, key in enumerate(cam_extrinsics):
81 | sys.stdout.write('\r')
82 | # the exact output you're looking for:
83 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
84 | sys.stdout.flush()
85 |
86 | extr = cam_extrinsics[key]
87 | intr = cam_intrinsics[extr.camera_id]
88 | height = intr.height
89 | width = intr.width
90 |
91 | uid = intr.id
92 | R = np.transpose(qvec2rotmat(extr.qvec))
93 | T = np.array(extr.tvec)
94 |
95 | if intr.model=="SIMPLE_PINHOLE":
96 | focal_length_x = intr.params[0]
97 | FovY = focal2fov(focal_length_x, height)
98 | FovX = focal2fov(focal_length_x, width)
99 | elif intr.model=="PINHOLE":
100 | focal_length_x = intr.params[0]
101 | focal_length_y = intr.params[1]
102 | FovY = focal2fov(focal_length_y, height)
103 | FovX = focal2fov(focal_length_x, width)
104 | else:
105 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
106 |
107 | image_path = os.path.join(images_folder, os.path.basename(extr.name))
108 | image_name = os.path.basename(image_path).split(".")[0]
109 | image = Image.open(image_path)
110 |
111 | # semantic_image
112 | base_dir, _ = os.path.split(images_folder)
113 | # semantic_folder = os.path.join(base_dir, "semantic_class")
114 | # semantic_path = os.path.join(semantic_folder, os.path.basename(extr.name))
115 |
116 | # prefix = '0000_'
117 | # semantic_image_name = prefix + extr.name
118 |
119 | # semantic_path = os.path.join(semantic_folder, os.path.basename(semantic_image_name))
120 | # semantic_image = np.array(Image.open(semantic_path))
121 |
122 | # semantic_classes
123 | # semantic_all_imgs = np.empty(0)
124 | # semantic_all_files = [file for file in os.listdir(semantic_folder) if file.endswith(".png")]
125 | # for semantic_file in semantic_all_files:
126 | # semantic_img_path = os.path.join(semantic_folder, semantic_file)
127 | # semantic_img = np.array(Image.open(semantic_img_path))
128 | # if semantic_all_imgs.size == 0:
129 | # semantic_all_imgs = semantic_img
130 | # else:
131 | # semantic_all_imgs = np.concatenate((semantic_all_imgs, semantic_img),axis=0)
132 | #
133 | # semantic_classes = np.unique(semantic_all_imgs).astype(np.uint8)
134 | # num_semantic_classes = semantic_classes.shape[0]
135 | # semantic_classes = list(semantic_classes)
136 |
137 | # 快速
138 | # num_semantic_classes = 16
139 | # semantic_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
140 |
141 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
142 | image_path=image_path, image_name=image_name,
143 | semantic_image=None, semantic_path=None,
144 | semantic_image_name=None,
145 | semantic_classes=None, num_semantic_classes=None,
146 | width=width, height=height)
147 | cam_infos.append(cam_info)
148 | sys.stdout.write('\n')
149 | return cam_infos
150 |
151 | def fetchPly(path):
152 | plydata = PlyData.read(path)
153 | vertices = plydata['vertex']
154 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
155 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
156 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
157 | return BasicPointCloud(points=positions, colors=colors, normals=normals)
158 |
159 | def storePly(path, xyz, rgb):
160 | # Define the dtype for the structured array
161 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
162 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
163 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
164 |
165 | normals = np.zeros_like(xyz)
166 |
167 | elements = np.empty(xyz.shape[0], dtype=dtype)
168 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
169 | elements[:] = list(map(tuple, attributes))
170 |
171 | # Create the PlyData object and write to file
172 | vertex_element = PlyElement.describe(elements, 'vertex')
173 | ply_data = PlyData([vertex_element])
174 | ply_data.write(path)
175 |
176 | def readColmapSceneInfo(path, images, eval, llffhold=8):
177 | try:
178 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
179 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
180 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
181 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
182 | except:
183 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
184 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
185 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
186 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
187 |
188 | reading_dir = "images" if images == None else images
189 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
190 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
191 |
192 | if eval:
193 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
194 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
195 |
196 | # train_cam_infos = [c for idx, c in enumerate(train_cam_infos) if idx % 2 == 0]
197 | # test_cam_infos = [c for idx, c in enumerate(test_cam_infos) if idx % 4 == 0]
198 | else:
199 | train_cam_infos = cam_infos
200 | test_cam_infos = []
201 |
202 | nerf_normalization = getNerfppNorm(train_cam_infos)
203 |
204 | ply_path = os.path.join(path, "sparse/0/points3D.ply")
205 | bin_path = os.path.join(path, "sparse/0/points3D.bin")
206 | txt_path = os.path.join(path, "sparse/0/points3D.txt")
207 | if not os.path.exists(ply_path):
208 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
209 | try:
210 | xyz, rgb, _ = read_points3D_binary(bin_path)
211 | except:
212 | xyz, rgb, _ = read_points3D_text(txt_path)
213 | storePly(ply_path, xyz, rgb)
214 | try:
215 | pcd = fetchPly(ply_path)
216 | except:
217 | pcd = None
218 |
219 | scene_info = SceneInfo(point_cloud=pcd,
220 | train_cameras=train_cam_infos,
221 | test_cameras=test_cam_infos,
222 | nerf_normalization=nerf_normalization,
223 | ply_path=ply_path)
224 | return scene_info
225 |
226 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
227 | cam_infos = []
228 |
229 | with open(os.path.join(path, transformsfile)) as json_file:
230 | contents = json.load(json_file)
231 | fovx = contents["camera_angle_x"]
232 |
233 | frames = contents["frames"]
234 | for idx, frame in enumerate(frames):
235 | cam_name = os.path.join(path, frame["file_path"] + extension)
236 |
237 | # NeRF 'transform_matrix' is a camera-to-world transform
238 | c2w = np.array(frame["transform_matrix"])
239 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
240 | c2w[:3, 1:3] *= -1
241 |
242 | # get the world-to-camera transform and set R, T
243 | w2c = np.linalg.inv(c2w)
244 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
245 | T = w2c[:3, 3]
246 |
247 | image_path = os.path.join(path, cam_name)
248 | image_name = Path(cam_name).stem
249 | image = Image.open(image_path)
250 |
251 | im_data = np.array(image.convert("RGBA"))
252 |
253 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
254 |
255 | norm_data = im_data / 255.0
256 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
257 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
258 |
259 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
260 | FovY = fovy
261 | FovX = fovx
262 |
263 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
264 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
265 |
266 | return cam_infos
267 |
268 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
269 | print("Reading Training Transforms")
270 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
271 | print("Reading Test Transforms")
272 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
273 |
274 | if not eval:
275 | train_cam_infos.extend(test_cam_infos)
276 | test_cam_infos = []
277 |
278 | nerf_normalization = getNerfppNorm(train_cam_infos)
279 |
280 | ply_path = os.path.join(path, "points3d.ply")
281 | if not os.path.exists(ply_path):
282 | # Since this data set has no colmap data, we start with random points
283 | num_pts = 100_000
284 | print(f"Generating random point cloud ({num_pts})...")
285 |
286 | # We create random points inside the bounds of the synthetic Blender scenes
287 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
288 | shs = np.random.random((num_pts, 3)) / 255.0
289 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
290 |
291 | storePly(ply_path, xyz, SH2RGB(shs) * 255)
292 | try:
293 | pcd = fetchPly(ply_path)
294 | except:
295 | pcd = None
296 |
297 | scene_info = SceneInfo(point_cloud=pcd,
298 | train_cameras=train_cam_infos,
299 | test_cameras=test_cam_infos,
300 | nerf_normalization=nerf_normalization,
301 | ply_path=ply_path)
302 | return scene_info
303 |
304 | sceneLoadTypeCallbacks = {
305 | "Colmap": readColmapSceneInfo,
306 | "Blender" : readNerfSyntheticInfo
307 | }
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torchvision
14 | import torch.optim as optim
15 | os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
16 | os.environ['CUDA_VISIBLE_DEVICES']= '1'
17 | import torch
18 | import torch.nn as nn
19 | from random import randint
20 | from utils.loss_utils import l1_loss, ssim
21 | from gaussian_renderer import render, network_gui
22 | import sys
23 | from scene import Scene, GaussianModel
24 | from utils.general_utils import safe_state, calculate_reflection_direction
25 | from utils.general_utils import compute_gaussian_normals, flip_align_view
26 | import uuid
27 | from tqdm import tqdm
28 | from utils.image_utils import psnr
29 | from utils.graphics_utils import views_dir
30 | from argparse import ArgumentParser, Namespace
31 | from arguments import ModelParams, PipelineParams, OptimizationParams
32 | from mynet import MyNet, embedding_fn
33 | from lpipsPyTorch import lpips
34 | try:
35 | from torch.utils.tensorboard import SummaryWriter
36 | TENSORBOARD_FOUND = True
37 | except ImportError:
38 | TENSORBOARD_FOUND = False
39 |
40 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
41 | first_iter = 0
42 | tb_writer = prepare_output_and_logger(dataset)
43 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes)
44 | scene = Scene(dataset, gaussians)
45 | gaussians.training_setup(opt)
46 |
47 | model = MyNet().to('cuda')
48 |
49 | learning_rate = 7e-6 # 5e-6
50 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
51 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2000, verbose=True, factor=0.25, min_lr=1e-7)
52 |
53 | if checkpoint:
54 | (model_params, first_iter) = torch.load(checkpoint)
55 | gaussians.restore(model_params, opt)
56 |
57 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
58 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
59 |
60 | iter_start = torch.cuda.Event(enable_timing=True)
61 | iter_end = torch.cuda.Event(enable_timing=True)
62 |
63 | CE_loss = nn.CrossEntropyLoss()
64 |
65 | viewpoint_stack = None
66 | ema_loss_for_log = 0.0
67 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
68 | first_iter += 1
69 | for iteration in range(first_iter, opt.iterations + 1):
70 | # if network_gui.conn == None:
71 | # network_gui.try_connect()
72 | # while network_gui.conn != None:
73 | # try:
74 | # net_image_bytes = None
75 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
76 | # if custom_cam != None:
77 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
78 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
79 | # network_gui.send(net_image_bytes, dataset.source_path)
80 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
81 | # break
82 | # except Exception as e:
83 | # network_gui.conn = None
84 |
85 | iter_start.record()
86 |
87 | gaussians.update_learning_rate(iteration)
88 |
89 | # Every 1000 its we increase the levels of SH up to a maximum degree
90 | if iteration % 1000 == 0:
91 | gaussians.oneupSHdegree()
92 |
93 | # Pick a random Camera
94 | if not viewpoint_stack:
95 | viewpoint_stack = scene.getTrainCameras().copy()
96 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
97 | viewpoint_camera_center = viewpoint_cam.camera_center
98 |
99 | # Render
100 | if (iteration - 1) == debug_from:
101 | pipe.debug = True
102 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
103 |
104 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths, Ln = render_pkg
105 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:,:, 17:]
106 |
107 | views_emd = viewpoint_cam.views_emd
108 | # print(rendered_features[0].shape, views_emd.shape)
109 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
110 |
111 |
112 | image = model(*rendered_features)
113 |
114 | mask = image['mask_out'].squeeze(0)
115 | color = image['color_out'].squeeze(0)
116 | image = image['im_out'].squeeze(0)
117 |
118 | # Color Loss
119 | semantic_loss = 0
120 |
121 | gt_image = viewpoint_cam.original_image.cuda()
122 |
123 | dir_pp = (gaussians.get_xyz - viewpoint_camera_center.repeat(gaussians.get_xyz.shape[0], 1))
124 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True)
125 | normals = gaussians.get_normal
126 | pseudo_normals = gaussians.get_minimum_axis
127 | pseudo_normals, _ = flip_align_view(pseudo_normals, dir_pp)
128 | Ln = 1 - torch.nn.functional.cosine_similarity(normals, pseudo_normals).mean()
129 |
130 | Ll1 = l1_loss(image, gt_image)
131 | color_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
132 |
133 | LB = l1_loss(color, gt_image)
134 | basic_loss = (1.0 - opt.lambda_dssim) * LB + opt.lambda_dssim * (1.0 - ssim(color, gt_image))
135 |
136 | loss = color_loss + 0.001 * Ln + 0.05 * basic_loss
137 |
138 | loss.backward()
139 |
140 | iter_end.record()
141 |
142 | with torch.no_grad():
143 | # Progress bar
144 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
145 | if iteration % 10 == 0:
146 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "C_Loss": f"{color_loss:.{5}f}", "S_Loss": f"{semantic_loss:.{5}f}"})
147 | progress_bar.update(10)
148 | if iteration == opt.iterations:
149 | progress_bar.close()
150 |
151 | # Log and save
152 |
153 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), model=model)
154 | if (iteration in saving_iterations):
155 | print("\n[ITER {}] Saving Gaussians".format(iteration))
156 | scene.save(iteration)
157 |
158 | # Densification
159 | if iteration < opt.densify_until_iter:
160 | # Keep track of max radii in image-space for pruning
161 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
162 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
163 |
164 |
165 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
166 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
167 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) # 0.005
168 |
169 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
170 | gaussians.reset_opacity()
171 |
172 | optimizer.step()
173 |
174 | # Optimizer step
175 | if iteration < opt.iterations:
176 | gaussians.optimizer.step()
177 | gaussians.optimizer.zero_grad(set_to_none=True)
178 | gaussians.optimizer_net.step()
179 | gaussians.optimizer_net.zero_grad(set_to_none=True)
180 | gaussians.scheduler_net.step()
181 | scheduler.step(loss)
182 |
183 | if (iteration in checkpoint_iterations):
184 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
185 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
186 |
187 |
188 |
189 | def prepare_output_and_logger(args):
190 | if not args.model_path:
191 | if os.getenv('OAR_JOB_ID'):
192 | unique_str = os.getenv('OAR_JOB_ID')
193 | else:
194 | unique_str = str(uuid.uuid4())
195 | args.model_path = os.path.join("./output/", unique_str[0:10])
196 |
197 | # Set up output folder
198 | print("Output folder: {}".format(args.model_path))
199 | os.makedirs(args.model_path, exist_ok=True)
200 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
201 | cfg_log_f.write(str(Namespace(**vars(args))))
202 |
203 | # Create Tensorboard writer
204 | tb_writer = None
205 | if TENSORBOARD_FOUND:
206 | tb_writer = SummaryWriter(args.model_path)
207 | else:
208 | print("Tensorboard not available: not logging progress")
209 | return tb_writer
210 |
211 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, model=None):
212 | if tb_writer:
213 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
214 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
215 | tb_writer.add_scalar('iter_time', elapsed, iteration)
216 |
217 | # Report test and samples of training set
218 | if iteration in testing_iterations:
219 | torch.save(model.state_dict(), scene.model_path + "/model_ckpt" + str(iteration) + "pth")
220 | torch.cuda.empty_cache()
221 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
222 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
223 | for config in validation_configs:
224 | if config['cameras'] and len(config['cameras']) > 0:
225 | l1_test = 0.0
226 | psnr_test = 0.0
227 | ssim_test = 0.0
228 | lpips_test = 0.0
229 | for idx, viewpoint in enumerate(config['cameras']):
230 | if model:
231 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths,Ln= renderFunc(viewpoint, scene.gaussians, *renderArgs)
232 | rays_d = torch.from_numpy(
233 | views_dir(viewpoint.image_height, viewpoint.image_width, viewpoint.K, viewpoint.c2w)).cuda()
234 |
235 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:, :, 17:]
236 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
237 |
238 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
239 |
240 |
241 | image = model(*rendered_features)
242 | image_save = image['im_out'].squeeze(0)
243 | image = torch.clamp(image_save, 0.0, 1.0)
244 | else:
245 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
246 |
247 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
248 |
249 | if tb_writer and (idx < 5):
250 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
251 | if iteration == testing_iterations[0]:
252 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
253 | l1_test += l1_loss(image, gt_image).mean().double()
254 | psnr_test += psnr(image, gt_image).mean().double()
255 | ssim_test += ssim(image, gt_image).mean().double()
256 | lpips_test += lpips(image, gt_image).mean().double()
257 | psnr_test /= len(config['cameras'])
258 | ssim_test /= len(config['cameras'])
259 | lpips_test /= len(config['cameras'])
260 | l1_test /= len(config['cameras'])
261 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test))
262 | if tb_writer:
263 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
264 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
265 |
266 | if tb_writer:
267 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
268 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
269 | torch.cuda.empty_cache()
270 |
271 | if __name__ == "__main__":
272 | # Set up command line argument parser
273 | parser = ArgumentParser(description="Training script parameters")
274 | lp = ModelParams(parser)
275 | op = OptimizationParams(parser)
276 | pp = PipelineParams(parser)
277 | parser.add_argument('--ip', type=str, default="127.0.0.1")
278 | parser.add_argument('--port', type=int, default=6009)
279 | parser.add_argument('--debug_from', type=int, default=-1)
280 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
281 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000])
282 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000])
283 | parser.add_argument("--quiet", action="store_true")
284 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
285 | parser.add_argument("--start_checkpoint", type=str, default = None)
286 | args = parser.parse_args(sys.argv[1:])
287 | args.save_iterations.append(args.iterations)
288 |
289 | print("Scene path" + args.source_path)
290 | print("Optimizing " + args.model_path)
291 |
292 | # Initialize system state (RNG)
293 | safe_state(args.quiet)
294 |
295 | # Start GUI server, configure and run training
296 | # network_gui.init(args.ip, args.port)
297 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
298 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
299 |
300 | # All done
301 | print("\nTraining complete.")
302 |
--------------------------------------------------------------------------------
/mynet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 | from functools import partial
7 |
8 |
9 | def embedding_fn(inputs):
10 | embed_fns = []
11 | d = 3
12 | out_dim = 0
13 | max_freq = 3 # 3
14 |
15 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=4) # [2**0, 2**1,...,]
16 | for freq in freq_bands:
17 | for p_fn in [torch.sin, torch.cos]:
18 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) # [torch.sin, torch.cos]
19 | out_dim += d
20 |
21 | return torch.cat([fn(inputs) for fn in embed_fns], -1)
22 |
23 | class BasicConv(nn.Module):
24 | # Gated_conv
25 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, dilation=1,
26 | padding_mode='reflect', act_fun=nn.ELU, normalization=nn.BatchNorm2d):
27 | super().__init__()
28 | self.pad_mode = padding_mode
29 | self.filter_size = kernel_size
30 | self.stride = stride
31 | self.dilation = dilation
32 | group_normalization = nn.GroupNorm
33 |
34 | n_pad_pxl = int(self.dilation * (self.filter_size - 1) / 2)
35 | self.flag = relu
36 |
37 | # this is for backward campatibility with older model checkpoints
38 | self.block = nn.ModuleDict(
39 | {
40 | 'conv_f': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
41 | padding=n_pad_pxl),
42 | 'act_f': act_fun(),
43 | 'conv_m': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
44 | padding=n_pad_pxl),
45 | 'act_m': nn.Sigmoid(),
46 | 'norm': normalization(out_channels),
47 | 'group_norm': group_normalization(num_groups=out_channels, num_channels=out_channels),
48 | }
49 | )
50 |
51 | def forward(self, x, *args, **kwargs):
52 | features = self.block.act_f(self.block.conv_f(x))
53 | output = features
54 | return output
55 |
56 |
57 | class ResBlock(nn.Module):
58 | def __init__(self, in_channel, out_channel):
59 | super(ResBlock, self).__init__()
60 | self.main = nn.Sequential(
61 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
62 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False),
63 | )
64 | self.relu = nn.ELU()
65 | self.norm = nn.GroupNorm(1, out_channel)
66 |
67 | def forward(self, x):
68 | return self.relu(self.norm(self.main(x) + (x)))
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, in_channel, out_channel):
72 | super(ResNet, self).__init__()
73 | self.main = nn.Sequential(
74 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
75 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False),
76 | nn.BatchNorm2d(out_channel)
77 |
78 | )
79 | self.net = BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True)
80 | self.norm = nn.BatchNorm2d(out_channel)
81 | self.relu = nn.ELU()
82 |
83 | def forward(self, x):
84 | output = self.main(x) + self.norm(self.net(x))
85 | return self.relu(output)
86 |
87 |
88 | class SCM(nn.Module):
89 | def __init__(self, out_plane):
90 | super(SCM, self).__init__()
91 | self.main = nn.Sequential(
92 | BasicConv(8, out_plane - 8, kernel_size=3, stride=1, relu=True),
93 | )
94 |
95 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)
96 |
97 | def forward(self, x):
98 | x = torch.cat([x, self.main(x)], dim=1)
99 | return self.conv(x)
100 |
101 |
102 | class EBlock(nn.Module):
103 | def __init__(self, out_channel, num_res=1):
104 | super(EBlock, self).__init__()
105 |
106 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
107 |
108 | self.layers = nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | return self.layers(x)
112 |
113 |
114 | class DBlock(nn.Module):
115 | def __init__(self, channel, num_res=1):
116 | super(DBlock, self).__init__()
117 |
118 | layers = [ResBlock(channel, channel) for _ in range(num_res)]
119 | self.layers = nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | return self.layers(x)
123 |
124 |
125 | class FAM(nn.Module):
126 | def __init__(self, channel):
127 | super(FAM, self).__init__()
128 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)
129 |
130 | def forward(self, x1, x2):
131 | x = x1 * x2
132 | out = x1 + self.merge(x)
133 | return out
134 |
135 |
136 | class AFF(nn.Module):
137 | def __init__(self, in_channel, out_channel):
138 | super(AFF, self).__init__()
139 | self.conv = nn.Sequential(
140 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
141 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
142 | )
143 |
144 | def forward(self, x1, x2, x3, x4):
145 | x = torch.cat([x1, x2, x3, x4], dim=1)
146 | return self.conv(x)
147 |
148 |
149 | class MyNet(nn.Module):
150 | def __init__(
151 | self,
152 | num_input_channels=8,
153 | num_output_channels=3,
154 | feature_scale=4,
155 | num_res=1
156 |
157 | ):
158 | super().__init__()
159 |
160 | self.feature_scale = feature_scale
161 | base_channel = 8
162 |
163 | filters = [64, 128, 256, 512, 1024]
164 | filters = [x // self.feature_scale for x in filters]
165 |
166 | base_channel = 8
167 |
168 | self.resnet = nn.ModuleList([
169 | ResNet(base_channel, base_channel * 2),
170 | ResNet(base_channel * 2, base_channel * 4),
171 | ResNet(base_channel * 4, base_channel * 8),
172 | ResNet(base_channel * 8, base_channel * 4),
173 | ResNet(base_channel * 4, base_channel * 2),
174 | ResNet(base_channel * 2, base_channel),
175 | ])
176 |
177 | self.feat_extract = nn.ModuleList([
178 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1),
179 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
180 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
181 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2),
182 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2),
183 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
184 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2),
185 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2),
186 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=1),
187 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=3, relu=True, stride=1),
188 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=3, relu=True, stride=1),
189 | BasicConv(base_channel * 2, base_channel * 1, kernel_size=3, relu=True, stride=1),
190 | BasicConv(base_channel, base_channel * 4, kernel_size=3, relu=True, stride=1),
191 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=1),
192 | BasicConv(base_channel * 8 + 24, base_channel * 16, kernel_size=3, relu=True, stride=1),
193 | # nn.Linear(base_channel * 8 + 24, base_channel*16),
194 | BasicConv(base_channel * 16, base_channel * 4, kernel_size=3, relu=True, stride=1),
195 | BasicConv(base_channel * 4, base_channel, kernel_size=3, relu=True, stride=1),
196 | BasicConv(24, base_channel * 8, kernel_size=3, relu=True, stride=1),
197 | BasicConv(base_channel * 8, base_channel * 2, kernel_size=3, relu=True, stride=1),
198 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
199 | BasicConv(base_channel * 8, 20, kernel_size=3, relu=False, stride=1),
200 |
201 | BasicConv(base_channel, base_channel * 8, kernel_size=3, relu=False, stride=1),
202 | BasicConv(base_channel * 8 + 24, base_channel, kernel_size=3, relu=True, stride=1),
203 | # BasicConv(base_channel * 8 + 24, base_channel, kernel_size=3, relu=True, stride=1),
204 | # BasicConv(base_channel * 8 , base_channel, kernel_size=3, relu=True, stride=1),
205 | BasicConv(base_channel, 3, kernel_size=3, relu=True, stride=1),
206 |
207 | BasicConv(24, base_channel, kernel_size=3, relu=True, stride=1),
208 | BasicConv(base_channel * 4, 20, kernel_size=3, relu=False, stride=1),
209 | # BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
210 | BasicConv(base_channel, 1, kernel_size=3, relu=False, stride=1),
211 | # BasicConv(base_channel*2+24, base_channel, kernel_size=3, relu=False, stride=1),
212 | BasicConv(base_channel * 2, base_channel, kernel_size=3, relu=False, stride=1),
213 |
214 | BasicConv(40, base_channel, kernel_size=3, relu=False, stride=1),
215 |
216 | # 29开始引入hash
217 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
218 | BasicConv(base_channel * 8, 3, kernel_size=3, relu=False, stride=1),
219 |
220 | # 31 开始MLP:
221 | nn.Linear(8, 64, bias=True),
222 | nn.Linear(64 + 24, 64, bias=True),
223 | nn.Linear(64, 3, bias=True)
224 |
225 | ])
226 |
227 | self.SCM0 = SCM(base_channel * 8)
228 | self.SCM1 = SCM(base_channel * 4)
229 | self.SCM2 = SCM(base_channel * 2)
230 |
231 | self.FAM0 = FAM(base_channel * 8)
232 | self.FAM1 = FAM(base_channel * 4)
233 | self.FAM2 = FAM(base_channel * 2)
234 |
235 | self.AFFs = nn.ModuleList([
236 | AFF(base_channel * 15, base_channel * 1),
237 | AFF(base_channel * 15, base_channel * 2),
238 | AFF(base_channel * 15, base_channel * 4),
239 | ])
240 |
241 | self.Encoder = nn.ModuleList([
242 | EBlock(base_channel, num_res),
243 | EBlock(base_channel * 2, num_res),
244 | EBlock(base_channel * 4, num_res),
245 | EBlock(base_channel * 8, num_res)
246 | ])
247 |
248 | self.Decoder = nn.ModuleList([
249 | DBlock(base_channel * 8, num_res),
250 | DBlock(base_channel * 4, num_res),
251 | DBlock(base_channel * 2, num_res),
252 | DBlock(base_channel, num_res)
253 | ])
254 |
255 | self.Convs = nn.ModuleList([
256 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1),
257 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
258 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
259 | BasicConv(base_channel * 8, base_channel * 8, kernel_size=1, relu=True, stride=1),
260 | BasicConv(base_channel * 4, base_channel * 4, kernel_size=1, relu=True, stride=1),
261 | BasicConv(base_channel * 2, base_channel * 2, kernel_size=1, relu=True, stride=1),
262 |
263 | ])
264 |
265 | self.ConvsOut = nn.ModuleList(
266 | [
267 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
268 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
269 | ]
270 | )
271 |
272 | self.up = nn.Upsample(scale_factor=2, mode='bilinear')
273 | self.up1 = nn.Upsample(scale_factor=4, mode='bilinear')
274 | self.sigmoid = nn.Sigmoid()
275 |
276 | # basic_model
277 | def forward(self, *inputs, **kwargs):
278 | inputs = list(inputs)
279 |
280 | n_input = len(inputs)
281 |
282 | x = inputs[0][:, :8, :, :]
283 |
284 | f = inputs[0][:, 8:16, :, :]
285 |
286 | sh_f = list(f.shape)
287 | sh_f[1] = 3
288 |
289 | tint = inputs[0][:, 16, :, :]
290 |
291 | views = inputs[0][:, 20:, :, :]
292 |
293 | # highlight
294 | f = self.feat_extract[21](f)
295 | f = self.Encoder[3](f)
296 | # print(f.shape, views.shape)
297 | f = torch.cat((f, views), dim=1)
298 | f = self.feat_extract[22](f)
299 | f = self.Decoder[3](f)
300 | f = self.feat_extract[23](f)
301 |
302 | mask = tint
303 |
304 | highlight = f * mask
305 |
306 | s = x
307 | z1 = x
308 |
309 | # color
310 | x = self.feat_extract[1](x)
311 | x = self.Encoder[1](x)
312 | z2 = x
313 |
314 | x = self.feat_extract[2](x)
315 | x = self.Encoder[2](x)
316 | z3 = x
317 |
318 | x = self.feat_extract[6](x)
319 | x = self.Encoder[3](x)
320 |
321 | x = self.up(x)
322 |
323 | x = self.Convs[0](x)
324 | x = F.interpolate(x, z3.shape[-2:])
325 | x = torch.cat([x, z3], dim=1)
326 | x = self.Decoder[0](x)
327 |
328 | x = self.feat_extract[9](x)
329 | x = self.up(x)
330 |
331 | x = self.Convs[1](x)
332 | x = F.interpolate(x, z2.shape[-2:])
333 | x = torch.cat([x, z2], dim=1)
334 | x = self.Decoder[1](x)
335 |
336 | x = self.feat_extract[10](x)
337 | x = self.up(x)
338 |
339 | x = self.Convs[2](x)
340 | x = F.interpolate(x, z1.shape[-2:])
341 | x = torch.cat([x, z1], dim=1)
342 | x = self.Decoder[2](x)
343 |
344 | x = self.feat_extract[11](x)
345 |
346 | z = self.feat_extract[5](x)
347 |
348 | color = z
349 |
350 | z = z + highlight
351 |
352 | return {'im_out': z,
353 | 's_out': s,
354 | 'mask_out': mask,
355 | 'highlight_out': highlight,
356 | 'f_out': f,
357 | 'color_out': color}
358 |
359 |
360 | if __name__ == '__main__':
361 | import pdb
362 | import time
363 | import numpy as np
364 |
365 | # model = UNet().to('cuda')
366 | model = MyNet().to('cuda')
367 | input = []
368 | img_sh = [1408, 376]
369 | sh_unit = 8
370 | # img_sh = list(map(lambda a: a - a % sh_unit + sh_unit if a % sh_unit != 0 else a, img_sh))
371 |
372 | # print(img_sh)
373 | down = lambda a, b: a // 2 ** b
374 | input.append(torch.zeros((1, 43, down(img_sh[0], 0), down(img_sh[1], 0)), requires_grad=True).cuda())
375 | input.append(F.interpolate(input[0], scale_factor=0.5))
376 | input.append(F.interpolate(input[1], scale_factor=0.5))
377 | input.append(F.interpolate(input[2], scale_factor=0.5))
378 | print(input)
379 |
380 | model.eval()
381 | st = time.time()
382 | print(input[0].max(), input[0].min())
383 | print(input[0].shape, input[1].shape, input[2].shape, input[3].shape)
384 | with torch.set_grad_enabled(False):
385 | out = model(*input)
386 | pdb.set_trace()
387 | print('model', time.time() - st)
388 | print(out['im_out'], out['im_out'].shape)
389 | print(out['s_out'], out['s_out'].shape)
390 | model.to('cpu')
391 |
--------------------------------------------------------------------------------
/train_2.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import os
13 | import torchvision
14 | import torch.optim as optim
15 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
16 | os.environ['CUDA_VISIBLE_DEVICES']= '0'
17 | import torch
18 | import torch.nn as nn
19 | from random import randint
20 | from utils.loss_utils import l1_loss, ssim
21 | from gaussian_renderer import render, network_gui
22 | import sys
23 | from scene import Scene, GaussianModel
24 | from utils.general_utils import safe_state, calculate_reflection_direction
25 | import uuid
26 | from tqdm import tqdm
27 | from utils.image_utils import psnr
28 | from utils.graphics_utils import views_dir
29 | from argparse import ArgumentParser, Namespace
30 | from model import UNet, SimpleNet, SimpleUNet
31 | from arguments import ModelParams, PipelineParams, OptimizationParams
32 | from mynet import MyNet, embedding_fn
33 | from lpipsPyTorch import lpips
34 | try:
35 | from torch.utils.tensorboard import SummaryWriter
36 | TENSORBOARD_FOUND = True
37 | except ImportError:
38 | TENSORBOARD_FOUND = False
39 |
40 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
41 | first_iter = 0
42 | tb_writer = prepare_output_and_logger(dataset)
43 | gaussians = GaussianModel(dataset.sh_degree, dataset.num_sem_classes)
44 | scene = Scene(dataset, gaussians)
45 | gaussians.training_setup(opt)
46 |
47 | # 初始化模型
48 | # model = SimpleNet().to('cuda')
49 | # model = SimpleUNet().to('cuda')
50 | model = MyNet().to('cuda')
51 |
52 | # 加载预训练权重
53 | # net_path = "/home/wangzhiru/projects/nerf/latent-gaussian-splatting/output/bonsai_output/model_ckpt30000pth"
54 | # net_weights = torch.load(net_path)
55 | # # print(net_weights)
56 | # model.load_state_dict(net_weights)
57 |
58 | learning_rate = 5e-6 # 5e-6 1e-5
59 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
60 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2000, verbose=True, factor=0.25, min_lr=1e-7)
61 |
62 | if checkpoint:
63 | (model_params, first_iter) = torch.load(checkpoint)
64 | gaussians.restore(model_params, opt)
65 |
66 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
67 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
68 |
69 | iter_start = torch.cuda.Event(enable_timing=True)
70 | iter_end = torch.cuda.Event(enable_timing=True)
71 |
72 | CE_loss = nn.CrossEntropyLoss()
73 |
74 | viewpoint_stack = None
75 | ema_loss_for_log = 0.0
76 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
77 | first_iter += 1
78 | for iteration in range(first_iter, opt.iterations + 1):
79 | # if network_gui.conn == None:
80 | # network_gui.try_connect()
81 | # while network_gui.conn != None:
82 | # try:
83 | # net_image_bytes = None
84 | # custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
85 | # if custom_cam != None:
86 | # net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
87 | # net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
88 | # network_gui.send(net_image_bytes, dataset.source_path)
89 | # if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
90 | # break
91 | # except Exception as e:
92 | # network_gui.conn = None
93 |
94 | iter_start.record()
95 |
96 | gaussians.update_learning_rate(iteration)
97 |
98 | # Every 1000 its we increase the levels of SH up to a maximum degree
99 | if iteration % 1000 == 0:
100 | gaussians.oneupSHdegree()
101 |
102 | # Pick a random Camera
103 | if not viewpoint_stack:
104 | viewpoint_stack = scene.getTrainCameras().copy()
105 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
106 |
107 | # Render
108 | if (iteration - 1) == debug_from:
109 | pipe.debug = True
110 | render_pkg = render(viewpoint_cam, gaussians, pipe, background)
111 |
112 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths, Ln = render_pkg
113 | # print(rendered_features[0].shape)
114 | # print(rendered_features, rendered_features[0].shape)
115 | # depths = rendered_depths[0]
116 |
117 | # 加入视角信息
118 | rays_d = torch.from_numpy(views_dir(viewpoint_cam.image_height, viewpoint_cam.image_width, viewpoint_cam.K, viewpoint_cam.c2w)).cuda()
119 | # print(rays_d.shape)
120 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:,:, 17:]
121 | # print(normals.shape)
122 |
123 |
124 | # rays_r = calculate_reflection_direction(rays_d, normals).permute(2,0,1).unsqueeze(0)
125 | # rays_d = calculate_reflection_direction(rays_d, normals)
126 |
127 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
128 | # rays_o, rays_d = views_dir(viewpoint_cam.image_height, viewpoint_cam.image_width, viewpoint_cam.K, viewpoint_cam.c2w)
129 | # rays_o, rays_d = torch.from_numpy(rays_o.copy()).cuda(), torch.from_numpy(rays_d.copy()).cuda(),
130 | # views_emdd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
131 | # views_emdo = embedding_fn(rays_o).permute(2, 0, 1).unsqueeze(0)
132 | # views_emd = torch.cat((views_emdo, views_emdd), dim = 1)
133 | rays_d = rays_d.permute(2, 0, 1).unsqueeze(0)
134 | normals = normals.permute(2, 0, 1).unsqueeze(0)
135 |
136 | # print(rays_dot.shape, views_emd.shape)
137 | # views_emd = torch.cat((views_emd, -rays_d, normals), dim=1)
138 | # views_emd = torch.cat((views_emd,normals), dim=1)
139 |
140 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
141 | # rendered_features[0] = torch.cat((rendered_features[0], depths), dim=1)
142 | # print(rendered_features[0].shape)
143 |
144 |
145 | image = model(*rendered_features)
146 | # semantic_logits = image['s_out'].squeeze(0)
147 | mask = image['mask_out'].squeeze(0)
148 | color = image['color_out'].squeeze(0)
149 | image = image['im_out'].squeeze(0)
150 |
151 | # image_mask = image * (torch.ones_like(mask)-mask)
152 |
153 | # image, semantic_logits, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["semantic"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
154 | # print(image.shape, semantic_logits.shape)
155 | # print("semantic:", semantic_logits.shape)
156 |
157 | # Semantic Loss
158 | # semantic_logits = semantic_logits.permute(1, 2, 0)
159 | # gt_semantic_image = viewpoint_cam.semantic_image.cuda()
160 | # Lce = CE_loss(semantic_logits.reshape(-1, 20), gt_semantic_image.reshape(-1).long())
161 |
162 | # semantic_loss = Lce
163 | # Color Loss
164 | semantic_loss = 0
165 |
166 | gt_image = viewpoint_cam.original_image.cuda()
167 | # gt_image_mask = gt_image * (torch.ones_like(mask)-mask)
168 | #
169 | # Ln = torch.mean((gs_normals - pseudo_normals) ** 2)
170 |
171 | Ll1 = l1_loss(image, gt_image)
172 | color_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
173 |
174 | LB = l1_loss(color, gt_image)
175 | basic_loss = (1.0 - opt.lambda_dssim) * LB + opt.lambda_dssim * (1.0 - ssim(color, gt_image))
176 |
177 | # LD = l1_loss(image_mask, gt_image_mask)
178 | # diffuse_loss = (1.0 - opt.lambda_dssim) * LD + opt.lambda_dssim * (1.0 - ssim(image_mask, gt_image_mask))
179 |
180 | loss = color_loss + 0.001 * Ln + 0.05 * basic_loss #+ diffuse_loss #+ basic_loss # + semantic_loss * 0.2
181 |
182 | loss.backward()
183 |
184 | iter_end.record()
185 |
186 | with torch.no_grad():
187 | # Progress bar
188 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
189 | if iteration % 10 == 0:
190 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{5}f}", "C_Loss": f"{color_loss:.{5}f}", "S_Loss": f"{semantic_loss:.{5}f}"})
191 | progress_bar.update(10)
192 | if iteration == opt.iterations:
193 | progress_bar.close()
194 |
195 | # Log and save
196 |
197 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), model=model)
198 | if (iteration in saving_iterations):
199 | print("\n[ITER {}] Saving Gaussians".format(iteration))
200 | scene.save(iteration)
201 |
202 | # Densification
203 | if iteration < opt.densify_until_iter:
204 | # Keep track of max radii in image-space for pruning
205 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
206 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
207 |
208 |
209 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
210 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None
211 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) # 0.005
212 |
213 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
214 | gaussians.reset_opacity()
215 |
216 | optimizer.step()
217 |
218 | # Optimizer step
219 | if iteration < opt.iterations:
220 | gaussians.optimizer.step()
221 | gaussians.optimizer.zero_grad(set_to_none=True)
222 | gaussians.optimizer_net.step()
223 | gaussians.optimizer_net.zero_grad(set_to_none=True)
224 | gaussians.scheduler_net.step()
225 | scheduler.step(loss)
226 |
227 | if (iteration in checkpoint_iterations):
228 | print("\n[ITER {}] Saving Checkpoint".format(iteration))
229 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
230 |
231 |
232 |
233 | def prepare_output_and_logger(args):
234 | if not args.model_path:
235 | if os.getenv('OAR_JOB_ID'):
236 | unique_str = os.getenv('OAR_JOB_ID')
237 | else:
238 | unique_str = str(uuid.uuid4())
239 | args.model_path = os.path.join("./output/", unique_str[0:10])
240 |
241 | # Set up output folder
242 | print("Output folder: {}".format(args.model_path))
243 | os.makedirs(args.model_path, exist_ok=True)
244 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
245 | cfg_log_f.write(str(Namespace(**vars(args))))
246 |
247 | # Create Tensorboard writer
248 | tb_writer = None
249 | if TENSORBOARD_FOUND:
250 | tb_writer = SummaryWriter(args.model_path)
251 | else:
252 | print("Tensorboard not available: not logging progress")
253 | return tb_writer
254 |
255 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, model=None):
256 | if tb_writer:
257 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
258 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
259 | tb_writer.add_scalar('iter_time', elapsed, iteration)
260 |
261 | # Report test and samples of training set
262 | if iteration in testing_iterations:
263 | torch.save(model.state_dict(), scene.model_path + "/model_ckpt" + str(iteration) + "pth")
264 | torch.cuda.empty_cache()
265 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
266 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
267 | for config in validation_configs:
268 | if config['cameras'] and len(config['cameras']) > 0:
269 | l1_test = 0.0
270 | psnr_test = 0.0
271 | ssim_test = 0.0
272 | lpips_test = 0.0
273 | for idx, viewpoint in enumerate(config['cameras']):
274 | if model:
275 | rendered_features, viewspace_point_tensor, visibility_filter, radii, rendered_depths,Ln= renderFunc(viewpoint, scene.gaussians, *renderArgs)
276 | # depths = rendered_depths[0]
277 | rays_d = torch.from_numpy(
278 | views_dir(viewpoint.image_height, viewpoint.image_width, viewpoint.K, viewpoint.c2w)).cuda()
279 |
280 | normals = rendered_features[0].squeeze(0).permute(1, 2, 0)[:, :, 17:]
281 | # print(normals.shape)
282 | #
283 | # rays_d = calculate_reflection_direction(rays_d, normals)
284 | # rays_r = calculate_reflection_direction(rays_d, normals).permute(2,0,1).unsqueeze(0)
285 | views_emd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
286 |
287 | # rays_o, rays_d = views_dir(viewpoint.image_height, viewpoint.image_width,viewpoint.K, viewpoint.c2w)
288 | # rays_o, rays_d = torch.from_numpy(rays_o.copy()).cuda(), torch.from_numpy(rays_d.copy()).cuda(),
289 | # views_emdd = embedding_fn(rays_d).permute(2, 0, 1).unsqueeze(0)
290 | # views_emdo = embedding_fn(rays_o).permute(2, 0, 1).unsqueeze(0)
291 | # views_emd = torch.cat((views_emdo, views_emdd), dim=1)
292 | # views_emd = torch.cat((views_emd, -rays_d.permute(2, 0, 1).unsqueeze(0), normals.permute(2, 0, 1).unsqueeze(0)), dim=1)
293 | # views_emd = torch.cat((views_emd, normals.permute(2, 0, 1).unsqueeze(0)),dim=1)
294 | rendered_features[0] = torch.cat((rendered_features[0], views_emd), dim=1)
295 | # rendered_features[0] = torch.cat((rendered_features[0], depths), dim=1)
296 |
297 |
298 | image = model(*rendered_features)
299 | image_save = image['im_out'].squeeze(0)
300 | image = torch.clamp(image_save, 0.0, 1.0)
301 | else:
302 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
303 |
304 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
305 |
306 | if tb_writer and (idx < 5):
307 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
308 | if iteration == testing_iterations[0]:
309 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
310 | l1_test += l1_loss(image, gt_image).mean().double()
311 | psnr_test += psnr(image, gt_image).mean().double()
312 | ssim_test += ssim(image, gt_image).mean().double()
313 | lpips_test += lpips(image, gt_image).mean().double()
314 | psnr_test /= len(config['cameras'])
315 | ssim_test /= len(config['cameras'])
316 | lpips_test /= len(config['cameras'])
317 | l1_test /= len(config['cameras'])
318 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], l1_test, psnr_test, ssim_test, lpips_test))
319 | if tb_writer:
320 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
321 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
322 |
323 | if tb_writer:
324 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
325 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
326 | torch.cuda.empty_cache()
327 |
328 | if __name__ == "__main__":
329 | # Set up command line argument parser
330 | parser = ArgumentParser(description="Training script parameters")
331 | lp = ModelParams(parser)
332 | op = OptimizationParams(parser)
333 | pp = PipelineParams(parser)
334 | parser.add_argument('--ip', type=str, default="127.0.0.1")
335 | parser.add_argument('--port', type=int, default=6009)
336 | parser.add_argument('--debug_from', type=int, default=-1)
337 | parser.add_argument('--detect_anomaly', action='store_true', default=False)
338 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000, 50_000, 60_000, 70_000, 80_000]) #, 40_000, 50_000, 60_000, 70_000, 80_000])
339 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[1000, 7000, 11_000, 15_000, 20_000, 25_000, 30_000, 40_000, 50_000, 60_000, 70_000, 80_000]) # ,40_000, 50_000, 60_000, 70_000, 80_000])
340 | parser.add_argument("--quiet", action="store_true")
341 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
342 | parser.add_argument("--start_checkpoint", type=str, default = None)
343 | args = parser.parse_args(sys.argv[1:])
344 | args.save_iterations.append(args.iterations)
345 |
346 | print("Scene path" + args.source_path)
347 | print("Optimizing " + args.model_path)
348 |
349 | # Initialize system state (RNG)
350 | safe_state(args.quiet)
351 |
352 | # Start GUI server, configure and run training
353 | # network_gui.init(args.ip, args.port)
354 | torch.autograd.set_detect_anomaly(args.detect_anomaly)
355 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
356 |
357 | # All done
358 | print("\nTraining complete.")
359 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 | from functools import partial
7 |
8 |
9 | class BasicConv(nn.Module):
10 | # Gated_conv
11 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, dilation=1,
12 | padding_mode='reflect', act_fun=nn.ELU, normalization=nn.BatchNorm2d):
13 | super().__init__()
14 | self.pad_mode = padding_mode
15 | self.filter_size = kernel_size
16 | self.stride = stride
17 | self.dilation = dilation
18 |
19 | n_pad_pxl = int(self.dilation * (self.filter_size - 1) / 2)
20 | self.flag = relu
21 |
22 | # this is for backward campatibility with older model checkpoints
23 | self.block = nn.ModuleDict(
24 | {
25 | 'conv_f': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
26 | padding=n_pad_pxl),
27 | 'act_f': act_fun(),
28 | 'conv_m': nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
29 | padding=n_pad_pxl),
30 | 'act_m': nn.Sigmoid(),
31 | 'norm': normalization(out_channels)
32 | }
33 | )
34 |
35 | def forward(self, x, *args, **kwargs):
36 | # if self.flag:
37 | # features = self.block.act_f(self.block.conv_f(x))
38 | features = self.block.act_f(self.block.conv_f(x))
39 | # else:
40 | # features = self.block.conv_f(x)
41 | # mask = self.block.act_m(self.block.conv_m(x))
42 | # output = features * mask
43 | output = features
44 | # output = self.block.norm(output)
45 |
46 | return output
47 |
48 |
49 | class ResBlock(nn.Module):
50 | def __init__(self, in_channel, out_channel):
51 | super(ResBlock, self).__init__()
52 | self.main = nn.Sequential(
53 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
54 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False),
55 | nn.BatchNorm2d(out_channel)
56 | )
57 |
58 | def forward(self, x):
59 | return self.main(x) + x
60 |
61 |
62 | class SCM(nn.Module):
63 | def __init__(self, out_plane):
64 | super(SCM, self).__init__()
65 | self.main = nn.Sequential(
66 | BasicConv(16, out_plane - 8, kernel_size=3, stride=1, relu=True),
67 | # BasicConv(8, out_plane // 4, kernel_size=3, stride=1, relu=True),
68 | # BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
69 | # BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
70 | # BasicConv(out_plane // 2, out_plane - 8, kernel_size=1, stride=1, relu=True)
71 | )
72 |
73 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)
74 |
75 | def forward(self, x):
76 | x = torch.cat([x, self.main(x)], dim=1)
77 | return self.conv(x)
78 |
79 |
80 | class EBlock(nn.Module):
81 | def __init__(self, out_channel, num_res=1):
82 | super(EBlock, self).__init__()
83 |
84 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
85 |
86 | self.layers = nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | return self.layers(x)
90 |
91 |
92 | class DBlock(nn.Module):
93 | def __init__(self, channel, num_res=1):
94 | super(DBlock, self).__init__()
95 |
96 | layers = [ResBlock(channel, channel) for _ in range(num_res)]
97 | self.layers = nn.Sequential(*layers)
98 |
99 | def forward(self, x):
100 | return self.layers(x)
101 |
102 |
103 | class FAM(nn.Module):
104 | def __init__(self, channel):
105 | super(FAM, self).__init__()
106 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)
107 |
108 | def forward(self, x1, x2):
109 | x = x1 * x2
110 | out = x1 + self.merge(x)
111 | return out
112 |
113 |
114 | class AFF(nn.Module):
115 | def __init__(self, in_channel, out_channel):
116 | super(AFF, self).__init__()
117 | self.conv = nn.Sequential(
118 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
119 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
120 | )
121 |
122 | def forward(self, x1, x2, x3, x4):
123 | x = torch.cat([x1, x2, x3, x4], dim=1)
124 | return self.conv(x)
125 |
126 |
127 | class UNet(nn.Module):
128 | r""" Rendering network with UNet architecture and multi-scale input.
129 |
130 | Args:
131 | num_input_channels: Number of channels in the input tensor or list of tensors. An integer or a list of integers for each input tensor.
132 | num_output_channels: Number of output channels.
133 | feature_scale: Division factor of number of convolutional channels. The bigger the less parameters in the model.
134 | num_res: Number of block resnet.
135 | """
136 |
137 | def __init__(
138 | self,
139 | num_input_channels=8,
140 | num_output_channels=3,
141 | feature_scale=4,
142 | num_res=1
143 |
144 | ):
145 | super().__init__()
146 |
147 | self.feature_scale = feature_scale
148 | base_channel = 32
149 |
150 | filters = [64, 128, 256, 512, 1024]
151 | filters = [x // self.feature_scale for x in filters]
152 |
153 | base_channel = 32
154 |
155 | self.feat_extract = nn.ModuleList([
156 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1),
157 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
158 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
159 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2),
160 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2),
161 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
162 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2),
163 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2),
164 | ])
165 |
166 | self.SCM0 = SCM(base_channel * 8)
167 | self.SCM1 = SCM(base_channel * 4)
168 | self.SCM2 = SCM(base_channel * 2)
169 |
170 | self.FAM0 = FAM(base_channel * 8)
171 | self.FAM1 = FAM(base_channel * 4)
172 | self.FAM2 = FAM(base_channel * 2)
173 |
174 | self.AFFs = nn.ModuleList([
175 | AFF(base_channel * 15, base_channel * 1),
176 | AFF(base_channel * 15, base_channel * 2),
177 | AFF(base_channel * 15, base_channel * 4),
178 | ])
179 |
180 | self.Encoder = nn.ModuleList([
181 | EBlock(base_channel, num_res),
182 | EBlock(base_channel * 2, num_res),
183 | EBlock(base_channel * 4, num_res),
184 | EBlock(base_channel * 8, num_res)
185 | ])
186 |
187 | self.Decoder = nn.ModuleList([
188 | DBlock(base_channel * 8, num_res),
189 | DBlock(base_channel * 4, num_res),
190 | DBlock(base_channel * 2, num_res),
191 | DBlock(base_channel, num_res)
192 | ])
193 |
194 | self.Convs = nn.ModuleList([
195 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1),
196 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
197 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1)
198 |
199 | ])
200 |
201 | self.ConvsOut = nn.ModuleList(
202 | [
203 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
204 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
205 | ]
206 | )
207 |
208 | self.up = nn.Upsample(scale_factor=4, mode='bilinear')
209 |
210 | def forward(self, *inputs, **kwargs):
211 | inputs = list(inputs)
212 |
213 | n_input = len(inputs)
214 |
215 | x = inputs[0]
216 | x_2 = inputs[1]
217 | x_4 = inputs[2]
218 | x_8 = inputs[3]
219 |
220 | z2 = self.SCM2(x_2)
221 | z4 = self.SCM1(x_4)
222 | z8 = self.SCM0(x_8)
223 |
224 | x_ = self.feat_extract[0](x)
225 | res1 = self.Encoder[0](x_)
226 |
227 | z = self.feat_extract[1](res1)
228 | z = self.FAM2(z, z2)
229 | res2 = self.Encoder[1](z)
230 |
231 | z = self.feat_extract[2](res2)
232 | z = self.FAM1(z, z4)
233 | res3 = self.Encoder[2](z)
234 |
235 | z = self.feat_extract[6](res3)
236 | z = self.FAM0(z, z8)
237 | z = self.Encoder[3](z)
238 |
239 | z12 = F.interpolate(res1, scale_factor=0.5)
240 | z13 = F.interpolate(res1, scale_factor=0.25)
241 |
242 | z21 = F.interpolate(res2, scale_factor=2)
243 | z23 = F.interpolate(res2, scale_factor=0.5)
244 |
245 | z32 = F.interpolate(res3, scale_factor=2)
246 | z31 = F.interpolate(res3, scale_factor=4)
247 |
248 | z43 = F.interpolate(z, scale_factor=2)
249 | z42 = F.interpolate(z43, scale_factor=2)
250 | z41 = F.interpolate(z42, scale_factor=2)
251 |
252 | res1 = self.AFFs[0](res1, z21, z31, z41)
253 | res2 = self.AFFs[1](z12, res2, z32, z42)
254 | res3 = self.AFFs[2](z13, z23, res3, z43)
255 | z = self.Decoder[0](z)
256 |
257 | z = self.feat_extract[7](z)
258 |
259 | z = self.up(z)
260 | z = torch.cat([z, res3], dim=1)
261 | z = self.Convs[0](z)
262 | z = self.Decoder[1](z)
263 |
264 | z = self.feat_extract[3](z)
265 | z = self.up(z)
266 |
267 | z = torch.cat([z, res2], dim=1)
268 | z = self.Convs[1](z)
269 | z = self.Decoder[2](z)
270 |
271 | z = self.feat_extract[4](z)
272 | z = self.up(z)
273 |
274 | z = torch.cat([z, res1], dim=1)
275 | z = self.Convs[2](z)
276 | z = self.Decoder[3](z)
277 | z = self.feat_extract[5](z)
278 |
279 | return {'im_out': z}
280 |
281 | class SimpleUNet(nn.Module):
282 | r""" Rendering network with UNet architecture and multi-scale input.
283 |
284 | Args:
285 | num_input_channels: Number of channels in the input tensor or list of tensors. An integer or a list of integers for each input tensor.
286 | num_output_channels: Number of output channels.
287 | feature_scale: Division factor of number of convolutional channels. The bigger the less parameters in the model.
288 | num_res: Number of block resnet.
289 | """
290 |
291 | def __init__(
292 | self,
293 | num_input_channels=8,
294 | num_output_channels=3,
295 | feature_scale=4,
296 | num_res=1
297 |
298 | ):
299 | super().__init__()
300 |
301 | self.feature_scale = feature_scale
302 | base_channel = 8
303 |
304 | filters = [64, 128, 256, 512, 1024]
305 | filters = [x // self.feature_scale for x in filters]
306 |
307 | base_channel = 8
308 |
309 | self.feat_extract = nn.ModuleList([
310 | BasicConv(8, base_channel, kernel_size=3, relu=True, stride=1),
311 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
312 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
313 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2),
314 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2),
315 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
316 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2),
317 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2),
318 | ])
319 |
320 | self.SCM0 = SCM(base_channel * 8)
321 | self.SCM1 = SCM(base_channel * 4)
322 | self.SCM2 = SCM(base_channel * 2)
323 |
324 | self.FAM0 = FAM(base_channel * 8)
325 | self.FAM1 = FAM(base_channel * 4)
326 | self.FAM2 = FAM(base_channel * 2)
327 |
328 | self.AFFs = nn.ModuleList([
329 | AFF(base_channel * 15, base_channel * 1),
330 | AFF(base_channel * 15, base_channel * 2),
331 | AFF(base_channel * 15, base_channel * 4),
332 | ])
333 |
334 | self.Encoder = nn.ModuleList([
335 | EBlock(base_channel, num_res),
336 | EBlock(base_channel * 2, num_res),
337 | EBlock(base_channel * 4, num_res),
338 | EBlock(base_channel * 8, num_res)
339 | ])
340 |
341 | self.Decoder = nn.ModuleList([
342 | DBlock(base_channel * 8, num_res),
343 | DBlock(base_channel * 4, num_res),
344 | DBlock(base_channel * 2, num_res),
345 | DBlock(base_channel, num_res)
346 | ])
347 |
348 | self.Convs = nn.ModuleList([
349 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1),
350 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
351 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1)
352 |
353 | ])
354 |
355 | self.ConvsOut = nn.ModuleList(
356 | [
357 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
358 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
359 | ]
360 | )
361 |
362 | self.up = nn.Upsample(scale_factor=4, mode='bilinear')
363 |
364 | def forward(self, *inputs, **kwargs):
365 | inputs = list(inputs)
366 |
367 | n_input = len(inputs)
368 |
369 | x = inputs[0]
370 | x_2 = inputs[1]
371 | x_4 = inputs[2]
372 | x_8 = inputs[3]
373 |
374 | z2 = self.SCM2(x_2)
375 | z4 = self.SCM1(x_4)
376 | z8 = self.SCM0(x_8)
377 |
378 | x_ = self.feat_extract[0](x)
379 | res1 = self.Encoder[0](x_)
380 |
381 | z = self.feat_extract[1](res1)
382 | z = self.FAM2(z, z2)
383 | res2 = self.Encoder[1](z)
384 |
385 | z = self.feat_extract[2](res2)
386 | z = self.FAM1(z, z4)
387 | res3 = self.Encoder[2](z)
388 |
389 | z = self.feat_extract[6](res3)
390 | z = self.FAM0(z, z8)
391 | z = self.Encoder[3](z)
392 |
393 | z12 = F.interpolate(res1, scale_factor=0.5)
394 | z13 = F.interpolate(res1, scale_factor=0.25)
395 |
396 | z21 = F.interpolate(res2, scale_factor=2)
397 | z23 = F.interpolate(res2, scale_factor=0.5)
398 |
399 | z32 = F.interpolate(res3, scale_factor=2)
400 | z31 = F.interpolate(res3, scale_factor=4)
401 |
402 | z43 = F.interpolate(z, scale_factor=2)
403 | z42 = F.interpolate(z43, scale_factor=2)
404 | z41 = F.interpolate(z42, scale_factor=2)
405 |
406 | res1 = self.AFFs[0](res1, z21, z31, z41)
407 | res2 = self.AFFs[1](z12, res2, z32, z42)
408 | res3 = self.AFFs[2](z13, z23, res3, z43)
409 | z = self.Decoder[0](z)
410 |
411 | z = self.feat_extract[7](z)
412 |
413 | z = self.up(z)
414 | z = F.interpolate(z, res3.shape[-2:])
415 | z = torch.cat([z, res3], dim=1)
416 | z = self.Convs[0](z)
417 | z = self.Decoder[1](z)
418 |
419 | z = self.feat_extract[3](z)
420 | z = self.up(z)
421 |
422 | z = torch.cat([z, res2], dim=1)
423 | z = self.Convs[1](z)
424 | z = self.Decoder[2](z)
425 |
426 | z = self.feat_extract[4](z)
427 | z = self.up(z)
428 |
429 | z = torch.cat([z, res1], dim=1)
430 | z = self.Convs[2](z)
431 | z = self.Decoder[3](z)
432 | z = self.feat_extract[5](z)
433 |
434 | return {'im_out': z}
435 |
436 |
437 | class SimpleNet(nn.Module):
438 | def __init__(
439 | self,
440 | num_input_channels=8,
441 | num_output_channels=3,
442 | feature_scale=4,
443 | num_res=2
444 |
445 | ):
446 | super().__init__()
447 |
448 | self.feature_scale = feature_scale
449 | base_channel = 8
450 |
451 | filters = [64, 128, 256, 512, 1024]
452 | filters = [x // self.feature_scale for x in filters]
453 |
454 | base_channel = 16
455 |
456 | self.feat_extract = nn.ModuleList([
457 | BasicConv(16, base_channel, kernel_size=3, relu=True, stride=1),
458 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=2),
459 | BasicConv(base_channel * 2, base_channel * 4, kernel_size=3, relu=True, stride=2),
460 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=4, relu=True, stride=2),
461 | BasicConv(base_channel * 2, base_channel, kernel_size=4, relu=True, stride=2),
462 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1),
463 | BasicConv(base_channel * 4, base_channel * 8, kernel_size=3, relu=True, stride=2),
464 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=4, relu=True, stride=2),
465 | BasicConv(base_channel, base_channel * 2, kernel_size=3, relu=True, stride=1),
466 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=3, relu=True, stride=1),
467 | BasicConv(base_channel*4, base_channel * 2, kernel_size=3, relu=True, stride=1),
468 | BasicConv(base_channel * 2, base_channel * 1, kernel_size=3, relu=True, stride=1)
469 | ])
470 |
471 | self.SCM0 = SCM(base_channel * 8)
472 | self.SCM1 = SCM(base_channel * 4)
473 | self.SCM2 = SCM(base_channel * 2)
474 |
475 | self.FAM0 = FAM(base_channel * 8)
476 | self.FAM1 = FAM(base_channel * 4)
477 | self.FAM2 = FAM(base_channel * 2)
478 |
479 | self.AFFs = nn.ModuleList([
480 | AFF(base_channel * 15, base_channel * 1),
481 | AFF(base_channel * 15, base_channel * 2),
482 | AFF(base_channel * 15, base_channel * 4),
483 | ])
484 |
485 | self.Encoder = nn.ModuleList([
486 | EBlock(base_channel, num_res),
487 | EBlock(base_channel * 2, num_res),
488 | EBlock(base_channel * 4, num_res),
489 | EBlock(base_channel * 8, num_res)
490 | ])
491 |
492 | self.Decoder = nn.ModuleList([
493 | DBlock(base_channel * 8, num_res),
494 | DBlock(base_channel * 4, num_res),
495 | DBlock(base_channel * 2, num_res),
496 | DBlock(base_channel, num_res)
497 | ])
498 |
499 | self.Convs = nn.ModuleList([
500 | BasicConv(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1),
501 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
502 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
503 | BasicConv(base_channel * 8, base_channel * 8, kernel_size=1, relu=True, stride=1),
504 | BasicConv(base_channel * 4, base_channel * 4, kernel_size=1, relu=True, stride=1),
505 | BasicConv(base_channel * 2, base_channel * 2, kernel_size=1, relu=True, stride=1),
506 |
507 | ])
508 |
509 | self.ConvsOut = nn.ModuleList(
510 | [
511 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
512 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
513 | ]
514 | )
515 |
516 | self.up = nn.Upsample(scale_factor=2, mode='bilinear')
517 |
518 | def forward(self, *inputs, **kwargs):
519 | inputs = list(inputs)
520 |
521 | n_input = len(inputs)
522 |
523 | x = inputs[0]
524 |
525 | # x = self.feat_extract[0](x)
526 | x = self.Encoder[0](x)
527 | z1 = x
528 | # print(x.shape, "1")
529 |
530 | x = self.feat_extract[1](x)
531 | x = self.Encoder[1](x)
532 | z2 = x
533 | # print(x.shape, "2")
534 |
535 | x = self.feat_extract[2](x)
536 | x = self.Encoder[2](x)
537 | z3 = x
538 | # print(x.shape, "3")
539 |
540 | x = self.feat_extract[6](x)
541 | x = self.Encoder[3](x)
542 | # print(x.shape, "4")
543 |
544 | x = self.up(x)
545 | # print(x.shape, "5")
546 |
547 | x = self.Convs[0](x)
548 | x = torch.cat([x, z3], dim=1)
549 | x = self.Decoder[0](x)
550 |
551 | # print(x.shape, "6")
552 |
553 |
554 | x = self.feat_extract[9](x)
555 | x = self.up(x)
556 | # print(x.shape, "7")
557 |
558 |
559 | x = self.Convs[1](x)
560 | x = torch.cat([x, z2], dim=1)
561 | x = self.Decoder[1](x)
562 |
563 | # print(x.shape, "8")
564 |
565 | x = self.feat_extract[10](x)
566 | x = self.up(x)
567 | # print(x.shape, "9")
568 |
569 | x = self.Convs[2](x)
570 | x = torch.cat([x, z1], dim=1)
571 | x = self.Decoder[2](x)
572 | # print(x.shape, "10")
573 |
574 | x = self.feat_extract[11](x)
575 | # print(x.shape, "11")
576 |
577 | z = self.feat_extract[5](x)
578 | # print(z.shape, "12")
579 |
580 | return {'im_out': z}
581 |
582 | # # x = self.feat_extract[0](x)
583 | # x = self.Encoder[0](x)
584 | # print(x.shape, "1")
585 | #
586 | # x = self.feat_extract[1](x)
587 | # x = self.Encoder[1](x)
588 | # print(x.shape, "2")
589 | #
590 | # x = self.feat_extract[2](x)
591 | # x = self.Encoder[2](x)
592 | # print(x.shape, "3")
593 | #
594 | # x = self.feat_extract[6](x)
595 | # x = self.Encoder[3](x)
596 | # print(x.shape, "4")
597 | #
598 | # x = self.up(x)
599 | # print(x.shape, "5")
600 | #
601 | # x = self.Convs[3](x)
602 | # x = self.Decoder[0](x)
603 | # print(x.shape, "6")
604 | #
605 | # x = self.feat_extract[9](x)
606 | # x = self.up(x)
607 | # print(x.shape, "7")
608 | #
609 | # x = self.Convs[4](x)
610 | # x = self.Decoder[1](x)
611 | # print(x.shape, "8")
612 | #
613 | # x = self.feat_extract[10](x)
614 | # x = self.up(x)
615 | # print(x.shape, "9")
616 | #
617 | # x = self.Convs[5](x)
618 | # x = self.Decoder[2](x)
619 | # print(x.shape, "10")
620 | #
621 | # x = self.feat_extract[11](x)
622 | # print(x.shape, "11")
623 | #
624 | # z = self.feat_extract[5](x)
625 | # print(z.shape, "12")
626 | #
627 | # return {'im_out': z}
628 |
629 | if __name__ == '__main__':
630 | import pdb
631 | import time
632 | import numpy as np
633 |
634 | # model = UNet().to('cuda')
635 | model = SimpleNet().to('cuda')
636 | input = []
637 | img_sh = [1408, 376]
638 | sh_unit = 8
639 | # img_sh = list(map(lambda a: a - a % sh_unit + sh_unit if a % sh_unit != 0 else a, img_sh))
640 |
641 | # print(img_sh)
642 | down = lambda a, b: a // 2 ** b
643 | input.append(torch.zeros((1, 8, down(img_sh[0], 0), down(img_sh[1], 0)), requires_grad=True).cuda())
644 | input.append(F.interpolate(input[0], scale_factor=0.5))
645 | input.append(F.interpolate(input[1], scale_factor=0.5))
646 | input.append(F.interpolate(input[2], scale_factor=0.5))
647 | print(input)
648 |
649 | model.eval()
650 | st = time.time()
651 | print(input[0].max(), input[0].min())
652 | print(input[0].shape, input[1].shape, input[2].shape, input[3].shape)
653 | with torch.set_grad_enabled(False):
654 | out = model(*input)
655 | pdb.set_trace()
656 | print('model', time.time() - st)
657 | print(out['im_out'], out['im_out'].shape)
658 | model.to('cpu')
659 |
--------------------------------------------------------------------------------
/scene/gaussian_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2023, Inria
3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco
4 | # All rights reserved.
5 | #
6 | # This software is free for non-commercial, research and evaluation use
7 | # under the terms of the LICENSE.md file.
8 | #
9 | # For inquiries contact george.drettakis@inria.fr
10 | #
11 |
12 | import torch
13 | import numpy as np
14 | from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
15 | from torch import nn
16 | import os
17 | from utils.system_utils import mkdir_p
18 | from plyfile import PlyData, PlyElement
19 | from utils.sh_utils import RGB2SH
20 | from simple_knn._C import distCUDA2
21 | from utils.graphics_utils import BasicPointCloud
22 | from utils.general_utils import strip_symmetric, build_scaling_rotation, get_minimum_axis, flip_align_view
23 | import tinycudann as tcnn
24 |
25 | class GaussianModel:
26 |
27 | def setup_functions(self):
28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation)
30 | actual_covariance = L @ L.transpose(1, 2)
31 | symm = strip_symmetric(actual_covariance)
32 | return symm
33 |
34 | self.scaling_activation = torch.exp
35 | self.scaling_inverse_activation = torch.log
36 |
37 | self.covariance_activation = build_covariance_from_scaling_rotation
38 |
39 | self.opacity_activation = torch.sigmoid
40 | self.inverse_opacity_activation = inverse_sigmoid
41 |
42 | self.rotation_activation = torch.nn.functional.normalize
43 |
44 | def __init__(self, sh_degree : int, num_sem_classes : int): # num_sem_classes: 语义类别数
45 | self.active_sh_degree = 0
46 | self.max_sh_degree = sh_degree
47 | self.num_sem_classes = num_sem_classes
48 | self._xyz = torch.empty(0)
49 | # self._features_dc = torch.empty(0)
50 | # self._features_rest = torch.empty(0)
51 | self._semantic = torch.empty(0) # 语义参数
52 | self._scaling = torch.empty(0)
53 | self._rotation = torch.empty(0)
54 | self._opacity = torch.empty(0)
55 | self.max_radii2D = torch.empty(0)
56 | self.xyz_gradient_accum = torch.empty(0)
57 | self.denom = torch.empty(0)
58 | self.optimizer = None
59 | self.percent_dense = 0
60 | self.spatial_lr_scale = 0
61 | self.setup_functions()
62 |
63 | self.recolor = tcnn.Encoding(
64 | n_input_dims=3,
65 | encoding_config={
66 | "otype": "HashGrid",
67 | "n_levels": 16,
68 | "n_features_per_level": 2,
69 | "log2_hashmap_size": 22,
70 | "base_resolution": 128,
71 | "per_level_scale": 1.4,
72 | },
73 | )
74 | self.positional_encoding = tcnn.Encoding(
75 | n_input_dims=3,
76 | encoding_config={
77 | "otype": "Frequency",
78 | "n_frequencies": 4
79 | },
80 | )
81 | self.direction_encoding = tcnn.Encoding(
82 | n_input_dims=3,
83 | encoding_config={
84 | "otype": "SphericalHarmonics",
85 | "degree": 3
86 | },
87 | )
88 | self.mlp_head = tcnn.Network(
89 | n_input_dims=(self.recolor.n_output_dims),
90 | n_output_dims=8,
91 | network_config={
92 | "otype": "FullyFusedMLP",
93 | "activation": "ReLU",
94 | "output_activation": "None",
95 | "n_neurons": 64,
96 | "n_hidden_layers": 2,
97 | },
98 | )
99 | self.mlp_direction_head = tcnn.Network(
100 | n_input_dims=(self.direction_encoding.n_output_dims + 3),
101 | n_output_dims=1,
102 | network_config={
103 | "otype": "FullyFusedMLP",
104 | "activation": "ReLU",
105 | "output_activation": "Sigmoid",
106 | "n_neurons": 32,
107 | "n_hidden_layers": 1,
108 | },
109 | )
110 | self.mlp_normal_head = tcnn.Network(
111 | n_input_dims=16,
112 | n_output_dims=3,
113 | network_config={
114 | "otype": "FullyFusedMLP",
115 | "activation": "ReLU",
116 | "output_activation": "None",
117 | "n_neurons": 32,
118 | "n_hidden_layers": 1,
119 | },
120 | )
121 |
122 | def capture(self):
123 | return (
124 | self.active_sh_degree,
125 | self.num_sem_classes,
126 | self._xyz,
127 | # self._features_dc,
128 | # self._features_rest,
129 | self._semantic,
130 | self._scaling,
131 | self._rotation,
132 | self._opacity,
133 | self.max_radii2D,
134 | self.xyz_gradient_accum,
135 | self.denom,
136 | self.optimizer.state_dict(),
137 | self.spatial_lr_scale,
138 | )
139 |
140 | def restore(self, model_args, training_args):
141 | (self.active_sh_degree,
142 | self._xyz,
143 | # self._features_dc,
144 | # self._features_rest,
145 | self._semantic,
146 | self._sem_class,
147 | self._scaling,
148 | self._rotation,
149 | self._opacity,
150 | self.max_radii2D,
151 | xyz_gradient_accum,
152 | denom,
153 | opt_dict,
154 | self.spatial_lr_scale) = model_args
155 | self.training_setup(training_args)
156 | self.xyz_gradient_accum = xyz_gradient_accum
157 | self.denom = denom
158 | self.optimizer.load_state_dict(opt_dict)
159 |
160 | @property
161 | def get_scaling(self):
162 | return self.scaling_activation(self._scaling)
163 |
164 | @property
165 | def get_rotation(self):
166 | return self.rotation_activation(self._rotation)
167 |
168 | @property
169 | def get_xyz(self):
170 | return self._xyz
171 |
172 | @property
173 | def get_minimum_axis(self):
174 | return get_minimum_axis(self.get_scaling, self.get_rotation)
175 |
176 | @property
177 | def get_pseudo_normal(self, dir_pp_normalized):
178 | normal_axis = self.get_minimum_axis
179 | normal_axis, positive = flip_align_view(normal_axis, dir_pp_normalized)
180 | normal = normal_axis
181 | normal = normal / normal.norm(dim=1, keepdim=True) # (N, 3)
182 | return normal
183 |
184 | @property
185 | def get_dir_pp(self, view):
186 | means3D = self.get_xyz
187 | dir_pp = (means3D - view.camera_center.repeat(means3D.shape[0], 1))
188 | dir_pp = dir_pp / dir_pp.norm(dim=1, keepdim=True)
189 | return dir_pp
190 |
191 | @property
192 | def get_normal(self):
193 | normal = self.mlp_normal_head(self.get_semantic)
194 | # normal = self.mlp_normal_head(
195 | # torch.cat((self.get_semantic, self.get_rotation, self.get_scaling, self._xyz), dim=1))
196 | return normal
197 |
198 | # @property
199 | # def get_features(self):
200 | # features_dc = self._features_dc
201 | # features_rest = self._features_rest
202 | # return torch.cat((features_dc, features_rest), dim=1)
203 |
204 | @property
205 | def get_semantic(self):
206 | return self._semantic
207 |
208 | @property
209 | def get_opacity(self):
210 | return self.opacity_activation(self._opacity)
211 |
212 | def get_covariance(self, scaling_modifier = 1):
213 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
214 |
215 | def oneupSHdegree(self):
216 | if self.active_sh_degree < self.max_sh_degree:
217 | self.active_sh_degree += 1
218 |
219 | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
220 | self.spatial_lr_scale = spatial_lr_scale
221 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() # [10458, 3]
222 | # fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) # [10458, 3]
223 | # features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() # [10458, 3, 16]
224 | # features[:, :3, 0 ] = fused_color
225 | # features[:, 3:, 1:] = 0.0 # [10458, 3, 16]
226 |
227 | semantic = torch.zeros((fused_point_cloud.shape[0], self.num_sem_classes)).float().cuda() # 语义 [..., 20, 1]
228 |
229 | print("Number of points at initialisation : ", fused_point_cloud.shape[0])
230 |
231 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
232 | scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
233 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
234 | rots[:, 0] = 1
235 |
236 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
237 |
238 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
239 | # self._features_dc = nn.Parameter(features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) # [10458, 1, 3]
240 | # self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) # [10458, 15, 3]
241 | self._semantic = nn.Parameter(semantic.requires_grad_(True)) # [10458, 1, 20] ???
242 | self._scaling = nn.Parameter(scales.requires_grad_(True))
243 | self._rotation = nn.Parameter(rots.requires_grad_(True))
244 | self._opacity = nn.Parameter(opacities.requires_grad_(True))
245 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
246 |
247 |
248 | def training_setup(self, training_args):
249 | self.percent_dense = training_args.percent_dense
250 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
251 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
252 |
253 | other_params = []
254 | # for params in self.recolor.parameters():
255 | # other_params.append(params)
256 | # for params in self.mlp_head.parameters():
257 | # other_params.append(params)
258 | for params in self.mlp_direction_head.parameters():
259 | other_params.append(params)
260 | for params in self.direction_encoding.parameters():
261 | other_params.append(params)
262 | for params in self.mlp_normal_head.parameters():
263 | other_params.append(params)
264 | for params in self.positional_encoding.parameters():
265 | other_params.append(params)
266 |
267 | l = [
268 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
269 | # {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
270 | # {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
271 | {'params': [self._semantic], 'lr': training_args.semantic_lr, "name": "semantic"},
272 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
273 | {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
274 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
275 | ]
276 |
277 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
278 | self.optimizer_net = torch.optim.Adam(other_params, lr=1e-2, eps=1e-15) # 5e-3
279 | self.scheduler_net = torch.optim.lr_scheduler.ChainedScheduler(
280 | [
281 | torch.optim.lr_scheduler.LinearLR(
282 | self.optimizer_net, start_factor=0.01, total_iters=100
283 | ),
284 | torch.optim.lr_scheduler.MultiStepLR(
285 | self.optimizer_net,
286 | milestones=[5_000, 15_000, 25_000],
287 | gamma=0.33,
288 | ),
289 | ]
290 | )
291 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
292 | lr_final=training_args.position_lr_final * self.spatial_lr_scale,
293 | lr_delay_mult=training_args.position_lr_delay_mult,
294 | max_steps=training_args.position_lr_max_steps)
295 |
296 | def update_learning_rate(self, iteration):
297 | ''' Learning rate scheduling per step '''
298 | for param_group in self.optimizer.param_groups:
299 | if param_group["name"] == "xyz":
300 | lr = self.xyz_scheduler_args(iteration)
301 | param_group['lr'] = lr
302 | return lr
303 |
304 | def construct_list_of_attributes(self):
305 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
306 | # All channels except the 3 DC
307 | # for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
308 | # l.append('f_dc_{}'.format(i))
309 | # for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
310 | # l.append('f_rest_{}'.format(i))
311 | for i in range(self._semantic.shape[1]):
312 | l.append('semantic_{}'.format(i))
313 | l.append('opacity')
314 | for i in range(self._scaling.shape[1]):
315 | l.append('scale_{}'.format(i))
316 | for i in range(self._rotation.shape[1]):
317 | l.append('rot_{}'.format(i))
318 | return l
319 |
320 | def save_ply(self, path):
321 | mkdir_p(os.path.dirname(path))
322 |
323 | xyz = self._xyz.detach().cpu().numpy()
324 | normals = np.zeros_like(xyz)
325 | # f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
326 | # f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
327 | semantic = self._semantic.detach().cpu().numpy()
328 | opacities = self._opacity.detach().cpu().numpy()
329 | scale = self._scaling.detach().cpu().numpy()
330 | rotation = self._rotation.detach().cpu().numpy()
331 |
332 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
333 |
334 | elements = np.empty(xyz.shape[0], dtype=dtype_full)
335 | # attributes = np.concatenate((xyz, normals, f_dc, f_rest, semantic, opacities, scale, rotation), axis=1)
336 | attributes = np.concatenate((xyz, normals, semantic, opacities, scale, rotation), axis=1)
337 | # attributes = np.concatenate((xyz, normals, opacities, scale, rotation), axis=1)
338 | elements[:] = list(map(tuple, attributes))
339 | el = PlyElement.describe(elements, 'vertex')
340 | PlyData([el]).write(path)
341 |
342 | def reset_opacity(self):
343 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
344 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
345 | self._opacity = optimizable_tensors["opacity"]
346 |
347 | def load_ply(self, path):
348 | plydata = PlyData.read(path)
349 |
350 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
351 | np.asarray(plydata.elements[0]["y"]),
352 | np.asarray(plydata.elements[0]["z"])), axis=1)
353 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
354 |
355 | semantic_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("semantic_")]
356 | semantic_names = sorted(semantic_names, key = lambda x: int(x.split('_')[-1]))
357 | semantic = np.zeros((xyz.shape[0], len(semantic_names)))
358 | for idx, attr_name in enumerate(semantic_names):
359 | semantic[:, idx] = np.asarray(plydata.elements[0][attr_name])
360 |
361 | # features_dc = np.zeros((xyz.shape[0], 3, 1))
362 | # features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
363 | # features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
364 | # features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
365 |
366 | # extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
367 | # extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
368 | # assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
369 | # features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
370 | # for idx, attr_name in enumerate(extra_f_names):
371 | # features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
372 | # # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
373 | # features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
374 |
375 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
376 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
377 | scales = np.zeros((xyz.shape[0], len(scale_names)))
378 | for idx, attr_name in enumerate(scale_names):
379 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
380 |
381 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
382 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
383 | rots = np.zeros((xyz.shape[0], len(rot_names)))
384 | for idx, attr_name in enumerate(rot_names):
385 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
386 |
387 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
388 | # self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
389 | # self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
390 | self._semantic = nn.Parameter(torch.tensor(semantic, dtype=torch.float, device="cuda").requires_grad_(True))
391 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
392 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
393 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
394 |
395 | self.active_sh_degree = self.max_sh_degree
396 |
397 | def replace_tensor_to_optimizer(self, tensor, name):
398 | optimizable_tensors = {}
399 | for group in self.optimizer.param_groups:
400 | if group["name"] == name:
401 | stored_state = self.optimizer.state.get(group['params'][0], None)
402 | stored_state["exp_avg"] = torch.zeros_like(tensor)
403 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
404 |
405 | del self.optimizer.state[group['params'][0]]
406 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
407 | self.optimizer.state[group['params'][0]] = stored_state
408 |
409 | optimizable_tensors[group["name"]] = group["params"][0]
410 | return optimizable_tensors
411 |
412 | def _prune_optimizer(self, mask):
413 | optimizable_tensors = {}
414 | for group in self.optimizer.param_groups:
415 | stored_state = self.optimizer.state.get(group['params'][0], None)
416 | if stored_state is not None:
417 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
418 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
419 |
420 | del self.optimizer.state[group['params'][0]]
421 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
422 | self.optimizer.state[group['params'][0]] = stored_state
423 |
424 | optimizable_tensors[group["name"]] = group["params"][0]
425 | else:
426 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
427 | optimizable_tensors[group["name"]] = group["params"][0]
428 | return optimizable_tensors
429 |
430 | def prune_points(self, mask):
431 | valid_points_mask = ~mask
432 | optimizable_tensors = self._prune_optimizer(valid_points_mask)
433 |
434 | self._xyz = optimizable_tensors["xyz"]
435 | # self._features_dc = optimizable_tensors["f_dc"]
436 | # self._features_rest = optimizable_tensors["f_rest"]
437 | self._semantic = optimizable_tensors["semantic"]
438 | self._opacity = optimizable_tensors["opacity"]
439 | self._scaling = optimizable_tensors["scaling"]
440 | self._rotation = optimizable_tensors["rotation"]
441 |
442 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
443 |
444 | self.denom = self.denom[valid_points_mask]
445 | self.max_radii2D = self.max_radii2D[valid_points_mask]
446 |
447 | def cat_tensors_to_optimizer(self, tensors_dict):
448 | optimizable_tensors = {}
449 | for group in self.optimizer.param_groups:
450 | assert len(group["params"]) == 1
451 | extension_tensor = tensors_dict[group["name"]]
452 | stored_state = self.optimizer.state.get(group['params'][0], None)
453 | if stored_state is not None:
454 |
455 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
456 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
457 |
458 | del self.optimizer.state[group['params'][0]]
459 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
460 | self.optimizer.state[group['params'][0]] = stored_state
461 |
462 | optimizable_tensors[group["name"]] = group["params"][0]
463 | else:
464 | group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
465 | optimizable_tensors[group["name"]] = group["params"][0]
466 |
467 | return optimizable_tensors
468 |
469 | # def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacities, new_scaling, new_rotation):
470 | def densification_postfix(self, new_xyz, new_semantic, new_opacities,new_scaling, new_rotation):
471 | # def densification_postfix(self, new_xyz, new_opacities, new_scaling, new_rotation):
472 | d = {"xyz": new_xyz,
473 | # "f_dc": new_features_dc,
474 | # "f_rest": new_features_rest,
475 | "semantic": new_semantic,
476 | "opacity": new_opacities,
477 | "scaling" : new_scaling,
478 | "rotation" : new_rotation}
479 |
480 | optimizable_tensors = self.cat_tensors_to_optimizer(d)
481 | self._xyz = optimizable_tensors["xyz"]
482 | # self._features_dc = optimizable_tensors["f_dc"]
483 | # self._features_rest = optimizable_tensors["f_rest"]
484 | self._semantic = optimizable_tensors["semantic"]
485 | self._opacity = optimizable_tensors["opacity"]
486 | self._scaling = optimizable_tensors["scaling"]
487 | self._rotation = optimizable_tensors["rotation"]
488 |
489 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
490 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
491 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
492 |
493 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
494 | n_init_points = self.get_xyz.shape[0]
495 | # Extract points that satisfy the gradient condition
496 | padded_grad = torch.zeros((n_init_points), device="cuda")
497 | padded_grad[:grads.shape[0]] = grads.squeeze()
498 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
499 | selected_pts_mask = torch.logical_and(selected_pts_mask,
500 | torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
501 |
502 | stds = self.get_scaling[selected_pts_mask].repeat(N,1)
503 | means = torch.zeros((stds.size(0), 3),device="cuda")
504 | samples = torch.normal(mean=means, std=stds)
505 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
506 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
507 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
508 | new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
509 | # new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
510 | # new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
511 | new_semantic = self._semantic[selected_pts_mask].repeat(N,1)
512 | new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
513 |
514 | # self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacity, new_scaling, new_rotation)
515 | self.densification_postfix(new_xyz, new_semantic, new_opacity, new_scaling,new_rotation)
516 | # self.densification_postfix(new_xyz, new_opacity, new_scaling, new_rotation)
517 |
518 | prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
519 | self.prune_points(prune_filter)
520 |
521 | def densify_and_clone(self, grads, grad_threshold, scene_extent):
522 | # Extract points that satisfy the gradient condition
523 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
524 | selected_pts_mask = torch.logical_and(selected_pts_mask,
525 | torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
526 |
527 | new_xyz = self._xyz[selected_pts_mask]
528 | # new_features_dc = self._features_dc[selected_pts_mask]
529 | # new_features_rest = self._features_rest[selected_pts_mask]
530 | new_semantic = self._semantic[selected_pts_mask]
531 | new_opacities = self._opacity[selected_pts_mask]
532 | new_scaling = self._scaling[selected_pts_mask]
533 | new_rotation = self._rotation[selected_pts_mask]
534 |
535 | # self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_semantic, new_opacities, new_scaling, new_rotation)
536 | self.densification_postfix(new_xyz, new_semantic, new_opacities, new_scaling, new_rotation)
537 | # self.densification_postfix(new_xyz, new_opacities, new_scaling, new_rotation)
538 |
539 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
540 | grads = self.xyz_gradient_accum / self.denom
541 | grads[grads.isnan()] = 0.0
542 |
543 | self.densify_and_clone(grads, max_grad, extent)
544 | self.densify_and_split(grads, max_grad, extent)
545 |
546 | prune_mask = (self.get_opacity < min_opacity).squeeze()
547 | if max_screen_size:
548 | big_points_vs = self.max_radii2D > max_screen_size
549 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
550 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
551 | self.prune_points(prune_mask)
552 |
553 | torch.cuda.empty_cache()
554 |
555 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
556 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True)
557 | self.denom[update_filter] += 1
558 |
559 |
--------------------------------------------------------------------------------