├── submodules ├── simple-knn │ └── README.md └── depth-diff-gaussian-rasterization │ └── README.md ├── assets └── method.png ├── requirements.txt ├── arguments ├── endonerf │ ├── pulling.py │ └── cutting.py └── __init__.py ├── utils ├── image_utils.py ├── system_utils.py ├── graphics_utils.py ├── preprocess.py ├── time_utils.py ├── pose_utils.py ├── loss_utils.py ├── rigid_utils.py ├── camera_utils.py ├── sh_utils.py ├── initial_utils.py └── general_utils.py ├── LICENSE ├── scene ├── deform_model.py ├── cameras.py ├── __init__.py ├── dataset_readers.py ├── colmap_loader.py └── gaussian_model.py ├── README.md ├── gaussian_renderer ├── network_gui.py └── __init__.py ├── metrics.py ├── train.py └── render.py /submodules/simple-knn/README.md: -------------------------------------------------------------------------------- 1 | # simple-knn 2 | 3 | -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwx0924/SurgicalGaussian/HEAD/assets/method.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | plyfile==0.8.1 2 | tqdm 3 | imageio==2.27.0 4 | opencv-python 5 | imageio-ffmpeg 6 | scipy 7 | dearpygui 8 | lpips 9 | -------------------------------------------------------------------------------- /arguments/endonerf/pulling.py: -------------------------------------------------------------------------------- 1 | ModelParams = dict( 2 | dataset_type='endonerf', 3 | depth_scale=100.0, 4 | frame_nums=63, 5 | test_id=[1, 9, 17, 25, 33, 41, 49, 57], 6 | is_mask=True, 7 | depth_initial=True, 8 | accurate_mask=True, 9 | is_depth=True, 10 | ) 11 | 12 | OptimizationParams = dict( 13 | iterations=40_000, 14 | ) -------------------------------------------------------------------------------- /arguments/endonerf/cutting.py: -------------------------------------------------------------------------------- 1 | ModelParams = dict( 2 | dataset_type='endonerf', 3 | depth_scale=100.0, 4 | frame_nums=156, 5 | test_id=[1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105, 113, 121, 129, 137, 145, 153], 6 | is_mask=True, 7 | depth_initial=True, 8 | accurate_mask=False, 9 | is_depth=True, 10 | ) 11 | OptimizationParams = dict( 12 | iterations=40_000, 13 | densify_grad_threshold=0.0003, 14 | lambda_cov=40, 15 | lambda_pos=0.2, 16 | ) 17 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | 15 | def mse(img1, img2): 16 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 17 | 18 | 19 | def psnr(img1, img2): 20 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 21 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 22 | 23 | -------------------------------------------------------------------------------- /submodules/depth-diff-gaussian-rasterization/README.md: -------------------------------------------------------------------------------- 1 | # Differential Gaussian Rasterization 2 | 3 | 4 |
5 |
6 |

BibTeX

