├── .gitignore ├── .gitmodules ├── .ipynb_checkpoints ├── mask_psnr-checkpoint.ipynb ├── test-checkpoint.ipynb └── weight_visualization-checkpoint.ipynb ├── 4DGaussians.ipynb ├── LICENSE.md ├── README.md ├── arguments ├── __init__.py ├── dnerf │ ├── bouncingballs.py │ ├── dnerf_default.py │ ├── hellwarrior.py │ ├── hook.py │ ├── jumpingjacks.py │ ├── lego.py │ ├── mutant.py │ ├── standup.py │ └── trex.py ├── dycheck │ └── default.py ├── dynerf │ ├── coffee_martini.py │ ├── cook_spinach.py │ ├── cut_roasted_beef.py │ ├── default.py │ ├── flame_salmon_1.py │ ├── flame_steak.py │ └── sear_steak.py ├── hypernerf │ ├── 3dprinter.py │ ├── banana.py │ ├── broom2.py │ ├── chicken.py │ └── default.py └── multipleview │ └── default.py ├── assets ├── cut_roasted_beef_time.mp4 ├── pipeline.png ├── port_forward.png ├── teaserfig.jpg ├── teaservideo.mp4 └── viewer.mp4 ├── colmap.sh ├── convert.py ├── database.py ├── docs └── viewer_usage.md ├── export_perframe_3DGS.py ├── full_eval.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── merge_many_4dgs.py ├── metrics.py ├── multipleviewprogress.sh ├── render.py ├── requirements.txt ├── scene ├── __init__.py ├── camera.py ├── cameras.py ├── colmap_loader.py ├── dataset.py ├── dataset_readers.py ├── deformation.py ├── gaussian_model.py ├── grid.py ├── hexplane.py ├── hyper_loader.py ├── multipleview_dataset.py ├── neural_3D_dataset_NDC.py ├── regulation.py └── utils.py ├── scripts ├── blender2colmap.py ├── cal_modelsize.py ├── colmap_converter.py ├── downsample_point.py ├── extractimages.py ├── grow_point.py ├── hypernerf2colmap.py ├── llff2colmap.py ├── merge_point.py ├── preprocess_dynerf.py ├── process_dnerf.sh ├── read_all_metrics.py ├── select_image.py ├── train_dnerf.sh ├── train_dycheck.sh ├── train_dynamic3dgs.sh ├── train_dynerf.sh ├── train_hyper_interp.sh ├── train_hyper_virg.sh └── train_test_split.py ├── train.py ├── utils ├── TIMES.TTF ├── TIMESBD.TTF ├── TIMESBI.TTF ├── TIMESI.TTF ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loader_utils.py ├── loss_utils.py ├── params_utils.py ├── point_utils.py ├── pose_utils.py ├── render_utils.py ├── scene_utils.py ├── sh_utils.py ├── system_utils.py └── timer.py └── weight_visualization.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots 9 | data/ 10 | data 11 | submodules/ 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | 5 | [submodule "submodules/depth-diff-gaussian-rasterization"] 6 | path = submodules/depth-diff-gaussian-rasterization 7 | url = https://github.com/ingra14m/depth-diff-gaussian-rasterization 8 | [submodule "SIBR_viewers"] 9 | path = SIBR_viewers 10 | url = https://gitlab.inria.fr/sibr/sibr_core 11 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/test-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import torch\n", 11 | "import os\n", 12 | "import imageio" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 48, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# path = \"/data3/guanjunwu/project_scp/TiNeuVox/logs/interp_data/interp/chicken/render_test_fine_last\"\n", 22 | "path = \"output/hypernerf4/interp/americano/test/ours_14000/renders\"\n", 23 | "# \n", 24 | "# path = \"output/dynamic3dgs/dynamic3dgs/basketball/test/ours_30000/renders\"\n", 25 | "image_list = os.listdir(path)\n", 26 | "len_image = len(image_list)\n", 27 | "tile = image_list[0].split('.')[-1]" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 49, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import re\n", 37 | "def sort_numeric_filenames(filenames):\n", 38 | " \"\"\"\n", 39 | " Sort a list of filenames based on the numeric part of the filename.\n", 40 | " Assumes filenames have a format like '0000.png', '0001.png', etc.\n", 41 | " \"\"\"\n", 42 | " def extract_number(filename):\n", 43 | " # 使用正则表达式提取文件名中的数字\n", 44 | " match = re.search(r'\\d+', filename)\n", 45 | " return int(match.group()) if match else 0\n", 46 | "\n", 47 | " # 使用提取的数字进行排序\n", 48 | " return sorted(filenames, key=extract_number)\n", 49 | "\n", 50 | "# 示例文件名列表\n", 51 | "filenames = image_list\n", 52 | "\n", 53 | "# 进行排序\n", 54 | "sorted_filenames = sort_numeric_filenames(filenames)\n", 55 | "sorted_filenames = [i for i in sorted_filenames if 'png' in i]" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 50, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "['000.png',\n", 67 | " '001.png',\n", 68 | " '002.png',\n", 69 | " '003.png',\n", 70 | " '004.png',\n", 71 | " '005.png',\n", 72 | " '006.png',\n", 73 | " '007.png',\n", 74 | " '008.png',\n", 75 | " '009.png',\n", 76 | " '010.png',\n", 77 | " '011.png',\n", 78 | " '012.png',\n", 79 | " '013.png',\n", 80 | " '014.png',\n", 81 | " '015.png',\n", 82 | " '016.png',\n", 83 | " '017.png',\n", 84 | " '018.png',\n", 85 | " '019.png',\n", 86 | " '020.png',\n", 87 | " '021.png',\n", 88 | " '022.png',\n", 89 | " '023.png',\n", 90 | " '024.png',\n", 91 | " '025.png',\n", 92 | " '026.png',\n", 93 | " '027.png',\n", 94 | " '028.png',\n", 95 | " '029.png',\n", 96 | " '030.png',\n", 97 | " '031.png',\n", 98 | " '032.png',\n", 99 | " '033.png',\n", 100 | " '034.png',\n", 101 | " '035.png',\n", 102 | " '036.png',\n", 103 | " '037.png',\n", 104 | " '038.png',\n", 105 | " '039.png',\n", 106 | " '040.png',\n", 107 | " '041.png',\n", 108 | " '042.png',\n", 109 | " '043.png',\n", 110 | " '044.png',\n", 111 | " '045.png',\n", 112 | " '046.png',\n", 113 | " '047.png',\n", 114 | " '048.png',\n", 115 | " '049.png',\n", 116 | " '050.png',\n", 117 | " '051.png',\n", 118 | " '052.png',\n", 119 | " '053.png',\n", 120 | " '054.png',\n", 121 | " '055.png',\n", 122 | " '056.png',\n", 123 | " '057.png',\n", 124 | " '058.png',\n", 125 | " '059.png',\n", 126 | " '060.png',\n", 127 | " '061.png',\n", 128 | " '062.png',\n", 129 | " '063.png',\n", 130 | " '064.png',\n", 131 | " '065.png',\n", 132 | " '066.png',\n", 133 | " '067.png',\n", 134 | " '068.png',\n", 135 | " '069.png',\n", 136 | " '070.png',\n", 137 | " '071.png',\n", 138 | " '072.png',\n", 139 | " '073.png',\n", 140 | " '074.png',\n", 141 | " '075.png',\n", 142 | " '076.png',\n", 143 | " '077.png',\n", 144 | " '078.png',\n", 145 | " '079.png',\n", 146 | " '080.png',\n", 147 | " '081.png',\n", 148 | " '082.png',\n", 149 | " '083.png',\n", 150 | " '084.png',\n", 151 | " '085.png',\n", 152 | " '086.png',\n", 153 | " '087.png',\n", 154 | " '088.png',\n", 155 | " '089.png',\n", 156 | " '090.png',\n", 157 | " '091.png',\n", 158 | " '092.png',\n", 159 | " '093.png',\n", 160 | " '094.png',\n", 161 | " '095.png',\n", 162 | " '096.png',\n", 163 | " '097.png',\n", 164 | " '098.png',\n", 165 | " '099.png',\n", 166 | " '100.png',\n", 167 | " '101.png',\n", 168 | " '102.png',\n", 169 | " '103.png',\n", 170 | " '104.png',\n", 171 | " '105.png',\n", 172 | " '106.png',\n", 173 | " '107.png',\n", 174 | " '108.png',\n", 175 | " '109.png',\n", 176 | " '110.png',\n", 177 | " '111.png',\n", 178 | " '112.png']" 179 | ] 180 | }, 181 | "execution_count": 50, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "sorted_filenames" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 51, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stderr", 197 | "output_type": "stream", 198 | "text": [ 199 | "/data/guanjunwu/disk2/miniconda3/envs/Gaussians4D/lib/python3.7/site-packages/ipykernel_launcher.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n", 200 | " \n", 201 | "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (536, 960) to (544, 960) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n", 202 | "[swscaler @ 0x67a2580] Warning: data is not aligned! This can lead to a speed loss\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "writer = imageio.get_writer(os.path.join(path,\"video111.mp4\"),fps=10)\n", 208 | "video_num = 1\n", 209 | "video_list = [[] for i in range(video_num)]\n", 210 | "for i, image in enumerate(sorted_filenames):\n", 211 | " if i % video_num == 0:\n", 212 | " image = imageio.imread(os.path.join(path,image))\n", 213 | " writer.append_data(image)\n", 214 | "writer.close()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "!pip install imageio[ffmpeg]" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [] 232 | } 233 | ], 234 | "metadata": { 235 | "kernelspec": { 236 | "display_name": "Python 3 (ipykernel)", 237 | "language": "python", 238 | "name": "python3" 239 | }, 240 | "language_info": { 241 | "codemirror_mode": { 242 | "name": "ipython", 243 | "version": 3 244 | }, 245 | "file_extension": ".py", 246 | "mimetype": "text/x-python", 247 | "name": "python", 248 | "nbconvert_exporter": "python", 249 | "pygments_lexer": "ipython3", 250 | "version": "3.7.16" 251 | } 252 | }, 253 | "nbformat": 4, 254 | "nbformat_minor": 2 255 | } 256 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | 16 | class GroupParams: 17 | pass 18 | 19 | class ParamGroup: 20 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 21 | group = parser.add_argument_group(name) 22 | for key, value in vars(self).items(): 23 | shorthand = False 24 | if key.startswith("_"): 25 | shorthand = True 26 | key = key[1:] 27 | t = type(value) 28 | value = value if not fill_none else None 29 | if shorthand: 30 | if t == bool: 31 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 32 | else: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 34 | else: 35 | if t == bool: 36 | group.add_argument("--" + key, default=value, action="store_true") 37 | else: 38 | group.add_argument("--" + key, default=value, type=t) 39 | 40 | def extract(self, args): 41 | group = GroupParams() 42 | for arg in vars(args).items(): 43 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 44 | setattr(group, arg[0], arg[1]) 45 | return group 46 | 47 | class ModelParams(ParamGroup): 48 | def __init__(self, parser, sentinel=False): 49 | self.sh_degree = 3 50 | self._source_path = "" 51 | self._model_path = "" 52 | self._images = "images" 53 | self._resolution = -1 54 | self._white_background = True 55 | self.data_device = "cuda" 56 | self.eval = True 57 | self.render_process=False 58 | self.add_points=False 59 | self.extension=".png" 60 | self.llffhold=8 61 | super().__init__(parser, "Loading Parameters", sentinel) 62 | 63 | def extract(self, args): 64 | g = super().extract(args) 65 | g.source_path = os.path.abspath(g.source_path) 66 | return g 67 | 68 | class PipelineParams(ParamGroup): 69 | def __init__(self, parser): 70 | self.convert_SHs_python = False 71 | self.compute_cov3D_python = False 72 | self.debug = False 73 | super().__init__(parser, "Pipeline Parameters") 74 | class ModelHiddenParams(ParamGroup): 75 | def __init__(self, parser): 76 | self.net_width = 64 # width of deformation MLP, larger will increase the rendering quality and decrase the training/rendering speed. 77 | self.timebase_pe = 4 # useless 78 | self.defor_depth = 1 # depth of deformation MLP, larger will increase the rendering quality and decrase the training/rendering speed. 79 | self.posebase_pe = 10 # useless 80 | self.scale_rotation_pe = 2 # useless 81 | self.opacity_pe = 2 # useless 82 | self.timenet_width = 64 # useless 83 | self.timenet_output = 32 # useless 84 | self.bounds = 1.6 85 | self.plane_tv_weight = 0.0001 # TV loss of spatial grid 86 | self.time_smoothness_weight = 0.01 # TV loss of temporal grid 87 | self.l1_time_planes = 0.0001 # TV loss of temporal grid 88 | self.kplanes_config = { 89 | 'grid_dimensions': 2, 90 | 'input_coordinate_dim': 4, 91 | 'output_coordinate_dim': 32, 92 | 'resolution': [64, 64, 64, 25] # [64,64,64]: resolution of spatial grid. 25: resolution of temporal grid, better to be half length of dynamic frames 93 | } 94 | self.multires = [1, 2, 4, 8] # multi resolution of voxel grid 95 | self.no_dx=False # cancel the deformation of Gaussians' position 96 | self.no_grid=False # cancel the spatial-temporal hexplane. 97 | self.no_ds=False # cancel the deformation of Gaussians' scaling 98 | self.no_dr=False # cancel the deformation of Gaussians' rotations 99 | self.no_do=True # cancel the deformation of Gaussians' opacity 100 | self.no_dshs=True # cancel the deformation of SH colors. 101 | self.empty_voxel=False # useless 102 | self.grid_pe=0 # useless, I was trying to add positional encoding to hexplane's features 103 | self.static_mlp=False # useless 104 | self.apply_rotation=False # useless 105 | 106 | 107 | super().__init__(parser, "ModelHiddenParams") 108 | 109 | class OptimizationParams(ParamGroup): 110 | def __init__(self, parser): 111 | self.dataloader=False 112 | self.zerostamp_init=False 113 | self.custom_sampler=None 114 | self.iterations = 30_000 115 | self.coarse_iterations = 3000 116 | self.position_lr_init = 0.00016 117 | self.position_lr_final = 0.0000016 118 | self.position_lr_delay_mult = 0.01 119 | self.position_lr_max_steps = 20_000 120 | self.deformation_lr_init = 0.00016 121 | self.deformation_lr_final = 0.000016 122 | self.deformation_lr_delay_mult = 0.01 123 | self.grid_lr_init = 0.0016 124 | self.grid_lr_final = 0.00016 125 | 126 | self.feature_lr = 0.0025 127 | self.opacity_lr = 0.05 128 | self.scaling_lr = 0.005 129 | self.rotation_lr = 0.001 130 | self.percent_dense = 0.01 131 | self.lambda_dssim = 0 132 | self.lambda_lpips = 0 133 | self.weight_constraint_init= 1 134 | self.weight_constraint_after = 0.2 135 | self.weight_decay_iteration = 5000 136 | self.opacity_reset_interval = 3000 137 | self.densification_interval = 100 138 | self.densify_from_iter = 500 139 | self.densify_until_iter = 15_000 140 | self.densify_grad_threshold_coarse = 0.0002 141 | self.densify_grad_threshold_fine_init = 0.0002 142 | self.densify_grad_threshold_after = 0.0002 143 | self.pruning_from_iter = 500 144 | self.pruning_interval = 100 145 | self.opacity_threshold_coarse = 0.005 146 | self.opacity_threshold_fine_init = 0.005 147 | self.opacity_threshold_fine_after = 0.005 148 | self.batch_size=1 149 | self.add_point=False 150 | super().__init__(parser, "Optimization Parameters") 151 | 152 | def get_combined_args(parser : ArgumentParser): 153 | cmdlne_string = sys.argv[1:] 154 | cfgfile_string = "Namespace()" 155 | args_cmdline = parser.parse_args(cmdlne_string) 156 | 157 | try: 158 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 159 | print("Looking for config file in", cfgfilepath) 160 | with open(cfgfilepath) as cfg_file: 161 | print("Config file found: {}".format(cfgfilepath)) 162 | cfgfile_string = cfg_file.read() 163 | except TypeError: 164 | print("Config file not found at") 165 | pass 166 | args_cfgfile = eval(cfgfile_string) 167 | 168 | merged_dict = vars(args_cfgfile).copy() 169 | for k,v in vars(args_cmdline).items(): 170 | if v != None: 171 | merged_dict[k] = v 172 | return Namespace(**merged_dict) 173 | -------------------------------------------------------------------------------- /arguments/dnerf/bouncingballs.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 75] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/dnerf_default.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | OptimizationParams = dict( 4 | 5 | coarse_iterations = 3000, 6 | deformation_lr_init = 0.00016, 7 | deformation_lr_final = 0.0000016, 8 | deformation_lr_delay_mult = 0.01, 9 | grid_lr_init = 0.0016, 10 | grid_lr_final = 0.000016, 11 | iterations = 20000, 12 | pruning_interval = 8000, 13 | percent_dense = 0.01, 14 | render_process=False, 15 | # no_do=False, 16 | # no_dshs=False 17 | 18 | # opacity_reset_interval=30000 19 | 20 | ) 21 | 22 | ModelHiddenParams = dict( 23 | 24 | multires = [1, 2], 25 | defor_depth = 0, 26 | net_width = 64, 27 | plane_tv_weight = 0.0001, 28 | time_smoothness_weight = 0.01, 29 | l1_time_planes = 0.0001, 30 | weight_decay_iteration=0, 31 | bounds=1.6 32 | ) 33 | -------------------------------------------------------------------------------- /arguments/dnerf/hellwarrior.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 50] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/hook.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 50] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/jumpingjacks.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 100] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/lego.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 25] 9 | }, 10 | 11 | # deformation_lr_init = 0.001, 12 | # deformation_lr_final = 0.001, 13 | # deformation_lr_delay_mult = 0.01, 14 | # grid_lr_init = 0.001, 15 | # grid_lr_final = 0.001, 16 | ) -------------------------------------------------------------------------------- /arguments/dnerf/mutant.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 75] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/standup.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 75] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dnerf/trex.py: -------------------------------------------------------------------------------- 1 | _base_ = './dnerf_default.py' 2 | 3 | ModelHiddenParams = dict( 4 | kplanes_config = { 5 | 'grid_dimensions': 2, 6 | 'input_coordinate_dim': 4, 7 | 'output_coordinate_dim': 32, 8 | 'resolution': [64, 64, 64, 100] 9 | } 10 | ) -------------------------------------------------------------------------------- /arguments/dycheck/default.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2,4], 9 | defor_depth = 1, 10 | net_width = 128, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.0001, 14 | render_process=True 15 | ) 16 | OptimizationParams = dict( 17 | # dataloader=True, 18 | iterations = 60_000, 19 | batch_size=2, 20 | coarse_iterations = 3000, 21 | densify_until_iter = 10_000, 22 | opacity_reset_interval = 300000, 23 | # grid_lr_init = 0.0016, 24 | # grid_lr_final = 16, 25 | # opacity_threshold_coarse = 0.005, 26 | # opacity_threshold_fine_init = 0.005, 27 | # opacity_threshold_fine_after = 0.005, 28 | # pruning_interval = 2000 29 | ) -------------------------------------------------------------------------------- /arguments/dynerf/coffee_martini.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | 4 | ) -------------------------------------------------------------------------------- /arguments/dynerf/cook_spinach.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | batch_size=2, 4 | 5 | ) -------------------------------------------------------------------------------- /arguments/dynerf/cut_roasted_beef.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | batch_size=2, 4 | ) -------------------------------------------------------------------------------- /arguments/dynerf/default.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2], 9 | defor_depth = 0, 10 | net_width = 128, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.0001, 14 | no_do=False, 15 | no_dshs=False, 16 | no_ds=False, 17 | empty_voxel=False, 18 | render_process=False, 19 | static_mlp=False 20 | 21 | ) 22 | OptimizationParams = dict( 23 | dataloader=True, 24 | iterations = 14000, 25 | batch_size=4, 26 | coarse_iterations = 3000, 27 | densify_until_iter = 10_000, 28 | opacity_reset_interval = 60000, 29 | opacity_threshold_coarse = 0.005, 30 | opacity_threshold_fine_init = 0.005, 31 | opacity_threshold_fine_after = 0.005, 32 | # pruning_interval = 2000 33 | ) -------------------------------------------------------------------------------- /arguments/dynerf/flame_salmon_1.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | 4 | ) -------------------------------------------------------------------------------- /arguments/dynerf/flame_steak.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | batch_size=2, 4 | 5 | ) -------------------------------------------------------------------------------- /arguments/dynerf/sear_steak.py: -------------------------------------------------------------------------------- 1 | _base_ = './default.py' 2 | OptimizationParams = dict( 3 | batch_size=2, 4 | ) -------------------------------------------------------------------------------- /arguments/hypernerf/3dprinter.py: -------------------------------------------------------------------------------- 1 | _base_="default.py" 2 | ModelParams=dict( 3 | kplanes_config = { 4 | 'grid_dimensions': 2, 5 | 'input_coordinate_dim': 4, 6 | 'output_coordinate_dim': 16, 7 | 'resolution': [64, 64, 64, 100] 8 | }, 9 | ) 10 | OptimizationParams=dict( 11 | ) -------------------------------------------------------------------------------- /arguments/hypernerf/banana.py: -------------------------------------------------------------------------------- 1 | _base_="default.py" 2 | ModelParams=dict( 3 | kplanes_config = { 4 | 'grid_dimensions': 2, 5 | 'input_coordinate_dim': 4, 6 | 'output_coordinate_dim': 16, 7 | 'resolution': [64, 64, 64, 250] 8 | }, 9 | ) 10 | OptimizationParams=dict( 11 | ) -------------------------------------------------------------------------------- /arguments/hypernerf/broom2.py: -------------------------------------------------------------------------------- 1 | _base_="default.py" 2 | ModelParams=dict( 3 | kplanes_config = { 4 | 'grid_dimensions': 2, 5 | 'input_coordinate_dim': 4, 6 | 'output_coordinate_dim': 16, 7 | 'resolution': [64, 64, 64, 100] 8 | }, 9 | ) 10 | OptimizationParams=dict( 11 | ) -------------------------------------------------------------------------------- /arguments/hypernerf/chicken.py: -------------------------------------------------------------------------------- 1 | _base_="default.py" 2 | ModelParams=dict( 3 | kplanes_config = { 4 | 'grid_dimensions': 2, 5 | 'input_coordinate_dim': 4, 6 | 'output_coordinate_dim': 16, 7 | 'resolution': [64, 64, 64, 80] 8 | }, 9 | ) 10 | OptimizationParams=dict( 11 | ) -------------------------------------------------------------------------------- /arguments/hypernerf/default.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2,4], 9 | defor_depth = 1, 10 | net_width = 128, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.0001, 14 | render_process=True 15 | ) 16 | OptimizationParams = dict( 17 | # dataloader=True, 18 | iterations = 14_000, 19 | batch_size=2, 20 | coarse_iterations = 3000, 21 | densify_until_iter = 10_000, 22 | opacity_reset_interval = 300000, 23 | # grid_lr_init = 0.0016, 24 | # grid_lr_final = 16, 25 | # opacity_threshold_coarse = 0.005, 26 | # opacity_threshold_fine_init = 0.005, 27 | # opacity_threshold_fine_after = 0.005, 28 | # pruning_interval = 2000 29 | ) -------------------------------------------------------------------------------- /arguments/multipleview/default.py: -------------------------------------------------------------------------------- 1 | ModelHiddenParams = dict( 2 | kplanes_config = { 3 | 'grid_dimensions': 2, 4 | 'input_coordinate_dim': 4, 5 | 'output_coordinate_dim': 16, 6 | 'resolution': [64, 64, 64, 150] 7 | }, 8 | multires = [1,2], 9 | defor_depth = 0, 10 | net_width = 128, 11 | plane_tv_weight = 0.0002, 12 | time_smoothness_weight = 0.001, 13 | l1_time_planes = 0.0001, 14 | no_do=False, 15 | no_dshs=False, 16 | no_ds=False, 17 | empty_voxel=False, 18 | render_process=False, 19 | static_mlp=False 20 | 21 | ) 22 | OptimizationParams = dict( 23 | dataloader=True, 24 | iterations = 15000, 25 | batch_size=1, 26 | coarse_iterations = 3000, 27 | densify_until_iter = 10_000, 28 | # opacity_reset_interval = 60000, 29 | opacity_threshold_coarse = 0.005, 30 | opacity_threshold_fine_init = 0.005, 31 | opacity_threshold_fine_after = 0.005, 32 | # pruning_interval = 2000 33 | ) -------------------------------------------------------------------------------- /assets/cut_roasted_beef_time.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/cut_roasted_beef_time.mp4 -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/pipeline.png -------------------------------------------------------------------------------- /assets/port_forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/port_forward.png -------------------------------------------------------------------------------- /assets/teaserfig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/teaserfig.jpg -------------------------------------------------------------------------------- /assets/teaservideo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/teaservideo.mp4 -------------------------------------------------------------------------------- /assets/viewer.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/assets/viewer.mp4 -------------------------------------------------------------------------------- /colmap.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | workdir=$1 4 | datatype=$2 # blender, hypernerf, llff 5 | export CUDA_VISIBLE_DEVICES=0 6 | rm -rf $workdir/sparse_ 7 | rm -rf $workdir/image_colmap 8 | python scripts/"$datatype"2colmap.py $workdir 9 | rm -rf $workdir/colmap 10 | rm -rf $workdir/colmap/sparse/0 11 | 12 | mkdir $workdir/colmap 13 | cp -r $workdir/image_colmap $workdir/colmap/images 14 | cp -r $workdir/sparse_ $workdir/colmap/sparse_custom 15 | colmap feature_extractor --database_path $workdir/colmap/database.db --image_path $workdir/colmap/images --SiftExtraction.max_image_size 4096 --SiftExtraction.max_num_features 16384 --SiftExtraction.estimate_affine_shape 1 --SiftExtraction.domain_size_pooling 1 16 | python database.py --database_path $workdir/colmap/database.db --txt_path $workdir/colmap/sparse_custom/cameras.txt 17 | colmap exhaustive_matcher --database_path $workdir/colmap/database.db 18 | mkdir -p $workdir/colmap/sparse/0 19 | 20 | colmap point_triangulator --database_path $workdir/colmap/database.db --image_path $workdir/colmap/images --input_path $workdir/colmap/sparse_custom --output_path $workdir/colmap/sparse/0 --clear_points 1 21 | 22 | mkdir -p $workdir/colmap/dense/workspace 23 | colmap image_undistorter --image_path $workdir/colmap/images --input_path $workdir/colmap/sparse/0 --output_path $workdir/colmap/dense/workspace 24 | colmap patch_match_stereo --workspace_path $workdir/colmap/dense/workspace 25 | colmap stereo_fusion --workspace_path $workdir/colmap/dense/workspace --output_path $workdir/colmap/dense/workspace/fused.ply 26 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | # This script is based on an original implementation by True Price. 2 | # Created by liminghao 3 | import sys 4 | import numpy as np 5 | import sqlite3 6 | 7 | IS_PYTHON3 = sys.version_info[0] >= 3 8 | 9 | def array_to_blob(array): 10 | if IS_PYTHON3: 11 | return array.tostring() 12 | else: 13 | return np.getbuffer(array) 14 | 15 | def blob_to_array(blob, dtype, shape=(-1,)): 16 | if IS_PYTHON3: 17 | return np.fromstring(blob, dtype=dtype).reshape(*shape) 18 | else: 19 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 20 | 21 | class COLMAPDatabase(sqlite3.Connection): 22 | 23 | @staticmethod 24 | def connect(database_path): 25 | return sqlite3.connect(database_path, factory=COLMAPDatabase) 26 | 27 | def __init__(self, *args, **kwargs): 28 | super(COLMAPDatabase, self).__init__(*args, **kwargs) 29 | 30 | self.create_tables = lambda: self.executescript(CREATE_ALL) 31 | self.create_cameras_table = \ 32 | lambda: self.executescript(CREATE_CAMERAS_TABLE) 33 | self.create_descriptors_table = \ 34 | lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) 35 | self.create_images_table = \ 36 | lambda: self.executescript(CREATE_IMAGES_TABLE) 37 | self.create_two_view_geometries_table = \ 38 | lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE) 39 | self.create_keypoints_table = \ 40 | lambda: self.executescript(CREATE_KEYPOINTS_TABLE) 41 | self.create_matches_table = \ 42 | lambda: self.executescript(CREATE_MATCHES_TABLE) 43 | self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) 44 | 45 | def update_camera(self, model, width, height, params, camera_id): 46 | params = np.asarray(params, np.float64) 47 | cursor = self.execute( 48 | "UPDATE cameras SET model=?, width=?, height=?, params=?, prior_focal_length=True WHERE camera_id=?", 49 | (model, width, height, array_to_blob(params),camera_id)) 50 | return cursor.lastrowid 51 | 52 | def camTodatabase(): 53 | import os 54 | import argparse 55 | 56 | camModelDict = {'SIMPLE_PINHOLE': 0, 57 | 'PINHOLE': 1, 58 | 'SIMPLE_RADIAL': 2, 59 | 'RADIAL': 3, 60 | 'OPENCV': 4, 61 | 'FULL_OPENCV': 5, 62 | 'SIMPLE_RADIAL_FISHEYE': 6, 63 | 'RADIAL_FISHEYE': 7, 64 | 'OPENCV_FISHEYE': 8, 65 | 'FOV': 9, 66 | 'THIN_PRISM_FISHEYE': 10} 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--database_path", type=str, default="database.db") 69 | parser.add_argument("--txt_path", type=str, default="colmap/sparse_cameras.txt") 70 | # breakpoint() 71 | args = parser.parse_args() 72 | if os.path.exists(args.database_path)==False: 73 | print("ERROR: database path dosen't exist -- please check database.db.") 74 | return 75 | # Open the database. 76 | db = COLMAPDatabase.connect(args.database_path) 77 | 78 | idList=list() 79 | modelList=list() 80 | widthList=list() 81 | heightList=list() 82 | paramsList=list() 83 | # Update real cameras from .txt 84 | with open(args.txt_path, "r") as cam: 85 | lines = cam.readlines() 86 | for i in range(0,len(lines),1): 87 | if lines[i][0]!='#': 88 | strLists = lines[i].split() 89 | cameraId=int(strLists[0]) 90 | cameraModel=camModelDict[strLists[1]] #SelectCameraModel 91 | width=int(strLists[2]) 92 | height=int(strLists[3]) 93 | paramstr=np.array(strLists[4:12]) 94 | params = paramstr.astype(np.float64) 95 | idList.append(cameraId) 96 | modelList.append(cameraModel) 97 | widthList.append(width) 98 | heightList.append(height) 99 | paramsList.append(params) 100 | camera_id = db.update_camera(cameraModel, width, height, params, cameraId) 101 | 102 | # Commit the data to the file. 103 | db.commit() 104 | # Read and check cameras. 105 | rows = db.execute("SELECT * FROM cameras") 106 | for i in range(0,len(idList),1): 107 | camera_id, model, width, height, params, prior = next(rows) 108 | params = blob_to_array(params, np.float64) 109 | assert camera_id == idList[i] 110 | assert model == modelList[i] and width == widthList[i] and height == heightList[i] 111 | assert np.allclose(params, paramsList[i]) 112 | 113 | # Close database.db. 114 | db.close() 115 | 116 | if __name__ == "__main__": 117 | import sys,os 118 | 119 | camTodatabase() -------------------------------------------------------------------------------- /docs/viewer_usage.md: -------------------------------------------------------------------------------- 1 | # 4D Gaussian Splatting 2 | The viewer is downloads from [3D-GS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/binaries/viewers.zip), you can extract the zip file under the folder of 4D-GS like: 3 | ``` 4 | ├── 4DGaussians 5 | | |viewers 6 | | ├── bin 7 | | ├── resources 8 | | ├── shaders 9 | | |... 10 | | | train.py 11 | | | test.py 12 | | | ...(other files) 13 | ``` 14 | ## How to use viewer? 15 | If you train the 4D-GS on locally: 16 | ```python 17 | ./viewers/bin/SIBR_remoteGaussian_app.exe --port 6017 # port should be same with your trainging code. 18 | ``` 19 | If you train the 4D-GS on the server, you should add a port forward, in VSCode, like this: 20 | ![port_forward](../assets/port_forward.png) 21 | Then you could clone this repo on your personal computer, and download a D-NeRF dataset, like this: 22 | ``` 23 | ├── 4DGaussians 24 | | |viewers 25 | | ├── bin 26 | | ├── resources 27 | | ├── shaders 28 | | |... 29 | │ | data 30 | │ ├── dnerf 31 | | | train.py 32 | | | test.py 33 | | | ...(other files) 34 | ``` 35 | And the rendering speed may mainly depends on your network bandwidth. 36 | -------------------------------------------------------------------------------- /export_perframe_3DGS.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | from scene import Scene 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | from os import makedirs 9 | from gaussian_renderer import render 10 | import torchvision 11 | from utils.general_utils import safe_state 12 | from argparse import ArgumentParser 13 | from arguments import ModelParams, PipelineParams, get_combined_args, ModelHiddenParams 14 | from gaussian_renderer import GaussianModel 15 | from time import time 16 | import open3d as o3d 17 | from plyfile import PlyData, PlyElement 18 | # import torch.multiprocessing as mp 19 | import threading 20 | from utils.render_utils import get_state_at_time 21 | import concurrent.futures 22 | def render_sets(dataset : ModelParams, hyperparam, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, skip_video: bool): 23 | with torch.no_grad(): 24 | gaussians = GaussianModel(dataset.sh_degree, hyperparam) 25 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 26 | 27 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 28 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 29 | 30 | return gaussians, scene 31 | 32 | def save_point_cloud(points, model_path, timestamp): 33 | output_path = os.path.join(model_path,"point_pertimestamp") 34 | if not os.path.exists(output_path): 35 | os.makedirs(output_path,exist_ok=True) 36 | points = points.detach().cpu().numpy() 37 | pcd = o3d.geometry.PointCloud() 38 | pcd.points = o3d.utility.Vector3dVector(points) 39 | ply_path = os.path.join(output_path,f"points_{timestamp}.ply") 40 | o3d.io.write_point_cloud(ply_path, pcd) 41 | def construct_list_of_attributes(feature_dc_shape, feature_rest_shape, scaling_shape,rotation_shape): 42 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 43 | # All channels except the 3 DC 44 | for i in range(feature_dc_shape[1]*feature_dc_shape[2]): 45 | l.append('f_dc_{}'.format(i)) 46 | for i in range(feature_rest_shape[1]*feature_rest_shape[2]): 47 | l.append('f_rest_{}'.format(i)) 48 | l.append('opacity') 49 | for i in range(scaling_shape[1]): 50 | l.append('scale_{}'.format(i)) 51 | for i in range(rotation_shape[1]): 52 | l.append('rot_{}'.format(i)) 53 | # breakpoint() 54 | return l 55 | def init_3DGaussians_ply(points, scales, rotations, opactiy, shs, feature_shape): 56 | xyz = points.detach().cpu().numpy() 57 | normals = np.zeros_like(xyz) 58 | feature_dc = shs[:,0:feature_shape[0],:] 59 | feature_rest = shs[:,feature_shape[0]:,:] 60 | f_dc = shs[:,:feature_shape[0],:].detach().transpose(1,2).flatten(start_dim=1).contiguous().cpu().numpy() 61 | # breakpoint() 62 | f_rest = shs[:,feature_shape[0]:,:].detach().transpose(1,2).flatten(start_dim=1).contiguous().cpu().numpy() 63 | opacities = opactiy.detach().cpu().numpy() 64 | scale = scales.detach().cpu().numpy() 65 | rotation = rotations.detach().cpu().numpy() 66 | 67 | dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(feature_dc.shape, feature_rest.shape, scales.shape, rotations.shape)] 68 | elements = np.empty(xyz.shape[0], dtype=dtype_full) 69 | attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) 70 | elements[:] = list(map(tuple, attributes)) 71 | el = PlyElement.describe(elements, 'vertex') 72 | # breakpoint() 73 | return PlyData([el]) 74 | 75 | parser = ArgumentParser(description="Testing script parameters") 76 | model = ModelParams(parser, sentinel=True) 77 | pipeline = PipelineParams(parser) 78 | hyperparam = ModelHiddenParams(parser) 79 | parser.add_argument("--iteration", default=-1, type=int) 80 | parser.add_argument("--skip_train", action="store_true") 81 | parser.add_argument("--skip_test", action="store_true") 82 | parser.add_argument("--quiet", action="store_true") 83 | parser.add_argument("--skip_video", action="store_true") 84 | parser.add_argument("--configs", type=str) 85 | # parser.add_argument("--model_path", type=str) 86 | 87 | args = get_combined_args(parser) 88 | print("Rendering " , args.model_path) 89 | if args.configs: 90 | import mmcv 91 | from utils.params_utils import merge_hparams 92 | config = mmcv.Config.fromfile(args.configs) 93 | args = merge_hparams(args, config) 94 | # Initialize system state (RNG) 95 | safe_state(args.quiet) 96 | gaussians, scene = render_sets(model.extract(args), hyperparam.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.skip_video) 97 | output_path = os.path.join(args.model_path,"gaussian_pertimestamp") 98 | os.makedirs(output_path,exist_ok=True) 99 | print("Computing Gaussians.") 100 | for index, viewpoint in enumerate(scene.getTestCameras()): 101 | 102 | points, scales_final, rotations_final, opacity_final, shs_final = get_state_at_time(gaussians, viewpoint) 103 | feature_dc_shape = gaussians._features_dc.shape[1] 104 | feature_rest_shape = gaussians._features_rest.shape[1] 105 | gs_ply = init_3DGaussians_ply(points, scales_final, rotations_final, opacity_final, shs_final, [feature_dc_shape, feature_rest_shape]) 106 | gs_ply.write(os.path.join(output_path,"time_{0:05d}.ply".format(index))) 107 | print("done") -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 68 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | 70 | if not args.skip_metrics: 71 | scenes_string = "" 72 | for scene in all_scenes: 73 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 74 | 75 | os.system("python metrics.py -m " + scenes_string) -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | from time import time as get_time 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, stage="fine", cam_type=None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | 34 | means3D = pc.get_xyz 35 | if cam_type != "PanopticSports": 36 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 37 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 38 | raster_settings = GaussianRasterizationSettings( 39 | image_height=int(viewpoint_camera.image_height), 40 | image_width=int(viewpoint_camera.image_width), 41 | tanfovx=tanfovx, 42 | tanfovy=tanfovy, 43 | bg=bg_color, 44 | scale_modifier=scaling_modifier, 45 | viewmatrix=viewpoint_camera.world_view_transform.cuda(), 46 | projmatrix=viewpoint_camera.full_proj_transform.cuda(), 47 | sh_degree=pc.active_sh_degree, 48 | campos=viewpoint_camera.camera_center.cuda(), 49 | prefiltered=False, 50 | debug=pipe.debug 51 | ) 52 | time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1) 53 | else: 54 | raster_settings = viewpoint_camera['camera'] 55 | time=torch.tensor(viewpoint_camera['time']).to(means3D.device).repeat(means3D.shape[0],1) 56 | 57 | 58 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 59 | 60 | # means3D = pc.get_xyz 61 | # add deformation to each points 62 | # deformation = pc.get_deformation 63 | 64 | 65 | means2D = screenspace_points 66 | opacity = pc._opacity 67 | shs = pc.get_features 68 | 69 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 70 | # scaling / rotation by the rasterizer. 71 | scales = None 72 | rotations = None 73 | cov3D_precomp = None 74 | if pipe.compute_cov3D_python: 75 | cov3D_precomp = pc.get_covariance(scaling_modifier) 76 | else: 77 | scales = pc._scaling 78 | rotations = pc._rotation 79 | deformation_point = pc._deformation_table 80 | if "coarse" in stage: 81 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = means3D, scales, rotations, opacity, shs 82 | elif "fine" in stage: 83 | # time0 = get_time() 84 | # means3D_deform, scales_deform, rotations_deform, opacity_deform = pc._deformation(means3D[deformation_point], scales[deformation_point], 85 | # rotations[deformation_point], opacity[deformation_point], 86 | # time[deformation_point]) 87 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = pc._deformation(means3D, scales, 88 | rotations, opacity, shs, 89 | time) 90 | else: 91 | raise NotImplementedError 92 | 93 | 94 | 95 | # time2 = get_time() 96 | # print("asset value:",time2-time1) 97 | scales_final = pc.scaling_activation(scales_final) 98 | rotations_final = pc.rotation_activation(rotations_final) 99 | opacity = pc.opacity_activation(opacity_final) 100 | # print(opacity.max()) 101 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 102 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 103 | # shs = None 104 | colors_precomp = None 105 | if override_color is None: 106 | if pipe.convert_SHs_python: 107 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 108 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.cuda().repeat(pc.get_features.shape[0], 1)) 109 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 110 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 111 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 112 | else: 113 | pass 114 | # shs = 115 | else: 116 | colors_precomp = override_color 117 | 118 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 119 | # time3 = get_time() 120 | rendered_image, radii, depth = rasterizer( 121 | means3D = means3D_final, 122 | means2D = means2D, 123 | shs = shs_final, 124 | colors_precomp = colors_precomp, 125 | opacities = opacity, 126 | scales = scales_final, 127 | rotations = rotations_final, 128 | cov3D_precomp = cov3D_precomp) 129 | # time4 = get_time() 130 | # print("rasterization:",time4-time3) 131 | # breakpoint() 132 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 133 | # They will be excluded from value updates used in the splitting criteria. 134 | return {"render": rendered_image, 135 | "viewspace_points": screenspace_points, 136 | "visibility_filter" : radii > 0, 137 | "radii": radii, 138 | "depth":depth} 139 | 140 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform,time=0) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /merge_many_4dgs.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import torch 4 | from scene import Scene 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | from os import makedirs 9 | from gaussian_renderer import render 10 | import torchvision 11 | from utils.general_utils import safe_state 12 | from argparse import ArgumentParser 13 | from arguments import ModelParams, PipelineParams, get_combined_args, ModelHiddenParams 14 | from gaussian_renderer import GaussianModel 15 | from time import time 16 | import open3d as o3d 17 | # import torch.multiprocessing as mp 18 | import threading 19 | import concurrent.futures 20 | from copy import deepcopy 21 | # 22 | # Copyright (C) 2023, Inria 23 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 24 | # All rights reserved. 25 | # 26 | # This software is free for non-commercial, research and evaluation use 27 | # under the terms of the LICENSE.md file. 28 | # 29 | # For inquiries contact george.drettakis@inria.fr 30 | # 31 | import torch 32 | import math 33 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 34 | from scene.gaussian_model import GaussianModel 35 | from utils.render_utils import get_state_at_time 36 | from tqdm import tqdm 37 | def rotate_point_cloud(point_cloud, displacement, rotation_angles, scales_bias): 38 | 39 | theta, phi = rotation_angles 40 | 41 | rotation_matrix_z = torch.tensor([ 42 | [torch.cos(theta), -torch.sin(theta), 0], 43 | [torch.sin(theta), torch.cos(theta), 0], 44 | [0, 0, 1] 45 | ]).to(point_cloud) 46 | rotation_matrix_x = torch.tensor([ 47 | [1, 0, 0], 48 | [0, torch.cos(phi), -torch.sin(phi)], 49 | [0, torch.sin(phi), torch.cos(phi)] 50 | ]).to(point_cloud) 51 | rotation_matrix = torch.matmul(rotation_matrix_z, rotation_matrix_x) 52 | # print(rotation_matrix) 53 | point_cloud = point_cloud*scales_bias 54 | rotated_point_cloud = torch.matmul(point_cloud, rotation_matrix.t()) 55 | displaced_point_cloud = rotated_point_cloud + displacement 56 | 57 | return displaced_point_cloud 58 | @torch.no_grad() 59 | def render(viewpoint_camera, gaussians, bg_color : torch.Tensor, scaling_modifier = 1.0, motion_bias = [torch.tensor([0,0,0])], rotation_bias = [torch.tensor([0,0])], 60 | scales_bias=[1,1]): 61 | """ 62 | Render the scene. 63 | 64 | Background tensor (bg_color) must be on GPU! 65 | """ 66 | 67 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 68 | 69 | 70 | # Set up rasterization configuration 71 | 72 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 73 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 74 | screenspace_points = None 75 | for pc in gaussians: 76 | if screenspace_points is None: 77 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 78 | else: 79 | screenspace_points1 = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 80 | screenspace_points = torch.cat([screenspace_points,screenspace_points1],dim=0) 81 | try: 82 | screenspace_points.retain_grad() 83 | except: 84 | pass 85 | raster_settings = GaussianRasterizationSettings( 86 | image_height=int(viewpoint_camera.image_height), 87 | image_width=int(viewpoint_camera.image_width), 88 | tanfovx=tanfovx, 89 | tanfovy=tanfovy, 90 | bg=bg_color, 91 | scale_modifier=scaling_modifier, 92 | viewmatrix=viewpoint_camera.world_view_transform.cuda(), 93 | projmatrix=viewpoint_camera.full_proj_transform.cuda(), 94 | sh_degree=gaussians[0].active_sh_degree, 95 | campos=viewpoint_camera.camera_center.cuda(), 96 | prefiltered=False, 97 | debug=False 98 | ) 99 | 100 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 101 | # means3D = pc.get_xyz 102 | # add deformation to each points 103 | # deformation = pc.get_deformation 104 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = None, None, None, None, None 105 | for index, pc in enumerate(gaussians): 106 | 107 | means3D_final1, scales_final1, rotations_final1, opacity_final1, shs_final1 = get_state_at_time(pc, viewpoint_camera) 108 | scales_final1 = pc.scaling_activation(scales_final1) 109 | rotations_final1 = pc.rotation_activation(rotations_final1) 110 | opacity_final1 = pc.opacity_activation(opacity_final1) 111 | if index == 0: 112 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = means3D_final1, scales_final1, rotations_final1, opacity_final1, shs_final1 113 | else: 114 | motion_bias_t = motion_bias[index-1].to(means3D_final) 115 | rotation_bias_t = rotation_bias[index-1].to(means3D_final) 116 | means3D_final1 = rotate_point_cloud(means3D_final1,motion_bias_t,rotation_bias_t,scales_bias[index-1]) 117 | # breakpoint() 118 | scales_final1 = scales_final1*scales_bias[index-1] 119 | means3D_final = torch.cat([means3D_final,means3D_final1],dim=0) 120 | scales_final = torch.cat([scales_final,scales_final1],dim=0) 121 | rotations_final = torch.cat([rotations_final,rotations_final1],dim=0) 122 | opacity_final = torch.cat([opacity_final,opacity_final1],dim=0) 123 | shs_final = torch.cat([shs_final,shs_final1],dim=0) 124 | 125 | colors_precomp = None 126 | cov3D_precomp = None 127 | rendered_image, radii, depth = rasterizer( 128 | means3D = means3D_final, 129 | means2D = screenspace_points, 130 | shs = shs_final, 131 | colors_precomp = colors_precomp, 132 | opacities = opacity_final, 133 | scales = scales_final, 134 | rotations = rotations_final, 135 | cov3D_precomp = cov3D_precomp) 136 | 137 | return {"render": rendered_image, 138 | "viewspace_points": screenspace_points, 139 | "visibility_filter" : radii > 0, 140 | "radii": radii, 141 | "depth":depth} 142 | 143 | 144 | def init_gaussians(dataset : ModelParams, hyperparam, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, skip_video: bool): 145 | with torch.no_grad(): 146 | gaussians = GaussianModel(dataset.sh_degree, hyperparam) 147 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 148 | 149 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 150 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 151 | 152 | print("hello!!") 153 | return gaussians, scene, background 154 | 155 | def save_point_cloud(points, model_path, timestamp): 156 | output_path = os.path.join(model_path,"point_pertimestamp") 157 | if not os.path.exists(output_path): 158 | os.makedirs(output_path,exist_ok=True) 159 | points = points.detach().cpu().numpy() 160 | pcd = o3d.geometry.PointCloud() 161 | pcd.points = o3d.utility.Vector3dVector(points) 162 | ply_path = os.path.join(output_path,f"points_{timestamp}.ply") 163 | o3d.io.write_point_cloud(ply_path, pcd) 164 | # This scripts can help you to merge many 4DGS. 165 | parser = ArgumentParser(description="Testing script parameters") 166 | model = ModelParams(parser, sentinel=True) 167 | pipeline = PipelineParams(parser) 168 | hyperparam = ModelHiddenParams(parser) 169 | parser.add_argument("--iteration", default=-1, type=int) 170 | parser.add_argument("--skip_train", action="store_true") 171 | parser.add_argument("--skip_test", action="store_true") 172 | parser.add_argument("--quiet", action="store_true") 173 | parser.add_argument("--skip_video", action="store_true") 174 | parser.add_argument("--configs1", type=str, default="arguments/dynerf_9/flame_salmon_1.py") 175 | parser.add_argument("--configs2", type=str, default="arguments/dnerf_tv_2/hellwarrior.py") 176 | parser.add_argument("--modelpath2", type=str, default="output/dnerf_tv_2/hellwarrior") 177 | parser.add_argument("--configs3", type=str, default="arguments/dnerf_tv_2/mutant.py") 178 | parser.add_argument("--modelpath3", type=str, default="output/dnerf_tv_2/mutant") 179 | render_path = "output/editing_render_flame_salmon" 180 | 181 | args = get_combined_args(parser) 182 | print("Rendering " , args.model_path) 183 | args2 = deepcopy(args) 184 | args3 = deepcopy(args) 185 | 186 | if args.configs1: 187 | import mmcv 188 | from utils.params_utils import merge_hparams 189 | config = mmcv.Config.fromfile(args.configs1) 190 | args1 = merge_hparams(args, config) 191 | # breakpoint() 192 | if args2.configs2: 193 | import mmcv 194 | from utils.params_utils import merge_hparams 195 | config = mmcv.Config.fromfile(args2.configs2) 196 | args2 = merge_hparams(args2, config) 197 | args2.model_path = args2.modelpath2 198 | if args3.configs3: 199 | import mmcv 200 | from utils.params_utils import merge_hparams 201 | config = mmcv.Config.fromfile(args3.configs3) 202 | args3 = merge_hparams(args3, config) 203 | args3.model_path = args3.modelpath3 204 | safe_state(args.quiet) 205 | gaussians1, scene1, background = init_gaussians(model.extract(args1), hyperparam.extract(args1), args1.iteration, pipeline.extract(args1), args1.skip_train, args1.skip_test, args1.skip_video) 206 | gaussians2, scene2, background = init_gaussians(model.extract(args2), hyperparam.extract(args2), args2.iteration, pipeline.extract(args2), args2.skip_train, args2.skip_test, args2.skip_video) 207 | gaussians3, scene3, background = init_gaussians(model.extract(args3), hyperparam.extract(args3), args3.iteration, pipeline.extract(args3), args3.skip_train, args3.skip_test, args3.skip_video) 208 | gaussians = [gaussians1,gaussians2,gaussians3] 209 | # breakpoint() 210 | to8b = lambda x : (255*np.clip(x.cpu().numpy(),0,1)).astype(np.uint8) 211 | 212 | render_images=[] 213 | if not os.path.exists(render_path): 214 | os.makedirs(render_path,exist_ok=True) 215 | for index, viewpoint in tqdm(enumerate(scene1.getVideoCameras())): 216 | result = render(viewpoint, gaussians, 217 | bg_color=background, 218 | motion_bias=[ 219 | torch.tensor([4,4,12]), 220 | torch.tensor([-2,4,12]) 221 | ] 222 | ,rotation_bias=[ 223 | torch.tensor([0,1.9*np.pi/4]), 224 | torch.tensor([0,1.9*np.pi/4]) 225 | ], 226 | scales_bias = [1,1]) 227 | render_images.append(to8b(result["render"]).transpose(1,2,0)) 228 | 229 | torchvision.utils.save_image(result["render"],os.path.join(render_path,f"output_image{index}.png")) 230 | 231 | imageio.mimwrite(os.path.join(render_path, 'video_rgb.mp4'), render_images, fps=30, codec='libx265') 232 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | from pytorch_msssim import ms_ssim 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | lpipsa = [] 71 | ms_ssims = [] 72 | Dssims = [] 73 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 74 | ssims.append(ssim(renders[idx], gts[idx])) 75 | psnrs.append(psnr(renders[idx], gts[idx])) 76 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 77 | ms_ssims.append(ms_ssim(renders[idx], gts[idx],data_range=1, size_average=True )) 78 | lpipsa.append(lpips(renders[idx], gts[idx], net_type='alex')) 79 | Dssims.append((1-ms_ssims[-1])/2) 80 | 81 | print("Scene: ", scene_dir, "SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 82 | print("Scene: ", scene_dir, "PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 83 | print("Scene: ", scene_dir, "LPIPS-vgg: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 84 | print("Scene: ", scene_dir, "LPIPS-alex: {:>12.7f}".format(torch.tensor(lpipsa).mean(), ".5")) 85 | print("Scene: ", scene_dir, "MS-SSIM: {:>12.7f}".format(torch.tensor(ms_ssims).mean(), ".5")) 86 | print("Scene: ", scene_dir, "D-SSIM: {:>12.7f}".format(torch.tensor(Dssims).mean(), ".5")) 87 | 88 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 89 | "PSNR": torch.tensor(psnrs).mean().item(), 90 | "LPIPS-vgg": torch.tensor(lpipss).mean().item(), 91 | "LPIPS-alex": torch.tensor(lpipsa).mean().item(), 92 | "MS-SSIM": torch.tensor(ms_ssims).mean().item(), 93 | "D-SSIM": torch.tensor(Dssims).mean().item()}, 94 | 95 | ) 96 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 97 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 98 | "LPIPS-vgg": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}, 99 | "LPIPS-alex": {name: lp for lp, name in zip(torch.tensor(lpipsa).tolist(), image_names)}, 100 | "MS-SSIM": {name: lp for lp, name in zip(torch.tensor(ms_ssims).tolist(), image_names)}, 101 | "D-SSIM": {name: lp for lp, name in zip(torch.tensor(Dssims).tolist(), image_names)}, 102 | 103 | } 104 | ) 105 | 106 | with open(scene_dir + "/results.json", 'w') as fp: 107 | json.dump(full_dict[scene_dir], fp, indent=True) 108 | with open(scene_dir + "/per_view.json", 'w') as fp: 109 | json.dump(per_view_dict[scene_dir], fp, indent=True) 110 | except Exception as e: 111 | 112 | print("Unable to compute metrics for model", scene_dir) 113 | raise e 114 | 115 | if __name__ == "__main__": 116 | device = torch.device("cuda:0") 117 | torch.cuda.set_device(device) 118 | 119 | # Set up command line argument parser 120 | parser = ArgumentParser(description="Training script parameters") 121 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 122 | args = parser.parse_args() 123 | evaluate(args.model_paths) 124 | -------------------------------------------------------------------------------- /multipleviewprogress.sh: -------------------------------------------------------------------------------- 1 | workdir=$1 2 | python scripts/extractimages.py multipleview/$workdir 3 | colmap feature_extractor --database_path ./colmap_tmp/database.db --image_path ./colmap_tmp/images --SiftExtraction.max_image_size 4096 --SiftExtraction.max_num_features 16384 --SiftExtraction.estimate_affine_shape 1 --SiftExtraction.domain_size_pooling 1 4 | colmap exhaustive_matcher --database_path ./colmap_tmp/database.db 5 | mkdir ./colmap_tmp/sparse 6 | colmap mapper --database_path ./colmap_tmp/database.db --image_path ./colmap_tmp/images --output_path ./colmap_tmp/sparse 7 | mkdir ./data/multipleview/$workdir/sparse_ 8 | cp -r ./colmap_tmp/sparse/0/* ./data/multipleview/$workdir/sparse_ 9 | 10 | mkdir ./colmap_tmp/dense 11 | colmap image_undistorter --image_path ./colmap_tmp/images --input_path ./colmap_tmp/sparse/0 --output_path ./colmap_tmp/dense --output_type COLMAP 12 | colmap patch_match_stereo --workspace_path ./colmap_tmp/dense --workspace_format COLMAP --PatchMatchStereo.geom_consistency true 13 | colmap stereo_fusion --workspace_path ./colmap_tmp/dense --workspace_format COLMAP --input_type geometric --output_path ./colmap_tmp/dense/fused.ply 14 | 15 | python scripts/downsample_point.py ./colmap_tmp/dense/fused.ply ./data/multipleview/$workdir/points3D_multipleview.ply 16 | 17 | git clone https://github.com/Fyusion/LLFF.git 18 | pip install scikit-image 19 | python LLFF/imgs2poses.py ./colmap_tmp/ 20 | 21 | cp ./colmap_tmp/poses_bounds.npy ./data/multipleview/$workdir/poses_bounds_multipleview.npy 22 | 23 | rm -rf ./colmap_tmp 24 | rm -rf ./LLFF 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import imageio 12 | import numpy as np 13 | import torch 14 | from scene import Scene 15 | import os 16 | import cv2 17 | from tqdm import tqdm 18 | from os import makedirs 19 | from gaussian_renderer import render 20 | import torchvision 21 | from utils.general_utils import safe_state 22 | from argparse import ArgumentParser 23 | from arguments import ModelParams, PipelineParams, get_combined_args, ModelHiddenParams 24 | from gaussian_renderer import GaussianModel 25 | from time import time 26 | import threading 27 | import concurrent.futures 28 | def multithread_write(image_list, path): 29 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=None) 30 | def write_image(image, count, path): 31 | try: 32 | torchvision.utils.save_image(image, os.path.join(path, '{0:05d}'.format(count) + ".png")) 33 | return count, True 34 | except: 35 | return count, False 36 | 37 | tasks = [] 38 | for index, image in enumerate(image_list): 39 | tasks.append(executor.submit(write_image, image, index, path)) 40 | executor.shutdown() 41 | for index, status in enumerate(tasks): 42 | if status == False: 43 | write_image(image_list[index], index, path) 44 | 45 | to8b = lambda x : (255*np.clip(x.cpu().numpy(),0,1)).astype(np.uint8) 46 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background, cam_type): 47 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 48 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 49 | 50 | makedirs(render_path, exist_ok=True) 51 | makedirs(gts_path, exist_ok=True) 52 | render_images = [] 53 | gt_list = [] 54 | render_list = [] 55 | print("point nums:",gaussians._xyz.shape[0]) 56 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 57 | if idx == 0:time1 = time() 58 | 59 | rendering = render(view, gaussians, pipeline, background,cam_type=cam_type)["render"] 60 | render_images.append(to8b(rendering).transpose(1,2,0)) 61 | render_list.append(rendering) 62 | if name in ["train", "test"]: 63 | if cam_type != "PanopticSports": 64 | gt = view.original_image[0:3, :, :] 65 | else: 66 | gt = view['image'].cuda() 67 | gt_list.append(gt) 68 | 69 | time2=time() 70 | print("FPS:",(len(views)-1)/(time2-time1)) 71 | 72 | multithread_write(gt_list, gts_path) 73 | 74 | multithread_write(render_list, render_path) 75 | 76 | 77 | imageio.mimwrite(os.path.join(model_path, name, "ours_{}".format(iteration), 'video_rgb.mp4'), render_images, fps=30) 78 | def render_sets(dataset : ModelParams, hyperparam, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, skip_video: bool): 79 | with torch.no_grad(): 80 | gaussians = GaussianModel(dataset.sh_degree, hyperparam) 81 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 82 | cam_type=scene.dataset_type 83 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 84 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 85 | 86 | if not skip_train: 87 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background,cam_type) 88 | 89 | if not skip_test: 90 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background,cam_type) 91 | if not skip_video: 92 | render_set(dataset.model_path,"video",scene.loaded_iter,scene.getVideoCameras(),gaussians,pipeline,background,cam_type) 93 | if __name__ == "__main__": 94 | # Set up command line argument parser 95 | parser = ArgumentParser(description="Testing script parameters") 96 | model = ModelParams(parser, sentinel=True) 97 | pipeline = PipelineParams(parser) 98 | hyperparam = ModelHiddenParams(parser) 99 | parser.add_argument("--iteration", default=-1, type=int) 100 | parser.add_argument("--skip_train", action="store_true") 101 | parser.add_argument("--skip_test", action="store_true") 102 | parser.add_argument("--quiet", action="store_true") 103 | parser.add_argument("--skip_video", action="store_true") 104 | parser.add_argument("--configs", type=str) 105 | args = get_combined_args(parser) 106 | print("Rendering " , args.model_path) 107 | if args.configs: 108 | import mmcv 109 | from utils.params_utils import merge_hparams 110 | config = mmcv.Config.fromfile(args.configs) 111 | args = merge_hparams(args, config) 112 | # Initialize system state (RNG) 113 | safe_state(args.quiet) 114 | 115 | render_sets(model.extract(args), hyperparam.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.skip_video) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | torchaudio==0.13.1 4 | mmcv==1.6.0 5 | matplotlib 6 | argparse 7 | lpips 8 | plyfile 9 | pytorch_msssim 10 | open3d 11 | imageio[ffmpeg] 12 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from scene.dataset import FourDGSdataset 19 | from arguments import ModelParams 20 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 21 | from torch.utils.data import Dataset 22 | from scene.dataset_readers import add_points 23 | class Scene: 24 | 25 | gaussians : GaussianModel 26 | 27 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], load_coarse=False): 28 | """b 29 | :param path: Path to colmap scene main folder. 30 | """ 31 | self.model_path = args.model_path 32 | self.loaded_iter = None 33 | self.gaussians = gaussians 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | self.video_cameras = {} 45 | if os.path.exists(os.path.join(args.source_path, "sparse")): 46 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args.llffhold) 47 | dataset_type="colmap" 48 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 49 | print("Found transforms_train.json file, assuming Blender data set!") 50 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, args.extension) 51 | dataset_type="blender" 52 | elif os.path.exists(os.path.join(args.source_path, "poses_bounds.npy")): 53 | scene_info = sceneLoadTypeCallbacks["dynerf"](args.source_path, args.white_background, args.eval) 54 | dataset_type="dynerf" 55 | elif os.path.exists(os.path.join(args.source_path,"dataset.json")): 56 | scene_info = sceneLoadTypeCallbacks["nerfies"](args.source_path, False, args.eval) 57 | dataset_type="nerfies" 58 | elif os.path.exists(os.path.join(args.source_path,"train_meta.json")): 59 | scene_info = sceneLoadTypeCallbacks["PanopticSports"](args.source_path) 60 | dataset_type="PanopticSports" 61 | elif os.path.exists(os.path.join(args.source_path,"points3D_multipleview.ply")): 62 | scene_info = sceneLoadTypeCallbacks["MultipleView"](args.source_path) 63 | dataset_type="MultipleView" 64 | else: 65 | assert False, "Could not recognize scene type!" 66 | self.maxtime = scene_info.maxtime 67 | self.dataset_type = dataset_type 68 | self.cameras_extent = scene_info.nerf_normalization["radius"] 69 | print("Loading Training Cameras") 70 | self.train_camera = FourDGSdataset(scene_info.train_cameras, args, dataset_type) 71 | print("Loading Test Cameras") 72 | self.test_camera = FourDGSdataset(scene_info.test_cameras, args, dataset_type) 73 | print("Loading Video Cameras") 74 | self.video_camera = FourDGSdataset(scene_info.video_cameras, args, dataset_type) 75 | 76 | # self.video_camera = cameraList_from_camInfos(scene_info.video_cameras,-1,args) 77 | xyz_max = scene_info.point_cloud.points.max(axis=0) 78 | xyz_min = scene_info.point_cloud.points.min(axis=0) 79 | if args.add_points: 80 | print("add points.") 81 | # breakpoint() 82 | scene_info = scene_info._replace(point_cloud=add_points(scene_info.point_cloud, xyz_max=xyz_max, xyz_min=xyz_min)) 83 | self.gaussians._deformation.deformation_net.set_aabb(xyz_max,xyz_min) 84 | if self.loaded_iter: 85 | self.gaussians.load_ply(os.path.join(self.model_path, 86 | "point_cloud", 87 | "iteration_" + str(self.loaded_iter), 88 | "point_cloud.ply")) 89 | self.gaussians.load_model(os.path.join(self.model_path, 90 | "point_cloud", 91 | "iteration_" + str(self.loaded_iter), 92 | )) 93 | else: 94 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent, self.maxtime) 95 | 96 | def save(self, iteration, stage): 97 | if stage == "coarse": 98 | point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration)) 99 | 100 | else: 101 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 102 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 103 | self.gaussians.save_deformation(point_cloud_path) 104 | def getTrainCameras(self, scale=1.0): 105 | return self.train_camera 106 | 107 | def getTestCameras(self, scale=1.0): 108 | return self.test_camera 109 | def getVideoCameras(self, scale=1.0): 110 | return self.video_camera -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", time = 0, 21 | mask = None, depth=None 22 | ): 23 | super(Camera, self).__init__() 24 | 25 | self.uid = uid 26 | self.colmap_id = colmap_id 27 | self.R = R 28 | self.T = T 29 | self.FoVx = FoVx 30 | self.FoVy = FoVy 31 | self.image_name = image_name 32 | self.time = time 33 | try: 34 | self.data_device = torch.device(data_device) 35 | except Exception as e: 36 | print(e) 37 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 38 | self.data_device = torch.device("cuda") 39 | self.original_image = image.clamp(0.0, 1.0)[:3,:,:] 40 | # breakpoint() 41 | # .to(self.data_device) 42 | self.image_width = self.original_image.shape[2] 43 | self.image_height = self.original_image.shape[1] 44 | 45 | if gt_alpha_mask is not None: 46 | self.original_image *= gt_alpha_mask 47 | # .to(self.data_device) 48 | else: 49 | self.original_image *= torch.ones((1, self.image_height, self.image_width)) 50 | # , device=self.data_device) 51 | self.depth = depth 52 | self.mask = mask 53 | self.zfar = 100.0 54 | self.znear = 0.01 55 | 56 | self.trans = trans 57 | self.scale = scale 58 | 59 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) 60 | # .cuda() 61 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) 62 | # .cuda() 63 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 64 | self.camera_center = self.world_view_transform.inverse()[3, :3] 65 | 66 | class MiniCam: 67 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform, time): 68 | self.image_width = width 69 | self.image_height = height 70 | self.FoVy = fovy 71 | self.FoVx = fovx 72 | self.znear = znear 73 | self.zfar = zfar 74 | self.world_view_transform = world_view_transform 75 | self.full_proj_transform = full_proj_transform 76 | view_inv = torch.inverse(self.world_view_transform) 77 | self.camera_center = view_inv[3][:3] 78 | self.time = time 79 | 80 | -------------------------------------------------------------------------------- /scene/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from scene.cameras import Camera 3 | import numpy as np 4 | from utils.general_utils import PILtoTorch 5 | from utils.graphics_utils import fov2focal, focal2fov 6 | import torch 7 | from utils.camera_utils import loadCam 8 | from utils.graphics_utils import focal2fov 9 | class FourDGSdataset(Dataset): 10 | def __init__( 11 | self, 12 | dataset, 13 | args, 14 | dataset_type 15 | ): 16 | self.dataset = dataset 17 | self.args = args 18 | self.dataset_type=dataset_type 19 | def __getitem__(self, index): 20 | # breakpoint() 21 | 22 | if self.dataset_type != "PanopticSports": 23 | try: 24 | image, w2c, time = self.dataset[index] 25 | R,T = w2c 26 | FovX = focal2fov(self.dataset.focal[0], image.shape[2]) 27 | FovY = focal2fov(self.dataset.focal[0], image.shape[1]) 28 | mask=None 29 | except: 30 | caminfo = self.dataset[index] 31 | image = caminfo.image 32 | R = caminfo.R 33 | T = caminfo.T 34 | FovX = caminfo.FovX 35 | FovY = caminfo.FovY 36 | time = caminfo.time 37 | 38 | mask = caminfo.mask 39 | return Camera(colmap_id=index,R=R,T=T,FoVx=FovX,FoVy=FovY,image=image,gt_alpha_mask=None, 40 | image_name=f"{index}",uid=index,data_device=torch.device("cuda"),time=time, 41 | mask=mask) 42 | else: 43 | return self.dataset[index] 44 | def __len__(self): 45 | 46 | return len(self.dataset) 47 | -------------------------------------------------------------------------------- /scene/deformation.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import os 4 | import time 5 | from tkinter import W 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init as init 12 | from utils.graphics_utils import apply_rotation, batch_quaternion_multiply 13 | from scene.hexplane import HexPlaneField 14 | from scene.grid import DenseGrid 15 | # from scene.grid import HashHexPlane 16 | class Deformation(nn.Module): 17 | def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, grid_pe=0, skips=[], args=None): 18 | super(Deformation, self).__init__() 19 | self.D = D 20 | self.W = W 21 | self.input_ch = input_ch 22 | self.input_ch_time = input_ch_time 23 | self.skips = skips 24 | self.grid_pe = grid_pe 25 | self.no_grid = args.no_grid 26 | self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires) 27 | # breakpoint() 28 | self.args = args 29 | # self.args.empty_voxel=True 30 | if self.args.empty_voxel: 31 | self.empty_voxel = DenseGrid(channels=1, world_size=[64,64,64]) 32 | if self.args.static_mlp: 33 | self.static_mlp = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1)) 34 | 35 | self.ratio=0 36 | self.create_net() 37 | @property 38 | def get_aabb(self): 39 | return self.grid.get_aabb 40 | def set_aabb(self, xyz_max, xyz_min): 41 | print("Deformation Net Set aabb",xyz_max, xyz_min) 42 | self.grid.set_aabb(xyz_max, xyz_min) 43 | if self.args.empty_voxel: 44 | self.empty_voxel.set_aabb(xyz_max, xyz_min) 45 | def create_net(self): 46 | mlp_out_dim = 0 47 | if self.grid_pe !=0: 48 | 49 | grid_out_dim = self.grid.feat_dim+(self.grid.feat_dim)*2 50 | else: 51 | grid_out_dim = self.grid.feat_dim 52 | if self.no_grid: 53 | self.feature_out = [nn.Linear(4,self.W)] 54 | else: 55 | self.feature_out = [nn.Linear(mlp_out_dim + grid_out_dim ,self.W)] 56 | 57 | for i in range(self.D-1): 58 | self.feature_out.append(nn.ReLU()) 59 | self.feature_out.append(nn.Linear(self.W,self.W)) 60 | self.feature_out = nn.Sequential(*self.feature_out) 61 | self.pos_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)) 62 | self.scales_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)) 63 | self.rotations_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)) 64 | self.opacity_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1)) 65 | self.shs_deform = nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 16*3)) 66 | 67 | def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb): 68 | 69 | if self.no_grid: 70 | h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) 71 | else: 72 | 73 | grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) 74 | # breakpoint() 75 | if self.grid_pe > 1: 76 | grid_feature = poc_fre(grid_feature,self.grid_pe) 77 | hidden = torch.cat([grid_feature],-1) 78 | 79 | 80 | hidden = self.feature_out(hidden) 81 | 82 | 83 | return hidden 84 | @property 85 | def get_empty_ratio(self): 86 | return self.ratio 87 | def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None,shs_emb=None, time_feature=None, time_emb=None): 88 | if time_emb is None: 89 | return self.forward_static(rays_pts_emb[:,:3]) 90 | else: 91 | return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, shs_emb, time_feature, time_emb) 92 | 93 | def forward_static(self, rays_pts_emb): 94 | grid_feature = self.grid(rays_pts_emb[:,:3]) 95 | dx = self.static_mlp(grid_feature) 96 | return rays_pts_emb[:, :3] + dx 97 | def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, shs_emb, time_feature, time_emb): 98 | hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_feature, time_emb) 99 | if self.args.static_mlp: 100 | mask = self.static_mlp(hidden) 101 | elif self.args.empty_voxel: 102 | mask = self.empty_voxel(rays_pts_emb[:,:3]) 103 | else: 104 | mask = torch.ones_like(opacity_emb[:,0]).unsqueeze(-1) 105 | # breakpoint() 106 | if self.args.no_dx: 107 | pts = rays_pts_emb[:,:3] 108 | else: 109 | dx = self.pos_deform(hidden) 110 | pts = torch.zeros_like(rays_pts_emb[:,:3]) 111 | pts = rays_pts_emb[:,:3]*mask + dx 112 | if self.args.no_ds : 113 | 114 | scales = scales_emb[:,:3] 115 | else: 116 | ds = self.scales_deform(hidden) 117 | 118 | scales = torch.zeros_like(scales_emb[:,:3]) 119 | scales = scales_emb[:,:3]*mask + ds 120 | 121 | if self.args.no_dr : 122 | rotations = rotations_emb[:,:4] 123 | else: 124 | dr = self.rotations_deform(hidden) 125 | 126 | rotations = torch.zeros_like(rotations_emb[:,:4]) 127 | if self.args.apply_rotation: 128 | rotations = batch_quaternion_multiply(rotations_emb, dr) 129 | else: 130 | rotations = rotations_emb[:,:4] + dr 131 | 132 | if self.args.no_do : 133 | opacity = opacity_emb[:,:1] 134 | else: 135 | do = self.opacity_deform(hidden) 136 | 137 | opacity = torch.zeros_like(opacity_emb[:,:1]) 138 | opacity = opacity_emb[:,:1]*mask + do 139 | if self.args.no_dshs: 140 | shs = shs_emb 141 | else: 142 | dshs = self.shs_deform(hidden).reshape([shs_emb.shape[0],16,3]) 143 | 144 | shs = torch.zeros_like(shs_emb) 145 | # breakpoint() 146 | shs = shs_emb*mask.unsqueeze(-1) + dshs 147 | 148 | return pts, scales, rotations, opacity, shs 149 | def get_mlp_parameters(self): 150 | parameter_list = [] 151 | for name, param in self.named_parameters(): 152 | if "grid" not in name: 153 | parameter_list.append(param) 154 | return parameter_list 155 | def get_grid_parameters(self): 156 | parameter_list = [] 157 | for name, param in self.named_parameters(): 158 | if "grid" in name: 159 | parameter_list.append(param) 160 | return parameter_list 161 | class deform_network(nn.Module): 162 | def __init__(self, args) : 163 | super(deform_network, self).__init__() 164 | net_width = args.net_width 165 | timebase_pe = args.timebase_pe 166 | defor_depth= args.defor_depth 167 | posbase_pe= args.posebase_pe 168 | scale_rotation_pe = args.scale_rotation_pe 169 | opacity_pe = args.opacity_pe 170 | timenet_width = args.timenet_width 171 | timenet_output = args.timenet_output 172 | grid_pe = args.grid_pe 173 | times_ch = 2*timebase_pe+1 174 | self.timenet = nn.Sequential( 175 | nn.Linear(times_ch, timenet_width), nn.ReLU(), 176 | nn.Linear(timenet_width, timenet_output)) 177 | self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(3)+(3*(posbase_pe))*2, grid_pe=grid_pe, input_ch_time=timenet_output, args=args) 178 | self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)])) 179 | self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)])) 180 | self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)])) 181 | self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)])) 182 | self.apply(initialize_weights) 183 | # print(self) 184 | 185 | def forward(self, point, scales=None, rotations=None, opacity=None, shs=None, times_sel=None): 186 | return self.forward_dynamic(point, scales, rotations, opacity, shs, times_sel) 187 | @property 188 | def get_aabb(self): 189 | 190 | return self.deformation_net.get_aabb 191 | @property 192 | def get_empty_ratio(self): 193 | return self.deformation_net.get_empty_ratio 194 | 195 | def forward_static(self, points): 196 | points = self.deformation_net(points) 197 | return points 198 | def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, shs=None, times_sel=None): 199 | # times_emb = poc_fre(times_sel, self.time_poc) 200 | point_emb = poc_fre(point,self.pos_poc) 201 | scales_emb = poc_fre(scales,self.rotation_scaling_poc) 202 | rotations_emb = poc_fre(rotations,self.rotation_scaling_poc) 203 | # time_emb = poc_fre(times_sel, self.time_poc) 204 | # times_feature = self.timenet(time_emb) 205 | means3D, scales, rotations, opacity, shs = self.deformation_net( point_emb, 206 | scales_emb, 207 | rotations_emb, 208 | opacity, 209 | shs, 210 | None, 211 | times_sel) 212 | return means3D, scales, rotations, opacity, shs 213 | def get_mlp_parameters(self): 214 | return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters()) 215 | def get_grid_parameters(self): 216 | return self.deformation_net.get_grid_parameters() 217 | 218 | def initialize_weights(m): 219 | if isinstance(m, nn.Linear): 220 | # init.constant_(m.weight, 0) 221 | init.xavier_uniform_(m.weight,gain=1) 222 | if m.bias is not None: 223 | init.xavier_uniform_(m.weight,gain=1) 224 | # init.constant_(m.bias, 0) 225 | def poc_fre(input_data,poc_buf): 226 | 227 | input_data_emb = (input_data.unsqueeze(-1) * poc_buf).flatten(-2) 228 | input_data_sin = input_data_emb.sin() 229 | input_data_cos = input_data_emb.cos() 230 | input_data_emb = torch.cat([input_data, input_data_sin,input_data_cos], -1) 231 | return input_data_emb -------------------------------------------------------------------------------- /scene/grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import functools 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | # import tinycudann as tcnn 10 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | 13 | ''' Dense 3D grid 14 | ''' 15 | class DenseGrid(nn.Module): 16 | def __init__(self, channels, world_size, **kwargs): 17 | super(DenseGrid, self).__init__() 18 | self.channels = channels 19 | self.world_size = world_size 20 | 21 | self.grid = nn.Parameter(torch.ones([1, channels, *world_size])) 22 | 23 | def forward(self, xyz): 24 | ''' 25 | xyz: global coordinates to query 26 | ''' 27 | shape = xyz.shape[:-1] 28 | xyz = xyz.reshape(1,1,1,-1,3) 29 | ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 30 | out = F.grid_sample(self.grid, ind_norm, mode='bilinear', align_corners=True) 31 | out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels) 32 | # if self.channels == 1: 33 | # out = out.squeeze(-1) 34 | return out 35 | 36 | def scale_volume_grid(self, new_world_size): 37 | if self.channels == 0: 38 | self.grid = nn.Parameter(torch.ones([1, self.channels, *new_world_size])) 39 | else: 40 | self.grid = nn.Parameter( 41 | F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True)) 42 | def set_aabb(self, xyz_max, xyz_min): 43 | self.register_buffer('xyz_min', torch.Tensor(xyz_min)) 44 | self.register_buffer('xyz_max', torch.Tensor(xyz_max)) 45 | def get_dense_grid(self): 46 | return self.grid 47 | 48 | @torch.no_grad() 49 | def __isub__(self, val): 50 | self.grid.data -= val 51 | return self 52 | 53 | def extra_repr(self): 54 | return f'channels={self.channels}, world_size={self.world_size}' 55 | 56 | # class HashHexPlane(nn.Module): 57 | # def __init__(self,hparams, 58 | # desired_resolution=1024, 59 | # base_solution=128, 60 | # n_levels=4, 61 | # ): 62 | # super(HashHexPlane, self).__init__() 63 | 64 | # per_level_scale = np.exp2(np.log2(desired_resolution / base_solution) / (int(n_levels) - 1)) 65 | # encoding_2d_config = { 66 | # "otype": "Grid", 67 | # "type": "Hash", 68 | # "n_levels": n_levels, 69 | # "n_features_per_level": 2, 70 | # "base_resolution": base_solution, 71 | # "per_level_scale":per_level_scale, 72 | # } 73 | # self.xy = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 74 | # self.yz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 75 | # self.xz = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 76 | # self.xt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 77 | # self.yt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 78 | # self.zt = tcnn.Encoding(n_input_dims=2, encoding_config=encoding_2d_config) 79 | 80 | # self.feat_dim = n_levels * 2 *3 81 | 82 | # def forward(self, x, bound): 83 | # x = (x + bound) / (2 * bound) # zyq: map to [0, 1] 84 | # xy_feat = self.xy(x[:, [0, 1]]) 85 | # yz_feat = self.yz(x[:, [0, 2]]) 86 | # xz_feat = self.xz(x[:, [1, 2]]) 87 | # xt_feat = self.xt(x[:, []]) 88 | # return torch.cat([xy_feat, yz_feat, xz_feat], dim=-1) -------------------------------------------------------------------------------- /scene/hexplane.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging as log 3 | from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_normalized_directions(directions): 11 | """SH encoding must be in the range [0, 1] 12 | 13 | Args: 14 | directions: batch of directions 15 | """ 16 | return (directions + 1.0) / 2.0 17 | 18 | 19 | def normalize_aabb(pts, aabb): 20 | return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 21 | def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: 22 | grid_dim = coords.shape[-1] 23 | 24 | if grid.dim() == grid_dim + 1: 25 | # no batch dimension present, need to add it 26 | grid = grid.unsqueeze(0) 27 | if coords.dim() == 2: 28 | coords = coords.unsqueeze(0) 29 | 30 | if grid_dim == 2 or grid_dim == 3: 31 | grid_sampler = F.grid_sample 32 | else: 33 | raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " 34 | f"implemented for 2 and 3D data.") 35 | 36 | coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) 37 | B, feature_dim = grid.shape[:2] 38 | n = coords.shape[-2] 39 | interp = grid_sampler( 40 | grid, # [B, feature_dim, reso, ...] 41 | coords, # [B, 1, ..., n, grid_dim] 42 | align_corners=align_corners, 43 | mode='bilinear', padding_mode='border') 44 | interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] 45 | interp = interp.squeeze() # [B?, n, feature_dim?] 46 | return interp 47 | 48 | def init_grid_param( 49 | grid_nd: int, 50 | in_dim: int, 51 | out_dim: int, 52 | reso: Sequence[int], 53 | a: float = 0.1, 54 | b: float = 0.5): 55 | assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" 56 | has_time_planes = in_dim == 4 57 | assert grid_nd <= in_dim 58 | coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) 59 | grid_coefs = nn.ParameterList() 60 | for ci, coo_comb in enumerate(coo_combs): 61 | new_grid_coef = nn.Parameter(torch.empty( 62 | [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] 63 | )) 64 | if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 65 | nn.init.ones_(new_grid_coef) 66 | else: 67 | nn.init.uniform_(new_grid_coef, a=a, b=b) 68 | grid_coefs.append(new_grid_coef) 69 | 70 | return grid_coefs 71 | 72 | 73 | def interpolate_ms_features(pts: torch.Tensor, 74 | ms_grids: Collection[Iterable[nn.Module]], 75 | grid_dimensions: int, 76 | concat_features: bool, 77 | num_levels: Optional[int], 78 | ) -> torch.Tensor: 79 | coo_combs = list(itertools.combinations( 80 | range(pts.shape[-1]), grid_dimensions) 81 | ) 82 | if num_levels is None: 83 | num_levels = len(ms_grids) 84 | multi_scale_interp = [] if concat_features else 0. 85 | grid: nn.ParameterList 86 | for scale_id, grid in enumerate(ms_grids[:num_levels]): 87 | interp_space = 1. 88 | for ci, coo_comb in enumerate(coo_combs): 89 | # interpolate in plane 90 | feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso 91 | interp_out_plane = ( 92 | grid_sample_wrapper(grid[ci], pts[..., coo_comb]) 93 | .view(-1, feature_dim) 94 | ) 95 | # compute product over planes 96 | interp_space = interp_space * interp_out_plane 97 | 98 | # combine over scales 99 | if concat_features: 100 | multi_scale_interp.append(interp_space) 101 | else: 102 | multi_scale_interp = multi_scale_interp + interp_space 103 | 104 | if concat_features: 105 | multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) 106 | return multi_scale_interp 107 | 108 | 109 | class HexPlaneField(nn.Module): 110 | def __init__( 111 | self, 112 | 113 | bounds, 114 | planeconfig, 115 | multires 116 | ) -> None: 117 | super().__init__() 118 | aabb = torch.tensor([[bounds,bounds,bounds], 119 | [-bounds,-bounds,-bounds]]) 120 | self.aabb = nn.Parameter(aabb, requires_grad=False) 121 | self.grid_config = [planeconfig] 122 | self.multiscale_res_multipliers = multires 123 | self.concat_features = True 124 | 125 | # 1. Init planes 126 | self.grids = nn.ModuleList() 127 | self.feat_dim = 0 128 | for res in self.multiscale_res_multipliers: 129 | # initialize coordinate grid 130 | config = self.grid_config[0].copy() 131 | # Resolution fix: multi-res only on spatial planes 132 | config["resolution"] = [ 133 | r * res for r in config["resolution"][:3] 134 | ] + config["resolution"][3:] 135 | gp = init_grid_param( 136 | grid_nd=config["grid_dimensions"], 137 | in_dim=config["input_coordinate_dim"], 138 | out_dim=config["output_coordinate_dim"], 139 | reso=config["resolution"], 140 | ) 141 | # shape[1] is out-dim - Concatenate over feature len for each scale 142 | if self.concat_features: 143 | self.feat_dim += gp[-1].shape[1] 144 | else: 145 | self.feat_dim = gp[-1].shape[1] 146 | self.grids.append(gp) 147 | # print(f"Initialized model grids: {self.grids}") 148 | print("feature_dim:",self.feat_dim) 149 | @property 150 | def get_aabb(self): 151 | return self.aabb[0], self.aabb[1] 152 | def set_aabb(self,xyz_max, xyz_min): 153 | aabb = torch.tensor([ 154 | xyz_max, 155 | xyz_min 156 | ],dtype=torch.float32) 157 | self.aabb = nn.Parameter(aabb,requires_grad=False) 158 | print("Voxel Plane: set aabb=",self.aabb) 159 | 160 | def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): 161 | """Computes and returns the densities.""" 162 | # breakpoint() 163 | pts = normalize_aabb(pts, self.aabb) 164 | pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] 165 | 166 | pts = pts.reshape(-1, pts.shape[-1]) 167 | features = interpolate_ms_features( 168 | pts, ms_grids=self.grids, # noqa 169 | grid_dimensions=self.grid_config[0]["grid_dimensions"], 170 | concat_features=self.concat_features, num_levels=None) 171 | if len(features) < 1: 172 | features = torch.zeros((0, 1)).to(features.device) 173 | 174 | 175 | return features 176 | 177 | def forward(self, 178 | pts: torch.Tensor, 179 | timestamps: Optional[torch.Tensor] = None): 180 | 181 | features = self.get_density(pts, timestamps) 182 | 183 | return features 184 | -------------------------------------------------------------------------------- /scene/hyper_loader.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import json 6 | import os 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | import math 13 | from tqdm import tqdm 14 | from scene.utils import Camera 15 | from typing import NamedTuple 16 | from torch.utils.data import Dataset 17 | from utils.general_utils import PILtoTorch 18 | # from scene.dataset_readers import 19 | import torch.nn.functional as F 20 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 21 | from utils.pose_utils import smooth_camera_poses 22 | class CameraInfo(NamedTuple): 23 | uid: int 24 | R: np.array 25 | T: np.array 26 | FovY: np.array 27 | FovX: np.array 28 | image: np.array 29 | image_path: str 30 | image_name: str 31 | width: int 32 | height: int 33 | time : float 34 | mask: np.array 35 | 36 | 37 | class Load_hyper_data(Dataset): 38 | def __init__(self, 39 | datadir, 40 | ratio=1.0, 41 | use_bg_points=False, 42 | split="train" 43 | ): 44 | 45 | from .utils import Camera 46 | datadir = os.path.expanduser(datadir) 47 | with open(f'{datadir}/scene.json', 'r') as f: 48 | scene_json = json.load(f) 49 | with open(f'{datadir}/metadata.json', 'r') as f: 50 | meta_json = json.load(f) 51 | with open(f'{datadir}/dataset.json', 'r') as f: 52 | dataset_json = json.load(f) 53 | 54 | self.near = scene_json['near'] 55 | self.far = scene_json['far'] 56 | self.coord_scale = scene_json['scale'] 57 | self.scene_center = scene_json['center'] 58 | 59 | self.all_img = dataset_json['ids'] 60 | self.val_id = dataset_json['val_ids'] 61 | self.split = split 62 | if len(self.val_id) == 0: 63 | self.i_train = np.array([i for i in np.arange(len(self.all_img)) if 64 | (i%4 == 0)]) 65 | self.i_test = self.i_train+2 66 | self.i_test = self.i_test[:-1,] 67 | else: 68 | self.train_id = dataset_json['train_ids'] 69 | self.i_test = [] 70 | self.i_train = [] 71 | for i in range(len(self.all_img)): 72 | id = self.all_img[i] 73 | if id in self.val_id: 74 | self.i_test.append(i) 75 | if id in self.train_id: 76 | self.i_train.append(i) 77 | 78 | self.all_cam = [meta_json[i]['camera_id'] for i in self.all_img] 79 | self.all_time = [meta_json[i]['warp_id'] for i in self.all_img] 80 | max_time = max(self.all_time) 81 | self.all_time = [meta_json[i]['warp_id']/max_time for i in self.all_img] 82 | self.selected_time = set(self.all_time) 83 | self.ratio = ratio 84 | self.max_time = max(self.all_time) 85 | self.min_time = min(self.all_time) 86 | self.i_video = [i for i in range(len(self.all_img))] 87 | self.i_video.sort() 88 | self.all_cam_params = [] 89 | for im in self.all_img: 90 | camera = Camera.from_json(f'{datadir}/camera/{im}.json') 91 | 92 | self.all_cam_params.append(camera) 93 | self.all_img_origin = self.all_img 94 | self.all_depth = [f'{datadir}/depth/{int(1/ratio)}x/{i}.npy' for i in self.all_img] 95 | 96 | self.all_img = [f'{datadir}/rgb/{int(1/ratio)}x/{i}.png' for i in self.all_img] 97 | 98 | self.h, self.w = self.all_cam_params[0].image_shape 99 | self.map = {} 100 | self.image_one = Image.open(self.all_img[0]) 101 | self.image_one_torch = PILtoTorch(self.image_one,None).to(torch.float32) 102 | if os.path.exists(os.path.join(datadir,"covisible")): 103 | self.image_mask = [f'{datadir}/covisible/{int(2)}x/val/{i}.png' for i in self.all_img_origin] 104 | else: 105 | self.image_mask = None 106 | 107 | # self.generate_video_path() 108 | # self.i_test 109 | def generate_video_path(self): 110 | 111 | self.select_video_cams = [item for i, item in enumerate(self.all_cam_params) if i % 1 == 0 ] 112 | self.video_path, self.video_time = smooth_camera_poses(self.select_video_cams,10) 113 | # breakpoint() 114 | self.video_path = self.video_path[:500] 115 | self.video_time = self.video_time[:500] 116 | # breakpoint() 117 | def __getitem__(self, index): 118 | if self.split == "train": 119 | return self.load_raw(self.i_train[index]) 120 | 121 | elif self.split == "test": 122 | return self.load_raw(self.i_test[index]) 123 | elif self.split == "video": 124 | return self.load_raw(index) 125 | def __len__(self): 126 | if self.split == "train": 127 | return len(self.i_train) 128 | elif self.split == "test": 129 | return len(self.i_test) 130 | elif self.split == "video": 131 | return len(self.i_test) 132 | # return len(self.video_v2) 133 | def load_video(self, idx): 134 | if idx in self.map.keys(): 135 | return self.map[idx] 136 | camera = self.all_cam_params[idx] 137 | w = self.image_one.size[0] 138 | h = self.image_one.size[1] 139 | time = self.video_time[idx] 140 | R = camera.orientation.T 141 | T = - camera.position @ R 142 | FovY = focal2fov(camera.focal_length, self.h) 143 | FovX = focal2fov(camera.focal_length, self.w) 144 | image_path = "/".join(self.all_img[idx].split("/")[:-1]) 145 | image_name = self.all_img[idx].split("/")[-1] 146 | caminfo = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=self.image_one_torch, 147 | image_path=image_path, image_name=image_name, width=w, height=h, time=time, mask=None 148 | ) 149 | self.map[idx] = caminfo 150 | return caminfo 151 | def load_raw(self, idx): 152 | if idx in self.map.keys(): 153 | return self.map[idx] 154 | camera = self.all_cam_params[idx] 155 | image = Image.open(self.all_img[idx]) 156 | w = image.size[0] 157 | h = image.size[1] 158 | image = PILtoTorch(image,None) 159 | image = image.to(torch.float32)[:3,:,:] 160 | time = self.all_time[idx] 161 | R = camera.orientation.T 162 | T = - camera.position @ R 163 | FovY = focal2fov(camera.focal_length, self.h) 164 | FovX = focal2fov(camera.focal_length, self.w) 165 | image_path = "/".join(self.all_img[idx].split("/")[:-1]) 166 | image_name = self.all_img[idx].split("/")[-1] 167 | if self.image_mask is not None and self.split == "test": 168 | mask = Image.open(self.image_mask[idx]) 169 | mask = PILtoTorch(mask,None) 170 | mask = mask.to(torch.float32)[0:1,:,:] 171 | 172 | mask = F.interpolate(mask.unsqueeze(0), size=[self.h, self.w], mode='bilinear', align_corners=False).squeeze(0) 173 | else: 174 | mask = None 175 | 176 | 177 | caminfo = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 178 | image_path=image_path, image_name=image_name, width=w, height=h, time=time, mask=mask 179 | ) 180 | self.map[idx] = caminfo 181 | return caminfo 182 | 183 | 184 | def format_hyper_data(data_class, split): 185 | if split == "train": 186 | data_idx = data_class.i_train 187 | elif split == "test": 188 | data_idx = data_class.i_test 189 | # dataset = data_class.copy() 190 | # dataset.mode = split 191 | cam_infos = [] 192 | for uid, index in tqdm(enumerate(data_idx)): 193 | camera = data_class.all_cam_params[index] 194 | # image = Image.open(data_class.all_img[index]) 195 | # image = PILtoTorch(image,None) 196 | time = data_class.all_time[index] 197 | R = camera.orientation.T 198 | T = - camera.position @ R 199 | FovY = focal2fov(camera.focal_length, data_class.h) 200 | FovX = focal2fov(camera.focal_length, data_class.w) 201 | image_path = "/".join(data_class.all_img[index].split("/")[:-1]) 202 | image_name = data_class.all_img[index].split("/")[-1] 203 | 204 | if data_class.image_mask is not None and data_class.split == "test": 205 | mask = Image.open(data_class.image_mask[index]) 206 | mask = PILtoTorch(mask,None) 207 | 208 | mask = mask.to(torch.float32)[0:1,:,:] 209 | 210 | 211 | else: 212 | mask = None 213 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=None, 214 | image_path=image_path, image_name=image_name, width=int(data_class.w), 215 | height=int(data_class.h), time=time, mask=mask 216 | ) 217 | cam_infos.append(cam_info) 218 | return cam_infos -------------------------------------------------------------------------------- /scene/multipleview_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from utils.graphics_utils import focal2fov 6 | from scene.colmap_loader import qvec2rotmat 7 | from scene.dataset_readers import CameraInfo 8 | from scene.neural_3D_dataset_NDC import get_spiral 9 | from torchvision import transforms as T 10 | 11 | 12 | class multipleview_dataset(Dataset): 13 | def __init__( 14 | self, 15 | cam_extrinsics, 16 | cam_intrinsics, 17 | cam_folder, 18 | split 19 | ): 20 | self.focal = [cam_intrinsics[1].params[0], cam_intrinsics[1].params[0]] 21 | height=cam_intrinsics[1].height 22 | width=cam_intrinsics[1].width 23 | self.FovY = focal2fov(self.focal[0], height) 24 | self.FovX = focal2fov(self.focal[0], width) 25 | self.transform = T.ToTensor() 26 | self.image_paths, self.image_poses, self.image_times= self.load_images_path(cam_folder, cam_extrinsics,cam_intrinsics,split) 27 | if split=="test": 28 | self.video_cam_infos=self.get_video_cam_infos(cam_folder) 29 | 30 | 31 | def load_images_path(self, cam_folder, cam_extrinsics,cam_intrinsics,split): 32 | image_length = len(os.listdir(os.path.join(cam_folder,"cam01"))) 33 | #len_cam=len(cam_extrinsics) 34 | image_paths=[] 35 | image_poses=[] 36 | image_times=[] 37 | for idx, key in enumerate(cam_extrinsics): 38 | extr = cam_extrinsics[key] 39 | R = np.transpose(qvec2rotmat(extr.qvec)) 40 | T = np.array(extr.tvec) 41 | 42 | number = os.path.basename(extr.name)[5:-4] 43 | images_folder=os.path.join(cam_folder,"cam"+number.zfill(2)) 44 | 45 | image_range=range(image_length) 46 | if split=="test": 47 | image_range = [image_range[0],image_range[int(image_length/3)],image_range[int(image_length*2/3)]] 48 | 49 | for i in image_range: 50 | num=i+1 51 | image_path=os.path.join(images_folder,"frame_"+str(num).zfill(5)+".jpg") 52 | image_paths.append(image_path) 53 | image_poses.append((R,T)) 54 | image_times.append(float(i/image_length)) 55 | 56 | return image_paths, image_poses,image_times 57 | 58 | def get_video_cam_infos(self,datadir): 59 | poses_arr = np.load(os.path.join(datadir, "poses_bounds_multipleview.npy")) 60 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5) 61 | near_fars = poses_arr[:, -2:] 62 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 63 | N_views = 300 64 | val_poses = get_spiral(poses, near_fars, N_views=N_views) 65 | 66 | cameras = [] 67 | len_poses = len(val_poses) 68 | times = [i/len_poses for i in range(len_poses)] 69 | image = Image.open(self.image_paths[0]) 70 | image = self.transform(image) 71 | 72 | for idx, p in enumerate(val_poses): 73 | image_path = None 74 | image_name = f"{idx}" 75 | time = times[idx] 76 | pose = np.eye(4) 77 | pose[:3,:] = p[:3,:] 78 | R = pose[:3,:3] 79 | R = - R 80 | R[:,0] = -R[:,0] 81 | T = -pose[:3,3].dot(R) 82 | FovX = self.FovX 83 | FovY = self.FovY 84 | cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 85 | image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1], 86 | time = time, mask=None)) 87 | return cameras 88 | def __len__(self): 89 | return len(self.image_paths) 90 | def __getitem__(self, index): 91 | img = Image.open(self.image_paths[index]) 92 | img = self.transform(img) 93 | return img, self.image_poses[index], self.image_times[index] 94 | def load_pose(self,index): 95 | return self.image_poses[index] -------------------------------------------------------------------------------- /scene/regulation.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | from typing import Sequence 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.optim.lr_scheduler 9 | from torch import nn 10 | 11 | 12 | 13 | def compute_plane_tv(t): 14 | batch_size, c, h, w = t.shape 15 | count_h = batch_size * c * (h - 1) * w 16 | count_w = batch_size * c * h * (w - 1) 17 | h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() 18 | w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() 19 | return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg 20 | 21 | 22 | def compute_plane_smoothness(t): 23 | batch_size, c, h, w = t.shape 24 | # Convolve with a second derivative filter, in the time dimension which is dimension 2 25 | first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] 26 | second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] 27 | # Take the L2 norm of the result 28 | return torch.square(second_difference).mean() 29 | 30 | 31 | class Regularizer(): 32 | def __init__(self, reg_type, initialization): 33 | self.reg_type = reg_type 34 | self.initialization = initialization 35 | self.weight = float(self.initialization) 36 | self.last_reg = None 37 | 38 | def step(self, global_step): 39 | pass 40 | 41 | def report(self, d): 42 | if self.last_reg is not None: 43 | d[self.reg_type].update(self.last_reg.item()) 44 | 45 | def regularize(self, *args, **kwargs) -> torch.Tensor: 46 | out = self._regularize(*args, **kwargs) * self.weight 47 | self.last_reg = out.detach() 48 | return out 49 | 50 | @abc.abstractmethod 51 | def _regularize(self, *args, **kwargs) -> torch.Tensor: 52 | raise NotImplementedError() 53 | 54 | def __str__(self): 55 | return f"Regularizer({self.reg_type}, weight={self.weight})" 56 | 57 | 58 | class PlaneTV(Regularizer): 59 | def __init__(self, initial_value, what: str = 'field'): 60 | if what not in {'field', 'proposal_network'}: 61 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 62 | f'but {what} was passed.') 63 | name = f'planeTV-{what[:2]}' 64 | super().__init__(name, initial_value) 65 | self.what = what 66 | 67 | def step(self, global_step): 68 | pass 69 | 70 | def _regularize(self, model, **kwargs): 71 | multi_res_grids: Sequence[nn.ParameterList] 72 | if self.what == 'field': 73 | multi_res_grids = model.field.grids 74 | elif self.what == 'proposal_network': 75 | multi_res_grids = [p.grids for p in model.proposal_networks] 76 | else: 77 | raise NotImplementedError(self.what) 78 | total = 0 79 | # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] 80 | for grids in multi_res_grids: 81 | if len(grids) == 3: 82 | spatial_grids = [0, 1, 2] 83 | else: 84 | spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal 85 | for grid_id in spatial_grids: 86 | total += compute_plane_tv(grids[grid_id]) 87 | for grid in grids: 88 | # grid: [1, c, h, w] 89 | total += compute_plane_tv(grid) 90 | return total 91 | 92 | 93 | class TimeSmoothness(Regularizer): 94 | def __init__(self, initial_value, what: str = 'field'): 95 | if what not in {'field', 'proposal_network'}: 96 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 97 | f'but {what} was passed.') 98 | name = f'time-smooth-{what[:2]}' 99 | super().__init__(name, initial_value) 100 | self.what = what 101 | 102 | def _regularize(self, model, **kwargs) -> torch.Tensor: 103 | multi_res_grids: Sequence[nn.ParameterList] 104 | if self.what == 'field': 105 | multi_res_grids = model.field.grids 106 | elif self.what == 'proposal_network': 107 | multi_res_grids = [p.grids for p in model.proposal_networks] 108 | else: 109 | raise NotImplementedError(self.what) 110 | total = 0 111 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 112 | for grids in multi_res_grids: 113 | if len(grids) == 3: 114 | time_grids = [] 115 | else: 116 | time_grids = [2, 4, 5] 117 | for grid_id in time_grids: 118 | total += compute_plane_smoothness(grids[grid_id]) 119 | return torch.as_tensor(total) 120 | 121 | 122 | 123 | class L1ProposalNetwork(Regularizer): 124 | def __init__(self, initial_value): 125 | super().__init__('l1-proposal-network', initial_value) 126 | 127 | def _regularize(self, model, **kwargs) -> torch.Tensor: 128 | grids = [p.grids for p in model.proposal_networks] 129 | total = 0.0 130 | for pn_grids in grids: 131 | for grid in pn_grids: 132 | total += torch.abs(grid).mean() 133 | return torch.as_tensor(total) 134 | 135 | 136 | class DepthTV(Regularizer): 137 | def __init__(self, initial_value): 138 | super().__init__('tv-depth', initial_value) 139 | 140 | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: 141 | depth = model_out['depth'] 142 | tv = compute_plane_tv( 143 | depth.reshape(64, 64)[None, None, :, :] 144 | ) 145 | return tv 146 | 147 | 148 | class L1TimePlanes(Regularizer): 149 | def __init__(self, initial_value, what='field'): 150 | if what not in {'field', 'proposal_network'}: 151 | raise ValueError(f'what must be one of "field" or "proposal_network" ' 152 | f'but {what} was passed.') 153 | super().__init__(f'l1-time-{what[:2]}', initial_value) 154 | self.what = what 155 | 156 | def _regularize(self, model, **kwargs) -> torch.Tensor: 157 | # model.grids is 6 x [1, rank * F_dim, reso, reso] 158 | multi_res_grids: Sequence[nn.ParameterList] 159 | if self.what == 'field': 160 | multi_res_grids = model.field.grids 161 | elif self.what == 'proposal_network': 162 | multi_res_grids = [p.grids for p in model.proposal_networks] 163 | else: 164 | raise NotImplementedError(self.what) 165 | 166 | total = 0.0 167 | for grids in multi_res_grids: 168 | if len(grids) == 3: 169 | continue 170 | else: 171 | # These are the spatiotemporal grids 172 | spatiotemporal_grids = [2, 4, 5] 173 | for grid_id in spatiotemporal_grids: 174 | total += torch.abs(1 - grids[grid_id]).mean() 175 | return torch.as_tensor(total) 176 | 177 | -------------------------------------------------------------------------------- /scripts/blender2colmap.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import glob 5 | import sys 6 | import json 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import shutil 10 | import math 11 | def fov2focal(fov, pixels): 12 | return pixels / (2 * math.tan(fov / 2)) 13 | def rotmat2qvec(R): 14 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 15 | K = np.array([ 16 | [Rxx - Ryy - Rzz, 0, 0, 0], 17 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 18 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 19 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 20 | eigvals, eigvecs = np.linalg.eigh(K) 21 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 22 | if qvec[0] < 0: 23 | qvec *= -1 24 | return qvec 25 | 26 | root_dir = sys.argv[1] 27 | colmap_dir = os.path.join(root_dir,"sparse_") 28 | if not os.path.exists(colmap_dir): 29 | os.makedirs(colmap_dir) 30 | imagecolmap_dir = os.path.join(root_dir,"image_colmap") 31 | if not os.path.exists(imagecolmap_dir): 32 | os.makedirs(imagecolmap_dir) 33 | 34 | image_dir = os.path.join(root_dir) 35 | images = os.listdir(image_dir) 36 | images.sort() 37 | camera_json = os.path.join(root_dir,"transforms_train.json") 38 | 39 | 40 | with open (camera_json) as f: 41 | meta = json.load(f) 42 | try: 43 | image_size = meta['w'], meta['h'] 44 | focal = [meta['fl_x'],meta['fl_y']] 45 | except: 46 | try: 47 | image_size = meta['frames'][0]['w'], meta['frames'][0]['h'] 48 | focal = [meta['frames'][0]['fl_x'],meta['frames'][0]['fl_y']] 49 | except: 50 | image_size = 800,800 51 | focal = fov2focal(meta['camera_angle_x'], 800) 52 | focal = [focal,focal] 53 | # size = image.size 54 | # breakpoint() 55 | object_images_file = open(os.path.join(colmap_dir,"images.txt"),"w") 56 | object_cameras_file = open(os.path.join(colmap_dir,"cameras.txt"),"w") 57 | 58 | idx=0 59 | sizes=1 60 | cnt=0 61 | while len(meta['frames'])//sizes > 200: 62 | sizes += 1 63 | for frame in meta['frames']: 64 | cnt+=1 65 | if cnt % sizes != 0: 66 | continue 67 | matrix = np.linalg.inv(np.array(frame["transform_matrix"])) 68 | R = -np.transpose(matrix[:3,:3]) 69 | R[:,0] = -R[:,0] 70 | T = -matrix[:3, 3] 71 | T = -np.matmul(R,T) 72 | T = [str(i) for i in T] 73 | qevc = [str(i) for i in rotmat2qvec(np.transpose(R))] 74 | print(idx+1," ".join(qevc)," ".join(T),1,frame['file_path'].split('/')[-1]+".png","\n",file=object_images_file) 75 | 76 | print(idx,"SIMPLE_PINHOLE",image_size[0],image_size[1],focal[0],image_size[0]/2,image_size[1]/2,file=object_cameras_file) 77 | idx+=1 78 | # breakpoint() 79 | print(os.path.join(image_dir,frame['file_path']),os.path.join(imagecolmap_dir,frame['file_path'].split('/')[-1]+".png")) 80 | shutil.copy(os.path.join(image_dir,frame['file_path']+".png"),os.path.join(imagecolmap_dir,frame['file_path'].split('/')[-1]+".png")) 81 | # write camera infomation. 82 | # print(1,"SIMPLE_PINHOLE",image_size[0],image_size[1],focal[0],image_sizep0/2,image_size[1]/2,file=object_cameras_file) 83 | object_point_file = open(os.path.join(colmap_dir,"points3D.txt"),"w") 84 | 85 | object_cameras_file.close() 86 | object_images_file.close() 87 | object_point_file.close() 88 | 89 | -------------------------------------------------------------------------------- /scripts/cal_modelsize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def calculate_total_size_of_files(folders): 4 | total_size = 0 5 | 6 | for folder_name in folders: 7 | deformation_path = os.path.join(folder_name, "./point_cloud/coarse_iteration_3000/deformation.pth") 8 | point_cloud_path = os.path.join(folder_name, "./point_cloud/coarse_iteration_3000/point_cloud.ply") 9 | # print(point_cloud_path) 10 | if os.path.exists(deformation_path): 11 | deformation_size = os.path.getsize(deformation_path)/(1024*1024) 12 | total_size += deformation_size 13 | 14 | if os.path.exists(point_cloud_path): 15 | point_cloud_size = os.path.getsize(point_cloud_path)/(1024*1024) 16 | total_size += point_cloud_size 17 | 18 | return total_size 19 | 20 | for model_name in ["dnerf_3dgs"]: 21 | # model_name = "dnerf_tv" 22 | folder_names = ["bouncingball", "hook", "hellwarrior","jumpingjack","lego","mutant","standup","trex"] 23 | new_folder_names = [os.path.join("output",model_name,i) for i in folder_names] 24 | total_size = calculate_total_size_of_files(new_folder_names) 25 | print(model_name, "average size (MB):", total_size/len(folder_names)) 26 | -------------------------------------------------------------------------------- /scripts/downsample_point.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import sys 3 | def process_ply_file(input_file, output_file): 4 | # 读取输入的ply文件 5 | pcd = o3d.io.read_point_cloud(input_file) 6 | print(f"Total points: {len(pcd.points)}") 7 | 8 | # 通过点云下采样将输入的点云减少 9 | voxel_size=0.02 10 | while len(pcd.points) > 40000: 11 | pcd = pcd.voxel_down_sample(voxel_size=voxel_size) 12 | print(f"Downsampled points: {len(pcd.points)}") 13 | voxel_size+=0.01 14 | 15 | # 将结果保存到输入的路径中 16 | o3d.io.write_point_cloud(output_file, pcd) 17 | 18 | # 使用函数 19 | process_ply_file(sys.argv[1], sys.argv[2]) -------------------------------------------------------------------------------- /scripts/extractimages.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | 5 | folder_path = sys.argv[1] 6 | 7 | colmap_path = "./colmap_tmp" 8 | images_path = os.path.join(colmap_path, "images") 9 | os.makedirs(images_path, exist_ok=True) 10 | i=0 11 | 12 | dir1=os.path.join("data",folder_path) 13 | for folder_name in sorted(os.listdir(dir1)): 14 | dir2=os.path.join(dir1,folder_name) 15 | for file_name in os.listdir(dir2): 16 | if file_name.startswith("frame_00001"): 17 | i=i+1 18 | src_path = os.path.join(dir2, file_name) 19 | dst_path = os.path.join(images_path, f"image{i}.jpg") 20 | shutil.copyfile(src_path, dst_path) 21 | 22 | print("End!") 23 | -------------------------------------------------------------------------------- /scripts/grow_point.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | def grow_sparse_regions(input_file, output_file): 5 | pcd = o3d.io.read_point_cloud(input_file) 6 | densities = o3d.geometry.PointCloud.compute_nearest_neighbor_distance(pcd) 7 | avg_density = np.average(densities) 8 | print(f"Average density: {avg_density}") 9 | sparse_indices = np.where(densities > avg_density * 1.2)[0] 10 | sparse_points = np.asarray(pcd.points)[sparse_indices] 11 | 12 | 13 | o3d.io.write_point_cloud(output_file, pcd) 14 | 15 | grow_sparse_regions("data/hypernerf/vrig/chickchicken/dense_downsample.ply", "data/hypernerf/interp/chickchicken/dense_downsample.ply") -------------------------------------------------------------------------------- /scripts/hypernerf2colmap.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import glob 5 | import sys 6 | import json 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import shutil 10 | def rotmat2qvec(R): 11 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 12 | K = np.array([ 13 | [Rxx - Ryy - Rzz, 0, 0, 0], 14 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 15 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 16 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 17 | eigvals, eigvecs = np.linalg.eigh(K) 18 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 19 | if qvec[0] < 0: 20 | qvec *= -1 21 | return qvec 22 | 23 | root_dir = sys.argv[1] 24 | colmap_dir = os.path.join(root_dir,"sparse_") 25 | if not os.path.exists(colmap_dir): 26 | os.makedirs(colmap_dir) 27 | imagecolmap_dir = os.path.join(root_dir,"image_colmap") 28 | if not os.path.exists(imagecolmap_dir): 29 | os.makedirs(imagecolmap_dir) 30 | 31 | image_dir = os.path.join(root_dir,"rgb","2x") 32 | images = os.listdir(image_dir) 33 | images.sort() 34 | camera_dir = os.path.join(root_dir,"camera") 35 | cameras = os.listdir(camera_dir) 36 | cameras.sort() 37 | cams = [] 38 | for jsonfile in tqdm(cameras): 39 | with open (os.path.join(camera_dir,jsonfile)) as f: 40 | cams.append(json.load(f)) 41 | image_size = cams[0]['image_size'] 42 | image = Image.open(os.path.join(image_dir,images[0])) 43 | size = image.size 44 | object_images_file = open(os.path.join(colmap_dir,"images.txt"),"w") 45 | object_cameras_file = open(os.path.join(colmap_dir,"cameras.txt"),"w") 46 | 47 | idx=0 48 | cnt=0 49 | sizes=2 50 | while len(cams)//sizes > 200: 51 | sizes += 1 52 | for cam, image in zip(cams, images): 53 | cnt+=1 54 | 55 | if cnt % sizes != 0: 56 | continue 57 | R = np.array(cam['orientation']).T 58 | T = -np.array(cam['position'])@R 59 | 60 | T = [str(i) for i in T] 61 | qevc = [str(i) for i in rotmat2qvec(R.T)] 62 | print(idx+1," ".join(qevc)," ".join(T),1,image,"\n",file=object_images_file) 63 | 64 | print(idx,"SIMPLE_PINHOLE",image_size[0]/2,image_size[1]/2,cam['focal_length']/2,cam['principal_point'][0]/2,cam['principal_point'][1]/2,file=object_cameras_file) 65 | idx+=1 66 | shutil.copy(os.path.join(image_dir,image),os.path.join(imagecolmap_dir,image)) 67 | print(idx) 68 | # write camera infomation. 69 | object_point_file = open(os.path.join(colmap_dir,"points3D.txt"),"w") 70 | 71 | object_cameras_file.close() 72 | object_images_file.close() 73 | object_point_file.close() 74 | -------------------------------------------------------------------------------- /scripts/llff2colmap.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import glob 5 | import sys 6 | def rotmat2qvec(R): 7 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 8 | K = np.array([ 9 | [Rxx - Ryy - Rzz, 0, 0, 0], 10 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 11 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 12 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 13 | eigvals, eigvecs = np.linalg.eigh(K) 14 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 15 | if qvec[0] < 0: 16 | qvec *= -1 17 | return qvec 18 | def normalize(v): 19 | """Normalize a vector.""" 20 | return v / np.linalg.norm(v) 21 | 22 | def average_poses(poses): 23 | """ 24 | Calculate the average pose, which is then used to center all poses 25 | using @center_poses. Its computation is as follows: 26 | 1. Compute the center: the average of pose centers. 27 | 2. Compute the z axis: the normalized average z axis. 28 | 3. Compute axis y': the average y axis. 29 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 30 | 5. Compute the y axis: z cross product x. 31 | 32 | Note that at step 3, we cannot directly use y' as y axis since it's 33 | not necessarily orthogonal to z axis. We need to pass from x to y. 34 | Inputs: 35 | poses: (N_images, 3, 4) 36 | Outputs: 37 | pose_avg: (3, 4) the average pose 38 | """ 39 | # 1. Compute the center 40 | center = poses[..., 3].mean(0) # (3) 41 | 42 | # 2. Compute the z axis 43 | z = normalize(poses[..., 2].mean(0)) # (3) 44 | 45 | # 3. Compute axis y' (no need to normalize as it's not the final output) 46 | y_ = poses[..., 1].mean(0) # (3) 47 | 48 | # 4. Compute the x axis 49 | x = normalize(np.cross(z, y_)) # (3) 50 | 51 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 52 | y = np.cross(x, z) # (3) 53 | 54 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 55 | 56 | return pose_avg 57 | 58 | blender2opencv = np.eye(4) 59 | def center_poses(poses, blender2opencv): 60 | """ 61 | Center the poses so that we can use NDC. 62 | See https://github.com/bmild/nerf/issues/34 63 | Inputs: 64 | poses: (N_images, 3, 4) 65 | Outputs: 66 | poses_centered: (N_images, 3, 4) the centered poses 67 | pose_avg: (3, 4) the average pose 68 | """ 69 | poses = poses @ blender2opencv 70 | pose_avg = average_poses(poses) # (3, 4) 71 | pose_avg_homo = np.eye(4) 72 | pose_avg_homo[ 73 | :3 74 | ] = pose_avg # convert to homogeneous coordinate for faster computation 75 | pose_avg_homo = pose_avg_homo 76 | # by simply adding 0, 0, 0, 1 as the last row 77 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 78 | poses_homo = np.concatenate( 79 | [poses, last_row], 1 80 | ) # (N_images, 4, 4) homogeneous coordinate 81 | 82 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 83 | # poses_centered = poses_centered @ blender2opencv 84 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 85 | 86 | return poses_centered, pose_avg_homo 87 | root_dir = sys.argv[1] 88 | colmap_dir = os.path.join(root_dir,"sparse_") 89 | if not os.path.exists(colmap_dir): 90 | os.makedirs(colmap_dir) 91 | poses_arr = np.load(os.path.join(root_dir, "poses_bounds.npy")) 92 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5) 93 | near_fars = poses_arr[:, -2:] 94 | videos = glob.glob(os.path.join(root_dir, "cam[0-9][0-9]")) 95 | videos = sorted(videos) 96 | assert len(videos) == poses_arr.shape[0] 97 | H, W, focal = poses[0, :, -1] 98 | focal = focal/2 99 | focal = [focal, focal] 100 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 101 | videos = glob.glob(os.path.join(root_dir, "cam[0-9][0-9]")) 102 | videos = sorted(videos) 103 | image_paths = [] 104 | for index, video_path in enumerate(videos): 105 | image_path = os.path.join(video_path,"images","0000.png") 106 | image_paths.append(image_path) 107 | print(image_paths) 108 | goal_dir = os.path.join(root_dir,"image_colmap") 109 | if not os.path.exists(goal_dir): 110 | os.makedirs(goal_dir) 111 | import shutil 112 | image_name_list =[] 113 | for index, image in enumerate(image_paths): 114 | image_name = image.split("/")[-1].split('.') 115 | image_name[0] = "r_%03d" % index 116 | print(image_name) 117 | # breakpoint() 118 | image_name = ".".join(image_name) 119 | image_name_list.append(image_name) 120 | goal_path = os.path.join(goal_dir,image_name) 121 | shutil.copy(image,goal_path) 122 | 123 | print(poses) 124 | # write image information. 125 | object_images_file = open(os.path.join(colmap_dir,"images.txt"),"w") 126 | for idx, pose in enumerate(poses): 127 | # pose_44 = np.eye(4) 128 | 129 | R = pose[:3,:3] 130 | R = -R 131 | R[:,0] = -R[:,0] 132 | T = pose[:3,3] 133 | 134 | R = np.linalg.inv(R) 135 | T = -np.matmul(R,T) 136 | T = [str(i) for i in T] 137 | qevc = [str(i) for i in rotmat2qvec(R)] 138 | print(idx+1," ".join(qevc)," ".join(T),1,image_name_list[idx],"\n",file=object_images_file) 139 | 140 | # write camera infomation. 141 | object_cameras_file = open(os.path.join(colmap_dir,"cameras.txt"),"w") 142 | print(1,"SIMPLE_PINHOLE",1352,1014,focal[0],1352/2,1014/2,file=object_cameras_file) # 143 | object_point_file = open(os.path.join(colmap_dir,"points3D.txt"),"w") 144 | 145 | object_cameras_file.close() 146 | object_images_file.close() 147 | object_point_file.close() 148 | -------------------------------------------------------------------------------- /scripts/merge_point.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import os 3 | from tqdm import tqdm 4 | def merge_point_clouds(directory, output_file): 5 | merged_pcd = o3d.geometry.PointCloud() 6 | 7 | for filename in tqdm(os.listdir(directory)): 8 | if filename.endswith('.ply'): 9 | pcd = o3d.io.read_point_cloud(os.path.join(directory, filename)) 10 | merged_pcd += pcd 11 | 12 | merged_pcd = merged_pcd.remove_duplicate_points() 13 | 14 | o3d.io.write_point_cloud(output_file, merged_pcd) 15 | 16 | merge_point_clouds("point_clouds_directory", "merged.ply") -------------------------------------------------------------------------------- /scripts/preprocess_dynerf.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import sys 3 | sys.path.append('./scene') 4 | from neural_3D_dataset_NDC import Neural3D_NDC_Dataset 5 | # import scene 6 | # from scene.neural_3D_dataset_NDC import Neural3D_NDC_Dataset 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser(description="Extract images from dynerf videos") 10 | parser.add_argument("--datadir", default='data/dynerf/cut_roasted_beef', type=str) 11 | args = parser.parse_args() 12 | train_dataset = Neural3D_NDC_Dataset(args.datadir, "train", 1.0, time_scale=1, 13 | scene_bbox_min=[-2.5, -2.0, -1.0], scene_bbox_max=[2.5, 2.0, 1.0], eval_index=0) 14 | test_dataset = Neural3D_NDC_Dataset(args.datadir, "test", 1.0, time_scale=1, 15 | scene_bbox_min=[-2.5, -2.0, -1.0], scene_bbox_max=[2.5, 2.0, 1.0], eval_index=0) 16 | -------------------------------------------------------------------------------- /scripts/process_dnerf.sh: -------------------------------------------------------------------------------- 1 | exp_name1=$1 2 | 3 | 4 | 5 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dnerf/jumpingjacks --port 7169 --expname "$exp_name1/jumpingjacks" --configs arguments/$exp_name1/jumpingjacks.py & 6 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/trex --port 7170 --expname "$exp_name1/trex" --configs arguments/$exp_name1/trex.py 7 | 8 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path "output/$exp_name1/jumpingjacks/" --skip_train --configs arguments/$exp_name1/jumpingjacks.py & 9 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path "output/$exp_name1/trex/" --skip_train --configs arguments/$exp_name1/trex.py 10 | wait 11 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name1/jumpingjacks/" & 12 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name1/trex/" 13 | 14 | wait 15 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/mutant --port 7168 --expname "$exp_name1/mutant" --configs arguments/$exp_name1/mutant.py & 16 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dnerf/standup --port 7166 --expname "$exp_name1/standup" --configs arguments/$exp_name1/standup.py 17 | 18 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path "output/$exp_name1/mutant/" --skip_train --configs arguments/$exp_name1/mutant.py & 19 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path "output/$exp_name1/standup/" --skip_train --configs arguments/$exp_name1/standup.py 20 | wait 21 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name1/mutant/" & 22 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name1/standup/" 23 | wait 24 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/hook --port 7369 --expname "$exp_name1/hook" --configs arguments/$exp_name1/hook.py & 25 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dnerf/hellwarrior --port 7370 --expname "$exp_name1/hellwarrior" --configs arguments/$exp_name1/hellwarrior.py 26 | wait 27 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path "output/$exp_name1/hellwarrior/" --skip_train --configs arguments/$exp_name1/hellwarrior.py & 28 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path "output/$exp_name1/hook/" --skip_train --configs arguments/$exp_name1/hook.py 29 | wait 30 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name1/hellwarrior/" & 31 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name1/hook/" 32 | wait 33 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/lego --port 7168 --expname "$exp_name1/lego" --configs arguments/$exp_name1/lego.py & 34 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dnerf/bouncingballs --port 7166 --expname "$exp_name1/bouncingballs" --configs arguments/$exp_name1/bouncingballs.py 35 | wait 36 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path "output/$exp_name1/bouncingballs/" --skip_train --configs arguments/$exp_name1/bouncingballs.py & 37 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path "output/$exp_name1/lego/" --skip_train --configs arguments/$exp_name1/lego.py 38 | wait 39 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name1/bouncingballs/" & 40 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name1/lego/" 41 | wait 42 | echo "Done" -------------------------------------------------------------------------------- /scripts/read_all_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | # exp_name = ["hypernerf"] 4 | # exp_name= ["dnerf"] 5 | exp_name=["dynerf"] 6 | scene_name = ["coffee_martini", "cook_spinach", "cut_roasted_beef", "flame_salmon_1", "flame_steak", "sear_steak"] 7 | # scene_name = ["bouncingballs","jumpingjacks","lego","standup","hook","mutant","hellwarrior","trex"] 8 | # scene_name = ["3dprinter","broom2","peel-banana","vrig-chicken"] 9 | json_name = "results.json" 10 | result_json = {"PSNR":0,"SSIM":0,"MS-SSIM":0,"D-SSIM":0,"LPIPS-vgg":0,"LPIPS-alex":0,"LPIPS":0} 11 | exp_json = {} 12 | for exps in exp_name: 13 | exp_json[exps] = result_json.copy() 14 | for scene in scene_name: 15 | for experiment in exp_name: 16 | load_path = os.path.join("output",experiment,scene,json_name) 17 | with open(load_path) as f: 18 | js = json.load(f) 19 | for res in ["ours_30000","ours_20000","ours_14000","ours_10000","ours_7000","ours_3000"]: 20 | if res in js.keys(): 21 | for key, item in js[res].items(): 22 | if key in exp_json[experiment].keys(): 23 | exp_json[experiment][key] += item 24 | print(scene, key, item) 25 | break 26 | 27 | # for scene in scene_name: 28 | 29 | for experiment in exp_name: 30 | print(exp_json[experiment]) 31 | for key, item in exp_json[experiment].items(): 32 | exp_json[experiment][key] /= len(scene_name) 33 | for key,item in exp_json.items(): 34 | print(key) 35 | print("PSNR,SSIM,D-SSIM,MS-SSIM,LPIPS-alex,LPIPS-vgg","LPIPS") 36 | print("%.4f"%item["PSNR"],"&","%.4f"%item["SSIM"],"%.4f"%item["D-SSIM"], 37 | "%.4f"%item["MS-SSIM"],"&","%.4f"%item["LPIPS-alex"],"%.4f"%item["LPIPS-vgg"], 38 | "%.4f"%item["LPIPS"]) 39 | # break -------------------------------------------------------------------------------- /scripts/select_image.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | 6 | import imageio 7 | data_path = "output/hypernerf_render/split-cookie/" 8 | # coarse_id = [i*50 for i in range(1,60)] 9 | # fine_id = [i * 50 for i in range(1,399)] 10 | coarse_id = [i * 50 for i in range(1, 10)] + [i * 50+1000 for i in range(40)] 11 | fine_id = [i * 10 for i in range(1, 100)] + [i * 50 for i in range(20,60)] + [i* 100 for i in range(30,100)] + [i*200 for i in range(50,140)] 12 | # breakpoint() 13 | times = 268 14 | # loading coarse images 15 | coarse_path = os.path.join(data_path,"coarse_render","images") 16 | fine_path = os.path.join(data_path,"fine_render","images") 17 | 18 | load_path = [] 19 | for index, frame in enumerate(coarse_id): 20 | idx = index * 2 21 | if (index // times) % 2 ==0: 22 | time_stamp = index % times 23 | else: 24 | time_stamp = times - 1 - (index % times) 25 | load_path.append(os.path.join(coarse_path,f"{frame}_{time_stamp}.jpg")) 26 | last_index = index 27 | for index, frame in enumerate(fine_id): 28 | thisindex = index + last_index 29 | if (thisindex // times) % 2 ==0: 30 | time_stamp = thisindex % times 31 | else: 32 | time_stamp = times - 1 - (thisindex % times) 33 | load_path.append(os.path.join(fine_path,f"{frame}_{time_stamp}.jpg")) 34 | # print(load_path,sep="\n") 35 | # breakpoint() 36 | writer = imageio.get_writer(os.path.join(data_path,"trainingstep.mp4"), fps=15) 37 | for image_file in load_path: 38 | image = imageio.imread(image_file) 39 | writer.append_data(image) 40 | 41 | writer.close() -------------------------------------------------------------------------------- /scripts/train_dnerf.sh: -------------------------------------------------------------------------------- 1 | exp_name1=$1 2 | 3 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/lego --port 6068 --expname "$exp_name1/lego" --configs arguments/$exp_name1/lego.py & 4 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dnerf/bouncingballs --port 6266 --expname "$exp_name1/bouncingballs" --configs arguments/$exp_name1/bouncingballs.py & 5 | wait 6 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/jumpingjacks --port 6069 --expname "$exp_name1/jumpingjacks" --configs arguments/$exp_name1/jumpingjacks.py & 7 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dnerf/trex --port 6070 --expname "$exp_name1/trex" --configs arguments/$exp_name1/trex.py & 8 | wait 9 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/mutant --port 6068 --expname "$exp_name1/mutant" --configs arguments/$exp_name1/mutant.py & 10 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dnerf/standup --port 6066 --expname "$exp_name1/standup" --configs arguments/$exp_name1/standup.py & 11 | wait 12 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dnerf/hook --port 6069 --expname "$exp_name1/hook" --configs arguments/$exp_name1/hook.py & 13 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dnerf/hellwarrior --port 6070 --expname "$exp_name1/hellwarrior" --configs arguments/$exp_name1/hellwarrior.py & 14 | wait 15 | echo "Done" -------------------------------------------------------------------------------- /scripts/train_dycheck.sh: -------------------------------------------------------------------------------- 1 | exp_name1=$1 2 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dycheck/spin --port 6084 --expname $exp_name1/spin/ --configs arguments/$exp_name1/default.py & 3 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dycheck/space-out --port 6083 --expname $exp_name1/space-out/ --configs arguments/$exp_name1/default.py & 4 | wait 5 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name1/space-out/ --configs arguments/$exp_name1/default.py & 6 | export CUDA_VISIBLE_DEVICES=3&&python render.py --model_path output/$exp_name1/spin/ --configs arguments/$exp_name1/default.py 7 | wait 8 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dycheck/teddy/ --port 6081 --expname $exp_name1/teddy/ --configs arguments/$exp_name1/default.py & 9 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dycheck/apple/ --port 6082 --expname $exp_name1/apple/ --configs arguments/$exp_name1/default.py 10 | 11 | wait 12 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name1/teddy/ --skip_train --configs arguments/$exp_name1/default.py & 13 | export CUDA_VISIBLE_DEVICES=3&&python render.py --model_path output/$exp_name1/apple/ --skip_train --configs arguments/$exp_name1/default.py 14 | 15 | 16 | wait 17 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path output/$exp_name1/apple/ & 18 | export CUDA_VISIBLE_DEVICES=3&&python metrics.py --model_path output/$exp_name1/teddy/ & 19 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path output/$exp_name1/space-out/ & 20 | export CUDA_VISIBLE_DEVICES=3&&python metrics.py --model_path output/$exp_name1/spin/ 21 | echo "Done" -------------------------------------------------------------------------------- /scripts/train_dynamic3dgs.sh: -------------------------------------------------------------------------------- 1 | exp_name1=$1 2 | 3 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/basketball --port 6068 --expname "$exp_name1/dynamic3dgs/basketball" --configs arguments/$exp_name1/default.py 4 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/boxes --port 6069 --expname "$exp_name1/dynamic3dgs/boxes" --configs arguments/$exp_name1/default.py 5 | wait 6 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/football --port 6068 --expname "$exp_name1/dynamic3dgs/football" --configs arguments/$exp_name1/default.py 7 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/juggle --port 6069 --expname "$exp_name1/dynamic3dgs/juggle" --configs arguments/$exp_name1/default.py 8 | wait 9 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/softball --port 6068 --expname "$exp_name1/dynamic3dgs/softball" --configs arguments/$exp_name1/default.py 10 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/dynamic3dgs/data/tennis --port 6069 --expname "$exp_name1/dynamic3dgs/tennis" --configs arguments/$exp_name1/default.py 11 | 12 | 13 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/basketball --configs arguments/$exp_name1/default.py --skip_train 14 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/boxes --configs arguments/$exp_name1/default.py --skip_train 15 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/football --configs arguments/$exp_name1/default.py --skip_train 16 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/juggle --configs arguments/$exp_name1/default.py --skip_train 17 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/softball --configs arguments/$exp_name1/default.py --skip_train 18 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name1/dynamic3dgs/tennis --configs arguments/$exp_name1/default.py --skip_train 19 | 20 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/basketball 21 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/boxes 22 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/football 23 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/juggle 24 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/softball 25 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path output/$exp_name/dynamic3dgs/tennis 26 | -------------------------------------------------------------------------------- /scripts/train_dynerf.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | export CUDA_VISIBLE_DEVICES=1&&python train.py -s data/dynerf/flame_salmon_1 --port 6468 --expname "$exp_name/flame_salmon_1" --configs arguments/$exp_name/flame_salmon_1.py & 3 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dynerf/coffee_martini --port 6472 --expname "$exp_name/coffee_martini" --configs arguments/$exp_name/coffee_martini.py & 4 | wait 5 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dynerf/cook_spinach --port 6436 --expname "$exp_name/cook_spinach" --configs arguments/$exp_name/cook_spinach.py & 6 | # wait 7 | export CUDA_VISIBLE_DEVICES=3&&python train.py -s data/dynerf/cut_roasted_beef --port 6470 --expname "$exp_name/cut_roasted_beef" --configs arguments/$exp_name/cut_roasted_beef.py 8 | wait 9 | export CUDA_VISIBLE_DEVICES=1&&python train.py -s data/dynerf/flame_steak --port 6471 --expname "$exp_name/flame_steak" --configs arguments/$exp_name/flame_steak.py & 10 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/dynerf/sear_steak --port 6569 --expname "$exp_name/sear_steak" --configs arguments/$exp_name/sear_steak.py 11 | wait 12 | 13 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/cut_roasted_beef --configs arguments/$exp_name/cut_roasted_beef.py --skip_train & 14 | export CUDA_VISIBLE_DEVICES=3&&python render.py --model_path output/$exp_name/sear_steak --configs arguments/$exp_name/sear_steak.py --skip_train 15 | wait 16 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/flame_steak --configs arguments/$exp_name/flame_steak.py --skip_train & 17 | export CUDA_VISIBLE_DEVICES=3&&python render.py --model_path output/$exp_name/flame_salmon_1 --configs arguments/$exp_name/flame_salmon_1.py --skip_train 18 | wait 19 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/cook_spinach --configs arguments/$exp_name/cook_spinach.py --skip_train & 20 | export CUDA_VISIBLE_DEVICES=3&&python render.py --model_path output/$exp_name/coffee_martini --configs arguments/$exp_name/coffee_martini.py --skip_train & 21 | wait 22 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/cut_roasted_beef/" & 23 | export CUDA_VISIBLE_DEVICES=3&&python metrics.py --model_path "output/$exp_name/cook_spinach/" 24 | wait 25 | export CUDA_VISIBLE_DEVICES=3&&python metrics.py --model_path "output/$exp_name/sear_steak/" & 26 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/flame_salmon_1/" 27 | wait 28 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/flame_steak/" & 29 | export CUDA_VISIBLE_DEVICES=3&&python metrics.py --model_path "output/$exp_name/coffee_martini/" 30 | echo "Done" -------------------------------------------------------------------------------- /scripts/train_hyper_interp.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/hypernerf/interp/aleks-teapot --port 6568 --expname "$exp_name/interp/aleks-teapot" --configs arguments/$exp_name/default.py & 3 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/hypernerf/interp/slice-banana --port 6566 --expname "$exp_name/interp/slice-banana" --configs arguments/$exp_name/default.py & 4 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/hypernerf/interp/chickchicken --port 6569 --expname "$exp_name/interp/interp-chicken" --configs arguments/$exp_name/default.py & 5 | 6 | wait 7 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/hypernerf/interp/cut-lemon1 --port 6670 --expname $exp_name/interp/cut-lemon1 --configs arguments/$exp_name/default.py & 8 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/hypernerf/interp/hand1-dense-v2 --port 6671 --expname $exp_name/interp/hand1-dense-v2 --configs arguments/$exp_name/default.py & 9 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/hypernerf/interp/torchocolate --port 6672 --expname $exp_name/interp/torchocolate --configs arguments/$exp_name/default.py & 10 | wait 11 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/interp/aleks-teapot --configs arguments/$exp_name/default.py --skip_train & 12 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/interp/slice-banana --configs arguments/$exp_name/default.py --skip_train & 13 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/interp/interp-chicken --configs arguments/$exp_name/default.py --skip_train & 14 | wait 15 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/interp/cut-lemon1 --configs arguments/$exp_name/default.py --skip_train & 16 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/interp/hand1-dense-v2 --configs arguments/$exp_name/default.py --skip_train& 17 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/interp/torchocolate --configs arguments/$exp_name/default.py --skip_train & 18 | 19 | wait 20 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/interp/aleks-teapot/" & 21 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/interp/slice-banana/" & 22 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/interp/interp-chicken/" 23 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/interp/cut-lemon1/" & 24 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/interp/hand1-dense-v2/" & 25 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/interp/torchocolate/" 26 | wait 27 | echo "Done" -------------------------------------------------------------------------------- /scripts/train_hyper_virg.sh: -------------------------------------------------------------------------------- 1 | exp_name=$1 2 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/hypernerf/virg/broom2 --port 6068 --expname "$exp_name/broom2" --configs arguments/$exp_name/broom2.py & 3 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/hypernerf/virg/vrig-3dprinter --port 6066 --expname "$exp_name/3dprinter" --configs arguments/$exp_name/3dprinter.py & 4 | export CUDA_VISIBLE_DEVICES=2&&python train.py -s data/hypernerf/virg/peel-banana --port 6069 --expname "$exp_name/peel-banana" --configs arguments/$exp_name/banana.py & 5 | export CUDA_VISIBLE_DEVICES=0&&python train.py -s data/hypernerf/virg/vrig-chicken --port 6070 --expname "$exp_name/vrig-chicken" --configs arguments/$exp_name/chicken.py 6 | wait 7 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/broom2 --configs arguments/$exp_name/broom2.py --skip_train --skip_test & 8 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/3dprinter --configs arguments/$exp_name/3dprinter.py --skip_train --skip_test & 9 | export CUDA_VISIBLE_DEVICES=2&&python render.py --model_path output/$exp_name/peel-banana --configs arguments/$exp_name/banana.py --skip_train --skip_test & 10 | export CUDA_VISIBLE_DEVICES=0&&python render.py --model_path output/$exp_name/vrig-chicken --configs arguments/$exp_name/chicken.py --skip_train --skip_test & 11 | wait 12 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/broom2/" & 13 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/3dprinter/" & 14 | export CUDA_VISIBLE_DEVICES=2&&python metrics.py --model_path "output/$exp_name/peel-banana/" & 15 | export CUDA_VISIBLE_DEVICES=0&&python metrics.py --model_path "output/$exp_name/vrig-chicken/" & 16 | wait 17 | echo "Done" -------------------------------------------------------------------------------- /scripts/train_test_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import shutil 5 | from tqdm import tqdm 6 | def resort(frames): 7 | newframes = {} 8 | min_frameid = 10000000 9 | for frame in frames: 10 | frameid = int(frame["file_path"].split('/')[1].split('.')[0]) 11 | # print() 12 | if frameid < min_frameid:min_frameid = frameid 13 | newframes[frameid] = frame 14 | return [newframes[i+min_frameid] for i in range(len(frames))] 15 | inputpath = "data/custom/wave-ns/" 16 | outputpath = "data/custom/wave-train/" 17 | testskip = 10 18 | if not os.path.exists(outputpath): 19 | os.makedirs(outputpath) 20 | image_path = os.listdir(os.path.join(inputpath,"images")) 21 | import json 22 | with open(os.path.join(inputpath,"transforms.json"),"r") as f: 23 | 24 | meta = json.load(f) 25 | 26 | cnt = 0 27 | train_json = { 28 | "w": meta["w"], 29 | "h": meta["h"], 30 | "fl_x": meta["fl_x"], 31 | "fl_y": meta["fl_y"], 32 | "cx": meta["cx"], 33 | "cy": meta["cy"], 34 | 35 | "camera_model" : meta["camera_model"], 36 | "frames":[] 37 | } 38 | test_json = { 39 | "w": meta["w"], 40 | "h": meta["h"], 41 | "fl_x": meta["fl_x"], 42 | "fl_y": meta["fl_y"], 43 | "cx": meta["cx"], 44 | "cy": meta["cy"], 45 | "camera_model" : meta["camera_model"], 46 | "frames":[] 47 | } 48 | train_image_path = os.path.join(outputpath,"train") 49 | os.makedirs(train_image_path) 50 | test_image_path = os.path.join(outputpath,"test") 51 | os.makedirs(test_image_path) 52 | # meta["frames"] = resort(meta["frames"]) 53 | totallen = len(meta["frames"]) 54 | for index, frame in tqdm(enumerate(meta["frames"])): 55 | image_path = os.path.join(inputpath,frame["file_path"]) 56 | 57 | frame["time"] = index/totallen 58 | if index % testskip == 0: 59 | frame["file_path"] = "test/" + frame["file_path"].split("/")[-1] 60 | test_json["frames"].append(frame) 61 | shutil.copy(image_path, test_image_path) 62 | else: 63 | frame["file_path"] = "train/" + frame["file_path"].split("/")[-1] 64 | train_json["frames"].append(frame) 65 | shutil.copy(image_path, train_image_path) 66 | with open(os.path.join(outputpath,"transforms_train.json"),"w") as f: 67 | json.dump(train_json, f) 68 | with open(os.path.join(outputpath,"transforms_test.json"),"w") as f: 69 | json.dump(test_json, f) 70 | print("done") -------------------------------------------------------------------------------- /utils/TIMES.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/utils/TIMES.TTF -------------------------------------------------------------------------------- /utils/TIMESBD.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/utils/TIMESBD.TTF -------------------------------------------------------------------------------- /utils/TIMESBI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/utils/TIMESBI.TTF -------------------------------------------------------------------------------- /utils/TIMESI.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/4DGaussians/843d5ac636c37e4b611242287754f3d4ed150144/utils/TIMESI.TTF -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | 21 | 22 | # resized_image_rgb = PILtoTorch(cam_info.image, resolution) 23 | 24 | # gt_image = resized_image_rgb[:3, ...] 25 | # loaded_mask = None 26 | 27 | # if resized_image_rgb.shape[1] == 4: 28 | # loaded_mask = resized_image_rgb[3:4, ...] 29 | 30 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 31 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 32 | image=cam_info.image, gt_alpha_mask=None, 33 | image_name=cam_info.image_name, uid=id, data_device=args.data_device, 34 | time = cam_info.time, 35 | ) 36 | 37 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 38 | camera_list = [] 39 | 40 | for id, c in enumerate(cam_infos): 41 | camera_list.append(loadCam(args, id, c, resolution_scale)) 42 | 43 | return camera_list 44 | 45 | def camera_to_JSON(id, camera : Camera): 46 | Rt = np.zeros((4, 4)) 47 | Rt[:3, :3] = camera.R.transpose() 48 | Rt[:3, 3] = camera.T 49 | Rt[3, 3] = 1.0 50 | 51 | W2C = np.linalg.inv(Rt) 52 | pos = W2C[:3, 3] 53 | rot = W2C[:3, :3] 54 | serializable_array_2d = [x.tolist() for x in rot] 55 | camera_entry = { 56 | 'id' : id, 57 | 'img_name' : camera.image_name, 58 | 'width' : camera.width, 59 | 'height' : camera.height, 60 | 'position': pos.tolist(), 61 | 'rotation': serializable_array_2d, 62 | 'fy' : fov2focal(camera.FovY, camera.height), 63 | 'fx' : fov2focal(camera.FovX, camera.width) 64 | } 65 | return camera_entry 66 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | if resolution is not None: 23 | resized_image_PIL = pil_image.resize(resolution) 24 | else: 25 | resized_image_PIL = pil_image 26 | if np.array(resized_image_PIL).max()!=1: 27 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 28 | else: 29 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) 30 | if len(resized_image.shape) == 3: 31 | return resized_image.permute(2, 0, 1) 32 | else: 33 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 34 | 35 | def get_expon_lr_func( 36 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 37 | ): 38 | """ 39 | Copied from Plenoxels 40 | 41 | Continuous learning rate decay function. Adapted from JaxNeRF 42 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 43 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 44 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 45 | function of lr_delay_mult, such that the initial learning rate is 46 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 47 | to the normal learning rate when steps>lr_delay_steps. 48 | :param conf: config subtree 'lr' or similar 49 | :param max_steps: int, the number of steps during optimization. 50 | :return HoF which takes step as input 51 | """ 52 | 53 | def helper(step): 54 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 55 | # Disable this parameter 56 | return 0.0 57 | if lr_delay_steps > 0: 58 | # A kind of reverse cosine decay. 59 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 60 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 61 | ) 62 | else: 63 | delay_rate = 1.0 64 | t = np.clip(step / max_steps, 0, 1) 65 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 66 | return delay_rate * log_lerp 67 | 68 | return helper 69 | 70 | def strip_lowerdiag(L): 71 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 72 | 73 | uncertainty[:, 0] = L[:, 0, 0] 74 | uncertainty[:, 1] = L[:, 0, 1] 75 | uncertainty[:, 2] = L[:, 0, 2] 76 | uncertainty[:, 3] = L[:, 1, 1] 77 | uncertainty[:, 4] = L[:, 1, 2] 78 | uncertainty[:, 5] = L[:, 2, 2] 79 | return uncertainty 80 | 81 | def strip_symmetric(sym): 82 | return strip_lowerdiag(sym) 83 | 84 | def build_rotation(r): 85 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 86 | 87 | q = r / norm[:, None] 88 | 89 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 90 | 91 | r = q[:, 0] 92 | x = q[:, 1] 93 | y = q[:, 2] 94 | z = q[:, 3] 95 | 96 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 97 | R[:, 0, 1] = 2 * (x*y - r*z) 98 | R[:, 0, 2] = 2 * (x*z + r*y) 99 | R[:, 1, 0] = 2 * (x*y + r*z) 100 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 101 | R[:, 1, 2] = 2 * (y*z - r*x) 102 | R[:, 2, 0] = 2 * (x*z - r*y) 103 | R[:, 2, 1] = 2 * (y*z + r*x) 104 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 105 | return R 106 | 107 | def build_scaling_rotation(s, r): 108 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 109 | R = build_rotation(r) 110 | 111 | L[:,0,0] = s[:,0] 112 | L[:,1,1] = s[:,1] 113 | L[:,2,2] = s[:,2] 114 | 115 | L = R @ L 116 | return L 117 | 118 | def safe_state(silent): 119 | old_f = sys.stdout 120 | class F: 121 | def __init__(self, silent): 122 | self.silent = silent 123 | 124 | def write(self, x): 125 | if not self.silent: 126 | if x.endswith("\n"): 127 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 128 | else: 129 | old_f.write(x) 130 | 131 | def flush(self): 132 | old_f.flush() 133 | 134 | sys.stdout = F(silent) 135 | 136 | random.seed(0) 137 | np.random.seed(0) 138 | torch.manual_seed(0) 139 | torch.cuda.set_device(torch.device("cuda:0")) 140 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) 78 | 79 | def apply_rotation(q1, q2): 80 | """ 81 | Applies a rotation to a quaternion. 82 | 83 | Parameters: 84 | q1 (Tensor): The original quaternion. 85 | q2 (Tensor): The rotation quaternion to be applied. 86 | 87 | Returns: 88 | Tensor: The resulting quaternion after applying the rotation. 89 | """ 90 | # Extract components for readability 91 | w1, x1, y1, z1 = q1 92 | w2, x2, y2, z2 = q2 93 | 94 | # Compute the product of the two quaternions 95 | w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 96 | x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 97 | y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 98 | z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 99 | 100 | # Combine the components into a new quaternion tensor 101 | q3 = torch.tensor([w3, x3, y3, z3]) 102 | 103 | # Normalize the resulting quaternion 104 | q3_normalized = q3 / torch.norm(q3) 105 | 106 | return q3_normalized 107 | 108 | 109 | def batch_quaternion_multiply(q1, q2): 110 | """ 111 | Multiply batches of quaternions. 112 | 113 | Args: 114 | - q1 (torch.Tensor): A tensor of shape [N, 4] representing the first batch of quaternions. 115 | - q2 (torch.Tensor): A tensor of shape [N, 4] representing the second batch of quaternions. 116 | 117 | Returns: 118 | - torch.Tensor: The resulting batch of quaternions after applying the rotation. 119 | """ 120 | # Calculate the product of each quaternion in the batch 121 | w = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3] 122 | x = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2] 123 | y = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1] 124 | z = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0] 125 | 126 | # Combine into new quaternions 127 | q3 = torch.stack((w, x, y, z), dim=1) 128 | 129 | # Normalize the quaternions 130 | norm_q3 = q3 / torch.norm(q3, dim=1, keepdim=True) 131 | 132 | return norm_q3 133 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | @torch.no_grad() 17 | def psnr(img1, img2, mask=None): 18 | if mask is not None: 19 | img1 = img1.flatten(1) 20 | img2 = img2.flatten(1) 21 | 22 | mask = mask.flatten(1).repeat(3,1) 23 | mask = torch.where(mask!=0,True,False) 24 | img1 = img1[mask] 25 | img2 = img2[mask] 26 | 27 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 28 | 29 | else: 30 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 31 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float())) 32 | if mask is not None: 33 | if torch.isinf(psnr).any(): 34 | print(mse.mean(),psnr.mean()) 35 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse.float())) 36 | psnr = psnr[~torch.isinf(psnr)] 37 | 38 | return psnr 39 | -------------------------------------------------------------------------------- /utils/loader_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import random 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.utils.data.sampler import Sampler 11 | from torchvision import transforms, utils 12 | import random 13 | def get_stamp_list(dataset, timestamp): 14 | frame_length = int(len(dataset)/len(dataset.dataset.poses)) 15 | # print(frame_length) 16 | if timestamp > frame_length: 17 | raise IndexError("input timestamp bigger than total timestamp.") 18 | print("select index:",[i*frame_length+timestamp for i in range(len(dataset.dataset.poses))]) 19 | return [dataset[i*frame_length+timestamp] for i in range(len(dataset.dataset.poses))] 20 | class FineSampler(Sampler): 21 | def __init__(self, dataset): 22 | self.len_dataset = len(dataset) 23 | self.len_pose = len(dataset.dataset.poses) 24 | self.frame_length = int(self.len_dataset/ self.len_pose) 25 | 26 | sample_list = [] 27 | for i in range(self.frame_length): 28 | for j in range(4): 29 | idx = torch.randperm(self.len_pose) *self.frame_length + i 30 | # print(idx) 31 | # breakpoint() 32 | now_list = [] 33 | cnt = 0 34 | for item in idx.tolist(): 35 | now_list.append(item) 36 | cnt+=1 37 | if cnt % 2 == 0 and len(sample_list)>2: 38 | select_element = [x for x in random.sample(sample_list,2)] 39 | now_list += select_element 40 | 41 | sample_list += now_list 42 | 43 | self.sample_list = sample_list 44 | # print(self.sample_list) 45 | # breakpoint() 46 | print("one epoch containing:",len(self.sample_list)) 47 | def __iter__(self): 48 | 49 | return iter(self.sample_list) 50 | 51 | def __len__(self): 52 | return len(self.sample_list) 53 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | import lpips 17 | def lpips_loss(img1, img2, lpips_model): 18 | loss = lpips_model(img1,img2) 19 | return loss.mean() 20 | def l1_loss(network_output, gt): 21 | return torch.abs((network_output - gt)).mean() 22 | 23 | def l2_loss(network_output, gt): 24 | return ((network_output - gt) ** 2).mean() 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | def create_window(window_size, channel): 31 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 32 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 33 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 34 | return window 35 | 36 | def ssim(img1, img2, window_size=11, size_average=True): 37 | channel = img1.size(-3) 38 | window = create_window(window_size, channel) 39 | 40 | if img1.is_cuda: 41 | window = window.cuda(img1.get_device()) 42 | window = window.type_as(img1) 43 | 44 | return _ssim(img1, img2, window, window_size, channel, size_average) 45 | 46 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 47 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 48 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 49 | 50 | mu1_sq = mu1.pow(2) 51 | mu2_sq = mu2.pow(2) 52 | mu1_mu2 = mu1 * mu2 53 | 54 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 55 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 56 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 57 | 58 | C1 = 0.01 ** 2 59 | C2 = 0.03 ** 2 60 | 61 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 62 | 63 | if size_average: 64 | return ssim_map.mean() 65 | else: 66 | return ssim_map.mean(1).mean(1).mean(1) 67 | 68 | -------------------------------------------------------------------------------- /utils/params_utils.py: -------------------------------------------------------------------------------- 1 | def merge_hparams(args, config): 2 | params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] 3 | for param in params: 4 | if param in config.keys(): 5 | for key, value in config[param].items(): 6 | if hasattr(args, key): 7 | setattr(args, key, value) 8 | 9 | return args -------------------------------------------------------------------------------- /utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open3d as o3d 3 | 4 | from torch.utils.data import TensorDataset, random_split 5 | from tqdm import tqdm 6 | import open3d as o3d 7 | import numpy as np 8 | from torch_cluster import grid_cluster 9 | def voxel_down_sample_custom(points, voxel_size): 10 | voxel_grid = torch.floor(points / voxel_size) 11 | unique_voxels, inverse_indices = torch.unique(voxel_grid, dim=0, return_inverse=True) 12 | new_points = torch.zeros_like(unique_voxels) 13 | new_points_count = torch.zeros(unique_voxels.size(0), dtype=torch.long) 14 | new_points[inverse_indices] = points 15 | 16 | 17 | return new_points, inverse_indices 18 | def downsample_point_cloud(points, ratio): 19 | dataset = TensorDataset(points) 20 | num_points = len(dataset) 21 | num_downsampled_points = int(num_points * ratio) 22 | downsampled_dataset, _ = random_split(dataset, [num_downsampled_points, num_points - num_downsampled_points]) 23 | indices = torch.tensor([i for i, _ in enumerate(downsampled_dataset)]) 24 | downsampled_points = torch.stack([x for x, in downsampled_dataset]) 25 | return indices, downsampled_points 26 | 27 | def downsample_point_cloud_open3d(points, voxel_size): 28 | downsampled_pcd, inverse_indices = voxel_down_sample_custom(points, voxel_size) 29 | downsampled_points = downsampled_pcd 30 | return torch.tensor(downsampled_points) 31 | def downsample_point_cloud_cluster(points, voxel_size): 32 | cluster = grid_cluster(points, size=torch.tensor([1,1,1])) 33 | return cluster, points 34 | import torch 35 | from sklearn.neighbors import NearestNeighbors 36 | 37 | def upsample_point_cloud(points, density_threshold, displacement_scale, iter_pass): 38 | try: 39 | nbrs = NearestNeighbors(n_neighbors=2+iter_pass, algorithm='ball_tree').fit(points) 40 | distances, indices = nbrs.kneighbors(points) 41 | except: 42 | print("no point added") 43 | return points, torch.tensor([]), torch.tensor([]), torch.zeros((points.shape[0]), dtype=torch.bool) 44 | 45 | low_density_points = points[distances[:,1] > density_threshold] 46 | low_density_index = distances[:,1] > density_threshold 47 | low_density_index = torch.from_numpy(low_density_index) 48 | num_points = low_density_points.shape[0] 49 | displacements = torch.randn(num_points, 3) * displacement_scale 50 | new_points = low_density_points + displacements 51 | return points, low_density_points, new_points, low_density_index 52 | 53 | 54 | def visualize_point_cloud(points, low_density_points, new_points): 55 | pcd = o3d.geometry.PointCloud() 56 | low_density_points += 0.01 57 | all_points = np.concatenate([points, low_density_points, new_points], axis=0) 58 | pcd.points = o3d.utility.Vector3dVector(all_points) 59 | colors = np.zeros((all_points.shape[0], 3)) 60 | colors[:points.shape[0]] = [0, 0, 0] 61 | colors[points.shape[0]:points.shape[0]+low_density_points.shape[0]] = [1, 0, 0] 62 | colors[points.shape[0]+low_density_points.shape[0]:] = [0, 1, 0] 63 | pcd.colors = o3d.utility.Vector3dVector(colors) 64 | o3d.visualization.draw_geometries([pcd]) 65 | def combine_pointcloud(points, low_density_points, new_points): 66 | pcd = o3d.geometry.PointCloud() 67 | low_density_points += 0.01 68 | new_points -= 0.01 69 | all_points = np.concatenate([points, low_density_points, new_points], axis=0) 70 | pcd.points = o3d.utility.Vector3dVector(all_points) 71 | colors = np.zeros((all_points.shape[0], 3)) 72 | colors[:points.shape[0]] = [0, 0, 0] 73 | colors[points.shape[0]:points.shape[0]+low_density_points.shape[0]] = [1, 0, 0] 74 | colors[points.shape[0]+low_density_points.shape[0]:] = [0, 1, 0] 75 | pcd.colors = o3d.utility.Vector3dVector(colors) 76 | return pcd 77 | def addpoint(point_cloud,density_threshold,displacement_scale, iter_pass,): 78 | points, low_density_points, new_points, low_density_index = upsample_point_cloud(point_cloud,density_threshold,displacement_scale, iter_pass) 79 | print("low_density_points",low_density_points.shape[0]) 80 | return point_cloud, low_density_points, new_points, low_density_index 81 | def find_point_indices(origin_point, goal_point): 82 | indices = torch.nonzero((origin_point[:, None] == goal_point).all(-1), as_tuple=True)[0] 83 | return indices 84 | def find_indices_in_A(A, B): 85 | is_equal = torch.eq(B.view(1, -1, 3), A.view(-1, 1, 3)) 86 | u_indices = torch.nonzero(is_equal, as_tuple=False)[:, 0] 87 | return torch.unique(u_indices) 88 | if __name__ =="__main__": 89 | from time import time 90 | pass_=0 91 | filename = "point_cloud.ply" 92 | pcd = o3d.io.read_point_cloud(filename) 93 | point_cloud = torch.tensor(pcd.points) 94 | voxel_size = 8 95 | density_threshold=20 96 | displacement_scale=5 97 | for i in range(pass_+1, 50): 98 | print("pass ",i) 99 | time0 = time() 100 | 101 | point_downsample = point_cloud 102 | flag = False 103 | while point_downsample.shape[0]>1000: 104 | if flag: 105 | voxel_size+=8 106 | print("point size:",point_downsample.shape[0]) 107 | point_downsample = downsample_point_cloud_open3d(point_cloud,voxel_size=voxel_size) 108 | flag = True 109 | 110 | print("point size:",point_downsample.shape[0]) 111 | downsampled_point_index = find_indices_in_A(point_cloud, point_downsample) 112 | print("selected_num",point_cloud[downsampled_point_index].shape[0]) 113 | _, low_density_points, new_points, low_density_index = addpoint(point_cloud[downsampled_point_index],density_threshold=density_threshold,displacement_scale=displacement_scale,iter_pass=0) 114 | if new_points.shape[0] < 100: 115 | density_threshold /= 2 116 | displacement_scale /= 2 117 | print("reduce diplacement_scale to: ",displacement_scale) 118 | 119 | global_mask = torch.zeros((point_cloud.shape[0]), dtype=torch.bool) 120 | 121 | global_mask[downsampled_point_index] = low_density_index 122 | time1 = time() 123 | 124 | print("time cost:",time1-time0,"new_points:",new_points.shape[0]) 125 | if low_density_points.shape[0] == 0: 126 | print("no more points.") 127 | continue 128 | point = combine_pointcloud(point_cloud, low_density_points, new_points) 129 | point_cloud = torch.tensor(point.points) 130 | o3d.io.write_point_cloud(f"pointcloud/pass_{i}.ply",point) 131 | -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | from scene.utils import Camera 4 | from copy import deepcopy 5 | def rotation_matrix_to_quaternion(rotation_matrix): 6 | return R.from_matrix(rotation_matrix).as_quat() 7 | 8 | def quaternion_to_rotation_matrix(quat): 9 | return R.from_quat(quat).as_matrix() 10 | 11 | def quaternion_slerp(q1, q2, t): 12 | # 计算两个四元数之间的点积 13 | dot = np.dot(q1, q2) 14 | 15 | # 如果点积为负,取反一个四元数以保证最短路径插值 16 | if dot < 0.0: 17 | q1 = -q1 18 | dot = -dot 19 | 20 | # 防止数值误差导致的问题 21 | dot = np.clip(dot, -1.0, 1.0) 22 | 23 | # 计算插值参数 24 | theta = np.arccos(dot) * t 25 | q3 = q2 - q1 * dot 26 | q3 = q3 / np.linalg.norm(q3) 27 | 28 | # 计算插值结果 29 | return np.cos(theta) * q1 + np.sin(theta) * q3 30 | 31 | def bezier_interpolation(p1, p2, t): 32 | return (1 - t) * p1 + t * p2 33 | def linear_interpolation(v1, v2, t): 34 | return (1 - t) * v1 + t * v2 35 | def smooth_camera_poses(cameras, num_interpolations=5): 36 | smoothed_cameras = [] 37 | smoothed_times = [] 38 | total_poses = len(cameras) - 1 + (len(cameras) - 1) * num_interpolations 39 | time_increment = 10 / total_poses 40 | 41 | for i in range(len(cameras) - 1): 42 | cam1 = cameras[i] 43 | cam2 = cameras[i + 1] 44 | 45 | quat1 = rotation_matrix_to_quaternion(cam1.orientation) 46 | quat2 = rotation_matrix_to_quaternion(cam2.orientation) 47 | 48 | for j in range(num_interpolations + 1): 49 | t = j / (num_interpolations + 1) 50 | 51 | interp_orientation_quat = quaternion_slerp(quat1, quat2, t) 52 | interp_orientation_matrix = quaternion_to_rotation_matrix(interp_orientation_quat) 53 | 54 | interp_position = linear_interpolation(cam1.position, cam2.position, t) 55 | 56 | interp_time = i*10 / (len(cameras) - 1) + time_increment * j 57 | 58 | newcam = deepcopy(cam1) 59 | newcam.orientation = interp_orientation_matrix 60 | newcam.position = interp_position 61 | smoothed_cameras.append(newcam) 62 | smoothed_times.append(interp_time) 63 | smoothed_cameras.append(cameras[-1]) 64 | smoothed_times.append(1.0) 65 | print(smoothed_times) 66 | return smoothed_cameras, smoothed_times 67 | 68 | -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | @torch.no_grad() 3 | def get_state_at_time(pc,viewpoint_camera): 4 | means3D = pc.get_xyz 5 | time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1) 6 | opacity = pc._opacity 7 | shs = pc.get_features 8 | 9 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 10 | # scaling / rotation by the rasterizer. 11 | scales = pc._scaling 12 | rotations = pc._rotation 13 | cov3D_precomp = None 14 | means3D_final, scales_final, rotations_final, opacity_final, shs_final = pc._deformation(means3D, scales, 15 | rotations, opacity, shs, 16 | time) 17 | 18 | return means3D_final, scales_final, rotations_final, opacity, shs_final -------------------------------------------------------------------------------- /utils/scene_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image, ImageDraw, ImageFont 4 | from matplotlib import pyplot as plt 5 | plt.rcParams['font.sans-serif'] = ['Times New Roman'] 6 | 7 | import numpy as np 8 | 9 | import copy 10 | @torch.no_grad() 11 | def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now, dataset_type): 12 | def render(gaussians, viewpoint, path, scaling, cam_type): 13 | # scaling_copy = gaussians._scaling 14 | render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage, cam_type=cam_type) 15 | label1 = f"stage:{stage},iter:{iteration}" 16 | times = time_now/60 17 | if times < 1: 18 | end = "min" 19 | else: 20 | end = "mins" 21 | label2 = "time:%.2f" % times + end 22 | image = render_pkg["render"] 23 | depth = render_pkg["depth"] 24 | if dataset_type == "PanopticSports": 25 | gt_np = viewpoint['image'].permute(1,2,0).cpu().numpy() 26 | else: 27 | gt_np = viewpoint.original_image.permute(1,2,0).cpu().numpy() 28 | image_np = image.permute(1, 2, 0).cpu().numpy() # (H, W, 3) 29 | depth_np = depth.permute(1, 2, 0).cpu().numpy() 30 | depth_np /= depth_np.max() 31 | depth_np = np.repeat(depth_np, 3, axis=2) 32 | image_np = np.concatenate((gt_np, image_np, depth_np), axis=1) 33 | image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) 34 | draw1 = ImageDraw.Draw(image_with_labels) 35 | font = ImageFont.truetype('./utils/TIMES.TTF', size=40) 36 | text_color = (255, 0, 0) 37 | label1_position = (10, 10) 38 | label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) 39 | draw1.text(label1_position, label1, fill=text_color, font=font) 40 | draw1.text(label2_position, label2, fill=text_color, font=font) 41 | 42 | image_with_labels.save(path) 43 | render_base_path = os.path.join(scene.model_path, f"{stage}_render") 44 | point_cloud_path = os.path.join(render_base_path,"pointclouds") 45 | image_path = os.path.join(render_base_path,"images") 46 | if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): 47 | os.makedirs(render_base_path) 48 | if not os.path.exists(point_cloud_path): 49 | os.makedirs(point_cloud_path) 50 | if not os.path.exists(image_path): 51 | os.makedirs(image_path) 52 | 53 | for idx in range(len(viewpoints)): 54 | image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg") 55 | render(gaussians,viewpoints[idx],image_save_path,scaling = 1,cam_type=dataset_type) 56 | pc_mask = gaussians.get_opacity 57 | pc_mask = pc_mask > 0.1 58 | 59 | def visualize_and_save_point_cloud(point_cloud, R, T, filename): 60 | fig = plt.figure() 61 | ax = fig.add_subplot(111, projection='3d') 62 | R = R.T 63 | T = -R.dot(T) 64 | transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) 65 | ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o') 66 | ax.axis("off") 67 | plt.savefig(filename) 68 | 69 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | class Timer: 3 | def __init__(self): 4 | self.start_time = None 5 | self.elapsed = 0 6 | self.paused = False 7 | 8 | def start(self): 9 | if self.start_time is None: 10 | self.start_time = time.time() 11 | elif self.paused: 12 | self.start_time = time.time() - self.elapsed 13 | self.paused = False 14 | 15 | def pause(self): 16 | if not self.paused: 17 | self.elapsed = time.time() - self.start_time 18 | self.paused = True 19 | 20 | def get_elapsed_time(self): 21 | if self.paused: 22 | return self.elapsed 23 | else: 24 | return time.time() - self.start_time --------------------------------------------------------------------------------