├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets └── teaser.jpg ├── compress ├── compress_ckpt_2_image.py ├── compress_ckpt_2_image_precompute.py ├── compress_image_2_video.py └── decompress_video_2_ckpt.py ├── config └── config_hash.json ├── convert.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── preprocess ├── calibration.py ├── colmap2k.py ├── colmap_helper.py ├── hifi4g_process.py └── undistortion.py ├── prune_gaussian.py ├── requirements.txt ├── run.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py ├── global_rt_field.py ├── global_rtc_field.py ├── global_t_field.py └── global_t_field_wo_hash.py ├── train.py ├── train_dynamic.py ├── train_prune.py ├── train_sequence.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | exp 10 | ablation 11 | ablation_image 12 | datasets 13 | external -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Penghao Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [SIGGRAPH Asia 2024] V^3: Viewing Volumetric Videos on Mobiles via Streamable 2D Dynamic Gaussians 2 | Official implementation for _V^3: Viewing Volumetric Videos on Mobiles via Streamable 2D Dynamic Gaussians_. 3 | 4 | **[Penghao Wang*](https://authoritywang.github.io/), [Zhirui Zhang*](https://github.com/zhangzhr4), [Liao Wang*](https://aoliao12138.github.io/), [Kaixin Yao](https://yaokxx.github.io/), [Siyuan Xie](https://simonxie2004.github.io/about/), [Jingyi Yu†](http://www.yu-jingyi.com/cv/), [Minye Wu†](https://wuminye.github.io/), [Lan Xu†](https://www.xu-lan.com/)** 5 | 6 | **SIGGRAPH Asia 2024 (ACM Transactions on Graphics)** 7 | 8 | | [Webpage](https://authoritywang.github.io/v3/) | [Paper](https://arxiv.org/pdf/2409.13648) | [Video](https://youtu.be/Z5La9AporRU?si=P95fDRxVYhXZEzYT) | [Training Code](https://github.com/AuthorityWang/VideoGS) | [SIBR Viewer Code](https://github.com/AuthorityWang/VideoGS_SIBR_viewers) | [IOS Viewer Code](https://github.com/zhangzhr4/VideoGS_IOS_viewers) |
9 | ![Teaser image](assets/teaser.jpg) 10 | 11 |
12 |
13 |

BibTeX

14 |
@article{wang2024v,
 15 |   title={V\^{} 3: Viewing Volumetric Videos on Mobiles via Streamable 2D Dynamic Gaussians},
 16 |   author={Wang, Penghao and Zhang, Zhirui and Wang, Liao and Yao, Kaixin and Xie, Siyuan and Yu, Jingyi and Wu, Minye and Xu, Lan},
 17 |   journal={ACM Transactions on Graphics (TOG)},
 18 |   volume={43},
 19 |   number={6},
 20 |   pages={1--13},
 21 |   year={2024},
 22 |   publisher={ACM New York, NY, USA}
 23 | }
24 |
25 |
26 | 27 | ## Install 28 | Create a new environment 29 | ``` 30 | conda create -n videogs python=3.9 31 | conda activate videogs 32 | ``` 33 | First install CUDA and PyTorch, our code is evaluated on [CUDA 11.6](https://developer.nvidia.com/cuda-11-6-2-download-archive) and [PyTorch 1.13.1+cu116](https://pytorch.org/get-started/previous-versions/#v1131). Then install the following dependencies: 34 | ``` 35 | pip install -r requirements.txt 36 | pip install submodules/diff-gaussian-rasterization 37 | pip install submodules/simple-knn 38 | ``` 39 | 40 | 41 | 42 | Install modified [NeuS2](https://vcai.mpi-inf.mpg.de/projects/NeuS2/) for key frame point cloud generation, please clone it to `external` folder and build it. 43 | ``` 44 | cd external 45 | git clone --recursive https://github.com/AuthorityWang/NeuS2_K.git 46 | cd NeuS2_K 47 | cmake . -B build 48 | cmake --build build --config RelWithDebInfo -j 49 | ``` 50 | 51 | ## Dataset Preprocess 52 | 53 | ### Download Dataset 54 | 55 | Our code mainly evaluated on multi-view human centric datasets including [ReRF](https://github.com/aoliao12138/ReRF_Dataset), [HiFi4G](https://github.com/moqiyinlun/HiFi4G_Dataset), and [HumanRF](https://synthesiaresearch.github.io/humanrf/#dataset) datasets. Please download the data you needed. 56 | 57 | ### Format 58 | Our dataset format is structed as follows: 59 | ``` 60 | datasets 61 | | |---xxx (data name) 62 | | | |---%d 63 | | | | |---images 64 | | | | | |---%d.png 65 | | | | |---transforms.json 66 | ``` 67 | The transforms.json is based on NGP calibration format: 68 | ``` 69 | { 70 | "frames": [ 71 | { 72 | "file_path": "xxx/xxx.png" (file path to the image), 73 | "transform_matrix": [ 74 | xxx (extrinsic) 75 | ], 76 | "K": [ 77 | xxx (intrinsic, note can be different for each view) 78 | ], 79 | "fl_x": xxx (focal length x), 80 | "fl_y": xxx (focal length y), 81 | "cx": xxx (cx), 82 | "cy": xxx (cx), 83 | "w": xxx (image width), 84 | "h": xxx (image height) 85 | }, 86 | { 87 | ... 88 | } 89 | ], 90 | "aabb_scale": xxx (aabb scale for NeuS2), 91 | "white_transparent": true (if the background is white) 92 | } 93 | ``` 94 | 95 | ### HiFi4G Dataset 96 | 97 | The dataset is structured as follows: 98 | ``` 99 | datasets 100 | |---HiFi4G 101 | | |---xxx (data name) 102 | | | |---image_undistortion_white 103 | | | | |---%d - The frame number, starts from 0. 104 | | | | | |---%d.png - Multi-view images, starts from 0. 105 | | | |---colmap/sparse/0 - Camera extrinsics and intrinsics in Gaussian Splatting format. 106 | ``` 107 | Then you need to restruct the dataset and convert colmap calibration to ngp format of transforms.json, simply run the following command: 108 | ``` 109 | cd preprocess 110 | python hifi4g_process.py --input xxx --output xxx 111 | ``` 112 | 113 |
114 | Command Line Arguments for hifi4g_process.py 115 | 116 | #### --input 117 | Input folder to the original hifi4g dataset 118 | #### --output 119 | Output folder to the processed hifi4g dataset 120 | #### --move 121 | If move the images to the output folder or copy. True for move, False for copy. 122 | 123 |
124 |
125 | 126 | The processed dataset is structured as follows: 127 | ``` 128 | datasets 129 | |---HiFi4G 130 | | |---xxx (data name) 131 | | | |---%d 132 | | | | |---images 133 | | | | | |---%d.png 134 | | | | |---transforms.json 135 | ``` 136 | 137 | ### ReRF dataset 138 | 139 | To process ReRF dataset, you need to re-calibration, undistortion the images and then convert to our format. 140 | 141 | #### Calibration 142 | 143 | Install [COLMAP](https://colmap.github.io/install.html) for calibration and undistortion. However, as images without background is hard to calibration, here we provide a colmap calibration for KPOP sequence in ReRF datasets. You can download it from [this link](https://5xmbb1-my.sharepoint.com/:f:/g/personal/auwang_5xmbb1_onmicrosoft_com/Ek6nsqEIzFxAi7j6H2FKv8UB6lNV0_h_JcLuv7JwG7ZLTg?e=SclyNr). If you need other sequence's calibration for ReRF dataset, please contact by email [wangph1@shanghaitech.edu.cn](wangph1@shanghaitech.edu.cn) 144 | 145 | #### Undistortion 146 | 147 | With installed colmap and colmap calibration, you can undistortion the other frames by the command 148 | ``` 149 | cd preprocess 150 | python undistortion.py --input xxx --output xxx --calib xxx(the path to colmap calibration) --start xxx(start frame) --end xxx(end frame) 151 | ``` 152 | 153 | Then follow the code in undistortion.py, undistortion the calibration, and use colmap2k.py to generate the transform.json file. 154 | 155 | Finally, the processed dataset is structured as follows: 156 | ``` 157 | datasets 158 | |---ReRF 159 | | |---xxx (data name) 160 | | | |---%d 161 | | | | |---images (undistorted images) 162 | | | | | |---%d.png 163 | | | | |---transforms.json 164 | ``` 165 | 166 | ## Train 167 | 168 | For processed data, lanuch training with `train_sequence.py` 169 | ``` 170 | python train_sequence.py --start 0 --end 200 --cuda 0 --data datasets/HiFi4G/0932dancer3 --output output/0923dancer3 --sh 0 --interval 1 --group_size 20 --resolution 2 171 | ``` 172 | 173 |
174 | Command Line Arguments for train_sequence.py 175 | 176 | #### --start 177 | The frame id to start training 178 | #### --end 179 | The frame id to end training 180 | #### --cuda 181 | The CUDA device for training 182 | #### --data 183 | The path to the dataset, note that this should be the folder containing frames from start to end 184 | #### --output 185 | The output path for trained frame 186 | #### --sh 187 | Order of spherical harmonics to be used. ```0``` by default. 188 | #### --interval 1 189 | The interval between frames. For example, if set to 2, the training frames will be 0, 2, 4, 6, ... 190 | #### --group_size 20 191 | The number of frames to trained in a group 192 | #### --resolution 2 193 | Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** 194 | 195 | 196 |
197 |
198 | 199 | After training, the checkpoints in the output folder is structured as follows: 200 | ``` 201 | output 202 | |---0923dancer3 203 | | |---checkpoint 204 | | | |---%d (each frame ckpt folder) 205 | | | |---record (record config and training file) 206 | | |---neus2_output 207 | ``` 208 | 209 | ## Compress 210 | 211 | After getting the Gaussian point clouds, we can compress them by the following command: 212 | ``` 213 | python compress_ckpt_2_image_precompute.py --frame_start 100 --frame_end 140 --group_size 20 --interval 1 --ply_path ~/workspace/output/v3/0923dancer3/checkpoint/ --output_folder ~/workspace/output/v3/0923dancer3/feature_image --sh_degree 0 214 | ``` 215 | The frame trained is [100, 140), so is 40 frames. 216 | The output structure will be: 217 | ``` 218 | output 219 | |---0923dancer3 220 | | |---checkpoint 221 | | |---feature_image 222 | | | |---group%d (each group's images) 223 | | | |---min_max.json (store the min max value for each frame) 224 | | | |---viewer_min_max.json (same as min_max.json, different struct) 225 | | | |---group_info.json (store the each group frame index) 226 | | |---neus2_output 227 | ``` 228 | 229 | Then compress images to video by the following command: 230 | ``` 231 | python compress_image_2_video.py --frame_start 100 --frame_end 140 --group_size 20 --output_path ~/workspace/output/v3/0923dancer3 --qp 25 232 | ``` 233 | The qp value is the parameter for compression, lower refers to higher quality, but larger size. 234 | 235 | The output structure will be: 236 | ``` 237 | output 238 | |---0923dancer3 239 | | |---checkpoint 240 | | |---feature_image 241 | | |---feature_video 242 | | | |---group%d (each group's videos) 243 | | | | |---%d.mp4 (each attribute's video) 244 | | | |---viewer_min_max.json (store each frame min max info) 245 | | | |---group_info.json (store the each group frame index) 246 | | |---neus2_output 247 | ``` 248 | 249 | Note that the `compress_image_2_video.py` need to be executed on linux OS due to video codec. 250 | 251 | Finally, the compressed video folder can be hosted by nginx server and use our [volumetric video viewer]() to play. 252 | 253 | ## Acknowledgement 254 | Our code is based on original [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) implementation. We also refer [NeuS2](https://vcai.mpi-inf.mpg.de/projects/NeuS2/) for fast key frame point cloud generation, and [3DGStream](https://sjojok.top/3dgstream/) for the inspiration of fast training strategy. 255 | 256 | Thanks for [Zhehao Shen](https://github.com/moqiyinlun) for his help on datasets process. 257 | 258 | If you find our work useful in your research, please consider citing our paper. 259 | ``` 260 | @article{wang2024v, 261 | title={V\^{} 3: Viewing Volumetric Videos on Mobiles via Streamable 2D Dynamic Gaussians}, 262 | author={Wang, Penghao and Zhang, Zhirui and Wang, Liao and Yao, Kaixin and Xie, Siyuan and Yu, Jingyi and Wu, Minye and Xu, Lan}, 263 | journal={ACM Transactions on Graphics (TOG)}, 264 | volume={43}, 265 | number={6}, 266 | pages={1--13}, 267 | year={2024}, 268 | publisher={ACM New York, NY, USA} 269 | } 270 | ``` -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = True 55 | self.data_device = "cuda" 56 | self.eval = False 57 | super().__init__(parser, "Loading Parameters", sentinel) 58 | 59 | def extract(self, args): 60 | g = super().extract(args) 61 | g.source_path = os.path.abspath(g.source_path) 62 | return g 63 | 64 | class PipelineParams(ParamGroup): 65 | def __init__(self, parser): 66 | self.convert_SHs_python = False 67 | self.compute_cov3D_python = False 68 | self.debug = False 69 | super().__init__(parser, "Pipeline Parameters") 70 | 71 | class OptimizationParams(ParamGroup): 72 | def __init__(self, parser): 73 | self.iterations = 30_000 74 | self.position_lr_init = 0.00016 75 | self.position_lr_final = 0.0000016 76 | self.position_lr_delay_mult = 0.01 77 | self.position_lr_max_steps = 30_000 78 | self.feature_lr = 0.0025 79 | self.opacity_lr = 0.05 80 | self.scaling_lr = 0.005 81 | self.rotation_lr = 0.001 82 | self.percent_dense = 0.01 83 | self.lambda_dssim = 0.2 84 | self.lambda_temporal = 1e-3 85 | self.lambda_entropy = 1e-4 86 | self.densification_interval = 100 87 | self.opacity_reset_interval = 3000 88 | self.densify_from_iter = 500 89 | self.densify_until_iter = 15_000 90 | self.densify_grad_threshold = 0.0002 91 | self.random_background = False 92 | self.densify_max_percent = 0.05 93 | self.first_frame_prune_iter = 10000 94 | super().__init__(parser, "Optimization Parameters") 95 | 96 | def get_combined_args(parser : ArgumentParser): 97 | cmdlne_string = sys.argv[1:] 98 | cfgfile_string = "Namespace()" 99 | args_cmdline = parser.parse_args(cmdlne_string) 100 | 101 | try: 102 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 103 | print("Looking for config file in", cfgfilepath) 104 | with open(cfgfilepath) as cfg_file: 105 | print("Config file found: {}".format(cfgfilepath)) 106 | cfgfile_string = cfg_file.read() 107 | except TypeError: 108 | print("Config file not found at") 109 | pass 110 | args_cfgfile = eval(cfgfile_string) 111 | 112 | merged_dict = vars(args_cfgfile).copy() 113 | for k,v in vars(args_cmdline).items(): 114 | if v != None: 115 | merged_dict[k] = v 116 | return Namespace(**merged_dict) 117 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AuthorityWang/VideoGS/99628426f5e58c1200bfbdf98f3f8e6baf31a78e/assets/teaser.jpg -------------------------------------------------------------------------------- /compress/compress_ckpt_2_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from plyfile import PlyData 5 | import json 6 | import argparse 7 | 8 | def normalize_uint8(data): 9 | min_val = np.min(data) 10 | max_val = np.max(data) 11 | normalized = (data - min_val) / (max_val - min_val) * 255.0 12 | return normalized.astype(np.uint8), min_val, max_val 13 | 14 | def normalize_uint8_tog(data, min_val, max_val): 15 | normalized = (data - min_val) / (max_val - min_val) * 255.0 16 | return normalized.astype(np.uint8), min_val, max_val 17 | 18 | def normalize_uint16(data): 19 | min_val = np.min(data) 20 | max_val = np.max(data) 21 | normalized = (data - min_val) / (max_val - min_val) * (2 ** 16 - 1) 22 | return normalized.astype(np.uint16), min_val, max_val 23 | 24 | def get_ply_matrix(file_path): 25 | plydata = PlyData.read(file_path) 26 | num_vertices = len(plydata['vertex']) 27 | num_attributes = len(plydata['vertex'].properties) 28 | data_matrix = np.zeros((num_vertices, num_attributes), dtype=float) 29 | for i, name in enumerate(plydata['vertex'].data.dtype.names): 30 | data_matrix[:, i] = plydata['vertex'].data[name] 31 | return data_matrix 32 | 33 | def calculate_image_size(num_points): 34 | image_size = 8 35 | while image_size * image_size < num_points: 36 | image_size += 8 37 | return image_size 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--frame_start", type=int, default=95) 42 | parser.add_argument("--frame_end", type=int, default=115) 43 | parser.add_argument("--group_size", type=int, default=20) 44 | parser.add_argument("--interval", type=int, default=1) 45 | parser.add_argument("--ply_path", type=str, default="/data/new_disk5/wangph1/output/xyz_smalliter_group40/checkpoint") 46 | parser.add_argument("--output_folder", type=str, default="/data/new_disk5/wangph1/output/xyz_smalliter_group40/feature_image") 47 | parser.add_argument("--sh_degree", type=int, default=0) 48 | args = parser.parse_args() 49 | 50 | frame_start_init = args.frame_start 51 | frame_end_init = args.frame_end 52 | group_size = args.group_size 53 | interval = args.interval 54 | ply_path = args.ply_path 55 | output_folder = args.output_folder 56 | sh_degree = args.sh_degree 57 | SH_N = (sh_degree + 1) * (sh_degree + 1) 58 | sh_number = SH_N * 3 59 | 60 | if not os.path.exists(output_folder): 61 | os.makedirs(output_folder) 62 | 63 | min_max_json = {} 64 | viewer_min_max_json = {} 65 | group_info_json = {} 66 | 67 | def searchForMaxIteration(folder): 68 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 69 | return max(saved_iters) 70 | 71 | 72 | for group in range(int((frame_end_init - frame_start_init) / group_size)): 73 | 74 | frame_start = group * group_size + frame_start_init 75 | frame_end = (group + 1) * group_size - 1 + frame_start_init 76 | 77 | group_info_json[str(group)] = {} 78 | group_info_json[str(group)]['frame_index'] = [group * group_size, (group + 1) * group_size - 1] 79 | group_info_json[str(group)]['name_index'] = [frame_start, frame_end] 80 | 81 | output_path = os.path.join(output_folder, f"group{group}") 82 | os.makedirs(output_path, exist_ok=True) 83 | 84 | for frame in range(frame_start, frame_end + 1, interval): 85 | 86 | png_ind = (frame - frame_start ) / interval 87 | 88 | ckpt_path = os.path.join(ply_path, str(frame), "point_cloud") 89 | # search max iteration 90 | max_iter = searchForMaxIteration(ckpt_path) 91 | 92 | # data = get_ply_matrix(os.path.join(ply_path, f"point_cloud_{frame}.ply")) 93 | current_data = get_ply_matrix(os.path.join(ply_path, str(frame), "point_cloud", f"iteration_{max_iter}", f"point_cloud.ply")) 94 | 95 | num_points = current_data.shape[0] 96 | image_size = calculate_image_size(num_points=num_points) 97 | num_attributes = current_data.shape[1] 98 | 99 | min_max_json[f'{frame}_num'] = num_points 100 | viewer_min_max_json[frame] = {} 101 | viewer_min_max_json[frame]['num'] = num_points 102 | viewer_min_max_json[frame]['info'] = [] 103 | 104 | # rotation_data = current_data[:, -4:] 105 | # rotation_length = np.sqrt(np.sum(rotation_data ** 2, axis=1)) 106 | # rotation_data_normalized = rotation_data / rotation_length[:, None] 107 | # current_data[:, -4:] = rotation_data_normalized 108 | # scale_data = np.exp(current_data[:, -7:-4]) 109 | # current_data[:, -7:-4] = scale_data 110 | # opacity_data = 1 / (1 + np.exp(-current_data[:, -8])) 111 | # current_data[:, -8] = opacity_data 112 | # shs_data = current_data[:, 6:6 + sh_number].copy() 113 | # current_data[:, 6] = shs_data[:, 0] 114 | # current_data[:, 7] = shs_data[:, 1] 115 | # current_data[:, 8] = shs_data[:, 2] 116 | # # rearrange 117 | # for j in range(1, SH_N): 118 | # current_data[:, j * 3 + 0 + 6] = shs_data[:, (j - 1) + 3] 119 | # current_data[:, j * 3 + 1 + 6] = shs_data[:, (j - 1) + SH_N + 2] 120 | # current_data[:, j * 3 + 2 + 6] = shs_data[:, (j - 1) + 2 * SH_N + 1] 121 | 122 | for i in range(num_attributes): 123 | if i > 2: 124 | attribute_data, min_val, max_val = normalize_uint8(current_data[:, i]) 125 | min_max_json[f'{frame}_{i}_min'] = float(min_val) 126 | min_max_json[f'{frame}_{i}_max'] = float(max_val) 127 | viewer_min_max_json[frame]['info'].append(float(min_val)) 128 | viewer_min_max_json[frame]['info'].append(float(max_val)) 129 | attribute_data_reshaped = attribute_data.reshape(-1, 1) 130 | image = np.zeros((image_size * image_size, 1), dtype=np.uint8) 131 | image[:attribute_data_reshaped.shape[0], :] = attribute_data_reshaped 132 | image_reshaped = image.reshape((image_size, image_size)) 133 | cv2.imwrite(os.path.join(output_path, f"{frame}_{i+3}.png"), image_reshaped) 134 | else: 135 | attribute_data, min_val, max_val = normalize_uint16(current_data[:, i]) 136 | min_max_json[f'{frame}_{i}_min'] = float(min_val) 137 | min_max_json[f'{frame}_{i}_max'] = float(max_val) 138 | viewer_min_max_json[frame]['info'].append(float(min_val)) 139 | viewer_min_max_json[frame]['info'].append(float(max_val)) 140 | attribute_data_reshaped = attribute_data.reshape(-1, 1) 141 | image_odd = np.zeros((image_size * image_size, 1), dtype=np.uint8) 142 | image_even = np.zeros((image_size * image_size, 1), dtype=np.uint8) 143 | #split the uint16 into two uint8, one is all the odd bits, the other is all the even bits 144 | # for j in range(16): 145 | # if j % 2 == 0: 146 | # image_even[:attribute_data_reshaped.shape[0], :] += ((attribute_data_reshaped >> j) & 1) << (j // 2) 147 | # else: 148 | # image_odd[:attribute_data_reshaped.shape[0], :] += ((attribute_data_reshaped >> j) & 1) << (j // 2) 149 | 150 | image_even[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped & 0xff) 151 | image_odd[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped >> 8) 152 | 153 | image_odd_reshaped = image_odd.reshape((image_size, image_size)) 154 | image_even_reshaped = image_even.reshape((image_size, image_size)) 155 | cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i}.png"), image_even_reshaped) 156 | cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i+1}.png"), image_odd_reshaped) 157 | 158 | with open(os.path.join(output_folder, "min_max.json"), "w") as f: 159 | json.dump(min_max_json, f, indent=4) 160 | 161 | with open(os.path.join(output_folder, "viewer_min_max.json"), "w") as f: 162 | json.dump(viewer_min_max_json, f, indent=4) 163 | 164 | with open(os.path.join(output_folder, "group_info.json"), "w") as f: 165 | json.dump(group_info_json, f, indent=4) -------------------------------------------------------------------------------- /compress/compress_ckpt_2_image_precompute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from plyfile import PlyData 5 | import json 6 | import argparse 7 | 8 | def normalize_uint8(data): 9 | min_val = np.min(data) 10 | max_val = np.max(data) 11 | normalized = (data - min_val) / (max_val - min_val) * 255.0 12 | return normalized.astype(np.uint8), min_val, max_val 13 | 14 | def normalize_uint8_tog(data, min_val, max_val): 15 | normalized = (data - min_val) / (max_val - min_val) * 255.0 16 | return normalized.astype(np.uint8), min_val, max_val 17 | 18 | def normalize_uint16(data): 19 | min_val = np.min(data) 20 | max_val = np.max(data) 21 | normalized = (data - min_val) / (max_val - min_val) * (2 ** 16 - 1) 22 | return normalized.astype(np.uint16), min_val, max_val 23 | 24 | def get_ply_matrix(file_path): 25 | plydata = PlyData.read(file_path) 26 | num_vertices = len(plydata['vertex']) 27 | num_attributes = len(plydata['vertex'].properties) 28 | data_matrix = np.zeros((num_vertices, num_attributes), dtype=float) 29 | for i, name in enumerate(plydata['vertex'].data.dtype.names): 30 | data_matrix[:, i] = plydata['vertex'].data[name] 31 | return data_matrix 32 | 33 | def calculate_image_size(num_points): 34 | image_size = 8 35 | while image_size * image_size < num_points: 36 | image_size += 8 37 | return image_size 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--frame_start", type=int, default=95) 42 | parser.add_argument("--frame_end", type=int, default=115) 43 | parser.add_argument("--group_size", type=int, default=20) 44 | parser.add_argument("--interval", type=int, default=1) 45 | parser.add_argument("--ply_path", type=str, default="/data/new_disk5/wangph1/output/xyz_smalliter_group40/checkpoint") 46 | parser.add_argument("--output_folder", type=str, default="/data/new_disk5/wangph1/output/xyz_smalliter_group40/feature_image") 47 | parser.add_argument("--sh_degree", type=int, default=0) 48 | args = parser.parse_args() 49 | 50 | frame_start_init = args.frame_start 51 | frame_end_init = args.frame_end 52 | group_size = args.group_size 53 | interval = args.interval 54 | ply_path = args.ply_path 55 | output_folder = args.output_folder 56 | sh_degree = args.sh_degree 57 | SH_N = (sh_degree + 1) * (sh_degree + 1) 58 | sh_number = SH_N * 3 59 | 60 | if not os.path.exists(output_folder): 61 | os.makedirs(output_folder) 62 | 63 | min_max_json = {} 64 | viewer_min_max_json = {} 65 | group_info_json = {} 66 | 67 | def searchForMaxIteration(folder): 68 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 69 | return max(saved_iters) 70 | 71 | 72 | for group in range(int((frame_end_init - frame_start_init) / group_size)): 73 | 74 | frame_start = group * group_size + frame_start_init 75 | frame_end = (group + 1) * group_size - 1 + frame_start_init 76 | 77 | group_info_json[str(group)] = {} 78 | group_info_json[str(group)]['frame_index'] = [group * group_size, (group + 1) * group_size - 1] 79 | group_info_json[str(group)]['name_index'] = [frame_start, frame_end] 80 | 81 | output_path = os.path.join(output_folder, f"group{group}") 82 | os.makedirs(output_path, exist_ok=True) 83 | 84 | for frame in range(frame_start, frame_end + 1, interval): 85 | 86 | png_ind = (frame - frame_start ) / interval 87 | 88 | ckpt_path = os.path.join(ply_path, str(frame), "point_cloud") 89 | # search max iteration 90 | max_iter = searchForMaxIteration(ckpt_path) 91 | 92 | # data = get_ply_matrix(os.path.join(ply_path, f"point_cloud_{frame}.ply")) 93 | current_data = get_ply_matrix(os.path.join(ply_path, str(frame), "point_cloud", f"iteration_{max_iter}", f"point_cloud.ply")) 94 | 95 | num_points = current_data.shape[0] 96 | image_size = calculate_image_size(num_points=num_points) 97 | num_attributes = current_data.shape[1] 98 | 99 | min_max_json[f'{frame}_num'] = num_points 100 | viewer_min_max_json[frame] = {} 101 | viewer_min_max_json[frame]['num'] = num_points 102 | viewer_min_max_json[frame]['info'] = [] 103 | 104 | rotation_data = current_data[:, -4:] 105 | rotation_length = np.sqrt(np.sum(rotation_data ** 2, axis=1)) 106 | rotation_data_normalized = rotation_data / rotation_length[:, None] 107 | current_data[:, -4:] = rotation_data_normalized 108 | scale_data = np.exp(current_data[:, -7:-4]) 109 | current_data[:, -7:-4] = scale_data 110 | opacity_data = 1 / (1 + np.exp(-current_data[:, -8])) 111 | current_data[:, -8] = opacity_data 112 | shs_data = current_data[:, 6:6 + sh_number].copy() 113 | current_data[:, 6] = shs_data[:, 0] 114 | current_data[:, 7] = shs_data[:, 1] 115 | current_data[:, 8] = shs_data[:, 2] 116 | # rearrange 117 | for j in range(1, SH_N): 118 | current_data[:, j * 3 + 0 + 6] = shs_data[:, (j - 1) + 3] 119 | current_data[:, j * 3 + 1 + 6] = shs_data[:, (j - 1) + SH_N + 2] 120 | current_data[:, j * 3 + 2 + 6] = shs_data[:, (j - 1) + 2 * SH_N + 1] 121 | 122 | for i in range(num_attributes): 123 | if i > 2: 124 | attribute_data, min_val, max_val = normalize_uint8(current_data[:, i]) 125 | min_max_json[f'{frame}_{i}_min'] = float(min_val) 126 | min_max_json[f'{frame}_{i}_max'] = float(max_val) 127 | viewer_min_max_json[frame]['info'].append(float(min_val)) 128 | viewer_min_max_json[frame]['info'].append(float(max_val)) 129 | attribute_data_reshaped = attribute_data.reshape(-1, 1) 130 | image = np.zeros((image_size * image_size, 1), dtype=np.uint8) 131 | image[:attribute_data_reshaped.shape[0], :] = attribute_data_reshaped 132 | image_reshaped = image.reshape((image_size, image_size)) 133 | cv2.imwrite(os.path.join(output_path, f"{frame}_{i+3}.png"), image_reshaped) 134 | else: 135 | attribute_data, min_val, max_val = normalize_uint16(current_data[:, i]) 136 | min_max_json[f'{frame}_{i}_min'] = float(min_val) 137 | min_max_json[f'{frame}_{i}_max'] = float(max_val) 138 | viewer_min_max_json[frame]['info'].append(float(min_val)) 139 | viewer_min_max_json[frame]['info'].append(float(max_val)) 140 | attribute_data_reshaped = attribute_data.reshape(-1, 1) 141 | image_odd = np.zeros((image_size * image_size, 1), dtype=np.uint8) 142 | image_even = np.zeros((image_size * image_size, 1), dtype=np.uint8) 143 | #split the uint16 into two uint8, one is all the odd bits, the other is all the even bits 144 | # for j in range(16): 145 | # if j % 2 == 0: 146 | # image_even[:attribute_data_reshaped.shape[0], :] += ((attribute_data_reshaped >> j) & 1) << (j // 2) 147 | # else: 148 | # image_odd[:attribute_data_reshaped.shape[0], :] += ((attribute_data_reshaped >> j) & 1) << (j // 2) 149 | 150 | image_even[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped & 0xff) 151 | image_odd[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped >> 8) 152 | 153 | image_odd_reshaped = image_odd.reshape((image_size, image_size)) 154 | image_even_reshaped = image_even.reshape((image_size, image_size)) 155 | cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i}.png"), image_even_reshaped) 156 | cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i+1}.png"), image_odd_reshaped) 157 | 158 | with open(os.path.join(output_folder, "min_max.json"), "w") as f: 159 | json.dump(min_max_json, f, indent=4) 160 | 161 | with open(os.path.join(output_folder, "viewer_min_max.json"), "w") as f: 162 | json.dump(viewer_min_max_json, f, indent=4) 163 | 164 | with open(os.path.join(output_folder, "group_info.json"), "w") as f: 165 | json.dump(group_info_json, f, indent=4) -------------------------------------------------------------------------------- /compress/compress_image_2_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--frame_start", type=int, default=95) 9 | parser.add_argument("--frame_end", type=int, default=115) 10 | parser.add_argument("--group_size", type=int, default=20) 11 | parser.add_argument("--output_path", type=str) 12 | parser.add_argument("--qp", type=int, default=25) 13 | 14 | args = parser.parse_args() 15 | 16 | group_size = args.group_size 17 | init_index = args.frame_start 18 | group_num = int((args.frame_end - args.frame_start) / group_size) 19 | 20 | # base_dir = f"/data/new_disk5/wangph1/output/xyz_smalliter_group40/feature_image" 21 | # save_dir = f"/data/new_disk5/wangph1/output/xyz_smalliter_group40/feature_video" 22 | base_dir = os.path.join(args.output_path, "feature_image") 23 | save_dir = os.path.join(args.output_path, "feature_video") 24 | os.makedirs(save_dir, exist_ok=True) 25 | # qps = [12, 14, 16, 18, 20, 23, 28] 26 | # qps = [0, 5, 10, 15, 20, 22, 25] 27 | # qps = [15] 28 | qps = [args.qp] 29 | # qps = [28, 30, 32, 35, 37, 40, 50] 30 | # qps = [43, 45, 47, 50] 31 | # qps = [43, 45, 47, 50, 0, 10, 15, 22, 26, 32, 37, 40] 32 | # qps = [15] 33 | for qp in qps: 34 | out_dir = os.path.join(save_dir, "png_all_" + str(qp)) 35 | # out_dir = save_dir 36 | # if os.path.exists(out_dir): 37 | os.makedirs(out_dir, exist_ok=True) 38 | # video_path = os.path.join(out_dir, "video") 39 | # os.makedirs(video_path, exist_ok=True) 40 | 41 | for group in range(group_num): 42 | 43 | start_index = group * group_size + init_index 44 | # end_index = (group + 1) * group_size 45 | group_path = os.path.join(out_dir, "group" + str(group)) 46 | os.makedirs(group_path, exist_ok=True) 47 | # group_video_path = os.path.join(group_path, "video") 48 | # os.makedirs(group_video_path, exist_ok=True) 49 | group_video_path = group_path 50 | 51 | input_group_path = os.path.join(base_dir, "group" + str(group)) 52 | 53 | for i in range(0,20): 54 | if i in [1,3,5]: 55 | os.system(f"ffmpeg -start_number {start_index} -i {input_group_path}/%d_{i}.png -vframes {group_size} -c:v libx264 -qp 0 -pix_fmt yuvj444p {group_video_path}/{i}.mp4") 56 | elif i in [9, 10, 11, 13, 14, 15, 16, 17, 18, 19] and qp > 22: 57 | os.system(f"ffmpeg -start_number {start_index} -i {input_group_path}/%d_{i}.png -vframes {group_size} -c:v libx264 -qp 22 -pix_fmt yuvj444p {group_video_path}/{i}.mp4") 58 | else: 59 | os.system(f"ffmpeg -start_number {start_index} -i {input_group_path}/%d_{i}.png -vframes {group_size} -c:v libx264 -qp {qp} -pix_fmt yuvj444p {group_video_path}/{i}.mp4") 60 | # if i < 3: 61 | # os.system(f"ffmpeg -i {video_path}/{i}.mp4 -vf format=gray16le -start_number 0 {out_dir}/%d_{i}.png") 62 | # else: 63 | # os.system(f"ffmpeg -i {group_video_path}/{i}.mp4 -vf format=gray -start_number {start_index} {group_path}/%d_{i}.png") 64 | #get all the file in the dir 65 | # file_list = os.listdir(save_dir) 66 | # for file in file_list: 67 | # #check if the name of file has 'occupancy' 68 | # if 'json' in file or 'atlas' in file or 'occupancy' in file: 69 | # os.system(f"cp {save_dir}/{file} {out_dir}/{file}") 70 | 71 | # copy the json 72 | # shutil.copy(os.path.join(base_dir, "min_max.json"), os.path.join(out_dir, "min_max.json")) 73 | shutil.copy(os.path.join(base_dir, "viewer_min_max.json"), os.path.join(out_dir, "viewer_min_max.json")) 74 | shutil.copy(os.path.join(base_dir, "group_info.json"), os.path.join(out_dir, "group_info.json")) 75 | 76 | 77 | print("finish") 78 | -------------------------------------------------------------------------------- /compress/decompress_video_2_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import plyfile 4 | import json 5 | import cv2 6 | 7 | input_path = "/data/new_disk5/wangph1/output/xyz_smalliter_group40" 8 | 9 | start = 95 10 | end = 115 11 | group_size = 20 12 | interval = 1 13 | qp = "manual" 14 | sh_degree = 0 15 | SH_N = (sh_degree + 1) * (sh_degree + 1) 16 | sh_number = SH_N * 3 17 | num_video = 20 18 | output_path = f"/data/new_disk5/wangph1/output/xyz_smalliter_group40/decompress/qp_{qp}" 19 | if not os.path.exists(output_path): 20 | os.makedirs(output_path) 21 | 22 | feature_video_path = os.path.join(input_path, "feature_video", f"png_all_{qp}") 23 | feature_image_path = os.path.join(input_path, "feature_image") 24 | min_max_path = os.path.join(feature_image_path, "min_max.json") 25 | 26 | group_idx = 0 27 | 28 | def read_video(video_path): 29 | cap = cv2.VideoCapture(video_path) 30 | frames = [] 31 | while True: 32 | ret, frame = cap.read() 33 | if not ret: 34 | break 35 | gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 36 | frames.append(gray_frame) 37 | return frames 38 | 39 | def get_attribute(): 40 | attribute_names = [] 41 | attribute_names.append('x') 42 | attribute_names.append('y') 43 | attribute_names.append('z') 44 | attribute_names.append('nx') 45 | attribute_names.append('ny') 46 | attribute_names.append('nz') 47 | for i in range(3): 48 | attribute_names.append('f_dc_' + str(i)) 49 | for i in range(45): 50 | attribute_names.append('f_rest_' + str(i)) 51 | attribute_names.append('opacity') 52 | for i in range(3): 53 | attribute_names.append('scale_' + str(i)) 54 | for i in range(4): 55 | attribute_names.append('rot_' + str(i)) 56 | 57 | return attribute_names 58 | 59 | def denormalize_uint8(data, min_val, max_val): 60 | return data / 255.0 * (max_val - min_val) + min_val 61 | 62 | def denormalize_uint16(data, min_val, max_val): 63 | return data / (2 ** 16 - 1) * (max_val - min_val) + min_val 64 | 65 | def reconstruct_ply_from_images(frame, num_attributes, image_size, input_folder, min_max_info): 66 | reconstructed_data = np.zeros((image_size * image_size, num_attributes), dtype=float) 67 | 68 | for i in range(num_attributes): 69 | img_path = os.path.join(input_folder, f"{frame}_{i}.png") 70 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) 71 | 72 | min_val = float(min_max_info[f'{frame}_{i}_min']) 73 | max_val = float(min_max_info[f'{frame}_{i}_max']) 74 | 75 | img_denormalized = denormalize_uint16(img, min_val, max_val) 76 | 77 | reconstructed_data[:, i] = img_denormalized.flatten() 78 | 79 | actual_num_points = min_max_info[f'{frame}_num'] 80 | reconstructed_data = reconstructed_data[:actual_num_points] 81 | 82 | return reconstructed_data, actual_num_points 83 | 84 | def save_ply(residual, output_file): 85 | n, k = residual.shape 86 | 87 | attribute_names = [] 88 | attribute_names.append('x') 89 | attribute_names.append('y') 90 | attribute_names.append('z') 91 | attribute_names.append('nx') 92 | attribute_names.append('ny') 93 | attribute_names.append('nz') 94 | for i in range(3): 95 | attribute_names.append('f_dc_' + str(i)) 96 | # for i in range(sh_number): 97 | # attribute_names.append('f_rest_' + str(i)) 98 | attribute_names.append('opacity') 99 | for i in range(3): 100 | attribute_names.append('scale_' + str(i)) 101 | for i in range(4): 102 | attribute_names.append('rot_' + str(i)) 103 | 104 | assert k == len(attribute_names) 105 | 106 | with open(output_file, 'wb') as ply_file: 107 | ply_file.write(b"ply\n") 108 | ply_file.write(b"format binary_little_endian 1.0\n") 109 | ply_file.write(b"element vertex %d\n" % n) 110 | 111 | for attribute_name in attribute_names: 112 | ply_file.write(b"property float %s\n" % attribute_name.encode()) 113 | 114 | ply_file.write(b"end_header\n") 115 | 116 | for i in range(n): 117 | vertex_data = residual[i].astype(np.float32).tobytes() 118 | ply_file.write(vertex_data) 119 | 120 | with open(min_max_path, "r") as f: 121 | min_max_info = json.load(f) 122 | 123 | for frame in range(start, end, group_size * interval): 124 | group_start = frame 125 | group_end = min(frame + group_size * interval, end) 126 | print(group_start, group_end) 127 | 128 | group_video_path = os.path.join(feature_video_path, f"group{group_idx}") 129 | group_video_data = [] 130 | for video_idx in range(num_video): 131 | video_path = os.path.join(group_video_path, f"{video_idx}.mp4") 132 | frames = read_video(video_path) 133 | group_video_data.append(frames) 134 | group_idx += 1 135 | 136 | group_frame_idx = 0 137 | # reconstruct a group 138 | for group_frame in range(group_start, group_end, interval): 139 | group_frame_data = np.zeros((min_max_info[f'{group_frame}_num'], num_video - 3), dtype=float) 140 | # position 141 | for att in range(3): 142 | # concat uint8 to uint16 143 | image_even = group_video_data[att * 2][group_frame_idx] 144 | image_odd = group_video_data[att * 2 + 1][group_frame_idx] 145 | # image_even[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped & 0xff) 146 | # image_odd[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped >> 8) 147 | image_even = image_even.astype(np.uint16) 148 | image_odd = image_odd.astype(np.uint16) 149 | image = image_even + (image_odd << 8) 150 | # denormalize 151 | min_val = float(min_max_info[f'{group_frame}_{att}_min']) 152 | max_val = float(min_max_info[f'{group_frame}_{att}_max']) 153 | # print(denormalize_uint16(image, min_val, max_val).shape) 154 | # print(group_frame_data[:, att].shape) 155 | group_frame_data[:, att] = denormalize_uint16(image, min_val, max_val).flatten()[:min_max_info[f'{group_frame}_num']] 156 | for att in range(3, 17): 157 | if att in [3, 4, 5]: 158 | continue 159 | image = group_video_data[att + 3][group_frame_idx] 160 | # denormalize 161 | min_val = float(min_max_info[f'{group_frame}_{att}_min']) 162 | max_val = float(min_max_info[f'{group_frame}_{att}_max']) 163 | group_frame_data[:, att] = denormalize_uint8(image, min_val, max_val).flatten()[:min_max_info[f'{group_frame}_num']] 164 | 165 | # save ply 166 | save_ply(group_frame_data, os.path.join(output_path, f"{group_frame}.ply")) 167 | 168 | group_frame_idx += 1 169 | -------------------------------------------------------------------------------- /config/config_hash.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": { 3 | "otype": "RelativeL2" 4 | }, 5 | "optimizer": { 6 | "otype": "Adam", 7 | "learning_rate": 1e-2, 8 | "beta1": 0.9, 9 | "beta2": 0.99, 10 | "epsilon": 1e-15, 11 | "l2_reg": 1e-6 12 | }, 13 | "encoding": { 14 | "otype": "HashGrid", 15 | "n_levels": 16, 16 | "n_features_per_level": 4, 17 | "log2_hashmap_size": 15, 18 | "base_resolution": 16, 19 | "per_level_scale": 1.5, 20 | "n_dims_to_encode": 3 21 | }, 22 | "network": { 23 | "otype": "FullyFusedMLP", 24 | "activation": "ReLU", 25 | "output_activation": "None", 26 | "n_neurons": 64, 27 | "n_hidden_layers": 2 28 | } 29 | } -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if pipe.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if pipe.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 96 | # They will be excluded from value updates used in the splitting criteria. 97 | return {"render": rendered_image, 98 | "viewspace_points": screenspace_points, 99 | "visibility_filter" : radii > 0, 100 | "radii": radii} 101 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /preprocess/calibration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--colmap_path",type=str,default="colmap") 7 | parser.add_argument("--data_path",type=str,default="/home/auwang/workspace/data/teaser_datasets/hanfu_cali/0") 8 | args = parser.parse_args() 9 | 10 | colmap_path = args.colmap_path 11 | data_path = args.data_path 12 | 13 | database_path = os.path.join(data_path,"database.db") 14 | image_path = os.path.join(data_path,"images") 15 | sparse_path = os.path.join(data_path,"sparse") 16 | dense_path = os.path.join(data_path,"dense") 17 | fused_path = os.path.join(dense_path,"fused.ply") 18 | os.makedirs(sparse_path,exist_ok=True) 19 | os.makedirs(dense_path,exist_ok=True) 20 | 21 | feature_extraction = "{} feature_extractor --database_path {} --image_path {}".format(colmap_path,database_path,image_path) 22 | exhaustive_matcher = "{} exhaustive_matcher --database_path {}".format(colmap_path,database_path) 23 | mapper = "{} mapper --database_path {} --image_path {} --output_path {}".format(colmap_path,database_path,image_path,sparse_path) 24 | patch_match_stereo = "{} patch_match_stereo --workspace_path {} --workspace_format COLMAP --PatchMatchStereo.geom_consistency true".format(colmap_path,dense_path) 25 | stereo_fusion = "{} stereo_fusion --workspace_path {} --workspace_format COLMAP --input_type geometric --output_path {}".format(colmap_path,dense_path,fused_path) 26 | 27 | os.system(feature_extraction) 28 | os.system(exhaustive_matcher) 29 | os.system(mapper) -------------------------------------------------------------------------------- /preprocess/colmap2k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | from glob import glob 13 | import os 14 | from pathlib import Path, PurePosixPath 15 | 16 | import numpy as np 17 | import json 18 | import sys 19 | import math 20 | import cv2 21 | import os 22 | import shutil 23 | from colmap_helper import read_images_binary, read_cameras_binary, write_images_text, write_cameras_text 24 | 25 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 26 | SCRIPTS_FOLDER = os.path.join(ROOT_DIR, "scripts") 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description="Convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place.") 30 | 31 | parser.add_argument("--video_in", default="", help="Run ffmpeg first to convert a provided video file into a set of images. Uses the video_fps parameter also.") 32 | parser.add_argument("--video_fps", default=2) 33 | parser.add_argument("--time_slice", default="", help="Time (in seconds) in the format t1,t2 within which the images should be generated from the video. E.g.: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video.") 34 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 35 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="Select which matcher colmap should use. Sequential for videos, exhaustive for ad-hoc images.") 36 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 37 | parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL", "OPENCV", "SIMPLE_RADIAL_FISHEYE", "RADIAL_FISHEYE", "OPENCV_FISHEYE"], help="Camera model") 38 | parser.add_argument("--colmap_camera_params", default="", help="Intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist") 39 | parser.add_argument("--text", default="colmap_text", help="Input path to the colmap text files (set automatically if --run_colmap is used).") 40 | parser.add_argument("--aabb_scale", default=32, choices=["1", "2", "4", "8", "16", "32", "64", "128"], help="Large scene scale factor. 1=scene fits in unit cube; power of 2 up to 128") 41 | parser.add_argument("--skip_early", default=0, help="Skip this many images from the start.") 42 | parser.add_argument("--keep_colmap_coords", action="store_true", help="Keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering).") 43 | parser.add_argument("--out", default="transforms.json", help="Output path.") 44 | parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.") 45 | parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.") 46 | parser.add_argument("--mask_categories", nargs="*", type=str, default=[], help="Object categories that should be masked out from the training images. See `scripts/category2id.json` for supported categories.") 47 | args = parser.parse_args() 48 | return args 49 | 50 | def do_system(arg): 51 | print(f"==== running: {arg}") 52 | err = os.system(arg) 53 | if err: 54 | print("FATAL: command failed") 55 | sys.exit(err) 56 | 57 | def variance_of_laplacian(image): 58 | return cv2.Laplacian(image, cv2.CV_64F).var() 59 | 60 | def sharpness(imagePath): 61 | image = cv2.imread(imagePath) 62 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 63 | fm = variance_of_laplacian(gray) 64 | return fm 65 | 66 | def qvec2rotmat(qvec): 67 | return np.array([ 68 | [ 69 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 70 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 71 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 72 | ], [ 73 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 74 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 75 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 76 | ], [ 77 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 78 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 79 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 80 | ] 81 | ]) 82 | 83 | def rotmat(a, b): 84 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 85 | v = np.cross(a, b) 86 | c = np.dot(a, b) 87 | # handle exception for the opposite direction input 88 | if c < -1 + 1e-10: 89 | return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) 90 | s = np.linalg.norm(v) 91 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 92 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 93 | 94 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 95 | da = da / np.linalg.norm(da) 96 | db = db / np.linalg.norm(db) 97 | c = np.cross(da, db) 98 | denom = np.linalg.norm(c)**2 99 | t = ob - oa 100 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 101 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 102 | if ta > 0: 103 | ta = 0 104 | if tb > 0: 105 | tb = 0 106 | return (oa+ta*da+ob+tb*db) * 0.5, denom 107 | 108 | def gen_K(transforms): 109 | transforms["white_transparent"] = True 110 | transforms["aabb_scale"] = 1 111 | for i in range(len(transforms["frames"])): 112 | K = [[transforms["frames"][i]["fl_x"],0,transforms["frames"][i]["cx"]],[0,transforms["frames"][i]["fl_y"],transforms["frames"][i]["cy"]],[0,0,1]] 113 | transforms["frames"][i]["K"] = K 114 | del transforms["frames"][i]["k1"] 115 | del transforms["frames"][i]["k2"] 116 | del transforms["frames"][i]["k3"] 117 | del transforms["frames"][i]["p1"] 118 | del transforms["frames"][i]["p2"] 119 | del transforms["frames"][i]["camera_angle_x"] 120 | del transforms["frames"][i]["camera_angle_y"] 121 | del transforms["frames"][i]["fovx"] 122 | del transforms["frames"][i]["fovy"] 123 | del transforms["frames"][i]["is_fisheye"] 124 | del transforms["frames"][i]["k4"] 125 | del transforms["frames"][i]["sharpness"] 126 | return transforms 127 | 128 | def bin2txt(datapath): 129 | images_path_gaussian = os.path.join(datapath, "images.bin") 130 | cameras_path_gaussian = os.path.join(datapath, "cameras.bin") 131 | images_txt_path_gaussian = os.path.join(datapath, "images.txt") 132 | cameras_txt_path_gaussian = os.path.join(datapath, "cameras.txt") 133 | images_gaussian = read_images_binary(images_path_gaussian) 134 | cameras_gaussian = read_cameras_binary(cameras_path_gaussian) 135 | write_images_text(images_gaussian,images_txt_path_gaussian) 136 | write_cameras_text(cameras_gaussian,cameras_txt_path_gaussian) 137 | # parser = argparse.ArgumentParser() 138 | # parser.add_argument("--datapath",type=str,default="") 139 | # parser=parser.parse_args() 140 | # datapath = parser.datapath 141 | # # images_white.image_white(picture_path) 142 | # bin2txt(datapath) 143 | # script_path = "/mnt/new_disk5/wangph1/workspace/process_code_h/process_code_h/colmap2nerf.py" 144 | # gaussian_path = "{}/transforms.json".format(datapath) 145 | # gaussian_txt = "{}/sparse".format(datapath) 146 | # os.system("python {} --text {} --out {} --keep_colmap_coords".format(script_path,gaussian_txt,gaussian_path)) 147 | 148 | if __name__ == "__main__": 149 | args = parse_args() 150 | AABB_SCALE = int(args.aabb_scale) 151 | SKIP_EARLY = int(args.skip_early) 152 | TEXT_FOLDER = args.text 153 | OUT_PATH = args.out 154 | print(f"outputting to {OUT_PATH}...") 155 | 156 | # bin 2 txt 157 | bin2txt(TEXT_FOLDER) 158 | 159 | cameras = {} 160 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: 161 | camera_angle_x = math.pi / 2 162 | for line in f: 163 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 164 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 165 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 166 | if line[0] == "#": 167 | continue 168 | els = line.split(" ") 169 | camera = {} 170 | camera_id = int(els[0]) 171 | camera["w"] = float(els[2]) 172 | camera["h"] = float(els[3]) 173 | camera["fl_x"] = float(els[4]) 174 | camera["fl_y"] = float(els[4]) 175 | camera["k1"] = 0 176 | camera["k2"] = 0 177 | camera["k3"] = 0 178 | camera["k4"] = 0 179 | camera["p1"] = 0 180 | camera["p2"] = 0 181 | camera["cx"] = camera["w"] / 2 182 | camera["cy"] = camera["h"] / 2 183 | camera["is_fisheye"] = False 184 | if els[1] == "SIMPLE_PINHOLE": 185 | camera["cx"] = float(els[5]) 186 | camera["cy"] = float(els[6]) 187 | elif els[1] == "PINHOLE": 188 | camera["fl_y"] = float(els[5]) 189 | camera["cx"] = float(els[6]) 190 | camera["cy"] = float(els[7]) 191 | elif els[1] == "SIMPLE_RADIAL": 192 | camera["cx"] = float(els[5]) 193 | camera["cy"] = float(els[6]) 194 | camera["k1"] = float(els[7]) 195 | elif els[1] == "RADIAL": 196 | camera["cx"] = float(els[5]) 197 | camera["cy"] = float(els[6]) 198 | camera["k1"] = float(els[7]) 199 | camera["k2"] = float(els[8]) 200 | elif els[1] == "OPENCV": 201 | camera["fl_y"] = float(els[5]) 202 | camera["cx"] = float(els[6]) 203 | camera["cy"] = float(els[7]) 204 | camera["k1"] = float(els[8]) 205 | camera["k2"] = float(els[9]) 206 | camera["p1"] = float(els[10]) 207 | camera["p2"] = float(els[11]) 208 | elif els[1] == "SIMPLE_RADIAL_FISHEYE": 209 | camera["is_fisheye"] = True 210 | camera["cx"] = float(els[5]) 211 | camera["cy"] = float(els[6]) 212 | camera["k1"] = float(els[7]) 213 | elif els[1] == "RADIAL_FISHEYE": 214 | camera["is_fisheye"] = True 215 | camera["cx"] = float(els[5]) 216 | camera["cy"] = float(els[6]) 217 | camera["k1"] = float(els[7]) 218 | camera["k2"] = float(els[8]) 219 | elif els[1] == "OPENCV_FISHEYE": 220 | camera["is_fisheye"] = True 221 | camera["fl_y"] = float(els[5]) 222 | camera["cx"] = float(els[6]) 223 | camera["cy"] = float(els[7]) 224 | camera["k1"] = float(els[8]) 225 | camera["k2"] = float(els[9]) 226 | camera["k3"] = float(els[10]) 227 | camera["k4"] = float(els[11]) 228 | else: 229 | print("Unknown camera model ", els[1]) 230 | # fl = 0.5 * w / tan(0.5 * angle_x); 231 | camera["camera_angle_x"] = math.atan(camera["w"] / (camera["fl_x"] * 2)) * 2 232 | camera["camera_angle_y"] = math.atan(camera["h"] / (camera["fl_y"] * 2)) * 2 233 | camera["fovx"] = camera["camera_angle_x"] * 180 / math.pi 234 | camera["fovy"] = camera["camera_angle_y"] * 180 / math.pi 235 | 236 | print(f"camera {camera_id}:\n\tres={camera['w'],camera['h']}\n\tcenter={camera['cx'],camera['cy']}\n\tfocal={camera['fl_x'],camera['fl_y']}\n\tfov={camera['fovx'],camera['fovy']}\n\tk={camera['k1'],camera['k2']} p={camera['p1'],camera['p2']} ") 237 | cameras[camera_id] = camera 238 | 239 | if len(cameras) == 0: 240 | print("No cameras found!") 241 | sys.exit(1) 242 | 243 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: 244 | i = 0 245 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 246 | if len(cameras) == 1: 247 | camera = cameras[camera_id] 248 | out = { 249 | "camera_angle_x": camera["camera_angle_x"], 250 | "camera_angle_y": camera["camera_angle_y"], 251 | "fl_x": camera["fl_x"], 252 | "fl_y": camera["fl_y"], 253 | "k1": camera["k1"], 254 | "k2": camera["k2"], 255 | "k3": camera["k3"], 256 | "k4": camera["k4"], 257 | "p1": camera["p1"], 258 | "p2": camera["p2"], 259 | "is_fisheye": camera["is_fisheye"], 260 | "cx": camera["cx"], 261 | "cy": camera["cy"], 262 | "w": camera["w"], 263 | "h": camera["h"], 264 | "aabb_scale": AABB_SCALE, 265 | "frames": [], 266 | } 267 | else: 268 | out = { 269 | "frames": [], 270 | "aabb_scale": AABB_SCALE 271 | } 272 | 273 | up = np.zeros(3) 274 | for line in f: 275 | line = line.strip() 276 | if line[0] == "#": 277 | continue 278 | i = i + 1 279 | if i < SKIP_EARLY*2: 280 | continue 281 | if i % 2 == 1: 282 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 283 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) 284 | # why is this requireing a relitive path while using ^ 285 | image_rel = "images" 286 | name = str(f"{image_rel}/{'_'.join(elems[9:])}") 287 | b = 100 #sharpness(name) 288 | print(name, "sharpness=",b) 289 | image_id = int(elems[0]) 290 | qvec = np.array(tuple(map(float, elems[1:5]))) 291 | tvec = np.array(tuple(map(float, elems[5:8]))) 292 | R = qvec2rotmat(-qvec) 293 | t = tvec.reshape([3,1]) 294 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 295 | c2w = np.linalg.inv(m) 296 | if not args.keep_colmap_coords: 297 | c2w[0:3,2] *= -1 # flip the y and z axis 298 | c2w[0:3,1] *= -1 299 | c2w = c2w[[1,0,2,3],:] 300 | c2w[2,:] *= -1 # flip whole world upside down 301 | 302 | up += c2w[0:3,1] 303 | 304 | frame = {"file_path":name,"sharpness":b,"transform_matrix": c2w} 305 | if len(cameras) != 1: 306 | frame.update(cameras[int(elems[8])]) 307 | out["frames"].append(frame) 308 | nframes = len(out["frames"]) 309 | 310 | if args.keep_colmap_coords: 311 | flip_mat = np.array([ 312 | [1, 0, 0, 0], 313 | [0, -1, 0, 0], 314 | [0, 0, -1, 0], 315 | [0, 0, 0, 1] 316 | ]) 317 | 318 | for f in out["frames"]: 319 | f["transform_matrix"] = np.matmul(f["transform_matrix"], flip_mat) # flip cameras (it just works) 320 | else: 321 | # don't keep colmap coords - reorient the scene to be easier to work with 322 | 323 | up = up / np.linalg.norm(up) 324 | print("up vector was", up) 325 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] 326 | R = np.pad(R,[0,1]) 327 | R[-1, -1] = 1 328 | 329 | for f in out["frames"]: 330 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 331 | 332 | # find a central point they are all looking at 333 | print("computing center of attention...") 334 | totw = 0.0 335 | totp = np.array([0.0, 0.0, 0.0]) 336 | for f in out["frames"]: 337 | mf = f["transform_matrix"][0:3,:] 338 | for g in out["frames"]: 339 | mg = g["transform_matrix"][0:3,:] 340 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 341 | if w > 0.00001: 342 | totp += p*w 343 | totw += w 344 | if totw > 0.0: 345 | totp /= totw 346 | print(totp) # the cameras are looking at totp 347 | for f in out["frames"]: 348 | f["transform_matrix"][0:3,3] -= totp 349 | 350 | avglen = 0. 351 | for f in out["frames"]: 352 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 353 | avglen /= nframes 354 | print("avg camera distance from origin", avglen) 355 | for f in out["frames"]: 356 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 357 | 358 | for f in out["frames"]: 359 | f["transform_matrix"] = f["transform_matrix"].tolist() 360 | 361 | out = gen_K(out) 362 | print(nframes,"frames") 363 | print(f"writing {OUT_PATH}") 364 | with open(OUT_PATH, "w") as outfile: 365 | json.dump(out, outfile, indent=4) -------------------------------------------------------------------------------- /preprocess/colmap_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import numpy as np 4 | import struct 5 | import argparse 6 | 7 | CameraModel = collections.namedtuple( 8 | "CameraModel", ["model_id", "model_name", "num_params"]) 9 | Camera = collections.namedtuple( 10 | "Camera", ["id", "model", "width", "height", "params"]) 11 | BaseImage = collections.namedtuple( 12 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 13 | Point3D = collections.namedtuple( 14 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 15 | 16 | class Image(BaseImage): 17 | def qvec2rotmat(self): 18 | return qvec2rotmat(self.qvec) 19 | 20 | CAMERA_MODELS = { 21 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 22 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 23 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 24 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 25 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 26 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 27 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 28 | CameraModel(model_id=7, model_name="FOV", num_params=5), 29 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 30 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 31 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 32 | } 33 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 34 | for camera_model in CAMERA_MODELS]) 35 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 36 | for camera_model in CAMERA_MODELS]) 37 | 38 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 39 | """Read and unpack the next bytes from a binary file. 40 | :param fid: 41 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 42 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 43 | :param endian_character: Any of {@, =, <, >, !} 44 | :return: Tuple of read and unpacked values. 45 | """ 46 | data = fid.read(num_bytes) 47 | return struct.unpack(endian_character + format_char_sequence, data) 48 | 49 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 50 | """pack and write to a binary file. 51 | :param fid: 52 | :param data: data to send, if multiple elements are sent at the same time, 53 | they should be encapsuled either in a list or a tuple 54 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 55 | should be the same length as the data list or tuple 56 | :param endian_character: Any of {@, =, <, >, !} 57 | """ 58 | if isinstance(data, (list, tuple)): 59 | bytes = struct.pack(endian_character + format_char_sequence, *data) 60 | else: 61 | bytes = struct.pack(endian_character + format_char_sequence, data) 62 | fid.write(bytes) 63 | 64 | def read_cameras_text(path): 65 | """ 66 | see: src/base/reconstruction.cc 67 | void Reconstruction::WriteCamerasText(const std::string& path) 68 | void Reconstruction::ReadCamerasText(const std::string& path) 69 | """ 70 | cameras = {} 71 | with open(path, "r") as fid: 72 | while True: 73 | line = fid.readline() 74 | if not line: 75 | break 76 | line = line.strip() 77 | if len(line) > 0 and line[0] != "#": 78 | elems = line.split() 79 | camera_id = int(elems[0]) 80 | model = elems[1] 81 | width = int(elems[2]) 82 | height = int(elems[3]) 83 | params = np.array(tuple(map(float, elems[4:]))) 84 | cameras[camera_id] = Camera(id=camera_id, model=model, 85 | width=width, height=height, 86 | params=params) 87 | return cameras 88 | 89 | def read_cameras_binary(path_to_model_file): 90 | """ 91 | see: src/base/reconstruction.cc 92 | void Reconstruction::WriteCamerasBinary(const std::string& path) 93 | void Reconstruction::ReadCamerasBinary(const std::string& path) 94 | """ 95 | cameras = {} 96 | with open(path_to_model_file, "rb") as fid: 97 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 98 | for _ in range(num_cameras): 99 | camera_properties = read_next_bytes( 100 | fid, num_bytes=24, format_char_sequence="iiQQ") 101 | camera_id = camera_properties[0] 102 | model_id = camera_properties[1] 103 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 104 | width = camera_properties[2] 105 | height = camera_properties[3] 106 | num_params = CAMERA_MODEL_IDS[model_id].num_params 107 | params = read_next_bytes(fid, num_bytes=8*num_params, 108 | format_char_sequence="d"*num_params) 109 | cameras[camera_id] = Camera(id=camera_id, 110 | model=model_name, 111 | width=width, 112 | height=height, 113 | params=np.array(params)) 114 | assert len(cameras) == num_cameras 115 | return cameras 116 | 117 | def write_cameras_text(cameras, path): 118 | """ 119 | see: src/base/reconstruction.cc 120 | void Reconstruction::WriteCamerasText(const std::string& path) 121 | void Reconstruction::ReadCamerasText(const std::string& path) 122 | """ 123 | HEADER = "# Camera list with one line of data per camera:\n" + \ 124 | "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \ 125 | "# Number of cameras: {}\n".format(len(cameras)) 126 | with open(path, "w") as fid: 127 | fid.write(HEADER) 128 | for _, cam in cameras.items(): 129 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 130 | line = " ".join([str(elem) for elem in to_write]) 131 | fid.write(line + "\n") 132 | 133 | def write_cameras_binary(cameras, path_to_model_file): 134 | """ 135 | see: src/base/reconstruction.cc 136 | void Reconstruction::WriteCamerasBinary(const std::string& path) 137 | void Reconstruction::ReadCamerasBinary(const std::string& path) 138 | """ 139 | with open(path_to_model_file, "wb") as fid: 140 | write_next_bytes(fid, len(cameras), "Q") 141 | for _, cam in cameras.items(): 142 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 143 | camera_properties = [cam.id, 144 | model_id, 145 | cam.width, 146 | cam.height] 147 | write_next_bytes(fid, camera_properties, "iiQQ") 148 | for p in cam.params: 149 | write_next_bytes(fid, float(p), "d") 150 | return cameras 151 | 152 | def read_images_text(path): 153 | """ 154 | see: src/base/reconstruction.cc 155 | void Reconstruction::ReadImagesText(const std::string& path) 156 | void Reconstruction::WriteImagesText(const std::string& path) 157 | """ 158 | images = {} 159 | with open(path, "r") as fid: 160 | while True: 161 | line = fid.readline() 162 | if not line: 163 | break 164 | line = line.strip() 165 | if len(line) > 0 and line[0] != "#": 166 | elems = line.split() 167 | image_id = int(elems[0]) 168 | qvec = np.array(tuple(map(float, elems[1:5]))) 169 | tvec = np.array(tuple(map(float, elems[5:8]))) 170 | camera_id = int(elems[8]) 171 | image_name = elems[9] 172 | elems = fid.readline().split() 173 | xys = np.column_stack([tuple(map(float, elems[0::3])), 174 | tuple(map(float, elems[1::3]))]) 175 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 176 | images[image_id] = Image( 177 | id=image_id, qvec=qvec, tvec=tvec, 178 | camera_id=camera_id, name=image_name, 179 | xys=xys, point3D_ids=point3D_ids) 180 | return images 181 | 182 | def read_images_binary(path_to_model_file): 183 | """ 184 | see: src/base/reconstruction.cc 185 | void Reconstruction::ReadImagesBinary(const std::string& path) 186 | void Reconstruction::WriteImagesBinary(const std::string& path) 187 | """ 188 | images = {} 189 | with open(path_to_model_file, "rb") as fid: 190 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 191 | for _ in range(num_reg_images): 192 | binary_image_properties = read_next_bytes( 193 | fid, num_bytes=64, format_char_sequence="idddddddi") 194 | image_id = binary_image_properties[0] 195 | qvec = np.array(binary_image_properties[1:5]) 196 | tvec = np.array(binary_image_properties[5:8]) 197 | camera_id = binary_image_properties[8] 198 | image_name = "" 199 | current_char = read_next_bytes(fid, 1, "c")[0] 200 | while current_char != b"\x00": # look for the ASCII 0 entry 201 | image_name += current_char.decode("utf-8") 202 | current_char = read_next_bytes(fid, 1, "c")[0] 203 | num_points2D = read_next_bytes(fid, num_bytes=8, 204 | format_char_sequence="Q")[0] 205 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 206 | format_char_sequence="ddq"*num_points2D) 207 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 208 | tuple(map(float, x_y_id_s[1::3]))]) 209 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 210 | images[image_id] = Image( 211 | id=image_id, qvec=qvec, tvec=tvec, 212 | camera_id=camera_id, name=image_name, 213 | xys=xys, point3D_ids=point3D_ids) 214 | 215 | return images 216 | 217 | def write_images_text(images, path): 218 | """ 219 | see: src/base/reconstruction.cc 220 | void Reconstruction::ReadImagesText(const std::string& path) 221 | void Reconstruction::WriteImagesText(const std::string& path) 222 | """ 223 | if len(images) == 0: 224 | mean_observations = 0 225 | else: 226 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 227 | HEADER = "# Image list with two lines of data per image:\n" + \ 228 | "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \ 229 | "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \ 230 | "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) 231 | 232 | with open(path, "w") as fid: 233 | fid.write(HEADER) 234 | for _, img in images.items(): 235 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 236 | first_line = " ".join(map(str, image_header)) 237 | fid.write(first_line + "\n") 238 | 239 | points_strings = [] 240 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 241 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 242 | fid.write(" ".join(points_strings) + "\n") 243 | 244 | def write_images_binary(images, path_to_model_file): 245 | """ 246 | see: src/base/reconstruction.cc 247 | void Reconstruction::ReadImagesBinary(const std::string& path) 248 | void Reconstruction::WriteImagesBinary(const std::string& path) 249 | """ 250 | with open(path_to_model_file, "wb") as fid: 251 | write_next_bytes(fid, len(images), "Q") 252 | for _, img in images.items(): 253 | write_next_bytes(fid, img.id, "i") 254 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 255 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 256 | write_next_bytes(fid, img.camera_id, "i") 257 | for char in img.name: 258 | write_next_bytes(fid, char.encode("utf-8"), "c") 259 | write_next_bytes(fid, b"\x00", "c") 260 | write_next_bytes(fid, len(img.point3D_ids), "Q") 261 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 262 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 263 | 264 | def qvec2rotmat(qvec): 265 | return np.array([ 266 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 267 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 268 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 269 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 270 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 271 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 272 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 273 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 274 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 275 | 276 | def rotmat2qvec(R): 277 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 278 | K = np.array([ 279 | [Rxx - Ryy - Rzz, 0, 0, 0], 280 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 281 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 282 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 283 | eigvals, eigvecs = np.linalg.eigh(K) 284 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 285 | if qvec[0] < 0: 286 | qvec *= -1 287 | return qvec -------------------------------------------------------------------------------- /preprocess/hifi4g_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import natsort 4 | from tqdm import tqdm 5 | import shutil 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Process HIFI4G data') 9 | parser.add_argument('--input', type=str, required=True, help='Path to the input data directory') 10 | parser.add_argument('--output', type=str, required=True, help='Path to the output data directory') 11 | parser.add_argument('--move', type=bool, default=False, help='If move the original data to the target folder') 12 | args = parser.parse_args() 13 | 14 | # assert input and output are same 15 | assert args.input != args.output, 'Input and output directories are same' 16 | 17 | if not os.path.exists(args.input): 18 | raise ValueError('Input directory does not exist') 19 | if not os.path.exists(args.output): 20 | os.makedirs(args.output) 21 | 22 | # generate transforms.json 23 | text_path = os.path.join(args.input, 'colmap', 'sparse') 24 | output_json_path = os.path.join(args.output, 'transforms.json') 25 | colmap2k_cmd = f"python colmap2k.py --text {text_path} --out {output_json_path} --keep_colmap_coords" 26 | os.system(colmap2k_cmd) 27 | 28 | # move the data 29 | images_folder_path = os.path.join(args.input, 'image_undistortion_white') 30 | frames = os.listdir(images_folder_path) 31 | frames = natsort.natsorted(frames) 32 | 33 | for frame in tqdm(frames): 34 | frame_source_path = os.path.join(images_folder_path, frame) 35 | # frame_source_path = os.path.join(images_folder_path, frame, 'image_undistortion_white', 'images') 36 | frame_target_path = os.path.join(args.output, frame, 'images') 37 | if args.move: 38 | shutil.move(frame_source_path, frame_target_path) 39 | else: 40 | shutil.copytree(frame_source_path, frame_target_path) 41 | 42 | # copy json 43 | shutil.copy(output_json_path, os.path.join(args.output, frame, 'transforms.json')) -------------------------------------------------------------------------------- /preprocess/undistortion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def undistortion(image_path, input_path, output_path): 5 | os.makedirs(output_path, exist_ok=True) 6 | image_undistorter = "colmap image_undistorter --image_path {} --input_path {} --output_path {} --output_type COLMAP".format(image_path,input_path,output_path) 7 | os.system(image_undistorter) 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--input",type = str) 12 | parser.add_argument("--output",type = str) 13 | parser.add_argument("--calib", type = str) 14 | parser.add_argument("--start", type = str) 15 | parser.add_argument("--end", type = str) 16 | parser.add_argument("--interval", type = str, default="1") 17 | parser=parser.parse_args() 18 | start_frame = int(parser.start) 19 | end_frame = int(parser.end) 20 | interval = int(parser.interval) 21 | 22 | input_path = parser.input 23 | output_path = parser.output 24 | calib_path = parser.calib 25 | for frame in range(start_frame, end_frame, interval): 26 | print("Processing frame: ", frame) 27 | input_image_path = os.path.join(input_path, str(frame), "images") 28 | if not os.path.exists(input_image_path): 29 | raise Exception(f"Path {input_image_path} does not exist.") 30 | 31 | output_image_path = os.path.join(output_path, str(frame), "image_undistortion_white") 32 | undistortion(input_image_path, calib_path, output_image_path) -------------------------------------------------------------------------------- /prune_gaussian.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import DynamicScene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | from utils.system_utils import searchForMaxIteration 26 | 27 | import numpy as np 28 | from plyfile import PlyData 29 | 30 | def finetune(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, last_ckpt_path, last_ckpt_iter): 31 | first_iter = 0 32 | gaussians = GaussianModel(0) 33 | scene = DynamicScene(dataset) 34 | gaussians.load_ply(last_ckpt_path) 35 | gaussians.training_setup(opt) 36 | 37 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 38 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 39 | 40 | iter_start = torch.cuda.Event(enable_timing = True) 41 | iter_end = torch.cuda.Event(enable_timing = True) 42 | 43 | viewpoint_stack = None 44 | ema_loss_for_log = 0.0 45 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 46 | first_iter += 1 47 | for iteration in range(first_iter, opt.iterations + 1): 48 | 49 | iter_start.record() 50 | 51 | gaussians.update_learning_rate(iteration) 52 | 53 | # Pick a random Camera 54 | if not viewpoint_stack: 55 | viewpoint_stack = scene.getTrainCameras().copy() 56 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 57 | 58 | # Render 59 | if (iteration - 1) == debug_from: 60 | pipe.debug = True 61 | 62 | bg = torch.rand((3), device="cuda") if opt.random_background else background 63 | 64 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 65 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 66 | 67 | # Loss 68 | gt_image = viewpoint_cam.original_image.cuda() 69 | Ll1 = l1_loss(image, gt_image) 70 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 71 | loss.backward() 72 | 73 | iter_end.record() 74 | 75 | with torch.no_grad(): 76 | # Progress bar 77 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 78 | if iteration % 10 == 0: 79 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 80 | progress_bar.update(10) 81 | if iteration == opt.iterations: 82 | progress_bar.close() 83 | 84 | # Log and save 85 | training_report(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, gaussians, scene, render, (pipe, background)) 86 | 87 | # Optimizer step 88 | if iteration < opt.iterations: 89 | gaussians.optimizer.step() 90 | gaussians.optimizer.zero_grad(set_to_none = True) 91 | 92 | with torch.no_grad(): 93 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 94 | save_pcd_path = os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(last_ckpt_iter + opt.iterations)) 95 | gaussians.save_ply(os.path.join(save_pcd_path, "point_cloud.ply")) 96 | 97 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, gaussians, scene : DynamicScene, renderFunc, renderArgs): 98 | 99 | # Report test and samples of training set 100 | if iteration in testing_iterations: 101 | torch.cuda.empty_cache() 102 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 103 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 104 | 105 | for config in validation_configs: 106 | if config['cameras'] and len(config['cameras']) > 0: 107 | l1_test = 0.0 108 | psnr_test = 0.0 109 | for idx, viewpoint in enumerate(config['cameras']): 110 | image = torch.clamp(renderFunc(viewpoint, gaussians, *renderArgs)["render"], 0.0, 1.0) 111 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 112 | l1_test += l1_loss(image, gt_image).mean().double() 113 | psnr_test += psnr(image, gt_image).mean().double() 114 | psnr_test /= len(config['cameras']) 115 | l1_test /= len(config['cameras']) 116 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 117 | torch.cuda.empty_cache() 118 | 119 | def get_ply_matrix(file_path): 120 | plydata = PlyData.read(file_path) 121 | num_vertices = len(plydata['vertex']) 122 | num_attributes = len(plydata['vertex'].properties) 123 | data_matrix = np.zeros((num_vertices, num_attributes), dtype=float) 124 | for i, name in enumerate(plydata['vertex'].data.dtype.names): 125 | data_matrix[:, i] = plydata['vertex'].data[name] 126 | return data_matrix 127 | 128 | def get_attribute(sh_degree): 129 | frest_dim = 3 * (sh_degree + 1) * (sh_degree + 1) - 3 130 | attribute_names = [] 131 | attribute_names.append('x') 132 | attribute_names.append('y') 133 | attribute_names.append('z') 134 | attribute_names.append('nx') 135 | attribute_names.append('ny') 136 | attribute_names.append('nz') 137 | for i in range(3): 138 | attribute_names.append('f_dc_' + str(i)) 139 | for i in range(frest_dim): 140 | attribute_names.append('f_rest_' + str(i)) 141 | attribute_names.append('opacity') 142 | for i in range(3): 143 | attribute_names.append('scale_' + str(i)) 144 | for i in range(4): 145 | attribute_names.append('rot_' + str(i)) 146 | 147 | return attribute_names 148 | 149 | if __name__ == "__main__": 150 | # Set up command line argument parser 151 | parser = ArgumentParser(description="Training script parameters") 152 | lp = ModelParams(parser) 153 | op = OptimizationParams(parser) 154 | pp = PipelineParams(parser) 155 | parser.add_argument('--debug_from', type=int, default=-1) 156 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 157 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[2000]) 158 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 159 | parser.add_argument("--quiet", action="store_true") 160 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 161 | parser.add_argument("--start_checkpoint", type=str, default = None) 162 | args = parser.parse_args(sys.argv[1:]) 163 | args.save_iterations.append(args.iterations) 164 | 165 | print("Optimizing " + args.model_path) 166 | 167 | # Initialize system state (RNG) 168 | safe_state(args.quiet) 169 | 170 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 171 | 172 | # prune percentage 173 | prune_percentage = 0.5 # 20% 174 | last_ckpt_iter = 12000 175 | # search for the last checkpoint 176 | pcd_path = os.path.join(args.model_path, "point_cloud") 177 | last_ckpt_path = os.path.join(pcd_path, "iteration_{}".format(last_ckpt_iter), "point_cloud.ply") 178 | 179 | sh_degree = 0 180 | 181 | pcd = get_ply_matrix(last_ckpt_path) 182 | print("Loaded point cloud with shape: ", pcd.shape) 183 | num_points = pcd.shape[0] 184 | num_points_to_prune = int(num_points * prune_percentage) 185 | # sort by opacity 186 | # opacity is the -8th column 187 | sorted_indices = np.argsort(pcd[:, -8]) 188 | # prune the first num_points_to_prune points 189 | pruned_pcd = pcd[sorted_indices[num_points_to_prune:]] 190 | pruned_num_points = pruned_pcd.shape[0] 191 | print("Pruned point cloud with shape: ", pruned_pcd.shape) 192 | 193 | # save the pruned pcd 194 | pruned_pcd_path = last_ckpt_path.replace(".ply", "_pruned.ply") 195 | attribute_list = get_attribute(sh_degree) 196 | 197 | # write the new ply file 198 | with open(os.path.join(pruned_pcd_path), 'wb') as ply_file: 199 | ply_file.write(b"ply\n") 200 | ply_file.write(b"format binary_little_endian 1.0\n") 201 | ply_file.write(b"element vertex %d\n" % pruned_num_points) 202 | 203 | for attribute_name in attribute_list: 204 | ply_file.write(b"property float %s\n" % attribute_name.encode()) 205 | 206 | ply_file.write(b"end_header\n") 207 | 208 | for i in range(pruned_num_points): 209 | vertex_data = pruned_pcd[i].astype(np.float32).tobytes() 210 | ply_file.write(vertex_data) 211 | 212 | finetune(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, pruned_pcd_path, last_ckpt_iter) 213 | 214 | # All done 215 | print("\nTraining complete.") 216 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | plyfile 2 | tqdm 3 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 4 | pymeshlab 5 | open3d 6 | commentjson 7 | imageio 8 | pybind11 9 | scipy 10 | opencv-python 11 | trimesh 12 | tensorboard 13 | natsort -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import pymeshlab 5 | import open3d as o3d 6 | import numpy as np 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--start', type=int, default='') 11 | parser.add_argument('--end', type=int, default='') 12 | parser.add_argument('--cuda', type=int, default='') 13 | parser.add_argument('--data', type=str, default='') 14 | parser.add_argument('--output', type=str, default='') 15 | parser.add_argument('--sh', type=str, default='') 16 | parser.add_argument('--interval', type=str, default='') 17 | parser.add_argument('--group_size', type=str, default='') 18 | args = parser.parse_args() 19 | 20 | print(args.start, args.end) 21 | 22 | # os.system("conda activate torch") 23 | card_id = args.cuda 24 | data_root_path = args.data 25 | output_path = args.output 26 | sh = args.sh 27 | interval = int(args.interval) 28 | group_size = int(args.group_size) 29 | 30 | # neus2_meshlab_filter_path = os.path.join(data_root_path, "luoxi_filter.mlx") 31 | 32 | neus2_output_path = os.path.join(output_path, "neus2_output") 33 | if not os.path.exists(neus2_output_path): 34 | os.makedirs(neus2_output_path) 35 | 36 | gaussian_output_path = os.path.join(output_path, "checkpoint") 37 | 38 | for i in range(args.start, args.end, group_size * interval): 39 | group_start = i 40 | group_end = min(i + group_size * interval, args.end) 41 | print(group_start, group_end) 42 | 43 | frame_path = os.path.join(data_root_path, str(i)) 44 | if not os.path.exists(frame_path): 45 | os.makedirs(frame_path) 46 | frame_neus2_output_path = os.path.join(neus2_output_path, str(i)) 47 | if not os.path.exists(frame_neus2_output_path): 48 | os.makedirs(frame_neus2_output_path) 49 | frame_neus2_ckpt_output_path = os.path.join(frame_neus2_output_path, "frame.msgpack") 50 | frame_neus2_mesh_output_path = os.path.join(frame_neus2_output_path, "points3d.obj") 51 | 52 | """NeuS2""" 53 | # neus2 command 54 | script_path = "scripts/run.py" 55 | neus2_command = f"cd external/NeuS2 && CUDA_VISIBLE_DEVICES={card_id} python {script_path} --scene {frame_path} --name neus --mode nerf --save_snapshot {frame_neus2_ckpt_output_path} --save_mesh --save_mesh_path {frame_neus2_mesh_output_path} && cd ../.." 56 | os.system(neus2_command) 57 | delete_neus2_output_path = os.path.join(frame_path, "output") 58 | shutil.rmtree(delete_neus2_output_path) 59 | 60 | # revert axis 61 | mesh1 = o3d.io.read_triangle_mesh(frame_neus2_mesh_output_path) 62 | vertices = np.asarray(mesh1.vertices) 63 | vertices = vertices[:,[2,0,1]] 64 | mesh1.vertices = o3d.utility.Vector3dVector(vertices) 65 | o3d.io.write_triangle_mesh(frame_neus2_mesh_output_path, mesh1) 66 | 67 | # use pymeshlab to convert obj to point cloud 68 | ms = pymeshlab.MeshSet() 69 | ms.load_new_mesh(frame_neus2_mesh_output_path) 70 | # ms.load_filter_script(neus2_meshlab_filter_path) 71 | # ms.apply_filter_script() 72 | ms.generate_simplified_point_cloud(samplenum = 100000) 73 | frame_points3d_output_path = os.path.join(frame_path, "points3d.ply") 74 | ms.save_current_mesh(frame_points3d_output_path, binary = True, save_vertex_normal = False) 75 | 76 | 77 | """ Gaussian """ 78 | # generate output 79 | frame_model_path = os.path.join(gaussian_output_path, str(i)) 80 | first_frame_iteration = 12000 81 | first_frame_save_iterations = first_frame_iteration 82 | first_gaussian_command = f"CUDA_VISIBLE_DEVICES={card_id} python train.py -s {frame_path} -m {frame_model_path} --iterations {first_frame_iteration} --save_iterations {first_frame_save_iterations} --sh_degree {sh} --port 600{card_id}" 83 | os.system(first_gaussian_command) 84 | 85 | # prune 86 | prune_iterations = 4000 87 | prune_gaussian_command = f"CUDA_VISIBLE_DEVICES={card_id} python prune_gaussian.py -s {frame_path} -m {frame_model_path} --sh_degree {sh} --iterations {prune_iterations}" 88 | os.system(prune_gaussian_command) 89 | 90 | # rest frame 91 | dynamic_command = f"CUDA_VISIBLE_DEVICES={card_id} python train_dynamic_t.py -s {data_root_path} -m {gaussian_output_path} --sh_degree {sh} --st {group_start} --ed {group_end} --interval {interval}" 92 | os.system(dynamic_command) 93 | 94 | print(f"Finish {group_start} to {group_end}") -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel, GTPRTGaussianModel, GTPRTCGaussianModel, GTPTGaussianModel, GTPT_wo_hash_GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 39 | 40 | self.train_cameras = {} 41 | self.test_cameras = {} 42 | 43 | if os.path.exists(os.path.join(args.source_path, "sparse")): 44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 46 | print("Found transforms_train.json file, assuming Blender data set!") 47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 48 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 49 | print("Found transforms.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 51 | else: 52 | assert False, "Could not recognize scene type!" 53 | 54 | if not self.loaded_iter: 55 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 56 | dest_file.write(src_file.read()) 57 | json_cams = [] 58 | camlist = [] 59 | if scene_info.test_cameras: 60 | camlist.extend(scene_info.test_cameras) 61 | if scene_info.train_cameras: 62 | camlist.extend(scene_info.train_cameras) 63 | for id, cam in enumerate(camlist): 64 | json_cams.append(camera_to_JSON(id, cam)) 65 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 66 | json.dump(json_cams, file) 67 | 68 | if shuffle: 69 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 70 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 71 | 72 | self.cameras_extent = scene_info.nerf_normalization["radius"] 73 | 74 | for resolution_scale in resolution_scales: 75 | print("Loading Training Cameras") 76 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 77 | print("Loading Test Cameras") 78 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 79 | 80 | if self.loaded_iter: 81 | self.gaussians.load_ply(os.path.join(self.model_path, 82 | "point_cloud", 83 | "iteration_" + str(self.loaded_iter), 84 | "point_cloud.ply")) 85 | else: 86 | try: 87 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 88 | except Exception as e: 89 | print("Training RT, do not use scene gaussians") 90 | 91 | def save(self, iteration): 92 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 93 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 94 | 95 | def save_gtp(self, iteration): 96 | point_cloud_path = os.path.join(self.model_path, "gtp_pcd/iteration_{}".format(iteration)) 97 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 98 | 99 | def getTrainCameras(self, scale=1.0): 100 | return self.train_cameras[scale] 101 | 102 | def getTestCameras(self, scale=1.0): 103 | return self.test_cameras[scale] 104 | 105 | class DynamicScene: 106 | 107 | gaussians : GaussianModel 108 | 109 | def __init__(self, args : ModelParams, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 110 | """b 111 | :param path: Path to colmap scene main folder. 112 | """ 113 | self.model_path = args.model_path 114 | self.loaded_iter = None 115 | # self.gaussians = gaussians 116 | 117 | if load_iteration: 118 | if load_iteration == -1: 119 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 120 | else: 121 | self.loaded_iter = load_iteration 122 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 123 | 124 | self.train_cameras = {} 125 | self.test_cameras = {} 126 | 127 | if os.path.exists(os.path.join(args.source_path, "sparse")): 128 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 129 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 130 | print("Found transforms_train.json file, assuming Blender data set!") 131 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 132 | elif os.path.exists(os.path.join(args.source_path, "transforms.json")): 133 | print("Found transforms.json file, assuming Blender data set!") 134 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 135 | else: 136 | assert False, "Could not recognize scene type!" 137 | 138 | if not self.loaded_iter: 139 | # with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 140 | # dest_file.write(src_file.read()) 141 | json_cams = [] 142 | camlist = [] 143 | if scene_info.test_cameras: 144 | camlist.extend(scene_info.test_cameras) 145 | if scene_info.train_cameras: 146 | camlist.extend(scene_info.train_cameras) 147 | for id, cam in enumerate(camlist): 148 | json_cams.append(camera_to_JSON(id, cam)) 149 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 150 | json.dump(json_cams, file) 151 | 152 | if shuffle: 153 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 154 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 155 | 156 | self.cameras_extent = scene_info.nerf_normalization["radius"] 157 | 158 | for resolution_scale in resolution_scales: 159 | print("Loading Training Cameras") 160 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 161 | print("Loading Test Cameras") 162 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 163 | 164 | # if self.loaded_iter: 165 | # self.gaussians.load_ply(os.path.join(self.model_path, 166 | # "point_cloud", 167 | # "iteration_" + str(self.loaded_iter), 168 | # "point_cloud.ply")) 169 | # else: 170 | # try: 171 | # self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 172 | # except Exception as e: 173 | # print("Training RT, do not use scene gaussians") 174 | 175 | # def save(self, iteration): 176 | # point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 177 | # self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 178 | 179 | # def save_gtp(self, iteration): 180 | # point_cloud_path = os.path.join(self.model_path, "gtp_pcd/iteration_{}".format(iteration)) 181 | # self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 182 | 183 | def getTrainCameras(self, scale=1.0): 184 | return self.train_cameras[scale] 185 | 186 | def getTestCameras(self, scale=1.0): 187 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | 39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | 43 | if gt_alpha_mask is not None: 44 | self.original_image *= gt_alpha_mask.to(self.data_device) 45 | else: 46 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 47 | 48 | self.zfar = 100.0 49 | self.znear = 0.01 50 | 51 | self.trans = trans 52 | self.scale = scale 53 | 54 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 55 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 56 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 57 | self.camera_center = self.world_view_transform.inverse()[3, :3] 58 | 59 | class MiniCam: 60 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 61 | self.image_width = width 62 | self.image_height = height 63 | self.FoVy = fovy 64 | self.FoVx = fovx 65 | self.znear = znear 66 | self.zfar = zfar 67 | self.world_view_transform = world_view_transform 68 | self.full_proj_transform = full_proj_transform 69 | view_inv = torch.inverse(self.world_view_transform) 70 | self.camera_center = view_inv[3][:3] 71 | 72 | -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, rotmat2qvec, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | 26 | class CameraInfo(NamedTuple): 27 | uid: int 28 | R: np.array 29 | T: np.array 30 | FovY: np.array 31 | FovX: np.array 32 | image: np.array 33 | image_path: str 34 | image_name: str 35 | width: int 36 | height: int 37 | 38 | class SceneInfo(NamedTuple): 39 | point_cloud: BasicPointCloud 40 | train_cameras: list 41 | test_cameras: list 42 | nerf_normalization: dict 43 | ply_path: str 44 | 45 | def getNerfppNorm(cam_info): 46 | def get_center_and_diag(cam_centers): 47 | cam_centers = np.hstack(cam_centers) 48 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 49 | center = avg_cam_center 50 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 51 | diagonal = np.max(dist) 52 | return center.flatten(), diagonal 53 | 54 | cam_centers = [] 55 | 56 | for cam in cam_info: 57 | W2C = getWorld2View2(cam.R, cam.T) 58 | C2W = np.linalg.inv(W2C) 59 | cam_centers.append(C2W[:3, 3:4]) 60 | 61 | center, diagonal = get_center_and_diag(cam_centers) 62 | radius = diagonal * 1.1 63 | 64 | translate = -center 65 | 66 | return {"translate": translate, "radius": radius} 67 | 68 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 69 | cam_infos = [] 70 | for idx, key in enumerate(cam_extrinsics): 71 | sys.stdout.write('\r') 72 | # the exact output you're looking for: 73 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 74 | sys.stdout.flush() 75 | 76 | extr = cam_extrinsics[key] 77 | intr = cam_intrinsics[extr.camera_id] 78 | height = intr.height 79 | width = intr.width 80 | 81 | uid = intr.id 82 | R = np.transpose(qvec2rotmat(extr.qvec)) 83 | T = np.array(extr.tvec) 84 | 85 | if intr.model=="SIMPLE_PINHOLE": 86 | focal_length_x = intr.params[0] 87 | FovY = focal2fov(focal_length_x, height) 88 | FovX = focal2fov(focal_length_x, width) 89 | elif intr.model=="PINHOLE": 90 | focal_length_x = intr.params[0] 91 | focal_length_y = intr.params[1] 92 | FovY = focal2fov(focal_length_y, height) 93 | FovX = focal2fov(focal_length_x, width) 94 | else: 95 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 96 | 97 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 98 | image_name = os.path.basename(image_path).split(".")[0] 99 | image = Image.open(image_path) 100 | 101 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 102 | image_path=image_path, image_name=image_name, width=width, height=height) 103 | cam_infos.append(cam_info) 104 | sys.stdout.write('\n') 105 | return cam_infos 106 | 107 | def fetchPly(path): 108 | plydata = PlyData.read(path) 109 | vertices = plydata['vertex'] 110 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 111 | try: 112 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 113 | except: 114 | colors = positions * 0 115 | try: 116 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 117 | except: 118 | normals = positions * 0 119 | 120 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 121 | 122 | def storePly(path, xyz, rgb): 123 | # Define the dtype for the structured array 124 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 125 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 126 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 127 | 128 | normals = np.zeros_like(xyz) 129 | 130 | elements = np.empty(xyz.shape[0], dtype=dtype) 131 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 132 | elements[:] = list(map(tuple, attributes)) 133 | 134 | # Create the PlyData object and write to file 135 | vertex_element = PlyElement.describe(elements, 'vertex') 136 | ply_data = PlyData([vertex_element]) 137 | ply_data.write(path) 138 | 139 | def readColmapSceneInfo(path, images, eval, llffhold=8): 140 | try: 141 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 142 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 143 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 144 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 145 | except: 146 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 147 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 148 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 149 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 150 | 151 | reading_dir = "images" if images == None else images 152 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 153 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 154 | 155 | if eval: 156 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 157 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 158 | else: 159 | train_cam_infos = cam_infos 160 | test_cam_infos = [] 161 | 162 | nerf_normalization = getNerfppNorm(train_cam_infos) 163 | 164 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 165 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 166 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 167 | if not os.path.exists(ply_path): 168 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 169 | try: 170 | xyz, rgb, _ = read_points3D_binary(bin_path) 171 | except: 172 | xyz, rgb, _ = read_points3D_text(txt_path) 173 | storePly(ply_path, xyz, rgb) 174 | try: 175 | pcd = fetchPly(ply_path) 176 | except: 177 | pcd = None 178 | 179 | scene_info = SceneInfo(point_cloud=pcd, 180 | train_cameras=train_cam_infos, 181 | test_cameras=test_cam_infos, 182 | nerf_normalization=nerf_normalization, 183 | ply_path=ply_path) 184 | return scene_info 185 | 186 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 187 | cam_infos = [] 188 | 189 | with open(os.path.join(path, transformsfile)) as json_file: 190 | contents = json.load(json_file) 191 | frames = contents["frames"] 192 | for idx, frame in enumerate(frames): 193 | cam_name = frame["file_path"] 194 | 195 | flip_mat = np.array([ 196 | [1, 0, 0, 0], 197 | [0, -1, 0, 0], 198 | [0, 0, -1, 0], 199 | [0, 0, 0, 1] 200 | ]) 201 | matrix = np.linalg.inv(np.matmul(np.array(frame["transform_matrix"]), flip_mat)) 202 | R = np.transpose(qvec2rotmat(-rotmat2qvec(matrix[:3,:3]))) 203 | T = matrix[:3, 3] 204 | image_path = os.path.join(path, cam_name) 205 | image_name = Path(image_path).stem 206 | 207 | image = Image.open(image_path) 208 | image = np.array(image) 209 | # im_data = np.array(image.convert("RGBA")) 210 | # bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 211 | norm_data = image / 255.0 212 | # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 213 | image = Image.fromarray(np.array(norm_data*255.0, dtype=np.byte), "RGB") 214 | 215 | 216 | # image = Image.open(image_path) 217 | # im_data = np.array(image.convert("RGBA")) 218 | # bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 219 | # norm_data = im_data / 255.0 220 | # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 221 | # image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 222 | 223 | fx = frame["fl_x"] 224 | fy = frame["fl_y"] 225 | cx = frame["cx"] 226 | cy = frame["cy"] 227 | FovY = focal2fov(fy, image.size[1]) 228 | FovX = focal2fov(fx, image.size[0]) 229 | 230 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 231 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 232 | 233 | return cam_infos 234 | 235 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png", ply_path = None): 236 | print("Reading Training Transforms") 237 | train_cam_infos = readCamerasFromTransforms(path, "transforms.json", white_background, extension) 238 | print("Reading Test Transforms") 239 | # test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 240 | test_cam_infos = [] 241 | if not eval: 242 | train_cam_infos.extend(test_cam_infos) 243 | test_cam_infos = [] 244 | 245 | nerf_normalization = getNerfppNorm(train_cam_infos) 246 | 247 | ply_path = os.path.join(path, "points3d.ply") 248 | if not os.path.exists(ply_path): 249 | # # Since this data set has no colmap data, we start with random points 250 | # num_pts = 100_000 251 | # print(f"Generating random point cloud ({num_pts})...") 252 | 253 | # # We create random points inside the bounds of the synthetic Blender scenes 254 | # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 255 | # shs = np.random.random((num_pts, 3)) / 255.0 256 | # pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 257 | 258 | # storePly(ply_path, xyz, SH2RGB(shs) * 255) 259 | pass 260 | try: 261 | pcd = fetchPly(ply_path) 262 | except: 263 | pcd = None 264 | 265 | scene_info = SceneInfo(point_cloud=pcd, 266 | train_cameras=train_cam_infos, 267 | test_cameras=test_cam_infos, 268 | nerf_normalization=nerf_normalization, 269 | ply_path=ply_path) 270 | return scene_info 271 | 272 | sceneLoadTypeCallbacks = { 273 | "Colmap": readColmapSceneInfo, 274 | "Blender" : readNerfSyntheticInfo 275 | } -------------------------------------------------------------------------------- /scene/global_rt_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tinycudann as tcnn 3 | import numpy as np 4 | import json 5 | 6 | class GlobalRTField(torch.nn.Module): 7 | def __init__(self, config_path = "config/config_hash.json") -> None: 8 | super().__init__() 9 | 10 | self.config_path = config_path 11 | with open(self.config_path) as config_file: 12 | self.config = json.load(config_file) 13 | config_file.close() 14 | 15 | self.model = tcnn.NetworkWithInputEncoding( 16 | n_input_dims=3, 17 | n_output_dims=3 + 4, 18 | encoding_config=self.config["encoding"], 19 | network_config=self.config["network"] 20 | ).to("cuda") 21 | 22 | def forward(self, pcd): 23 | deform_params = self.model(pcd) 24 | global_translation = deform_params[:, :3] 25 | global_quaternion = deform_params[:, 3:] 26 | return global_translation.float(), global_quaternion.float() 27 | 28 | def dump_ckpt(self, output_path): 29 | print(f"Saving model to {output_path}") 30 | torch.save(self.model.state_dict(), output_path) 31 | 32 | def load_ckpt(self, input_path): 33 | print(f"Loading model from {input_path}") 34 | self.model.load_state_dict(torch.load(input_path)) -------------------------------------------------------------------------------- /scene/global_rtc_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tinycudann as tcnn 3 | import numpy as np 4 | import json 5 | 6 | class GlobalRTCField(torch.nn.Module): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | 10 | self.config_path = "config/config_hash.json" 11 | with open(self.config_path) as config_file: 12 | self.config = json.load(config_file) 13 | config_file.close() 14 | 15 | self.model = tcnn.NetworkWithInputEncoding( 16 | n_input_dims=3, 17 | n_output_dims=3 + 4 + 3, 18 | encoding_config=self.config["encoding"], 19 | network_config=self.config["network"] 20 | ).to("cuda") 21 | 22 | def forward(self, pcd): 23 | deform_params = self.model(pcd) 24 | global_translation = deform_params[:, :3] 25 | global_quaternion = deform_params[:, 3:7] 26 | global_scale = deform_params[:, 7:] 27 | # return global_translation, global_quaternion 28 | # float16 to float32 29 | return global_translation.float(), global_quaternion.float(), global_scale.float() -------------------------------------------------------------------------------- /scene/global_t_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tinycudann as tcnn 3 | import numpy as np 4 | import json 5 | 6 | class GlobalTField(torch.nn.Module): 7 | def __init__(self, config_path = "config/config_hash.json") -> None: 8 | super().__init__() 9 | 10 | self.config_path = config_path 11 | with open(self.config_path) as config_file: 12 | self.config = json.load(config_file) 13 | config_file.close() 14 | 15 | self.model = tcnn.NetworkWithInputEncoding( 16 | n_input_dims=3, 17 | n_output_dims=3, 18 | encoding_config=self.config["encoding"], 19 | network_config=self.config["network"] 20 | ).to("cuda") 21 | 22 | def forward(self, pcd): 23 | deform_params = self.model(pcd) 24 | global_translation = deform_params 25 | # global_quaternion = deform_params[:, 3:] 26 | return global_translation.float() 27 | 28 | def dump_ckpt(self, output_path): 29 | print(f"Saving model to {output_path}") 30 | torch.save(self.model.state_dict(), output_path) 31 | 32 | def load_ckpt(self, input_path): 33 | print(f"Loading model from {input_path}") 34 | self.model.load_state_dict(torch.load(input_path)) -------------------------------------------------------------------------------- /scene/global_t_field_wo_hash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tinycudann as tcnn 3 | import numpy as np 4 | import json 5 | 6 | class GlobalTField_wo_hash(torch.nn.Module): 7 | def __init__(self, config_path = "config/config_wo_hash.json") -> None: 8 | super().__init__() 9 | 10 | self.config_path = config_path 11 | with open(self.config_path) as config_file: 12 | self.config = json.load(config_file) 13 | config_file.close() 14 | 15 | self.model = tcnn.Network( 16 | n_input_dims=3, 17 | n_output_dims=3, 18 | network_config=self.config["network"] 19 | ).to("cuda") 20 | 21 | def forward(self, pcd): 22 | deform_params = self.model(pcd) 23 | global_translation = deform_params 24 | # global_quaternion = deform_params[:, 3:] 25 | return global_translation.float() 26 | 27 | def dump_ckpt(self, output_path): 28 | print(f"Saving model to {output_path}") 29 | torch.save(self.model.state_dict(), output_path) 30 | 31 | def load_ckpt(self, input_path): 32 | print(f"Loading model from {input_path}") 33 | self.model.load_state_dict(torch.load(input_path)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | try: 26 | from torch.utils.tensorboard import SummaryWriter 27 | TENSORBOARD_FOUND = True 28 | except ImportError: 29 | TENSORBOARD_FOUND = False 30 | 31 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 32 | first_iter = 0 33 | tb_writer = prepare_output_and_logger(dataset) 34 | gaussians = GaussianModel(dataset.sh_degree) 35 | scene = Scene(dataset, gaussians) 36 | gaussians.training_setup(opt) 37 | if checkpoint: 38 | (model_params, first_iter) = torch.load(checkpoint) 39 | gaussians.restore(model_params, opt) 40 | 41 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 42 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 43 | 44 | iter_start = torch.cuda.Event(enable_timing = True) 45 | iter_end = torch.cuda.Event(enable_timing = True) 46 | 47 | viewpoint_stack = None 48 | ema_loss_for_log = 0.0 49 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 50 | first_iter += 1 51 | for iteration in range(first_iter, opt.iterations + 1): 52 | if network_gui.conn == None: 53 | network_gui.try_connect() 54 | while network_gui.conn != None: 55 | try: 56 | net_image_bytes = None 57 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 58 | if custom_cam != None: 59 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 60 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 61 | network_gui.send(net_image_bytes, dataset.source_path) 62 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 63 | break 64 | except Exception as e: 65 | network_gui.conn = None 66 | 67 | iter_start.record() 68 | 69 | gaussians.update_learning_rate(iteration) 70 | 71 | # Every 1000 its we increase the levels of SH up to a maximum degree 72 | if iteration % 1000 == 0: 73 | gaussians.oneupSHdegree() 74 | 75 | # Pick a random Camera 76 | if not viewpoint_stack: 77 | viewpoint_stack = scene.getTrainCameras().copy() 78 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 79 | 80 | # Render 81 | if (iteration - 1) == debug_from: 82 | pipe.debug = True 83 | 84 | bg = torch.rand((3), device="cuda") if opt.random_background else background 85 | 86 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 87 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 88 | 89 | # Loss 90 | gt_image = viewpoint_cam.original_image.cuda() 91 | Ll1 = l1_loss(image, gt_image) 92 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 93 | loss.backward() 94 | 95 | iter_end.record() 96 | 97 | with torch.no_grad(): 98 | # Progress bar 99 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 100 | if iteration % 10 == 0: 101 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 102 | progress_bar.update(10) 103 | if iteration == opt.iterations: 104 | progress_bar.close() 105 | 106 | # Log and save 107 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 108 | if (iteration in saving_iterations): 109 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 110 | scene.save(iteration) 111 | 112 | # Densification 113 | if iteration < opt.densify_until_iter: 114 | # Keep track of max radii in image-space for pruning 115 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 116 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 117 | 118 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 119 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 120 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 121 | 122 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 123 | gaussians.reset_opacity() 124 | 125 | # Optimizer step 126 | if iteration < opt.iterations: 127 | gaussians.optimizer.step() 128 | gaussians.optimizer.zero_grad(set_to_none = True) 129 | 130 | if (iteration in checkpoint_iterations): 131 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 132 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 133 | 134 | def prepare_output_and_logger(args): 135 | if not args.model_path: 136 | if os.getenv('OAR_JOB_ID'): 137 | unique_str=os.getenv('OAR_JOB_ID') 138 | else: 139 | unique_str = str(uuid.uuid4()) 140 | args.model_path = os.path.join("./output/", unique_str[0:10]) 141 | 142 | # Set up output folder 143 | print("Output folder: {}".format(args.model_path)) 144 | os.makedirs(args.model_path, exist_ok = True) 145 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 146 | cfg_log_f.write(str(Namespace(**vars(args)))) 147 | 148 | # Create Tensorboard writer 149 | tb_writer = None 150 | if TENSORBOARD_FOUND: 151 | tb_writer = SummaryWriter(args.model_path) 152 | else: 153 | print("Tensorboard not available: not logging progress") 154 | return tb_writer 155 | 156 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 157 | if tb_writer: 158 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 159 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 160 | tb_writer.add_scalar('iter_time', elapsed, iteration) 161 | 162 | # Report test and samples of training set 163 | if iteration in testing_iterations: 164 | torch.cuda.empty_cache() 165 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 166 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 167 | 168 | for config in validation_configs: 169 | if config['cameras'] and len(config['cameras']) > 0: 170 | l1_test = 0.0 171 | psnr_test = 0.0 172 | for idx, viewpoint in enumerate(config['cameras']): 173 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 174 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 175 | if tb_writer and (idx < 5): 176 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 177 | if iteration == testing_iterations[0]: 178 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 179 | l1_test += l1_loss(image, gt_image).mean().double() 180 | psnr_test += psnr(image, gt_image).mean().double() 181 | psnr_test /= len(config['cameras']) 182 | l1_test /= len(config['cameras']) 183 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 184 | if tb_writer: 185 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 186 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 187 | 188 | if tb_writer: 189 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 190 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 191 | torch.cuda.empty_cache() 192 | 193 | if __name__ == "__main__": 194 | # Set up command line argument parser 195 | parser = ArgumentParser(description="Training script parameters") 196 | lp = ModelParams(parser) 197 | op = OptimizationParams(parser) 198 | pp = PipelineParams(parser) 199 | parser.add_argument('--ip', type=str, default="127.0.0.1") 200 | parser.add_argument('--port', type=int, default=6009) 201 | parser.add_argument('--debug_from', type=int, default=-1) 202 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 203 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) 204 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 205 | parser.add_argument("--quiet", action="store_true") 206 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 207 | parser.add_argument("--start_checkpoint", type=str, default = None) 208 | args = parser.parse_args(sys.argv[1:]) 209 | args.save_iterations.append(args.iterations) 210 | 211 | print("Optimizing " + args.model_path) 212 | 213 | # Initialize system state (RNG) 214 | safe_state(args.quiet) 215 | 216 | # Start GUI server, configure and run training 217 | network_gui.init(args.ip, args.port) 218 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 219 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 220 | 221 | # All done 222 | print("\nTraining complete.") 223 | -------------------------------------------------------------------------------- /train_dynamic.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim, l2_loss 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel, DynamicScene, GTPTGaussianModel 19 | from utils.general_utils import safe_state 20 | from utils.system_utils import searchForMaxIteration 21 | import uuid 22 | from tqdm import tqdm 23 | from utils.image_utils import psnr 24 | from argparse import ArgumentParser, Namespace 25 | from arguments import ModelParams, PipelineParams, OptimizationParams 26 | import shutil 27 | import copy 28 | import torch.nn.functional as F 29 | import plyfile 30 | import numpy as np 31 | try: 32 | from torch.utils.tensorboard import SummaryWriter 33 | TENSORBOARD_FOUND = True 34 | except ImportError: 35 | TENSORBOARD_FOUND = False 36 | 37 | def entropy_regularization_loss(current_frame_gaussian, last_frame_gaussian): 38 | # current_frest = current_frame_gaussian._features_rest 39 | current_scale = current_frame_gaussian._scaling 40 | current_rotation = current_frame_gaussian._rotation 41 | current_opacity = current_frame_gaussian._opacity 42 | current_attribute = [] 43 | for i in range(current_scale.shape[1]): 44 | current_attribute.append(current_scale[:, i]) 45 | for i in range(current_rotation.shape[1]): 46 | current_attribute.append(current_rotation[:, i]) 47 | for i in range(current_opacity.shape[1]): 48 | current_attribute.append(current_opacity[:, i]) 49 | 50 | last_scale = last_frame_gaussian._scaling 51 | last_rotation = last_frame_gaussian._rotation 52 | last_opacity = last_frame_gaussian._opacity 53 | last_attribute = [] 54 | for i in range(last_scale.shape[1]): 55 | last_attribute.append(last_scale[:, i]) 56 | for i in range(last_rotation.shape[1]): 57 | last_attribute.append(last_rotation[:, i]) 58 | for i in range(last_opacity.shape[1]): 59 | last_attribute.append(last_opacity[:, i]) 60 | 61 | quantization_range = 255 62 | loss = 0.0 63 | for idx in range(len(current_attribute)): 64 | delta_attribute = current_attribute[idx] - last_attribute[idx] 65 | # if zero return 66 | if torch.sum(delta_attribute) == 0: 67 | return 0.0 68 | # delta_attribute_normalize = (delta_attribute - torch.min(delta_attribute)) / (torch.max(delta_attribute) - torch.min(delta_attribute)) 69 | delta_attribute_min = torch.min(delta_attribute) 70 | delta_attribute_max = torch.max(delta_attribute) 71 | delta_attribute_normalize = (delta_attribute - delta_attribute_min) / (delta_attribute_max - delta_attribute_min) * quantization_range 72 | # generate -1/2 to 1/2 noise 73 | disturb_noise = np.random.uniform(-0.5, 0.5) 74 | disturb_delta_attribute = delta_attribute_normalize + disturb_noise 75 | disturb_delta_attribute_up = disturb_delta_attribute + 0.5 76 | disturb_delta_attribute_down = disturb_delta_attribute - 0.5 77 | 78 | m1 = torch.distributions.normal.Normal(torch.mean(disturb_delta_attribute_up), torch.std(disturb_delta_attribute_up)) 79 | m2 = torch.distributions.normal.Normal(torch.mean(disturb_delta_attribute_down), torch.std(disturb_delta_attribute_down)) 80 | 81 | cdf1 = m1.cdf(disturb_delta_attribute) 82 | cdf2 = m2.cdf(disturb_delta_attribute) 83 | 84 | cdf_diff = cdf1 - cdf2 85 | loss += -torch.log2(torch.abs(cdf_diff).sum()) / current_attribute[idx].shape[0] 86 | 87 | return loss 88 | 89 | def temporal_loss(current_frame_gaussian, last_frame_gaussian): 90 | # current_frest = current_frame_gaussian._features_rest 91 | # current_xyz = current_frame_gaussian._xyz 92 | # current_fdc = current_frame_gaussian._features_dc 93 | current_scale = current_frame_gaussian._scaling 94 | current_rotation = current_frame_gaussian._rotation 95 | current_opacity = current_frame_gaussian._opacity 96 | 97 | current_attribute = [current_scale, current_rotation, current_opacity] 98 | 99 | # last_frest = last_frame_gaussian._features_rest 100 | # last_xyz = last_frame_gaussian._xyz 101 | # last_fdc = last_frame_gaussian._features_dc 102 | last_scale = last_frame_gaussian._scaling 103 | last_rotation = last_frame_gaussian._rotation 104 | last_opacity = last_frame_gaussian._opacity 105 | 106 | last_attribute = [last_scale, last_rotation, last_opacity] 107 | 108 | loss = 0.0 109 | for idx in range(len(current_attribute)): 110 | att_loss = l2_loss(current_attribute[idx], last_attribute[idx]) 111 | loss += att_loss 112 | 113 | return loss 114 | 115 | def train_rt_network(dataset, scene, pipe, last_model_path, init_model_path, gtp_iter, load_last_rt_model, load_init_rt_model): 116 | first_iter = 0 117 | gaussians = GTPTGaussianModel(dataset.sh_degree) 118 | # find the last checkpoint 119 | last_pcd_iter = searchForMaxIteration(os.path.join(last_model_path, "point_cloud")) 120 | last_pcd_path = os.path.join(last_model_path, "point_cloud", "iteration_" + str(last_pcd_iter), "point_cloud.ply") 121 | print("Loading last pcd model from: ", last_pcd_path) 122 | gaussians.load_ply(last_pcd_path) 123 | 124 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 125 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 126 | lambda_dssim = 0.2 127 | 128 | iter_start = torch.cuda.Event(enable_timing = True) 129 | iter_end = torch.cuda.Event(enable_timing = True) 130 | 131 | viewpoint_stack = None 132 | ema_loss_for_log = 0.0 133 | progress_bar = tqdm(range(first_iter, gtp_iter), desc="Training T progress") 134 | first_iter += 1 135 | for iteration in range(first_iter, gtp_iter + 1): 136 | 137 | iter_start.record() 138 | 139 | # Pick a random Camera 140 | if not viewpoint_stack: 141 | viewpoint_stack = scene.getTrainCameras().copy() 142 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 143 | 144 | bg = background 145 | 146 | # first predict gtp 147 | gaussians.global_predict() 148 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 149 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 150 | 151 | # Loss 152 | gt_image = viewpoint_cam.original_image.cuda() 153 | Ll1 = l1_loss(image, gt_image) 154 | loss = (1.0 - lambda_dssim) * Ll1 + lambda_dssim * (1.0 - ssim(image, gt_image)) 155 | loss.backward() 156 | 157 | iter_end.record() 158 | 159 | with torch.no_grad(): 160 | # Progress bar 161 | if iteration == gtp_iter: 162 | progress_bar.close() 163 | if iteration % 10 == 0: 164 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 165 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 166 | progress_bar.update(10) 167 | # Optimizer step 168 | if iteration < gtp_iter: 169 | gaussians.optimizer.step() 170 | gaussians.optimizer.zero_grad(set_to_none = True) 171 | 172 | with torch.no_grad(): 173 | print("\n[ITER {}] Saving GTP Gaussians".format(iteration)) 174 | save_gtp_pcd_path = os.path.join(dataset.model_path, "gtp_pcd/iteration_{}".format(iteration)) 175 | gaussians.save_ply(os.path.join(save_gtp_pcd_path, "point_cloud.ply")) 176 | print("\n[ITER {}] Saving GTP Checkpoint".format(iteration)) 177 | # save_gtp_ckpt_path = os.path.join(dataset.model_path, "gtp_ckpt/iteration_{}".format(iteration)) 178 | # os.makedirs(save_gtp_ckpt_path, exist_ok = True) 179 | # gaussians.rt_model.dump_ckpt(os.path.join(save_gtp_ckpt_path, "rt_ckpt.pth")) 180 | 181 | def finetune(dataset, scene, opt, pipe, last_model_path, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 182 | first_iter = 0 183 | gaussians = GaussianModel(dataset.sh_degree) 184 | if checkpoint: 185 | (model_params, first_iter) = torch.load(checkpoint) 186 | gaussians.restore(model_params, opt) 187 | 188 | 189 | # load gtp pcd and finetune 190 | last_model_iter = searchForMaxIteration(os.path.join(dataset.model_path, "gtp_pcd")) 191 | print("Loading last gtp model from: ", os.path.join(dataset.model_path, "gtp_pcd", "iteration_" + str(last_model_iter))) 192 | gaussians.load_ply(os.path.join(dataset.model_path, "gtp_pcd", "iteration_" + str(last_model_iter), "point_cloud.ply")) 193 | gaussians.training_setup(opt) 194 | 195 | last_gaussians = GaussianModel(dataset.sh_degree) 196 | last_model_iter = searchForMaxIteration(os.path.join(last_model_path, "point_cloud")) 197 | print("Temporal loss and entropy Loading last pcd model from: ", os.path.join(last_model_path, "point_cloud", "iteration_" + str(last_model_iter), "point_cloud.ply")) 198 | last_gaussians.load_ply(os.path.join(last_model_path, "point_cloud", "iteration_" + str(last_model_iter), "point_cloud.ply")) 199 | 200 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 201 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 202 | 203 | iter_start = torch.cuda.Event(enable_timing = True) 204 | iter_end = torch.cuda.Event(enable_timing = True) 205 | 206 | viewpoint_stack = None 207 | ema_loss_for_log = 0.0 208 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Finetune progress") 209 | first_iter += 1 210 | for iteration in range(first_iter, opt.iterations + 1): 211 | 212 | iter_start.record() 213 | 214 | gaussians.update_learning_rate(iteration) 215 | 216 | # Pick a random Camera 217 | if not viewpoint_stack: 218 | viewpoint_stack = scene.getTrainCameras().copy() 219 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 220 | 221 | # Render 222 | if (iteration - 1) == debug_from: 223 | pipe.debug = True 224 | 225 | bg = torch.rand((3), device="cuda") if opt.random_background else background 226 | 227 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 228 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 229 | 230 | # Loss 231 | gt_image = viewpoint_cam.original_image.cuda() 232 | Ll1 = l1_loss(image, gt_image) 233 | temporal_loss_value = temporal_loss(gaussians, last_gaussians) 234 | entropy_loss = entropy_regularization_loss(gaussians, last_gaussians) 235 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + opt.lambda_temporal * temporal_loss_value + opt.lambda_entropy * entropy_loss 236 | loss.backward() 237 | 238 | iter_end.record() 239 | 240 | with torch.no_grad(): 241 | # Progress bar 242 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 243 | if iteration % 10 == 0: 244 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 245 | progress_bar.update(10) 246 | if iteration == opt.iterations: 247 | progress_bar.close() 248 | 249 | # Log and save 250 | training_report(iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), gaussians) 251 | 252 | # Optimizer step 253 | if iteration < opt.iterations: 254 | gaussians.optimizer.step() 255 | gaussians.optimizer.zero_grad(set_to_none = True) 256 | 257 | # always save gaussian after finetune 258 | with torch.no_grad(): 259 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 260 | save_pcd_path = os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)) 261 | gaussians.save_ply(os.path.join(save_pcd_path, "point_cloud.ply")) 262 | 263 | def dynamic_training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, start_frame, end_frame, interval_frame): 264 | 265 | if not os.path.exists(dataset.model_path): 266 | os.makedirs(dataset.model_path) 267 | 268 | print("Using keyframe {}".format(start_frame)) 269 | # test if keyframe is in the dataset 270 | frame_list = os.listdir(dataset.model_path) 271 | if str(start_frame) not in frame_list: 272 | print("Keyframe model not find") 273 | return 274 | 275 | gtp_iterations = 800 276 | # gtp_iterations = 4000 277 | # finetune_iterations = 3500 278 | finetune_iterations = 2000 279 | load_last_rt_model = False 280 | load_init_rt_model = False 281 | 282 | testing_iterations = [finetune_iterations] 283 | 284 | # record code 285 | os.makedirs(os.path.join(dataset.model_path, "record"), exist_ok = True) 286 | shutil.copy(__file__, os.path.join(dataset.model_path, "record", "train_dynamic.py")) 287 | shutil.copy("config/config_hash.json", os.path.join(dataset.model_path, "record", "config_hash.json")) 288 | shutil.copy("scene/gaussian_model.py", os.path.join(dataset.model_path, "record", "gaussian_model.py")) 289 | shutil.copy("scene/global_t_field.py", os.path.join(dataset.model_path, "record", "global_rt_field.py")) 290 | 291 | init_model_path = os.path.join(dataset.model_path, str(0)) 292 | 293 | for frame in range(start_frame + 1, end_frame + 1, interval_frame): 294 | print("Training frame {}".format(frame)) 295 | # ready dataset and opt into frame 296 | frame_dataset = copy.copy(dataset) 297 | frame_dataset.model_path = os.path.join(dataset.model_path, str(frame)) 298 | frame_dataset.source_path = os.path.join(dataset.source_path, str(frame)) 299 | 300 | frame_opt = copy.copy(opt) 301 | frame_opt.iterations = finetune_iterations 302 | 303 | # ready scene 304 | tb_writer = prepare_output_and_logger(frame_dataset) 305 | scene = DynamicScene(frame_dataset) 306 | 307 | # learn from last frame 308 | last_frame = frame - interval_frame 309 | last_model_path = os.path.join(dataset.model_path, str(last_frame)) 310 | 311 | # train rt for the frame using the keyframe model 312 | train_rt_network(frame_dataset, scene, pipe, last_model_path, init_model_path, gtp_iterations, load_last_rt_model, load_init_rt_model) 313 | 314 | # finetune wrapped model 315 | finetune(frame_dataset, scene, frame_opt, pipe, last_model_path, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from) 316 | 317 | # clean up 318 | del scene 319 | del tb_writer 320 | 321 | print("Training frame {} done".format(frame)) 322 | 323 | 324 | def prepare_output_and_logger(args): 325 | if not args.model_path: 326 | if os.getenv('OAR_JOB_ID'): 327 | unique_str=os.getenv('OAR_JOB_ID') 328 | else: 329 | unique_str = str(uuid.uuid4()) 330 | args.model_path = os.path.join("./output/", unique_str[0:10]) 331 | 332 | # Set up output folder 333 | print("Output folder: {}".format(args.model_path)) 334 | os.makedirs(args.model_path, exist_ok = True) 335 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 336 | cfg_log_f.write(str(Namespace(**vars(args)))) 337 | 338 | # Create Tensorboard writer 339 | tb_writer = None 340 | if TENSORBOARD_FOUND: 341 | tb_writer = SummaryWriter(args.model_path) 342 | else: 343 | print("Tensorboard not available: not logging progress") 344 | return tb_writer 345 | 346 | def training_report(iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, gaussians): 347 | # Report test and samples of training set 348 | if iteration in testing_iterations: 349 | torch.cuda.empty_cache() 350 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 351 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 352 | 353 | for config in validation_configs: 354 | if config['cameras'] and len(config['cameras']) > 0: 355 | l1_test = 0.0 356 | psnr_test = 0.0 357 | for idx, viewpoint in enumerate(config['cameras']): 358 | image = torch.clamp(renderFunc(viewpoint, gaussians, *renderArgs)["render"], 0.0, 1.0) 359 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 360 | l1_test += l1_loss(image, gt_image).mean().double() 361 | psnr_test += psnr(image, gt_image).mean().double() 362 | psnr_test /= len(config['cameras']) 363 | l1_test /= len(config['cameras']) 364 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 365 | torch.cuda.empty_cache() 366 | 367 | if __name__ == "__main__": 368 | # Set up command line argument parser 369 | parser = ArgumentParser(description="Training script parameters") 370 | lp = ModelParams(parser) 371 | op = OptimizationParams(parser) 372 | pp = PipelineParams(parser) 373 | parser.add_argument('--debug_from', type=int, default=-1) 374 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 375 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[3_500]) 376 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[3_500]) 377 | parser.add_argument("--st", type=int, default=0) 378 | parser.add_argument("--ed", type=int, default=0) 379 | parser.add_argument("--interval", type=int, default=0) 380 | 381 | parser.add_argument("--quiet", action="store_true") 382 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 383 | parser.add_argument("--start_checkpoint", type=str, default = None) 384 | args = parser.parse_args(sys.argv[1:]) 385 | args.save_iterations.append(args.iterations) 386 | 387 | print("Optimizing " + args.model_path) 388 | 389 | # Initialize system state (RNG) 390 | safe_state(args.quiet) 391 | 392 | # Start GUI server, configure and run training 393 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 394 | 395 | print(f"train with keyframe {args.st}") 396 | print(f"train from frame {args.st + args.interval} to frame {args.ed}") 397 | dynamic_training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.st, args.ed, args.interval) 398 | 399 | # All done 400 | print("\nTraining complete.") 401 | -------------------------------------------------------------------------------- /train_prune.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import torch 14 | from random import randint 15 | from utils.loss_utils import l1_loss, ssim 16 | from gaussian_renderer import render, network_gui 17 | import sys 18 | from scene import Scene, GaussianModel 19 | from utils.general_utils import safe_state 20 | import uuid 21 | from tqdm import tqdm 22 | from utils.image_utils import psnr 23 | from argparse import ArgumentParser, Namespace 24 | from arguments import ModelParams, PipelineParams, OptimizationParams 25 | try: 26 | from torch.utils.tensorboard import SummaryWriter 27 | TENSORBOARD_FOUND = True 28 | except ImportError: 29 | TENSORBOARD_FOUND = False 30 | 31 | def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): 32 | first_iter = 0 33 | tb_writer = prepare_output_and_logger(dataset) 34 | gaussians = GaussianModel(dataset.sh_degree) 35 | scene = Scene(dataset, gaussians) 36 | gaussians.training_setup(opt) 37 | if checkpoint: 38 | (model_params, first_iter) = torch.load(checkpoint) 39 | gaussians.restore(model_params, opt) 40 | 41 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 42 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 43 | 44 | iter_start = torch.cuda.Event(enable_timing = True) 45 | iter_end = torch.cuda.Event(enable_timing = True) 46 | 47 | viewpoint_stack = None 48 | ema_loss_for_log = 0.0 49 | progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") 50 | first_iter += 1 51 | for iteration in range(first_iter, opt.iterations + 1): 52 | if network_gui.conn == None: 53 | network_gui.try_connect() 54 | while network_gui.conn != None: 55 | try: 56 | net_image_bytes = None 57 | custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 58 | if custom_cam != None: 59 | net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] 60 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 61 | network_gui.send(net_image_bytes, dataset.source_path) 62 | if do_training and ((iteration < int(opt.iterations)) or not keep_alive): 63 | break 64 | except Exception as e: 65 | network_gui.conn = None 66 | 67 | iter_start.record() 68 | 69 | gaussians.update_learning_rate(iteration) 70 | 71 | # Every 1000 its we increase the levels of SH up to a maximum degree 72 | if iteration % 1000 == 0: 73 | gaussians.oneupSHdegree() 74 | 75 | # Pick a random Camera 76 | if not viewpoint_stack: 77 | viewpoint_stack = scene.getTrainCameras().copy() 78 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) 79 | 80 | # Render 81 | if (iteration - 1) == debug_from: 82 | pipe.debug = True 83 | 84 | bg = torch.rand((3), device="cuda") if opt.random_background else background 85 | 86 | render_pkg = render(viewpoint_cam, gaussians, pipe, bg) 87 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] 88 | 89 | # Loss 90 | gt_image = viewpoint_cam.original_image.cuda() 91 | Ll1 = l1_loss(image, gt_image) 92 | loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) 93 | loss.backward() 94 | 95 | iter_end.record() 96 | 97 | with torch.no_grad(): 98 | # Progress bar 99 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 100 | if iteration % 10 == 0: 101 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 102 | progress_bar.update(10) 103 | if iteration == opt.iterations: 104 | progress_bar.close() 105 | 106 | # Log and save 107 | training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) 108 | if (iteration in saving_iterations): 109 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 110 | scene.save(iteration) 111 | 112 | if iteration < opt.first_frame_prune_iter: 113 | 114 | # Densification 115 | if iteration < opt.densify_until_iter: 116 | # Keep track of max radii in image-space for pruning 117 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) 118 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 119 | 120 | if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: 121 | size_threshold = 20 if iteration > opt.opacity_reset_interval else None 122 | gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 123 | 124 | if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): 125 | gaussians.reset_opacity() 126 | 127 | # Pruning 128 | if iteration == opt.first_frame_prune_iter: 129 | gaussians.prune_opacity(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) 130 | 131 | # Optimizer step 132 | if iteration < opt.iterations: 133 | gaussians.optimizer.step() 134 | gaussians.optimizer.zero_grad(set_to_none = True) 135 | 136 | if (iteration in checkpoint_iterations): 137 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 138 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 139 | 140 | def prepare_output_and_logger(args): 141 | if not args.model_path: 142 | if os.getenv('OAR_JOB_ID'): 143 | unique_str=os.getenv('OAR_JOB_ID') 144 | else: 145 | unique_str = str(uuid.uuid4()) 146 | args.model_path = os.path.join("./output/", unique_str[0:10]) 147 | 148 | # Set up output folder 149 | print("Output folder: {}".format(args.model_path)) 150 | os.makedirs(args.model_path, exist_ok = True) 151 | with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: 152 | cfg_log_f.write(str(Namespace(**vars(args)))) 153 | 154 | # Create Tensorboard writer 155 | tb_writer = None 156 | if TENSORBOARD_FOUND: 157 | tb_writer = SummaryWriter(args.model_path) 158 | else: 159 | print("Tensorboard not available: not logging progress") 160 | return tb_writer 161 | 162 | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): 163 | if tb_writer: 164 | tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) 165 | tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) 166 | tb_writer.add_scalar('iter_time', elapsed, iteration) 167 | 168 | # Report test and samples of training set 169 | if iteration in testing_iterations: 170 | torch.cuda.empty_cache() 171 | validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 172 | {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) 173 | 174 | for config in validation_configs: 175 | if config['cameras'] and len(config['cameras']) > 0: 176 | l1_test = 0.0 177 | psnr_test = 0.0 178 | for idx, viewpoint in enumerate(config['cameras']): 179 | image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) 180 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 181 | if tb_writer and (idx < 5): 182 | tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) 183 | if iteration == testing_iterations[0]: 184 | tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) 185 | l1_test += l1_loss(image, gt_image).mean().double() 186 | psnr_test += psnr(image, gt_image).mean().double() 187 | psnr_test /= len(config['cameras']) 188 | l1_test /= len(config['cameras']) 189 | print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) 190 | if tb_writer: 191 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) 192 | tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) 193 | 194 | if tb_writer: 195 | tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) 196 | tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) 197 | torch.cuda.empty_cache() 198 | 199 | if __name__ == "__main__": 200 | # Set up command line argument parser 201 | parser = ArgumentParser(description="Training script parameters") 202 | lp = ModelParams(parser) 203 | op = OptimizationParams(parser) 204 | pp = PipelineParams(parser) 205 | parser.add_argument('--ip', type=str, default="127.0.0.1") 206 | parser.add_argument('--port', type=int, default=6009) 207 | parser.add_argument('--debug_from', type=int, default=-1) 208 | parser.add_argument('--detect_anomaly', action='store_true', default=False) 209 | parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) 210 | parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) 211 | parser.add_argument("--quiet", action="store_true") 212 | parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) 213 | parser.add_argument("--start_checkpoint", type=str, default = None) 214 | args = parser.parse_args(sys.argv[1:]) 215 | args.save_iterations.append(args.iterations) 216 | 217 | print("Optimizing " + args.model_path) 218 | 219 | # Initialize system state (RNG) 220 | safe_state(args.quiet) 221 | 222 | # Start GUI server, configure and run training 223 | network_gui.init(args.ip, args.port) 224 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 225 | training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) 226 | 227 | # All done 228 | print("\nTraining complete.") 229 | -------------------------------------------------------------------------------- /train_sequence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import pymeshlab 5 | import open3d as o3d 6 | import numpy as np 7 | 8 | # group_size = 10 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--start', type=int, default='') 13 | parser.add_argument('--end', type=int, default='') 14 | parser.add_argument('--cuda', type=int, default='') 15 | parser.add_argument('--data', type=str, default='') 16 | parser.add_argument('--output', type=str, default='') 17 | parser.add_argument('--sh', type=str, default='0') 18 | parser.add_argument('--interval', type=str, default='') 19 | parser.add_argument('--group_size', type=str, default='') 20 | parser.add_argument('--resolution', type=int, default=2) 21 | args = parser.parse_args() 22 | 23 | print(args.start, args.end) 24 | 25 | # os.system("conda activate torch") 26 | card_id = args.cuda 27 | data_root_path = args.data 28 | output_path = args.output 29 | sh = args.sh 30 | interval = int(args.interval) 31 | group_size = int(args.group_size) 32 | resolution_scale = int(args.resolution) 33 | 34 | # neus2_meshlab_filter_path = os.path.join(data_root_path, "luoxi_filter.mlx") 35 | 36 | neus2_output_path = os.path.join(output_path, "neus2_output") 37 | if not os.path.exists(neus2_output_path): 38 | os.makedirs(neus2_output_path) 39 | 40 | gaussian_output_path = os.path.join(output_path, "checkpoint") 41 | 42 | for i in range(args.start, args.end, group_size * interval): 43 | group_start = i 44 | group_end = min(i + group_size * interval, args.end) - 1 45 | print(group_start, group_end) 46 | 47 | frame_path = os.path.join(data_root_path, str(i)) 48 | if not os.path.exists(frame_path): 49 | os.makedirs(frame_path) 50 | frame_neus2_output_path = os.path.join(neus2_output_path, str(i)) 51 | if not os.path.exists(frame_neus2_output_path): 52 | os.makedirs(frame_neus2_output_path) 53 | frame_neus2_ckpt_output_path = os.path.join(frame_neus2_output_path, "frame.msgpack") 54 | frame_neus2_mesh_output_path = os.path.join(frame_neus2_output_path, "points3d.obj") 55 | 56 | """NeuS2""" 57 | # neus2 command 58 | script_path = "scripts/run.py" 59 | neus2_command = f"cd external/NeuS2_K && CUDA_VISIBLE_DEVICES={card_id} python {script_path} --scene {frame_path} --name neus --mode nerf --save_snapshot {frame_neus2_ckpt_output_path} --save_mesh --save_mesh_path {frame_neus2_mesh_output_path} && cd ../.." 60 | os.system(neus2_command) 61 | delete_neus2_output_path = os.path.join(frame_path, "output") 62 | shutil.rmtree(delete_neus2_output_path) 63 | 64 | # revert axis 65 | mesh1 = o3d.io.read_triangle_mesh(frame_neus2_mesh_output_path) 66 | vertices = np.asarray(mesh1.vertices) 67 | vertices = vertices[:,[2,0,1]] 68 | mesh1.vertices = o3d.utility.Vector3dVector(vertices) 69 | o3d.io.write_triangle_mesh(frame_neus2_mesh_output_path, mesh1) 70 | 71 | # use pymeshlab to convert obj to point cloud 72 | ms = pymeshlab.MeshSet() 73 | ms.load_new_mesh(frame_neus2_mesh_output_path) 74 | # ms.load_filter_script(neus2_meshlab_filter_path) 75 | # ms.apply_filter_script() 76 | ms.generate_simplified_point_cloud(samplenum = 100000) 77 | frame_points3d_output_path = os.path.join(frame_path, "points3d.ply") 78 | ms.save_current_mesh(frame_points3d_output_path, binary = True, save_vertex_normal = False) 79 | 80 | 81 | """ Gaussian """ 82 | # generate output 83 | frame_model_path = os.path.join(gaussian_output_path, str(i)) 84 | first_frame_iteration = 12000 85 | first_frame_save_iterations = first_frame_iteration 86 | first_gaussian_command = f"CUDA_VISIBLE_DEVICES={card_id} python train.py -s {frame_path} -m {frame_model_path} --iterations {first_frame_iteration} --save_iterations {first_frame_save_iterations} --sh_degree {sh} -r {resolution_scale} --port 600{card_id}" 87 | os.system(first_gaussian_command) 88 | 89 | # prune 90 | prune_iterations = 4000 91 | prune_gaussian_command = f"CUDA_VISIBLE_DEVICES={card_id} python prune_gaussian.py -s {frame_path} -m {frame_model_path} --sh_degree {sh} -r {resolution_scale} --iterations {prune_iterations}" 92 | os.system(prune_gaussian_command) 93 | 94 | # rest frame 95 | dynamic_command = f"CUDA_VISIBLE_DEVICES={card_id} python train_dynamic.py -s {data_root_path} -m {gaussian_output_path} --sh_degree {sh} -r {resolution_scale} --st {group_start} --ed {group_end} --interval {interval}" 96 | os.system(dynamic_command) 97 | 98 | print(f"Finish {group_start} to {group_end}") -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | if resized_image_rgb.shape[1] == 4: 47 | loaded_mask = resized_image_rgb[3:4, ...] 48 | 49 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 50 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 51 | image=gt_image, gt_alpha_mask=loaded_mask, 52 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 53 | 54 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 55 | camera_list = [] 56 | 57 | for id, c in enumerate(cam_infos): 58 | camera_list.append(loadCam(args, id, c, resolution_scale)) 59 | 60 | return camera_list 61 | 62 | def camera_to_JSON(id, camera : Camera): 63 | Rt = np.zeros((4, 4)) 64 | Rt[:3, :3] = camera.R.transpose() 65 | Rt[:3, 3] = camera.T 66 | Rt[3, 3] = 1.0 67 | 68 | W2C = np.linalg.inv(Rt) 69 | pos = W2C[:3, 3] 70 | rot = W2C[:3, :3] 71 | serializable_array_2d = [x.tolist() for x in rot] 72 | camera_entry = { 73 | 'id' : id, 74 | 'img_name' : camera.image_name, 75 | 'width' : camera.width, 76 | 'height' : camera.height, 77 | 'position': pos.tolist(), 78 | 'rotation': serializable_array_2d, 79 | 'fy' : fov2focal(camera.FovY, camera.height), 80 | 'fx' : fov2focal(camera.FovX, camera.width) 81 | } 82 | return camera_entry 83 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------