7 |
@Article{kerbl3Dgaussians,
 8 |       author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
 9 |       title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
10 |       journal      = {ACM Transactions on Graphics},
11 |       number       = {4},
12 |       volume       = {42},
13 |       month        = {July},
14 |       year         = {2023},
15 |       url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
16 | }
17 |
18 |
19 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | 17 | def mkdir_p(folder_path): 18 | # Creates a directory. equivalent to using mkdir -p on the command line 19 | try: 20 | makedirs(folder_path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == EEXIST and path.isdir(folder_path): 23 | pass 24 | else: 25 | raise 26 | 27 | 28 | def searchForMaxIteration(folder): 29 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 30 | return max(saved_iters) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 IGLICT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scene/deform_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.time_utils import DeformNetwork 5 | import os 6 | from utils.system_utils import searchForMaxIteration 7 | from utils.general_utils import get_expon_lr_func 8 | 9 | 10 | class DeformModel: 11 | def __init__(self): 12 | self.deform = DeformNetwork().cuda() 13 | self.optimizer = None 14 | 15 | self.spatial_lr_scale = 1 16 | 17 | def step(self, xyz, time_emb): 18 | return self.deform(xyz, time_emb) 19 | 20 | def train_setting(self, training_args): 21 | l = [ 22 | {'params': list(self.deform.parameters()), 23 | 'lr': training_args.position_lr_init * self.spatial_lr_scale, 24 | "name": "deform"} 25 | ] 26 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 27 | 28 | self.deform_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 29 | lr_final=training_args.position_lr_final, 30 | lr_delay_mult=training_args.position_lr_delay_mult, 31 | max_steps=training_args.deform_lr_max_steps) 32 | 33 | def save_weights(self, model_path, iteration): 34 | out_weights_path = os.path.join(model_path, "deform/iteration_{}".format(iteration)) 35 | os.makedirs(out_weights_path, exist_ok=True) 36 | torch.save(self.deform.state_dict(), os.path.join(out_weights_path, 'deform.pth')) 37 | 38 | def load_weights(self, model_path, iteration=-1): 39 | if iteration == -1: 40 | loaded_iter = searchForMaxIteration(os.path.join(model_path, "deform")) 41 | else: 42 | loaded_iter = iteration 43 | weights_path = os.path.join(model_path, "deform/iteration_{}/deform.pth".format(loaded_iter)) 44 | self.deform.load_state_dict(torch.load(weights_path)) 45 | 46 | def update_learning_rate(self, iteration): 47 | for param_group in self.optimizer.param_groups: 48 | if param_group["name"] == "deform": 49 | lr = self.deform_scheduler_args(iteration) 50 | param_group['lr'] = lr 51 | return lr 52 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | 18 | class BasicPointCloud(NamedTuple): 19 | points: np.array 20 | colors: np.array 21 | normals: np.array 22 | 23 | 24 | def geom_transform_points(points, transf_matrix): 25 | P, _ = points.shape 26 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 27 | points_hom = torch.cat([points, ones], dim=1) 28 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 29 | 30 | denom = points_out[..., 3:] + 0.0000001 31 | return (points_out[..., :3] / denom).squeeze(dim=0) 32 | 33 | 34 | def getWorld2View(R, t): 35 | Rt = np.zeros((4, 4)) 36 | Rt[:3, :3] = R.transpose() 37 | Rt[:3, 3] = t 38 | Rt[3, 3] = 1.0 39 | return np.float32(Rt) 40 | 41 | 42 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 43 | Rt = np.zeros((4, 4)) 44 | Rt[:3, :3] = R.transpose() 45 | Rt[:3, 3] = t 46 | Rt[3, 3] = 1.0 47 | 48 | C2W = np.linalg.inv(Rt) 49 | cam_center = C2W[:3, 3] 50 | cam_center = (cam_center + translate) * scale 51 | C2W[:3, 3] = cam_center 52 | Rt = np.linalg.inv(C2W) 53 | return np.float32(Rt) 54 | 55 | 56 | def getProjectionMatrix(znear, zfar, fovX, fovY): 57 | tanHalfFovY = math.tan((fovY / 2)) 58 | tanHalfFovX = math.tan((fovX / 2)) 59 | 60 | top = tanHalfFovY * znear 61 | bottom = -top 62 | right = tanHalfFovX * znear 63 | left = -right 64 | 65 | P = torch.zeros(4, 4) 66 | 67 | z_sign = 1.0 68 | 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | 79 | def fov2focal(fov, pixels): 80 | return pixels / (2 * math.tan(fov / 2)) 81 | 82 | 83 | def focal2fov(focal, pixels): 84 | return 2 * math.atan(pixels / (2 * focal)) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SurgicalGaussian: Deformable 3D Gaussians for High-Fidelity Surgical Scene Reconstruction 2 | 3 | ### [Project Page](https://surgicalgaussian.github.io/) 4 | 5 | 6 | ------------------------------------------- 7 | ![introduction](assets/method.png) 8 | 9 | ## Environment 10 | Please follow the [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting) and [4DGS](https://github.com/hustvl/4DGaussians) to install the relative packages. 11 | ```bash 12 | git clone https://github.com/xwx0924/SurgicalGaussian.git 13 | cd SurgicalGaussian 14 | 15 | conda create -n SurgicalGaussian python=3.7 16 | conda activate SurgicalGaussian 17 | 18 | # install pytorch and others. 19 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 20 | pip install -r requirements.txt 21 | # You also need to install the pytorch3d library to compute Gaussian neighborhoods. 22 | 23 | # You can follow 4DGS to download depth-diff-gaussian-rasterization and simple-knn. 24 | pip install -e submodules/depth-diff-gaussian-rasterization 25 | pip install -e submodules/simple-knn 26 | ``` 27 | 28 | ## Dataset 29 | **EndoNeRF Dataset:** 30 | The dataset from [EndoNeRF](https://github.com/med-air/EndoNeRF) is used, which can be downloaded from their website. We use the clips 'pulling_soft_tissues' and 'cutting_tissues_twice'. 31 | 32 | **StereoMIS Dataset:** 33 | The dataset provided in [StereoMIS](https://zenodo.org/records/7727692) is used. We use the clips 'p2-7' and 'p2-8'. The resulted file structure is as follows. 34 | ``` 35 | ├── data 36 | │ | EndoNeRF 37 | │ ├── pulling 38 | │ ├── cutting 39 | │ | StereoMIS 40 | │ ├── intestine 41 | │ ├── liver 42 | | ├── ... 43 | ``` 44 | 45 | 46 | ## Training 47 | For surgical scene `pulling_soft_tissues`, run 48 | ``` 49 | python train.py -s data/EndoNeRF/pulling -m output/pulling --config arguments/endonerf/pulling.py 50 | ``` 51 | 52 | ## Rendering 53 | Run the following script to render the images. 54 | 55 | ``` 56 | python render.py -m output/pulling 57 | ``` 58 | 59 | 60 | ## Evaluation 61 | Run the following script to evaluate the model. 62 | 63 | ``` 64 | python metrics.py -m output/pulling 65 | ``` 66 | 67 | --- 68 | ## Acknowledgement 69 | 70 | 71 | 72 | Some source code is borrowed from [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [4DGS](https://github.com/hustvl/4DGaussians), and [Deformable-3D-Gaussian](https://github.com/ingra14m/Deformable-3D-Gaussians/tree/main). Thanks for their excellent code. 73 | 74 | 75 | ## Citation 76 | If you find this work helpful, welcome to cite this paper. 77 | ``` 78 | @article{xie2024surgicalgaussian, 79 | author = {Xie, Weixing and Yao, Junfeng and Cao, Xianpeng and Lin, Qiqin and Tang, Zerui and Dong, Xiao and Guo, Xiaohu}, 80 | title = {SurgicalGaussian: Deformable 3D Gaussians for High-Fidelity Surgical Scene Reconstruction}, 81 | journal = {arXiv preprint arXiv:2407.05023}, 82 | year = {2024}, 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | 36 | def try_connect(): 37 | global conn, addr, listener 38 | try: 39 | conn, addr = listener.accept() 40 | print(f"\nConnected by {addr}") 41 | conn.settimeout(None) 42 | except Exception as inst: 43 | pass 44 | 45 | 46 | def read(): 47 | global conn 48 | messageLength = conn.recv(4) 49 | messageLength = int.from_bytes(messageLength, 'little') 50 | message = conn.recv(messageLength) 51 | return json.loads(message.decode("utf-8")) 52 | 53 | 54 | def send(message_bytes, verify): 55 | global conn 56 | if message_bytes != None: 57 | conn.sendall(message_bytes) 58 | conn.sendall(len(verify).to_bytes(4, 'little')) 59 | conn.sendall(bytes(verify, 'ascii')) 60 | 61 | 62 | def receive(): 63 | message = read() 64 | 65 | width = message["resolution_x"] 66 | height = message["resolution_y"] 67 | 68 | if width != 0 and height != 0: 69 | try: 70 | do_training = bool(message["train"]) 71 | fovy = message["fov_y"] 72 | fovx = message["fov_x"] 73 | znear = message["z_near"] 74 | zfar = message["z_far"] 75 | do_shs_python = bool(message["shs_python"]) 76 | do_rot_scale_python = bool(message["rot_scale_python"]) 77 | keep_alive = bool(message["keep_alive"]) 78 | scaling_modifier = message["scaling_modifier"] 79 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 80 | world_view_transform[:, 1] = -world_view_transform[:, 1] 81 | world_view_transform[:, 2] = -world_view_transform[:, 2] 82 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 83 | full_proj_transform[:, 1] = -full_proj_transform[:, 1] 84 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 85 | except Exception as e: 86 | print("") 87 | traceback.print_exc() 88 | raise e 89 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 90 | else: 91 | return None, None, None, None, None, None 92 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | # @title Configure dataset directories 2 | import os 3 | from pathlib import Path 4 | 5 | # @markdown The base directory for all captures. This can be anything if you're running this notebook on your own Jupyter runtime. 6 | save_dir = '/data00/yzy/Git_Project/data/dynamic/mine/' # @param {type: 'string'} 7 | capture_name = 'lemon' # @param {type: 'string'} 8 | # The root directory for this capture. 9 | root_dir = Path(save_dir, capture_name) 10 | # Where to save RGB images. 11 | rgb_dir = root_dir / 'rgb' 12 | rgb_raw_dir = root_dir / 'rgb-raw' 13 | # Where to save the COLMAP outputs. 14 | colmap_dir = root_dir / 'colmap' 15 | colmap_db_path = colmap_dir / 'database.db' 16 | colmap_out_path = colmap_dir / 'sparse' 17 | 18 | colmap_out_path.mkdir(exist_ok=True, parents=True) 19 | rgb_raw_dir.mkdir(exist_ok=True, parents=True) 20 | 21 | print(f"""Directories configured: 22 | root_dir = {root_dir} 23 | rgb_raw_dir = {rgb_raw_dir} 24 | rgb_dir = {rgb_dir} 25 | colmap_dir = {colmap_dir} 26 | """) 27 | 28 | # ==================== colmap ========================= 29 | # @title Extract features. 30 | # @markdown Computes SIFT features and saves them to the COLMAP DB. 31 | share_intrinsics = True # @param {type: 'boolean'} 32 | assume_upright_cameras = True # @param {type: 'boolean'} 33 | 34 | # @markdown This sets the scale at which we will run COLMAP. A scale of 1 will be more accurate but will be slow. 35 | colmap_image_scale = 4 # @param {type: 'number'} 36 | colmap_rgb_dir = rgb_dir / f'{colmap_image_scale}x' 37 | 38 | # @markdown Check this if you want to re-process SfM. 39 | overwrite = False # @param {type: 'boolean'} 40 | 41 | if overwrite and colmap_db_path.exists(): 42 | colmap_db_path.unlink() 43 | 44 | os.system('colmap feature_extractor \ 45 | --SiftExtraction.use_gpu 0 \ 46 | --SiftExtraction.upright {int(assume_upright_cameras)} \ 47 | --ImageReader.camera_model OPENCV \ 48 | --ImageReader.single_camera {int(share_intrinsics)} \ 49 | --database_path "{str(colmap_db_path)}" \ 50 | --image_path "{str(colmap_rgb_dir)}"') 51 | 52 | # @title Match features. 53 | # @markdown Match the SIFT features between images. Use `exhaustive` if you only have a few images and use `vocab_tree` if you have a lot of images. 54 | 55 | match_method = 'exhaustive' # @param ["exhaustive", "vocab_tree"] 56 | 57 | if match_method == 'exhaustive': 58 | os.system('colmap exhaustive_matcher \ 59 | --SiftMatching.use_gpu 0 \ 60 | --database_path "{str(colmap_db_path)}"') 61 | 62 | # @title Reconstruction. 63 | # @markdown Run structure-from-motion to compute camera parameters. 64 | 65 | refine_principal_point = True # @param {type:"boolean"} 66 | min_num_matches = 32 # @param {type: 'number'} 67 | filter_max_reproj_error = 2 # @param {type: 'number'} 68 | tri_complete_max_reproj_error = 2 # @param {type: 'number'} 69 | 70 | os.system('colmap mapper \ 71 | --Mapper.ba_refine_principal_point {int(refine_principal_point)} \ 72 | --Mapper.filter_max_reproj_error $filter_max_reproj_error \ 73 | --Mapper.tri_complete_max_reproj_error $tri_complete_max_reproj_error \ 74 | --Mapper.min_num_matches $min_num_matches \ 75 | --database_path "{str(colmap_db_path)}" \ 76 | --image_path "{str(colmap_rgb_dir)}" \ 77 | --export_path "{str(colmap_out_path)}"') 78 | 79 | print("debug") 80 | -------------------------------------------------------------------------------- /utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.rigid_utils import exp_se3 5 | 6 | 7 | def get_embedder(multires, i=1): 8 | if i == -1: 9 | return nn.Identity(), 3 10 | 11 | embed_kwargs = { 12 | 'include_input': True, 13 | 'input_dims': i, 14 | 'max_freq_log2': multires - 1, 15 | 'num_freqs': multires, 16 | 'log_sampling': True, 17 | 'periodic_fns': [torch.sin, torch.cos], 18 | } 19 | 20 | embedder_obj = Embedder(**embed_kwargs) 21 | embed = lambda x, eo=embedder_obj: eo.embed(x) 22 | return embed, embedder_obj.out_dim 23 | 24 | 25 | class Embedder: 26 | def __init__(self, **kwargs): 27 | self.kwargs = kwargs 28 | self.create_embedding_fn() 29 | 30 | def create_embedding_fn(self): 31 | embed_fns = [] 32 | d = self.kwargs['input_dims'] 33 | out_dim = 0 34 | if self.kwargs['include_input']: 35 | embed_fns.append(lambda x: x) 36 | out_dim += d 37 | 38 | max_freq = self.kwargs['max_freq_log2'] 39 | N_freqs = self.kwargs['num_freqs'] 40 | 41 | if self.kwargs['log_sampling']: 42 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 43 | else: 44 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 45 | 46 | for freq in freq_bands: 47 | for p_fn in self.kwargs['periodic_fns']: 48 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 49 | out_dim += d 50 | 51 | self.embed_fns = embed_fns 52 | self.out_dim = out_dim 53 | 54 | def embed(self, inputs): 55 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 56 | 57 | 58 | class DeformNetwork(nn.Module): 59 | def __init__(self, D=8, W=256, input_ch=3, output_ch=59, multires=10): 60 | super(DeformNetwork, self).__init__() 61 | self.D = D 62 | self.W = W 63 | self.input_ch = input_ch 64 | self.output_ch = output_ch 65 | self.t_multires = 4 66 | self.skips = [D // 2] 67 | 68 | self.embed_time_fn, time_input_ch = get_embedder(self.t_multires, 1) 69 | self.embed_fn, xyz_input_ch = get_embedder(multires, 3) 70 | self.input_ch = xyz_input_ch + time_input_ch 71 | 72 | self.linear = nn.ModuleList( 73 | [nn.Linear(self.input_ch, W)] + [ 74 | nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) 75 | for i in range(D - 1)] 76 | ) 77 | 78 | self.gaussian_warp = nn.Linear(W, 3) 79 | self.gaussian_rotation = nn.Linear(W, 4) 80 | self.gaussian_scaling = nn.Linear(W, 3) 81 | 82 | def forward(self, x, t): 83 | t_emb = self.embed_time_fn(t) 84 | x_emb = self.embed_fn(x) 85 | h = torch.cat([x_emb, t_emb], dim=-1) 86 | for i, l in enumerate(self.linear): 87 | h = self.linear[i](h) 88 | h = F.relu(h) 89 | if i in self.skips: 90 | h = torch.cat([x_emb, t_emb, h], -1) 91 | 92 | d_xyz = self.gaussian_warp(h) 93 | scaling = self.gaussian_scaling(h) 94 | rotation = self.gaussian_rotation(h) 95 | 96 | return d_xyz, rotation, scaling 97 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.graphics_utils import fov2focal 4 | 5 | trans_t = lambda t: torch.Tensor([ 6 | [1, 0, 0, 0], 7 | [0, 1, 0, 0], 8 | [0, 0, 1, t], 9 | [0, 0, 0, 1]]).float() 10 | 11 | rot_phi = lambda phi: torch.Tensor([ 12 | [1, 0, 0, 0], 13 | [0, np.cos(phi), -np.sin(phi), 0], 14 | [0, np.sin(phi), np.cos(phi), 0], 15 | [0, 0, 0, 1]]).float() 16 | 17 | rot_theta = lambda th: torch.Tensor([ 18 | [np.cos(th), 0, -np.sin(th), 0], 19 | [0, 1, 0, 0], 20 | [np.sin(th), 0, np.cos(th), 0], 21 | [0, 0, 0, 1]]).float() 22 | 23 | 24 | def rodrigues_mat_to_rot(R): 25 | eps = 1e-16 26 | trc = np.trace(R) 27 | trc2 = (trc - 1.) / 2. 28 | # sinacostrc2 = np.sqrt(1 - trc2 * trc2) 29 | s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) 30 | if (1 - trc2 * trc2) >= eps: 31 | tHeta = np.arccos(trc2) 32 | tHetaf = tHeta / (2 * (np.sin(tHeta))) 33 | else: 34 | tHeta = np.real(np.arccos(trc2)) 35 | tHetaf = 0.5 / (1 - tHeta / 6) 36 | omega = tHetaf * s 37 | return omega 38 | 39 | 40 | def rodrigues_rot_to_mat(r): 41 | wx, wy, wz = r 42 | theta = np.sqrt(wx * wx + wy * wy + wz * wz) 43 | a = np.cos(theta) 44 | b = (1 - np.cos(theta)) / (theta * theta) 45 | c = np.sin(theta) / theta 46 | R = np.zeros([3, 3]) 47 | R[0, 0] = a + b * (wx * wx) 48 | R[0, 1] = b * wx * wy - c * wz 49 | R[0, 2] = b * wx * wz + c * wy 50 | R[1, 0] = b * wx * wy + c * wz 51 | R[1, 1] = a + b * (wy * wy) 52 | R[1, 2] = b * wy * wz - c * wx 53 | R[2, 0] = b * wx * wz - c * wy 54 | R[2, 1] = b * wz * wy + c * wx 55 | R[2, 2] = a + b * (wz * wz) 56 | return R 57 | 58 | 59 | def pose_spherical(theta, phi, radius): 60 | c2w = trans_t(radius) 61 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 62 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 63 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 64 | return c2w 65 | 66 | 67 | def render_wander_path(view): 68 | focal_length = fov2focal(view.FoVy, view.image_height) 69 | R = view.R 70 | R[:, 1] = -R[:, 1] 71 | R[:, 2] = -R[:, 2] 72 | T = -view.T.reshape(-1, 1) 73 | pose = np.concatenate([R, T], -1) 74 | 75 | num_frames = 60 76 | max_disp = 5000.0 # 64 , 48 77 | 78 | max_trans = max_disp / focal_length # Maximum camera translation to satisfy max_disp parameter 79 | output_poses = [] 80 | 81 | for i in range(num_frames): 82 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 83 | y_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 # * 3.0 / 4.0 84 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 3.0 85 | 86 | i_pose = np.concatenate([ 87 | np.concatenate( 88 | [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1), 89 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :] 90 | ], axis=0) # [np.newaxis, :, :] 91 | 92 | i_pose = np.linalg.inv(i_pose) # torch.tensor(np.linalg.inv(i_pose)).float() 93 | 94 | ref_pose = np.concatenate([pose, np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0) 95 | 96 | render_pose = np.dot(ref_pose, i_pose) 97 | output_poses.append(torch.Tensor(render_pose)) 98 | 99 | return output_poses 100 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | from pytorch3d.ops.knn import knn_points 17 | 18 | 19 | def l1_loss(network_output, gt): 20 | return torch.abs((network_output - gt)).mean() 21 | 22 | 23 | def tv_loss(x): 24 | # K-Plane 25 | tv_h = torch.abs(x[:,1:,:] - x[:,:-1,:]).sum() 26 | tv_w = torch.abs(x[:,:,1:] - x[:,:,:-1]).sum() 27 | return (tv_h + tv_w) * 2 / x.numel() 28 | 29 | 30 | 31 | def l2_loss(network_output, gt): 32 | return ((network_output - gt) ** 2).mean() 33 | 34 | 35 | def gaussian(window_size, sigma): 36 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 37 | return gauss / gauss.sum() 38 | 39 | 40 | def create_window(window_size, channel): 41 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 42 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 43 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 44 | return window 45 | 46 | 47 | def ssim(img1, img2, window_size=11, size_average=True): 48 | channel = img1.size(-3) 49 | window = create_window(window_size, channel) 50 | 51 | if img1.is_cuda: 52 | window = window.cuda(img1.get_device()) 53 | window = window.type_as(img1) 54 | 55 | return _ssim(img1, img2, window, window_size, channel, size_average) 56 | 57 | 58 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 59 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 60 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 61 | 62 | mu1_sq = mu1.pow(2) 63 | mu2_sq = mu2.pow(2) 64 | mu1_mu2 = mu1 * mu2 65 | 66 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 67 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 68 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 69 | 70 | C1 = 0.01 ** 2 71 | C2 = 0.03 ** 2 72 | 73 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 74 | 75 | if size_average: 76 | return ssim_map.mean() 77 | else: 78 | return ssim_map.mean(1).mean(1).mean(1) 79 | 80 | 81 | def def_reg_loss(gs_can, d_xyz, d_rotation, d_scaling, K=5): 82 | xyz_can = gs_can.get_xyz 83 | xyz_obs = xyz_can + d_xyz 84 | 85 | cov_can = gs_can.get_covariance() 86 | cov_obs = gs_can.get_covariance_obs(d_rotation, d_scaling) 87 | 88 | _, nn_ix, _ = knn_points(xyz_can.unsqueeze(0), xyz_can.unsqueeze(0), K=K, return_sorted=True) 89 | nn_ix = nn_ix.squeeze(0) 90 | 91 | dis_xyz_can = torch.cdist(xyz_can.unsqueeze(1), xyz_can[nn_ix])[:, 0, 1:] 92 | dis_xyz_obs = torch.cdist(xyz_obs.unsqueeze(1), xyz_obs[nn_ix])[:, 0, 1:] 93 | loss_pos = F.l1_loss(dis_xyz_can, dis_xyz_obs) 94 | 95 | dis_cov_can = torch.cdist(cov_can.unsqueeze(1), cov_can[nn_ix])[:, 0, 1:] 96 | dis_cov_obs = torch.cdist(cov_obs.unsqueeze(1), cov_obs[nn_ix])[:, 0, 1:] 97 | loss_cov = F.l1_loss(dis_cov_can, dis_cov_obs) 98 | 99 | return loss_pos, loss_cov 100 | -------------------------------------------------------------------------------- /utils/rigid_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def skew(w: torch.Tensor) -> torch.Tensor: 5 | """Build a skew matrix ("cross product matrix") for vector w. 6 | 7 | Modern Robotics Eqn 3.30. 8 | 9 | Args: 10 | w: (N, 3) A 3-vector 11 | 12 | Returns: 13 | W: (N, 3, 3) A skew matrix such that W @ v == w x v 14 | """ 15 | zeros = torch.zeros(w.shape[0], device=w.device) 16 | w_skew_list = [zeros, -w[:, 2], w[:, 1], 17 | w[:, 2], zeros, -w[:, 0], 18 | -w[:, 1], w[:, 0], zeros] 19 | w_skew = torch.stack(w_skew_list, dim=-1).reshape(-1, 3, 3) 20 | return w_skew 21 | 22 | 23 | def rp_to_se3(R: torch.Tensor, p: torch.Tensor) -> torch.Tensor: 24 | """Rotation and translation to homogeneous transform. 25 | 26 | Args: 27 | R: (3, 3) An orthonormal rotation matrix. 28 | p: (3,) A 3-vector representing an offset. 29 | 30 | Returns: 31 | X: (4, 4) The homogeneous transformation matrix described by rotating by R 32 | and translating by p. 33 | """ 34 | bottom_row = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=R.device).repeat(R.shape[0], 1, 1) 35 | transform = torch.cat([torch.cat([R, p], dim=-1), bottom_row], dim=1) 36 | 37 | return transform 38 | 39 | 40 | def exp_so3(w: torch.Tensor, theta: float) -> torch.Tensor: 41 | """Exponential map from Lie algebra so3 to Lie group SO3. 42 | 43 | Modern Robotics Eqn 3.51, a.k.a. Rodrigues' formula. 44 | 45 | Args: 46 | w: (3,) An axis of rotation. 47 | theta: An angle of rotation. 48 | 49 | Returns: 50 | R: (3, 3) An orthonormal rotation matrix representing a rotation of 51 | magnitude theta about axis w. 52 | """ 53 | W = skew(w) 54 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 55 | W_sqr = torch.bmm(W, W) # batch matrix multiplication 56 | R = identity + torch.sin(theta.unsqueeze(-1)) * W + (1.0 - torch.cos(theta.unsqueeze(-1))) * W_sqr 57 | return R 58 | 59 | 60 | def exp_se3(S: torch.Tensor, theta: float) -> torch.Tensor: 61 | """Exponential map from Lie algebra so3 to Lie group SO3. 62 | 63 | Modern Robotics Eqn 3.88. 64 | 65 | Args: 66 | S: (6,) A screw axis of motion. 67 | theta: Magnitude of motion. 68 | 69 | Returns: 70 | a_X_b: (4, 4) The homogeneous transformation matrix attained by integrating 71 | motion of magnitude theta about S for one second. 72 | """ 73 | w, v = torch.split(S, 3, dim=-1) 74 | W = skew(w) 75 | R = exp_so3(w, theta) 76 | 77 | identity = torch.eye(3).unsqueeze(0).repeat(W.shape[0], 1, 1).to(W.device) 78 | W_sqr = torch.bmm(W, W) 79 | theta = theta.view(-1, 1, 1) 80 | 81 | p = torch.bmm((theta * identity + (1.0 - torch.cos(theta)) * W + (theta - torch.sin(theta)) * W_sqr), 82 | v.unsqueeze(-1)) 83 | return rp_to_se3(R, p) 84 | 85 | 86 | def to_homogenous(v: torch.Tensor) -> torch.Tensor: 87 | """Converts a vector to a homogeneous coordinate vector by appending a 1. 88 | 89 | Args: 90 | v: A tensor representing a vector or batch of vectors. 91 | 92 | Returns: 93 | A tensor with an additional dimension set to 1. 94 | """ 95 | return torch.cat([v, torch.ones_like(v[..., :1])], dim=-1) 96 | 97 | 98 | def from_homogenous(v: torch.Tensor) -> torch.Tensor: 99 | """Converts a homogeneous coordinate vector to a standard vector by dividing by the last element. 100 | 101 | Args: 102 | v: A tensor representing a homogeneous coordinate vector or batch of homogeneous coordinate vectors. 103 | 104 | Returns: 105 | A tensor with the last dimension removed. 106 | """ 107 | return v[..., :3] / v[..., -1:] 108 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | 18 | class Camera(nn.Module): 19 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device="cuda", fid=None, depth=None, mask_depth=None, mask=None): 21 | super(Camera, self).__init__() 22 | 23 | self.uid = uid 24 | self.colmap_id = colmap_id 25 | self.R = R 26 | self.T = T 27 | self.FoVx = FoVx 28 | self.FoVy = FoVy 29 | self.image_name = image_name 30 | 31 | try: 32 | self.data_device = torch.device(data_device) 33 | except Exception as e: 34 | print(e) 35 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device") 36 | self.data_device = torch.device("cuda") 37 | 38 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 39 | self.fid = torch.Tensor(np.array([fid])).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | 43 | self.depth = torch.Tensor(depth).to(self.data_device) if depth is not None else None 44 | self.mask_depth = torch.Tensor(mask_depth).to(self.data_device) if depth is not None else None 45 | self.mask = torch.Tensor(mask).to(self.data_device) if mask is not None else None 46 | 47 | if gt_alpha_mask is not None: 48 | self.original_image *= gt_alpha_mask.to(self.data_device) 49 | else: 50 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 51 | 52 | self.zfar = 100.0 53 | self.znear = 0.01 54 | 55 | self.trans = trans 56 | self.scale = scale 57 | 58 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).to( 59 | self.data_device) 60 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, 61 | fovY=self.FoVy).transpose(0, 1).to(self.data_device) 62 | self.full_proj_transform = ( 63 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 64 | self.camera_center = self.world_view_transform.inverse()[3, :3] 65 | 66 | def reset_extrinsic(self, R, T): 67 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, self.trans, self.scale)).transpose(0, 1).cuda() 68 | self.full_proj_transform = ( 69 | self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 70 | self.camera_center = self.world_view_transform.inverse()[3, :3] 71 | 72 | def load2device(self, data_device='cuda'): 73 | self.original_image = self.original_image.to(data_device) 74 | self.world_view_transform = self.world_view_transform.to(data_device) 75 | self.projection_matrix = self.projection_matrix.to(data_device) 76 | self.full_proj_transform = self.full_proj_transform.to(data_device) 77 | self.camera_center = self.camera_center.to(data_device) 78 | self.fid = self.fid.to(data_device) 79 | 80 | 81 | class MiniCam: 82 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 83 | self.image_width = width 84 | self.image_height = height 85 | self.FoVy = fovy 86 | self.FoVx = fovx 87 | self.znear = znear 88 | self.zfar = zfar 89 | self.world_view_transform = world_view_transform 90 | self.full_proj_transform = full_proj_transform 91 | view_inv = torch.inverse(self.world_view_transform) 92 | self.camera_center = view_inv[3][:3] 93 | -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch, ArrayToTorch 15 | from utils.graphics_utils import fov2focal 16 | import json 17 | 18 | WARNED = False 19 | 20 | 21 | def loadCam(args, id, cam_info, resolution_scale): 22 | orig_w, orig_h = cam_info.image.size 23 | 24 | if args.resolution in [1, 2, 4, 8]: 25 | resolution = round(orig_w / (resolution_scale * args.resolution)), round( 26 | orig_h / (resolution_scale * args.resolution)) 27 | else: # should be a type that converts to float 28 | if args.resolution == -1: 29 | if orig_w > 1600: 30 | global WARNED 31 | if not WARNED: 32 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 33 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 34 | WARNED = True 35 | global_down = orig_w / 1600 36 | else: 37 | global_down = 1 38 | else: 39 | global_down = orig_w / args.resolution 40 | 41 | scale = float(global_down) * float(resolution_scale) 42 | resolution = (int(orig_w / scale), int(orig_h / scale)) 43 | 44 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) # 这里把图像转换到0-1,同时[H W 3] 调整为 [3 H W]? 45 | 46 | gt_image = resized_image_rgb[:3, ...] 47 | loaded_mask = None 48 | 49 | if resized_image_rgb.shape[1] == 4: 50 | loaded_mask = resized_image_rgb[3:4, ...] 51 | # 这里没有对depth和mask同步缩放,不支持调整分辨率。*Camera()返回class Camera(nn.Module): 52 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 53 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 54 | image=gt_image, gt_alpha_mask=loaded_mask, 55 | image_name=cam_info.image_name, uid=id, 56 | data_device=args.data_device if not args.load2gpu_on_the_fly else 'cpu', fid=cam_info.fid, 57 | depth=cam_info.depth, mask_depth=cam_info.mask_depth, mask=cam_info.mask) 58 | 59 | 60 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 61 | camera_list = [] 62 | 63 | for id, c in enumerate(cam_infos): 64 | camera_list.append(loadCam(args, id, c, resolution_scale)) 65 | 66 | return camera_list 67 | 68 | 69 | def camera_to_JSON(id, camera: Camera): 70 | Rt = np.zeros((4, 4)) 71 | Rt[:3, :3] = camera.R.transpose() 72 | Rt[:3, 3] = camera.T 73 | Rt[3, 3] = 1.0 74 | 75 | W2C = np.linalg.inv(Rt) 76 | pos = W2C[:3, 3] 77 | rot = W2C[:3, :3] 78 | serializable_array_2d = [x.tolist() for x in rot] 79 | camera_entry = { 80 | 'id': id, 81 | 'img_name': camera.image_name, 82 | 'width': camera.width, 83 | 'height': camera.height, 84 | 'position': pos.tolist(), 85 | 'rotation': serializable_array_2d, 86 | 'fy': fov2focal(camera.FovY, camera.height), 87 | 'fx': fov2focal(camera.FovX, camera.width) 88 | } 89 | return camera_entry 90 | 91 | 92 | def camera_nerfies_from_JSON(path, scale): 93 | """Loads a JSON camera into memory.""" 94 | with open(path, 'r') as fp: 95 | camera_json = json.load(fp) 96 | 97 | # Fix old camera JSON. 98 | if 'tangential' in camera_json: 99 | camera_json['tangential_distortion'] = camera_json['tangential'] 100 | 101 | return dict( 102 | orientation=np.array(camera_json['orientation']), 103 | position=np.array(camera_json['position']), 104 | focal_length=camera_json['focal_length'] * scale, 105 | principal_point=np.array(camera_json['principal_point']) * scale, 106 | skew=camera_json['skew'], 107 | pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], 108 | radial_distortion=np.array(camera_json['radial_distortion']), 109 | tangential_distortion=np.array(camera_json['tangential_distortion']), 110 | image_size=np.array((int(round(camera_json['image_size'][0] * scale)), 111 | int(round(camera_json['image_size'][1] * scale)))), 112 | ) 113 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.deform_model import DeformModel 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 21 | 22 | 23 | class Scene: 24 | gaussians: GaussianModel 25 | 26 | def __init__(self, args: ModelParams, gaussians: GaussianModel, load_iteration=None, shuffle=True, 27 | resolution_scales=[1.0]): 28 | """b 29 | :param path: Path to colmap scene main folder. 30 | """ 31 | self.model_path = args.model_path 32 | self.loaded_iter = None 33 | self.gaussians = gaussians 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | 45 | if os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")): 46 | print("Found calibration_full.json, assuming Endonerf data set!") 47 | scene_info = sceneLoadTypeCallbacks["Endonerf"](args.source_path, args.dataset_type, args.eval, args.is_depth, args.depth_scale, args.is_mask, 48 | args.depth_initial, args.frame_nums, args.test_id) 49 | else: 50 | assert False, "Could not recognize scene type!" 51 | 52 | if not self.loaded_iter: # scene_info里读取对应的数据 53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply"), 54 | 'wb') as dest_file: # input.ply 55 | dest_file.write(src_file.read()) 56 | json_cams = [] 57 | camlist = [] 58 | if scene_info.test_cameras: 59 | camlist.extend(scene_info.test_cameras) 60 | if scene_info.train_cameras: 61 | camlist.extend(scene_info.train_cameras) 62 | for id, cam in enumerate(camlist): 63 | json_cams.append(camera_to_JSON(id, cam)) 64 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: # cameras.json 65 | json.dump(json_cams, file) 66 | 67 | if shuffle: # 随机打乱 68 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 69 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 70 | 71 | self.cameras_extent = scene_info.nerf_normalization["radius"] # 场景大小值 72 | 73 | for resolution_scale in resolution_scales: 74 | print("Loading Training Cameras") 75 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, 76 | args) 77 | print("Loading Test Cameras") 78 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, 79 | args) 80 | 81 | if self.loaded_iter: # 迭代一定次数后会生成新的高斯,会从新的初始高斯加载 82 | self.gaussians.load_ply(os.path.join(self.model_path, 83 | "point_cloud", 84 | "iteration_" + str(self.loaded_iter), 85 | "point_cloud.ply"), 86 | og_number_points=len(scene_info.point_cloud.points)) 87 | else: 88 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) # 场景大小值传入 89 | 90 | def save(self, iteration): 91 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 92 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"), iteration) 93 | 94 | def getTrainCameras(self, scale=1.0): 95 | return self.train_cameras[scale] 96 | 97 | def getTestCameras(self, scale=1.0): 98 | return self.test_cameras[scale] 99 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from utils.rigid_utils import from_homogenous, to_homogenous 18 | 19 | 20 | def quaternion_multiply(q1, q2): 21 | w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] 22 | w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] 23 | 24 | w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 25 | x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 26 | y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 27 | z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 28 | 29 | return torch.stack((w, x, y, z), dim=-1) 30 | 31 | 32 | def render(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, d_xyz, d_rotation, d_scaling, 33 | scaling_modifier=1.0, override_color=None): 34 | """ 35 | Render the scene. 36 | 37 | Background tensor (bg_color) must be on GPU! 38 | """ 39 | 40 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 41 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 42 | try: 43 | screenspace_points.retain_grad() 44 | except: 45 | pass 46 | 47 | # Set up rasterization configuration 48 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 49 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 50 | 51 | raster_settings = GaussianRasterizationSettings( 52 | image_height=int(viewpoint_camera.image_height), 53 | image_width=int(viewpoint_camera.image_width), 54 | tanfovx=tanfovx, 55 | tanfovy=tanfovy, 56 | bg=bg_color, 57 | scale_modifier=scaling_modifier, 58 | viewmatrix=viewpoint_camera.world_view_transform, 59 | projmatrix=viewpoint_camera.full_proj_transform, 60 | sh_degree=pc.active_sh_degree, 61 | campos=viewpoint_camera.camera_center, 62 | prefiltered=False, 63 | debug=pipe.debug, 64 | ) 65 | 66 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 67 | 68 | means3D = pc.get_xyz + d_xyz 69 | means2D = screenspace_points 70 | opacity = pc.get_opacity 71 | 72 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 73 | # scaling / rotation by the rasterizer. 74 | scales = None 75 | rotations = None 76 | cov3D_precomp = None 77 | if pipe.compute_cov3D_python: 78 | cov3D_precomp = pc.get_covariance(scaling_modifier) 79 | else: 80 | scales = pc.get_scaling + d_scaling 81 | rotations = pc.get_rotation + d_rotation 82 | 83 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 84 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 85 | shs = None 86 | colors_precomp = None 87 | if colors_precomp is None: 88 | if pipe.convert_SHs_python: 89 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree + 1) ** 2) 90 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 91 | dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) 92 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 93 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 94 | else: 95 | shs = pc.get_features 96 | else: 97 | colors_precomp = override_color 98 | 99 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 100 | rendered_image, radii, depth = rasterizer( 101 | means3D=means3D, 102 | means2D=means2D, 103 | shs=shs, 104 | colors_precomp=colors_precomp, 105 | opacities=opacity, 106 | scales=scales, 107 | rotations=rotations, 108 | cov3D_precomp=cov3D_precomp) 109 | 110 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 111 | # They will be excluded from value updates used in the splitting criteria. 112 | return {"render": rendered_image, 113 | "viewspace_points": screenspace_points, 114 | "visibility_filter": radii > 0, 115 | "radii": radii, 116 | "depth": depth, 117 | } 118 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | 17 | class GroupParams: 18 | pass 19 | 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name: str, fill_none=False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | else: 35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | 50 | class ModelParams(ParamGroup): 51 | def __init__(self, parser, sentinel=False): 52 | self.sh_degree = 3 53 | self._source_path = "" 54 | self._model_path = "" 55 | self._images = "images" 56 | self._resolution = -1 57 | self._white_background = True 58 | self.data_device = "cuda" 59 | self.eval = True 60 | self.load2gpu_on_the_fly = False 61 | self.dataset_type = None 62 | self.frame_nums = None 63 | self.test_id = None 64 | 65 | self.depth_scale = 100 66 | self.is_depth = False 67 | self.is_mask = False 68 | self.depth_initial = False 69 | self.accurate_mask = False 70 | 71 | super().__init__(parser, "Loading Parameters", sentinel) 72 | 73 | def extract(self, args): 74 | g = super().extract(args) 75 | g.source_path = os.path.abspath(g.source_path) 76 | return g 77 | 78 | 79 | class PipelineParams(ParamGroup): 80 | def __init__(self, parser): 81 | self.convert_SHs_python = False 82 | self.compute_cov3D_python = False 83 | self.debug = False 84 | super().__init__(parser, "Pipeline Parameters") 85 | 86 | 87 | class OptimizationParams(ParamGroup): 88 | def __init__(self, parser): 89 | self.iterations = 40_000 90 | self.position_lr_init = 0.00016 91 | self.position_lr_final = 0.0000016 92 | self.position_lr_delay_mult = 0.01 93 | self.position_lr_max_steps = 30000 94 | self.deform_lr_max_steps = 40_000 95 | self.feature_lr = 0.0025 96 | self.opacity_lr = 0.05 97 | self.scaling_lr = 0.001 98 | self.rotation_lr = 0.001 99 | self.percent_dense = 0.01 100 | self.lambda_dssim = 0.2 101 | self.lambda_smooth = 0.02 102 | self.lambda_cov = 0.0 103 | self.lambda_pos = 0.0 104 | self.densification_interval = 100 105 | self.opacity_reset_interval = 3000 106 | self.densify_from_iter = 500 107 | self.densify_until_iter = 15_000 108 | self.densify_grad_threshold = 0.0002 109 | 110 | super().__init__(parser, "Optimization Parameters") 111 | 112 | 113 | def get_combined_args(parser: ArgumentParser): 114 | cmdlne_string = sys.argv[1:] 115 | cfgfile_string = "Namespace()" 116 | args_cmdline = parser.parse_args(cmdlne_string) 117 | 118 | try: 119 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 120 | print("Looking for config file in", cfgfilepath) 121 | with open(cfgfilepath) as cfg_file: 122 | print("Config file found: {}".format(cfgfilepath)) 123 | cfgfile_string = cfg_file.read() 124 | except TypeError: 125 | print("Config file not found at") 126 | pass 127 | args_cfgfile = eval(cfgfile_string) 128 | 129 | merged_dict = vars(args_cfgfile).copy() 130 | for k, v in vars(args_cmdline).items(): 131 | if v != None: 132 | merged_dict[k] = v 133 | return Namespace(**merged_dict) 134 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | 115 | def RGB2SH(rgb): 116 | return (rgb - 0.5) / C0 117 | 118 | 119 | def SH2RGB(sh): 120 | return sh * C0 + 0.5 121 | -------------------------------------------------------------------------------- /utils/initial_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | from glob import glob 5 | import cv2 as cv 6 | 7 | 8 | def imread(f): 9 | if f.endswith('png'): 10 | return imageio.imread(f, ignoregamma=True) 11 | else: 12 | return imageio.imread(f) 13 | 14 | def get_all_initial_data_endo(path, data_type, depth_scale, is_mask, npy_file): 15 | poses_bounds = np.load(os.path.join(path, npy_file)) 16 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) 17 | H, W, focal = poses[0, :, -1] 18 | H = H.astype(int) 19 | W = W.astype(int) 20 | 21 | video_path = sorted(glob(os.path.join(path, 'images/*'))) 22 | if is_mask: 23 | masks_path = sorted(glob(os.path.join(path, 'gt_masks/*'))) 24 | GT_masks_path = os.path.join(path, "gt_masks") 25 | 26 | inpaint_mask_all = np.zeros((512, 640)) 27 | 28 | for i in range(0, len(masks_path)): 29 | img_name = masks_path[i] 30 | f_mask = os.path.join(GT_masks_path, img_name) 31 | m = 1.0 - np.array(imread(f_mask) / 255.0) 32 | inpaint_mask_all = inpaint_mask_all + m 33 | inpaint_mask_all[inpaint_mask_all >= 1] = 1 34 | 35 | inpaint_mask_all = (1.0 - inpaint_mask_all) * 255.0 36 | inpaint_mask_all = inpaint_mask_all.astype(np.uint8) 37 | 38 | fn = os.path.join(path, f"invisible_mask.png") 39 | imageio.imwrite(fn, inpaint_mask_all) 40 | kernel = np.ones((5, 5), np.uint8) 41 | 42 | dilated_mask = cv.dilate(inpaint_mask_all, kernel, iterations=2) 43 | if data_type == 'endonerf': 44 | dilated_mask[-12:, :] = 255 45 | fn = os.path.join(path, f"dilated_invisible_mask.png") 46 | imageio.imwrite(fn, dilated_mask) 47 | 48 | depths_path = sorted(glob(os.path.join(path, 'depth/*'))) 49 | 50 | print(f"Using the all depth map as the initial point cloud") 51 | print(f"Using the depth_scale:{depth_scale} scale the depth map") 52 | 53 | depth_all = np.zeros((H, W)) 54 | color_all = np.zeros((3, H, W)) 55 | mask_all = np.zeros((H, W)) 56 | inv_mask_all = np.ones((H, W)) 57 | for i in range(poses_bounds.shape[0]): 58 | images_path = video_path[i] 59 | image = np.array(imread(images_path) * 1.0) 60 | color = np.transpose(image, (2, 0, 1)) # (H, W, C) -> (C, H, W) 61 | 62 | mask_image = None 63 | if is_mask: 64 | mask_path = masks_path[i] 65 | mask_image = np.array(imread(mask_path) / 255.0) 66 | # mask is 0 or 1 67 | mask_image = np.where(mask_image > 0.5, 1.0, 0.0) 68 | # Convert 0 for tool, 1 for not tool 69 | mask_image = 1.0 - mask_image 70 | 71 | color_mask = np.expand_dims(mask_image, axis=0) 72 | color = color * color_mask 73 | 74 | depth_path = depths_path[i] 75 | depth_image = np.array(imread(depth_path) * 1.0) 76 | depth_image = depth_image / depth_scale 77 | near = np.percentile(depth_image, 3) 78 | far = np.percentile(depth_image, 98) 79 | mask_depth = np.bitwise_and(depth_image > near, depth_image < far) 80 | mask_depth = mask_depth * mask_image 81 | depth = depth_image * mask_depth 82 | mask_plus = mask_depth * inv_mask_all 83 | color_all_mask = np.expand_dims(mask_plus, axis=0) 84 | depth_all = depth_all + depth * mask_plus 85 | color_all = color_all + color * color_all_mask 86 | 87 | mask_all = mask_all + mask_depth 88 | mask_all[mask_all >= 1] = 1 89 | inv_mask_all = 1.0 - mask_all 90 | 91 | depth_all = np.expand_dims(depth_all, axis=0) 92 | 93 | intrinsics = np.zeros((3, 3)) 94 | intrinsics[0, 0] = focal 95 | intrinsics[1, 1] = focal 96 | intrinsics[0, 2] = W / 2.0 # CX: W/2 97 | intrinsics[1, 2] = H / 2.0 # CY: H/2 98 | intrinsics[2, 2] = 1.0 99 | 100 | 101 | return color_all, depth_all, intrinsics, mask_all 102 | 103 | def get_pointcloud(color, depth, intrinsics, mask, w2c=None, transform_pts=False): 104 | width, height = color.shape[2], color.shape[1] 105 | CX = intrinsics[0][2] 106 | CY = intrinsics[1][2] 107 | FX = intrinsics[0][0] 108 | FY = intrinsics[1][1] 109 | 110 | # Compute indices of pixels 111 | x_grid, y_grid = np.meshgrid(np.arange(width).astype(np.float32), 112 | np.arange(height).astype(np.float32), 113 | indexing='xy') 114 | xx = (x_grid - CX) / FX 115 | yy = (y_grid - CY) / FY 116 | xx = xx.reshape(-1) # 117 | yy = yy.reshape(-1) # 118 | depth_z = depth[0].reshape(-1) # 119 | 120 | # Initialize point cloud 121 | pts_cam = np.stack((xx * depth_z, yy * depth_z, depth_z), axis=-1) 122 | 123 | if transform_pts: 124 | pix_ones = np.ones(height * width, 1).astype(np.float32) 125 | pts4 = np.concatenate((pts_cam, pix_ones), axis=1) 126 | c2w = np.linalg.inv(w2c) 127 | pts = np.dot(pts4, c2w.T)[:, :3] 128 | else: 129 | pts = pts_cam 130 | 131 | # Colorize point cloud 132 | cols = np.transpose(color, (1, 2, 0)).reshape(-1, 3) # (C, H, W) -> (H, W, C) -> (H * W, C) 133 | mask_sample = sample_pts(height, width, 3) 134 | mask_sample = (mask_sample != 0) 135 | mask_sample = mask_sample.reshape(-1) 136 | 137 | mask = mask.reshape(-1) 138 | mask = (mask != 0) 139 | pts = pts[mask & mask_sample] 140 | 141 | cols = cols[mask & mask_sample] 142 | print(f"Using the {pts.shape[0]} points as initial") 143 | 144 | return pts, cols 145 | 146 | def sample_pts(height, width, factor=2): 147 | mask_sample_h = np.zeros((height, width)).astype(np.int) 148 | mask_sample_w = np.zeros((height, width)).astype(np.int) 149 | mask_sample_h[:, 1::factor] = 1 150 | mask_sample_w[1::factor, :] = 1 151 | mask_sample = mask_sample_h & mask_sample_w 152 | 153 | return mask_sample 154 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | 19 | def inverse_sigmoid(x): 20 | return torch.log(x / (1 - x)) 21 | 22 | def np_inverse_sigmoid(x): 23 | return np.log(x / (1 - x)) 24 | 25 | def PILtoTorch(pil_image, resolution): 26 | resized_image_PIL = pil_image.resize(resolution) 27 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 28 | if len(resized_image.shape) == 3: 29 | return resized_image.permute(2, 0, 1) 30 | else: 31 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 32 | 33 | 34 | def ArrayToTorch(array, resolution): 35 | # resized_image = np.resize(array, resolution) 36 | resized_image_torch = torch.from_numpy(array) 37 | 38 | if len(resized_image_torch.shape) == 3: 39 | return resized_image_torch.permute(2, 0, 1) 40 | else: 41 | return resized_image_torch.unsqueeze(dim=-1).permute(2, 0, 1) 42 | 43 | 44 | def get_expon_lr_func( 45 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 46 | ): 47 | """ 48 | Copied from Plenoxels 49 | 50 | Continuous learning rate decay function. Adapted from JaxNeRF 51 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 52 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 53 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 54 | function of lr_delay_mult, such that the initial learning rate is 55 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 56 | to the normal learning rate when steps>lr_delay_steps. 57 | :param conf: config subtree 'lr' or similar 58 | :param max_steps: int, the number of steps during optimization. 59 | :return HoF which takes step as input 60 | """ 61 | 62 | def helper(step): 63 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 64 | # Disable this parameter 65 | return 0.0 66 | if lr_delay_steps > 0: 67 | # A kind of reverse cosine decay. 68 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 69 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 70 | ) 71 | else: 72 | delay_rate = 1.0 73 | t = np.clip(step / max_steps, 0, 1) 74 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 75 | return delay_rate * log_lerp 76 | 77 | return helper 78 | 79 | 80 | def get_linear_noise_func( 81 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 82 | ): 83 | """ 84 | Copied from Plenoxels 85 | 86 | Continuous learning rate decay function. Adapted from JaxNeRF 87 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 88 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 89 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 90 | function of lr_delay_mult, such that the initial learning rate is 91 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 92 | to the normal learning rate when steps>lr_delay_steps. 93 | :param conf: config subtree 'lr' or similar 94 | :param max_steps: int, the number of steps during optimization. 95 | :return HoF which takes step as input 96 | """ 97 | 98 | def helper(step): 99 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 100 | # Disable this parameter 101 | return 0.0 102 | if lr_delay_steps > 0: 103 | # A kind of reverse cosine decay. 104 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 105 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 106 | ) 107 | else: 108 | delay_rate = 1.0 109 | t = np.clip(step / max_steps, 0, 1) 110 | log_lerp = lr_init * (1 - t) + lr_final * t 111 | return delay_rate * log_lerp 112 | 113 | return helper 114 | 115 | 116 | def strip_lowerdiag(L): 117 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 118 | 119 | uncertainty[:, 0] = L[:, 0, 0] 120 | uncertainty[:, 1] = L[:, 0, 1] 121 | uncertainty[:, 2] = L[:, 0, 2] 122 | uncertainty[:, 3] = L[:, 1, 1] 123 | uncertainty[:, 4] = L[:, 1, 2] 124 | uncertainty[:, 5] = L[:, 2, 2] 125 | return uncertainty 126 | 127 | 128 | def strip_symmetric(sym): 129 | return strip_lowerdiag(sym) 130 | 131 | 132 | def build_rotation(r): 133 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 134 | 135 | q = r / norm[:, None] 136 | 137 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 138 | 139 | r = q[:, 0] 140 | x = q[:, 1] 141 | y = q[:, 2] 142 | z = q[:, 3] 143 | 144 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 145 | R[:, 0, 1] = 2 * (x * y - r * z) 146 | R[:, 0, 2] = 2 * (x * z + r * y) 147 | R[:, 1, 0] = 2 * (x * y + r * z) 148 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 149 | R[:, 1, 2] = 2 * (y * z - r * x) 150 | R[:, 2, 0] = 2 * (x * z - r * y) 151 | R[:, 2, 1] = 2 * (y * z + r * x) 152 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 153 | return R 154 | 155 | 156 | def build_scaling_rotation(s, r): 157 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 158 | R = build_rotation(r) 159 | 160 | L[:, 0, 0] = s[:, 0] 161 | L[:, 1, 1] = s[:, 1] 162 | L[:, 2, 2] = s[:, 2] 163 | 164 | L = R @ L 165 | return L 166 | 167 | 168 | def safe_state(silent): 169 | old_f = sys.stdout 170 | 171 | class F: 172 | def __init__(self, silent): 173 | self.silent = silent 174 | 175 | def write(self, x): 176 | if not self.silent: 177 | if x.endswith("\n"): 178 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 179 | else: 180 | old_f.write(x) 181 | 182 | def flush(self): 183 | old_f.flush() 184 | 185 | sys.stdout = F(silent) 186 | 187 | random.seed(0) 188 | np.random.seed(0) 189 | torch.manual_seed(0) 190 | torch.cuda.set_device(torch.device("cuda:0")) 191 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | import numpy as np 15 | from PIL import Image 16 | import torch 17 | import torchvision.transforms.functional as tf 18 | from utils.loss_utils import ssim 19 | # from lpipsPyTorch import lpips 20 | import lpips 21 | import json 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from utils.image_utils import mse 25 | from argparse import ArgumentParser 26 | 27 | 28 | def readImages(renders_dir, gt_dir, masks_dir): 29 | renders = [] 30 | gts = [] 31 | image_names = [] 32 | for fname in os.listdir(renders_dir): 33 | render = Image.open(renders_dir / fname) 34 | gt = Image.open(gt_dir / fname) 35 | mask_tensor = tf.to_tensor(Image.open(masks_dir / fname)) 36 | renders.append((tf.to_tensor(render) * mask_tensor).unsqueeze(0)[:, :3, :, :].cuda()) 37 | gts.append((tf.to_tensor(gt) * mask_tensor).unsqueeze(0)[:, :3, :, :].cuda()) 38 | image_names.append(fname) 39 | return renders, gts, image_names 40 | 41 | def read_depth_np(renders_dir, gt_dir): 42 | renders = [] 43 | gts = [] 44 | image_names = [] 45 | for fname in os.listdir(renders_dir): 46 | render = np.load(renders_dir / fname) # [1,H,W] 47 | gt = np.load(gt_dir / fname) 48 | 49 | renders.append(torch.tensor(render).cuda()) 50 | gts.append(torch.tensor(gt).cuda()) 51 | image_names.append(fname) 52 | return renders, gts, image_names 53 | 54 | def read_depth_Images(renders_dir, gt_dir): 55 | renders = [] 56 | gts = [] 57 | image_names = [] 58 | for fname in os.listdir(renders_dir): 59 | render = Image.open(renders_dir / fname).convert("L") 60 | gt = Image.open(gt_dir / fname).convert("L") 61 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :, :].cuda()) 62 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :, :].cuda()) 63 | image_names.append(fname) 64 | return renders, gts, image_names 65 | 66 | def evaluate(model_paths): 67 | full_dict = {} 68 | per_view_dict = {} 69 | full_dict_polytopeonly = {} 70 | per_view_dict_polytopeonly = {} 71 | print("") 72 | 73 | for scene_dir in model_paths: 74 | try: 75 | print("Scene:", scene_dir) 76 | full_dict[scene_dir] = {} 77 | per_view_dict[scene_dir] = {} 78 | full_dict_polytopeonly[scene_dir] = {} 79 | per_view_dict_polytopeonly[scene_dir] = {} 80 | 81 | test_dir = Path(scene_dir) / "test" 82 | 83 | for method in os.listdir(test_dir): 84 | if not method.startswith("ours"): 85 | continue 86 | print("Method:", method) 87 | 88 | full_dict[scene_dir][method] = {} 89 | per_view_dict[scene_dir][method] = {} 90 | full_dict_polytopeonly[scene_dir][method] = {} 91 | per_view_dict_polytopeonly[scene_dir][method] = {} 92 | 93 | method_dir = test_dir / method 94 | gt_dir = method_dir / "gt_color" 95 | renders_dir = method_dir / "renders" 96 | masks_dir = method_dir / "masks" # tool mask,not contain depth lack 97 | renders, gts, image_names = readImages(renders_dir, gt_dir, masks_dir) 98 | 99 | ssims = [] 100 | psnrs = [] 101 | lpipss = [] 102 | 103 | for idx in tqdm(range(len(renders)), desc="Color metric evaluation progress"): 104 | ssims.append(ssim(renders[idx], gts[idx])) 105 | psnrs.append(psnr(renders[idx], gts[idx])) 106 | lpipss.append(lpips_fn(renders[idx], gts[idx]).detach()) 107 | 108 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 109 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 110 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 111 | print("") 112 | 113 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 114 | "PSNR": torch.tensor(psnrs).mean().item(), 115 | "LPIPS": torch.tensor(lpipss).mean().item()}) 116 | per_view_dict[scene_dir][method].update( 117 | {"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 118 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 119 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 120 | 121 | Depth_psnrs = [] 122 | gt_depth_dir = method_dir / "gt_depth" 123 | if len(os.listdir(gt_depth_dir)) > 0: 124 | depth_dir = method_dir / "depth" 125 | depths, gt_depths, depth_names = read_depth_Images(depth_dir, gt_depth_dir) 126 | 127 | for idx in tqdm(range(len(depths)), desc="Depth_psnr metric evaluation progress"): 128 | Depth_psnrs.append(psnr(depths[idx], gt_depths[idx])) 129 | print(" Depth_psnr : {:>12.7f}".format(torch.tensor(Depth_psnrs).mean(), ".5")) 130 | full_dict[scene_dir][method].update({"Depth_psnr": torch.tensor(Depth_psnrs).mean().item()}) 131 | per_view_dict[scene_dir][method].update( 132 | {"Depth_psnr": {name: dpsnr for dpsnr, name in zip(torch.tensor(Depth_psnrs).tolist(), image_names)}}) 133 | 134 | Depth_mses = [] 135 | 136 | gt_depth_np_dir = method_dir / "gt_depth_np" 137 | if len(os.listdir(gt_depth_np_dir)) > 0: 138 | depth_np_dir = method_dir / "depth_np" 139 | depths_np, gt_depths_np, depth_np_names = read_depth_np(depth_np_dir, gt_depth_np_dir) 140 | 141 | for idx in tqdm(range(len(depths_np)), desc="Depth_mse metric evaluation progress"): 142 | Depth_mses.append(mse(depths_np[idx], gt_depths_np[idx])) 143 | print(" Depth_mse : {:>12.7f}".format(torch.tensor(Depth_mses).mean(), ".5")) 144 | full_dict[scene_dir][method].update({"Depth_mse": torch.tensor(Depth_mses).mean().item()}) 145 | per_view_dict[scene_dir][method].update( 146 | {"Depth_mse": {name: mse for mse, name in zip(torch.tensor(Depth_mses).tolist(), image_names)}}) 147 | 148 | with open(scene_dir + "/results.json", 'w') as fp: 149 | json.dump(full_dict[scene_dir], fp, indent=True) 150 | with open(scene_dir + "/per_view.json", 'w') as fp: 151 | json.dump(per_view_dict[scene_dir], fp, indent=True) 152 | except: 153 | print("Unable to compute metrics for model", scene_dir) 154 | 155 | 156 | if __name__ == "__main__": 157 | device = torch.device("cuda:0") 158 | torch.cuda.set_device(device) 159 | # vgg and alex 160 | # lpips_fn = lpips.LPIPS(net='vgg').to(device) 161 | lpips_fn = lpips.LPIPS(net='alex').to(device) 162 | 163 | # Set up command line argument parser 164 | parser = ArgumentParser(description="Training script parameters") 165 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 166 | args = parser.parse_args() 167 | evaluate(args.model_paths) 168 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | import torch 15 | from PIL import Image 16 | from typing import NamedTuple, Optional 17 | 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | 21 | import imageio 22 | from glob import glob 23 | import cv2 as cv 24 | from pathlib import Path 25 | from plyfile import PlyData, PlyElement 26 | from utils.sh_utils import SH2RGB 27 | from scene.gaussian_model import BasicPointCloud 28 | 29 | from utils.initial_utils import get_pointcloud, get_all_initial_data_endo 30 | 31 | 32 | class CameraInfo(NamedTuple): 33 | uid: int 34 | R: np.array 35 | T: np.array 36 | FovY: np.array 37 | FovX: np.array 38 | image: np.array 39 | image_path: str 40 | image_name: str 41 | width: int 42 | height: int 43 | fid: float 44 | depth: Optional[np.array] = None 45 | mask_depth: Optional[np.array] = None 46 | mask: Optional[np.array] = None 47 | 48 | 49 | class SceneInfo(NamedTuple): 50 | point_cloud: BasicPointCloud 51 | train_cameras: list 52 | test_cameras: list 53 | nerf_normalization: dict 54 | ply_path: str 55 | 56 | 57 | def load_K_Rt_from_P(filename, P=None): 58 | if P is None: 59 | lines = open(filename).read().splitlines() 60 | if len(lines) == 4: 61 | lines = lines[1:] 62 | lines = [[x[0], x[1], x[2], x[3]] 63 | for x in (x.split(" ") for x in lines)] 64 | P = np.asarray(lines).astype(np.float32).squeeze() 65 | 66 | out = cv.decomposeProjectionMatrix(P) 67 | K = out[0] 68 | R = out[1] 69 | t = out[2] 70 | 71 | K = K / K[2, 2] 72 | 73 | pose = np.eye(4, dtype=np.float32) 74 | pose[:3, :3] = R.transpose() 75 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 76 | 77 | return K, pose 78 | 79 | def imread(f): 80 | if f.endswith('png'): 81 | return imageio.imread(f, ignoregamma=True) 82 | else: 83 | return imageio.imread(f) 84 | 85 | def getNerfppNorm(cam_info): 86 | def get_center_and_diag(cam_centers): 87 | cam_centers = np.hstack(cam_centers) 88 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 89 | center = avg_cam_center 90 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 91 | diagonal = np.max(dist) 92 | return center.flatten(), diagonal 93 | 94 | cam_centers = [] 95 | 96 | for cam in cam_info: 97 | W2C = getWorld2View2(cam.R, cam.T) 98 | C2W = np.linalg.inv(W2C) 99 | cam_centers.append(C2W[:3, 3:4]) 100 | 101 | center, diagonal = get_center_and_diag(cam_centers) 102 | radius = diagonal * 1.1 103 | 104 | translate = -center 105 | 106 | return {"translate": translate, "radius": radius} 107 | 108 | 109 | def fetchPly(path): 110 | plydata = PlyData.read(path) 111 | vertices = plydata['vertex'] 112 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 113 | colors = np.vstack([vertices['red'], vertices['green'], 114 | vertices['blue']]).T / 255.0 115 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 116 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 117 | 118 | 119 | def storePly(path, xyz, rgb): 120 | # Define the dtype for the structured array 121 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 122 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 123 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 124 | 125 | normals = np.zeros_like(xyz) 126 | 127 | elements = np.empty(xyz.shape[0], dtype=dtype) 128 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 129 | elements[:] = list(map(tuple, attributes)) 130 | 131 | # Create the PlyData object and write to file 132 | vertex_element = PlyElement.describe(elements, 'vertex') 133 | ply_data = PlyData([vertex_element]) 134 | ply_data.write(path) 135 | 136 | 137 | def readCamerasdavinci(path, data_type, is_depth, depth_scale, is_mask, npy_file, split, hold_id, num_images): 138 | cam_infos = [] 139 | poses_bounds = np.load(os.path.join(path, npy_file)) 140 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) 141 | H, W, focal = poses[0, :, -1] 142 | video_path = sorted(glob(os.path.join(path, 'images/*'))) 143 | if is_mask: 144 | masks_path = sorted(glob(os.path.join(path, 'gt_masks/*'))) 145 | if is_depth: 146 | depths_path = sorted(glob(os.path.join(path, 'depth/*'))) 147 | bds = poses_bounds[:, -2:] 148 | close_depth, inf_depth = np.ndarray.min(bds), np.ndarray.max(bds) 149 | 150 | n_cameras = poses.shape[0] 151 | poses = poses[:, :, :4] 152 | bottoms = np.array([0, 0, 0, 1]).reshape( 153 | 1, -1, 4).repeat(poses.shape[0], axis=0) 154 | poses = np.concatenate([poses, bottoms], axis=1) 155 | 156 | i_test = np.array(hold_id) 157 | video_list = i_test if split != 'train' else list( 158 | set(np.arange(n_cameras)) - set(i_test)) 159 | 160 | for i in video_list: 161 | c2w = poses[i] 162 | images_path = video_path[i] 163 | image_name = Path(images_path).stem 164 | n_frames = num_images 165 | 166 | matrix = np.linalg.inv(np.array(c2w)) 167 | R = np.transpose(matrix[:3, :3]) 168 | T = matrix[:3, 3] 169 | 170 | image = Image.open(images_path) 171 | 172 | mask_image = None 173 | if is_mask: 174 | mask_path = masks_path[i] 175 | mask_image = np.array(imread(mask_path) / 255.0) 176 | # mask is 0 or 1 177 | mask_image = np.where(mask_image > 0.5, 1.0, 0.0) 178 | # Convert 0 for tool, 1 for not tool 179 | mask_image = 1.0 - mask_image 180 | if data_type == 'endonerf': 181 | mask_image[-12:, :] = 0 182 | 183 | depth_image = None 184 | mask_depth = None 185 | if is_depth: 186 | depth_path = depths_path[i] 187 | depth_image = np.array(imread(depth_path) * 1.0) 188 | depth_image = depth_image / depth_scale 189 | near = np.percentile(depth_image, 3) 190 | far = np.percentile(depth_image, 98) 191 | mask_depth = np.bitwise_and(depth_image > near, depth_image < far) 192 | if is_mask: 193 | mask_depth = mask_depth * mask_image 194 | depth_image = depth_image * mask_depth 195 | 196 | frame_time = i / (n_frames - 1) 197 | FovX = focal2fov(focal, image.size[0]) 198 | FovY = focal2fov(focal, image.size[1]) 199 | cam_infos.append(CameraInfo(uid=i, R=R, T=T, FovX=FovX, FovY=FovY, 200 | image=image, 201 | image_path=images_path, image_name=image_name, 202 | width=image.size[0], height=image.size[1], fid=frame_time, depth=depth_image, mask_depth=mask_depth, mask=mask_image)) 203 | 204 | 205 | return cam_infos 206 | 207 | 208 | def readEndonerfInfo(path, data_type, eval, is_depth, depth_scale, is_mask, depth_initial, num_images, hold_id): # hold_id选择test的帧数ID,这个是endosurf的测试集 209 | print("Reading Training Camera") 210 | train_cam_infos = readCamerasdavinci( 211 | path, data_type, is_depth, depth_scale, is_mask, 'poses_bounds.npy', split="train", hold_id=hold_id, num_images=num_images) 212 | 213 | print("Reading Test Camera") 214 | test_cam_infos = readCamerasdavinci( 215 | path, data_type, is_depth, depth_scale, is_mask, 'poses_bounds.npy', split="test", hold_id=hold_id, num_images=num_images) 216 | 217 | if not eval: 218 | train_cam_infos.extend(test_cam_infos) 219 | test_cam_infos = [] 220 | 221 | nerf_normalization = getNerfppNorm(train_cam_infos) 222 | nerf_normalization["radius"] = 1 223 | 224 | ply_path = os.path.join(path, 'points3D.ply') 225 | if not os.path.exists(ply_path): 226 | if depth_initial: 227 | color, depth, intrinsics, mask = get_all_initial_data_endo(path, data_type, depth_scale, is_mask, 'poses_bounds.npy') 228 | 229 | xyz, RGB = get_pointcloud(color, depth, intrinsics, mask) 230 | storePly(ply_path, xyz, RGB) 231 | 232 | else: 233 | num_pts = 100_000 234 | print(f"Generating random point cloud ({num_pts})...") 235 | 236 | # We create random points inside the bounds of the synthetic Blender scenes 237 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 238 | shs = np.random.random((num_pts, 3)) / 255.0 239 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB( 240 | shs), normals=np.zeros((num_pts, 3))) 241 | 242 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 243 | 244 | try: 245 | pcd = fetchPly(ply_path) 246 | except: 247 | pcd = None 248 | 249 | scene_info = SceneInfo(point_cloud=pcd, 250 | train_cameras=train_cam_infos, 251 | test_cameras=test_cam_infos, 252 | nerf_normalization=nerf_normalization, 253 | ply_path=ply_path) 254 | return scene_info 255 | 256 | 257 | sceneLoadTypeCallbacks = { 258 | "Endonerf": readEndonerfInfo, 259 | } 260 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2]]) 54 | 55 | 56 | def rotmat2qvec(R): 57 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 58 | K = np.array([ 59 | [Rxx - Ryy - Rzz, 0, 0, 0], 60 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 61 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 62 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 63 | eigvals, eigvecs = np.linalg.eigh(K) 64 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 65 | if qvec[0] < 0: 66 | qvec *= -1 67 | return qvec 68 | 69 | 70 | class Image(BaseImage): 71 | def qvec2rotmat(self): 72 | return qvec2rotmat(self.qvec) 73 | 74 | 75 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 76 | """Read and unpack the next bytes from a binary file. 77 | :param fid: 78 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 79 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 80 | :param endian_character: Any of {@, =, <, >, !} 81 | :return: Tuple of read and unpacked values. 82 | """ 83 | data = fid.read(num_bytes) 84 | return struct.unpack(endian_character + format_char_sequence, data) 85 | 86 | 87 | def read_points3D_text(path): 88 | """ 89 | see: src/base/reconstruction.cc 90 | void Reconstruction::ReadPoints3DText(const std::string& path) 91 | void Reconstruction::WritePoints3DText(const std::string& path) 92 | """ 93 | xyzs = None 94 | rgbs = None 95 | errors = None 96 | with open(path, "r") as fid: 97 | while True: 98 | line = fid.readline() 99 | if not line: 100 | break 101 | line = line.strip() 102 | if len(line) > 0 and line[0] != "#": 103 | elems = line.split() 104 | xyz = np.array(tuple(map(float, elems[1:4]))) 105 | rgb = np.array(tuple(map(int, elems[4:7]))) 106 | error = np.array(float(elems[7])) 107 | if xyzs is None: 108 | xyzs = xyz[None, ...] 109 | rgbs = rgb[None, ...] 110 | errors = error[None, ...] 111 | else: 112 | xyzs = np.append(xyzs, xyz[None, ...], axis=0) 113 | rgbs = np.append(rgbs, rgb[None, ...], axis=0) 114 | errors = np.append(errors, error[None, ...], axis=0) 115 | return xyzs, rgbs, errors 116 | 117 | 118 | def read_points3D_binary(path_to_model_file): 119 | """ 120 | see: src/base/reconstruction.cc 121 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 122 | void Reconstruction::WritePoints3DBinary(const std::string& path) 123 | """ 124 | 125 | with open(path_to_model_file, "rb") as fid: 126 | num_points = read_next_bytes(fid, 8, "Q")[0] 127 | 128 | xyzs = np.empty((num_points, 3)) 129 | rgbs = np.empty((num_points, 3)) 130 | errors = np.empty((num_points, 1)) 131 | 132 | for p_id in range(num_points): 133 | binary_point_line_properties = read_next_bytes( 134 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 135 | xyz = np.array(binary_point_line_properties[1:4]) 136 | rgb = np.array(binary_point_line_properties[4:7]) 137 | error = np.array(binary_point_line_properties[7]) 138 | track_length = read_next_bytes( 139 | fid, num_bytes=8, format_char_sequence="Q")[0] 140 | track_elems = read_next_bytes( 141 | fid, num_bytes=8 * track_length, 142 | format_char_sequence="ii" * track_length) 143 | xyzs[p_id] = xyz 144 | rgbs[p_id] = rgb 145 | errors[p_id] = error 146 | return xyzs, rgbs, errors 147 | 148 | 149 | def read_intrinsics_text(path): 150 | """ 151 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 152 | """ 153 | cameras = {} 154 | with open(path, "r") as fid: 155 | while True: 156 | line = fid.readline() 157 | if not line: 158 | break 159 | line = line.strip() 160 | if len(line) > 0 and line[0] != "#": 161 | elems = line.split() 162 | camera_id = int(elems[0]) 163 | model = elems[1] 164 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 165 | width = int(elems[2]) 166 | height = int(elems[3]) 167 | params = np.array(tuple(map(float, elems[4:]))) 168 | cameras[camera_id] = Camera(id=camera_id, model=model, 169 | width=width, height=height, 170 | params=params) 171 | return cameras 172 | 173 | 174 | def read_extrinsics_binary(path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::ReadImagesBinary(const std::string& path) 178 | void Reconstruction::WriteImagesBinary(const std::string& path) 179 | """ 180 | images = {} 181 | with open(path_to_model_file, "rb") as fid: 182 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 183 | for _ in range(num_reg_images): 184 | binary_image_properties = read_next_bytes( 185 | fid, num_bytes=64, format_char_sequence="idddddddi") 186 | image_id = binary_image_properties[0] 187 | qvec = np.array(binary_image_properties[1:5]) 188 | tvec = np.array(binary_image_properties[5:8]) 189 | camera_id = binary_image_properties[8] 190 | image_name = "" 191 | current_char = read_next_bytes(fid, 1, "c")[0] 192 | while current_char != b"\x00": # look for the ASCII 0 entry 193 | image_name += current_char.decode("utf-8") 194 | current_char = read_next_bytes(fid, 1, "c")[0] 195 | num_points2D = read_next_bytes(fid, num_bytes=8, 196 | format_char_sequence="Q")[0] 197 | x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, 198 | format_char_sequence="ddq" * num_points2D) 199 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 200 | tuple(map(float, x_y_id_s[1::3]))]) 201 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 202 | images[image_id] = Image( 203 | id=image_id, qvec=qvec, tvec=tvec, 204 | camera_id=camera_id, name=image_name, 205 | xys=xys, point3D_ids=point3D_ids) 206 | return images 207 | 208 | 209 | def read_intrinsics_binary(path_to_model_file): 210 | """ 211 | see: src/base/reconstruction.cc 212 | void Reconstruction::WriteCamerasBinary(const std::string& path) 213 | void Reconstruction::ReadCamerasBinary(const std::string& path) 214 | """ 215 | cameras = {} 216 | with open(path_to_model_file, "rb") as fid: 217 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 218 | for _ in range(num_cameras): 219 | camera_properties = read_next_bytes( 220 | fid, num_bytes=24, format_char_sequence="iiQQ") 221 | camera_id = camera_properties[0] 222 | model_id = camera_properties[1] 223 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 224 | width = camera_properties[2] 225 | height = camera_properties[3] 226 | num_params = CAMERA_MODEL_IDS[model_id].num_params 227 | params = read_next_bytes(fid, num_bytes=8 * num_params, 228 | format_char_sequence="d" * num_params) 229 | cameras[camera_id] = Camera(id=camera_id, 230 | model=model_name, 231 | width=width, 232 | height=height, 233 | params=np.array(params)) 234 | assert len(cameras) == num_cameras 235 | return cameras 236 | 237 | 238 | def read_extrinsics_text(path): 239 | """ 240 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 241 | """ 242 | images = {} 243 | with open(path, "r") as fid: 244 | while True: 245 | line = fid.readline() 246 | if not line: 247 | break 248 | line = line.strip() 249 | if len(line) > 0 and line[0] != "#": 250 | elems = line.split() 251 | image_id = int(elems[0]) 252 | qvec = np.array(tuple(map(float, elems[1:5]))) 253 | tvec = np.array(tuple(map(float, elems[5:8]))) 254 | camera_id = int(elems[8]) 255 | image_name = elems[9] 256 | elems = fid.readline().split() 257 | xys = np.column_stack([tuple(map(float, elems[0::3])), 258 | tuple(map(float, elems[1::3]))]) 259 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 260 | images[image_id] = Image( 261 | id=image_id, qvec=qvec, tvec=tvec, 262 | camera_id=camera_id, name=image_name, 263 | xys=xys, point3D_ids=point3D_ids) 264 | return images 265 | 266 | 267 | def read_colmap_bin_array(path): 268 | """ 269 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 270 | 271 | :param path: path to the colmap binary file. 272 | :return: nd array with the floating point values in the value 273 | """ 274 | with open(path, "rb") as fid: 275 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 276 | usecols=(0, 1, 2), dtype=int) 277 | fid.seek(0) 278 | num_delimiter = 0 279 | byte = fid.read(1) 280 | while True: 281 | if byte == b"&": 282 | num_delimiter += 1 283 | if num_delimiter >= 3: 284 | break 285 | byte = fid.read(1) 286 | array = np.fromfile(fid, np.float32) 287 | array = array.reshape((width, height, channels), order="F") 288 | return np.transpose(array, (1, 0, 2)).squeeze() 289 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | import torchvision 15 | import mmcv 16 | from random import randint 17 | from utils.loss_utils import l1_loss, ssim, tv_loss, def_reg_loss 18 | from gaussian_renderer import render, network_gui 19 | import sys 20 | from scene import Scene, GaussianModel, DeformModel 21 | from utils.general_utils import safe_state, get_linear_noise_func 22 | import uuid 23 | from tqdm import tqdm 24 | from utils.image_utils import psnr 25 | from utils.initial_utils import imread 26 | from argparse import ArgumentParser, Namespace 27 | from arguments import ModelParams, PipelineParams, OptimizationParams 28 | import torch.nn.functional as F 29 | 30 | 31 | try: 32 | from torch.utils.tensorboard import SummaryWriter 33 | 34 | TENSORBOARD_FOUND = True 35 | except ImportError: 36 | TENSORBOARD_FOUND = False 37 | 38 | 39 | def training(dataset, opt, pipe, testing_iterations, saving_iterations): 40 | tb_writer = prepare_output_and_logger(dataset) 41 | gaussians = GaussianModel(dataset.sh_degree) 42 | deform = DeformModel() 43 | deform.train_setting(opt) 44 | 45 | scene = Scene(dataset, gaussians) 46 | gaussians.training_setup(opt) 47 | 48 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 49 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 50 | 51 | iter_start = torch.cuda.Event(enable_timing=True) 52 | iter_end = torch.cuda.Event(enable_timing=True) 53 | 54 | if dataset.accurate_mask: 55 | invisible_mask_path = os.path.join(dataset.source_path, "dilated_invisible_mask.png") 56 | inpaint_mask = imread(invisible_mask_path) / 255.0 57 | inpaint_mask_tensor = torch.tensor(inpaint_mask, dtype=torch.float32, device="cuda") 58 | viewpoint_stack = None 59 | ema_loss_for_log = 0.0 60 | best_psnr = 0.0 61 | best_iteration = 0 62 | 63 | progress_bar = tqdm(range(opt.iterations), desc="Training progress") 64 | 65 | for iteration in range(1, opt.iterations + 1): 66 | 67 | iter_start.record() 68 | 69 | # Every 1000 its we increase the levels of SH up to a maximum degree 70 | if iteration % 1000 == 0: 71 | gaussians.oneupSHdegree() 72 | 73 | # Pick a random Camera 74 | if not viewpoint_stack: 75 | viewpoint_stack = scene.getTrainCameras().copy() 76 | 77 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 78 | if dataset.load2gpu_on_the_fly: 79 | viewpoint_cam.load2device() 80 | fid = viewpoint_cam.fid 81 | 82 | N = gaussians.get_xyz.shape[0] 83 | time_input = fid.unsqueeze(0).expand(N, -1) 84 | d_xyz, d_rotation, d_scaling = deform.step(gaussians.get_xyz.detach(), time_input) 85 | 86 | # Render 87 | render_pkg_re = render(viewpoint_cam, gaussians, pipe, background, d_xyz, d_rotation, d_scaling) 88 | image, viewspace_point_tensor, visibility_filter, radii, depth = render_pkg_re["render"], render_pkg_re[ 89 | "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"], render_pkg_re["depth"] 90 | 91 | # Loss 92 | gt_image = viewpoint_cam.original_image.cuda() 93 | if dataset.is_mask: 94 | mask = viewpoint_cam.mask.unsqueeze(0).cuda() 95 | gt_image = gt_image * mask 96 | if dataset.accurate_mask: 97 | img_tv_loss = tv_loss(image * inpaint_mask_tensor) 98 | else: 99 | img_tv_loss = tv_loss(image * (1-mask)) 100 | 101 | image = image * mask 102 | Ll1 = l1_loss(image, gt_image) 103 | 104 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 105 | if opt.lambda_smooth != 0: 106 | loss += opt.lambda_smooth * img_tv_loss 107 | 108 | depth_loss = None 109 | if dataset.is_depth: 110 | gt_depth = viewpoint_cam.depth.unsqueeze(0).cuda() 111 | mask_depth = viewpoint_cam.mask_depth.unsqueeze(0).cuda() 112 | depth = depth * mask_depth 113 | depth_loss = l1_loss(depth, gt_depth) 114 | 115 | loss = loss + 0.001 * depth_loss 116 | 117 | # deformation loss 118 | loss_pos, loss_cov = def_reg_loss(scene.gaussians, d_xyz, d_rotation, d_scaling) 119 | 120 | loss += opt.lambda_pos * loss_pos 121 | loss += opt.lambda_cov * loss_cov 122 | 123 | loss.backward() 124 | 125 | iter_end.record() 126 | 127 | if dataset.load2gpu_on_the_fly: 128 | viewpoint_cam.load2device('cpu') 129 | 130 | with torch.no_grad(): 131 | # Progress bar 132 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 133 | if iteration % 10 == 0: 134 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 135 | progress_bar.update(10) 136 | if iteration == opt.iterations: 137 | progress_bar.close() 138 | 139 | # Keep track of max radii in image-space for pruning 140 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], 141 | radii[visibility_filter]) 142 | 143 | # Log and save 144 | cur_psnr = training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), 145 | testing_iterations, scene, render, (pipe, background), deform, 146 | dataset.load2gpu_on_the_fly, depth_loss) 147 | if iteration in testing_iterations: 148 | if cur_psnr.item() > best_psnr: 149 | best_psnr = cur_psnr.item() 150 | best_iteration = iteration 151 | 152 | if iteration in saving_iterations: 153 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 154 | scene.save(iteration) 155 | deform.save_weights(args.model_path, iteration) 156 | 157 | 158 | optimize_path = os.path.join(args.model_path, "optimize/iteration_{}".format(iteration)) 159 | 160 | render_path = os.path.join(optimize_path, "train/renders") 161 | depth_path = os.path.join(optimize_path, "train/depth") 162 | os.makedirs(render_path, exist_ok=True) 163 | os.makedirs(depth_path, exist_ok=True) 164 | 165 | train_view = scene.getTrainCameras().copy() 166 | for idx, view in enumerate(tqdm(train_view, desc="Rendering progress")): 167 | fid = view.fid 168 | name = view.colmap_id 169 | xyz = gaussians.get_xyz 170 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 171 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 172 | render_pkg_re = render(view, gaussians, pipe, background, d_xyz, d_rotation, d_scaling) 173 | rendering = render_pkg_re["render"] 174 | depth_np = render_pkg_re["depth"] 175 | depth = depth_np / (depth_np.max() + 1e-5) 176 | 177 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(name) + ".png")) 178 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(name) + ".png")) 179 | 180 | # Densification 181 | if iteration < opt.densify_until_iter: # < 15_000 182 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 183 | 184 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 185 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 186 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 187 | 188 | if iteration % opt.opacity_reset_interval == 0 or ( 189 | dataset.white_background and iteration == opt.densify_from_iter): 190 | gaussians.reset_opacity() 191 | 192 | # Optimizer step 193 | if iteration < opt.iterations: 194 | gaussians.optimizer.step() 195 | gaussians.update_learning_rate(iteration) 196 | deform.optimizer.step() 197 | gaussians.optimizer.zero_grad(set_to_none=True) 198 | deform.optimizer.zero_grad() 199 | deform.update_learning_rate(iteration) 200 | 201 | print("Best PSNR = {} in Iteration {}".format(best_psnr, best_iteration)) 202 | 203 | 204 | def prepare_output_and_logger(args): 205 | if not args.model_path: 206 | if os.getenv('OAR_JOB_ID'): 207 | unique_str = os.getenv('OAR_JOB_ID') 208 | else: 209 | unique_str = str(uuid.uuid4()) 210 | args.model_path = os.path.join("./output/", unique_str[0:10]) 211 | 212 | # Set up output folder 213 | print("Output folder: {}".format(args.model_path)) 214 | os.makedirs(args.model_path, exist_ok=True) 215 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 216 | cfg_log_f.write(str(Namespace(**vars(args)))) 217 | 218 | # Create Tensorboard writer 219 | tb_writer = None 220 | if TENSORBOARD_FOUND: 221 | tb_writer = SummaryWriter(args.model_path) 222 | else: 223 | print("Tensorboard not available: not logging progress") 224 | return tb_writer 225 | 226 | def read_config_params(args, config): 227 | params = ["OptimizationParams", "ModelParams", "PipelineParams"] 228 | for param in params: 229 | if param in config.keys(): 230 | for key, value in config[param].items(): 231 | if hasattr(args, key): 232 | setattr(args, key, value) 233 | return args 234 | 235 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc, 236 | renderArgs, deform, load2gpu_on_the_fly, depth_loss=None ): 237 | if tb_writer: 238 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 239 | 240 | if depth_loss is not None: 241 | tb_writer.add_scalar('train_loss_patches/depth_loss', depth_loss.item(), iteration) 242 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 243 | tb_writer.add_scalar('iter_time', elapsed, iteration) 244 | 245 | test_psnr = 0.0 246 | # Report test and samples of training set 247 | if iteration in testing_iterations: 248 | torch.cuda.empty_cache() 249 | validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()}, 250 | {'name': 'train', 251 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in 252 | range(5, 30, 5)]}) 253 | 254 | for config in validation_configs: 255 | if config['cameras'] and len(config['cameras']) > 0: 256 | images = torch.tensor([], device="cuda") 257 | gts = torch.tensor([], device="cuda") 258 | for idx, viewpoint in enumerate(config['cameras']): 259 | if load2gpu_on_the_fly: 260 | viewpoint.load2device() 261 | fid = viewpoint.fid 262 | xyz = scene.gaussians.get_xyz 263 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 264 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 265 | image = torch.clamp( 266 | renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz, d_rotation, d_scaling)["render"], 267 | 0.0, 1.0) 268 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 269 | 270 | 271 | if viewpoint.mask is not None: 272 | mask = viewpoint.mask.unsqueeze(0).cuda() 273 | image = image * mask 274 | gt_image = gt_image * mask 275 | images = torch.cat((images, image.unsqueeze(0)), dim=0) 276 | gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) 277 | 278 | if load2gpu_on_the_fly: 279 | viewpoint.load2device('cpu') 280 | if tb_writer and (idx < 5): 281 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), 282 | image[None], global_step=iteration) 283 | if iteration == testing_iterations[0]: 284 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), 285 | gt_image[None], global_step=iteration) 286 | 287 | l1_test = l1_loss(images, gts) 288 | psnr_test = psnr(images, gts).mean() 289 | if config['name'] == 'test' or len(validation_configs[0]['cameras']) == 0: 290 | test_psnr = psnr_test 291 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 292 | if tb_writer: 293 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 294 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 295 | 296 | if tb_writer: 297 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 298 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 299 | torch.cuda.empty_cache() 300 | 301 | return test_psnr 302 | 303 | 304 | if __name__ == "__main__": 305 | # Set up command line argument parser 306 | parser = ArgumentParser(description="Training script parameters") 307 | lp = ModelParams(parser) 308 | op = OptimizationParams(parser) 309 | pp = PipelineParams(parser) 310 | parser.add_argument('--ip', type=str, default="127.0.0.1") 311 | parser.add_argument('--port', type=int, default=6009) 312 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 313 | parser.add_argument("--test_iterations", nargs="+", type=int, 314 | default=[3000, 5000, 6000, 7_000] + list(range(10000, 60001, 1000))) 315 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[20_000, 40000, 60000]) 316 | 317 | parser.add_argument("--quiet", action="store_true") 318 | parser.add_argument("--config", type=str, default="") 319 | args = parser.parse_args(sys.argv[1:]) 320 | args.save_iterations.append(args.iterations) 321 | config = mmcv.Config.fromfile(args.config) 322 | args = read_config_params(args, config) 323 | 324 | print("Optimizing " + args.model_path) 325 | 326 | # Initialize system state (RNG) 327 | safe_state(args.quiet) 328 | 329 | # Start GUI server, configure and run training 330 | # network_gui.init(args.ip, args.port) 331 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 332 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) 333 | 334 | # All done 335 | print("\nTraining complete.") 336 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene, DeformModel 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from utils.pose_utils import pose_spherical, render_wander_path 21 | from argparse import ArgumentParser 22 | from arguments import ModelParams, PipelineParams, get_combined_args 23 | from gaussian_renderer import GaussianModel 24 | import imageio 25 | import numpy as np 26 | from utils.rigid_utils import from_homogenous, to_homogenous 27 | import open3d as o3d 28 | 29 | def render_set(model_path, load2gpu_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 30 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 31 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_color") 32 | depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") 33 | gts_depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_depth") 34 | 35 | depth_np_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth_np") 36 | gts_depth_np_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_depth_np") 37 | masks_path = os.path.join(model_path, name, "ours_{}".format(iteration), "masks") 38 | 39 | makedirs(render_path, exist_ok=True) 40 | makedirs(gts_path, exist_ok=True) 41 | makedirs(depth_path, exist_ok=True) 42 | makedirs(gts_depth_path, exist_ok=True) 43 | makedirs(depth_np_path, exist_ok=True) 44 | makedirs(gts_depth_np_path, exist_ok=True) 45 | makedirs(masks_path, exist_ok=True) 46 | 47 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 48 | if load2gpu_on_the_fly: 49 | view.load2device() 50 | fid = view.fid 51 | xyz = gaussians.get_xyz 52 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 53 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 54 | 55 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 56 | rendering = results["render"] 57 | depth_np = results["depth"] 58 | depth = depth_np / (depth_np.max() + 1e-5) 59 | mask = view.mask 60 | 61 | gt = view.original_image[0:3, :, :] 62 | if view.depth is not None: 63 | gts_depth_np = view.depth.unsqueeze(0).cpu().numpy() 64 | np.save(os.path.join(gts_depth_np_path, '{0:05d}'.format(idx) + ".npy"), gts_depth_np) 65 | 66 | gts_depth = view.depth.unsqueeze(0) 67 | gts_depth = gts_depth / (gts_depth.max() + 1e-5) 68 | torchvision.utils.save_image(gts_depth, os.path.join(gts_depth_path, '{0:05d}'.format(idx) + ".png")) 69 | 70 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 71 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 72 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) 73 | torchvision.utils.save_image(mask, os.path.join(masks_path, '{0:05d}'.format(idx) + ".png")) 74 | np.save(os.path.join(depth_np_path, '{0:05d}'.format(idx) + ".npy"), depth_np.cpu().numpy()) 75 | 76 | 77 | def interpolate_time(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 78 | render_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "renders") 79 | depth_path = os.path.join(model_path, name, "interpolate_{}".format(iteration), "depth") 80 | 81 | makedirs(render_path, exist_ok=True) 82 | makedirs(depth_path, exist_ok=True) 83 | 84 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 85 | 86 | frame = 150 87 | idx = torch.randint(0, len(views), (1,)).item() 88 | view = views[idx] 89 | renderings = [] 90 | for t in tqdm(range(0, frame, 1), desc="Rendering progress"): 91 | fid = torch.Tensor([t / (frame - 1)]).cuda() 92 | xyz = gaussians.get_xyz 93 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 94 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 95 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 96 | rendering = results["render"] 97 | renderings.append(to8b(rendering.cpu().numpy())) 98 | depth = results["depth"] 99 | depth = depth / (depth.max() + 1e-5) 100 | 101 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(t) + ".png")) 102 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(t) + ".png")) 103 | 104 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 105 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 106 | 107 | 108 | def interpolate_view(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, timer): 109 | render_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "renders") 110 | depth_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "depth") 111 | # acc_path = os.path.join(model_path, name, "interpolate_view_{}".format(iteration), "acc") 112 | 113 | makedirs(render_path, exist_ok=True) 114 | makedirs(depth_path, exist_ok=True) 115 | # makedirs(acc_path, exist_ok=True) 116 | 117 | frame = 150 118 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 119 | 120 | idx = torch.randint(0, len(views), (1,)).item() 121 | view = views[idx] # Choose a specific time for rendering 122 | 123 | render_poses = torch.stack(render_wander_path(view), 0) 124 | 125 | 126 | renderings = [] 127 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 128 | fid = view.fid 129 | 130 | matrix = np.linalg.inv(np.array(pose)) 131 | R = -np.transpose(matrix[:3, :3]) 132 | R[:, 0] = -R[:, 0] 133 | T = -matrix[:3, 3] 134 | 135 | view.reset_extrinsic(R, T) 136 | 137 | xyz = gaussians.get_xyz 138 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 139 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 140 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 141 | rendering = results["render"] 142 | renderings.append(to8b(rendering.cpu().numpy())) 143 | depth = results["depth"] 144 | depth = depth / (depth.max() + 1e-5) 145 | # acc = results["acc"] 146 | 147 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 148 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 149 | # torchvision.utils.save_image(acc, os.path.join(acc_path, '{0:05d}'.format(i) + ".png")) 150 | 151 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 152 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 153 | 154 | 155 | def interpolate_all(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, deform): 156 | render_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "renders") 157 | depth_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "depth") 158 | 159 | makedirs(render_path, exist_ok=True) 160 | makedirs(depth_path, exist_ok=True) 161 | 162 | frame = 150 163 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, frame + 1)[:-1]], 164 | 0) 165 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 166 | 167 | idx = torch.randint(0, len(views), (1,)).item() 168 | view = views[idx] # Choose a specific time for rendering 169 | 170 | renderings = [] 171 | for i, pose in enumerate(tqdm(render_poses, desc="Rendering progress")): 172 | fid = torch.Tensor([i / (frame - 1)]).cuda() 173 | 174 | matrix = np.linalg.inv(np.array(pose)) 175 | R = -np.transpose(matrix[:3, :3]) 176 | R[:, 0] = -R[:, 0] 177 | T = -matrix[:3, 3] 178 | 179 | view.reset_extrinsic(R, T) 180 | 181 | xyz = gaussians.get_xyz 182 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 183 | d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input) 184 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 185 | rendering = results["render"] 186 | renderings.append(to8b(rendering.cpu().numpy())) 187 | depth = results["depth"] 188 | depth = depth / (depth.max() + 1e-5) 189 | 190 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(i) + ".png")) 191 | torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(i) + ".png")) 192 | 193 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 194 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=30, quality=8) 195 | 196 | 197 | def interpolate_poses(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, timer): 198 | render_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "renders") 199 | depth_path = os.path.join(model_path, name, "interpolate_pose_{}".format(iteration), "depth") 200 | 201 | makedirs(render_path, exist_ok=True) 202 | makedirs(depth_path, exist_ok=True) 203 | # makedirs(acc_path, exist_ok=True) 204 | frame = 520 205 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 206 | 207 | idx = torch.randint(0, len(views), (1,)).item() 208 | view_begin = views[0] # Choose a specific time for rendering 209 | view_end = views[-1] 210 | view = views[idx] 211 | 212 | R_begin = view_begin.R 213 | R_end = view_end.R 214 | t_begin = view_begin.T 215 | t_end = view_end.T 216 | 217 | renderings = [] 218 | for i in tqdm(range(frame), desc="Rendering progress"): 219 | fid = view.fid 220 | 221 | ratio = i / (frame - 1) 222 | 223 | R_cur = (1 - ratio) * R_begin + ratio * R_end 224 | T_cur = (1 - ratio) * t_begin + ratio * t_end 225 | 226 | view.reset_extrinsic(R_cur, T_cur) 227 | 228 | xyz = gaussians.get_xyz 229 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 230 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 231 | 232 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 233 | rendering = results["render"] 234 | renderings.append(to8b(rendering.cpu().numpy())) 235 | depth = results["depth"] 236 | depth = depth / (depth.max() + 1e-5) 237 | 238 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 239 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 240 | 241 | 242 | def interpolate_view_original(model_path, load2gpt_on_the_fly, name, iteration, views, gaussians, pipeline, background, 243 | timer): 244 | render_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "renders") 245 | depth_path = os.path.join(model_path, name, "interpolate_hyper_view_{}".format(iteration), "depth") 246 | # acc_path = os.path.join(model_path, name, "interpolate_all_{}".format(iteration), "acc") 247 | 248 | makedirs(render_path, exist_ok=True) 249 | makedirs(depth_path, exist_ok=True) 250 | 251 | frame = 1000 252 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 253 | 254 | R = [] 255 | T = [] 256 | for view in views: 257 | R.append(view.R) 258 | T.append(view.T) 259 | 260 | view = views[0] 261 | renderings = [] 262 | for i in tqdm(range(frame), desc="Rendering progress"): 263 | fid = torch.Tensor([i / (frame - 1)]).cuda() 264 | 265 | query_idx = i / frame * len(views) 266 | begin_idx = int(np.floor(query_idx)) 267 | end_idx = int(np.ceil(query_idx)) 268 | if end_idx == len(views): 269 | break 270 | view_begin = views[begin_idx] 271 | view_end = views[end_idx] 272 | R_begin = view_begin.R 273 | R_end = view_end.R 274 | t_begin = view_begin.T 275 | t_end = view_end.T 276 | 277 | ratio = query_idx - begin_idx 278 | 279 | R_cur = (1 - ratio) * R_begin + ratio * R_end 280 | T_cur = (1 - ratio) * t_begin + ratio * t_end 281 | 282 | view.reset_extrinsic(R_cur, T_cur) 283 | 284 | xyz = gaussians.get_xyz 285 | time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1) 286 | d_xyz, d_rotation, d_scaling = timer.step(xyz.detach(), time_input) 287 | 288 | results = render(view, gaussians, pipeline, background, d_xyz, d_rotation, d_scaling) 289 | rendering = results["render"] 290 | renderings.append(to8b(rendering.cpu().numpy())) 291 | depth = results["depth"] 292 | depth = depth / (depth.max() + 1e-5) 293 | 294 | renderings = np.stack(renderings, 0).transpose(0, 2, 3, 1) 295 | imageio.mimwrite(os.path.join(render_path, 'video.mp4'), renderings, fps=60, quality=8) 296 | 297 | 298 | def render_sets(dataset: ModelParams, iteration: int, pipeline: PipelineParams, skip_train: bool, skip_test: bool, 299 | mode: str): 300 | with torch.no_grad(): 301 | gaussians = GaussianModel(dataset.sh_degree) 302 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 303 | deform = DeformModel() 304 | deform.load_weights(dataset.model_path) 305 | 306 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 307 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 308 | 309 | if mode == "render": 310 | render_func = render_set 311 | elif mode == "time": 312 | render_func = interpolate_time 313 | elif mode == "view": 314 | render_func = interpolate_view 315 | elif mode == "pose": 316 | render_func = interpolate_poses 317 | elif mode == "original": 318 | render_func = interpolate_view_original 319 | else: 320 | render_func = interpolate_all 321 | 322 | if not skip_train: 323 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, "train", scene.loaded_iter, 324 | scene.getTrainCameras(), gaussians, pipeline, 325 | background, deform) 326 | 327 | if not skip_test: 328 | render_func(dataset.model_path, dataset.load2gpu_on_the_fly, "test", scene.loaded_iter, 329 | scene.getTestCameras(), gaussians, pipeline, 330 | background, deform) 331 | 332 | 333 | if __name__ == "__main__": 334 | # Set up command line argument parser 335 | parser = ArgumentParser(description="Testing script parameters") 336 | model = ModelParams(parser, sentinel=True) 337 | pipeline = PipelineParams(parser) 338 | parser.add_argument("--iteration", default=-1, type=int) 339 | parser.add_argument("--skip_train", action="store_true") 340 | parser.add_argument("--skip_test", action="store_true") 341 | parser.add_argument("--quiet", action="store_true") 342 | parser.add_argument("--mode", default='render', choices=['render', 'time', 'view', 'all', 'pose', 'original']) 343 | args = get_combined_args(parser) 344 | print("Rendering " + args.model_path) 345 | 346 | # Initialize system state (RNG) 347 | safe_state(args.quiet) 348 | 349 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.mode) 350 | -------------------------------------------------------------------------------- /scene/gaussian_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import numpy as np 14 | from utils.general_utils import inverse_sigmoid, np_inverse_sigmoid, get_expon_lr_func, build_rotation 15 | from torch import nn 16 | import os 17 | from utils.system_utils import mkdir_p 18 | from plyfile import PlyData, PlyElement 19 | from utils.sh_utils import RGB2SH 20 | from simple_knn._C import distCUDA2 21 | from utils.graphics_utils import BasicPointCloud 22 | from utils.general_utils import strip_symmetric, build_scaling_rotation 23 | import tinycudann as tcnn 24 | 25 | class GaussianModel: 26 | def __init__(self, sh_degree: int): 27 | 28 | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): 29 | L = build_scaling_rotation(scaling_modifier * scaling, rotation) 30 | actual_covariance = L @ L.transpose(1, 2) 31 | symm = strip_symmetric(actual_covariance) 32 | return symm 33 | 34 | self.active_sh_degree = 0 35 | self.max_sh_degree = sh_degree 36 | 37 | self._xyz = torch.empty(0) 38 | self._features_dc = torch.empty(0) 39 | self._features_rest = torch.empty(0) 40 | self._scaling = torch.empty(0) 41 | self._rotation = torch.empty(0) 42 | self._opacity = torch.empty(0) 43 | self.max_radii2D = torch.empty(0) 44 | self.xyz_gradient_accum = torch.empty(0) 45 | 46 | self.optimizer = None 47 | 48 | self.scaling_activation = torch.exp 49 | self.scaling_inverse_activation = torch.log 50 | self.covariance_activation = build_covariance_from_scaling_rotation 51 | self.opacity_activation = torch.sigmoid 52 | self.inverse_opacity_activation = inverse_sigmoid 53 | self.rotation_activation = torch.nn.functional.normalize 54 | 55 | @property 56 | def get_scaling(self): 57 | return self.scaling_activation(self._scaling) 58 | 59 | @property 60 | def get_rotation(self): 61 | return self.rotation_activation(self._rotation) 62 | 63 | @property 64 | def get_xyz(self): 65 | return self._xyz 66 | 67 | @property 68 | def get_features(self): 69 | features_dc = self._features_dc 70 | features_rest = self._features_rest 71 | return torch.cat((features_dc, features_rest), dim=1) 72 | 73 | @property 74 | def get_opacity(self): 75 | return self.opacity_activation(self._opacity) 76 | 77 | def get_covariance(self, scaling_modifier=1): 78 | return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) 79 | 80 | def get_covariance_obs(self, d_rotation, d_scaling, scaling_modifier=1): 81 | return self.covariance_activation(self.get_scaling + d_scaling, scaling_modifier, self._rotation + d_rotation) 82 | 83 | def oneupSHdegree(self): 84 | if self.active_sh_degree < self.max_sh_degree: 85 | self.active_sh_degree += 1 86 | 87 | def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): 88 | # self.spatial_lr_scale = 5 # 89 | self.spatial_lr_scale = spatial_lr_scale 90 | fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() 91 | fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) 92 | features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() 93 | features[:, :3, 0] = fused_color 94 | features[:, 3:, 1:] = 0.0 95 | 96 | print("Number of points at initialisation : ", fused_point_cloud.shape[0]) 97 | 98 | dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) 99 | scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) 100 | rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") 101 | rots[:, 0] = 1 102 | 103 | opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) 104 | 105 | self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) 106 | self._features_dc = nn.Parameter( 107 | features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)) 108 | self._features_rest = nn.Parameter(features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)) 109 | self._scaling = nn.Parameter(scales.requires_grad_(True)) 110 | self._rotation = nn.Parameter(rots.requires_grad_(True)) 111 | self._opacity = nn.Parameter(opacities.requires_grad_(True)) 112 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 113 | 114 | def training_setup(self, training_args): 115 | self.percent_dense = training_args.percent_dense 116 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 117 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 118 | 119 | l = [ 120 | {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, 121 | {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, 122 | {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, 123 | {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, 124 | {'params': [self._scaling], 'lr': training_args.scaling_lr * self.spatial_lr_scale, "name": "scaling"}, 125 | 126 | {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} 127 | ] 128 | self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) 129 | self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale, 130 | lr_final=training_args.position_lr_final * self.spatial_lr_scale, 131 | lr_delay_mult=training_args.position_lr_delay_mult, 132 | max_steps=training_args.position_lr_max_steps) 133 | 134 | def update_learning_rate(self, iteration): 135 | ''' Learning rate scheduling per step ''' 136 | for param_group in self.optimizer.param_groups: 137 | if param_group["name"] == "xyz": 138 | lr = self.xyz_scheduler_args(iteration) 139 | param_group['lr'] = lr 140 | return lr 141 | 142 | def construct_list_of_attributes(self): 143 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 144 | # All channels except the 3 DC 145 | for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): 146 | l.append('f_dc_{}'.format(i)) 147 | for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): 148 | l.append('f_rest_{}'.format(i)) 149 | l.append('opacity') 150 | for i in range(self._scaling.shape[1]): 151 | l.append('scale_{}'.format(i)) 152 | for i in range(self._rotation.shape[1]): 153 | l.append('rot_{}'.format(i)) 154 | return l 155 | 156 | def save_ply(self, path, iteration): 157 | mkdir_p(os.path.dirname(path)) 158 | 159 | xyz = self._xyz.detach().cpu().numpy() 160 | normals = np.zeros_like(xyz) 161 | f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 162 | f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 163 | opacities = self._opacity.detach().cpu().numpy() 164 | scale = self._scaling.detach().cpu().numpy() 165 | rotation = self._rotation.detach().cpu().numpy() 166 | 167 | dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] 168 | 169 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 170 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 171 | elements[:] = list(map(tuple, attributes)) 172 | el = PlyElement.describe(elements, 'vertex') 173 | PlyData([el]).write(path) 174 | 175 | 176 | def reset_opacity(self): 177 | opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)) 178 | optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") 179 | self._opacity = optimizable_tensors["opacity"] 180 | 181 | def load_ply(self, path, og_number_points=-1): 182 | self.og_number_points = og_number_points 183 | plydata = PlyData.read(path) 184 | 185 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 186 | np.asarray(plydata.elements[0]["y"]), 187 | np.asarray(plydata.elements[0]["z"])), axis=1) 188 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 189 | 190 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 191 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 192 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 193 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 194 | 195 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 196 | assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 197 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 198 | for idx, attr_name in enumerate(extra_f_names): 199 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 200 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 201 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) 202 | 203 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 204 | scales = np.zeros((xyz.shape[0], len(scale_names))) 205 | for idx, attr_name in enumerate(scale_names): 206 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 207 | 208 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 209 | rots = np.zeros((xyz.shape[0], len(rot_names))) 210 | for idx, attr_name in enumerate(rot_names): 211 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 212 | 213 | self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) 214 | self._features_dc = nn.Parameter( 215 | torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 216 | True)) 217 | self._features_rest = nn.Parameter( 218 | torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_( 219 | True)) 220 | self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) 221 | self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) 222 | self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) 223 | 224 | self.active_sh_degree = self.max_sh_degree 225 | 226 | def replace_tensor_to_optimizer(self, tensor, name): 227 | optimizable_tensors = {} 228 | for group in self.optimizer.param_groups: 229 | if group["name"] == name: 230 | stored_state = self.optimizer.state.get(group['params'][0], None) 231 | stored_state["exp_avg"] = torch.zeros_like(tensor) 232 | stored_state["exp_avg_sq"] = torch.zeros_like(tensor) 233 | 234 | del self.optimizer.state[group['params'][0]] 235 | group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) 236 | self.optimizer.state[group['params'][0]] = stored_state 237 | 238 | optimizable_tensors[group["name"]] = group["params"][0] 239 | return optimizable_tensors 240 | 241 | def _prune_optimizer(self, mask): 242 | optimizable_tensors = {} 243 | for group in self.optimizer.param_groups: 244 | stored_state = self.optimizer.state.get(group['params'][0], None) 245 | if stored_state is not None: 246 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 247 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 248 | 249 | del self.optimizer.state[group['params'][0]] 250 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 251 | self.optimizer.state[group['params'][0]] = stored_state 252 | 253 | optimizable_tensors[group["name"]] = group["params"][0] 254 | else: 255 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 256 | optimizable_tensors[group["name"]] = group["params"][0] 257 | return optimizable_tensors 258 | 259 | def prune_points(self, mask): 260 | valid_points_mask = ~mask 261 | optimizable_tensors = self._prune_optimizer(valid_points_mask) 262 | 263 | self._xyz = optimizable_tensors["xyz"] 264 | self._features_dc = optimizable_tensors["f_dc"] 265 | self._features_rest = optimizable_tensors["f_rest"] 266 | self._opacity = optimizable_tensors["opacity"] 267 | self._scaling = optimizable_tensors["scaling"] 268 | self._rotation = optimizable_tensors["rotation"] 269 | 270 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 271 | 272 | self.denom = self.denom[valid_points_mask] 273 | self.max_radii2D = self.max_radii2D[valid_points_mask] 274 | 275 | def cat_tensors_to_optimizer(self, tensors_dict): 276 | optimizable_tensors = {} 277 | for group in self.optimizer.param_groups: 278 | assert len(group["params"]) == 1 279 | extension_tensor = tensors_dict[group["name"]] 280 | stored_state = self.optimizer.state.get(group['params'][0], None) 281 | if stored_state is not None: 282 | 283 | stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), 284 | dim=0) 285 | stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), 286 | dim=0) 287 | 288 | del self.optimizer.state[group['params'][0]] 289 | group["params"][0] = nn.Parameter( 290 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 291 | self.optimizer.state[group['params'][0]] = stored_state 292 | 293 | optimizable_tensors[group["name"]] = group["params"][0] 294 | else: 295 | group["params"][0] = nn.Parameter( 296 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) 297 | optimizable_tensors[group["name"]] = group["params"][0] 298 | 299 | return optimizable_tensors 300 | 301 | def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 302 | new_rotation): 303 | d = {"xyz": new_xyz, 304 | "f_dc": new_features_dc, 305 | "f_rest": new_features_rest, 306 | "opacity": new_opacities, 307 | "scaling": new_scaling, 308 | "rotation": new_rotation} 309 | 310 | optimizable_tensors = self.cat_tensors_to_optimizer(d) 311 | self._xyz = optimizable_tensors["xyz"] 312 | self._features_dc = optimizable_tensors["f_dc"] 313 | self._features_rest = optimizable_tensors["f_rest"] 314 | self._opacity = optimizable_tensors["opacity"] 315 | self._scaling = optimizable_tensors["scaling"] 316 | self._rotation = optimizable_tensors["rotation"] 317 | 318 | self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 319 | self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") 320 | self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") 321 | 322 | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): 323 | n_init_points = self.get_xyz.shape[0] 324 | # Extract points that satisfy the gradient condition 325 | padded_grad = torch.zeros((n_init_points), device="cuda") 326 | padded_grad[:grads.shape[0]] = grads.squeeze() 327 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 328 | selected_pts_mask = torch.logical_and(selected_pts_mask, 329 | torch.max(self.get_scaling, 330 | dim=1).values > self.percent_dense * scene_extent) 331 | 332 | stds = self.get_scaling[selected_pts_mask].repeat(N, 1) 333 | means = torch.zeros((stds.size(0), 3), device="cuda") 334 | samples = torch.normal(mean=means, std=stds) 335 | rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) 336 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) 337 | new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)) 338 | new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) 339 | new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) 340 | new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) 341 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) 342 | 343 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) 344 | 345 | prune_filter = torch.cat( 346 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) 347 | self.prune_points(prune_filter) 348 | 349 | def densify_and_clone(self, grads, grad_threshold, scene_extent): 350 | # Extract points that satisfy the gradient condition 351 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 352 | selected_pts_mask = torch.logical_and(selected_pts_mask, 353 | torch.max(self.get_scaling, 354 | dim=1).values <= self.percent_dense * scene_extent) 355 | 356 | new_xyz = self._xyz[selected_pts_mask] 357 | new_features_dc = self._features_dc[selected_pts_mask] 358 | new_features_rest = self._features_rest[selected_pts_mask] 359 | new_opacities = self._opacity[selected_pts_mask] 360 | new_scaling = self._scaling[selected_pts_mask] 361 | new_rotation = self._rotation[selected_pts_mask] 362 | 363 | self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, 364 | new_rotation) 365 | 366 | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): 367 | grads = self.xyz_gradient_accum / self.denom 368 | grads[grads.isnan()] = 0.0 369 | 370 | self.densify_and_clone(grads, max_grad, extent) 371 | self.densify_and_split(grads, max_grad, extent) 372 | 373 | prune_mask = (self.get_opacity < min_opacity).squeeze() 374 | if max_screen_size: 375 | big_points_vs = self.max_radii2D > max_screen_size 376 | big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent 377 | prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) 378 | self.prune_points(prune_mask) 379 | 380 | torch.cuda.empty_cache() 381 | 382 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 383 | self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter, :2], dim=-1, 384 | keepdim=True) 385 | self.denom[update_filter] += 1 --------------------------------------------------------------------------------