├── .github └── workflows │ └── stale.yml ├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── assets ├── analysis.png └── demo.gif ├── blender ├── .gitignore ├── README.md ├── cam_pose_utils │ ├── cam_reader.py │ ├── colmap_loader.py │ └── graphic_utils.py ├── generate_video.py ├── render.py ├── render_cfgs │ ├── dtu │ │ ├── background.json │ │ ├── background.py │ │ ├── light.json │ │ └── light.py │ ├── gauu │ │ ├── background.json │ │ └── light.json │ ├── matrixcity_aerial │ │ ├── background.json │ │ ├── background.py │ │ ├── light.json │ │ └── light.py │ ├── matrixcity_street │ │ ├── background.json │ │ └── light.json │ ├── mip │ │ ├── background.json │ │ ├── background.py │ │ ├── light.json │ │ └── light.py │ └── pixsfm │ │ ├── background.json │ │ ├── background.py │ │ ├── light.json │ │ └── light.py ├── render_sun.py └── render_utils │ ├── background_generator.py │ └── texture_allocator.py ├── configs ├── .gitignore ├── appearance.yaml ├── appearance_embedding_renderer │ ├── distributed-view_independent-128_dims-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth-scale_reg-mip.yaml │ ├── sh_view_dependent-128_dims-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth.yaml │ ├── sh_view_dependent-128_dims-lr_0.005-with_scheduler.yaml │ ├── sh_view_dependent-accel.yaml │ ├── sh_view_dependent-estimated_depth_reg-hard_depth.yaml │ ├── sh_view_dependent-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth.yaml │ ├── sh_view_dependent-lr_0.005-with_scheduler.yaml │ ├── sh_view_dependent-mip.yaml │ ├── sh_view_dependent-normalized-accel.yaml │ ├── sh_view_dependent-normalized-sim_reg-accel.yaml │ ├── sh_view_dependent.yaml │ ├── snippets │ │ ├── .gitignore │ │ └── tcnn.yaml │ ├── view_dependent-distributed.yaml │ ├── view_dependent-estimated_depth_reg-hard_depth.yaml │ ├── view_dependent-estimated_depth_reg.yaml │ ├── view_dependent.yaml │ ├── view_independent-2dgs.yaml │ ├── view_independent-distributed.yaml │ ├── view_independent-estimated_depth_reg-hard_depth.yaml │ ├── view_independent-estimated_depth_reg.yaml │ ├── view_independent-phototourism-sls-opacity_reg_0.01.yaml │ ├── view_independent-phototourism-sls_high_occlusion-opacity_reg_0.01.yaml │ └── view_independent.yaml ├── appearance_embedding_visibility_map_renderer │ ├── view_independent-2x_ds-exp.yaml │ └── view_independent-2x_ds.yaml ├── background_sphere-cameras_center-random_background.yaml ├── background_sphere.yaml ├── blender.yaml ├── blender_gsplat.yaml ├── citygs_lfls_coarse_sh2.yaml ├── citygs_lfls_sh2_trim.yaml ├── citygsv2_lfls_coarse_sh2.yaml ├── citygsv2_lfls_sh2_trim.yaml ├── citygsv2_mc_aerial_coarse_sh2.yaml ├── citygsv2_mc_aerial_sh2_trim.yaml ├── citygsv2_mc_street_coarse_sh2.yaml ├── citygsv2_mc_street_sh2.yaml ├── citygsv2_smbu_coarse_sh2.yaml ├── citygsv2_smbu_sh2_trim.yaml ├── citygsv2_upper_coarse_sh2.yaml ├── citygsv2_upper_sh2_trim.yaml ├── colmap_exp.yaml ├── ddp.yaml ├── ddp_not_find_unused.yaml ├── deformable_6dof_blender.yaml ├── deformable_6dof_real.yaml ├── deformable_blender.yaml ├── deformable_blender_rotate_xyz.yaml ├── deformable_real.yaml ├── depth_regularization │ ├── estimated_inverse_depth-alternating-l1-accel.yaml │ ├── estimated_inverse_depth-hard_depth-l1.yaml │ ├── estimated_inverse_depth-hard_depth-l1_ssim.yaml │ ├── estimated_inverse_depth-l1.yaml │ ├── estimated_inverse_depth-l1_ssim.yaml │ ├── estimated_inverse_depth-l2.yaml │ └── estimated_inverse_depth-normalized-l1.yaml ├── distributed-accel.yaml ├── distributed.yaml ├── feature_3dgs │ ├── lseg-speedup.yaml │ ├── lseg.yaml │ ├── sam-speedup.yaml │ └── sam.yaml ├── fused_ssim.yaml ├── gsplat-absgrad-experiment.yaml ├── gsplat-absgrad.yaml ├── gsplat-mcmc.yaml ├── gsplat.yaml ├── gsplat_v1-accel-steerable.yaml ├── gsplat_v1-accel.yaml ├── gsplat_v1-accel_more.yaml ├── gsplat_v1-tile_based_culling-selective_adam.yaml ├── gsplat_v1-tile_based_culling.yaml ├── gsplat_v1.yaml ├── image_on_gpu-uint8.yaml ├── image_on_gpu.yaml ├── larger_dataset.yaml ├── light_gaussian │ ├── prune_finetune-gsplat-experiment.yaml │ ├── prune_finetune-gsplat.yaml │ ├── prune_finetune.yaml │ ├── train_densify_prune-gsplat-experiment.yaml │ └── train_densify_prune-gsplat.yaml ├── matrixcity │ ├── README.md │ ├── depth-up_background_sphere.yaml │ ├── depth.yaml │ ├── gsplat-aerial.yaml │ ├── gsplat-aerial_street-depth_reg-example.yaml │ ├── gsplat-aerial_street-example.yaml │ └── hard_depth.yaml ├── mcmc.yaml ├── mip_splatting_gsplat_v2-blender.yaml ├── mip_splatting_gsplat_v2.yaml ├── pvg_dynamic.yaml ├── pypreprocess_gsplat.yaml ├── random_background.yaml ├── reorient.yaml ├── scale_reg.yaml ├── segany_splatting.yaml ├── spot_less_splats │ ├── gsplat-cluster.yaml │ ├── gsplat-mlp-high_occlusion.yaml │ ├── gsplat-mlp-mask_size_400-with_ssim-opacity_reg_0.01.yaml │ ├── gsplat-mlp-mask_size_400-with_ssim.yaml │ ├── gsplat-mlp-mask_size_400.yaml │ ├── gsplat-mlp-opacity_reg_0.01.yaml │ ├── gsplat-mlp-with_ssim.yaml │ ├── gsplat-mlp.yaml │ └── mlp.yaml ├── stp │ └── baseline.yaml ├── swag_baseline.yaml ├── taming_3dgs │ ├── fused_ssim.yaml │ ├── rasterizer-fused_ssim-aa.yaml │ ├── rasterizer-fused_ssim-sparse_adam-aa.yaml │ ├── rasterizer-fused_ssim-sparse_adam.yaml │ ├── rasterizer-fused_ssim.yaml │ └── rasterizer.yaml └── vanilla_2dgs.yaml ├── dataset.py ├── doc ├── data_preparation.md ├── installation.md ├── render_video.md └── run&eval.md ├── internal ├── __init__.py ├── callbacks.py ├── cameras │ ├── __init__.py │ └── cameras.py ├── cli.py ├── configs │ ├── .gitignore │ ├── __init__.py │ ├── appearance.py │ ├── dataset.py │ ├── instantiate_config.py │ ├── light_gaussian.py │ ├── model.py │ ├── optimization.py │ ├── segany_splatting.py │ └── tcnn_encoding_config.py ├── dataparsers │ ├── __init__.py │ ├── blender_dataparser.py │ ├── colmap_block_dataparser.py │ ├── colmap_dataparser.py │ ├── dataparser.py │ ├── estimated_depth_colmap_block_dataparser.py │ ├── estimated_depth_colmap_dataparser.py │ ├── feature_3dgs_dataparser.py │ ├── matrix_city_dataparser.py │ ├── nerfies_dataparser.py │ ├── ngp_dataparser.py │ ├── nsvf_dataparser.py │ ├── phototourism_dataparser.py │ ├── segany_colmap_dataparser.py │ ├── silvr_dataparser.py │ └── spotless_colmap_dataparser.py ├── dataset.py ├── density_controllers │ ├── __init__.py │ ├── accurate_visibility_filter_density_controller.py │ ├── citygsv2_density_controller.py │ ├── density_controller.py │ ├── distributed_vanilla_density_controller.py │ ├── foreground_first_density_controller.py │ ├── gs2d_density_controller.py │ ├── h3dgs_density_controller.py │ ├── logger_mixin.py │ ├── mcmc_density_controller.py │ ├── static_density_controller.py │ ├── taming_3dgs_density_controller.py │ └── vanilla_density_controller.py ├── encodings │ ├── __init__.py │ └── positional_encoding.py ├── entrypoints │ ├── __init__.py │ ├── gs2d_mesh_extraction.py │ ├── gspl.py │ ├── seganygs.py │ └── viewer.py ├── gaussian_splatting.py ├── metrics │ ├── __init__.py │ ├── appearance_feature_similarity_regularization_metrics.py │ ├── citygsv2_metrics.py │ ├── depth_metrics.py │ ├── feature_3dgs_metrics.py │ ├── gs2d_metrics.py │ ├── inverse_depth_metrics.py │ ├── mcmc_metrics.py │ ├── metric.py │ ├── pvg_dynamic_metrics.py │ ├── scale_regularization_metrics.py │ ├── spotless_metrics.py │ ├── vanilla_metrics.py │ ├── vanilla_with_fused_ssim_metrics.py │ └── visibility_map_metrics.py ├── model_components │ ├── __init__.py │ ├── envlight.py │ ├── gs4d_deformation.py │ ├── gs4d_grid.py │ └── gs4d_hexplane.py ├── models │ ├── .gitignore │ ├── __init__.py │ ├── appearance_feature_gaussian.py │ ├── appearance_gs2d.py │ ├── appearance_mip_gaussian.py │ ├── appearance_model.py │ ├── deform_model.py │ ├── flatten_gaussian_model.py │ ├── gaussian.py │ ├── gaussian_2d.py │ ├── mip_splatting.py │ ├── periodic_vibration_gaussian.py │ ├── sparse_adam_gaussian.py │ ├── swag_model.py │ ├── vanilla_deform_model.py │ ├── vanilla_gaussian.py │ └── vast_model.py ├── mp_strategy.py ├── optimizers.py ├── renderers │ ├── .gitignore │ ├── __init__.py │ ├── appearance_2dgs_renderer.py │ ├── appearance_mlp_renderer.py │ ├── contrastive_feature_renderer.py │ ├── deformable_renderer.py │ ├── feature_3dgs_renderer.py │ ├── gsplat_appearance_embedding_renderer.py │ ├── gsplat_appearance_embedding_visibility_map_renderer.py │ ├── gsplat_contrastive_feature_renderer.py │ ├── gsplat_distributed_appearance_embedding_renderer.py │ ├── gsplat_distributed_renderer.py │ ├── gsplat_hit_pixel_count_renderer.py │ ├── gsplat_mip_splatting_renderer_v2.py │ ├── gsplat_renderer.py │ ├── gsplat_v1_renderer.py │ ├── mip_splatting_gsplat_renderer.py │ ├── partition_lod_renderer.py │ ├── periodic_vibration_gaussian_renderer.py │ ├── pypreprocess_gsplat_renderer.py │ ├── renderer.py │ ├── rgb_mlp_renderer.py │ ├── seganygs_renderer.py │ ├── sep_depth_trim_2dgs_renderer.py │ ├── stp_renderer.py │ ├── swag_renderer.py │ ├── taming_3dgs_renderer.py │ ├── vanilla_2dgs_renderer.py │ ├── vanilla_deformable_renderer.py │ ├── vanilla_gs4d_renderer.py │ ├── vanilla_renderer.py │ └── vanilla_trim_renderer.py ├── schedulers.py ├── segany_splatting.py ├── utils │ ├── __init__.py │ ├── citygs_partitioning_utils.py │ ├── colmap.py │ ├── common.py │ ├── depth_map_utils.py │ ├── fisheye_utils.py │ ├── fix_lightning_save_hyperparameters.py │ ├── gaussian_containers.py │ ├── gaussian_model_editor.py │ ├── gaussian_model_loader.py │ ├── gaussian_projection.py │ ├── gaussian_utils.py │ ├── general_utils.py │ ├── graphics_utils.py │ ├── gs2d_mesh_utils.py │ ├── image_utils.py │ ├── las_utils.py │ ├── light_gaussian.py │ ├── lpips.py │ ├── lpipsPyTorch │ │ └── modules │ │ │ ├── lpips.py │ │ │ ├── networks.py │ │ │ └── utils.py │ ├── network_factory.py │ ├── partitioning_utils.py │ ├── psnr.py │ ├── render_utils.py │ ├── rigid_utils.py │ ├── rotation.py │ ├── seganygs.py │ ├── sfm_outlier_detection.py │ ├── sh_utils.py │ ├── ssim.py │ ├── visualizers.py │ ├── vq.py │ └── vq_utils.py └── viewer │ ├── __init__.py │ ├── client.py │ ├── renderer.py │ ├── training_viewer.py │ ├── ui │ ├── __init__.py │ ├── edit_panel.py │ ├── render_panel.py │ ├── transform_panel.py │ └── up_direction_folder.py │ └── viewer.py ├── main.py ├── notebooks ├── citygs_split_v2.ipynb ├── colmap_aerial_split.ipynb ├── colmap_split_v2.ipynb ├── context2colmap.ipynb ├── dji2pose.ipynb ├── foreground_first_density_controller_test.ipynb ├── gps_based_sfm_outlier_detection.ipynb ├── gsplat.ipynb ├── matrix_city_aerial_split.ipynb ├── matrix_city_split.ipynb ├── matrixcity2meganerf.ipynb ├── meganerf_rubble_split.ipynb ├── merge_partitions.ipynb ├── partition_light_gaussian_pruning.ipynb ├── preprocess.ipynb ├── prior_pose_guided_sfm.ipynb ├── prompt_segmenting.ipynb └── rotate_shs.ipynb ├── pyproject.toml ├── render.py ├── requirements.txt ├── requirements ├── 2DGS.txt ├── CityGS.txt ├── SpotLessSplats.txt ├── StopThePop.txt ├── common.txt ├── diff-accel-rasterization.txt ├── diff-surfel-rasterization.txt ├── fused-ssim.txt ├── gsplat.txt ├── lightning23.txt ├── lightning25.txt ├── pyt201_cu118.txt ├── pyt251_cu124.txt ├── pytorch3d-compile.txt ├── pytorch3d-pre.txt ├── pytorch3d-py39_cu118_pyt201.txt ├── sam.txt └── tcnn.txt ├── scripts ├── data_proc_mc.sh ├── data_proc_mc_scratch.sh ├── data_proc_mill19.sh ├── data_proc_mill19_scratch.sh ├── data_proc_us3d.sh ├── data_proc_us3d_scratch.sh ├── estimate_dataset_depths.slurm ├── get_sam_masks.slurm ├── gt_generate.sh ├── run_citygs_lfls.sh ├── run_citygs_mc_aerial.sh ├── run_citygs_mc_street.sh ├── run_citygs_smbu.sh ├── run_citygs_upper.sh ├── sd_feature_extraction.slurm ├── train-meganerf_rubble-partitions.slurm ├── untar_matrixcity_test.sh └── untar_matrixcity_train.sh ├── seganygs.py ├── submodules └── README.md ├── tests ├── .gitignore ├── dataset │ ├── .gitignore │ ├── blender_dataparser_test.py │ ├── colmap_dataparser_test.py │ ├── matrix_city_dataparser_test.py │ └── nerfies_dataparser_test.py ├── deformable_model_test.py ├── density_controller_utils_test.py ├── gaussian_containers_test.py ├── gaussian_projection_test.py ├── network_factory_test.py ├── positional_encoding_test.py └── vanilla_gaussian_model_test.py ├── tools ├── add_pypath.py ├── block_wandb_sync.py ├── clean_outputs.py ├── convert.py ├── convert_cam.py ├── copy_images.py ├── eval_tnt │ ├── README.md │ ├── compute_bbox_for_mesh.py │ ├── config.py │ ├── cull_mesh.py │ ├── evaluate_single_scene.py │ ├── evaluation.py │ ├── help_func.py │ ├── plot.py │ ├── registration.py │ ├── requirements.txt │ ├── run.py │ ├── trajectory_io.py │ └── util.py ├── render_traj.py ├── transform_json2txt_mc_aerial.py ├── transform_json2txt_mc_street.py ├── transform_pt2txt.py └── vectree_lightning.py ├── utils ├── PolyCam.md ├── add_pypath.py ├── argparser_utils.py ├── auto_hyper_parameter.py ├── ckpt2ply.py ├── colmap_undistort_mask.py ├── common.py ├── convert2splat.py ├── depths_downsample.py ├── distibuted_tasks.py ├── downsample_pcd.py ├── dump_ckpt.py ├── edit_with_histories.py ├── estimate_dataset_depths.py ├── eval_blender.py ├── eval_mipnerf360.py ├── finetune_partition.py ├── finetune_pruned_partitions_v2.py ├── fuse_appearance_embeddings_into_shs_dc.py ├── fuse_mip_filter.py ├── gaussian_transform.py ├── generate_crop_volume.py ├── generate_image_apperance_groups.py ├── generate_image_apperance_groups_by_exposure.py ├── generate_image_list.py ├── get_depth_scales.py ├── get_sam_embeddings.py ├── get_sam_mask_scales.py ├── get_sam_masks.py ├── gs2d_mesh_extraction.py ├── image_downsample.py ├── matrix_city_depth_to_point_cloud.py ├── matrix_city_frame_group_slice.py ├── matrix_city_group_continuous_frame.py ├── meganerf2colmap.py ├── merge_citygs_ckpts.py ├── merge_distributed_ckpts.py ├── merge_partitions_v2.py ├── merge_ply.py ├── mesh_post_process.py ├── optimize_val_set_appearance_embeddings.py ├── partition_citygs.py ├── ply2ckpt.py ├── polycam2ngp.py ├── polycam2points.py ├── prune_by_segany_mask.py ├── prune_partitions_v2.py ├── render_sls_masks.py ├── requirements.txt ├── run_depth_anything_v2.py ├── scalable_param_configs │ ├── appearance-depth_reg.yaml │ ├── appearance-with_scheduler-depth_reg.yaml │ ├── appearance-with_scheduler.yaml │ ├── appearance.yaml │ ├── depth_reg.yaml │ └── scale_reg.yaml ├── sd_feature_extraction.py ├── show_cameras.py ├── train_citygs_partitions.py ├── train_colmap_partitions.py ├── train_colmap_partitions_v2.py ├── train_matrix_city_partitions_v2.py ├── train_partitions.py ├── trained_partition_utils.py └── update_ckpt.py └── viewer.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode/ 3 | data 4 | output* 5 | camera_paths 6 | renders 7 | videos 8 | edited 9 | **/__pycache__ 10 | submodules/simple-knn 11 | submodules/tiny-cuda-nn-fp32 12 | submodules/diff-surfel-rasterization 13 | submodules/diff-surfel-rasterization_debug 14 | internal/dataparsers/colmap_cluster_dataparser.py 15 | internal/utils/data_sampler.py 16 | tools/merge_new_img.py 17 | tools/render_clip.py 18 | scripts/run.sh 19 | scripts/run1.sh 20 | scripts/run2.sh 21 | scripts/run_citygs2d_building.sh 22 | scripts/run_citygs2d_rubble.sh 23 | scripts/post_proc.sh 24 | scripts/lod_citygs_mc_aerial.sh 25 | notebooks_dumped 26 | clusters 27 | segments 28 | **/__pycache__ 29 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/diff-gaussian-rasterization"] 2 | path = submodules/diff-gaussian-rasterization 3 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 4 | [submodule "submodules/simple-knn"] 5 | path = submodules/simple-knn 6 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 7 | [submodule "diff-trim-gaussian-rasterization"] 8 | path = submodules/diff-trim-gaussian-rasterization 9 | url = https://github.com/Abyssaledge/diff-gaussian-rasterization 10 | [submodule "submodules/diff-trim-surfel-rasterization"] 11 | path = submodules/diff-trim-surfel-rasterization 12 | url = https://github.com/YuxueYang1204/diff-surfel-rasterization 13 | -------------------------------------------------------------------------------- /assets/analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/assets/analysis.png -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/assets/demo.gif -------------------------------------------------------------------------------- /blender/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | render_res 3 | dumped 4 | output 5 | __pycache__ -------------------------------------------------------------------------------- /blender/cam_pose_utils/graphic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def getWorld2View(R, t): 5 | Rt = np.zeros((4, 4)) 6 | Rt[:3, :3] = R 7 | Rt[:3, 3] = t 8 | Rt[3, 3] = 1.0 9 | return np.float32(Rt) 10 | 11 | def fov2focal(fov, pixels): 12 | return pixels / (2 * math.tan(fov / 2)) 13 | 14 | def focal2fov(focal, pixels): 15 | return 2*math.atan(pixels/(2*focal)) 16 | 17 | 18 | class Virtual(): 19 | pass -------------------------------------------------------------------------------- /blender/generate_video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | from tqdm import tqdm 4 | from argparse import ArgumentParser 5 | 6 | def generate_video(path, fps): 7 | # if not is_texture: 8 | # write_dir = os.path.join(path, 'videos', 'mesh') 9 | # load_dir = os.path.join(path, 'mesh') 10 | # else: 11 | # write_dir = os.path.join(path, 'videos', 'texture') 12 | # load_dir = os.path.join(path, 'texture') 13 | load_dir = path 14 | if not os.path.isdir(load_dir): 15 | assert False 16 | write_dir = os.path.join(path, 'videos') 17 | if not os.path.exists(write_dir): 18 | os.makedirs(write_dir) 19 | video = imageio.get_writer(os.path.join(write_dir, 'video.mp4'), fps=fps) 20 | image_list = sorted(os.listdir(load_dir)) 21 | for i in tqdm(range(len(image_list)), desc=f"Creating video"): 22 | path = os.path.join(load_dir, image_list[i]) 23 | if os.path.isdir(path): 24 | continue 25 | image = imageio.imread(path) 26 | video.append_data(image) 27 | video.close() 28 | 29 | if __name__ == "__main__": 30 | 31 | parser = ArgumentParser(description='video generator arg parser') 32 | parser.add_argument('--load_dir', type=str, default="render_res") 33 | parser.add_argument("--is_texture", action="store_true") 34 | parser.add_argument("--fps", type=int, default=60) 35 | args = parser.parse_args() 36 | generate_video(path=args.load_dir, fps=args.fps) -------------------------------------------------------------------------------- /blender/render_cfgs/dtu/background.json: -------------------------------------------------------------------------------- 1 | 3.8 -------------------------------------------------------------------------------- /blender/render_cfgs/dtu/background.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | back_sphere_radius = 3.8 4 | 5 | with open("background.json", "w") as f: 6 | json.dump(back_sphere_radius, f) -------------------------------------------------------------------------------- /blender/render_cfgs/dtu/light.json: -------------------------------------------------------------------------------- 1 | {"mesh": {"pose": [[1, 0, -2], [0, 1, -2], [0, -1, -2]], "energy": [12.0, 8.0, 7.0]}, "texture": {"pose": [[1, 0, -2], [0, 1, -2], [0, -1, -2]], "energy": [70.0, 45.0, 45.0]}} -------------------------------------------------------------------------------- /blender/render_cfgs/dtu/light.py: -------------------------------------------------------------------------------- 1 | import json 2 | light_cfg = {'mesh':{'pose':((1, 0, -2), (0, 1, -2), (0, -1, -2)), 'energy':(12.0, 8.0, 7.0)}, 'texture':{'pose':((1, 0, -2), (0, 1, -2), (0, -1, -2)), 'energy':(70.0, 45.0, 45.0)}} 3 | 4 | with open("light.json", "w") as f: 5 | json.dump(light_cfg, f) -------------------------------------------------------------------------------- /blender/render_cfgs/gauu/background.json: -------------------------------------------------------------------------------- 1 | 7.5 -------------------------------------------------------------------------------- /blender/render_cfgs/gauu/light.json: -------------------------------------------------------------------------------- 1 | { 2 | "mesh": { 3 | "pose": [[0, 0, -5]], 4 | "energy": [2.5], 5 | "type": ["SUN"], 6 | "scale": [1], 7 | "rotation": [[3.142, 0.0, 1.571]]}, 8 | "texture": { 9 | "pose": [[0, 0, -5]], 10 | "energy": [2], 11 | "type": ["SUN"], 12 | "scale": [1], 13 | "rotation": [[3.142, 0.0, 1.57]]}} -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_aerial/background.json: -------------------------------------------------------------------------------- 1 | 7.5 -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_aerial/background.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | back_sphere_radius = 7.5 4 | 5 | with open("background.json", "w") as f: 6 | json.dump(back_sphere_radius, f) -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_aerial/light.json: -------------------------------------------------------------------------------- 1 | {"mesh": {"pose": [[-5, 0, 5]], "energy": [2.5], "type": ["SUN"], "scale": [1], "rotation": [[0.0, 0.0, 1.571]]}, "texture": {"pose": [[-5, 0, 5]], "energy": [5], "type": ["SUN"], "scale": [1], "rotation": [[0.0, 0.0, 1.571]]}} -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_aerial/light.py: -------------------------------------------------------------------------------- 1 | import json 2 | light_cfg = {"mesh": {"pose": [[-5, 0, 5], ], 3 | "energy": [2.5, ], 4 | "type":["SUN", ], 5 | "scale":[1, ], 6 | "rotation":[[0.0, 0.0, 1.571], ]}, 7 | "texture": {"pose": [[-5, 0, 5], ], 8 | "energy": [5, ], 9 | "type":["SUN", ], 10 | "scale":[1, ], 11 | "rotation":[[0.0, 0.0, 1.571], ]}} 12 | 13 | with open("light.json", "w") as f: 14 | json.dump(light_cfg, f) -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_street/background.json: -------------------------------------------------------------------------------- 1 | 500 -------------------------------------------------------------------------------- /blender/render_cfgs/matrixcity_street/light.json: -------------------------------------------------------------------------------- 1 | { 2 | "mesh": { 3 | "pose": [[0, 100, 100]], 4 | "energy": [2.5], 5 | "type": ["SUN"], 6 | "scale": [1], 7 | "rotation": [[0, 0.0, 1.571]]}, 8 | "texture": { 9 | "pose": [[0, 100, 100]], 10 | "energy": [5], 11 | "type": ["SUN"], 12 | "scale": [1], 13 | "rotation": [[0, 0.0, 1.571]]}} -------------------------------------------------------------------------------- /blender/render_cfgs/mip/background.json: -------------------------------------------------------------------------------- 1 | 7.5 -------------------------------------------------------------------------------- /blender/render_cfgs/mip/background.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | back_sphere_radius = 7.5 4 | 5 | with open("background.json", "w") as f: 6 | json.dump(back_sphere_radius, f) -------------------------------------------------------------------------------- /blender/render_cfgs/mip/light.json: -------------------------------------------------------------------------------- 1 | {"mesh": {"pose": [[2.0, -3.5, 0], [-1, -3, 1], [0, -3, -1]], "energy": [60.0, 50.0, 30.0]}, "texture": {"pose": [[2.0, -3.5, 0], [-1, -3, 1], [0, -3, -1]], "energy": [200.0, 150.0, 150.0]}} -------------------------------------------------------------------------------- /blender/render_cfgs/mip/light.py: -------------------------------------------------------------------------------- 1 | import json 2 | light_cfg = {'mesh':{'pose':((2., -3.5, 0), (-1, -3, 1), (0, -3, -1)), 'energy':(60.0, 50.0, 30.0)}, 'texture':{'pose':((2., -3.5, 0), (-1, -3, 1), (0, -3, -1)), 'energy':(200.0, 150.0, 150.0)}} 3 | 4 | with open("light.json", "w") as f: 5 | json.dump(light_cfg, f) -------------------------------------------------------------------------------- /blender/render_cfgs/pixsfm/background.json: -------------------------------------------------------------------------------- 1 | 3.8 -------------------------------------------------------------------------------- /blender/render_cfgs/pixsfm/background.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | back_sphere_radius = 3.8 4 | 5 | with open("background.json", "w") as f: 6 | json.dump(back_sphere_radius, f) -------------------------------------------------------------------------------- /blender/render_cfgs/pixsfm/light.json: -------------------------------------------------------------------------------- 1 | {"mesh": {"pose": [[20, -150, -50]], "energy": [2.5], "type": ["SUN"], "scale": [1], "rotation": [[1.571, 0.0, 0.0]]}, "texture": {"pose": [[20, -150, -50]], "energy": [5], "type": ["SUN"], "scale": [1], "rotation": [[1.571, 0.0, 0.0]]}} -------------------------------------------------------------------------------- /blender/render_cfgs/pixsfm/light.py: -------------------------------------------------------------------------------- 1 | import json 2 | light_cfg = {"mesh": {"pose": [[20, -150, -50], ], 3 | "energy": [0.9, ], 4 | "type":["SUN", ], 5 | "scale":[1, ], 6 | "rotation":[[1.571, 0.0, 0.0], ]}, 7 | "texture": {"pose": [[20, -150, -50], ], 8 | "energy": [5, ], 9 | "type":["SUN", ], 10 | "scale":[1, ], 11 | "rotation":[[1.571, 0.0, 0.0], ]}} 12 | 13 | with open("light.json", "w") as f: 14 | json.dump(light_cfg, f) -------------------------------------------------------------------------------- /blender/render_utils/background_generator.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | 3 | color_mesh = (1, 0.878, 0.949) 4 | color_texture = (0.5, 0.5, 0.5) 5 | 6 | def newMaterial(id): 7 | 8 | mat = bpy.data.materials.get(id) 9 | if mat is None: 10 | mat = bpy.data.materials.new(name=id) 11 | 12 | mat.use_nodes = True 13 | if mat.node_tree: 14 | mat.node_tree.links.clear() 15 | mat.node_tree.nodes.clear() 16 | 17 | return mat 18 | 19 | 20 | def newShader(id, type, r, g, b): 21 | 22 | mat = newMaterial(id) 23 | nodes = mat.node_tree.nodes 24 | links = mat.node_tree.links 25 | output = nodes.new(type='ShaderNodeOutputMaterial') 26 | 27 | if type == "diffuse": 28 | shader = nodes.new(type='ShaderNodeBsdfDiffuse') 29 | nodes["Diffuse BSDF"].inputs[0].default_value = (r, g, b, 1) 30 | else: 31 | assert False 32 | links.new(shader.outputs[0], output.inputs[0]) 33 | 34 | return mat 35 | 36 | 37 | def draw_background(is_texture): 38 | 39 | if is_texture: 40 | mat = newShader("Texture", "diffuse", *color_texture) 41 | else: 42 | mat = newShader("Mesh", "diffuse", *color_mesh) 43 | bpy.ops.surface.primitive_nurbs_surface_sphere_add() 44 | bpy.context.active_object.data.materials.append(mat) 45 | -------------------------------------------------------------------------------- /blender/render_utils/texture_allocator.py: -------------------------------------------------------------------------------- 1 | 2 | class TextureAllocator: 3 | 4 | def __init__(self, bpy, texture_name='texture_material'): 5 | self.bpy = bpy 6 | self.texture_name = texture_name 7 | # self.init_texture() 8 | 9 | def init_texture(self): 10 | bpy = self.bpy 11 | texture_name = self.texture_name 12 | mat = bpy.data.materials.new(name=texture_name) 13 | mat.use_nodes = True 14 | if mat.node_tree: 15 | mat.node_tree.links.clear() 16 | mat.node_tree.nodes.clear() 17 | 18 | nodes = mat.node_tree.nodes 19 | links = mat.node_tree.links 20 | output = nodes.new(type='ShaderNodeOutputMaterial') 21 | # shader = nodes.new(type='ShaderNodeBsdfDiffuse') 22 | shader = nodes.new(type='ShaderNodeBsdfPrincipled') 23 | links.new(shader.outputs[0], output.inputs[0]) 24 | 25 | input_attribute = nodes.new(type='ShaderNodeAttribute') 26 | input_attribute.attribute_name = 'Col' 27 | links.new(input_attribute.outputs[0], shader.inputs[0]) 28 | # return mat 29 | 30 | def set_texture(self): 31 | bpy = self.bpy 32 | bpy.context.active_object.data.materials.append(bpy.data.materials[self.texture_name]) 33 | -------------------------------------------------------------------------------- /configs/appearance.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: AppearanceMLPRenderer -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/distributed-view_independent-128_dims-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth-scale_reg-mip.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | strategy: 3 | class_path: internal.mp_strategy.MPStrategy 4 | devices: -1 5 | model: 6 | gaussian: 7 | class_path: internal.models.appearance_mip_gaussian.AppearanceMipGaussian 8 | init_args: 9 | sh_degree: 0 10 | appearance_feature_dims: 64 11 | optimization: 12 | appearance_feature_lr_init: 0.005 13 | appearance_feature_lr_scheduler: 14 | class_path: ExponentialDecayScheduler 15 | init_args: 16 | lr_final: 0.00025 17 | renderer: 18 | class_path: internal.renderers.gsplat_distributed_appearance_embedding_renderer.GSplatDistributedAppearanceMipRenderer 19 | init_args: 20 | appearance: 21 | n_appearance_embedding_dims: 128 22 | n_appearances: 2560 # must be greater than the number of images 23 | is_view_dependent: false 24 | appearance_optimization: 25 | embedding_lr_init: 5e-3 26 | embedding_lr_final_factor: 0.05 27 | warm_up: 0 28 | density: internal.density_controllers.distributed_vanilla_density_controller.DistributedVanillaDensityController 29 | metric: 30 | class_path: internal.metrics.scale_regularization_metrics.ScaleRegularizationWithDepthMetrics 31 | init_args: 32 | depth_output_key: hard_inverse_depth 33 | renderer_output_types: 34 | - rgb 35 | - hard_inverse_depth 36 | data: 37 | val_max_num_images_to_cache: -1 38 | test_max_num_images_to_cache: -1 39 | distributed: true 40 | parser: 41 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 42 | init_args: 43 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-128_dims-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 0.005 9 | appearance_feature_lr_scheduler: 10 | class_path: ExponentialDecayScheduler 11 | init_args: 12 | lr_final: 0.00025 13 | renderer: 14 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 15 | init_args: 16 | model: 17 | n_appearance_embedding_dims: 128 18 | is_view_dependent: false 19 | optimization: 20 | embedding_lr_init: 5e-3 21 | embedding_lr_final_factor: 0.05 22 | warm_up: 1000 23 | metric: 24 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 25 | init_args: 26 | depth_output_key: hard_inverse_depth 27 | renderer_output_types: 28 | - rgb 29 | - hard_inverse_depth 30 | data: 31 | val_max_num_images_to_cache: -1 32 | test_max_num_images_to_cache: -1 33 | parser: 34 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 35 | init_args: 36 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-128_dims-lr_0.005-with_scheduler.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 0.005 9 | appearance_feature_lr_scheduler: 10 | class_path: ExponentialDecayScheduler 11 | init_args: 12 | lr_final: 0.00025 13 | renderer: 14 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 15 | init_args: 16 | model: 17 | n_appearance_embedding_dims: 128 18 | is_view_dependent: false 19 | optimization: 20 | embedding_lr_init: 5e-3 21 | embedding_lr_final_factor: 0.05 22 | warm_up: 1000 23 | data: 24 | val_max_num_images_to_cache: -1 25 | test_max_num_images_to_cache: -1 26 | parser: 27 | class_path: Colmap 28 | init_args: 29 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-accel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | renderer: 8 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 9 | init_args: 10 | model: 11 | is_view_dependent: false 12 | normalize: false 13 | optimization: 14 | warm_up: 1000 15 | tile_based_culling: true 16 | metric: 17 | fused_ssim: true 18 | data: 19 | val_max_num_images_to_cache: -1 20 | test_max_num_images_to_cache: -1 21 | parser: 22 | class_path: Colmap 23 | init_args: 24 | split_mode: reconstruction 25 | appearance_groups: appearance_groups-image_dedicated 26 | -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-estimated_depth_reg-hard_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | metric: 17 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 18 | init_args: 19 | depth_output_key: hard_inverse_depth 20 | renderer_output_types: 21 | - rgb 22 | - hard_inverse_depth 23 | data: 24 | val_max_num_images_to_cache: -1 25 | test_max_num_images_to_cache: -1 26 | parser: 27 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 28 | init_args: 29 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 0.005 9 | appearance_feature_lr_scheduler: 10 | class_path: ExponentialDecayScheduler 11 | init_args: 12 | lr_final: 0.00025 13 | renderer: 14 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 15 | init_args: 16 | model: 17 | is_view_dependent: false 18 | optimization: 19 | embedding_lr_init: 5e-3 20 | embedding_lr_final_factor: 0.05 21 | warm_up: 1000 22 | metric: 23 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 24 | init_args: 25 | depth_output_key: hard_inverse_depth 26 | renderer_output_types: 27 | - rgb 28 | - hard_inverse_depth 29 | data: 30 | val_max_num_images_to_cache: -1 31 | test_max_num_images_to_cache: -1 32 | parser: 33 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 34 | init_args: 35 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-lr_0.005-with_scheduler.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 0.005 9 | appearance_feature_lr_scheduler: 10 | class_path: ExponentialDecayScheduler 11 | init_args: 12 | lr_final: 0.00025 13 | renderer: 14 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 15 | init_args: 16 | model: 17 | is_view_dependent: false 18 | optimization: 19 | embedding_lr_init: 5e-3 20 | embedding_lr_final_factor: 0.05 21 | warm_up: 1000 22 | data: 23 | val_max_num_images_to_cache: -1 24 | test_max_num_images_to_cache: -1 25 | parser: 26 | class_path: Colmap 27 | init_args: 28 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-mip.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_mip_gaussian.AppearanceMipGaussian # MipSplatting model 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingMipRenderer # MipSplatting version 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: Colmap 21 | init_args: 22 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-normalized-accel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | renderer: 8 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 9 | init_args: 10 | model: 11 | is_view_dependent: false 12 | normalize: true 13 | optimization: 14 | warm_up: 1000 15 | tile_based_culling: true 16 | metric: 17 | fused_ssim: true 18 | data: 19 | val_max_num_images_to_cache: -1 20 | test_max_num_images_to_cache: -1 21 | parser: 22 | class_path: Colmap 23 | init_args: 24 | split_mode: reconstruction 25 | appearance_groups: appearance_groups-image_dedicated 26 | -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent-normalized-sim_reg-accel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | renderer: 8 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 9 | init_args: 10 | model: 11 | is_view_dependent: false 12 | normalize: true 13 | optimization: 14 | warm_up: 1000 15 | tile_based_culling: true 16 | metric: 17 | class_path: internal.metrics.appearance_feature_similarity_regularization_metrics.VanillaMetricsWithSimilarityRegularization 18 | init_args: 19 | fused_ssim: true 20 | data: 21 | val_max_num_images_to_cache: -1 22 | test_max_num_images_to_cache: -1 23 | parser: 24 | class_path: Colmap 25 | init_args: 26 | split_mode: reconstruction 27 | appearance_groups: appearance_groups-image_dedicated 28 | -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/sh_view_dependent.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 3 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: Colmap 21 | init_args: 22 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/snippets/.gitignore: -------------------------------------------------------------------------------- 1 | !tcnn.yaml -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/snippets/tcnn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | model: 4 | tcnn: true 5 | -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_dependent-distributed.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | strategy: 3 | class_path: internal.mp_strategy.MPStrategy 4 | devices: -1 5 | model: 6 | gaussian: 7 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 8 | init_args: 9 | sh_degree: 0 10 | appearance_feature_dims: 64 11 | optimization: 12 | appearance_feature_lr_init: 2e-3 13 | renderer: 14 | class_path: internal.renderers.gsplat_distributed_appearance_embedding_renderer.GSplatDistributedAppearanceEmbeddingRenderer 15 | init_args: 16 | appearance: 17 | n_appearances: 1024 18 | is_view_dependent: true 19 | appearance_optimization: 20 | warm_up: 0 21 | density: internal.density_controllers.distributed_vanilla_density_controller.DistributedVanillaDensityController 22 | data: 23 | val_max_num_images_to_cache: -1 24 | test_max_num_images_to_cache: -1 25 | parser: 26 | class_path: Colmap 27 | init_args: 28 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_dependent-estimated_depth_reg-hard_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: true 14 | optimization: 15 | warm_up: 1000 16 | metric: 17 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 18 | init_args: 19 | depth_output_key: hard_inverse_depth 20 | renderer_output_types: 21 | - rgb 22 | - hard_inverse_depth 23 | data: 24 | val_max_num_images_to_cache: -1 25 | test_max_num_images_to_cache: -1 26 | parser: 27 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 28 | init_args: 29 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_dependent-estimated_depth_reg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: true 14 | optimization: 15 | warm_up: 1000 16 | metric: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 17 | renderer_output_types: 18 | - rgb 19 | - inverse_depth 20 | data: 21 | val_max_num_images_to_cache: -1 22 | test_max_num_images_to_cache: -1 23 | parser: 24 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 25 | init_args: 26 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_dependent.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: true 14 | optimization: 15 | warm_up: 1000 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: Colmap 21 | init_args: 22 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-2dgs.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_gs2d.AppearanceGS2D 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.appearance_2dgs_renderer.Appearance2DGSRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | normalize: true 15 | optimization: 16 | warm_up: 1000 17 | metric: internal.metrics.gs2d_metrics.GS2DMetrics 18 | density: 19 | class_path: internal.density_controllers.gs2d_density_controller.GS2DDensityController 20 | init_args: 21 | cull_opacity_threshold: 0.05 22 | data: 23 | val_max_num_images_to_cache: -1 24 | test_max_num_images_to_cache: -1 25 | parser: 26 | class_path: Colmap 27 | init_args: 28 | split_mode: "reconstruction" 29 | -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-distributed.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | strategy: 3 | class_path: internal.mp_strategy.MPStrategy 4 | devices: -1 5 | model: 6 | gaussian: 7 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 8 | init_args: 9 | sh_degree: 0 10 | appearance_feature_dims: 64 11 | optimization: 12 | appearance_feature_lr_init: 2e-3 13 | renderer: 14 | class_path: internal.renderers.gsplat_distributed_appearance_embedding_renderer.GSplatDistributedAppearanceEmbeddingRenderer 15 | init_args: 16 | appearance: 17 | n_appearances: 1024 18 | is_view_dependent: false 19 | appearance_optimization: 20 | warm_up: 0 21 | density: internal.density_controllers.distributed_vanilla_density_controller.DistributedVanillaDensityController 22 | data: 23 | val_max_num_images_to_cache: -1 24 | test_max_num_images_to_cache: -1 25 | parser: 26 | class_path: Colmap 27 | init_args: 28 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-estimated_depth_reg-hard_depth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | metric: 17 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 18 | init_args: 19 | depth_output_key: hard_inverse_depth 20 | renderer_output_types: 21 | - rgb 22 | - hard_inverse_depth 23 | data: 24 | val_max_num_images_to_cache: -1 25 | test_max_num_images_to_cache: -1 26 | parser: 27 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 28 | init_args: 29 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-estimated_depth_reg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | metric: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 17 | renderer_output_types: 18 | - rgb 19 | - inverse_depth 20 | data: 21 | val_max_num_images_to_cache: -1 22 | test_max_num_images_to_cache: -1 23 | parser: 24 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 25 | init_args: 26 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-phototourism-sls-opacity_reg_0.01.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | density: 17 | opacity_reset_interval: 999999999 # no reset 18 | metric: 19 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 20 | init_args: 21 | opacity_reg: 0.01 22 | data: 23 | val_max_num_images_to_cache: -1 24 | test_max_num_images_to_cache: -1 25 | parser: 26 | class_path: PhotoTourism 27 | init_args: 28 | split_mode: "reconstruction" 29 | semantic_feature: true -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent-phototourism-sls_high_occlusion-opacity_reg_0.01.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | density: 17 | opacity_reset_interval: 999999999 # no reset 18 | metric: 19 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 20 | init_args: 21 | lower_bound: 0.3 22 | upper_bound: 0.8 23 | opacity_reg: 0.01 24 | data: 25 | val_max_num_images_to_cache: -1 26 | test_max_num_images_to_cache: -1 27 | parser: 28 | class_path: PhotoTourism 29 | init_args: 30 | split_mode: "reconstruction" 31 | semantic_feature: true -------------------------------------------------------------------------------- /configs/appearance_embedding_renderer/view_independent.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_renderer.GSplatAppearanceEmbeddingRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | optimization: 15 | warm_up: 1000 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: Colmap 21 | init_args: 22 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/appearance_embedding_visibility_map_renderer/view_independent-2x_ds-exp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_visibility_map_renderer.GSplatAppearanceEmbeddingVisibilityMapRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | metric: 15 | class_path: internal.metrics.visibility_map_metrics.VisibilityMapMetrics 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: PhotoTourism 21 | init_args: 22 | down_sample_factor: 2 23 | split_mode: "experiment" -------------------------------------------------------------------------------- /configs/appearance_embedding_visibility_map_renderer/view_independent-2x_ds.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.appearance_feature_gaussian.AppearanceFeatureGaussian 4 | init_args: 5 | sh_degree: 0 6 | appearance_feature_dims: 64 7 | optimization: 8 | appearance_feature_lr_init: 2e-3 9 | renderer: 10 | class_path: internal.renderers.gsplat_appearance_embedding_visibility_map_renderer.GSplatAppearanceEmbeddingVisibilityMapRenderer 11 | init_args: 12 | model: 13 | is_view_dependent: false 14 | metric: 15 | class_path: internal.metrics.visibility_map_metrics.VisibilityMapMetrics 16 | data: 17 | val_max_num_images_to_cache: -1 18 | test_max_num_images_to_cache: -1 19 | parser: 20 | class_path: PhotoTourism 21 | init_args: 22 | down_sample_factor: 2 23 | split_mode: "reconstruction" -------------------------------------------------------------------------------- /configs/background_sphere-cameras_center-random_background.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_background_sphere: true 3 | background_sphere_center: "cameras" 4 | background_sphere_distance: 1 5 | background_sphere_points: 204800 6 | model: 7 | random_background: true -------------------------------------------------------------------------------- /configs/background_sphere.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_background_sphere: true 3 | background_sphere_distance: 2.2 4 | background_sphere_points: 204800 -------------------------------------------------------------------------------- /configs/blender.yaml: -------------------------------------------------------------------------------- 1 | pbar_rate: 10 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | data: 5 | parser: Blender 6 | val_max_num_images_to_cache: -1 7 | test_max_num_images_to_cache: -1 8 | image_on_cpu: false -------------------------------------------------------------------------------- /configs/blender_gsplat.yaml: -------------------------------------------------------------------------------- 1 | pbar_rate: 10 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | model: 5 | renderer: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 6 | data: 7 | parser: Blender 8 | val_max_num_images_to_cache: -1 9 | test_max_num_images_to_cache: -1 10 | image_on_cpu: false -------------------------------------------------------------------------------- /configs/citygs_lfls_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | init_args: 4 | sh_degree: 2 5 | renderer: 6 | class_path: internal.renderers.vanilla_trim_renderer.VanillaTrimRenderer 7 | init_args: 8 | diable_trimming: true 9 | trainer: 10 | check_val_every_n_epoch: 20 11 | max_steps: 30000 12 | data: 13 | path: data/GauU_Scene/LFLS 14 | parser: 15 | class_path: EstimatedDepthBlockColmap 16 | init_args: 17 | split_mode: experiment 18 | eval_image_select_mode: ratio 19 | eval_ratio: 0.1 20 | down_sample_factor: 3.4175 -------------------------------------------------------------------------------- /configs/citygs_lfls_sh2_trim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygs_lfls_coarse_sh2/checkpoints/epoch=32-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | init_args: 6 | sh_degree: 2 7 | optimization: 8 | means_lr_init: 0.000064 9 | means_lr_scheduler: 10 | lr_final: 0.00000064 11 | scales_lr: 0.004 12 | renderer: internal.renderers.vanilla_trim_renderer.VanillaTrimRenderer 13 | trainer: 14 | check_val_every_n_epoch: 20 15 | max_steps: 30000 16 | data: 17 | path: data/GauU_Scene/LFLS 18 | parser: 19 | class_path: EstimatedDepthBlockColmap 20 | init_args: 21 | down_sample_factor: 3.4175 22 | content_threshold: 0.05 23 | block_dim: # removed z dimension 24 | - 4 25 | - 2 -------------------------------------------------------------------------------- /configs/citygsv2_lfls_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.gaussian_2d.Gaussian2D 4 | init_args: 5 | sh_degree: 2 6 | metric: 7 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 8 | init_args: 9 | lambda_normal: 0.0125 10 | depth_loss_type: l1+ssim 11 | depth_loss_ssim_weight: 1.0 12 | depth_loss_weight: 13 | init: 0.5 14 | final_factor: 0.005 15 | renderer: 16 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 17 | init_args: 18 | depth_ratio: 1.0 19 | diable_trimming: true 20 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 21 | trainer: 22 | check_val_every_n_epoch: 20 23 | max_steps: 30000 24 | data: 25 | path: data/GauU_Scene/LFLS 26 | parser: 27 | class_path: EstimatedDepthBlockColmap 28 | init_args: 29 | split_mode: experiment 30 | eval_image_select_mode: ratio 31 | eval_ratio: 0.1 32 | down_sample_factor: 3.4175 -------------------------------------------------------------------------------- /configs/citygsv2_lfls_sh2_trim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygsv2_lfls_coarse_sh2/checkpoints/epoch=32-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | class_path: internal.models.gaussian_2d.Gaussian2D 6 | init_args: 7 | sh_degree: 2 8 | optimization: 9 | means_lr_init: 0.000064 10 | means_lr_scheduler: 11 | lr_final: 0.00000064 12 | scales_lr: 0.004 13 | metric: 14 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 15 | init_args: 16 | lambda_normal: 0.0125 17 | normal_regularization_from_iter: 0 18 | depth_loss_type: l1+ssim 19 | depth_loss_ssim_weight: 1.0 20 | depth_loss_weight: 21 | init: 0.5 22 | final_factor: 0.05 23 | renderer: 24 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 25 | init_args: 26 | depth_ratio: 1.0 27 | prune_ratio: 0.05 28 | density: 29 | class_path: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 30 | init_args: 31 | densify_grad_threshold: 0.0003 32 | trainer: 33 | check_val_every_n_epoch: 20 34 | max_steps: 30000 35 | data: 36 | path: data/GauU_Scene/LFLS 37 | parser: 38 | class_path: EstimatedDepthBlockColmap 39 | init_args: 40 | down_sample_factor: 3.4175 41 | content_threshold: 0.05 42 | block_dim: # removed z dimension 43 | - 4 44 | - 2 -------------------------------------------------------------------------------- /configs/citygsv2_mc_aerial_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.gaussian_2d.Gaussian2D 4 | init_args: 5 | sh_degree: 2 6 | metric: 7 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 8 | init_args: 9 | lambda_normal: 0.0125 10 | depth_loss_type: l1+ssim 11 | depth_loss_ssim_weight: 1.0 12 | depth_loss_weight: 13 | init: 0.5 14 | final_factor: 0.005 15 | renderer: 16 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 17 | init_args: 18 | depth_ratio: 1.0 19 | diable_trimming: true 20 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 21 | trainer: 22 | check_val_every_n_epoch: 20 23 | max_steps: 30000 24 | data: 25 | path: data/matrix_city/aerial/train/block_all 26 | parser: 27 | class_path: EstimatedDepthBlockColmap 28 | init_args: 29 | down_sample_factor: 1.2 -------------------------------------------------------------------------------- /configs/citygsv2_mc_aerial_sh2_trim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygsv2_mc_aerial_coarse_sh2/checkpoints/epoch=6-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | class_path: internal.models.gaussian_2d.Gaussian2D 6 | init_args: 7 | sh_degree: 2 8 | optimization: 9 | means_lr_init: 0.000064 10 | means_lr_scheduler: 11 | lr_final: 0.00000064 12 | max_steps: 60_000 13 | scales_lr: 0.004 14 | metric: 15 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 16 | init_args: 17 | lambda_normal: 0.0125 18 | normal_regularization_from_iter: 0 19 | depth_loss_type: l1+ssim 20 | depth_loss_ssim_weight: 1.0 21 | depth_loss_weight: 22 | init: 0.5 23 | final_factor: 0.05 24 | max_steps: 60_000 25 | renderer: 26 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 27 | init_args: 28 | depth_ratio: 1.0 29 | density: 30 | class_path: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 31 | init_args: 32 | densification_interval: 200 33 | opacity_reset_interval: 6000 34 | densify_from_iter: 1000 35 | densify_until_iter: 30_000 36 | trainer: 37 | check_val_every_n_epoch: 20 38 | max_steps: 60000 39 | data: 40 | path: data/matrix_city/aerial/train/block_all 41 | parser: 42 | class_path: EstimatedDepthBlockColmap 43 | init_args: 44 | down_sample_factor: 1.2 45 | content_threshold: 0.05 46 | block_dim: 47 | - 4 48 | - 4 49 | save_iterations: 50 | - 30000 51 | - 60000 -------------------------------------------------------------------------------- /configs/citygsv2_mc_street_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.gaussian_2d.Gaussian2D 4 | init_args: 5 | sh_degree: 2 6 | optimization: 7 | means_lr_init: 1.6e-5 8 | means_lr_scheduler: 9 | lr_final: 1.6e-6 10 | max_steps: 30_000 11 | scales_lr: 0.001 12 | metric: 13 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 14 | init_args: 15 | lambda_normal: 0.0125 16 | depth_normalized: true 17 | depth_loss_weight: 18 | init: 1.0 19 | final_factor: 0.1 20 | renderer: 21 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 22 | init_args: 23 | depth_ratio: 1.0 24 | diable_trimming: true 25 | density: 26 | class_path: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 27 | init_args: 28 | densification_interval: 1000 29 | opacity_reset_interval: 20000 30 | densify_from_iter: 4000 31 | densify_grad_threshold: 0.00005 32 | trainer: 33 | check_val_every_n_epoch: 20 34 | max_steps: 30000 35 | data: 36 | path: data/matrix_city/street/train/block_A 37 | parser: 38 | class_path: EstimatedDepthBlockColmap 39 | init_args: 40 | down_sample_factor: 1 41 | depth_scale_lower_bound: 0.01 42 | depth_scale_upper_bound: 50.0 -------------------------------------------------------------------------------- /configs/citygsv2_mc_street_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygsv2_mc_street_coarse_sh2/checkpoints/epoch=8-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | class_path: internal.models.gaussian_2d.Gaussian2D 6 | init_args: 7 | sh_degree: 2 8 | optimization: 9 | means_lr_init: 0.8e-5 10 | means_lr_scheduler: 11 | lr_final: 0.8e-6 12 | max_steps: 60_000 13 | scales_lr: 0.0025 14 | metric: 15 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 16 | init_args: 17 | lambda_normal: 0.0125 18 | normal_regularization_from_iter: 0 19 | depth_normalized: true 20 | depth_loss_type: l1+ssim 21 | depth_loss_ssim_weight: 1.0 22 | depth_loss_weight: 23 | init: 0.5 24 | final_factor: 0.05 25 | max_steps: 60_000 26 | renderer: 27 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 28 | init_args: 29 | depth_ratio: 1.0 30 | diable_trimming: true 31 | density: 32 | class_path: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 33 | init_args: 34 | densification_interval: 1000 35 | opacity_reset_interval: 30_000 36 | densify_from_iter: 500 37 | densify_until_iter: 30_000 38 | densify_grad_threshold: 0.000075 39 | trainer: 40 | check_val_every_n_epoch: 20 41 | max_steps: 60000 42 | data: 43 | path: data/matrix_city/street/train/block_A 44 | parser: 45 | class_path: EstimatedDepthBlockColmap 46 | init_args: 47 | down_sample_factor: 1 48 | depth_scale_lower_bound: 0.01 49 | depth_scale_upper_bound: 50.0 50 | content_threshold: 0.01 51 | block_dim: 52 | - 5 53 | - 4 54 | aabb: 55 | - -600 56 | - -400 57 | - -300 58 | - -200 59 | save_iterations: 60 | - 30000 61 | - 60000 -------------------------------------------------------------------------------- /configs/citygsv2_smbu_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.gaussian_2d.Gaussian2D 4 | init_args: 5 | sh_degree: 2 6 | metric: 7 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 8 | init_args: 9 | lambda_normal: 0.0125 10 | depth_loss_type: l1+ssim 11 | depth_loss_ssim_weight: 1.0 12 | depth_loss_weight: 13 | init: 0.5 14 | final_factor: 0.005 15 | renderer: 16 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 17 | init_args: 18 | depth_ratio: 1.0 19 | diable_trimming: true 20 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 21 | trainer: 22 | check_val_every_n_epoch: 20 23 | max_steps: 30000 24 | data: 25 | path: data/GauU_Scene/SMBU 26 | parser: 27 | class_path: EstimatedDepthBlockColmap 28 | init_args: 29 | split_mode: experiment 30 | eval_image_select_mode: ratio 31 | eval_ratio: 0.1 32 | down_sample_factor: 3.4175 -------------------------------------------------------------------------------- /configs/citygsv2_smbu_sh2_trim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygsv2_smbu_coarse_sh2/checkpoints/epoch=60-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | class_path: internal.models.gaussian_2d.Gaussian2D 6 | init_args: 7 | sh_degree: 2 8 | optimization: 9 | means_lr_init: 0.000064 10 | means_lr_scheduler: 11 | lr_final: 0.00000064 12 | scales_lr: 0.004 13 | metric: 14 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 15 | init_args: 16 | lambda_normal: 0.0125 17 | normal_regularization_from_iter: 0 18 | depth_loss_type: l1+ssim 19 | depth_loss_ssim_weight: 1.0 20 | depth_loss_weight: 21 | init: 0.5 22 | final_factor: 0.05 23 | renderer: 24 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 25 | init_args: 26 | depth_ratio: 1.0 27 | prune_ratio: 0.05 28 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 29 | trainer: 30 | check_val_every_n_epoch: 20 31 | max_steps: 30000 32 | data: 33 | path: data/GauU_Scene/SMBU 34 | parser: 35 | class_path: EstimatedDepthBlockColmap 36 | init_args: 37 | down_sample_factor: 3.4175 38 | content_threshold: 0.05 39 | block_dim: # removed z dimension 40 | - 3 41 | - 3 -------------------------------------------------------------------------------- /configs/citygsv2_upper_coarse_sh2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.gaussian_2d.Gaussian2D 4 | init_args: 5 | sh_degree: 2 6 | metric: 7 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 8 | init_args: 9 | lambda_normal: 0.0125 10 | depth_loss_type: l1+ssim 11 | depth_loss_ssim_weight: 1.0 12 | depth_loss_weight: 13 | init: 0.5 14 | final_factor: 0.005 15 | renderer: 16 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 17 | init_args: 18 | depth_ratio: 1.0 19 | diable_trimming: true 20 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 21 | trainer: 22 | check_val_every_n_epoch: 20 23 | max_steps: 30000 24 | data: 25 | path: data/GauU_Scene/CUHK_UPPER_COLMAP 26 | parser: 27 | class_path: EstimatedDepthBlockColmap 28 | init_args: 29 | split_mode: experiment 30 | eval_image_select_mode: ratio 31 | eval_ratio: 0.1 32 | down_sample_factor: 3.4175 -------------------------------------------------------------------------------- /configs/citygsv2_upper_sh2_trim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | initialize_from: outputs/citygsv2_upper_coarse_sh2/checkpoints/epoch=48-step=30000.ckpt 3 | overwrite_config: False 4 | gaussian: 5 | class_path: internal.models.gaussian_2d.Gaussian2D 6 | init_args: 7 | sh_degree: 2 8 | optimization: 9 | means_lr_init: 0.000064 10 | means_lr_scheduler: 11 | lr_final: 0.00000064 12 | scales_lr: 0.004 13 | metric: 14 | class_path: internal.metrics.citygsv2_metrics.CityGSV2Metrics 15 | init_args: 16 | lambda_normal: 0.0125 17 | normal_regularization_from_iter: 0 18 | depth_loss_type: l1+ssim 19 | depth_loss_ssim_weight: 1.0 20 | depth_loss_weight: 21 | init: 0.5 22 | final_factor: 0.05 23 | renderer: 24 | class_path: internal.renderers.sep_depth_trim_2dgs_renderer.SepDepthTrim2DGSRenderer 25 | init_args: 26 | depth_ratio: 1.0 27 | prune_ratio: 0.05 28 | density: internal.density_controllers.citygsv2_density_controller.CityGSV2DensityController 29 | trainer: 30 | check_val_every_n_epoch: 20 31 | max_steps: 30000 32 | data: 33 | path: data/GauU_Scene/CUHK_UPPER_COLMAP 34 | parser: 35 | class_path: EstimatedDepthBlockColmap 36 | init_args: 37 | down_sample_factor: 3.4175 38 | content_threshold: 0.05 39 | block_dim: # removed z dimension 40 | - 3 41 | - 3 -------------------------------------------------------------------------------- /configs/colmap_exp.yaml: -------------------------------------------------------------------------------- 1 | logger: wandb 2 | data: 3 | parser: 4 | class_path: Colmap 5 | init_args: 6 | split_mode: "experiment" -------------------------------------------------------------------------------- /configs/ddp.yaml: -------------------------------------------------------------------------------- 1 | strategy: 2 | class_path: lightning.pytorch.strategies.ddp.DDPStrategy 3 | init_args: 4 | find_unused_parameters: true 5 | devices: -1 6 | -------------------------------------------------------------------------------- /configs/ddp_not_find_unused.yaml: -------------------------------------------------------------------------------- 1 | strategy: 2 | class_path: lightning.pytorch.strategies.ddp.DDPStrategy 3 | init_args: 4 | find_unused_parameters: false 5 | devices: -1 6 | -------------------------------------------------------------------------------- /configs/deformable_6dof_blender.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 40000 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | data: 5 | val_max_num_images_to_cache: -1 6 | test_max_num_images_to_cache: -1 7 | model: 8 | renderer: 9 | class_path: internal.renderers.deformable_renderer.DeformableRenderer 10 | init_args: 11 | deform_network: 12 | tcnn: false 13 | is_6dof: true 14 | time_encoding: 15 | n_frequencies: 6 16 | n_layers: 2 17 | n_neurons: 256 18 | optimization: 19 | enable_ast: false 20 | gaussian: 21 | optimization: 22 | spatial_lr_scale: 5 -------------------------------------------------------------------------------- /configs/deformable_6dof_real.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 40000 2 | data: 3 | val_max_num_images_to_cache: -1 4 | test_max_num_images_to_cache: -1 5 | model: 6 | renderer: 7 | class_path: internal.renderers.deformable_renderer.DeformableRenderer 8 | init_args: 9 | deform_network: 10 | tcnn: false 11 | is_6dof: true 12 | time_encoding: 13 | n_frequencies: 10 14 | n_layers: 0 15 | n_neurons: 0 16 | optimization: 17 | enable_ast: true 18 | gaussian: 19 | optimization: 20 | spatial_lr_scale: 5 -------------------------------------------------------------------------------- /configs/deformable_blender.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 40000 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | data: 5 | val_max_num_images_to_cache: -1 6 | test_max_num_images_to_cache: -1 7 | model: 8 | renderer: 9 | class_path: internal.renderers.deformable_renderer.DeformableRenderer 10 | init_args: 11 | deform_network: 12 | tcnn: false 13 | time_encoding: 14 | n_frequencies: 6 15 | n_layers: 2 16 | n_neurons: 256 17 | optimization: 18 | enable_ast: false 19 | gaussian: 20 | optimization: 21 | spatial_lr_scale: 5 -------------------------------------------------------------------------------- /configs/deformable_blender_rotate_xyz.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 40000 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | data: 5 | val_max_num_images_to_cache: -1 6 | test_max_num_images_to_cache: -1 7 | model: 8 | renderer: 9 | class_path: internal.renderers.deformable_renderer.DeformableRenderer 10 | init_args: 11 | deform_network: 12 | tcnn: false 13 | rotate_xyz: true 14 | time_encoding: 15 | n_frequencies: 6 16 | n_layers: 2 17 | n_neurons: 256 18 | optimization: 19 | enable_ast: false 20 | gaussian: 21 | optimization: 22 | spatial_lr_scale: 5 -------------------------------------------------------------------------------- /configs/deformable_real.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 40000 2 | data: 3 | val_max_num_images_to_cache: -1 4 | test_max_num_images_to_cache: -1 5 | model: 6 | renderer: 7 | class_path: internal.renderers.deformable_renderer.DeformableRenderer 8 | init_args: 9 | deform_network: 10 | tcnn: false 11 | #chunk: 16384 # avoid CUDA OOM 12 | time_encoding: 13 | n_frequencies: 10 14 | n_layers: 0 15 | n_neurons: 0 16 | optimization: 17 | enable_ast: true 18 | gaussian: 19 | optimization: 20 | spatial_lr_scale: 5 -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-alternating-l1-accel.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | fused_ssim: true 6 | depth_output_key: inv_depth_alt 7 | renderer: 8 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 9 | init_args: 10 | separate_sh: true 11 | tile_based_culling: true 12 | renderer_output_types: 13 | - rgb 14 | - inv_depth_alt 15 | data: 16 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 17 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-hard_depth-l1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | depth_output_key: hard_inverse_depth 6 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 7 | renderer_output_types: 8 | - rgb 9 | - hard_inverse_depth 10 | data: 11 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-hard_depth-l1_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | depth_loss_type: l1+ssim 6 | depth_output_key: hard_inverse_depth 7 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 8 | renderer_output_types: 9 | - rgb 10 | - hard_inverse_depth 11 | data: 12 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 13 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-l1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 3 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 4 | renderer_output_types: 5 | - rgb 6 | - inverse_depth 7 | data: 8 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 9 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-l1_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | depth_loss_type: l1+ssim 6 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 7 | renderer_output_types: 8 | - rgb 9 | - inverse_depth 10 | data: 11 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-l2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | depth_loss_type: l2 6 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 7 | renderer_output_types: 8 | - rgb 9 | - inverse_depth 10 | data: 11 | parser: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/depth_regularization/estimated_inverse_depth-normalized-l1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 4 | init_args: 5 | depth_normalized: true 6 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 7 | renderer_output_types: 8 | - rgb 9 | - inverse_depth 10 | data: 11 | parser: 12 | class_path: internal.dataparsers.estimated_depth_colmap_dataparser.EstimatedDepthColmap 13 | init_args: 14 | depth_rescaling: false 15 | cache_all_images: true -------------------------------------------------------------------------------- /configs/distributed-accel.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | strategy: 3 | class_path: internal.mp_strategy.MPStrategy 4 | devices: -1 5 | model: 6 | renderer: 7 | class_path: internal.renderers.gsplat_distributed_renderer.GSplatDistributedRenderer 8 | init_args: 9 | tile_based_culling: true 10 | density: internal.density_controllers.distributed_vanilla_density_controller.DistributedVanillaDensityController 11 | metric: 12 | fused_ssim: true -------------------------------------------------------------------------------- /configs/distributed.yaml: -------------------------------------------------------------------------------- 1 | # A simplified https://daohanlu.github.io/scaling-up-3dgs/. 2 | # Gaussians are stored, projected and its colors are calculated in a distributed manner. 3 | # Rasterization are done locally. 4 | # No Pixel-wise Distribution. 5 | 6 | trainer: 7 | strategy: 8 | class_path: internal.mp_strategy.MPStrategy 9 | devices: -1 10 | model: 11 | renderer: internal.renderers.gsplat_distributed_renderer.GSplatDistributedRenderer 12 | density: internal.density_controllers.distributed_vanilla_density_controller.DistributedVanillaDensityController -------------------------------------------------------------------------------- /configs/feature_3dgs/lseg-speedup.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 10000 2 | model: 3 | renderer: 4 | class_path: internal.renderers.feature_3dgs_renderer.Feature3DGSRenderer 5 | init_args: 6 | speedup: true 7 | n_feature_dims: 512 8 | metric: internal.metrics.feature_3dgs_metrics.Feature3DGSMetrics 9 | density: internal.density_controllers.static_density_controller.StaticDensityController 10 | data: 11 | parser: 12 | class_path: Feature3DGSColmap 13 | init_args: 14 | feature_dir: "rgb_feature_langseg" 15 | filename_suffix: "_fmap_CxHxW" 16 | filename_include_image_ext: false -------------------------------------------------------------------------------- /configs/feature_3dgs/lseg.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 10000 2 | model: 3 | renderer: 4 | class_path: internal.renderers.feature_3dgs_renderer.Feature3DGSRenderer 5 | init_args: 6 | speedup: false 7 | n_feature_dims: 512 8 | metric: internal.metrics.feature_3dgs_metrics.Feature3DGSMetrics 9 | density: internal.density_controllers.static_density_controller.StaticDensityController 10 | data: 11 | parser: 12 | class_path: Feature3DGSColmap 13 | init_args: 14 | feature_dir: "rgb_feature_langseg" 15 | filename_suffix: "_fmap_CxHxW" 16 | filename_include_image_ext: false -------------------------------------------------------------------------------- /configs/feature_3dgs/sam-speedup.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 10000 2 | model: 3 | renderer: 4 | class_path: internal.renderers.feature_3dgs_renderer.Feature3DGSRenderer 5 | init_args: 6 | speedup: true 7 | n_feature_dims: 256 8 | metric: internal.metrics.feature_3dgs_metrics.Feature3DGSMetrics 9 | density: internal.density_controllers.static_density_controller.StaticDensityController 10 | data: 11 | parser: Feature3DGSColmap -------------------------------------------------------------------------------- /configs/feature_3dgs/sam.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 10000 2 | model: 3 | renderer: 4 | class_path: internal.renderers.feature_3dgs_renderer.Feature3DGSRenderer 5 | init_args: 6 | speedup: false 7 | n_feature_dims: 256 8 | metric: internal.metrics.feature_3dgs_metrics.Feature3DGSMetrics 9 | density: internal.density_controllers.static_density_controller.StaticDensityController 10 | data: 11 | parser: Feature3DGSColmap -------------------------------------------------------------------------------- /configs/fused_ssim.yaml: -------------------------------------------------------------------------------- 1 | # pip install git+https://github.com/rahul-goel/fused-ssim.git@d99e3d27513fa3563d98f74fcd40fd429e9e9b0e 2 | 3 | model: 4 | metric: 5 | init_args: 6 | fused_ssim: true 7 | -------------------------------------------------------------------------------- /configs/gsplat-absgrad-experiment.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | val_max_num_images_to_cache: -1 3 | test_max_num_images_to_cache: -1 4 | parser: 5 | class_path: Colmap 6 | init_args: 7 | split_mode: "experiment" 8 | model: 9 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 10 | density: 11 | densify_grad_threshold: 0.0006 12 | absgrad: true -------------------------------------------------------------------------------- /configs/gsplat-absgrad.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | val_max_num_images_to_cache: -1 3 | test_max_num_images_to_cache: -1 4 | model: 5 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 6 | density: 7 | densify_grad_threshold: 0.0006 8 | absgrad: true -------------------------------------------------------------------------------- /configs/gsplat-mcmc.yaml: -------------------------------------------------------------------------------- 1 | # 3D Gaussian Splatting as Markov Chain Monte Carlo 2 | # https://ubc-vision.github.io/3dgs-mcmc/ 3 | model: 4 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 5 | metric: internal.metrics.mcmc_metrics.MCMCMetrics 6 | density: internal.density_controllers.mcmc_density_controller.MCMCDensityController 7 | data: 8 | val_max_num_images_to_cache: -1 9 | test_max_num_images_to_cache: -1 -------------------------------------------------------------------------------- /configs/gsplat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer -------------------------------------------------------------------------------- /configs/gsplat_v1-accel-steerable.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 4 | init_args: 5 | separate_sh: true 6 | tile_based_culling: true 7 | metric: 8 | init_args: 9 | fused_ssim: true 10 | density: internal.density_controllers.taming_3dgs_density_controller.Taming3DGSDensityController 11 | -------------------------------------------------------------------------------- /configs/gsplat_v1-accel.yaml: -------------------------------------------------------------------------------- 1 | # acceleration with competitive quality 2 | model: 3 | renderer: 4 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 5 | init_args: 6 | separate_sh: true 7 | tile_based_culling: true 8 | metric: 9 | init_args: 10 | fused_ssim: true -------------------------------------------------------------------------------- /configs/gsplat_v1-accel_more.yaml: -------------------------------------------------------------------------------- 1 | # more acceleration, but slightly lower quality 2 | model: 3 | gaussian: 4 | init_args: 5 | optimization: 6 | optimizer: SelectiveAdam 7 | renderer: 8 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 9 | init_args: 10 | separate_sh: true 11 | tile_based_culling: true 12 | metric: 13 | init_args: 14 | fused_ssim: true -------------------------------------------------------------------------------- /configs/gsplat_v1-tile_based_culling-selective_adam.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | init_args: 4 | optimization: 5 | optimizer: SelectiveAdam 6 | renderer: 7 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 8 | init_args: 9 | tile_based_culling: true 10 | -------------------------------------------------------------------------------- /configs/gsplat_v1-tile_based_culling.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | class_path: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer 4 | init_args: 5 | tile_based_culling: true -------------------------------------------------------------------------------- /configs/gsplat_v1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_v1_renderer.GSplatV1Renderer -------------------------------------------------------------------------------- /configs/image_on_gpu-uint8.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | image_on_cpu: false 3 | image_uint8: true -------------------------------------------------------------------------------- /configs/image_on_gpu.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | image_on_cpu: false -------------------------------------------------------------------------------- /configs/larger_dataset.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: Colmap 3 | init_args: 4 | split_mode: "reconstruction" 5 | eval_image_select_mode: "ratio" 6 | eval_ratio: 0.01 7 | # https://github.com/graphdeco-inria/gaussian-splatting#faq 8 | # How can I use this for a much larger dataset, like a city district? The current method was not designed for these, but given enough memory, it should work out. However, the approach can struggle in multi-scale detail scenes (extreme close-ups, mixed with far-away shots). This is usually the case in, e.g., driving data sets (cars close up, buildings far away). For such scenes, you can lower the --position_lr_init, --position_lr_final and --scaling_lr (x0.3, x0.1, ...). The more extensive the scene, the lower these values should be. Below, we use default learning rates (left) and --position_lr_init 0.000016 --scaling_lr 0.001" 9 | model: 10 | save_val_output: true 11 | max_save_val_output: 8 12 | gaussian: 13 | optimization: 14 | position_lr_init: 0.000016 15 | scaling_lr: 0.001 -------------------------------------------------------------------------------- /configs/light_gaussian/prune_finetune-gsplat-experiment.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 35_000 2 | logger: wandb 3 | data: 4 | val_max_num_images_to_cache: -1 5 | test_max_num_images_to_cache: -1 6 | parser: 7 | class_path: Colmap 8 | init_args: 9 | split_mode: "experiment" 10 | model: 11 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 12 | gaussian: 13 | optimization: 14 | means_lr_scheduler: 15 | init_args: 16 | max_steps: 35_000 17 | light_gaussian: 18 | prune_steps: 19 | - 30_001 -------------------------------------------------------------------------------- /configs/light_gaussian/prune_finetune-gsplat.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 35_000 2 | data: 3 | val_max_num_images_to_cache: -1 4 | test_max_num_images_to_cache: -1 5 | model: 6 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 7 | gaussian: 8 | optimization: 9 | means_lr_scheduler: 10 | init_args: 11 | max_steps: 35_000 12 | light_gaussian: 13 | prune_steps: 14 | - 30_001 -------------------------------------------------------------------------------- /configs/light_gaussian/prune_finetune.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 35_000 2 | model: 3 | gaussian: 4 | optimization: 5 | means_lr_scheduler: 6 | init_args: 7 | max_steps: 35_000 8 | light_gaussian: 9 | prune_steps: 10 | - 30_001 11 | -------------------------------------------------------------------------------- /configs/light_gaussian/train_densify_prune-gsplat-experiment.yaml: -------------------------------------------------------------------------------- 1 | logger: wandb 2 | data: 3 | val_max_num_images_to_cache: -1 4 | test_max_num_images_to_cache: -1 5 | parser: 6 | class_path: Colmap 7 | init_args: 8 | split_mode: "experiment" 9 | model: 10 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 11 | light_gaussian: 12 | prune_decay: 0.6 13 | prune_percent: 0.6 14 | prune_steps: 15 | - 16_000 16 | - 24_000 -------------------------------------------------------------------------------- /configs/light_gaussian/train_densify_prune-gsplat.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | val_max_num_images_to_cache: -1 3 | test_max_num_images_to_cache: -1 4 | model: 5 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 6 | light_gaussian: 7 | prune_decay: 0.6 8 | prune_percent: 0.6 9 | prune_steps: 10 | - 16_000 11 | - 24_000 -------------------------------------------------------------------------------- /configs/matrixcity/depth-up_background_sphere.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | add_background_sphere: true 3 | background_sphere_distance: 1 4 | background_sphere_points: 204800 5 | background_sphere_color: white 6 | background_sphere_min_altitude: 0. 7 | parser: 8 | class_path: MatrixCity 9 | init_args: 10 | use_depth: true 11 | model: 12 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 13 | metric: 14 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 15 | renderer_output_types: 16 | - rgb 17 | - inverse_depth 18 | cache_all_images: true -------------------------------------------------------------------------------- /configs/matrixcity/depth.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | parser: 3 | class_path: MatrixCity 4 | init_args: 5 | use_depth: true 6 | model: 7 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 8 | metric: 9 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 10 | renderer_output_types: 11 | - rgb 12 | - inverse_depth 13 | cache_all_images: true -------------------------------------------------------------------------------- /configs/matrixcity/gsplat-aerial.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 60_000 2 | data: 3 | train_max_num_images_to_cache: 4096 # avoid OOM 4 | parser: 5 | class_path: MatrixCity 6 | init_args: 7 | train: 8 | - block_1/transforms_origin.json 9 | - block_2/transforms_origin.json 10 | - block_3/transforms_origin.json 11 | - block_4/transforms_origin.json 12 | - block_5/transforms_origin.json 13 | - block_6/transforms_origin.json 14 | - block_7/transforms_origin.json 15 | - block_8/transforms_origin.json 16 | - block_9/transforms_origin.json 17 | - block_10/transforms_origin.json 18 | test: 19 | - block_1_test/transforms_origin.json 20 | depth_read_step: 4 21 | model: 22 | gaussian: 23 | optimization: 24 | spatial_lr_scale: 0.2 # avoid large xyz learning rate 25 | sh_degree: 0 # avoid CUDA OOM 26 | density: 27 | densify_grad_threshold: 0.0006 # avoid CUDA OOM 28 | densify_until_iter: 30_000 29 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer -------------------------------------------------------------------------------- /configs/matrixcity/gsplat-aerial_street-depth_reg-example.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 60_000 2 | trainer: 3 | limit_val_batches: 64 4 | data: 5 | add_background_sphere: true 6 | background_sphere_distance: 2 7 | background_sphere_points: 204800 8 | background_sphere_color: white 9 | train_max_num_images_to_cache: 4096 # avoid OOM 10 | val_max_num_images_to_cache: 64 11 | parser: 12 | class_path: MatrixCity 13 | init_args: 14 | train: 15 | - aerial/block_1/transforms_origin.json 16 | - street_without_water/small_city_road_down_dense/transforms-0_145.json 17 | test: 18 | - aerial/block_1_test/transforms_origin.json 19 | - street_without_water/small_city_road_down_test/transforms-0_2.json 20 | scale: 0.1 # default is 0.01, remember to rescale `spatial_lr_scale` below by the same factor if you changed this 21 | depth_read_step: 2 22 | use_depth: true # load depth maps into training batches 23 | model: 24 | gaussian: 25 | optimization: 26 | spatial_lr_scale: 2 # avoid large xyz learning rate 27 | density: 28 | densification_interval: 200 29 | densify_until_iter: 30_000 30 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 31 | metric: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics # depth regularization metrics 32 | renderer_output_types: 33 | - rgb 34 | - inverse_depth # predict depth map -------------------------------------------------------------------------------- /configs/matrixcity/gsplat-aerial_street-example.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 60_000 2 | data: 3 | train_max_num_images_to_cache: 4096 # avoid OOM 4 | parser: 5 | class_path: MatrixCity 6 | init_args: 7 | train: 8 | - aerial/block_1/transforms_origin.json 9 | - street/small_city_road_down/transforms-0_59.json 10 | test: 11 | - aerial/block_1_test/transforms_origin.json 12 | - street/small_city_road_down_test/transforms-0_2.json 13 | scale: 0.1 # default is 0.01, remember to rescale `spatial_lr_scale` below by the same factor if you changed this 14 | depth_read_step: 4 15 | model: 16 | gaussian: 17 | optimization: 18 | spatial_lr_scale: 2 # avoid large xyz learning rate 19 | sh_degree: 0 # avoid CUDA OOM 20 | density: 21 | densify_grad_threshold: 0.0006 # avoid CUDA OOM 22 | densify_until_iter: 30_000 23 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer -------------------------------------------------------------------------------- /configs/matrixcity/hard_depth.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | parser: 3 | class_path: MatrixCity 4 | init_args: 5 | use_depth: true 6 | model: 7 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 8 | metric: 9 | class_path: internal.metrics.inverse_depth_metrics.HasInverseDepthMetrics 10 | init_args: 11 | depth_output_key: hard_inverse_depth 12 | renderer_output_types: 13 | - rgb 14 | - hard_inverse_depth 15 | cache_all_images: true -------------------------------------------------------------------------------- /configs/mcmc.yaml: -------------------------------------------------------------------------------- 1 | # 3D Gaussian Splatting as Markov Chain Monte Carlo 2 | # https://ubc-vision.github.io/3dgs-mcmc/ 3 | model: 4 | metric: internal.metrics.mcmc_metrics.MCMCMetrics 5 | density: internal.density_controllers.mcmc_density_controller.MCMCDensityController 6 | data: 7 | val_max_num_images_to_cache: -1 8 | test_max_num_images_to_cache: -1 -------------------------------------------------------------------------------- /configs/mip_splatting_gsplat_v2-blender.yaml: -------------------------------------------------------------------------------- 1 | # NOTE: the quality of v2 and v1 are identical 2 | trainer: 3 | check_val_every_n_epoch: 10 4 | data: 5 | val_max_num_images_to_cache: -1 6 | test_max_num_images_to_cache: -1 7 | image_on_cpu: false 8 | model: 9 | gaussian: internal.models.mip_splatting.MipSplatting 10 | renderer: internal.renderers.gsplat_mip_splatting_renderer_v2.GSplatMipSplattingRendererV2 -------------------------------------------------------------------------------- /configs/mip_splatting_gsplat_v2.yaml: -------------------------------------------------------------------------------- 1 | # NOTE: the quality of v2 and v1 are identical 2 | model: 3 | gaussian: internal.models.mip_splatting.MipSplatting 4 | renderer: internal.renderers.gsplat_mip_splatting_renderer_v2.GSplatMipSplattingRendererV2 -------------------------------------------------------------------------------- /configs/pvg_dynamic.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.periodic_vibration_gaussian.PeriodicVibrationGaussian 4 | init_args: 5 | optimization: 6 | opacities_lr: 0.005 7 | renderer: 8 | class_path: internal.renderers.periodic_vibration_gaussian_renderer.PeriodicVibrationGaussianRenderer 9 | init_args: 10 | env_map_res: -1 11 | lambda_self_supervision: -1 12 | metric: internal.metrics.pvg_dynamic_metrics.PVGDynamicMetrics -------------------------------------------------------------------------------- /configs/pypreprocess_gsplat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.pypreprocess_gsplat_renderer.PythonPreprocessGSplatRenderer -------------------------------------------------------------------------------- /configs/random_background.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | random_background: true -------------------------------------------------------------------------------- /configs/reorient.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | parser: 3 | class_path: Colmap 4 | init_args: 5 | reorient: true -------------------------------------------------------------------------------- /configs/scale_reg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.scale_regularization_metrics.ScaleRegularizationMetrics 4 | init_args: 5 | scale_reg_from: 3300 6 | max_scale: -1 7 | max_scale_ratio: 10. 8 | -------------------------------------------------------------------------------- /configs/segany_splatting.yaml: -------------------------------------------------------------------------------- 1 | max_steps: 10000 2 | data: 3 | parser: SegAnyColmap 4 | val_max_num_images_to_cache: -1 5 | test_max_num_images_to_cache: -1 -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-cluster.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | lambda_dssim: 0. 9 | cluster: true 10 | data: 11 | parser: 12 | class_path: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 13 | init_args: 14 | cluster: true 15 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-high_occlusion.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | lower_bound: 0.3 9 | upper_bound: 0.8 10 | lambda_dssim: 0. 11 | data: 12 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 13 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-mask_size_400-with_ssim-opacity_reg_0.01.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | max_mlp_mask_size: 400 9 | opacity_reg: 0.01 10 | data: 11 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-mask_size_400-with_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | max_mlp_mask_size: 400 9 | data: 10 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 11 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-mask_size_400.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | max_mlp_mask_size: 400 9 | lambda_dssim: 0. 10 | data: 11 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-opacity_reg_0.01.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | lambda_dssim: 0. 9 | opacity_reg: 0.01 10 | data: 11 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 12 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp-with_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | data: 8 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 9 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/gsplat-mlp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.gsplat_renderer.GSPlatRenderer 3 | density: 4 | opacity_reset_interval: 999999999 # no reset 5 | metric: 6 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 7 | init_args: 8 | lambda_dssim: 0. 9 | data: 10 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 11 | cache_all_images: true -------------------------------------------------------------------------------- /configs/spot_less_splats/mlp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | density: 3 | opacity_reset_interval: 999999999 # no reset 4 | metric: 5 | class_path: internal.metrics.spotless_metrics.SpotLessMetrics 6 | init_args: 7 | lambda_dssim: 0. 8 | data: 9 | parser: internal.dataparsers.spotless_colmap_dataparser.SpotLessColmap 10 | cache_all_images: true -------------------------------------------------------------------------------- /configs/stp/baseline.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: internal.renderers.stp_renderer.STPRenderer 3 | cache_all_images: true -------------------------------------------------------------------------------- /configs/swag_baseline.yaml: -------------------------------------------------------------------------------- 1 | # SWAG: Splatting in the Wild images with Appearance-conditioned Gaussians 2 | # [NOTE] This is not an official implementation, and can not reach the metrics in the paper 3 | data: 4 | parser: 5 | class_path: PhotoTourism 6 | init_args: 7 | split_mode: experiment 8 | down_sample_factor: 2 9 | model: 10 | renderer: 11 | class_path: internal.renderers.swag_renderer.SWAGRenderer 12 | init_args: 13 | embedding: 14 | num_embeddings: 1536 # make sure this value larger than max_image_id+1 (or max_image_num+1) -------------------------------------------------------------------------------- /configs/taming_3dgs/fused_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | metric: 3 | class_path: internal.metrics.vanilla_with_fused_ssim_metrics.VanillaWithFusedSSIMMetrics 4 | -------------------------------------------------------------------------------- /configs/taming_3dgs/rasterizer-fused_ssim-aa.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | class_path: internal.renderers.taming_3dgs_renderer.Taming3DGSRenderer 4 | init_args: 5 | anti_aliased: true 6 | metric: 7 | class_path: internal.metrics.vanilla_with_fused_ssim_metrics.VanillaWithFusedSSIMMetrics 8 | -------------------------------------------------------------------------------- /configs/taming_3dgs/rasterizer-fused_ssim-sparse_adam-aa.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.sparse_adam_gaussian.VanillaGaussianWithSparseAdam 4 | renderer: 5 | class_path: internal.renderers.taming_3dgs_renderer.Taming3DGSRenderer 6 | init_args: 7 | anti_aliased: true 8 | metric: 9 | class_path: internal.metrics.vanilla_with_fused_ssim_metrics.VanillaWithFusedSSIMMetrics 10 | -------------------------------------------------------------------------------- /configs/taming_3dgs/rasterizer-fused_ssim-sparse_adam.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: 3 | class_path: internal.models.sparse_adam_gaussian.VanillaGaussianWithSparseAdam 4 | renderer: 5 | class_path: internal.renderers.taming_3dgs_renderer.Taming3DGSRenderer 6 | metric: 7 | class_path: internal.metrics.vanilla_with_fused_ssim_metrics.VanillaWithFusedSSIMMetrics 8 | -------------------------------------------------------------------------------- /configs/taming_3dgs/rasterizer-fused_ssim.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | class_path: internal.renderers.taming_3dgs_renderer.Taming3DGSRenderer 4 | metric: 5 | class_path: internal.metrics.vanilla_with_fused_ssim_metrics.VanillaWithFusedSSIMMetrics 6 | -------------------------------------------------------------------------------- /configs/taming_3dgs/rasterizer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | renderer: 3 | class_path: internal.renderers.taming_3dgs_renderer.Taming3DGSRenderer 4 | -------------------------------------------------------------------------------- /configs/vanilla_2dgs.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | gaussian: internal.models.gaussian_2d.Gaussian2D 3 | renderer: internal.renderers.vanilla_2dgs_renderer.Vanilla2DGSRenderer 4 | metric: internal.metrics.gs2d_metrics.GS2DMetrics 5 | density: 6 | class_path: internal.density_controllers.gs2d_density_controller.GS2DDensityController 7 | init_args: 8 | cull_opacity_threshold: 0.05 9 | data: 10 | val_max_num_images_to_cache: -1 11 | test_max_num_images_to_cache: -1 -------------------------------------------------------------------------------- /doc/installation.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | ### A. Clone repository 3 | 4 | ```bash 5 | # clone repository 6 | git clone https://github.com/DekuLiuTesla/CityGaussian.git 7 | cd CityGaussian 8 | ``` 9 | 10 | ### B. Create virtual environment 11 | 12 | ```bash 13 | # create virtual environment 14 | conda create -yn gspl python=3.9 pip 15 | conda activate gspl 16 | ``` 17 | 18 | ### C. Install PyTorch 19 | * Tested on `PyTorch==2.0.1` 20 | * You must install the one match to the version of your nvcc (nvcc --version) 21 | * For CUDA 11.8 22 | 23 | ```bash 24 | pip install -r requirements/pyt201_cu118.txt 25 | ``` 26 | 27 | ### D. Install requirements 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### E. Install additional package for CityGaussian 34 | 35 | ```bash 36 | pip install -r requirements/CityGS.txt 37 | ``` 38 | Note that here we use modified version of Trim2DGS rasterizer, so as to resolve [impulse noise problem](https://github.com/hbb1/2d-gaussian-splatting/issues/174) under street views. This version also avoids interference from out-of-view surfels. -------------------------------------------------------------------------------- /doc/render_video.md: -------------------------------------------------------------------------------- 1 | ## Render Video 2 | ### A. Generate trajectory, filter out floaters, and render a video 3 | ```bash 4 | python tools/render_traj.py --output_path outputs/$NAME --filter --train 5 | ``` 6 | The script will generate and save a ellipse trajectory according to the training cameras. `--filter` means filter out floaters in each bev pillar according to spatial distribution, and `--train` means use training cameras to generate trajectory. Please refer to the script for more control options. 7 | 8 | ### B. Render mesh on appointed trajectory with Blender 9 | First, follow the instrution [here](blender/README.md) to install the blender environment. For video rendering, use the following command: 10 | ```bash 11 | cd blender 12 | python render_run.py --load_path --traj_path --config_dir 13 | ``` 14 | By setting `traj_path`, you can apply the trajectory generated in the previous step. The `config_dir` for GauU-Scene and MatrixCity are provided in `blender/render_cfgs`. -------------------------------------------------------------------------------- /internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/__init__.py -------------------------------------------------------------------------------- /internal/cameras/__init__.py: -------------------------------------------------------------------------------- 1 | from .cameras import Camera, Cameras 2 | -------------------------------------------------------------------------------- /internal/configs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !__init__.py 4 | !appearance.py 5 | !dataset.py 6 | !model.py 7 | !optimization.py 8 | !tcnn_encoding_config.py 9 | !light_gaussian.py 10 | !segany_splatting.py 11 | !instantiate_config.py -------------------------------------------------------------------------------- /internal/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/configs/__init__.py -------------------------------------------------------------------------------- /internal/configs/appearance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class AppearanceModelOptimizationParams: 6 | lr: float = 1e-3 7 | eps: float = 1e-15 8 | gamma: float = 1 9 | max_steps: int = 30_000 10 | 11 | 12 | @dataclass 13 | class AppearanceModelParams: 14 | optimization: AppearanceModelOptimizationParams 15 | 16 | n_grayscale_factors: int = 3 17 | n_gammas: int = 3 18 | n_neurons: int = 32 19 | n_hidden_layers: int = 2 20 | n_frequencies: int = 4 21 | grayscale_factors_activation: str = "Sigmoid" 22 | gamma_activation: str = "Softplus" 23 | 24 | @dataclass 25 | class SwagAppearanceModelParams: 26 | optimization: AppearanceModelOptimizationParams 27 | 28 | n_appearance_count: int=6000 29 | n_appearance_dims: int = 24 30 | n_input_dims: int = 30 31 | n_neurons: int = 64 32 | n_hidden_layers: int = 3 33 | color_activation: str = "Sigmoid" 34 | 35 | @dataclass 36 | class VastAppearanceModelParams: 37 | optimization: AppearanceModelOptimizationParams 38 | 39 | n_appearance_count: int=6000 40 | n_appearance_dims: int = 64 41 | n_rgb_dims: int = 3 42 | std: float = 1e-4 -------------------------------------------------------------------------------- /internal/configs/instantiate_config.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Any 2 | 3 | 4 | class InstantiatableConfig: 5 | def instantiate(self, *args, **kwargs) -> Any: 6 | pass 7 | -------------------------------------------------------------------------------- /internal/configs/light_gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class LightGaussian: 7 | prune_steps: List[int] = field(default_factory=lambda: []) 8 | prune_decay: float = 1. 9 | prune_percent: float = 0.66 10 | prune_type: Literal["v_important_score"] = "v_important_score" 11 | v_pow: float = 0.1 12 | -------------------------------------------------------------------------------- /internal/configs/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from dataclasses import dataclass 3 | from internal.configs.optimization import OptimizationParams 4 | 5 | 6 | @dataclass 7 | class ModelParams: 8 | optimization: OptimizationParams 9 | sh_degree: int = 3 10 | extra_feature_dims: int = 0 11 | -------------------------------------------------------------------------------- /internal/configs/optimization.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Literal 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class OptimizationParams: 7 | position_lr_init: float = 0.00016 8 | position_lr_final: float = 0.0000016 9 | position_lr_delay_mult: float = 0.01 10 | position_lr_max_steps: float = 30_000 11 | feature_lr: float = 0.0025 12 | feature_rest_lr_init: float = 0.0025 / 20. 13 | feature_rest_lr_final_factor: float = 0.1 14 | feature_rest_lr_max_steps: int = -1 15 | feature_extra_lr_init: float = 1e-3 16 | feature_extra_lr_final_factor: float = 0.1 17 | feature_extra_lr_max_steps: int = 30_000 18 | opacity_lr: float = 0.05 19 | scaling_lr: float = 0.005 20 | rotation_lr: float = 0.001 21 | 22 | spatial_lr_scale: float = -1 # auto calculate from camera poses if > 0 23 | -------------------------------------------------------------------------------- /internal/configs/segany_splatting.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Optimization: 6 | lr: float = 0.0025 7 | lr_final_factor: float = 1. 8 | -------------------------------------------------------------------------------- /internal/dataparsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataparser import DataParserConfig, DataParser, ImageSet, DataParserOutputs 2 | 3 | # import to allow use a shorter name for `--data.parser` 4 | from .colmap_dataparser import Colmap 5 | from .blender_dataparser import Blender 6 | from .nsvf_dataparser import NSVF 7 | from .nerfies_dataparser import Nerfies 8 | from .matrix_city_dataparser import MatrixCity 9 | from .phototourism_dataparser import PhotoTourism 10 | from .segany_colmap_dataparser import SegAnyColmap 11 | from .feature_3dgs_dataparser import Feature3DGSColmap 12 | from .colmap_block_dataparser import ColmapBlock 13 | from .estimated_depth_colmap_block_dataparser import EstimatedDepthBlockColmap 14 | -------------------------------------------------------------------------------- /internal/dataparsers/feature_3dgs_dataparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from . import DataParserOutputs, DataParser 6 | from .colmap_dataparser import Colmap, ColmapDataParser 7 | 8 | 9 | @dataclass 10 | class Feature3DGSColmap(Colmap): 11 | feature_dir: str = "semantic/sam_features" 12 | 13 | filename_suffix: str = "" 14 | 15 | filename_include_image_ext: bool = True 16 | 17 | def instantiate(self, path: str, output_path: str, global_rank: int) -> DataParser: 18 | return Feature3DGSColmapDataParser(path=path, output_path=output_path, global_rank=global_rank, params=self) 19 | 20 | 21 | class Feature3DGSColmapDataParser(ColmapDataParser): 22 | def __init__(self, path: str, output_path: str, global_rank: int, params: Feature3DGSColmap) -> None: 23 | super().__init__(path, output_path, global_rank, params) 24 | 25 | def get_outputs(self) -> DataParserOutputs: 26 | dataparser_outputs = super().get_outputs() 27 | 28 | # val_set and test_set are same object 29 | for image_set in [dataparser_outputs.train_set, dataparser_outputs.val_set]: 30 | for idx, image_name in enumerate(image_set.image_names): 31 | if self.params.filename_include_image_ext is False: 32 | image_name = image_name[:image_name.rfind(".")] 33 | semantic_file_name = f"{image_name}{self.params.filename_suffix}.pt" 34 | image_set.extra_data[idx] = os.path.join(self.path, self.params.feature_dir, semantic_file_name) 35 | image_set.extra_data_processor = Feature3DGSColmapDataParser.read_semantic_data 36 | 37 | # remove image paths to avoid caching 38 | for i in [dataparser_outputs.train_set, dataparser_outputs.val_set, dataparser_outputs.test_set]: 39 | for j in range(len(i.image_paths)): 40 | i.image_paths[j] = None 41 | 42 | return dataparser_outputs 43 | 44 | @staticmethod 45 | def read_semantic_data(path): 46 | return torch.load(path, map_location="cpu") 47 | -------------------------------------------------------------------------------- /internal/dataparsers/segany_colmap_dataparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from dataclasses import dataclass 5 | from . import DataParserOutputs, DataParser 6 | from .colmap_dataparser import Colmap, ColmapDataParser 7 | 8 | 9 | @dataclass 10 | class SegAnyColmap(Colmap): 11 | semantic_mask_dir: str = "semantic/masks" 12 | 13 | semantic_scale_dir: str = "semantic/scales" 14 | 15 | def instantiate(self, path: str, output_path: str, global_rank: int) -> DataParser: 16 | return SegAnyColmapDataParser(path, output_path, global_rank, self) 17 | 18 | 19 | class SegAnyColmapDataParser(ColmapDataParser): 20 | def __init__(self, path: str, output_path: str, global_rank: int, params: SegAnyColmap) -> None: 21 | super().__init__(path, output_path, global_rank, params) 22 | 23 | def get_outputs(self) -> DataParserOutputs: 24 | dataparser_outputs = super().get_outputs() 25 | 26 | # val_set and test_set are same object 27 | for image_set in [dataparser_outputs.train_set, dataparser_outputs.val_set]: 28 | for idx, image_name in enumerate(image_set.image_names): 29 | semantic_file_name = f"{image_name}.pt" 30 | image_set.extra_data[idx] = ( 31 | os.path.join(self.path, self.params.semantic_mask_dir, semantic_file_name), 32 | os.path.join(self.path, self.params.semantic_scale_dir, semantic_file_name), 33 | ) 34 | image_set.extra_data_processor = SegAnyColmapDataParser.read_semantic_data 35 | 36 | # remove image paths to avoid caching 37 | for i in [dataparser_outputs.train_set, dataparser_outputs.val_set, dataparser_outputs.test_set]: 38 | for j in range(len(i.image_paths)): 39 | i.image_paths[j] = None 40 | 41 | return dataparser_outputs 42 | 43 | @staticmethod 44 | def read_semantic_data(paths): 45 | mask_path, scale_path = paths 46 | return torch.load(mask_path, map_location="cpu"), torch.load(scale_path, map_location="cpu") 47 | -------------------------------------------------------------------------------- /internal/density_controllers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/density_controllers/__init__.py -------------------------------------------------------------------------------- /internal/density_controllers/accurate_visibility_filter_density_controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | pip install git+https://github.com/yzslab/gsplat.git@accurate_visibility_filter 3 | 4 | Get visibility filter from rasterization instead of projection. 5 | This filter is more accurate, and can improve the evaluation metrics a little bit. 6 | """ 7 | 8 | from dataclasses import dataclass 9 | import torch 10 | from .vanilla_density_controller import VanillaDensityController, VanillaDensityControllerImpl 11 | 12 | 13 | @dataclass 14 | class AccurateVisibilityFilterDensityController(VanillaDensityController): 15 | def instantiate(self, *args, **kwargs) -> "AccurateVisibilityFilterDensityControllerModule": 16 | return AccurateVisibilityFilterDensityControllerModule(self) 17 | 18 | 19 | class AccurateVisibilityFilterDensityControllerModule(VanillaDensityControllerImpl): 20 | def update_states(self, outputs): 21 | viewspace_point_tensor, radii = outputs["viewspace_points"], outputs["radii"] 22 | visibility_filter = viewspace_point_tensor.has_hit_any_pixels 23 | # retrieve viewspace_points_grad_scale if provided 24 | viewspace_points_grad_scale = outputs.get("viewspace_points_grad_scale", None) 25 | 26 | # update states 27 | self.max_radii2D[visibility_filter] = torch.max( 28 | self.max_radii2D[visibility_filter], 29 | radii[visibility_filter] 30 | ) 31 | xys_grad = viewspace_point_tensor.grad 32 | if self.config.absgrad is True: 33 | xys_grad = viewspace_point_tensor.absgrad 34 | self._add_densification_stats(xys_grad, visibility_filter, scale=viewspace_points_grad_scale) 35 | -------------------------------------------------------------------------------- /internal/density_controllers/distributed_vanilla_density_controller.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Union, List 2 | from dataclasses import dataclass 3 | import torch 4 | from lightning import LightningModule 5 | 6 | from .vanilla_density_controller import VanillaDensityController, VanillaDensityControllerImpl 7 | 8 | 9 | @dataclass 10 | class DistributedVanillaDensityController(VanillaDensityController): 11 | def instantiate(self, *args, **kwargs) -> "DistributedVanillaDensityControllerImpl": 12 | return DistributedVanillaDensityControllerImpl(self) 13 | 14 | 15 | class DistributedVanillaDensityControllerImpl(VanillaDensityControllerImpl): 16 | def before_backward(self, outputs: dict, batch, gaussian_model, optimizers: List, global_step: int, pl_module: LightningModule) -> None: 17 | if global_step >= self.config.densify_until_iter: 18 | return 19 | 20 | for i in outputs["projection_results_list"]: 21 | i[1].retain_grad() 22 | 23 | def update_states(self, outputs): 24 | cameras = outputs["cameras"] 25 | projection_results_list = outputs["projection_results_list"] 26 | visible_mask_list = outputs["visible_mask_list"] 27 | # processing for each projection results 28 | for i in range(len(projection_results_list)): 29 | # retrieve data 30 | camera = cameras[i] 31 | radii, xys = projection_results_list[i][0], projection_results_list[i][1] 32 | 33 | viewspace_point_tensor = xys 34 | visibility_filter = visible_mask_list[i] 35 | viewspace_points_grad_scale = torch.ones((2,), dtype=torch.float, device=xys.device) 36 | if outputs["xys_grad_scale_required"] is True: 37 | viewspace_points_grad_scale = 0.5 * torch.tensor([[camera.width, camera.height]], dtype=torch.float, device=xys.device) 38 | 39 | # update states 40 | self.max_radii2D[visibility_filter] = torch.max( 41 | self.max_radii2D[visibility_filter], 42 | radii[visibility_filter] 43 | ) 44 | xys_grad = viewspace_point_tensor.grad 45 | if self.config.absgrad is True: 46 | xys_grad = viewspace_point_tensor.absgrad 47 | self._add_densification_stats(xys_grad, visibility_filter, scale=viewspace_points_grad_scale) 48 | -------------------------------------------------------------------------------- /internal/density_controllers/gs2d_density_controller.py: -------------------------------------------------------------------------------- 1 | from .vanilla_density_controller import VanillaDensityController, VanillaDensityControllerImpl, build_rotation 2 | import torch 3 | 4 | 5 | class GS2DDensityController(VanillaDensityController): 6 | def instantiate(self, *args, **kwargs) -> "GS2DDensityControllerModule": 7 | return GS2DDensityControllerModule(self) 8 | 9 | 10 | class GS2DDensityControllerModule(VanillaDensityControllerImpl): 11 | def _split_means_and_scales(self, gaussian_model, selected_pts_mask, N): 12 | scales = gaussian_model.get_scales() 13 | device = scales.device 14 | 15 | stds = scales[selected_pts_mask].repeat(N, 1) 16 | stds = torch.cat([stds, 0 * torch.ones_like(stds[:, :1])], dim=-1) 17 | means = torch.zeros((stds.size(0), 3), device=device) 18 | samples = torch.normal(mean=means, std=stds) 19 | rots = build_rotation(gaussian_model.get_property("rotations")[selected_pts_mask]).repeat(N, 1, 1) 20 | # Split means and scales, they are a little bit different 21 | new_means = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + gaussian_model.get_means()[selected_pts_mask].repeat(N, 1) 22 | new_scales = gaussian_model.scale_inverse_activation(scales[selected_pts_mask].repeat(N, 1) / (0.8 * N)) 23 | 24 | new_properties = { 25 | "means": new_means, 26 | "scales": new_scales, 27 | } 28 | 29 | return new_properties 30 | -------------------------------------------------------------------------------- /internal/density_controllers/logger_mixin.py: -------------------------------------------------------------------------------- 1 | class LoggerMixin: 2 | def setup(self, stage: str, pl_module: "LightningModule") -> None: 3 | super().setup(stage, pl_module) 4 | 5 | self.avoid_state_dict = {"pl": pl_module} 6 | 7 | def log_metric(self, name, value): 8 | self.avoid_state_dict["pl"].logger.log_metrics( 9 | { 10 | "density/{}".format(name): value, 11 | }, 12 | step=self.avoid_state_dict["pl"].trainer.global_step, 13 | ) 14 | -------------------------------------------------------------------------------- /internal/density_controllers/static_density_controller.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from lightning import LightningModule 3 | from .density_controller import DensityController, DensityControllerImpl 4 | 5 | 6 | @dataclass 7 | class StaticDensityController(DensityController): 8 | def instantiate(self, *args, **kwargs) -> DensityControllerImpl: 9 | return StaticDensityControllerImpl(self) 10 | 11 | 12 | class StaticDensityControllerImpl(DensityControllerImpl): 13 | pass 14 | -------------------------------------------------------------------------------- /internal/encodings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/encodings/__init__.py -------------------------------------------------------------------------------- /internal/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PositionalEncoding(torch.nn.Module): 5 | def __init__(self, input_channels: int, n_frequencies: int, log_sampling: bool = True): 6 | """ 7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 8 | in_channels: number of input channels (3 for both xyz and direction) 9 | """ 10 | super().__init__() 11 | self.n_frequencies = n_frequencies 12 | self.input_channels = input_channels 13 | self.funcs = [torch.sin, torch.cos] 14 | self.output_channels = input_channels * (len(self.funcs) * n_frequencies + 1) 15 | 16 | max_frequencies = n_frequencies - 1 17 | if log_sampling: 18 | self.freq_bands = 2. ** torch.linspace(0., max_frequencies, steps=n_frequencies) 19 | 20 | else: 21 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_frequencies, steps=n_frequencies) 22 | 23 | def forward(self, x): 24 | """ 25 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 26 | Different from the paper, "x" is also in the output 27 | See https://github.com/bmild/nerf/issues/12 28 | 29 | Inputs: 30 | x: (B, self.in_channels) 31 | 32 | Outputs: 33 | out: (B, self.out_channels) 34 | """ 35 | out = [x] 36 | for freq in self.freq_bands: 37 | for func in self.funcs: 38 | out += [func(freq * x)] 39 | 40 | return torch.cat(out, -1) 41 | 42 | def get_output_n_channels(self) -> int: 43 | return self.output_channels 44 | -------------------------------------------------------------------------------- /internal/entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/entrypoints/__init__.py -------------------------------------------------------------------------------- /internal/entrypoints/gspl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from internal.cli import CLI 3 | from jsonargparse import lazy_instance 4 | from lightning.pytorch.cli import ArgsType 5 | 6 | from internal.gaussian_splatting import GaussianSplatting 7 | from internal.dataset import DataModule 8 | from internal.callbacks import SaveGaussian, KeepRunningIfWebViewerEnabled, StopImageSavingThreads, ProgressBar, ValidateOnTrainEnd, StopDataLoaderCacheThread 9 | 10 | 11 | def cli(args: ArgsType = None): 12 | CLI( 13 | GaussianSplatting, 14 | DataModule, 15 | seed_everything_default=42, 16 | auto_configure_optimizers=False, 17 | trainer_defaults={ 18 | "accelerator": "gpu", 19 | "strategy": "auto", 20 | "devices": 1, 21 | # "logger": "TensorBoardLogger", 22 | "num_sanity_val_steps": 1, 23 | # "max_epochs": -1, 24 | "max_steps": 30_000, 25 | "use_distributed_sampler": False, # use custom ddp sampler 26 | "enable_checkpointing": False, 27 | "callbacks": [ 28 | lazy_instance(SaveGaussian), 29 | lazy_instance(ValidateOnTrainEnd), 30 | lazy_instance(KeepRunningIfWebViewerEnabled), 31 | lazy_instance(StopImageSavingThreads), 32 | lazy_instance(ProgressBar), 33 | lazy_instance(StopDataLoaderCacheThread), 34 | ], 35 | }, 36 | save_config_kwargs={"overwrite": True}, 37 | args=args, 38 | ) 39 | # note: don't call fit!! 40 | 41 | 42 | def cli_with_subcommand(subcommand: str): 43 | sys.argv.insert(1, subcommand) 44 | cli() 45 | 46 | 47 | def cli_fit(): 48 | cli_with_subcommand("fit") 49 | 50 | 51 | def cli_val(): 52 | cli_with_subcommand("validate") 53 | 54 | 55 | def cli_test(): 56 | cli_with_subcommand("test") 57 | 58 | 59 | def cli_predict(): 60 | cli_with_subcommand("predict") 61 | -------------------------------------------------------------------------------- /internal/entrypoints/seganygs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import lightning 3 | from internal.cli import CLI 4 | from jsonargparse import lazy_instance 5 | 6 | from internal.segany_splatting import SegAnySplatting 7 | from internal.dataset import DataModule 8 | from internal.callbacks import SaveCheckpoint, ProgressBar, StopDataLoaderCacheThread 9 | import lightning.pytorch.loggers 10 | 11 | 12 | def cli(): 13 | CLI( 14 | SegAnySplatting, 15 | DataModule, 16 | seed_everything_default=42, 17 | trainer_defaults={ 18 | "accelerator": "gpu", 19 | "strategy": "auto", 20 | "devices": 1, 21 | # "logger": "TensorBoardLogger", 22 | "num_sanity_val_steps": 1, 23 | # "max_epochs": 100, 24 | "max_steps": 30_000, 25 | "use_distributed_sampler": False, # use custom ddp sampler 26 | "enable_checkpointing": False, 27 | "callbacks": [ 28 | lazy_instance(SaveCheckpoint), 29 | lazy_instance(ProgressBar), 30 | lazy_instance(StopDataLoaderCacheThread), 31 | ], 32 | }, 33 | save_config_kwargs={"overwrite": True}, 34 | ) 35 | # note: don't call fit!! 36 | 37 | 38 | def cli_with_subcommand(subcommand: str): 39 | sys.argv.insert(1, subcommand) 40 | cli() 41 | 42 | 43 | def cli_fit(): 44 | cli_with_subcommand("fit") 45 | 46 | 47 | def cli_val(): 48 | cli_with_subcommand("validate") 49 | 50 | 51 | def cli_test(): 52 | cli_with_subcommand("test") 53 | 54 | 55 | def cli_predict(): 56 | cli_with_subcommand("predict") 57 | -------------------------------------------------------------------------------- /internal/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/metrics/__init__.py -------------------------------------------------------------------------------- /internal/metrics/feature_3dgs_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple, Dict, Any 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .metric import Metric, MetricImpl 7 | 8 | 9 | @dataclass 10 | class Feature3DGSMetrics(Metric): 11 | def instantiate(self, *args, **kwargs) -> MetricImpl: 12 | return Feature3DGSMetricImpl(self) 13 | 14 | 15 | class Feature3DGSMetricImpl(MetricImpl): 16 | def get_validate_metrics(self, pl_module, gaussian_model, batch, outputs) -> Tuple[Dict[str, float], Dict[str, bool]]: 17 | metrics = {} 18 | metrics_pbar = {} 19 | 20 | _, _, gt_feature_map = batch 21 | 22 | feature_map = outputs["features"] 23 | feature_map = F.interpolate(feature_map.unsqueeze(0), size=(gt_feature_map.shape[1], gt_feature_map.shape[2]), mode='bilinear', align_corners=True).squeeze(0) 24 | 25 | l1_loss = torch.abs((feature_map - gt_feature_map)).mean() 26 | 27 | metrics["loss"] = l1_loss 28 | metrics_pbar["loss"] = True 29 | 30 | return metrics, metrics_pbar 31 | -------------------------------------------------------------------------------- /internal/metrics/mcmc_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D Gaussian Splatting as Markov Chain Monte Carlo 3 | https://ubc-vision.github.io/3dgs-mcmc/ 4 | 5 | Most codes are copied from https://github.com/ubc-vision/3dgs-mcmc 6 | """ 7 | 8 | from typing import Tuple, Dict, Any 9 | import torch 10 | 11 | from .metric import MetricImpl 12 | from .vanilla_metrics import VanillaMetrics, VanillaMetricsImpl 13 | 14 | 15 | class MCMCMetrics(VanillaMetrics): 16 | opacity_reg: float = 0.01 17 | 18 | scale_reg: float = 0.01 19 | 20 | def instantiate(self, *args, **kwargs) -> MetricImpl: 21 | return MCMCMetricsImpl(self) 22 | 23 | 24 | class MCMCMetricsImpl(VanillaMetricsImpl): 25 | def reg_loss(self, gaussian_model, basic_metrics: Tuple[Dict[str, Any], Dict[str, bool]]): 26 | opacity_reg_loss = self.config.opacity_reg * torch.abs(gaussian_model.get_opacity).mean() 27 | scale_reg_loss = self.config.scale_reg * torch.abs(gaussian_model.get_scaling).mean() 28 | 29 | basic_metrics[0]["loss"] = basic_metrics[0]["loss"] + opacity_reg_loss + scale_reg_loss 30 | basic_metrics[0]["o_reg"] = opacity_reg_loss 31 | basic_metrics[0]["s_reg"] = scale_reg_loss 32 | 33 | basic_metrics[1]["o_reg"] = True 34 | basic_metrics[1]["s_reg"] = True 35 | 36 | return basic_metrics 37 | 38 | def get_train_metrics(self, pl_module, gaussian_model, step: int, batch, outputs) -> Tuple[Dict[str, Any], Dict[str, bool]]: 39 | basic_metrics = super().get_train_metrics(pl_module, gaussian_model, step, batch, outputs) 40 | return self.reg_loss(gaussian_model, basic_metrics) 41 | 42 | def get_validate_metrics(self, pl_module, gaussian_model, batch, outputs) -> Tuple[Dict[str, Any], Dict[str, bool]]: 43 | basic_metrics = super().get_validate_metrics(pl_module, gaussian_model, batch, outputs) 44 | return self.reg_loss(gaussian_model, basic_metrics) 45 | -------------------------------------------------------------------------------- /internal/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any 2 | import torch 3 | from internal.configs.instantiate_config import InstantiatableConfig 4 | 5 | 6 | class MetricModule(torch.nn.Module): 7 | def __init__(self, config, *args, **kwargs) -> None: 8 | super().__init__() 9 | self.config = config 10 | 11 | def setup(self, stage: str, pl_module): 12 | pass 13 | 14 | def get_train_metrics(self, pl_module, gaussian_model, step: int, batch, outputs) -> Tuple[Dict[str, Any], Dict[str, bool]]: 15 | """ 16 | :return: 17 | The first dict: contains the metric values. 18 | The `backward()` only will be invoked for the one with key `loss`. 19 | All other values are only for logging. 20 | The second dict: indicates whether the metric value should be shown on progress bar 21 | """ 22 | 23 | return self.get_validate_metrics( 24 | pl_module=pl_module, 25 | gaussian_model=gaussian_model, 26 | batch=batch, 27 | outputs=outputs, 28 | ) 29 | 30 | def training_setup(self, pl_module) -> Tuple: 31 | return [], [] 32 | 33 | def get_validate_metrics(self, pl_module, gaussian_model, batch, outputs) -> Tuple[Dict[str, float], Dict[str, bool]]: 34 | pass 35 | 36 | def on_parameter_move(self, *args, **kwargs): 37 | pass 38 | 39 | 40 | class MetricImpl(MetricModule): 41 | pass 42 | 43 | 44 | class Metric(InstantiatableConfig): 45 | def instantiate(self, *args, **kwargs) -> MetricModule: 46 | pass 47 | -------------------------------------------------------------------------------- /internal/metrics/pvg_dynamic_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple, Dict, Any 3 | import torch 4 | 5 | from .vanilla_metrics import VanillaMetrics, VanillaMetricsImpl 6 | 7 | 8 | @dataclass 9 | class PVGDynamicMetrics(VanillaMetrics): 10 | velocity_reg: float = 0.001 11 | t_reg: float = 0. 12 | opacity_entropy_reg: float = 0. 13 | 14 | def instantiate(self, *args, **kwargs) -> "PVGDynamicMetricsModule": 15 | return PVGDynamicMetricsModule(self) 16 | 17 | 18 | class PVGDynamicMetricsModule(VanillaMetricsImpl): 19 | def _get_basic_metrics(self, pl_module, gaussian_model, batch, outputs): 20 | basic_metrics, pbar = super()._get_basic_metrics(pl_module, gaussian_model, batch, outputs) 21 | 22 | # sparse velocity loss 23 | if self.config.velocity_reg > 0: 24 | velocity_map = outputs["average_velocity"] / outputs["alpha"].detach().clamp_min(1e-5) 25 | v_reg_loss = torch.abs(velocity_map).mean() * self.config.velocity_reg 26 | basic_metrics["loss"] = basic_metrics["loss"] + v_reg_loss 27 | basic_metrics["v_reg"] = v_reg_loss 28 | pbar["v_reg"] = True 29 | 30 | if self.config.t_reg > 0: 31 | t_reg_loss = -torch.abs(outputs["scale_t"] / outputs["alpha"].detach().clamp_min(1e-5)).mean() * self.config.t_reg 32 | basic_metrics["loss"] = basic_metrics["loss"] + t_reg_loss 33 | basic_metrics["t_reg"] = t_reg_loss 34 | pbar["t_reg"] = True 35 | 36 | if self.config.opacity_entropy_reg > 0: 37 | alpha = outputs["alpha"].detach() 38 | o = alpha.clamp(1e-6, 1 - 1e-6) 39 | loss_opacity_entropy = -(o * torch.log(o)).mean() * self.config.opacity_entropy_reg 40 | basic_metrics["loss"] = basic_metrics["loss"] + loss_opacity_entropy 41 | basic_metrics["o_reg"] = loss_opacity_entropy 42 | pbar["o_reg"] = True 43 | 44 | return basic_metrics, pbar 45 | -------------------------------------------------------------------------------- /internal/metrics/vanilla_with_fused_ssim_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from .vanilla_metrics import VanillaMetrics, VanillaMetricsImpl 3 | 4 | 5 | @dataclass 6 | class VanillaWithFusedSSIMMetrics(VanillaMetrics): 7 | fused_ssim: bool = True 8 | 9 | def instantiate(self, *args, **kwargs) -> "VanillaWithFusedSSIMMetricsModule": 10 | return VanillaWithFusedSSIMMetricsModule(self) 11 | 12 | 13 | class VanillaWithFusedSSIMMetricsModule(VanillaMetricsImpl): 14 | pass 15 | -------------------------------------------------------------------------------- /internal/metrics/visibility_map_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any 2 | from .vanilla_metrics import VanillaMetrics, VanillaMetricsImpl 3 | 4 | 5 | class VisibilityMapMetrics(VanillaMetrics): 6 | vis_reg_factor: float = 0.2 7 | 8 | def instantiate(self, *args, **kwargs): 9 | return VisibilityMapMetricsImpl(self) 10 | 11 | 12 | class VisibilityMapMetricsImpl(VanillaMetricsImpl): 13 | def get_train_metrics(self, pl_module, gaussian_model, step: int, batch, outputs) -> Tuple[Dict[str, Any], Dict[str, bool]]: 14 | camera, image_info, extra_data = batch 15 | image_name, gt_image, masked_pixels = image_info 16 | image = outputs["render"] 17 | 18 | visibility_map = outputs["visibility"] 19 | vis_masked_image = image * visibility_map 20 | vis_masked_gt_image = gt_image * visibility_map 21 | 22 | metrics, pbar = super().get_train_metrics( 23 | pl_module, 24 | gaussian_model, 25 | step, 26 | (camera, (image_name, vis_masked_gt_image, masked_pixels), extra_data), 27 | { 28 | "render": vis_masked_image, 29 | }, 30 | ) 31 | 32 | vis_reg = ((1. - visibility_map) ** 2).mean() * self.config.vis_reg_factor 33 | 34 | metrics["loss"] = metrics["loss"] + vis_reg 35 | metrics["vis_reg"] = vis_reg 36 | pbar["vis_reg"] = True 37 | 38 | return metrics, pbar 39 | -------------------------------------------------------------------------------- /internal/model_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/model_components/__init__.py -------------------------------------------------------------------------------- /internal/model_components/envlight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nvdiffrast.torch as dr 3 | 4 | 5 | class EnvLight(torch.nn.Module): 6 | def __init__(self, resolution=1024): 7 | super().__init__() 8 | self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda") 9 | self.base = torch.nn.Parameter( 10 | 0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True), 11 | ) 12 | 13 | def forward(self, l): 14 | l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape) 15 | l = l.contiguous() 16 | prefix = l.shape[:-1] 17 | if len(prefix) != 3: # reshape to [B, H, W, -1] 18 | l = l.reshape(1, 1, -1, l.shape[-1]) 19 | 20 | light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube') 21 | light = light.view(*prefix, -1) 22 | 23 | return light 24 | -------------------------------------------------------------------------------- /internal/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !__init__.py 4 | !appearance_model.py 5 | !gaussian_model.py 6 | !gaussian_model_simplified.py 7 | !simplified_gaussian_model_manager.py 8 | !defromable_model.py 9 | !deform_model.py 10 | !vanilla_deform_model.py 11 | !swag_model.py 12 | !gaussian.py 13 | !vanilla_gaussian.py 14 | !appearance_feature_gaussian.py 15 | !appearance_mip_gaussian.py 16 | !gaussian_2d.py 17 | !periodic_vibration_gaussian.py 18 | !mip_splatting.py 19 | !appearance_gs2d.py 20 | !sparse_adam_gaussian.py 21 | -------------------------------------------------------------------------------- /internal/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/models/__init__.py -------------------------------------------------------------------------------- /internal/models/appearance_gs2d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from .appearance_feature_gaussian import AppearanceFeatureGaussian, AppearanceFeatureGaussianModel 3 | from .gaussian_2d import Gaussian2D, Gaussian2DModelMixin 4 | 5 | 6 | @dataclass 7 | class AppearanceGS2D(AppearanceFeatureGaussian): 8 | def instantiate(self, *args, **kwargs) -> "AppearanceGS2dModel": 9 | return AppearanceGS2dModel(self) 10 | 11 | 12 | class AppearanceGS2dModel(Gaussian2DModelMixin, AppearanceFeatureGaussianModel): 13 | pass 14 | -------------------------------------------------------------------------------- /internal/models/appearance_mip_gaussian.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .appearance_feature_gaussian import AppearanceFeatureGaussian, AppearanceFeatureGaussianModel 4 | from .mip_splatting import MipSplattingConfigMixin, MipSplattingModelMixin 5 | 6 | 7 | @dataclass 8 | class AppearanceMipGaussian(MipSplattingConfigMixin, AppearanceFeatureGaussian): 9 | def instantiate(self, *args, **kwargs) -> "AppearanceMipGaussianModel": 10 | return AppearanceMipGaussianModel(self) 11 | 12 | 13 | class AppearanceMipGaussianModel(MipSplattingModelMixin, AppearanceFeatureGaussianModel): 14 | config: AppearanceMipGaussian 15 | -------------------------------------------------------------------------------- /internal/models/gaussian_2d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | 4 | import torch 5 | 6 | from .vanilla_gaussian import VanillaGaussian, VanillaGaussianModel 7 | 8 | 9 | @dataclass 10 | class Gaussian2D(VanillaGaussian): 11 | def instantiate(self, *args, **kwargs) -> "Gaussian2DModel": 12 | return Gaussian2DModel(self) 13 | 14 | 15 | class Gaussian2DModelMixin: 16 | def before_setup_set_properties_from_pcd(self, xyz: torch.Tensor, rgb: torch.Tensor, property_dict: Dict[str, torch.Tensor], *args, **kwargs): 17 | super().before_setup_set_properties_from_pcd( 18 | xyz=xyz, 19 | rgb=rgb, 20 | property_dict=property_dict, 21 | *args, 22 | **kwargs, 23 | ) 24 | with torch.no_grad(): 25 | property_dict["scales"] = property_dict["scales"][..., :2] 26 | # key to a quality comparable to hbb1/2d-gaussian-splatting 27 | property_dict["rotations"].copy_(torch.rand_like(property_dict["rotations"])) 28 | 29 | def before_setup_set_properties_from_number(self, n: int, property_dict: Dict[str, torch.Tensor], *args, **kwargs): 30 | super().before_setup_set_properties_from_number( 31 | n=n, 32 | property_dict=property_dict, 33 | *args, 34 | **kwargs, 35 | ) 36 | property_dict["scales"] = property_dict["scales"][..., :2] 37 | 38 | 39 | class Gaussian2DModel(Gaussian2DModelMixin, VanillaGaussianModel): 40 | pass 41 | -------------------------------------------------------------------------------- /internal/models/sparse_adam_gaussian.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from .vanilla_gaussian import VanillaGaussian, VanillaGaussianModel, OptimizationConfig 3 | 4 | 5 | @dataclass 6 | class VanillaGaussianWithSparseAdam(VanillaGaussian): 7 | optimization: OptimizationConfig = field(default_factory=lambda: OptimizationConfig( 8 | optimizer={"class_path": "SparseGaussianAdam"} 9 | )) 10 | 11 | def instantiate(self, *args, **kwargs) -> "VanillaGaussianModel": 12 | return VanillaGaussianModel(self) 13 | -------------------------------------------------------------------------------- /internal/renderers/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !__init__.py 4 | !renderer.py 5 | !vanilla_renderer.py 6 | !appearance_mlp_renderer.py 7 | !appearance_swag_renderer.py 8 | !rgb_mlp_renderer.py 9 | !deformable_renderer.py 10 | !vanilla_deformable_renderer.py 11 | !vanilla_gs4d_renderer.py 12 | !gsplat_renderer.py 13 | !gsplat_v1_renderer.py 14 | !pypreprocess_gsplat_renderer.py 15 | !swag_renderer.py 16 | !mip_splatting_gsplat_renderer.py 17 | !gsplat_mip_splatting_renderer_v2.py 18 | !gsplat_hit_pixel_count_renderer.py 19 | !vanilla_2dgs_renderer.py 20 | !sep_depth_trim_2dgs_renderer.py 21 | !vanilla_trim_renderer.py 22 | !seganygs_renderer.py 23 | !gsplat_appearance_embedding_renderer.py 24 | !feature_3dgs_renderer.py 25 | !gsplat_appearance_embedding_visibility_map_renderer.py 26 | !gsplat_distributed_renderer.py 27 | !gsplat_distributed_appearance_embedding_renderer.py 28 | !periodic_vibration_gaussian_renderer.py 29 | !partition_lod_renderer.py 30 | !stp_renderer.py 31 | !appearance_2dgs_renderer.py 32 | !taming_3dgs_renderer.py 33 | -------------------------------------------------------------------------------- /internal/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | from .renderer import RendererOutputTypes, RendererOutputVisualizer, RendererOutputInfo, Renderer, RendererConfig 2 | from .vanilla_renderer import VanillaRenderer 3 | -------------------------------------------------------------------------------- /internal/renderers/gsplat_hit_pixel_count_renderer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from .renderer import Renderer 4 | from .gsplat_renderer import GSPlatRenderer 5 | from gsplat.hit_pixel_count import hit_pixel_count 6 | 7 | 8 | class GSplatHitPixelCountRenderer(Renderer): 9 | @staticmethod 10 | def hit_pixel_count( 11 | means3D: torch.Tensor, # xyz 12 | opacities: torch.Tensor, 13 | scales: Optional[torch.Tensor], 14 | rotations: Optional[torch.Tensor], # remember to normalize them yourself 15 | viewpoint_camera, 16 | scaling_modifier=1.0, 17 | anti_aliased: bool = True, 18 | block_size: int = 16, 19 | extra_projection_kwargs: dict = None, 20 | ): 21 | xys, depths, radii, conics, comp, num_tiles_hit, cov3d = GSPlatRenderer.project( 22 | means3D=means3D, 23 | scales=scales, 24 | rotations=rotations, 25 | viewpoint_camera=viewpoint_camera, 26 | scaling_modifier=scaling_modifier, 27 | block_size=block_size, 28 | extra_projection_kwargs=extra_projection_kwargs, 29 | ) 30 | 31 | if anti_aliased is True: 32 | opacities = opacities * comp[:, None] 33 | 34 | count, opacity_score, alpha_score, visibility_score = hit_pixel_count( 35 | xys, 36 | depths, 37 | radii, 38 | conics, 39 | num_tiles_hit, 40 | opacities, 41 | img_height=int(viewpoint_camera.height.item()), 42 | img_width=int(viewpoint_camera.width.item()), 43 | block_width=block_size, 44 | ) 45 | 46 | return count, opacity_score, alpha_score, visibility_score 47 | -------------------------------------------------------------------------------- /internal/renderers/gsplat_mip_splatting_renderer_v2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from .gsplat_v1_renderer import GSplatV1Renderer, GSplatV1RendererModule 6 | from ..models.mip_splatting import MipSplattingModel 7 | 8 | 9 | @dataclass 10 | class GSplatMipSplattingRendererV2(GSplatV1Renderer): 11 | filter_2d_kernel_size: float = 0.1 12 | 13 | def instantiate(self, *args, **kwargs) -> "GSplatMipSplattingRendererV2Module": 14 | return GSplatMipSplattingRendererV2Module(self) 15 | 16 | 17 | class MipSplattingRendererMixin: 18 | def get_scales(self, camera, gaussian_model: MipSplattingModel, **kwargs): 19 | opacities, scales = gaussian_model.get_3d_filtered_scales_and_opacities() 20 | 21 | return scales, opacities.squeeze(-1) 22 | 23 | def get_opacities(self, camera, gaussian_model: MipSplattingModel, projections, visibility_filter, status: torch.Any, **kwargs): 24 | return status, None 25 | 26 | 27 | class GSplatMipSplattingRendererV2Module(MipSplattingRendererMixin, GSplatV1RendererModule): 28 | pass 29 | -------------------------------------------------------------------------------- /internal/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linketic/CityGaussian/db21484dc262a446d12995633ac1b80bba44d4c9/internal/utils/__init__.py -------------------------------------------------------------------------------- /internal/utils/common.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | 4 | def parse_cfg_args(path) -> Namespace: 5 | with open(path, "r") as f: 6 | cfg_args = f.read() 7 | return eval(cfg_args) 8 | 9 | def parse_cfg_yaml(data): 10 | data = Namespace(**data) 11 | for arg in vars(data): 12 | if isinstance(getattr(data, arg), dict): 13 | setattr(data, arg, parse_cfg_yaml(getattr(data, arg))) 14 | return data -------------------------------------------------------------------------------- /internal/utils/fix_lightning_save_hyperparameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/Lightning-AI/pytorch-lightning/pull/18105 3 | """ 4 | 5 | from contextlib import contextmanager 6 | import lightning.pytorch.core.mixins.hparams_mixin 7 | 8 | if hasattr(lightning.pytorch.core.mixins.hparams_mixin, "_given_hyperparameters_context"): 9 | @contextmanager 10 | def fix_save_hyperparameters(*args, **kwargs): 11 | yield 12 | 13 | lightning.pytorch.core.mixins.hparams_mixin._given_hyperparameters_context = fix_save_hyperparameters 14 | -------------------------------------------------------------------------------- /internal/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torchvision 5 | 6 | 7 | def read_image_as_tensor(path: str) -> torch.Tensor: 8 | return torchvision.io.read_image(path) # [C, H, W] 9 | 10 | 11 | def save_tensor_image(path: str, image: torch.Tensor): 12 | torchvision.utils.save_image(image, path) 13 | 14 | 15 | def read_image(path: str) -> np.ndarray: 16 | pil_image = Image.open(path) 17 | return np.array(pil_image) # [H, W, C] 18 | 19 | 20 | def save_image(path: str, image: np.ndarray): 21 | pil_image = Image.fromarray(image) 22 | pil_image.save(path, subsampling=0, quality=100) 23 | 24 | 25 | def rgba2rgb(rgba: np.ndarray, background: np.ndarray = None) -> np.ndarray: 26 | if background is None: 27 | background = np.array([0, 0, 0], dtype=np.float64) 28 | normalized_rgba = rgba / 255. 29 | rgb = normalized_rgba[:, :, :3] * normalized_rgba[:, :, 3:4] + background * (1 - normalized_rgba[:, :, 3:4]) 30 | return np.asarray(rgb * 255, dtype=np.uint8) 31 | -------------------------------------------------------------------------------- /internal/utils/las_utils.py: -------------------------------------------------------------------------------- 1 | import laspy 2 | import warnings 3 | import numpy as np 4 | 5 | def read_las_fit(filename, attrs=None): 6 | """ 7 | Args: 8 | filename: las file path 9 | attrs: additional attributes to read from the las file 10 | 11 | Returns: 12 | xyz, rgb, attr_dict 13 | """ 14 | if attrs is None: 15 | attrs = [] 16 | 17 | attrs = list(set(attrs + ["scales", "offsets"])) 18 | 19 | inFile = laspy.read(filename) 20 | N_points = len(inFile) 21 | x = np.reshape(inFile.x, (N_points, 1)) 22 | y = np.reshape(inFile.y, (N_points, 1)) 23 | z = np.reshape(inFile.z, (N_points, 1)) 24 | xyz = np.hstack((x, y, z)) 25 | 26 | rgb = np.zeros((N_points, 3), dtype=np.uint16) 27 | if hasattr(inFile, "red") and hasattr(inFile, "green") and hasattr(inFile, "blue"): 28 | r = np.reshape(inFile.red, (N_points, 1)) 29 | g = np.reshape(inFile.green, (N_points, 1)) 30 | b = np.reshape(inFile.blue, (N_points, 1)) 31 | # i = np.reshape(inFile.Reflectance, (N_points, 1)) 32 | rgb = np.float32(np.hstack((r, g, b))) / 65535 33 | else: 34 | print(f"{filename.split('/')[-1]} has no RGB information!") 35 | 36 | attr_dict = {} 37 | for attr in attrs: 38 | value = None 39 | if hasattr(inFile, attr): 40 | value = getattr(inFile, attr) 41 | elif hasattr(inFile.header, attr): 42 | value = getattr(inFile.header, attr) 43 | else: 44 | warnings.warn(f"{filename.split('/')[-1]} has no information for {attr}!") 45 | 46 | if hasattr(value, "array"): 47 | attr_dict[attr] = np.array(value) 48 | else: 49 | attr_dict[attr] = value 50 | 51 | return xyz, rgb, attr_dict -------------------------------------------------------------------------------- /internal/utils/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lpipsPyTorch.modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /internal/utils/lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /internal/utils/lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /internal/utils/psnr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def color_correct(img, ref, num_iters=5, eps=0.5 / 255): 4 | """Warp `img` to match the colors in `ref_img`.""" 5 | if img.shape[-1] != ref.shape[-1]: 6 | raise ValueError( 7 | f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' 8 | ) 9 | 10 | num_channels = img.shape[-1] 11 | img_mat = img.reshape(-1, num_channels) 12 | ref_mat = ref.reshape(-1, num_channels) 13 | 14 | def is_unclipped(z): 15 | return (z >= eps) & (z <= (1 - eps)) # z ∈ [eps, 1-eps]. 16 | 17 | mask0 = is_unclipped(img_mat) 18 | 19 | for _ in range(num_iters): 20 | a_mat = [] 21 | for c in range(num_channels): 22 | a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. 23 | a_mat.append(img_mat) # Linear term. 24 | a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. 25 | a_mat = torch.cat(a_mat, dim=-1) 26 | 27 | warp = [] 28 | for c in range(num_channels): 29 | b = ref_mat[:, c] 30 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 31 | ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) 32 | mb = torch.where(mask, b, torch.zeros_like(b)) 33 | w = torch.linalg.lstsq(ma_mat, mb, rcond=-1).solution # Solve the linear system. 34 | assert torch.all(torch.isfinite(w)) 35 | warp.append(w.squeeze()) 36 | 37 | warp = torch.stack(warp, dim=-1) 38 | img_mat = torch.clamp(torch.matmul(a_mat, warp), 0, 1) 39 | 40 | corrected_img = img_mat.reshape(img.shape) 41 | return corrected_img -------------------------------------------------------------------------------- /internal/utils/rotation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rotation_matrix(a, b): 5 | """Compute the rotation matrix that rotates vector a to vector b. 6 | 7 | Args: 8 | a: The vector to rotate. 9 | b: The vector to rotate to. 10 | Returns: 11 | The rotation matrix. 12 | """ 13 | a = a / torch.linalg.norm(a) 14 | b = b / torch.linalg.norm(b) 15 | v = torch.cross(a, b) 16 | c = torch.dot(a, b) 17 | # If vectors are exactly opposite, we add a little noise to one of them 18 | if c < -1 + 1e-8: 19 | eps = (torch.rand(3) - 0.5) * 0.01 20 | return rotation_matrix(a + eps, b) 21 | s = torch.linalg.norm(v) 22 | skew_sym_mat = torch.Tensor( 23 | [ 24 | [0, -v[2], v[1]], 25 | [v[2], 0, -v[0]], 26 | [-v[1], v[0], 0], 27 | ] 28 | ) 29 | return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s ** 2 + 1e-8)) 30 | 31 | 32 | def qvec2rot(q): 33 | R = torch.zeros((q.size(0), 3, 3), device=q.device) 34 | 35 | r = q[:, 0] 36 | x = q[:, 1] 37 | y = q[:, 2] 38 | z = q[:, 3] 39 | 40 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 41 | R[:, 0, 1] = 2 * (x * y - r * z) 42 | R[:, 0, 2] = 2 * (x * z + r * y) 43 | R[:, 1, 0] = 2 * (x * y + r * z) 44 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 45 | R[:, 1, 2] = 2 * (y * z - r * x) 46 | R[:, 2, 0] = 2 * (x * z - r * y) 47 | R[:, 2, 1] = 2 * (y * z + r * x) 48 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 49 | return R 50 | -------------------------------------------------------------------------------- /internal/viewer/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import ClientThread 2 | from .renderer import ViewerRenderer 3 | -------------------------------------------------------------------------------- /internal/viewer/ui/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform_panel import TransformPanel 2 | from .edit_panel import EditPanel 3 | from .render_panel import populate_render_tab -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from internal.entrypoints.gspl import cli 2 | 3 | if __name__ == "__main__": 4 | cli() 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools.packages.find] 6 | include = ["internal*", "utils*"] 7 | 8 | [project] 9 | name = "gaussian-splatting-lightning" 10 | dynamic = ["version"] 11 | requires-python = ">=3.8" 12 | 13 | [project.scripts] 14 | gs-fit = "internal.entrypoints.gspl:cli_fit" 15 | gs-val = "internal.entrypoints.gspl:cli_val" 16 | gs-test = "internal.entrypoints.gspl:cli_test" 17 | gs-predict = "internal.entrypoints.gspl:cli_predict" 18 | 19 | segany-fit = "internal.entrypoints.seganygs:cli_fit" 20 | segany-val = "internal.entrypoints.seganygs:cli_val" 21 | segany-test = "internal.entrypoints.seganygs:cli_test" 22 | segany-predict = "internal.entrypoints.seganygs:cli_predict" 23 | 24 | gs-viewer = "internal.entrypoints.viewer:cli" 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/lightning23.txt 2 | -------------------------------------------------------------------------------- /requirements/2DGS.txt: -------------------------------------------------------------------------------- 1 | open3d==0.18.0 2 | scikit-image==0.24.0 3 | trimesh==4.4.3 4 | -r diff-surfel-rasterization.txt 5 | -------------------------------------------------------------------------------- /requirements/CityGS.txt: -------------------------------------------------------------------------------- 1 | open3d==0.18.0 2 | scikit-image==0.24.0 3 | trimesh==4.4.3 4 | imageio==2.36.0 5 | imageio-ffmpeg==0.5.1 6 | py3nvml 7 | torch_scatter 8 | git+https://github.com/DekuLiuTesla/diff-surfel-rasterization.git@9eefc03858a30bb3e5f98eccc56f077420ee2aaf 9 | git+https://github.com/DekuLiuTesla/diff-gaussian-rasterization.git@82c31d7ad780f7ef74c4f9485985895e9da2c93a -------------------------------------------------------------------------------- /requirements/SpotLessSplats.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.27.2 2 | transformers==4.40.1 3 | scikit-learn 4 | -------------------------------------------------------------------------------- /requirements/StopThePop.txt: -------------------------------------------------------------------------------- 1 | dacite 2 | git+https://github.com/yzslab/StopThePop-Rasterization.git 3 | -------------------------------------------------------------------------------- /requirements/common.txt: -------------------------------------------------------------------------------- 1 | splines==0.3.0 2 | plyfile==0.8.1 3 | tensorboard 4 | wandb 5 | tqdm 6 | einops 7 | joblib 8 | open3d 9 | viser==0.2.3 10 | opencv-python-headless==4.10.* 11 | matplotlib 12 | mediapy==1.2.2 13 | git+https://github.com/graphdeco-inria/diff-gaussian-rasterization.git@59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d 14 | git+https://github.com/yzslab/simple-knn.git@44f764299fa305faf6ec5ebd99939e0508331503 15 | -------------------------------------------------------------------------------- /requirements/diff-accel-rasterization.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/yzslab/diff-gaussian-rasterization.git@b403ab6c5cfb4ed89265a9759bd4766f9c4b56de 2 | -------------------------------------------------------------------------------- /requirements/diff-surfel-rasterization.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/hbb1/diff-surfel-rasterization.git@e0ed0207b3e0669960cfad70852200a4a5847f61 2 | -------------------------------------------------------------------------------- /requirements/fused-ssim.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/rahul-goel/fused-ssim.git@d99e3d27513fa3563d98f74fcd40fd429e9e9b0e 2 | -------------------------------------------------------------------------------- /requirements/gsplat.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/yzslab/gsplat.git@58f3772541b6fb55e3219b36cd2b64be0584645c 2 | -------------------------------------------------------------------------------- /requirements/lightning23.txt: -------------------------------------------------------------------------------- 1 | lightning[pytorch-extra]==2.3.* 2 | pytorch-lightning==2.3.* 3 | -r common.txt 4 | -------------------------------------------------------------------------------- /requirements/lightning25.txt: -------------------------------------------------------------------------------- 1 | lightning[pytorch-extra]==2.5.* 2 | pytorch-lightning==2.5.* 3 | -r common.txt 4 | -------------------------------------------------------------------------------- /requirements/pyt201_cu118.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/cu118 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | torchaudio==2.0.2 5 | -------------------------------------------------------------------------------- /requirements/pyt251_cu124.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/cu124 2 | torch==2.5.1 3 | torchvision==0.20.* 4 | torchaudio==2.5.* 5 | -------------------------------------------------------------------------------- /requirements/pytorch3d-compile.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/facebookresearch/pytorch3d.git@stable 2 | -------------------------------------------------------------------------------- /requirements/pytorch3d-pre.txt: -------------------------------------------------------------------------------- 1 | fvcore 2 | iopath 3 | -------------------------------------------------------------------------------- /requirements/pytorch3d-py39_cu118_pyt201.txt: -------------------------------------------------------------------------------- 1 | -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu118_pyt201/download.html 2 | pytorch3d 3 | -------------------------------------------------------------------------------- /requirements/sam.txt: -------------------------------------------------------------------------------- 1 | hdbscan 2 | scikit-learn==1.3.2 3 | git+https://github.com/facebookresearch/segment-anything.git 4 | -------------------------------------------------------------------------------- /requirements/tcnn.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 2 | -------------------------------------------------------------------------------- /scripts/data_proc_mc_scratch.sh: -------------------------------------------------------------------------------- 1 | # MatrixCity, Aerial View, block All 2 | mkdir data/matrix_city/aerial/train/block_all 3 | mkdir data/matrix_city/aerial/test/block_all_test 4 | mkdir data/matrix_city/aerial/train/block_all/input 5 | mkdir data/matrix_city/aerial/test/block_all_test/input 6 | cp data/matrix_city/aerial/pose/block_all/transforms_train.json data/matrix_city/aerial/train/block_all/transforms.json 7 | cp data/matrix_city/aerial/pose/block_all/transforms_test.json data/matrix_city/aerial/test/block_all_test/transforms.json 8 | 9 | python tools/transform_json2txt_mc_aerial.py --source_path data/matrix_city/aerial/train/block_all 10 | python tools/transform_json2txt_mc_aerial.py --source_path data/matrix_city/aerial/test/block_all_test 11 | python convert_cam.py -s data/matrix_city/aerial/train/block_all 12 | python convert_cam.py -s data/matrix_city/aerial/test/block_all_test 13 | 14 | # Street View 15 | mkdir data/matrix_city/street/train/block_A 16 | mkdir data/matrix_city/street/test/block_A_test 17 | mkdir data/matrix_city/street/train/block_A/input 18 | mkdir data/matrix_city/street/test/block_A_test/input 19 | cp data/matrix_city/street/pose/block_A/transforms_train.json data/matrix_city/street/train/block_A/transforms.json 20 | cp data/matrix_city/street/pose/block_A/transforms_test.json data/matrix_city/street/test/block_A_test/transforms.json 21 | 22 | python tools/transform_json2txt_mc_street.py --source_path data/matrix_city/street/train/block_A --intrinsic_path data/matrix_city/street/pose/block_A/transforms_train.json 23 | python tools/transform_json2txt_mc_street.py --source_path data/matrix_city/street/test/block_A_test --intrinsic_path data/matrix_city/street/pose/block_A/transforms_test.json 24 | python tools/convert_cam.py -s data/matrix_city/street/train/block_A 25 | python tools/convert_cam.py -s data/matrix_city/street/test/block_A_test -------------------------------------------------------------------------------- /scripts/data_proc_mill19.sh: -------------------------------------------------------------------------------- 1 | # Mill19, Building, Rubble 2 | ln -s data/mill19/building-pixsfm/train/rgbs data/mill19/building-pixsfm/train/images 3 | ln -s data/mill19/building-pixsfm/val/rgbs data/mill19/building-pixsfm/val/images 4 | 5 | ln -s data/mill19/rubble-pixsfm/train/rgbs data/mill19/rubble-pixsfm/train/images 6 | ln -s data/mill19/rubble-pixsfm/val/rgbs data/mill19/rubble-pixsfm/val/images 7 | 8 | mv data/colmap_results/building/train/sparse data/mill19/building-pixsfm/train 9 | mv data/colmap_results/building/val/sparse data/mill19/building-pixsfm/val 10 | 11 | mv data/colmap_results/rubble/train/sparse data/mill19/rubble-pixsfm/train 12 | mv data/colmap_results/rubble/val/sparse data/mill19/rubble-pixsfm/val 13 | 14 | -------------------------------------------------------------------------------- /scripts/data_proc_mill19_scratch.sh: -------------------------------------------------------------------------------- 1 | # Mill19, Building, Rubble 2 | ln -s data/mill19/building-pixsfm/train/rgbs data/mill19/building-pixsfm/train/input 3 | ln -s data/mill19/building-pixsfm/val/rgbs data/mill19/building-pixsfm/val/input 4 | 5 | ln -s data/mill19/rubble-pixsfm/train/rgbs data/mill19/rubble-pixsfm/train/input 6 | ln -s data/mill19/rubble-pixsfm/val/rgbs data/mill19/rubble-pixsfm/val/input 7 | 8 | rm -rf data/mill19/building-pixsfm/train/sparse 9 | rm -rf data/mill19/building-pixsfm/val/sparse 10 | python tools/transform_pt2txt.py --source_path data/mill19/building-pixsfm 11 | python tools/convert_cam.py -s data/mill19/building-pixsfm/train 12 | python tools/convert_cam.py -s data/mill19/building-pixsfm/val 13 | 14 | rm -rf data/mill19/rubble-pixsfm/train/sparse 15 | rm -rf data/mill19/rubble-pixsfm/val/sparse 16 | python tools/transform_pt2txt.py --source_path data/mill19/rubble-pixsfm 17 | python tools/convert_cam.py -s data/mill19/rubble-pixsfm/train 18 | python tools/convert_cam.py -s data/mill19/rubble-pixsfm/val 19 | 20 | -------------------------------------------------------------------------------- /scripts/data_proc_us3d.sh: -------------------------------------------------------------------------------- 1 | # UrbanScene3D, Residence, Sci-Art 2 | python tools/copy_images.py --image_path data/urban_scene_3d/Residence/photos --dataset_path data/urban_scene_3d/residence-pixsfm 3 | python tools/copy_images.py --image_path data/urban_scene_3d/Sci-Art/photos --dataset_path data/urban_scene_3d/sci-art-pixsfm 4 | 5 | mv data/colmap_results/residence/train/sparse data/mill19/residence-pixsfm/train 6 | mv data/colmap_results/residence/val/sparse data/mill19/residence-pixsfm/val 7 | 8 | mv data/colmap_results/sciart/train/sparse data/mill19/sci-art-pixsfm/train 9 | mv data/colmap_results/sciart/val/sparse data/mill19/sci-art-pixsfm/val 10 | -------------------------------------------------------------------------------- /scripts/data_proc_us3d_scratch.sh: -------------------------------------------------------------------------------- 1 | # UrbanScene3D, Residence, Sci-Art 2 | python tools/copy_images.py --image_path data/urban_scene_3d/Residence/photos --dataset_path data/urban_scene_3d/residence-pixsfm 3 | python tools/copy_images.py --image_path data/urban_scene_3d/Sci-Art/photos --dataset_path data/urban_scene_3d/sci-art-pixsfm 4 | 5 | rm -rf data/urban_scene_3d/residence-pixsfm/train/sparse 6 | rm -rf data/urban_scene_3d/residence-pixsfm/val/sparse 7 | python tools/transform_pt2txt.py --source_path data/urban_scene_3d/residence-pixsfm 8 | python tools/convert_cam.py -s data/urban_scene_3d/residence-pixsfm/train 9 | python tools/convert_cam.py -s data/urban_scene_3d/residence-pixsfm/val 10 | 11 | rm -rf data/urban_scene_3d/sci-art-pixsfm/train/sparse 12 | rm -rf data/urban_scene_3d/sci-art-pixsfm/val/sparse 13 | python tools/transform_pt2txt.py --source_path data/urban_scene_3d/sci-art-pixsfm 14 | python tools/convert_cam.py -s data/urban_scene_3d/sci-art-pixsfm/train 15 | python tools/convert_cam.py -s data/urban_scene_3d/sci-art-pixsfm/val 16 | -------------------------------------------------------------------------------- /scripts/estimate_dataset_depths.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 nodes, 2 GPUs per node: sbatch --nodes=4 --gres=gpu:2 --ntasks-per-node=2 scripts/estimate_dataset_depths.slurm DATASET_DIR --preview 4 | 5 | if [ "${1}" == "" ]; then 6 | echo "dataset directory is required" 7 | exit 1 8 | fi 9 | 10 | IMAGE_DIR="${1}/images" 11 | 12 | srun python utils/run_depth_anything_v2.py "${IMAGE_DIR}" "${@:2}" 13 | srun --gres=gpu:1 --nodes=1 -n1 --ntasks-per-node=1 python utils/get_depth_scales.py "${1}" -------------------------------------------------------------------------------- /scripts/get_sam_masks.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "${1}" == "" ] || [ "${2}" == "" ]; then 4 | echo "Usage: get_sam_masks.slurm IMAGE_DIR TRAINED_MODEL_DIR" 5 | exit 1 6 | fi 7 | 8 | srun python utils/get_sam_masks.py "${1}" --preview 9 | srun --gres=gpu:1 --nodes=1 -n1 --ntasks-per-node=1 python utils/get_sam_mask_scales.py "${2}" --preview -------------------------------------------------------------------------------- /scripts/sd_feature_extraction.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 nodes, 2 GPUs per node: sbatch --nodes=4 --gres=gpu:2 --ntasks-per-node=2 scripts/sd_feature_extraction.slurm IMAGE_DIR 4 | 5 | if [ "${1}" == "" ]; then 6 | echo "image directory is required" 7 | exit 1 8 | fi 9 | 10 | srun python utils/sd_feature_extraction.py "${1}" "${@:2}" -------------------------------------------------------------------------------- /scripts/train-meganerf_rubble-partitions.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PARTITION_DATA_PATH=~/dataset/Mega-NeRF/rubble-pixsfm/colmap/dense_max_1600/partitions-size_2.9-enlarge_0.1-visibility_0.9_0.25 4 | PROJECT_NAME=MegaNeRF-rubble-view_independent-hard_depth 5 | 6 | # NOTES 7 | # * `srun` is invoked by the python script 8 | # * these options are vital for the correctness of job scheduling 9 | # * --gpus=1 10 | # * --nodes=1 11 | # * --ntasks=1 12 | # * --exclusive 13 | # 14 | # submitting: sbatch --gpus=6 --cpus-per-gpu=32 scripts/train-meganerf_rubble-partitions.slurm 15 | # remember to specify the value of `--cpus-per-gpu=`, or only one cpu core will be assigned to each process even with the `DefCpuPerGPU` set 16 | python3 utils/train_colmap_partitions_v2.py \ 17 | ${PARTITION_DATA_PATH} \ 18 | -p ${PROJECT_NAME} \ 19 | --scalable-config utils/scalable_param_configs/appearance-with_scheduler-depth_reg.yaml \ 20 | --config configs/appearance_embedding_renderer/view_independent-lr_0.005-with_scheduler-estimated_depth_reg-hard_depth.yaml \ 21 | -- \ 22 | --data.parser.appearance_groups appearance_image_dedicated \ 23 | --model.gaussian.optimization.spatial_lr_scale 1.5 \ 24 | -- \ 25 | --gpus=1 \ 26 | --nodes=1 \ 27 | --ntasks=1 \ 28 | --exclusive -------------------------------------------------------------------------------- /scripts/untar_matrixcity_test.sh: -------------------------------------------------------------------------------- 1 | # Aerial 2 | for num in {1..10} 3 | do 4 | mkdir block_${num}_test/input 5 | tar -xvf block_${num}_test.tar 6 | mv block_${num}_test/*.png block_${num}_test/input 7 | done 8 | 9 | # Street 10 | mkdir small_city_road_down_test/input 11 | tar -xvf small_city_road_down_test.tar 12 | mv small_city_road_down_test/*.png small_city_road_down_test/input 13 | 14 | mkdir small_city_road_horizon_test/input 15 | tar -xvf small_city_road_horizon_test.tar 16 | mv small_city_road_horizon_test/*.png small_city_road_horizon_test/input 17 | 18 | mkdir small_city_road_outside_test/input 19 | tar -xvf small_city_road_outside_test.tar 20 | mv small_city_road_outside_test/*.png small_city_road_outside_test/input 21 | 22 | mkdir small_city_road_vertical_test/input 23 | tar -xvf small_city_road_vertical_test.tar 24 | mv small_city_road_vertical_test/*.png small_city_road_vertical_test/input -------------------------------------------------------------------------------- /scripts/untar_matrixcity_train.sh: -------------------------------------------------------------------------------- 1 | # Aerial 2 | for num in {1..10} 3 | do 4 | mkdir block_$num/input 5 | tar -xvf block_$num.tar 6 | mv block_$num/*.png block_$num/input 7 | done 8 | 9 | # Street 10 | mkdir small_city_road_down/input 11 | tar -xvf small_city_road_down.tar 12 | mv small_city_road_down/*.png small_city_road_down/input 13 | 14 | mkdir small_city_road_horizon/input 15 | tar -xvf small_city_road_horizon.tar 16 | mv small_city_road_horizon/*.png small_city_road_horizon/input 17 | 18 | mkdir small_city_road_outside/input 19 | tar -xvf small_city_road_outside.tar 20 | mv small_city_road_outside/*.png small_city_road_outside/input 21 | 22 | mkdir small_city_road_vertical/input 23 | tar -xvf small_city_road_vertical.tar 24 | mv small_city_road_vertical/*.png small_city_road_vertical/input -------------------------------------------------------------------------------- /seganygs.py: -------------------------------------------------------------------------------- 1 | from internal.entrypoints.seganygs import cli 2 | 3 | if __name__ == "__main__": 4 | cli() 5 | -------------------------------------------------------------------------------- /submodules/README.md: -------------------------------------------------------------------------------- 1 | [NOTE] Submodules are licensed separately. -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !dataset 3 | !.gitignore 4 | !positional_encoding_test.py 5 | !network_factory_test.py 6 | !deformable_model_test.py 7 | !gaussian_projection_test.py 8 | !vanilla_gaussian_model_test.py 9 | !density_controller_utils_test.py -------------------------------------------------------------------------------- /tests/dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !blender_dataparser_test.py 4 | !nerfies_dataparser_test.py -------------------------------------------------------------------------------- /tests/dataset/blender_dataparser_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import unittest 3 | import json 4 | import torch 5 | from internal.dataparsers.blender_dataparser import Blender 6 | 7 | 8 | class BlenderDataparserTestCase(unittest.TestCase): 9 | def test_blender_dataparser(self): 10 | dataset_path = os.path.expanduser("~/data/nerf/nerf_synthetic/lego") 11 | 12 | gt_camera_sets = [] 13 | for i in ["train", "val", "test"]: 14 | with open(os.path.join(dataset_path, "transforms_{}.json".format(i)), "r") as f: 15 | gt_camera_sets.append(json.load(f)) 16 | 17 | dataparser = Blender().instantiate(dataset_path, os.getcwd(), 0) 18 | dataparser_outputs = dataparser.get_outputs() 19 | 20 | for parsed, gt_set in zip([dataparser_outputs.train_set, dataparser_outputs.val_set, dataparser_outputs.test_set], gt_camera_sets): 21 | gt_c2w_list = [] 22 | for frame in gt_set["frames"]: 23 | gt_c2w_list.append(frame["transform_matrix"]) 24 | gt_c2w = torch.tensor(gt_c2w_list) 25 | 26 | w2c = parsed.cameras.world_to_camera.transpose(1, 2) 27 | c2w = torch.linalg.inv(w2c) 28 | c2w[:, :3, 1:3] *= -1 29 | 30 | self.assertTrue(torch.allclose(gt_c2w, c2w, atol=3e-7)) 31 | 32 | self.assertTrue(torch.all(torch.isclose(parsed.cameras.fov_x, torch.tensor(gt_set["camera_angle_x"])))) 33 | self.assertTrue(torch.all(torch.isclose(parsed.cameras.fov_y, torch.tensor(gt_set["camera_angle_x"])))) 34 | 35 | 36 | if __name__ == '__main__': 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /tests/dataset/colmap_dataparser_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import unittest 3 | from internal.dataparsers.colmap_dataparser import Colmap 4 | 5 | 6 | class ColmapDataparserTestCase(unittest.TestCase): 7 | def test_eval_list(self): 8 | eval_list = os.path.expanduser("~/data/Mega-NeRF/rubble-pixsfm/val_set.txt") 9 | 10 | eval_set = {} 11 | with open(eval_list, "r") as f: 12 | for row in f: 13 | eval_set[row.rstrip("\n")] = True 14 | 15 | datapatser = Colmap( 16 | split_mode="experiment", 17 | eval_image_select_mode="list", 18 | eval_list=os.path.expanduser(eval_list), 19 | ).instantiate(os.path.expanduser("~/data/Mega-NeRF/rubble-pixsfm/colmap/"), os.getcwd(), 0) 20 | dataparser_outputs = datapatser.get_outputs() 21 | for i in dataparser_outputs.train_set.image_names: 22 | self.assertTrue(i not in eval_set) 23 | for i in dataparser_outputs.val_set.image_names: 24 | self.assertTrue(i in eval_set) 25 | 26 | datapatser = Colmap( 27 | split_mode="reconstruction", 28 | eval_image_select_mode="list", 29 | eval_list=os.path.expanduser(eval_list), 30 | ).instantiate(os.path.expanduser("~/data/Mega-NeRF/rubble-pixsfm/colmap/"), os.getcwd(), 0) 31 | dataparser_outputs = datapatser.get_outputs() 32 | for i in eval_set: 33 | dataparser_outputs.train_set.image_names.index(i) 34 | for i in dataparser_outputs.val_set.image_names: 35 | self.assertTrue(i in eval_set) 36 | 37 | 38 | if __name__ == '__main__': 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /tests/dataset/matrix_city_dataparser_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import unittest 3 | from internal.configs.dataset import MatrixCityParams 4 | from internal.dataparsers.matrix_city_dataparser import MatrixCityDataParser 5 | 6 | 7 | class MatrixCityDataparserTestCase(unittest.TestCase): 8 | def test_dataparser(self): 9 | dataparser = MatrixCityDataParser( 10 | os.path.expanduser("~/data/matrixcity/aerial/"), 11 | ".", 12 | 0, 13 | MatrixCityParams( 14 | train=["aerial_train/transforms.json"], 15 | test=["aerial_test/transforms.json"], 16 | ) 17 | ) 18 | dataparser.get_outputs() 19 | dataparser = MatrixCityDataParser( 20 | os.path.expanduser("~/data/matrixcity/street/"), 21 | ".", 22 | 0, 23 | MatrixCityParams( 24 | train=["small_city_road_vertical/transforms.json"], 25 | test=["small_city_road_vertical_test/transforms.json"], 26 | depth_read_step=16, 27 | ) 28 | ) 29 | dataparser.get_outputs() 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /tests/dataset/nerfies_dataparser_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import unittest 3 | import torch 4 | from internal.configs.dataset import NerfiesParams 5 | from internal.dataparsers.nerfies_dataparser import NerfiesDataparser 6 | 7 | 8 | class NerfiesDataparserTestCase(unittest.TestCase): 9 | def test_nerfies_dataparser(self): 10 | daraparser = NerfiesDataparser( 11 | path=os.path.expanduser("~/data/DynamicDatasets/HyperNeRF/espresso"), 12 | output_path="/tmp/HyperNeRF", 13 | global_rank=0, 14 | params=NerfiesParams() 15 | ) 16 | outputs_1x = daraparser.get_outputs() 17 | 18 | daraparser = NerfiesDataparser( 19 | path=os.path.expanduser("~/data/DynamicDatasets/HyperNeRF/espresso"), 20 | output_path="/tmp/HyperNeRF", 21 | global_rank=0, 22 | params=NerfiesParams(down_sample_factor=2) 23 | ) 24 | outputs_2x = daraparser.get_outputs() 25 | self.assertTrue(torch.allclose(outputs_1x.train_set.cameras.fx / 2., outputs_2x.train_set.cameras.fx)) 26 | self.assertTrue(torch.allclose(outputs_1x.train_set.cameras.fy / 2., outputs_2x.train_set.cameras.fy)) 27 | self.assertTrue(torch.allclose(outputs_1x.train_set.cameras.cx / 2., outputs_2x.train_set.cameras.cx)) 28 | self.assertTrue(torch.allclose(outputs_1x.train_set.cameras.cy / 2., outputs_2x.train_set.cameras.cy)) 29 | self.assertTrue(torch.all(outputs_2x.train_set.cameras.time <= 1.)) 30 | self.assertTrue(torch.all(outputs_2x.val_set.cameras.time <= 1.)) 31 | 32 | for i in outputs_2x.train_set.image_paths: 33 | self.assertTrue("rgb/2x/" in i) 34 | 35 | self.assertEqual(outputs_2x.val_set.image_names[0], "000001.png") 36 | print(outputs_2x.val_set.cameras[0]) 37 | 38 | w2c = torch.eye(4) 39 | w2c[:3, :3] = outputs_2x.val_set.cameras[0].R 40 | w2c[:3, 3] = outputs_2x.val_set.cameras[0].T 41 | c2w = torch.linalg.inv(w2c) 42 | self.assertTrue(torch.allclose(c2w[:3, 3], torch.tensor([ 43 | 0.008652675425308145, 44 | -0.00921293454554443, 45 | -0.70470877132779 46 | ]))) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /tests/positional_encoding_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from internal.encodings.positional_encoding import PositionalEncoding 6 | from internal.models.vanilla_deform_model import get_embedder 7 | 8 | 9 | class PositionalEncodingTestCase(unittest.TestCase): 10 | def test_positional_encoding(self): 11 | pe1 = PositionalEncoding(3, 10, True) 12 | pe2, _ = get_embedder(10, 3) 13 | 14 | input = torch.arange(3 * 100).reshape((-1, 3)) 15 | self.assertTrue(torch.all(pe1(input) == pe2(input))) 16 | input = torch.randn((100, 3)) 17 | self.assertTrue(torch.all(pe1(input) == pe2(input))) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /tools/add_pypath.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 5 | -------------------------------------------------------------------------------- /tools/block_wandb_sync.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import torch 5 | import numpy as np 6 | import subprocess 7 | from tqdm import tqdm 8 | from argparse import ArgumentParser, Namespace 9 | from concurrent.futures import ProcessPoolExecutor 10 | 11 | 12 | def sync(wandb_path): 13 | cmds = [ 14 | f"wandb sync {wandb_path}", 15 | ] 16 | 17 | for cmd in cmds: 18 | print(cmd) 19 | subprocess.run(cmd, shell=True, check=True) 20 | return True 21 | 22 | def main(): 23 | parser = ArgumentParser(description="Training script parameters") 24 | parser.add_argument('--output_path', type=str, help='path of output folder', default=None) 25 | args = parser.parse_args(sys.argv[1:]) 26 | 27 | blocks_path = os.path.join(args.output_path, 'blocks') 28 | jobs = [f for f in os.listdir(blocks_path) if os.path.isdir(os.path.join(blocks_path, f))] 29 | 30 | with ProcessPoolExecutor(max_workers=len(jobs)) as executor: 31 | futures = [executor.submit(sync, os.path.join(blocks_path, block, 'wandb/latest-run')) for block in jobs] 32 | 33 | for future in futures: 34 | try: 35 | result = future.result() 36 | print(f"Finished job with result: {result}\n") 37 | except Exception as e: 38 | print(f"Failed job with exception: {e}\n") 39 | 40 | if __name__ == "__main__": 41 | main() -------------------------------------------------------------------------------- /tools/clean_outputs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | from argparse import ArgumentParser 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser(description="clean mesh folder under the output path") 10 | parser.add_argument("--output_dir", type=str, default="./outputs", help="Path to the output folder") 11 | 12 | args = parser.parse_args(sys.argv[1:]) 13 | 14 | # remove outputs/*/checkpoints/*6999-xyz_rgb.ply and outputs/*/mesh/fuse.ply 15 | for root, dirs, files in os.walk(args.output_dir): 16 | for file in files: 17 | if file.endswith("xyz_rgb.ply") or file.endswith("fuse.ply") or file.endswith("=6999.ckpt"): 18 | os.remove(os.path.join(root, file)) 19 | print(f"Removed {os.path.join(root, file)}") -------------------------------------------------------------------------------- /tools/copy_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | 11 | def _get_images_opts(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--image_path', type=str, required=True) 15 | parser.add_argument('--dataset_path', type=str, required=True) 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(hparams: Namespace) -> None: 21 | image_path = Path(hparams.image_path) 22 | dataset_path = Path(hparams.dataset_path) 23 | if not (dataset_path / 'train' / 'rgbs').exists(): 24 | (dataset_path / 'train' / 'rgbs').mkdir() 25 | if not (dataset_path / 'val' / 'rgbs').exists(): 26 | (dataset_path / 'val' / 'rgbs').mkdir() 27 | 28 | with (Path(hparams.dataset_path) / 'mappings.txt').open() as f: 29 | for line in tqdm(f): 30 | image_name, metadata_name = line.strip().split(',') 31 | metadata_path = dataset_path / 'train' / 'metadata' / metadata_name 32 | if not metadata_path.exists(): 33 | metadata_path = dataset_path / 'val' / 'metadata' / metadata_name 34 | assert metadata_path.exists() 35 | 36 | distorted = cv2.imread(str(image_path / image_name)) 37 | metadata = torch.load(metadata_path, map_location='cpu') 38 | intrinsics = metadata['intrinsics'] 39 | camera_matrix = np.array([[intrinsics[0], 0, intrinsics[2]], 40 | [0, intrinsics[1], intrinsics[3]], 41 | [0, 0, 1]]) 42 | 43 | undistorted = cv2.undistort(distorted, camera_matrix, metadata['distortion'].numpy()) 44 | assert undistorted.shape[0] == metadata['H'] 45 | assert undistorted.shape[1] == metadata['W'] 46 | 47 | cv2.imwrite(str(metadata_path.parent.parent / 'rgbs' / '{}.{}'.format(metadata_path.stem, 48 | image_name.split('.')[ 49 | -1])), 50 | undistorted) 51 | 52 | 53 | if __name__ == '__main__': 54 | main(_get_images_opts()) -------------------------------------------------------------------------------- /tools/eval_tnt/config.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- 2 | # - TanksAndTemples Website Toolbox - 3 | # - http://www.tanksandtemples.org - 4 | # ---------------------------------------------------------------------------- 5 | # The MIT License (MIT) 6 | # 7 | # Copyright (c) 2017 8 | # Arno Knapitsch 9 | # Jaesik Park 10 | # Qian-Yi Zhou 11 | # Vladlen Koltun 12 | # 13 | # Permission is hereby granted, free of charge, to any person obtaining a copy 14 | # of this software and associated documentation files (the "Software"), to deal 15 | # in the Software without restriction, including without limitation the rights 16 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | # copies of the Software, and to permit persons to whom the Software is 18 | # furnished to do so, subject to the following conditions: 19 | # 20 | # The above copyright notice and this permission notice shall be included in 21 | # all copies or substantial portions of the Software. 22 | # 23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 29 | # THE SOFTWARE. 30 | # ---------------------------------------------------------------------------- 31 | 32 | # some global parameters - do not modify 33 | scenes_tau_dict = { 34 | "Barn": 0.01, 35 | "Caterpillar": 0.005, 36 | "Church": 0.025, 37 | "Courthouse": 0.025, 38 | "Ignatius": 0.003, 39 | "Meetingroom": 0.01, 40 | "Truck": 0.005, 41 | } 42 | -------------------------------------------------------------------------------- /tools/eval_tnt/evaluate_single_scene.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 5 | import numpy as np 6 | import os 7 | import glob 8 | from skimage.morphology import binary_dilation, disk 9 | import argparse 10 | 11 | import trimesh 12 | from pathlib import Path 13 | import subprocess 14 | import sys 15 | import json 16 | 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser( 21 | description='Arguments to evaluate the mesh.' 22 | ) 23 | 24 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be evaluated') 25 | parser.add_argument('--scene', type=str, help='scan id of the input mesh') 26 | parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder') 27 | parser.add_argument('--TNT', type=str, default='Offical_DTU_Dataset', help='path to the GT DTU point clouds') 28 | args = parser.parse_args() 29 | 30 | 31 | TNT_Dataset = args.TNT 32 | out_dir = args.output_dir 33 | Path(out_dir).mkdir(parents=True, exist_ok=True) 34 | scene = args.scene 35 | ply_file = args.input_mesh 36 | result_mesh_file = os.path.join(out_dir, "culled_mesh.ply") 37 | # read scene.json 38 | f"python run.py --dataset-dir {ply_file} --traj-path {TNT_Dataset}/{scene}/{scene}_COLMAP_SfM.log --ply-path {TNT_Dataset}/{scene}/{scene}_COLMAP.ply" -------------------------------------------------------------------------------- /tools/eval_tnt/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=1.3 2 | open3d==0.10 3 | -------------------------------------------------------------------------------- /tools/eval_tnt/trajectory_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | 5 | class CameraPose: 6 | 7 | def __init__(self, meta, mat): 8 | self.metadata = meta 9 | self.pose = mat 10 | 11 | def __str__(self): 12 | return ("Metadata : " + " ".join(map(str, self.metadata)) + "\n" + 13 | "Pose : " + "\n" + np.array_str(self.pose)) 14 | 15 | 16 | def convert_trajectory_to_pointcloud(traj): 17 | pcd = o3d.geometry.PointCloud() 18 | for t in traj: 19 | pcd.points.append(t.pose[:3, 3]) 20 | return pcd 21 | 22 | 23 | def read_trajectory(filename): 24 | traj = [] 25 | with open(filename, "r") as f: 26 | metastr = f.readline() 27 | while metastr: 28 | metadata = map(int, metastr.split()) 29 | mat = np.zeros(shape=(4, 4)) 30 | for i in range(4): 31 | matstr = f.readline() 32 | mat[i, :] = np.fromstring(matstr, dtype=float, sep=" \t") 33 | traj.append(CameraPose(metadata, mat)) 34 | metastr = f.readline() 35 | return traj 36 | 37 | 38 | def write_trajectory(traj, filename): 39 | with open(filename, "w") as f: 40 | for x in traj: 41 | p = x.pose.tolist() 42 | f.write(" ".join(map(str, x.metadata)) + "\n") 43 | f.write("\n".join( 44 | " ".join(map("{0:.12f}".format, p[i])) for i in range(4))) 45 | f.write("\n") 46 | -------------------------------------------------------------------------------- /tools/eval_tnt/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_dir(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /utils/PolyCam.md: -------------------------------------------------------------------------------- 1 | # PolyCam's raw data 2 | 3 | * Only the raw data in LiDAR or Room mode are supported 4 | 5 | * Take a look NeRFStudio's quick start about how to export raw data 6 | 7 | ## Using the raw data in this repo. 8 | 1. Unzip the raw data file 9 | 10 | 2. Convert it to NGP's `transforms.json` and generate a point cloud from the depth maps 11 | ```bash 12 | python utils/polycam2ngp.py RAW_DATA_DIR 13 | ``` 14 | The `RAW_DATA_DIR` must contain the directory named `keyframes`. 15 | 16 | 3. Start training 17 | ```bash 18 | python main.py fit \ 19 | --config config/gsplat.yaml \ 20 | --data.path RAW_DATA_DIR \ 21 | --data.parser internal.dataparsers.ngp_dataparser.NGP \ 22 | ... 23 | ``` 24 | -------------------------------------------------------------------------------- /utils/add_pypath.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 5 | -------------------------------------------------------------------------------- /utils/argparser_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | STOP_INDICATOR = "--" 4 | 5 | 6 | def split_stoppable_args(argv: list): 7 | for idx, v in enumerate(argv + [STOP_INDICATOR]): 8 | if v == STOP_INDICATOR: 9 | break 10 | 11 | return argv[:idx], argv[idx + 1:] 12 | 13 | 14 | def parser_stoppable_args(parser): 15 | argvs = split_stoppable_args(sys.argv[1:]) 16 | 17 | return parser.parse_args(argvs[0]), argvs[1] 18 | 19 | 20 | def test_split_stoppable_args(): 21 | assert split_stoppable_args(["-a", "1", "-b", "2"]) == (["-a", "1", "-b", "2"], []) 22 | assert split_stoppable_args(["-a", "-b", "--", "-c"]) == (["-a", "-b"], ["-c"]) 23 | -------------------------------------------------------------------------------- /utils/ckpt2ply.py: -------------------------------------------------------------------------------- 1 | import add_pypath 2 | import os 3 | import argparse 4 | import lightning 5 | import torch 6 | from internal.utils.gaussian_utils import GaussianPlyUtils 7 | from internal.utils.gaussian_model_loader import GaussianModelLoader 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("input") 11 | parser.add_argument("--output", "-o", required=False, default=None) 12 | parser.add_argument("--colored", "-c", action="store_true", default=False) 13 | args = parser.parse_args() 14 | 15 | # search input file 16 | print("Searching checkpoint file...") 17 | load_file = GaussianModelLoader.search_load_file(args.input) 18 | assert load_file.endswith(".ckpt"), f"Not a valid ckpt file can be found in '{args.input}'" 19 | 20 | # auto select output path if not provided 21 | if args.output is None: 22 | args.output = load_file[:load_file.rfind(".")] + ".ply" 23 | # if provided input path is a directory, write output file to `PROVIDED_PATH/point_cloud/iteration_.../point_cloud.ply` 24 | if os.path.isdir(args.input) is True: 25 | try: 26 | iteration = load_file[load_file.rfind("=") + 1:load_file.rfind(".")] 27 | if len(iteration) > 0: 28 | args.output = os.path.join(args.input, "point_cloud", f"iteration_{iteration}", "point_cloud.ply") 29 | except: 30 | pass 31 | 32 | assert os.path.exists(args.output) is False, f"Output file already exists, please remove it first: '{args.output}'" 33 | 34 | print(f"Loading checkpoint '{load_file}'...") 35 | ckpt = torch.load(load_file) 36 | print("Converting...") 37 | model = GaussianPlyUtils.load_from_state_dict(ckpt["state_dict"]).to_ply_format().save_to_ply(args.output, args.colored) 38 | print(f"Saved to '{args.output}'") 39 | 40 | size = os.path.getsize(args.output) 41 | size_MB = size / 1024.0 / 1024.0 42 | print("Size = {:.2f} MB".format(size_MB)) -------------------------------------------------------------------------------- /utils/depths_downsample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import cv2 4 | import os 5 | from os.path import isfile, join 6 | from concurrent.futures import ThreadPoolExecutor 7 | from tqdm import tqdm 8 | 9 | args_src = None 10 | args_dst = None 11 | downsample_factor = None 12 | 13 | def process_file(f): 14 | depth_map = np.load(join(args_src, f)) 15 | height, width = depth_map.shape 16 | downsampled = cv2.resize(depth_map, (int(width // downsample_factor), int(height // downsample_factor)), interpolation=cv2.INTER_CUBIC) 17 | np.save(join(args_dst, f), downsampled) 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("src") 22 | parser.add_argument("--dst", default=None) 23 | parser.add_argument("--factor", type=float, default=2) 24 | args = parser.parse_args() 25 | 26 | max_threads = min(32, (os.cpu_count() or 1) + 4) 27 | 28 | assert args.src != args.dst 29 | 30 | if args.dst is None: 31 | args.dst = "{}_{}".format(args.src, args.factor) 32 | 33 | args_src = args.src 34 | args_dst = args.dst 35 | downsample_factor = args.factor 36 | 37 | print(args.dst) 38 | 39 | os.makedirs(args.dst) 40 | 41 | depth_maps = [f for f in os.listdir(args.src) if (isfile(join(args.src, f)) and f.endswith('.npy'))] 42 | 43 | 44 | with ThreadPoolExecutor(max_workers=3) as executor: 45 | list(tqdm(executor.map(process_file, depth_maps), total=len(depth_maps), desc="downsampling")) -------------------------------------------------------------------------------- /utils/downsample_pcd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import add_pypath 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | from argparse import ArgumentParser 8 | 9 | if __name__ == "__main__": 10 | # Set up command line argument parser 11 | parser = ArgumentParser(description="Training script parameters") 12 | parser.add_argument('--file_dir', '-f', type=str, help='path to target point cloud', required=True) 13 | parser.add_argument('--vox_size', '-v', type=float, help='downsampling voxel size', default=0.008) 14 | parser.add_argument('--scaling_factor', '-s', type=float, help='position scaling factor', default=1.0) 15 | args = parser.parse_args(sys.argv[1:]) 16 | 17 | pcd = o3d.io.read_point_cloud(args.file_dir) 18 | print(f"{args.file_dir} has {len(pcd.points)} points") 19 | 20 | ds_pcd = pcd.voxel_down_sample(voxel_size=args.vox_size).scale(args.scaling_factor, center=(0, 0, 0)) 21 | print("Downsampled has", len(ds_pcd.points), "points") 22 | print("Average distance of downsampled point cloud: ", np.mean(ds_pcd.compute_nearest_neighbor_distance())) 23 | 24 | save_dir = args.file_dir.replace(".ply", "_ds.ply") 25 | print("Downsampled point cloud is saved to ", save_dir) 26 | o3d.io.write_point_cloud(save_dir, ds_pcd) -------------------------------------------------------------------------------- /utils/dump_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import sys 4 | import yaml 5 | import os 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | if __name__=='__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("path", help="Path to the .ckpt file") 12 | args = parser.parse_args() 13 | ckpt = torch.load(args.path, map_location="cpu") 14 | 15 | def tuple_representer(dumper, data): 16 | return dumper.represent_list(data) 17 | yaml.add_representer(tuple, tuple_representer) 18 | 19 | print(yaml.dump(ckpt['hyper_parameters'], default_flow_style=False)) -------------------------------------------------------------------------------- /utils/edit_with_histories.py: -------------------------------------------------------------------------------- 1 | import os 2 | import add_pypath 3 | import lightning 4 | import argparse 5 | import torch 6 | from tqdm.auto import tqdm 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("ckpt") 10 | parser.add_argument("history_file") 11 | parser.add_argument("output") 12 | args = parser.parse_args() 13 | 14 | assert os.path.exists(args.output) is False 15 | device = torch.device("cuda") 16 | ckpt = torch.load(args.ckpt, map_location=device) 17 | histories = torch.load(args.history_file) 18 | 19 | if "gaussian_model._xyz" in ckpt["state_dict"]: 20 | dict_key_prefix = "gaussian_model._" 21 | xyz = ckpt["state_dict"]["gaussian_model._xyz"] 22 | else: 23 | dict_key_prefix = "gaussian_model.gaussians." 24 | xyz = ckpt["state_dict"]["gaussian_model.gaussians.means"] 25 | preserve_mask = torch.ones((xyz.shape[0],), dtype=torch.bool, device=device) 26 | for operation in tqdm(histories, total=len(histories)): 27 | is_gaussian_selected = torch.ones(xyz.shape[0], device=xyz.device, dtype=torch.bool) 28 | for item in operation: 29 | se3, grid_size = item 30 | se3 = se3.to(device) 31 | new_xyz = torch.matmul(xyz, se3[:3, :3].T) + se3[:3, 3] 32 | x_mask = torch.abs(new_xyz[:, 0]) < grid_size[0] / 2 33 | y_mask = torch.abs(new_xyz[:, 1]) < grid_size[1] / 2 34 | z_mask = new_xyz[:, 2] > 0 35 | # update mask 36 | is_gaussian_selected = torch.bitwise_and(is_gaussian_selected, x_mask) 37 | is_gaussian_selected = torch.bitwise_and(is_gaussian_selected, y_mask) 38 | is_gaussian_selected = torch.bitwise_and(is_gaussian_selected, z_mask) 39 | preserve_mask = torch.bitwise_and(preserve_mask, torch.bitwise_not(is_gaussian_selected)) 40 | 41 | for i in ckpt["state_dict"]: 42 | if i.startswith(dict_key_prefix): 43 | ckpt["state_dict"][i] = ckpt["state_dict"][i][preserve_mask] 44 | # TODO: prune density controller and optimizer states too 45 | torch.save(ckpt, args.output) 46 | print(f"{preserve_mask.sum().item()} of {preserve_mask.shape[0]} Gaussians saved to {args.output}") 47 | -------------------------------------------------------------------------------- /utils/estimate_dataset_depths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("dataset_path") 7 | parser.add_argument("--image_dir", type=str, default="images") 8 | parser.add_argument("--preview", action="store_true", default=False) 9 | parser.add_argument("--downsample_factor", "-d", type=float, default=1) 10 | args = parser.parse_args() 11 | 12 | assert subprocess.call( 13 | args=[ 14 | "python", 15 | "-u", 16 | os.path.join(os.path.dirname(__file__), "run_depth_anything_v2.py"), 17 | os.path.join(args.dataset_path, args.image_dir), 18 | ] + (["--preview"] if args.preview else []) + (["--downsample_factor", str(args.downsample_factor)] if args.downsample_factor != 1 else []), 19 | shell=False, 20 | ) == 0 21 | 22 | assert subprocess.call( 23 | args=[ 24 | "python", 25 | "-u", 26 | os.path.join(os.path.dirname(__file__), "get_depth_scales.py"), 27 | args.dataset_path, 28 | ], 29 | shell=False, 30 | ) == 0 31 | -------------------------------------------------------------------------------- /utils/eval_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("path") 8 | parser.add_argument("--config", "-c", default=None) 9 | parser.add_argument("--project", "-p", default="Blender") 10 | args = parser.parse_args() 11 | 12 | # find scenes 13 | scenes = [] 14 | for i in list(os.listdir(args.path)): 15 | if os.path.exists(os.path.join(args.path, i, "transforms_train.json")) is False: 16 | continue 17 | scenes.append(i) 18 | print(scenes) 19 | 20 | 21 | def start(command: str, scene: str, extra_args: list = None): 22 | arg_list = [ 23 | "python", 24 | "main.py", 25 | command, 26 | "--data.path", os.path.join(args.path, scene), 27 | "--trainer.check_val_every_n_epoch", "10", 28 | "--cache_all_images", 29 | "--logger", "wandb", 30 | "--output", os.path.join("outputs", args.project), 31 | "--project", args.project, 32 | "-n", scene, 33 | ] 34 | if args.config is not None: 35 | arg_list += ["--config", args.config] 36 | if extra_args is not None: 37 | arg_list += extra_args 38 | 39 | subprocess.call(arg_list) 40 | 41 | 42 | with tqdm(scenes) as t: 43 | for i in t: 44 | t.set_description(i) 45 | start("fit", i) 46 | start("validate", i, extra_args=["--save_val"]) 47 | start("test", i, extra_args=["--save_val"]) 48 | -------------------------------------------------------------------------------- /utils/eval_mipnerf360.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | import distibuted_tasks 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("path") 9 | parser.add_argument("--config", "-c", default=None) 10 | parser.add_argument("--down_sample_factor", "--down-sample-facotr", "-d", type=int, default=4) 11 | parser.add_argument("--project", "-p", default="MipNeRF360") 12 | distibuted_tasks.configure_arg_parser(parser) 13 | args, fitting_args = parser.parse_known_args() 14 | print(fitting_args) 15 | 16 | # find scenes 17 | scenes = [] 18 | for i in list(os.listdir(args.path)): 19 | if os.path.isdir(os.path.join(args.path, i, "sparse")) is False: 20 | continue 21 | scenes.append(i) 22 | scenes.sort() 23 | scenes = distibuted_tasks.get_task_list_with_args(args, scenes) 24 | print(scenes) 25 | 26 | 27 | def start(command: str, scene: str, extra_args: list = None): 28 | arg_list = [ 29 | "python", 30 | "main.py", 31 | command, 32 | "--data.parser", "Colmap", # this can be overridden by config file or args latter 33 | ] 34 | if args.config is not None: 35 | arg_list += ["--config", args.config] 36 | if extra_args is not None: 37 | arg_list += extra_args 38 | arg_list += [ 39 | "--data.path", os.path.join(args.path, scene), 40 | "--data.parser.down_sample_factor", "{}".format(args.down_sample_factor), 41 | "--data.parser.split_mode", "experiment", 42 | "--data.parser.down_sample_rounding_mode", "round_half_up", 43 | "--cache_all_images", 44 | "--logger", "wandb", 45 | "--output", os.path.join("outputs", args.project), 46 | "--project", args.project, 47 | "-n", scene, 48 | ] 49 | 50 | return subprocess.call(arg_list) 51 | 52 | 53 | with tqdm(scenes) as t: 54 | for i in t: 55 | t.set_description(i) 56 | start("fit", i, extra_args=fitting_args) 57 | start("validate", i, extra_args=fitting_args + ["--save_val"]) 58 | -------------------------------------------------------------------------------- /utils/finetune_pruned_partitions_v2.py: -------------------------------------------------------------------------------- 1 | import add_pypath 2 | import os 3 | import argparse 4 | from typing import Dict, Any 5 | from dataclasses import dataclass 6 | from train_partitions import PartitionTrainingConfig, PartitionTraining 7 | 8 | 9 | @dataclass 10 | class PartitionFinetuningConfig(PartitionTrainingConfig): 11 | prune_percent: float = 0.6 12 | trained_project: str = None 13 | 14 | @classmethod 15 | def get_extra_init_kwargs(cls, args) -> Dict[str, Any]: 16 | return { 17 | "prune_percent": args.prune_percent, 18 | "trained_project": args.trained_project, 19 | } 20 | 21 | 22 | class PartitionFinetuning(PartitionTraining): 23 | def get_overridable_partition_specific_args(self, partition_idx: int) -> list[str]: 24 | return super().get_overridable_partition_specific_args(partition_idx) + ["--config={}".format(os.path.join( 25 | self.get_project_output_dir_by_name(self.config.trained_project), 26 | self.get_partition_id_str(partition_idx), 27 | "config.yaml", 28 | ))] 29 | 30 | def get_partition_specific_args(self, partition_idx: int) -> list[str]: 31 | return super().get_partition_specific_args(partition_idx) + ["--ckpt_path={}".format(os.path.join( 32 | self.get_project_output_dir_by_name(self.config.trained_project), 33 | self.get_partition_id_str(partition_idx), 34 | "pruned_checkpoints", 35 | "latest-opacity_pruned-{}.ckpt".format(self.config.prune_percent) 36 | ))] 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | PartitionTrainingConfig.configure_argparser(parser, extra_epoches=30) 42 | parser.add_argument("--prune-percent", type=float, default=0.6) 43 | parser.add_argument("--trained-project", "-t", type=str, required=True) 44 | 45 | PartitionFinetuning.start_with_configured_argparser(parser, config_cls=PartitionFinetuningConfig) 46 | 47 | 48 | main() 49 | -------------------------------------------------------------------------------- /utils/generate_image_apperance_groups.py: -------------------------------------------------------------------------------- 1 | """ 2 | For colmap dataset 3 | """ 4 | 5 | import add_pypath 6 | import os 7 | import json 8 | import argparse 9 | from tqdm import tqdm 10 | from internal.utils.colmap import read_images_binary 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("dir") 14 | parser.add_argument("--dirname", action="store_true", default=False, 15 | help="Share same appearance group for every directory") 16 | parser.add_argument("--camera", action="store_true", default=False, 17 | help="Share same appearance group for every camera") 18 | parser.add_argument("--image", action="store_true", default=False, 19 | help="Every image has different appearance") 20 | parser.add_argument("--name", type=str, default=None, 21 | help="output filename without extension") 22 | args = parser.parse_args() 23 | 24 | images_bin_path = os.path.join(args.dir, "sparse", "images.bin") 25 | if os.path.exists(images_bin_path) is False: 26 | images_bin_path = os.path.join(args.dir, "sparse", "0", "images.bin") 27 | 28 | print("reading {}".format(images_bin_path)) 29 | images = read_images_binary(images_bin_path) 30 | image_group = {} 31 | for i in tqdm(images, desc="reading image information"): 32 | image = images[i] 33 | 34 | if args.dirname is True: 35 | key = os.path.dirname(image.name) 36 | elif args.camera is True: 37 | key = image.camera_id 38 | elif args.image is True: 39 | key = image.name 40 | else: 41 | raise ValueError("unsupported group type") 42 | 43 | if key not in image_group: 44 | image_group[key] = [] 45 | image_group[key].append(image.name) 46 | 47 | for i in image_group: 48 | image_group[i].sort() 49 | 50 | save_path = os.path.join(args.dir, "appearance_groups.json" if args.name is None else "{}.json".format(args.name)) 51 | with open(save_path, "w") as f: 52 | json.dump(image_group, f, indent=4, ensure_ascii=False) 53 | print(save_path) 54 | -------------------------------------------------------------------------------- /utils/generate_image_apperance_groups_by_exposure.py: -------------------------------------------------------------------------------- 1 | """ 2 | For HDR-NeRF dataset: https://xhuangcv.github.io/hdr-nerf/ 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("dir") 11 | parser.add_argument("--exposure", type=str, required=True, 12 | help="path to exposure json file") 13 | parser.add_argument("--name", type=str, default=None, 14 | help="output filename without extension") 15 | args = parser.parse_args() 16 | 17 | with open(args.exposure, "r") as f: 18 | exposure_key_by_image_name = json.load(f) 19 | 20 | image_group_by_exposure = {} 21 | for image_name in exposure_key_by_image_name: 22 | exposure = exposure_key_by_image_name[image_name] 23 | if exposure not in image_group_by_exposure: 24 | image_group_by_exposure[exposure] = [] 25 | image_group_by_exposure[exposure].append(image_name) 26 | 27 | save_path = os.path.join(args.dir, "appearance_groups.json" if args.name is None else "{}.json".format(args.name)) 28 | with open(save_path, "w") as f: 29 | json.dump(image_group_by_exposure, f, indent=4, ensure_ascii=False) 30 | print(save_path) 31 | -------------------------------------------------------------------------------- /utils/gs2d_mesh_extraction.py: -------------------------------------------------------------------------------- 1 | import add_pypath 2 | import os 3 | if "OMP_NUM_THREADS" not in os.environ: 4 | os.environ["OMP_NUM_THREADS"] = "4" 5 | from internal.entrypoints.gs2d_mesh_extraction import main 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /utils/image_downsample.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import argparse 4 | from concurrent.futures import ThreadPoolExecutor 5 | from glob import glob 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | 10 | def find_images(path: str, extensions: list) -> list: 11 | image_list = [] 12 | for extension in extensions: 13 | image_list += list(glob(os.path.join(path, "**", "*.{}".format(extension)), recursive=True)) 14 | 15 | # convert to relative path 16 | path_length = len(path) 17 | image_list = [i[path_length:].lstrip("/\\") for i in image_list] 18 | 19 | return image_list 20 | 21 | 22 | def resize_image(image, factor): 23 | width, height = image.size 24 | resized_width, resized_height = round(width / factor), round(height / factor) 25 | return image.resize((resized_width, resized_height)) 26 | 27 | 28 | def process_task(src: str, dst: str, image_name: str, factor: float): 29 | image = Image.open(os.path.join(src, image_name)) 30 | image = resize_image(image, factor) 31 | 32 | output_path = os.path.join(dst, image_name) 33 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 34 | 35 | image.save(output_path, quality=100) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("src") 41 | parser.add_argument("--dst", default=None) 42 | parser.add_argument("--factor", type=float, default=2) 43 | parser.add_argument("--extensions", nargs="+", default=[ 44 | "jpg", 45 | "JPG", 46 | "jpeg", 47 | "JPEG", 48 | "png", 49 | "PNG", 50 | ]) 51 | args = parser.parse_args() 52 | 53 | assert args.src != args.dst 54 | 55 | if args.dst is None: 56 | args.dst = "{}_{}".format(args.src, args.factor) 57 | 58 | image_list = find_images(args.src, args.extensions) 59 | 60 | with ThreadPoolExecutor() as tpe: 61 | future_list = [] 62 | for i in image_list: 63 | future_list.append(tpe.submit( 64 | process_task, 65 | args.src, 66 | args.dst, 67 | i, 68 | args.factor, 69 | )) 70 | 71 | for _ in tqdm(concurrent.futures.as_completed(future_list), total=len(future_list)): 72 | pass 73 | 74 | print("images saved to {}".format(args.dst)) 75 | -------------------------------------------------------------------------------- /utils/matrix_city_frame_group_slice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("input") 7 | parser.add_argument("output") 8 | parser.add_argument("--slice", type=str, nargs="+", 9 | help="start,count") 10 | parser.add_argument("--camera-count", "-c", type=int, default=6) 11 | args = parser.parse_args() 12 | assert args.input != args.output 13 | 14 | with open(args.input, "r") as f: 15 | transforms = json.load(f) 16 | 17 | assert len(transforms["frames"]) % args.camera_count == 0 18 | 19 | selected_frames = [] 20 | for i in args.slice: 21 | start_count = i.split(",") 22 | start = int(start_count[0]) 23 | count = int(start_count[1]) 24 | 25 | index_left = start * args.camera_count 26 | index_right = index_left + args.camera_count * count 27 | selected_frames += transforms["frames"][index_left:index_right] 28 | 29 | # check overlap 30 | selected_frame_index = {} 31 | for i in selected_frames: 32 | frame_index = i["frame_index"] 33 | assert frame_index not in selected_frame_index 34 | selected_frame_index[frame_index] = True 35 | 36 | transforms["frames"] = selected_frames 37 | with open(args.output, "w") as f: 38 | json.dump(transforms, f, indent=4, ensure_ascii=False) 39 | -------------------------------------------------------------------------------- /utils/merge_ply.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | from tqdm import tqdm 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input", type=str, nargs="+") 10 | parser.add_argument("--output", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | xyz_list = [] 14 | rgb_list = [] 15 | with tqdm(args.input) as t: 16 | for ply in t: 17 | t.set_description("Loading {}".format(ply)) 18 | point_cloud = o3d.io.read_point_cloud(ply) 19 | xyz, rgb = np.asarray(point_cloud.points), (np.asarray(point_cloud.colors)) 20 | xyz_list.append(xyz) 21 | rgb_list.append(rgb) 22 | xyz = np.concatenate(xyz_list, axis=0) 23 | rgb = np.concatenate(rgb_list, axis=0) 24 | final_pcd = o3d.geometry.PointCloud() 25 | final_pcd.points = o3d.utility.Vector3dVector(xyz) 26 | final_pcd.colors = o3d.utility.Vector3dVector(rgb) 27 | o3d.io.write_point_cloud(args.output, final_pcd) 28 | -------------------------------------------------------------------------------- /utils/mesh_post_process.py: -------------------------------------------------------------------------------- 1 | import add_pypath 2 | import os 3 | import argparse 4 | import open3d as o3d 5 | from internal.utils.gs2d_mesh_utils import post_process_mesh 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("ply", type=str) 9 | parser.add_argument("num_cluster", type=int) 10 | args = parser.parse_args() 11 | 12 | mesh = o3d.io.read_triangle_mesh(args.ply) 13 | mesh_post = post_process_mesh(mesh, cluster_to_keep=args.num_cluster) 14 | 15 | filename = os.path.basename(args.ply) 16 | filename = "{}-post_{}.ply".format(filename[:-4], args.num_cluster) 17 | output_path = os.path.join( 18 | os.path.dirname(args.ply), 19 | filename, 20 | ) 21 | o3d.io.write_triangle_mesh(output_path, mesh_post) 22 | print("saved to '{}'".format(output_path)) 23 | -------------------------------------------------------------------------------- /utils/prune_by_segany_mask.py: -------------------------------------------------------------------------------- 1 | import add_pypath 2 | import os 3 | import argparse 4 | import torch 5 | from internal.utils.gaussian_model_loader import GaussianModelLoader 6 | from internal.utils.gaussian_utils import GaussianPlyUtils 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--model", type=str, required=True) 10 | parser.add_argument("--mask", type=str, required=True) 11 | parser.add_argument("--output", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | assert args.model != args.output 15 | assert args.output.endswith(".ply") 16 | assert os.path.exists(args.output) is False 17 | 18 | model, _ = GaussianModelLoader.search_and_load(args.model, device="cpu", eval_mode=True, pre_activate=False) 19 | mask = torch.load(args.mask, map_location="cpu")["mask"] 20 | 21 | properties = {key: value[mask] for key, value in model.properties.items()} 22 | GaussianPlyUtils.load_from_model_properties(properties).to_ply_format().save_to_ply(args.output) 23 | print(f"Saved to '{args.output}'") 24 | -------------------------------------------------------------------------------- /utils/requirements.txt: -------------------------------------------------------------------------------- 1 | open3d==0.18.* -------------------------------------------------------------------------------- /utils/scalable_param_configs/appearance-depth_reg.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.renderer.optimization.max_steps: 30000 3 | model.metric.depth_loss_weight.max_steps: 30000 -------------------------------------------------------------------------------- /utils/scalable_param_configs/appearance-with_scheduler-depth_reg.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.gaussian.optimization.appearance_feature_lr_scheduler.max_steps: 30000 3 | model.metric.depth_loss_weight.max_steps: 30000 4 | model.renderer.optimization.max_steps: 30000 -------------------------------------------------------------------------------- /utils/scalable_param_configs/appearance-with_scheduler.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.gaussian.optimization.appearance_feature_lr_scheduler.max_steps: 30000 3 | model.renderer.optimization.max_steps: 30000 -------------------------------------------------------------------------------- /utils/scalable_param_configs/appearance.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.renderer.optimization.max_steps: 30000 -------------------------------------------------------------------------------- /utils/scalable_param_configs/depth_reg.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.metric.depth_loss_weight.max_steps: 30000 -------------------------------------------------------------------------------- /utils/scalable_param_configs/scale_reg.yaml: -------------------------------------------------------------------------------- 1 | scalable: 2 | model.metric.scale_reg_from: 3300 -------------------------------------------------------------------------------- /utils/train_colmap_partitions_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from train_partitions import PartitionTrainingConfig, PartitionTraining 4 | import argparse 5 | 6 | 7 | @dataclass 8 | class ColmapPartitionTrainingConfig(PartitionTrainingConfig): 9 | eval: bool = False 10 | 11 | @classmethod 12 | def get_extra_init_kwargs(cls, args): 13 | return { 14 | "eval": args.eval, 15 | } 16 | 17 | @staticmethod 18 | def configure_argparser(parser, extra_epoches: int = 0): 19 | PartitionTrainingConfig.configure_argparser(parser, extra_epoches) 20 | parser.add_argument("--eval", action="store_true", default=False) 21 | 22 | 23 | class ColmapPartitionTraining(PartitionTraining): 24 | def get_default_dataparser_name(self) -> str: 25 | return "Colmap" 26 | 27 | def get_dataset_specific_args(self, partition_idx: int) -> list[str]: 28 | return [ 29 | "--data.parser.image_list={}".format(os.path.join( 30 | self.path, 31 | "{}.txt".format(self.get_partition_id_str(partition_idx)), 32 | )), 33 | "--data.parser.split_mode={}".format("experiment" if self.config.eval else "reconstruction"), 34 | "--data.parser.eval_step=64", 35 | ] 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | ColmapPartitionTrainingConfig.configure_argparser(parser) 41 | ColmapPartitionTraining.start_with_configured_argparser(parser, config_cls=ColmapPartitionTrainingConfig) 42 | 43 | 44 | main() 45 | -------------------------------------------------------------------------------- /utils/train_matrix_city_partitions_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from train_partitions import PartitionTrainingConfig, PartitionTraining 3 | import argparse 4 | 5 | 6 | class MatrixCityPartitionTraining(PartitionTraining): 7 | def get_default_dataparser_name(self) -> str: 8 | return "MatrixCity" 9 | 10 | def get_dataset_specific_args(self, partition_idx: int) -> list[str]: 11 | return [ 12 | "--data.parser.train={}".format([os.path.join( 13 | self.path, 14 | "partition-{}.json".format(self.get_partition_id_str(partition_idx)), 15 | )]), 16 | "--data.parser.test={}".format([os.path.join( 17 | self.path, 18 | "partition-{}-test.json".format(self.get_partition_id_str(partition_idx)), 19 | )]), 20 | ] 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | PartitionTrainingConfig.configure_argparser(parser) 26 | MatrixCityPartitionTraining.start_with_configured_argparser(parser) 27 | 28 | 29 | main() 30 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | from internal.entrypoints.viewer import cli 2 | 3 | if __name__ == "__main__": 4 | cli() --------------------------------------------------------------------------------