├── README.md ├── assets ├── Inter-Regular.otf ├── evaluation_index_acid.json ├── evaluation_index_dtu_nctx2.json ├── evaluation_index_dtu_nctx3.json ├── evaluation_index_re10k.json ├── evaluation_index_replica_nctx2.json ├── evaluation_index_replica_nctx3.json └── readme_fig │ └── framework.jpg ├── config ├── dataset │ ├── dtu.yaml │ ├── re10k.yaml │ ├── replica.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── bounded_re10k.yaml │ │ └── evaluation_re10k.yaml ├── experiment │ ├── acid.yaml │ ├── dtu.yaml │ ├── re10k.yaml │ └── replica.yaml ├── loss │ ├── depth.yaml │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ ├── backbone │ ├── dino.yaml │ └── resnet.yaml │ ├── costvolume.yaml │ └── costvolume_pyramid.yaml ├── demo.py ├── demo └── demo_example.tar ├── requirements.txt └── src ├── __init__.py ├── config.py ├── dataset ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_re10k.py ├── shims │ ├── augmentation_shim.py │ ├── bounds_shim.py │ ├── crop_shim.py │ └── patch_shim.py ├── types.py ├── validation_wrapper.py └── view_sampler │ ├── __init__.py │ ├── view_sampler.py │ ├── view_sampler_all.py │ ├── view_sampler_arbitrary.py │ ├── view_sampler_bounded.py │ └── view_sampler_evaluation.py ├── evaluation ├── evaluation_cfg.py ├── evaluation_index_generator.py ├── metric_computer.py └── metrics.py ├── geometry ├── epipolar_lines.py └── projection.py ├── global_cfg.py ├── loss ├── __init__.py ├── loss.py ├── loss_depth.py ├── loss_lpips.py └── loss_mse.py ├── main.py ├── misc ├── LocalLogger.py ├── benchmarker.py ├── collation.py ├── discrete_probability_distribution.py ├── heterogeneous_pairings.py ├── image_io.py ├── nn_module_tools.py ├── sh_rotation.py ├── step_tracker.py └── wandb_tools.py ├── model ├── decoder │ ├── __init__.py │ ├── cuda_splatting.py │ ├── decoder.py │ └── decoder_splatting_cuda.py ├── encoder │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── backbone_dino.py │ │ ├── backbone_multiview.py │ │ ├── backbone_pyramid.py │ │ ├── backbone_resnet.py │ │ ├── multiview_transformer.py │ │ ├── mvsformer_module │ │ │ ├── FMT.py │ │ │ ├── __init__.py │ │ │ ├── cost_volume.py │ │ │ ├── dino │ │ │ │ ├── __init__.py │ │ │ │ ├── dinov2.py │ │ │ │ └── layers │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── block.py │ │ │ │ │ ├── dino_head.py │ │ │ │ │ ├── drop_path.py │ │ │ │ │ ├── layer_scale.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── patch_embed.py │ │ │ │ │ └── swiglu_ffn.py │ │ │ ├── losses.py │ │ │ ├── lr_decay.py │ │ │ ├── module.py │ │ │ ├── networks │ │ │ │ ├── DINOv2_mvsformer_model.py │ │ │ │ └── casmvs_model.py │ │ │ ├── position_encoding.py │ │ │ ├── utils.py │ │ │ └── warping.py │ │ └── unimatch │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── backbone.py │ │ │ ├── geometry.py │ │ │ ├── matching.py │ │ │ ├── position.py │ │ │ ├── reg_refine.py │ │ │ ├── transformer.py │ │ │ ├── trident_conv.py │ │ │ ├── unimatch.py │ │ │ └── utils.py │ ├── common │ │ ├── gaussian_adapter.py │ │ ├── gaussians.py │ │ └── sampler.py │ ├── costvolume │ │ ├── conversions.py │ │ ├── depth_predictor_multiview.py │ │ └── ldm_unet │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── unet.py │ │ │ └── util.py │ ├── encoder.py │ ├── encoder_costvolume_pyramid.py │ └── visualization │ │ ├── encoder_visualizer.py │ │ ├── encoder_visualizer_costvolume.py │ │ └── encoder_visualizer_costvolume_cfg.py ├── encodings │ └── positional_encoding.py ├── model_wrapper.py ├── ply_export.py ├── transformer │ ├── attention.py │ ├── feed_forward.py │ ├── pre_norm.py │ └── transformer.py └── types.py ├── scripts ├── convert_dtu.py └── convert_replica.py ├── utils ├── fusion.py └── my_utils.py └── visualization ├── annotation.py ├── camera_trajectory ├── interpolation.py ├── spin.py └── wobble.py ├── color_map.py ├── colors.py ├── drawing ├── cameras.py ├── coordinate_conversion.py ├── lines.py ├── points.py ├── rendering.py └── types.py ├── layout.py ├── validation_in_3d.py └── vis_depth.py /assets/Inter-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/assets/Inter-Regular.otf -------------------------------------------------------------------------------- /assets/evaluation_index_dtu_nctx2.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31], "target": [32]}, "scan1_train_01": {"context": [33, 31], "target": [24]}, "scan1_train_02": {"context": [33, 31], "target": [23]}, "scan1_train_03": {"context": [33, 31], "target": [44]}, "scan8_train_00": {"context": [33, 31], "target": [32]}, "scan8_train_01": {"context": [33, 31], "target": [24]}, "scan8_train_02": {"context": [33, 31], "target": [23]}, "scan8_train_03": {"context": [33, 31], "target": [44]}, "scan21_train_00": {"context": [33, 31], "target": [32]}, "scan21_train_01": {"context": [33, 31], "target": [24]}, "scan21_train_02": {"context": [33, 31], "target": [23]}, "scan21_train_03": {"context": [33, 31], "target": [44]}, "scan30_train_00": {"context": [33, 31], "target": [32]}, "scan30_train_01": {"context": [33, 31], "target": [24]}, "scan30_train_02": {"context": [33, 31], "target": [23]}, "scan30_train_03": {"context": [33, 31], "target": [44]}, "scan31_train_00": {"context": [33, 31], "target": [32]}, "scan31_train_01": {"context": [33, 31], "target": [24]}, "scan31_train_02": {"context": [33, 31], "target": [23]}, "scan31_train_03": {"context": [33, 31], "target": [44]}, "scan34_train_00": {"context": [33, 31], "target": [32]}, "scan34_train_01": {"context": [33, 31], "target": [24]}, "scan34_train_02": {"context": [33, 31], "target": [23]}, "scan34_train_03": {"context": [33, 31], "target": [44]}, "scan38_train_00": {"context": [33, 31], "target": [32]}, "scan38_train_01": {"context": [33, 31], "target": [24]}, "scan38_train_02": {"context": [33, 31], "target": [23]}, "scan38_train_03": {"context": [33, 31], "target": [44]}, "scan40_train_00": {"context": [33, 31], "target": [32]}, "scan40_train_01": {"context": [33, 31], "target": [24]}, "scan40_train_02": {"context": [33, 31], "target": [23]}, "scan40_train_03": {"context": [33, 31], "target": [44]}, "scan41_train_00": {"context": [33, 31], "target": [32]}, "scan41_train_01": {"context": [33, 31], "target": [24]}, "scan41_train_02": {"context": [33, 31], "target": [23]}, "scan41_train_03": {"context": [33, 31], "target": [44]}, "scan45_train_00": {"context": [33, 31], "target": [32]}, "scan45_train_01": {"context": [33, 31], "target": [24]}, "scan45_train_02": {"context": [33, 31], "target": [23]}, "scan45_train_03": {"context": [33, 31], "target": [44]}, "scan55_train_00": {"context": [33, 31], "target": [32]}, "scan55_train_01": {"context": [33, 31], "target": [24]}, "scan55_train_02": {"context": [33, 31], "target": [23]}, "scan55_train_03": {"context": [33, 31], "target": [44]}, "scan63_train_00": {"context": [33, 31], "target": [32]}, "scan63_train_01": {"context": [33, 31], "target": [24]}, "scan63_train_02": {"context": [33, 31], "target": [23]}, "scan63_train_03": {"context": [33, 31], "target": [44]}, "scan82_train_00": {"context": [33, 31], "target": [32]}, "scan82_train_01": {"context": [33, 31], "target": [24]}, "scan82_train_02": {"context": [33, 31], "target": [23]}, "scan82_train_03": {"context": [33, 31], "target": [44]}, "scan103_train_00": {"context": [33, 31], "target": [32]}, "scan103_train_01": {"context": [33, 31], "target": [24]}, "scan103_train_02": {"context": [33, 31], "target": [23]}, "scan103_train_03": {"context": [33, 31], "target": [44]}, "scan110_train_00": {"context": [33, 31], "target": [32]}, "scan110_train_01": {"context": [33, 31], "target": [24]}, "scan110_train_02": {"context": [33, 31], "target": [23]}, "scan110_train_03": {"context": [33, 31], "target": [44]}, "scan114_train_00": {"context": [33, 31], "target": [32]}, "scan114_train_01": {"context": [33, 31], "target": [24]}, "scan114_train_02": {"context": [33, 31], "target": [23]}, "scan114_train_03": {"context": [33, 31], "target": [44]}} -------------------------------------------------------------------------------- /assets/evaluation_index_dtu_nctx3.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31, 43], "target": [32]}, "scan1_train_01": {"context": [33, 31, 43], "target": [24]}, "scan1_train_02": {"context": [33, 31, 43], "target": [23]}, "scan1_train_03": {"context": [33, 31, 43], "target": [44]}, "scan8_train_00": {"context": [33, 31, 43], "target": [32]}, "scan8_train_01": {"context": [33, 31, 43], "target": [24]}, "scan8_train_02": {"context": [33, 31, 43], "target": [23]}, "scan8_train_03": {"context": [33, 31, 43], "target": [44]}, "scan21_train_00": {"context": [33, 31, 43], "target": [32]}, "scan21_train_01": {"context": [33, 31, 43], "target": [24]}, "scan21_train_02": {"context": [33, 31, 43], "target": [23]}, "scan21_train_03": {"context": [33, 31, 43], "target": [44]}, "scan30_train_00": {"context": [33, 31, 43], "target": [32]}, "scan30_train_01": {"context": [33, 31, 43], "target": [24]}, "scan30_train_02": {"context": [33, 31, 43], "target": [23]}, "scan30_train_03": {"context": [33, 31, 43], "target": [44]}, "scan31_train_00": {"context": [33, 31, 43], "target": [32]}, "scan31_train_01": {"context": [33, 31, 43], "target": [24]}, "scan31_train_02": {"context": [33, 31, 43], "target": [23]}, "scan31_train_03": {"context": [33, 31, 43], "target": [44]}, "scan34_train_00": {"context": [33, 31, 43], "target": [32]}, "scan34_train_01": {"context": [33, 31, 43], "target": [24]}, "scan34_train_02": {"context": [33, 31, 43], "target": [23]}, "scan34_train_03": {"context": [33, 31, 43], "target": [44]}, "scan38_train_00": {"context": [33, 31, 43], "target": [32]}, "scan38_train_01": {"context": [33, 31, 43], "target": [24]}, "scan38_train_02": {"context": [33, 31, 43], "target": [23]}, "scan38_train_03": {"context": [33, 31, 43], "target": [44]}, "scan40_train_00": {"context": [33, 31, 43], "target": [32]}, "scan40_train_01": {"context": [33, 31, 43], "target": [24]}, "scan40_train_02": {"context": [33, 31, 43], "target": [23]}, "scan40_train_03": {"context": [33, 31, 43], "target": [44]}, "scan41_train_00": {"context": [33, 31, 43], "target": [32]}, "scan41_train_01": {"context": [33, 31, 43], "target": [24]}, "scan41_train_02": {"context": [33, 31, 43], "target": [23]}, "scan41_train_03": {"context": [33, 31, 43], "target": [44]}, "scan45_train_00": {"context": [33, 31, 43], "target": [32]}, "scan45_train_01": {"context": [33, 31, 43], "target": [24]}, "scan45_train_02": {"context": [33, 31, 43], "target": [23]}, "scan45_train_03": {"context": [33, 31, 43], "target": [44]}, "scan55_train_00": {"context": [33, 31, 43], "target": [32]}, "scan55_train_01": {"context": [33, 31, 43], "target": [24]}, "scan55_train_02": {"context": [33, 31, 43], "target": [23]}, "scan55_train_03": {"context": [33, 31, 43], "target": [44]}, "scan63_train_00": {"context": [33, 31, 43], "target": [32]}, "scan63_train_01": {"context": [33, 31, 43], "target": [24]}, "scan63_train_02": {"context": [33, 31, 43], "target": [23]}, "scan63_train_03": {"context": [33, 31, 43], "target": [44]}, "scan82_train_00": {"context": [33, 31, 43], "target": [32]}, "scan82_train_01": {"context": [33, 31, 43], "target": [24]}, "scan82_train_02": {"context": [33, 31, 43], "target": [23]}, "scan82_train_03": {"context": [33, 31, 43], "target": [44]}, "scan103_train_00": {"context": [33, 31, 43], "target": [32]}, "scan103_train_01": {"context": [33, 31, 43], "target": [24]}, "scan103_train_02": {"context": [33, 31, 43], "target": [23]}, "scan103_train_03": {"context": [33, 31, 43], "target": [44]}, "scan110_train_00": {"context": [33, 31, 43], "target": [32]}, "scan110_train_01": {"context": [33, 31, 43], "target": [24]}, "scan110_train_02": {"context": [33, 31, 43], "target": [23]}, "scan110_train_03": {"context": [33, 31, 43], "target": [44]}, "scan114_train_00": {"context": [33, 31, 43], "target": [32]}, "scan114_train_01": {"context": [33, 31, 43], "target": [24]}, "scan114_train_02": {"context": [33, 31, 43], "target": [23]}, "scan114_train_03": {"context": [33, 31, 43], "target": [44]}} -------------------------------------------------------------------------------- /assets/readme_fig/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/assets/readme_fig/framework.jpg -------------------------------------------------------------------------------- /config/dataset/dtu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: dtu 5 | roots: [datasets/re10k] 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | test_times_per_scene: 1 24 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: re10k 5 | roots: [datasets/re10k] 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | test_times_per_scene: 1 24 | -------------------------------------------------------------------------------- /config/dataset/replica.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: replica 5 | roots: [datasets/replica] 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | test_times_per_scene: 1 24 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 13 | special_way: null -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k_video.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 45 6 | max_distance_between_context_views: 192 7 | min_distance_to_context_views: 0 8 | warm_up_steps: 150_000 9 | initial_min_distance_between_context_views: 25 10 | initial_max_distance_between_context_views: 45 11 | num_target_views: 4 12 | special_way: null 13 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | -------------------------------------------------------------------------------- /config/experiment/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume_pyramid 6 | - override /loss: [mse, lpips] 7 | precision: 32 8 | wandb: 9 | name: acid 10 | tags: [acid, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/acid] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/experiment/dtu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dtu 5 | - override /model/encoder: costvolume_pyramid 6 | - override /loss: [mse, lpips] 7 | precision: 32 8 | wandb: 9 | name: dtu/views2 10 | tags: [dtu, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/dtu] 40 | near: 2.125 41 | far: 4.525 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | test_times_per_scene: 4 45 | skip_bad_shape: false 46 | 47 | test: 48 | eval_time_skip_steps: 5 49 | compute_scores: true 50 | -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume_pyramid 6 | - override /loss: [mse, lpips] 7 | precision: 32 8 | wandb: 9 | name: re10k 10 | tags: [re10k, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/re10k] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/experiment/replica.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: replica 5 | - override /model/encoder: costvolume_pyramid 6 | - override /loss: [mse, lpips] 7 | precision: 32 8 | wandb: 9 | name: replica 10 | tags: [dtu, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/replica] 40 | near: 0.5 41 | far: 15.0 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | test_times_per_scene: 36 45 | skip_bad_shape: false 46 | 47 | test: 48 | eval_time_skip_steps: 5 49 | compute_scores: true 50 | -------------------------------------------------------------------------------- /config/loss/depth.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | weight: 0.25 3 | sigma_image: null 4 | use_second_derivative: false 5 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.05 3 | apply_after_step: 150_000 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - model/encoder: costvolume 5 | - model/decoder: splatting_cuda 6 | - loss: [mse] 7 | 8 | wandb: 9 | project: hisplat 10 | entity: placeholder 11 | name: placeholder 12 | mode: offline #disabled 13 | id: null 14 | 15 | use_tensorboard: null 16 | use_xy_sin: true 17 | mode: train 18 | device: auto 19 | method: hisplat 20 | 21 | output_dir: null 22 | 23 | dataset: 24 | overfit_to_scene: null 25 | 26 | data_loader: 27 | # Avoid having to spin up new processes to print out visualizations. 28 | train: 29 | num_workers: 10 30 | persistent_workers: true 31 | batch_size: 4 32 | seed: 1234 33 | test: 34 | num_workers: 4 35 | persistent_workers: false 36 | batch_size: 1 37 | seed: 2345 38 | val: 39 | num_workers: 1 40 | persistent_workers: true 41 | batch_size: 1 42 | seed: 3456 43 | 44 | optimizer: 45 | lr: 2.e-4 46 | warm_up_steps: 2000 47 | cosine_lr: true 48 | 49 | checkpointing: 50 | load: null 51 | # 15 checkpoints 52 | every_n_train_steps: 10000 # 5000 53 | save_top_k: -1 54 | pretrained_model: null 55 | 56 | train: 57 | depth_mode: null 58 | extended_visualization: false 59 | print_log_every_n_steps: 1 60 | # whether to use alignment, false for disable, float for loss coefficient 61 | align_2d: false 62 | align_3d: false 63 | align_depth: false 64 | normal_norm: true 65 | 66 | test: 67 | output_path: outputs/test 68 | compute_scores: true 69 | eval_time_skip_steps: 0 70 | save_image: false 71 | save_video: false 72 | test_all_ckpt: false 73 | 74 | seed: 111123 75 | 76 | trainer: 77 | max_steps: -1 78 | # val 100 times 79 | val_check_interval: 3000 80 | gradient_clip_val: 0.5 81 | num_sanity_val_steps: 0 82 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /config/model/encoder/backbone/dino.yaml: -------------------------------------------------------------------------------- 1 | name: dino 2 | 3 | model: dino_vitb8 4 | d_out: 512 5 | -------------------------------------------------------------------------------- /config/model/encoder/backbone/resnet.yaml: -------------------------------------------------------------------------------- 1 | name: resnet 2 | 3 | model: resnet50 4 | num_layers: 5 5 | use_first_pool: false 6 | d_out: 512 7 | -------------------------------------------------------------------------------- /config/model/encoder/costvolume.yaml: -------------------------------------------------------------------------------- 1 | name: costvolume 2 | 3 | opacity_mapping: 4 | initial: 0.0 5 | final: 0.0 6 | warm_up: 1 7 | 8 | num_depth_candidates: 32 9 | num_surfaces: 1 10 | 11 | gaussians_per_pixel: 1 12 | 13 | gaussian_adapter: 14 | gaussian_scale_min: 0.5 15 | gaussian_scale_max: 15.0 16 | sh_degree: 4 17 | 18 | d_feature: 128 19 | 20 | visualizer: 21 | num_samples: 8 22 | min_resolution: 256 23 | export_ply: false 24 | 25 | # params for multi-view depth predictor 26 | unimatch_weights_path: "checkpoints/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" 27 | multiview_trans_attn_split: 2 28 | costvolume_unet_feat_dim: 128 29 | costvolume_unet_channel_mult: [1,1,1] 30 | costvolume_unet_attn_res: [] 31 | depth_unet_feat_dim: 64 32 | depth_unet_attn_res: [] 33 | depth_unet_channel_mult: [1, 1, 1] 34 | downscale_factor: 4 35 | shim_patch_size: 4 36 | -------------------------------------------------------------------------------- /config/model/encoder/costvolume_pyramid.yaml: -------------------------------------------------------------------------------- 1 | name: costvolume_pyramid 2 | 3 | opacity_mapping: 4 | initial: 0.0 5 | final: 0.0 6 | warm_up: 1 7 | 8 | num_depth_candidates: 32 9 | num_surfaces: 1 10 | 11 | gaussians_per_pixel: 1 12 | 13 | gaussian_adapter: 14 | gaussian_scale_min: 0.5 15 | gaussian_scale_max: 15.0 16 | sh_degree: 4 17 | 18 | d_feature: 128 19 | 20 | visualizer: 21 | num_samples: 8 22 | min_resolution: 256 23 | export_ply: false 24 | 25 | # params for multi-view depth predictor 26 | unimatch_weights_path: "checkpoints/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" 27 | multiview_trans_attn_split: 2 28 | costvolume_unet_feat_dim: 128 29 | costvolume_unet_channel_mult: [1,1,1] 30 | costvolume_unet_attn_res: [] 31 | depth_unet_feat_dim: 64 32 | depth_unet_attn_res: [] 33 | depth_unet_channel_mult: [1, 1, 1] 34 | downscale_factor: 4 35 | shim_patch_size: 4 36 | -------------------------------------------------------------------------------- /demo/demo_example.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/demo/demo_example.tar -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel==0.43.0 2 | tqdm==4.66.4 3 | pytorch_lightning==2.3.0 4 | black==24.4.2 5 | ruff==0.4.8 6 | hydra-core==1.3.2 7 | jaxtyping==0.2.30 8 | beartype==0.18.5 9 | wandb==0.17.1 10 | einops==0.8.0 11 | colorama==0.4.6 12 | scikit-image==0.23.2 13 | colorspacious==1.1.2 14 | matplotlib==3.5.2 15 | moviepy==1.0.3 16 | imageio==2.34.1 17 | git+https://github.com/dcharatan/diff-gaussian-rasterization-modified 18 | timm==1.0.3 19 | dacite==1.8.1 20 | lpips==0.1.4 21 | e3nn==0.5.1 22 | plyfile==1.0.3 23 | tabulate==0.9.0 24 | svg.py==1.4.3 25 | opencv-python==4.6.0.66 26 | sk-video==1.1.10 27 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset.data_module import DataLoaderCfg, DatasetCfg 9 | from .loss import LossCfgWrapper 10 | from .model.decoder import DecoderCfg 11 | from .model.encoder import EncoderCfg 12 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 13 | 14 | 15 | @dataclass 16 | class CheckpointingCfg: 17 | load: Optional[str] # Not a path, since it could be something like wandb://... 18 | every_n_train_steps: int 19 | save_top_k: int 20 | pretrained_model: Optional[str] 21 | 22 | 23 | @dataclass 24 | class ModelCfg: 25 | decoder: DecoderCfg 26 | encoder: EncoderCfg 27 | 28 | 29 | @dataclass 30 | class TrainerCfg: 31 | max_steps: int 32 | val_check_interval: int | float | None 33 | gradient_clip_val: int | float | None 34 | num_sanity_val_steps: int 35 | 36 | 37 | @dataclass 38 | class RootCfg: 39 | use_tensorboard: bool | None 40 | wandb: dict 41 | mode: Literal["train", "test"] 42 | dataset: DatasetCfg 43 | data_loader: DataLoaderCfg 44 | model: ModelCfg 45 | optimizer: OptimizerCfg 46 | checkpointing: CheckpointingCfg 47 | trainer: TrainerCfg 48 | loss: list[LossCfgWrapper] 49 | test: TestCfg 50 | train: TrainCfg 51 | seed: int 52 | device: int | list[int] | str 53 | output_dir: str | None 54 | method: str | None 55 | precision: int 56 | 57 | 58 | TYPE_HOOKS = { 59 | Path: Path, 60 | } 61 | 62 | 63 | T = TypeVar("T") 64 | 65 | 66 | def load_typed_config( 67 | cfg: DictConfig, 68 | data_class: Type[T], 69 | extra_type_hooks: dict = {}, 70 | ) -> T: 71 | return from_dict( 72 | data_class, 73 | OmegaConf.to_container(cfg), 74 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 75 | ) 76 | 77 | 78 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 79 | # The dummy allows the union to be converted. 80 | @dataclass 81 | class Dummy: 82 | dummy: LossCfgWrapper 83 | 84 | return [load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy for k, v in joined.items()] 85 | 86 | 87 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 88 | return load_typed_config( 89 | cfg, 90 | RootCfg, 91 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers}, 92 | ) 93 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from ..misc.step_tracker import StepTracker 4 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg 5 | from .types import Stage 6 | from .view_sampler import get_view_sampler 7 | 8 | DATASETS: dict[str, Dataset] = { 9 | "re10k": DatasetRE10k, 10 | } 11 | 12 | 13 | DatasetCfg = DatasetRE10kCfg 14 | 15 | 16 | def get_dataset( 17 | cfg: DatasetCfg, 18 | stage: Stage, 19 | step_tracker: StepTracker | None, 20 | ) -> Dataset: 21 | view_sampler = get_view_sampler( 22 | cfg.view_sampler, 23 | stage, 24 | cfg.overfit_to_scene is not None, 25 | cfg.cameras_are_circular, 26 | step_tracker, 27 | ) 28 | try: 29 | return DATASETS[cfg.name](cfg, stage, view_sampler) 30 | except: 31 | return DATASETS["re10k"](cfg, stage, view_sampler) 32 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from pytorch_lightning import LightningDataModule 9 | from torch import Generator, nn 10 | from torch.utils.data import DataLoader, Dataset, IterableDataset 11 | 12 | from ..misc.step_tracker import StepTracker 13 | from . import DatasetCfg, get_dataset 14 | from .types import DataShim, Stage 15 | from .validation_wrapper import ValidationWrapper 16 | 17 | 18 | def get_data_shim(encoder: nn.Module) -> DataShim: 19 | """Get functions that modify the batch. It's sometimes necessary to modify batches 20 | outside the data loader because GPU computations are required to modify the batch or 21 | because the modification depends on something outside the data loader. 22 | """ 23 | 24 | shims: list[DataShim] = [] 25 | if hasattr(encoder, "get_data_shim"): 26 | shims.append(encoder.get_data_shim()) 27 | 28 | def combined_shim(batch): 29 | for shim in shims: 30 | batch = shim(batch) 31 | return batch 32 | 33 | return combined_shim 34 | 35 | 36 | @dataclass 37 | class DataLoaderStageCfg: 38 | batch_size: int 39 | num_workers: int 40 | persistent_workers: bool 41 | seed: int | None 42 | 43 | 44 | @dataclass 45 | class DataLoaderCfg: 46 | train: DataLoaderStageCfg 47 | test: DataLoaderStageCfg 48 | val: DataLoaderStageCfg 49 | 50 | 51 | DatasetShim = Callable[[Dataset, Stage], Dataset] 52 | 53 | 54 | def worker_init_fn(worker_id: int) -> None: 55 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 56 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 57 | 58 | 59 | class DataModule(LightningDataModule): 60 | dataset_cfg: DatasetCfg 61 | data_loader_cfg: DataLoaderCfg 62 | step_tracker: StepTracker | None 63 | dataset_shim: DatasetShim 64 | global_rank: int 65 | 66 | def __init__( 67 | self, 68 | dataset_cfg: DatasetCfg, 69 | data_loader_cfg: DataLoaderCfg, 70 | step_tracker: StepTracker | None = None, 71 | dataset_shim: DatasetShim = lambda dataset, _: dataset, 72 | global_rank: int = 0, 73 | ) -> None: 74 | super().__init__() 75 | self.dataset_cfg = dataset_cfg 76 | self.data_loader_cfg = data_loader_cfg 77 | self.step_tracker = step_tracker 78 | self.dataset_shim = dataset_shim 79 | self.global_rank = global_rank 80 | 81 | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: 82 | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers 83 | 84 | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: 85 | if loader_cfg.seed is None: 86 | return None 87 | generator = Generator() 88 | generator.manual_seed(loader_cfg.seed + self.global_rank) 89 | return generator 90 | 91 | def train_dataloader(self): 92 | dataset = get_dataset(self.dataset_cfg, "train", self.step_tracker) 93 | dataset = self.dataset_shim(dataset, "train") 94 | return DataLoader( 95 | dataset, 96 | self.data_loader_cfg.train.batch_size, 97 | shuffle=not isinstance(dataset, IterableDataset), 98 | num_workers=self.data_loader_cfg.train.num_workers, 99 | generator=self.get_generator(self.data_loader_cfg.train), 100 | worker_init_fn=worker_init_fn, 101 | persistent_workers=self.get_persistent(self.data_loader_cfg.train), 102 | ) 103 | 104 | def val_dataloader(self): 105 | # dataset = get_dataset(self.dataset_cfg, "val", self.step_tracker) 106 | dataset = get_dataset(self.dataset_cfg, "test", self.step_tracker) 107 | # dataset = self.dataset_shim(dataset, "val") 108 | dataset = self.dataset_shim(dataset, "test") 109 | try: 110 | world_size = dist.get_world_size() 111 | except: 112 | world_size = 1 113 | return DataLoader( 114 | ValidationWrapper(dataset, 100 * world_size), 115 | self.data_loader_cfg.val.batch_size, 116 | num_workers=self.data_loader_cfg.val.num_workers, 117 | generator=self.get_generator(self.data_loader_cfg.val), 118 | worker_init_fn=worker_init_fn, 119 | persistent_workers=self.get_persistent(self.data_loader_cfg.val), 120 | shuffle=False, 121 | ) 122 | 123 | def test_dataloader(self, dataset_cfg=None): 124 | dataset = get_dataset( 125 | self.dataset_cfg if dataset_cfg is None else dataset_cfg, 126 | "test", 127 | self.step_tracker, 128 | ) 129 | dataset = self.dataset_shim(dataset, "test") 130 | return DataLoader( 131 | dataset, 132 | self.data_loader_cfg.test.batch_size, 133 | num_workers=self.data_loader_cfg.test.num_workers, 134 | generator=self.get_generator(self.data_loader_cfg.test), 135 | worker_init_fn=worker_init_fn, 136 | persistent_workers=self.get_persistent(self.data_loader_cfg.test), 137 | shuffle=False, 138 | ) 139 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum(intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i") 30 | 31 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 32 | # don't make much sense anyway. 33 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 34 | 35 | return baselines / (disparity * mean_pixel_size) 36 | 37 | 38 | def apply_bounds_shim( 39 | batch: BatchedExample, 40 | near_disparity: float, 41 | far_disparity: float, 42 | ) -> BatchedExample: 43 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 44 | assumes that all of an example's views are of roughly the same thing. 45 | """ 46 | 47 | context = batch["context"] 48 | _, cv, _, h, w = context["image"].shape 49 | 50 | # Compute near and far planes using the context views. 51 | near = compute_depth_for_disparity( 52 | context["extrinsics"], 53 | context["intrinsics"], 54 | (h, w), 55 | near_disparity, 56 | ) 57 | far = compute_depth_for_disparity( 58 | context["extrinsics"], 59 | context["intrinsics"], 60 | (h, w), 61 | far_disparity, 62 | ) 63 | 64 | target = batch["target"] 65 | _, tv, _, _, _ = target["image"].shape 66 | return { 67 | **batch, 68 | "context": { 69 | **context, 70 | "near": repeat(near, "b -> b v", v=cv), 71 | "far": repeat(far, "b -> b v", v=cv), 72 | }, 73 | "target": { 74 | **target, 75 | "near": repeat(near, "b -> b v", v=tv), 76 | "far": repeat(far, "b -> b v", v=tv), 77 | }, 78 | } 79 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.utils.data import Dataset, IterableDataset 6 | 7 | 8 | class ValidationWrapper(Dataset): 9 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 10 | visualization step. 11 | """ 12 | 13 | dataset: Dataset 14 | dataset_iterator: Optional[Iterator] 15 | length: int 16 | 17 | def __init__(self, dataset: Dataset, length: int) -> None: 18 | super().__init__() 19 | self.dataset = dataset 20 | self.length = length 21 | self.dataset_iterator = None 22 | self.iter_num = 0 23 | 24 | def __len__(self): 25 | return self.length 26 | 27 | def __getitem__(self, index: int): 28 | if isinstance(self.dataset, IterableDataset): 29 | # TODO: Very dangerous, may cause leaking 30 | try: 31 | world_size = dist.get_world_size() 32 | except: 33 | world_size = 1 34 | if self.dataset_iterator is None or self.iter_num >= self.length / world_size: 35 | self.iter_num = 0 36 | self.dataset_iterator = iter(self.dataset) 37 | self.iter_num += 1 38 | return next(self.dataset_iterator) 39 | 40 | random_index = torch.randint(0, len(self.dataset), tuple()) 41 | return self.dataset[random_index.item()] 42 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from ...global_cfg import get_cfg 5 | from ...misc.step_tracker import StepTracker 6 | from ..types import Stage 7 | from .view_sampler import ViewSampler 8 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 9 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 10 | from .view_sampler_bounded import ( 11 | ViewSamplerBounded, 12 | ViewSamplerBoundedCfg, 13 | ViewSamplerBoundedDTU, 14 | ) 15 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 16 | 17 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 18 | "all": ViewSamplerAll, 19 | "arbitrary": ViewSamplerArbitrary, 20 | "bounded": ViewSamplerBounded, 21 | "evaluation": ViewSamplerEvaluation, 22 | "bounded_dtu": ViewSamplerBoundedDTU, 23 | } 24 | 25 | ViewSamplerCfg = ViewSamplerArbitraryCfg | ViewSamplerBoundedCfg | ViewSamplerEvaluationCfg | ViewSamplerAllCfg 26 | 27 | 28 | def get_view_sampler( 29 | cfg: ViewSamplerCfg, 30 | stage: Stage, 31 | overfit: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> ViewSampler[Any]: 35 | # TODO: only a temporary fix, need to support cfg input 36 | if not stage == "train": 37 | dataset_name = get_cfg().dataset.roots[0].split("/")[-1] 38 | if dataset_name == "dtu": 39 | index_path = f'assets/evaluation_index_dtu_nctx{get_cfg().dataset.view_sampler.num_context_views}.json' 40 | elif dataset_name == "re10k": 41 | index_path = "assets/evaluation_index_re10k.json" 42 | elif dataset_name == "acid": 43 | index_path = "assets/evaluation_index_acid.json" 44 | elif dataset_name == "replica": 45 | index_path = f"assets/evaluation_index_replica_nctx{get_cfg().dataset.view_sampler.num_context_views}.json" 46 | else: 47 | index_path = None 48 | cfg = ViewSamplerEvaluationCfg(name="evaluation", index_path=Path(index_path), num_context_views=2) 49 | 50 | return VIEW_SAMPLERS[cfg.name]( 51 | cfg, 52 | stage, 53 | overfit, 54 | cameras_are_circular, 55 | step_tracker, 56 | ) 57 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | **kwargs, 43 | ) -> tuple[ 44 | Int64[Tensor, " context_view"], # indices for context views 45 | Int64[Tensor, " target_view"], # indices for target views 46 | ]: 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def num_target_views(self) -> int: 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def num_context_views(self) -> int: 57 | pass 58 | 59 | @property 60 | def global_step(self) -> int: 61 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 62 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerArbitraryCfg: 13 | name: Literal["arbitrary"] 14 | num_context_views: int 15 | num_target_views: int 16 | context_views: list[int] | None 17 | target_views: list[int] | None 18 | 19 | 20 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 21 | def sample( 22 | self, 23 | scene: str, 24 | extrinsics: Float[Tensor, "view 4 4"], 25 | intrinsics: Float[Tensor, "view 3 3"], 26 | device: torch.device = torch.device("cpu"), 27 | ) -> tuple[ 28 | Int64[Tensor, " context_view"], # indices for context views 29 | Int64[Tensor, " target_view"], # indices for target views 30 | ]: 31 | """Arbitrarily sample context and target views.""" 32 | num_views, _, _ = extrinsics.shape 33 | 34 | index_context = torch.randint( 35 | 0, 36 | num_views, 37 | size=(self.cfg.num_context_views,), 38 | device=device, 39 | ) 40 | 41 | # Allow the context views to be fixed. 42 | if self.cfg.context_views is not None: 43 | assert len(self.cfg.context_views) == self.cfg.num_context_views 44 | index_context = torch.tensor(self.cfg.context_views, dtype=torch.int64, device=device) 45 | 46 | index_target = torch.randint( 47 | 0, 48 | num_views, 49 | size=(self.cfg.num_target_views,), 50 | device=device, 51 | ) 52 | 53 | # Allow the target views to be fixed. 54 | if self.cfg.target_views is not None: 55 | assert len(self.cfg.target_views) == self.cfg.num_target_views 56 | index_target = torch.tensor(self.cfg.target_views, dtype=torch.int64, device=device) 57 | 58 | return index_context, index_target 59 | 60 | @property 61 | def num_context_views(self) -> int: 62 | return self.cfg.num_context_views 63 | 64 | @property 65 | def num_target_views(self) -> int: 66 | return self.cfg.num_target_views 67 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...misc.step_tracker import StepTracker 13 | from ..types import Stage 14 | from .view_sampler import ViewSampler 15 | 16 | 17 | @dataclass 18 | class ViewSamplerEvaluationCfg: 19 | name: Literal["evaluation"] 20 | index_path: Path 21 | num_context_views: int 22 | 23 | 24 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 25 | index: dict[str, IndexEntry | None] 26 | 27 | def __init__( 28 | self, 29 | cfg: ViewSamplerEvaluationCfg, 30 | stage: Stage, 31 | is_overfitting: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> None: 35 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 36 | 37 | dacite_config = Config(cast=[tuple]) 38 | with cfg.index_path.open("r") as f: 39 | self.index = { 40 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) for k, v in json.load(f).items() 41 | } 42 | 43 | def sample( 44 | self, 45 | scene: str, 46 | extrinsics: Float[Tensor, "view 4 4"], 47 | intrinsics: Float[Tensor, "view 3 3"], 48 | device: torch.device = torch.device("cpu"), 49 | ) -> tuple[ 50 | Int64[Tensor, " context_view"], # indices for context views 51 | Int64[Tensor, " target_view"], # indices for target views 52 | ]: 53 | entry = self.index.get(scene) 54 | if entry is None: 55 | raise ValueError(f"No indices available for scene {scene}.") 56 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 57 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 58 | return context_indices, target_indices 59 | 60 | @property 61 | def num_context_views(self) -> int: 62 | return 0 63 | 64 | @property 65 | def num_target_views(self) -> int: 66 | return 0 67 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | animate_side_by_side: bool 23 | highlighted: list[SceneCfg] 24 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_index_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | import torch 6 | from einops import rearrange 7 | from pytorch_lightning import LightningModule 8 | from tqdm import tqdm 9 | 10 | from ..geometry.epipolar_lines import project_rays 11 | from ..geometry.projection import get_world_rays, sample_image_grid 12 | from ..misc.image_io import save_image 13 | from ..visualization.annotation import add_label 14 | from ..visualization.layout import add_border, hcat 15 | 16 | 17 | @dataclass 18 | class EvaluationIndexGeneratorCfg: 19 | num_target_views: int 20 | min_distance: int 21 | max_distance: int 22 | min_overlap: float 23 | max_overlap: float 24 | output_path: Path 25 | save_previews: bool 26 | seed: int 27 | 28 | 29 | @dataclass 30 | class IndexEntry: 31 | context: tuple[int, ...] 32 | target: tuple[int, ...] 33 | 34 | 35 | class EvaluationIndexGenerator(LightningModule): 36 | generator: torch.Generator 37 | cfg: EvaluationIndexGeneratorCfg 38 | index: dict[str, IndexEntry | None] 39 | 40 | def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: 41 | super().__init__() 42 | self.cfg = cfg 43 | self.generator = torch.Generator() 44 | self.generator.manual_seed(cfg.seed) 45 | self.index = {} 46 | 47 | def test_step(self, batch, batch_idx): 48 | b, v, _, h, w = batch["target"]["image"].shape 49 | assert b == 1 50 | extrinsics = batch["target"]["extrinsics"][0] 51 | intrinsics = batch["target"]["intrinsics"][0] 52 | scene = batch["scene"][0] 53 | 54 | context_indices = torch.randperm(v, generator=self.generator) 55 | for context_index in tqdm(context_indices, "Finding context pair"): 56 | xy, _ = sample_image_grid((h, w), self.device) 57 | context_origins, context_directions = get_world_rays( 58 | rearrange(xy, "h w xy -> (h w) xy"), 59 | extrinsics[context_index], 60 | intrinsics[context_index], 61 | ) 62 | 63 | # Step away from context view until the minimum overlap threshold is met. 64 | valid_indices = [] 65 | for step in (1, -1): 66 | min_distance = self.cfg.min_distance 67 | max_distance = self.cfg.max_distance 68 | current_index = context_index + step * min_distance 69 | 70 | while 0 <= current_index.item() < v: 71 | # Compute overlap. 72 | current_origins, current_directions = get_world_rays( 73 | rearrange(xy, "h w xy -> (h w) xy"), 74 | extrinsics[current_index], 75 | intrinsics[current_index], 76 | ) 77 | projection_onto_current = project_rays( 78 | context_origins, 79 | context_directions, 80 | extrinsics[current_index], 81 | intrinsics[current_index], 82 | ) 83 | projection_onto_context = project_rays( 84 | current_origins, 85 | current_directions, 86 | extrinsics[context_index], 87 | intrinsics[context_index], 88 | ) 89 | overlap_a = projection_onto_context["overlaps_image"].float().mean() 90 | overlap_b = projection_onto_current["overlaps_image"].float().mean() 91 | 92 | overlap = min(overlap_a, overlap_b) 93 | delta = (current_index - context_index).abs() 94 | 95 | min_overlap = self.cfg.min_overlap 96 | max_overlap = self.cfg.max_overlap 97 | if min_overlap <= overlap <= max_overlap: 98 | valid_indices.append((current_index.item(), overlap_a, overlap_b)) 99 | 100 | # Stop once the camera has panned away too much. 101 | if overlap < min_overlap or delta > max_distance: 102 | break 103 | 104 | current_index += step 105 | 106 | if valid_indices: 107 | # Pick a random valid view. Index the resulting views. 108 | num_options = len(valid_indices) 109 | chosen = torch.randint(0, num_options, size=tuple(), generator=self.generator) 110 | chosen, overlap_a, overlap_b = valid_indices[chosen] 111 | 112 | context_left = min(chosen, context_index.item()) 113 | context_right = max(chosen, context_index.item()) 114 | delta = context_right - context_left 115 | 116 | # Pick non-repeated random target views. 117 | while True: 118 | target_views = torch.randint( 119 | context_left, 120 | context_right + 1, 121 | (self.cfg.num_target_views,), 122 | generator=self.generator, 123 | ) 124 | if (target_views.unique(return_counts=True)[1] == 1).all(): 125 | break 126 | 127 | target = tuple(sorted(target_views.tolist())) 128 | self.index[scene] = IndexEntry( 129 | context=(context_left, context_right), 130 | target=target, 131 | ) 132 | 133 | # Optionally, save a preview. 134 | if self.cfg.save_previews: 135 | preview_path = self.cfg.output_path / "previews" 136 | preview_path.mkdir(exist_ok=True, parents=True) 137 | a = batch["target"]["image"][0, chosen] 138 | a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") 139 | b = batch["target"]["image"][0, context_index] 140 | b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") 141 | vis = add_border(add_border(hcat(a, b)), 1, 0) 142 | vis = add_label(vis, f"Distance: {delta} frames") 143 | save_image(add_border(vis), preview_path / f"{scene}.png") 144 | break 145 | else: 146 | # This happens if no starting frame produces a valid evaluation example. 147 | self.index[scene] = None 148 | 149 | def save_index(self) -> None: 150 | self.cfg.output_path.mkdir(exist_ok=True, parents=True) 151 | with (self.cfg.output_path / "evaluation_index.json").open("w") as f: 152 | json.dump({k: None if v is None else asdict(v) for k, v in self.index.items()}, f) 153 | -------------------------------------------------------------------------------- /src/evaluation/metric_computer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from tabulate import tabulate 7 | 8 | from ..misc.image_io import load_image, save_image 9 | from ..visualization.annotation import add_label 10 | from ..visualization.layout import add_border, hcat 11 | from .evaluation_cfg import EvaluationCfg 12 | from .metrics import compute_lpips, compute_psnr, compute_ssim 13 | 14 | 15 | class MetricComputer(LightningModule): 16 | cfg: EvaluationCfg 17 | 18 | def __init__(self, cfg: EvaluationCfg) -> None: 19 | super().__init__() 20 | self.cfg = cfg 21 | 22 | def test_step(self, batch, batch_idx): 23 | scene = batch["scene"][0] 24 | b, cv, _, _, _ = batch["context"]["image"].shape 25 | assert b == 1 and cv == 2 26 | _, v, _, _, _ = batch["target"]["image"].shape 27 | 28 | # Skip scenes. 29 | for method in self.cfg.methods: 30 | if not (method.path / scene).exists(): 31 | print(f'Skipping "{scene}".') 32 | return 33 | 34 | # Load the images. 35 | all_images = {} 36 | try: 37 | for method in self.cfg.methods: 38 | images = [ 39 | load_image(method.path / scene / f"color/{index.item():0>6}.png") 40 | for index in batch["target"]["index"][0] 41 | ] 42 | all_images[method.key] = torch.stack(images).to(self.device) 43 | except FileNotFoundError: 44 | print(f'Skipping "{scene}".') 45 | return 46 | 47 | # Compute metrics. 48 | all_metrics = {} 49 | rgb_gt = batch["target"]["image"][0] 50 | for key, images in all_images.items(): 51 | all_metrics = { 52 | **all_metrics, 53 | f"lpips_{key}": compute_lpips(rgb_gt, images).mean(), 54 | f"ssim_{key}": compute_ssim(rgb_gt, images).mean(), 55 | f"psnr_{key}": compute_psnr(rgb_gt, images).mean(), 56 | } 57 | self.log_dict(all_metrics) 58 | self.print_preview_metrics(all_metrics) 59 | 60 | # Skip the rest if no side-by-side is needed. 61 | if self.cfg.side_by_side_path is None: 62 | return 63 | 64 | # Create side-by-side. 65 | scene_key = f"{batch_idx:0>6}_{scene}" 66 | for i in range(v): 67 | true_index = batch["target"]["index"][0, i] 68 | row = [add_label(batch["target"]["image"][0, i], "Ground Truth")] 69 | for method in self.cfg.methods: 70 | image = all_images[method.key][i] 71 | image = add_label(image, method.name) 72 | row.append(image) 73 | start_frame = batch["target"]["index"][0, 0] 74 | end_frame = batch["target"]["index"][0, -1] 75 | label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})" 76 | row = add_border(add_label(hcat(*row), label, font_size=16)) 77 | save_image( 78 | row, 79 | self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png", 80 | ) 81 | 82 | # Create an animation. 83 | if self.cfg.animate_side_by_side: 84 | (self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True) 85 | command = ( 86 | 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 ' 87 | '-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"' 88 | ) 89 | os.system( 90 | f"cd {self.cfg.side_by_side_path / scene_key} && {command} " 91 | f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4" 92 | ) 93 | 94 | def print_preview_metrics(self, metrics: dict[str, float]) -> None: 95 | if getattr(self, "running_metrics", None) is None: 96 | self.running_metrics = metrics 97 | self.running_metric_steps = 1 98 | else: 99 | s = self.running_metric_steps 100 | self.running_metrics = {k: ((s * v) + metrics[k]) / (s + 1) for k, v in self.running_metrics.items()} 101 | self.running_metric_steps += 1 102 | 103 | table = [] 104 | for method in self.cfg.methods: 105 | row = [f"{self.running_metrics[f'{metric}_{method.key}']:.3f}" for metric in ("psnr", "lpips", "ssim")] 106 | table.append((method.key, *row)) 107 | 108 | table = tabulate(table, ["Method", "PSNR (dB)", "LPIPS", "SSIM"]) 109 | print(table) 110 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_depth import LossDepth, LossDepthCfgWrapper 3 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 4 | from .loss_mse import LossMse, LossMseCfgWrapper 5 | 6 | LOSSES = { 7 | LossDepthCfgWrapper: LossDepth, 8 | LossLpipsCfgWrapper: LossLpips, 9 | LossMseCfgWrapper: LossMse, 10 | } 11 | 12 | LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper 13 | 14 | 15 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 16 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 17 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians: Gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | 38 | @abstractmethod 39 | def dynamic_forward( 40 | self, prediction: DecoderOutput, gt_image: Tensor, global_step: int, weight: float | None = None 41 | ) -> Float[Tensor, ""]: 42 | pass 43 | -------------------------------------------------------------------------------- /src/loss/loss_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | from .loss import Loss 12 | 13 | 14 | @dataclass 15 | class LossDepthCfg: 16 | weight: float 17 | sigma_image: float | None 18 | use_second_derivative: bool 19 | 20 | 21 | @dataclass 22 | class LossDepthCfgWrapper: 23 | depth: LossDepthCfg 24 | 25 | 26 | class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): 27 | def forward( 28 | self, 29 | prediction: DecoderOutput, 30 | batch: BatchedExample, 31 | gaussians: Gaussians, 32 | global_step: int, 33 | ) -> Float[Tensor, ""]: 34 | # Scale the depth between the near and far planes. 35 | near = batch["target"]["near"][..., None, None].log() 36 | far = batch["target"]["far"][..., None, None].log() 37 | depth = prediction.depth.minimum(far).maximum(near) 38 | depth = (depth - near) / (far - near) 39 | 40 | # Compute the difference between neighboring pixels in each direction. 41 | depth_dx = depth.diff(dim=-1) 42 | depth_dy = depth.diff(dim=-2) 43 | 44 | # If desired, compute a 2nd derivative. 45 | if self.cfg.use_second_derivative: 46 | depth_dx = depth_dx.diff(dim=-1) 47 | depth_dy = depth_dy.diff(dim=-2) 48 | 49 | # If desired, add bilateral filtering. 50 | if self.cfg.sigma_image is not None: 51 | color_gt = batch["target"]["image"] 52 | color_dx = reduce(color_gt.diff(dim=-1), "b v c h w -> b v h w", "max") 53 | color_dy = reduce(color_gt.diff(dim=-2), "b v c h w -> b v h w", "max") 54 | if self.cfg.use_second_derivative: 55 | color_dx = color_dx[..., :, 1:].maximum(color_dx[..., :, :-1]) 56 | color_dy = color_dy[..., 1:, :].maximum(color_dy[..., :-1, :]) 57 | depth_dx = depth_dx * torch.exp(-color_dx * self.cfg.sigma_image) 58 | depth_dy = depth_dy * torch.exp(-color_dy * self.cfg.sigma_image) 59 | 60 | return self.cfg.weight * (depth_dx.abs().mean() + depth_dy.abs().mean()) 61 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians: Gaussians, 41 | global_step: int, 42 | ) -> Float[Tensor, ""]: 43 | image = batch["target"]["image"] 44 | 45 | # Before the specified step, don't apply the loss. 46 | if global_step < self.cfg.apply_after_step: 47 | return torch.tensor(0, dtype=torch.float32, device=image.device) 48 | 49 | loss = self.lpips.forward( 50 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 51 | rearrange(image, "b v c h w -> (b v) c h w"), 52 | normalize=True, 53 | ) 54 | return self.cfg.weight * loss.mean() 55 | 56 | def dynamic_forward( 57 | self, prediction: DecoderOutput, gt_image: Tensor, global_step: int, weight: float | None = None 58 | ) -> Float[Tensor, ""]: 59 | image = gt_image 60 | weight = self.cfg.weight if weight is None else weight 61 | # Before the specified step, don't apply the loss. 62 | if global_step < self.cfg.apply_after_step: 63 | return torch.tensor(0, dtype=torch.float32, device=image.device) 64 | 65 | loss = self.lpips.forward( 66 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 67 | rearrange(image, "b v c h w -> (b v) c h w"), 68 | normalize=True, 69 | ) 70 | return weight * loss.mean() 71 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians: Gaussians, 28 | global_step: int, 29 | ) -> Float[Tensor, ""]: 30 | delta = prediction.color - batch["target"]["image"] 31 | return self.cfg.weight * (delta**2).mean() 32 | 33 | def dynamic_forward( 34 | self, prediction: DecoderOutput, gt_image: Tensor, global_step: int, weight: float | None = None 35 | ) -> Float[Tensor, ""]: 36 | weight = self.cfg.weight if weight is None else weight 37 | delta = prediction.color - gt_image 38 | return weight * (delta**2).mean() 39 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from pathlib import Path 4 | 5 | import hydra 6 | import torch 7 | import wandb 8 | from colorama import Fore 9 | from jaxtyping import install_import_hook 10 | from omegaconf import DictConfig, OmegaConf 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning.callbacks import ( 13 | LearningRateMonitor, 14 | ModelCheckpoint, 15 | ) 16 | from pytorch_lightning.loggers import TensorBoardLogger 17 | from pytorch_lightning.loggers.wandb import WandbLogger 18 | from pytorch_lightning.plugins.layer_sync import TorchSyncBatchNorm 19 | 20 | # Configure beartype and jaxtyping. 21 | with install_import_hook( 22 | ("src",), 23 | ("beartype", "beartype"), 24 | ): 25 | from src.config import load_typed_root_config 26 | from src.dataset.data_module import DataModule 27 | from src.global_cfg import set_cfg 28 | from src.loss import get_losses 29 | from src.misc.LocalLogger import LocalLogger 30 | from src.misc.step_tracker import StepTracker 31 | from src.misc.wandb_tools import update_checkpoint_path 32 | from src.model.decoder import get_decoder 33 | from src.model.encoder import get_encoder 34 | from src.model.model_wrapper import ModelWrapper 35 | 36 | 37 | def cyan(text: str) -> str: 38 | return f"{Fore.CYAN}{text}{Fore.RESET}" 39 | 40 | 41 | @hydra.main( 42 | version_base=None, 43 | config_path="../config", 44 | config_name="main", 45 | ) 46 | def train(cfg_dict: DictConfig): 47 | cfg_dict["test"]["output_path"] = os.path.join("outputs", cfg_dict["output_dir"], "test") 48 | cfg = load_typed_root_config(cfg_dict) 49 | set_cfg(cfg_dict) 50 | # Set up the output directory. 51 | if cfg_dict.output_dir is None: 52 | output_dir = Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) 53 | else: # for resuming 54 | output_dir = Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["cwd"]) / "outputs" / cfg_dict.output_dir 55 | os.makedirs(output_dir, exist_ok=True) 56 | print(cyan(f"Saving outputs to {output_dir}.")) 57 | latest_run = output_dir.parents[1] / "latest-run" 58 | os.system(f"rm {latest_run}") 59 | os.system(f"ln -s {output_dir} {latest_run}") 60 | 61 | # Set up logging with wandb. 62 | callbacks = [] 63 | if cfg_dict.wandb.mode != "disabled": 64 | wandb_extra_kwargs = {} 65 | if cfg_dict.wandb.id is not None: 66 | wandb_extra_kwargs.update({"id": cfg_dict.wandb.id, "resume": "must"}) 67 | logger = WandbLogger( 68 | entity=cfg_dict.wandb.entity, 69 | project=cfg_dict.wandb.project, 70 | mode=cfg_dict.wandb.mode, 71 | name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", 72 | tags=cfg_dict.wandb.get("tags", None), 73 | log_model=False, 74 | save_dir=output_dir, 75 | config=OmegaConf.to_container(cfg_dict), 76 | **wandb_extra_kwargs, 77 | ) 78 | callbacks.append(LearningRateMonitor("step", True)) 79 | 80 | # On rank != 0, wandb.run is None. 81 | if wandb.run is not None: 82 | wandb.run.log_code("src") 83 | elif cfg_dict.use_tensorboard is not None: 84 | tensorboard_dir = output_dir / "tensorboard" 85 | tensorboard_dir.mkdir(exist_ok=True, parents=True) 86 | logger = TensorBoardLogger(save_dir=output_dir) 87 | callbacks.append(LearningRateMonitor("step", True)) 88 | else: 89 | logger = LocalLogger() 90 | 91 | # Set up checkpointing. 92 | callbacks.append( 93 | ModelCheckpoint( 94 | output_dir / "checkpoints", 95 | every_n_train_steps=cfg.checkpointing.every_n_train_steps, 96 | save_top_k=cfg.checkpointing.save_top_k, 97 | monitor="info/global_step", 98 | mode="max", # save the lastest k ckpt, can do offline test later 99 | ) 100 | ) 101 | for cb in callbacks: 102 | cb.CHECKPOINT_EQUALS_CHAR = "_" 103 | 104 | # Prepare the checkpoint for loading. 105 | checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) 106 | 107 | # This allows the current step to be shared with the data loader processes. 108 | step_tracker = StepTracker() 109 | 110 | trainer = Trainer( 111 | max_epochs=-1, 112 | accelerator="gpu", 113 | logger=logger, 114 | devices=cfg.device, 115 | strategy="ddp", 116 | callbacks=callbacks, 117 | val_check_interval=cfg.trainer.val_check_interval, 118 | enable_progress_bar=cfg.mode == "test", 119 | gradient_clip_val=cfg.trainer.gradient_clip_val, 120 | max_steps=cfg.trainer.max_steps, 121 | num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, 122 | precision=cfg.precision, 123 | ) 124 | print(f"GPU number is {torch.cuda.device_count()}") 125 | torch.manual_seed(cfg_dict.seed + trainer.global_rank) 126 | decoder = get_decoder(cfg.model.decoder, cfg.dataset) 127 | encoder, encoder_visualizer = get_encoder(cfg.model.encoder, decoder) 128 | if cfg.mode == "train" and checkpoint_path is not None: 129 | ckpt = torch.load(checkpoint_path)["state_dict"] 130 | ckpt = {".".join(k.split(".")[1:]): v for k, v in ckpt.items()} 131 | encoder.load_state_dict(ckpt) 132 | model_wrapper = ModelWrapper( 133 | cfg.optimizer, 134 | cfg.test, 135 | cfg.train, 136 | encoder, 137 | encoder_visualizer, 138 | get_decoder(cfg.model.decoder, cfg.dataset), 139 | get_losses(cfg.loss), 140 | step_tracker, 141 | ) 142 | model_wrapper = TorchSyncBatchNorm().apply(model_wrapper) 143 | 144 | data_module = DataModule( 145 | cfg.dataset, 146 | cfg.data_loader, 147 | step_tracker, 148 | global_rank=trainer.global_rank, 149 | ) 150 | 151 | if cfg.mode == "train": 152 | print("begin to train fit!") 153 | try: 154 | print(f"resume from {checkpoint_path}!!!") 155 | trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) 156 | except: 157 | print(f"start from scratch!!!") 158 | trainer.fit(model_wrapper, datamodule=data_module) 159 | elif cfg.mode == "test": 160 | trainer.test( 161 | model_wrapper, 162 | datamodule=data_module, 163 | ckpt_path=checkpoint_path, 164 | ) 165 | else: 166 | raise NotImplementedError(f"The {cfg.mode} mode is not implemented!") 167 | 168 | 169 | if __name__ == "__main__": 170 | warnings.filterwarnings("ignore") 171 | torch.set_float32_matmul_precision("high") 172 | 173 | train() 174 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from PIL import Image 6 | from pytorch_lightning.loggers.logger import Logger 7 | from pytorch_lightning.utilities import rank_zero_only 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | 18 | @property 19 | def name(self): 20 | return "LocalLogger" 21 | 22 | @property 23 | def version(self): 24 | return 0 25 | 26 | @rank_zero_only 27 | def log_hyperparams(self, params): 28 | pass 29 | 30 | @rank_zero_only 31 | def log_metrics(self, metrics, step): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_image( 36 | self, 37 | key: str, 38 | images: list[Any], 39 | step: Optional[int] = None, 40 | **kwargs, 41 | ): 42 | # The function signature is the same as the wandb logger's, but the step is 43 | # actually required. 44 | assert step is not None 45 | for index, image in enumerate(images): 46 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 47 | path.parent.mkdir(exist_ok=True, parents=True) 48 | Image.fromarray(image).save(path) 49 | -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | 39 | def clear_history(self) -> None: 40 | self.execution_times = defaultdict(list) 41 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return {key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0]} 14 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import cv2 6 | import numpy as np 7 | import skvideo.io 8 | import torch 9 | import torchvision.transforms as tf 10 | from einops import rearrange, repeat 11 | from jaxtyping import Float, UInt8 12 | from matplotlib.figure import Figure 13 | from PIL import Image 14 | from torch import Tensor 15 | 16 | FloatImage = Union[ 17 | Float[Tensor, "height width"], 18 | Float[Tensor, "channel height width"], 19 | Float[Tensor, "batch channel height width"], 20 | ] 21 | 22 | 23 | def fig_to_image( 24 | fig: Figure, 25 | dpi: int = 100, 26 | device: torch.device = torch.device("cpu"), 27 | ) -> Float[Tensor, "3 height width"]: 28 | buffer = io.BytesIO() 29 | fig.savefig(buffer, format="raw", dpi=dpi) 30 | buffer.seek(0) 31 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 32 | h = int(fig.bbox.bounds[3]) 33 | w = int(fig.bbox.bounds[2]) 34 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 35 | buffer.close() 36 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 37 | 38 | 39 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 40 | # Handle batched images. 41 | if image.ndim == 4: 42 | image = rearrange(image, "b c h w -> c h (b w)") 43 | 44 | # Handle single-channel images. 45 | if image.ndim == 2: 46 | image = rearrange(image, "h w -> () h w") 47 | 48 | # Ensure that there are 3 or 4 channels. 49 | channel, _, _ = image.shape 50 | if channel == 1: 51 | image = repeat(image, "() h w -> c h w", c=3) 52 | assert image.shape[0] in (3, 4) 53 | 54 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 55 | return rearrange(image, "c h w -> h w c").cpu().numpy() 56 | 57 | 58 | def save_image( 59 | image: FloatImage, 60 | path: Union[Path, str], 61 | ) -> None: 62 | """Save an image. Assumed to be in range 0-1.""" 63 | 64 | # Create the parent directory if it doesn't already exist. 65 | path = Path(path) 66 | path.parent.mkdir(exist_ok=True, parents=True) 67 | 68 | # Save the image. 69 | Image.fromarray(prep_image(image)).save(path) 70 | 71 | 72 | def save_depth_image(image: FloatImage, path: Union[Path, str], depth_max=10) -> None: 73 | """Save an image. Assumed to be in range 0-1.""" 74 | path = Path(path) 75 | path.parent.mkdir(exist_ok=True, parents=True) 76 | image = image.detach().cpu() 77 | image = image.clip(max=depth_max) / depth_max 78 | image = (image.numpy() * 255).clip(0, 255).astype(np.uint8) 79 | image = cv2.applyColorMap(image, cv2.COLORMAP_JET) 80 | cv2.imwrite(str(path), image) 81 | 82 | 83 | def save_batch_images(images: FloatImage, path: str, depth_max=None, depth_min=0.0) -> None: 84 | """Save a batch of images. Assumed to be in range 0-1.""" 85 | images = images.detach().cpu() 86 | if depth_max is not None: 87 | if depth_max != 1.0: 88 | images = images.clip(max=depth_max, min=depth_min) / (depth_max - depth_min) 89 | else: 90 | images = (images - images.min()) / (images.max() - images.min()) 91 | for i, image in enumerate(images): 92 | postfix = path.split(".")[1] 93 | img_name = path.split(".")[0] 94 | path_i = f"{img_name}_{i}.{postfix}" 95 | image = (image.numpy() * 255).clip(0, 255).astype(np.uint8) 96 | # image = cv2.applyColorMap(image, cv2.COLORMAP_JET) 97 | image = cv2.applyColorMap(image, cv2.COLORMAP_HOT) 98 | cv2.imwrite(path_i, image) 99 | else: 100 | for i, image in enumerate(images): 101 | postfix = path.split(".")[1] 102 | img_name = path.split(".")[0] 103 | path_i = f"{img_name}_{i}.{postfix}" 104 | save_image(image, path_i) 105 | 106 | 107 | def unnormalize_images(images): 108 | shape = [*[1] * (images.dim() - 3), 3, 1, 1] 109 | mean = torch.tensor([0.485, 0.456, 0.406]).reshape(*shape).to(images.device) 110 | std = torch.tensor([0.229, 0.224, 0.225]).reshape(*shape).to(images.device) 111 | 112 | return images * std + mean 113 | 114 | 115 | def load_image( 116 | path: Union[Path, str], 117 | ) -> Float[Tensor, "3 height width"]: 118 | return tf.ToTensor()(Image.open(path))[:3] 119 | 120 | 121 | def save_video( 122 | images: list[FloatImage], 123 | path: Union[Path, str], 124 | ) -> None: 125 | """Save an image. Assumed to be in range 0-1.""" 126 | 127 | # Create the parent directory if it doesn't already exist. 128 | path = Path(path) 129 | path.parent.mkdir(exist_ok=True, parents=True) 130 | 131 | # Save the image. 132 | # Image.fromarray(prep_image(image)).save(path) 133 | frames = [] 134 | for image in images: 135 | frames.append(prep_image(image)) 136 | 137 | writer = skvideo.io.FFmpegWriter(path, outputdict={"-pix_fmt": "yuv420p", "-crf": "21", "-vf": f"setpts=1.*PTS"}) 138 | for frame in frames: 139 | writer.writeFrame(frame) 140 | writer.close() 141 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import isqrt 3 | 4 | import torch 5 | 6 | # from e3nn.o3 import matrix_to_angles, so3_generators, xyz_to_angles, angles_to_matrix 7 | from e3nn.o3 import angles_to_matrix, so3_generators, xyz_to_angles 8 | from einops import einsum 9 | from jaxtyping import Float 10 | from torch import Tensor 11 | 12 | """ Compatibility for torch < 2.0""" 13 | 14 | 15 | def wigner_D(l, alpha, beta, gamma, device=torch.device("cpu")): 16 | r"""Wigner D matrix representation of :math:`SO(3)`. 17 | 18 | It satisfies the following properties: 19 | 20 | * :math:`D(\text{identity rotation}) = \text{identity matrix}` 21 | * :math:`D(R_1 \circ R_2) = D(R_1) \circ D(R_2)` 22 | * :math:`D(R^{-1}) = D(R)^{-1} = D(R)^T` 23 | * :math:`D(\text{rotation around Y axis})` has some property that allows us to use FFT in `ToS2Grid` 24 | 25 | Parameters 26 | ---------- 27 | l : int 28 | :math:`l` 29 | 30 | alpha : `torch.Tensor` 31 | tensor of shape :math:`(...)` 32 | Rotation :math:`\alpha` around Y axis, applied third. 33 | 34 | beta : `torch.Tensor` 35 | tensor of shape :math:`(...)` 36 | Rotation :math:`\beta` around X axis, applied second. 37 | 38 | gamma : `torch.Tensor` 39 | tensor of shape :math:`(...)` 40 | Rotation :math:`\gamma` around Y axis, applied first. 41 | 42 | Returns 43 | ------- 44 | `torch.Tensor` 45 | tensor :math:`D^l(\alpha, \beta, \gamma)` of shape :math:`(2l+1, 2l+1)` 46 | """ 47 | alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) 48 | alpha = alpha[..., None, None] % (2 * math.pi) 49 | beta = beta[..., None, None] % (2 * math.pi) 50 | gamma = gamma[..., None, None] % (2 * math.pi) 51 | X = so3_generators(l) 52 | X = X.to(device) 53 | return torch.matrix_exp(alpha * X[1]) @ torch.matrix_exp(beta * X[0]) @ torch.matrix_exp(gamma * X[1]) 54 | 55 | 56 | def matrix_to_angles(R): 57 | r"""conversion from matrix to angles 58 | 59 | Parameters 60 | ---------- 61 | R : `torch.Tensor` 62 | matrices of shape :math:`(..., 3, 3)` 63 | 64 | Returns 65 | ------- 66 | alpha : `torch.Tensor` 67 | tensor of shape :math:`(...)` 68 | 69 | beta : `torch.Tensor` 70 | tensor of shape :math:`(...)` 71 | 72 | gamma : `torch.Tensor` 73 | tensor of shape :math:`(...)` 74 | """ 75 | # assert torch.allclose(torch.det(R), R.new_tensor(1)) 76 | x = R @ R.new_tensor([0.0, 1.0, 0.0]) 77 | a, b = xyz_to_angles(x) 78 | R = angles_to_matrix(a, b, torch.zeros_like(a)).transpose(-1, -2) @ R 79 | c = torch.atan2(R[..., 0, 2], R[..., 0, 0]) 80 | return a, b, c 81 | 82 | 83 | def rotate_sh( 84 | sh_coefficients: Float[Tensor, "*#batch n"], 85 | rotations: Float[Tensor, "*#batch 3 3"], 86 | ) -> Float[Tensor, "*batch n"]: 87 | device = sh_coefficients.device 88 | dtype = sh_coefficients.dtype 89 | 90 | *_, n = sh_coefficients.shape 91 | alpha, beta, gamma = matrix_to_angles(rotations) 92 | result = [] 93 | for degree in range(isqrt(n)): 94 | # with torch.device(device): 95 | sh_rotations = wigner_D(degree, alpha, beta, gamma, device=device).type(dtype) 96 | sh_rotated = einsum( 97 | sh_rotations, 98 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 99 | "... i j, ... j -> ... i", 100 | ) 101 | result.append(sh_rotated) 102 | 103 | return torch.cat(result, dim=-1) 104 | 105 | 106 | if __name__ == "__main__": 107 | from pathlib import Path 108 | 109 | import matplotlib.pyplot as plt 110 | from e3nn.o3 import spherical_harmonics 111 | from matplotlib import cm 112 | from scipy.spatial.transform.rotation import Rotation as R 113 | 114 | device = torch.device("cuda") 115 | 116 | # Generate random spherical harmonics coefficients. 117 | degree = 4 118 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 119 | 120 | def plot_sh(sh_coefficients, path: Path) -> None: 121 | phi = torch.linspace(0, torch.pi, 100, device=device) 122 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 123 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 124 | x = torch.sin(phi) * torch.cos(theta) 125 | y = torch.sin(phi) * torch.sin(theta) 126 | z = torch.cos(phi) 127 | xyz = torch.stack([x, y, z], dim=-1) 128 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 129 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 130 | result = (result - result.min()) / (result.max() - result.min()) 131 | 132 | # Set the aspect ratio to 1 so our sphere looks spherical 133 | fig = plt.figure(figsize=plt.figaspect(1.0)) 134 | ax = fig.add_subplot(111, projection="3d") 135 | ax.plot_surface( 136 | x.cpu().numpy(), 137 | y.cpu().numpy(), 138 | z.cpu().numpy(), 139 | rstride=1, 140 | cstride=1, 141 | facecolors=cm.seismic(result.cpu().numpy()), 142 | ) 143 | # Turn off the axis planes 144 | ax.set_axis_off() 145 | path.parent.mkdir(exist_ok=True, parents=True) 146 | plt.savefig(path) 147 | 148 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 149 | rotation = torch.tensor(R.from_euler("x", angle.item()).as_matrix(), device=device) 150 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 151 | 152 | print("Done!") 153 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ...dataset import DatasetCfg 2 | from .decoder import Decoder 3 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 4 | 5 | DECODERS = { 6 | "splatting_cuda": DecoderSplattingCUDA, 7 | } 8 | 9 | DecoderCfg = DecoderSplattingCUDACfg 10 | 11 | 12 | def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: 13 | return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) 14 | -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ...dataset import DatasetCfg 9 | from ..types import Gaussians 10 | 11 | DepthRenderingMode = Literal[ 12 | "depth", 13 | "log", 14 | "disparity", 15 | "relative_disparity", 16 | ] 17 | 18 | 19 | @dataclass 20 | class DecoderOutput: 21 | color: Float[Tensor, "batch view 3 height width"] 22 | depth: Float[Tensor, "batch view height width"] | None 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class Decoder(nn.Module, ABC, Generic[T]): 29 | cfg: T 30 | dataset_cfg: DatasetCfg 31 | 32 | def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | self.dataset_cfg = dataset_cfg 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | gaussians: Gaussians, 41 | extrinsics: Float[Tensor, "batch view 4 4"], 42 | intrinsics: Float[Tensor, "batch view 3 3"], 43 | near: Float[Tensor, "batch view"], 44 | far: Float[Tensor, "batch view"], 45 | image_shape: tuple[int, int], 46 | depth_mode: DepthRenderingMode | None = None, 47 | ) -> DecoderOutput: 48 | pass 49 | -------------------------------------------------------------------------------- /src/model/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda 12 | from .decoder import Decoder, DecoderOutput 13 | 14 | 15 | @dataclass 16 | class DecoderSplattingCUDACfg: 17 | name: Literal["splatting_cuda"] 18 | 19 | 20 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 21 | background_color: Float[Tensor, "3"] 22 | 23 | def __init__( 24 | self, 25 | cfg: DecoderSplattingCUDACfg, 26 | dataset_cfg: DatasetCfg, 27 | ) -> None: 28 | super().__init__(cfg, dataset_cfg) 29 | self.register_buffer( 30 | "background_color", 31 | torch.tensor(dataset_cfg.background_color, dtype=torch.float32), 32 | persistent=False, 33 | ) 34 | 35 | def forward( 36 | self, 37 | gaussians: Gaussians, 38 | extrinsics: Float[Tensor, "batch view 4 4"], 39 | intrinsics: Float[Tensor, "batch view 3 3"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | image_shape: tuple[int, int], 43 | depth_mode: DepthRenderingMode | None = None, 44 | ) -> DecoderOutput: 45 | b, v, _, _ = extrinsics.shape 46 | color = render_cuda( 47 | rearrange(extrinsics, "b v i j -> (b v) i j"), # [4,4,4,4] 48 | rearrange(intrinsics, "b v i j -> (b v) i j"), # [4,4,3,3] 49 | rearrange(near, "b v -> (b v)"), # [4, 4] 50 | rearrange(far, "b v -> (b v)"), # [4, 4] 51 | image_shape, # [256,256] 52 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), # [3] 53 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), # [4,131072,3] 54 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), # [4,131072,3,3] 55 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), # [4,131072,3,25] 56 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), # [4,131072] 57 | ) 58 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) # [16, 3, 256, 256] 59 | 60 | return DecoderOutput( 61 | color, 62 | None 63 | if depth_mode is None 64 | else self.render_depth(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode), 65 | ) 66 | 67 | def render_depth( 68 | self, 69 | gaussians: Gaussians, 70 | extrinsics: Float[Tensor, "batch view 4 4"], 71 | intrinsics: Float[Tensor, "batch view 3 3"], 72 | near: Float[Tensor, "batch view"], 73 | far: Float[Tensor, "batch view"], 74 | image_shape: tuple[int, int], 75 | mode: DepthRenderingMode = "depth", 76 | ) -> Float[Tensor, "batch view height width"]: 77 | b, v, _, _ = extrinsics.shape 78 | result = render_depth_cuda( 79 | rearrange(extrinsics, "b v i j -> (b v) i j"), 80 | rearrange(intrinsics, "b v i j -> (b v) i j"), 81 | rearrange(near, "b v -> (b v)"), 82 | rearrange(far, "b v -> (b v)"), 83 | image_shape, 84 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 85 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 86 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 87 | mode=mode, 88 | ) 89 | return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) 90 | 91 | 92 | class DecoderGeneralSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 93 | background_color: Float[Tensor, "3"] 94 | 95 | def __init__( 96 | self, 97 | cfg: DecoderSplattingCUDACfg, 98 | dataset_cfg: DatasetCfg, 99 | ) -> None: 100 | super().__init__(cfg, dataset_cfg) 101 | self.register_buffer( 102 | "background_color", 103 | torch.tensor(dataset_cfg.background_color, dtype=torch.float32), 104 | persistent=False, 105 | ) 106 | 107 | def forward( 108 | self, 109 | gaussians: Gaussians, 110 | extrinsics: Float[Tensor, "batch view 4 4"], 111 | intrinsics: Float[Tensor, "batch view 3 3"], 112 | near: Float[Tensor, "batch view"], 113 | far: Float[Tensor, "batch view"], 114 | image_shape: tuple[int, int], 115 | depth_mode: DepthRenderingMode | None = None, 116 | ) -> DecoderOutput: 117 | b, v, _, _ = extrinsics.shape 118 | color = render_cuda( 119 | rearrange(extrinsics, "b v i j -> (b v) i j"), 120 | rearrange(intrinsics, "b v i j -> (b v) i j"), 121 | rearrange(near, "b v -> (b v)"), 122 | rearrange(far, "b v -> (b v)"), 123 | image_shape, 124 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 125 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 126 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 127 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 128 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 129 | ) 130 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 131 | 132 | return DecoderOutput( 133 | color, 134 | None 135 | if depth_mode is None 136 | else self.render_depth(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode), 137 | ) 138 | 139 | def render_depth( 140 | self, 141 | gaussians: Gaussians, 142 | extrinsics: Float[Tensor, "batch view 4 4"], 143 | intrinsics: Float[Tensor, "batch view 3 3"], 144 | near: Float[Tensor, "batch view"], 145 | far: Float[Tensor, "batch view"], 146 | image_shape: tuple[int, int], 147 | mode: DepthRenderingMode = "depth", 148 | ) -> Float[Tensor, "batch view height width"]: 149 | b, v, _, _ = extrinsics.shape 150 | result = render_depth_cuda( 151 | rearrange(extrinsics, "b v i j -> (b v) i j"), 152 | rearrange(intrinsics, "b v i j -> (b v) i j"), 153 | rearrange(near, "b v -> (b v)"), 154 | rearrange(far, "b v -> (b v)"), 155 | image_shape, 156 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 157 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 158 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 159 | mode=mode, 160 | ) 161 | return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) 162 | 163 | def render_norm( 164 | self, 165 | gaussians: Gaussians, 166 | extrinsics: Float[Tensor, "batch view 4 4"], 167 | intrinsics: Float[Tensor, "batch view 3 3"], 168 | near: Float[Tensor, "batch view"], 169 | far: Float[Tensor, "batch view"], 170 | image_shape: tuple[int, int], 171 | mode: DepthRenderingMode = "depth", 172 | ) -> Float[Tensor, "batch view height width"]: ... 173 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_costvolume_pyramid import ( 5 | EncoderCostVolumeCfgPyramid, 6 | EncoderCostVolumePyramid, 7 | ) 8 | from .visualization.encoder_visualizer import EncoderVisualizer 9 | from .visualization.encoder_visualizer_costvolume import EncoderVisualizerCostVolume 10 | 11 | ENCODERS = { 12 | "costvolume_pyramid": (EncoderCostVolumePyramid, EncoderVisualizerCostVolume), 13 | } 14 | 15 | EncoderCfg = EncoderCostVolumeCfgPyramid 16 | 17 | 18 | def get_encoder(cfg: EncoderCfg, decoder) -> tuple[Encoder, Optional[EncoderVisualizer]]: 19 | encoder, visualizer = ENCODERS[cfg.name] 20 | encoder = encoder(cfg) 21 | encoder.decoder = decoder 22 | if visualizer is not None: 23 | visualizer = visualizer(cfg.visualizer, encoder) 24 | return encoder, visualizer 25 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .backbone import Backbone 4 | from .backbone_dino import BackboneDino, BackboneDinoCfg 5 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 6 | 7 | BACKBONES: dict[str, Backbone[Any]] = { 8 | "resnet": BackboneResnet, 9 | "dino": BackboneDino, 10 | } 11 | BackboneCfg = BackboneResnetCfg | BackboneDinoCfg 12 | 13 | 14 | def get_backbone(cfg: BackboneCfg, d_in: int) -> Backbone[Any]: 15 | return BACKBONES[cfg.name](cfg, d_in) 16 | 17 | 18 | from .backbone_multiview import BackboneMultiview 19 | from .backbone_pyramid import BackbonePyramid 20 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor, nn 6 | 7 | from ....dataset.types import BatchedViews 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Backbone(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | ) -> Float[Tensor, "batch view d_out height width"]: 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def d_out(self) -> int: 29 | pass 30 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone_dino.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor, nn 8 | 9 | from ....dataset.types import BatchedViews 10 | from .backbone import Backbone 11 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 12 | 13 | 14 | @dataclass 15 | class BackboneDinoCfg: 16 | name: Literal["dino"] 17 | model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"] 18 | d_out: int 19 | 20 | 21 | class BackboneDino(Backbone[BackboneDinoCfg]): 22 | def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None: 23 | super().__init__(cfg) 24 | assert d_in == 3 25 | self.dino = torch.hub.load("facebookresearch/dino:main", cfg.model) 26 | self.resnet_backbone = BackboneResnet( 27 | BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out), 28 | d_in, 29 | ) 30 | self.global_token_mlp = nn.Sequential( 31 | nn.Linear(768, 768), 32 | nn.ReLU(), 33 | nn.Linear(768, cfg.d_out), 34 | ) 35 | self.local_token_mlp = nn.Sequential( 36 | nn.Linear(768, 768), 37 | nn.ReLU(), 38 | nn.Linear(768, cfg.d_out), 39 | ) 40 | 41 | def forward( 42 | self, 43 | context: BatchedViews, 44 | ) -> Float[Tensor, "batch view d_out height width"]: 45 | # Compute features from the DINO-pretrained resnet50. 46 | resnet_features = self.resnet_backbone(context) 47 | 48 | # Compute features from the DINO-pretrained ViT. 49 | b, v, _, h, w = context["image"].shape 50 | assert h % self.patch_size == 0 and w % self.patch_size == 0 51 | tokens = rearrange(context["image"], "b v c h w -> (b v) c h w") 52 | tokens = self.dino.get_intermediate_layers(tokens)[0] 53 | global_token = self.global_token_mlp(tokens[:, 0]) 54 | local_tokens = self.local_token_mlp(tokens[:, 1:]) 55 | 56 | # Repeat the global token to match the image shape. 57 | global_token = repeat(global_token, "(b v) c -> b v c h w", b=b, v=v, h=h, w=w) 58 | 59 | # Repeat the local tokens to match the image shape. 60 | local_tokens = repeat( 61 | local_tokens, 62 | "(b v) (h w) c -> b v c (h hps) (w wps)", 63 | b=b, 64 | v=v, 65 | h=h // self.patch_size, 66 | hps=self.patch_size, 67 | w=w // self.patch_size, 68 | wps=self.patch_size, 69 | ) 70 | 71 | return resnet_features + local_tokens + global_token 72 | 73 | @property 74 | def patch_size(self) -> int: 75 | return int("".join(filter(str.isdigit, self.cfg.model))) 76 | 77 | @property 78 | def d_out(self) -> int: 79 | return self.cfg.d_out 80 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone_resnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Literal 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from einops import rearrange 9 | from jaxtyping import Float 10 | from torch import Tensor, nn 11 | from torchvision.models import ResNet 12 | 13 | from ....dataset.types import BatchedViews 14 | from .backbone import Backbone 15 | 16 | 17 | @dataclass 18 | class BackboneResnetCfg: 19 | name: Literal["resnet"] 20 | model: Literal["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50"] 21 | num_layers: int 22 | use_first_pool: bool 23 | d_out: int 24 | 25 | 26 | class BackboneResnet(Backbone[BackboneResnetCfg]): 27 | model: ResNet 28 | 29 | def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None: 30 | super().__init__(cfg) 31 | 32 | assert d_in == 3 33 | 34 | norm_layer = functools.partial( 35 | nn.InstanceNorm2d, 36 | affine=False, 37 | track_running_stats=False, 38 | ) 39 | 40 | if cfg.model == "dino_resnet50": 41 | self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") 42 | else: 43 | self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer) 44 | 45 | # Set up projections 46 | self.projections = nn.ModuleDict({}) 47 | for index in range(1, cfg.num_layers): 48 | key = f"layer{index}" 49 | block = getattr(self.model, key) 50 | conv_index = 1 51 | try: 52 | while True: 53 | d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels 54 | conv_index += 1 55 | except AttributeError: 56 | pass 57 | self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1) 58 | 59 | # Add a projection for the first layer. 60 | self.projections["layer0"] = nn.Conv2d(self.model.conv1.out_channels, cfg.d_out, 1) 61 | 62 | def forward( 63 | self, 64 | context: BatchedViews, 65 | ) -> Float[Tensor, "batch view d_out height width"]: 66 | # Merge the batch dimensions. 67 | b, v, _, h, w = context["image"].shape 68 | x = rearrange(context["image"], "b v c h w -> (b v) c h w") 69 | 70 | # Run the images through the resnet. 71 | x = self.model.conv1(x) 72 | x = self.model.bn1(x) 73 | x = self.model.relu(x) 74 | features = [self.projections["layer0"](x)] 75 | 76 | # Propagate the input through the resnet's layers. 77 | for index in range(1, self.cfg.num_layers): 78 | key = f"layer{index}" 79 | if index == 0 and self.cfg.use_first_pool: 80 | x = self.model.maxpool(x) 81 | x = getattr(self.model, key)(x) 82 | features.append(self.projections[key](x)) 83 | 84 | # Upscale the features. 85 | features = [F.interpolate(f, (h, w), mode="bilinear", align_corners=True) for f in features] 86 | features = torch.stack(features).sum(dim=0) 87 | 88 | # Separate batch dimensions. 89 | return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v) 90 | 91 | @property 92 | def d_out(self) -> int: 93 | return self.cfg.d_out 94 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/src/model/encoder/backbone/mvsformer_module/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/cost_volume.py: -------------------------------------------------------------------------------- 1 | # from models.swin.block import SwinTransformerCostReg 2 | 3 | from .module import * 4 | from .warping import homo_warping_3D_with_mask 5 | 6 | 7 | class identity_with(object): 8 | def __init__(self, enabled=True): 9 | self._enabled = enabled 10 | 11 | def __enter__(self): 12 | pass 13 | 14 | def __exit__(self, *args): 15 | pass 16 | 17 | 18 | autocast = torch.cuda.amp.autocast if torch.__version__ >= "1.6.0" else identity_with 19 | 20 | 21 | class StageNet(nn.Module): 22 | def __init__(self, args, ndepth, stage_idx): 23 | super(StageNet, self).__init__() 24 | self.args = args 25 | self.fusion_type = args.get("fusion_type", "cnn") 26 | self.ndepth = ndepth 27 | self.stage_idx = stage_idx 28 | self.cost_reg_type = args.get("cost_reg_type", ["Normal", "Normal", "Normal", "Normal"])[stage_idx] 29 | self.depth_type = args["depth_type"] 30 | if type(self.depth_type) == list: 31 | self.depth_type = self.depth_type[stage_idx] 32 | 33 | in_channels = args["base_ch"] 34 | if type(in_channels) == list: 35 | in_channels = in_channels[stage_idx] 36 | if self.fusion_type == "cnn": 37 | self.vis = nn.Sequential( 38 | ConvBnReLU(1, 16), ConvBnReLU(16, 16), ConvBnReLU(16, 8), nn.Conv2d(8, 1, 1), nn.Sigmoid() 39 | ) 40 | else: 41 | raise NotImplementedError(f"Not implemented fusion type: {self.fusion_type}.") 42 | 43 | if self.cost_reg_type == "PureTransformerCostReg": 44 | args["transformer_config"][stage_idx]["base_channel"] = in_channels 45 | self.cost_reg = PureTransformerCostReg(in_channels, **args["transformer_config"][stage_idx]) 46 | else: 47 | model_th = args.get("model_th", 8) 48 | if ndepth <= model_th: # do not downsample in depth range 49 | self.cost_reg = CostRegNet3D(in_channels, in_channels) 50 | else: 51 | self.cost_reg = CostRegNet(in_channels, in_channels) 52 | 53 | def forward(self, features, proj_matrices, depth_values, tmp, position3d=None): 54 | ref_feat = features[:, 0] 55 | src_feats = features[:, 1:] 56 | src_feats = torch.unbind(src_feats, dim=1) 57 | proj_matrices = torch.unbind(proj_matrices, 1) 58 | assert len(src_feats) == len(proj_matrices) - 1, "Different number of images and projection matrices" 59 | 60 | # step 1. feature extraction 61 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 62 | # step 2. differentiable homograph, build cost volume 63 | volume_sum = 0.0 64 | vis_sum = 0.0 65 | similarities = [] 66 | with autocast(enabled=False): 67 | for src_feat, src_proj in zip(src_feats, src_projs): 68 | # warpped features 69 | src_feat = src_feat.to(torch.float32) 70 | src_proj_new = src_proj[:, 0].clone() 71 | src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) 72 | ref_proj_new = ref_proj[:, 0].clone() 73 | ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) 74 | warped_volume, proj_mask = homo_warping_3D_with_mask(src_feat, src_proj_new, ref_proj_new, depth_values) 75 | 76 | B, C, D, H, W = warped_volume.shape 77 | G = self.args["base_ch"] 78 | if type(G) == list: 79 | G = G[self.stage_idx] 80 | 81 | if G < C: 82 | warped_volume = warped_volume.view(B, G, C // G, D, H, W) 83 | ref_volume = ref_feat.view(B, G, C // G, 1, H, W).repeat(1, 1, 1, D, 1, 1).to(torch.float32) 84 | in_prod_vol = (ref_volume * warped_volume).mean(dim=2) # [B,G,D,H,W] 85 | elif G == C: 86 | ref_volume = ref_feat.view(B, G, 1, H, W).to(torch.float32) 87 | in_prod_vol = ref_volume * warped_volume # [B,C(G),D,H,W] 88 | else: 89 | raise AssertionError("G must <= C!") 90 | 91 | if self.fusion_type == "cnn": 92 | sim_vol = in_prod_vol.sum(dim=1) # [B,D,H,W] 93 | sim_vol_norm = F.softmax(sim_vol.detach(), dim=1) 94 | entropy = (-sim_vol_norm * torch.log(sim_vol_norm + 1e-7)).sum(dim=1, keepdim=True) 95 | vis_weight = self.vis(entropy) 96 | else: 97 | raise NotImplementedError 98 | 99 | volume_sum = volume_sum + in_prod_vol * vis_weight.unsqueeze(1) 100 | vis_sum = vis_sum + vis_weight 101 | 102 | # aggregate multiple feature volumes by variance 103 | volume_mean = volume_sum / (vis_sum.unsqueeze(1) + 1e-6) # volume_sum / (num_views - 1) 104 | 105 | cost_reg = self.cost_reg(volume_mean, position3d) 106 | 107 | prob_volume_pre = cost_reg.squeeze(1) 108 | prob_volume = F.softmax(prob_volume_pre, dim=1) 109 | 110 | if self.depth_type == "ce": 111 | if self.training: 112 | _, idx = torch.max(prob_volume, dim=1) 113 | # vanilla argmax 114 | depth = torch.gather(depth_values, dim=1, index=idx.unsqueeze(1)).squeeze(1) 115 | else: 116 | # regression (t) 117 | depth = depth_regression(F.softmax(prob_volume_pre * tmp, dim=1), depth_values=depth_values) 118 | # # regression (t) 119 | # depth = depth_regression(F.softmax(prob_volume_pre * tmp, dim=1), depth_values=depth_values) 120 | # conf 121 | photometric_confidence = prob_volume.max(1)[0] # [B,H,W] 122 | 123 | else: 124 | depth = depth_regression(prob_volume, depth_values=depth_values) 125 | if self.ndepth >= 32: 126 | photometric_confidence = conf_regression(prob_volume, n=4) 127 | elif self.ndepth == 16: 128 | photometric_confidence = conf_regression(prob_volume, n=3) 129 | elif self.ndepth == 8: 130 | photometric_confidence = conf_regression(prob_volume, n=2) 131 | else: # D == 4 132 | photometric_confidence = prob_volume.max(1)[0] # [B,H,W] 133 | 134 | outputs = { 135 | "depth": depth, 136 | "prob_volume": prob_volume, 137 | "photometric_confidence": photometric_confidence.detach(), 138 | "depth_values": depth_values, 139 | "prob_volume_pre": prob_volume_pre, 140 | } 141 | 142 | return outputs 143 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/src/model/encoder/backbone/mvsformer_module/dino/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .attention import MemEffAttention 8 | from .block import NestedTensorBlock 9 | from .dino_head import DINOHead 10 | from .mlp import Mlp 11 | from .patch_embed import PatchEmbed 12 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 13 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor, nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | import torch.nn as nn 14 | from torch import Tensor 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/dino/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | import torch.nn.functional as F 10 | from torch import Tensor, nn 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | 13 | def param_groups_lrd(model, vit_lr, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75): 14 | """ 15 | Parameter groups for layer-wise lr decay 16 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 17 | """ 18 | param_group_names = {} 19 | param_groups = {} 20 | 21 | num_layers = len(model.blocks) + 1 22 | 23 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 24 | 25 | for n, p in model.named_parameters(): 26 | if not p.requires_grad: 27 | continue 28 | 29 | # no decay: all 1D parameters and model specific ones 30 | if p.ndim == 1 or n in no_weight_decay_list: 31 | g_decay = "no_decay" 32 | this_decay = 0.0 33 | else: 34 | g_decay = "decay" 35 | this_decay = weight_decay 36 | 37 | layer_id = get_layer_id_for_vit(n, num_layers) 38 | group_name = "layer_%d_%s" % (layer_id, g_decay) 39 | 40 | if group_name not in param_group_names: 41 | if layer_id == -1: 42 | this_scale = 10.0 43 | else: 44 | this_scale = layer_scales[layer_id] 45 | 46 | param_group_names[group_name] = { 47 | "lr": vit_lr, 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | "vit_param": True, 52 | } 53 | param_groups[group_name] = { 54 | "lr": vit_lr, 55 | "lr_scale": this_scale, 56 | "weight_decay": this_decay, 57 | "params": [], 58 | "vit_param": True, 59 | } 60 | 61 | param_group_names[group_name]["params"].append(n) 62 | param_groups[group_name]["params"].append(p) 63 | 64 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 65 | 66 | return list(param_groups.values()) 67 | 68 | 69 | def get_layer_id_for_vit(name, num_layers): 70 | """ 71 | Assign a parameter with its layer id 72 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 73 | """ 74 | if name in ["cls_token", "pos_embed"]: 75 | return 0 76 | elif name.startswith("patch_embed"): 77 | return 0 78 | elif name.startswith("cross_blocks"): 79 | return -1 80 | elif name.startswith("blocks"): 81 | return int(name.split(".")[1]) + 1 82 | else: 83 | return num_layers 84 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/mvsformer_module/networks/casmvs_model.py: -------------------------------------------------------------------------------- 1 | from ..cost_volume import * 2 | from ..position_encoding import get_position_3d 3 | 4 | Align_Corners_Range = False 5 | 6 | 7 | class identity_with(object): 8 | def __init__(self, enabled=True): 9 | self._enabled = enabled 10 | 11 | def __enter__(self): 12 | pass 13 | 14 | def __exit__(self, *args): 15 | pass 16 | 17 | 18 | autocast = torch.cuda.amp.autocast if torch.__version__ >= "1.6.0" else identity_with 19 | 20 | 21 | class CasMVSNet(nn.Module): 22 | def __init__(self, args): 23 | super(CasMVSNet, self).__init__() 24 | self.args = args 25 | self.ndepths = args["ndepths"] 26 | self.depth_interals_ratio = args["depth_interals_ratio"] 27 | self.inverse_depth = args.get("inverse_depth", False) 28 | self.use_pe3d = args.get("use_pe3d", False) 29 | self.cost_reg_type = args.get("cost_reg_type", ["Normal", "Normal", "Normal", "Normal"]) 30 | 31 | self.encoder = FPNEncoder(feat_chs=args["feat_chs"]) 32 | self.decoder = FPNDecoder(feat_chs=args["feat_chs"]) 33 | 34 | self.fusions = nn.ModuleList([StageNet(args, self.ndepths[i], i) for i in range(len(self.ndepths))]) 35 | 36 | def forward(self, imgs, proj_matrices, depth_values, tmp=[5.0, 5.0, 5.0, 1.0]): 37 | B, V, H, W = imgs.shape[0], imgs.shape[1], imgs.shape[3], imgs.shape[4] 38 | depth_interval = depth_values[:, 1] - depth_values[:, 0] 39 | 40 | # feature encode 41 | if not self.training: 42 | feat1s, feat2s, feat3s, feat4s = [], [], [], [] 43 | for vi in range(V): 44 | img_v = imgs[:, vi] 45 | conv01, conv11, conv21, conv31 = self.encoder(img_v) 46 | feat1, feat2, feat3, feat4 = self.decoder.forward(conv01, conv11, conv21, conv31) 47 | feat1s.append(feat1) 48 | feat2s.append(feat2) 49 | feat3s.append(feat3) 50 | feat4s.append(feat4) 51 | 52 | features = { 53 | "stage1": torch.stack(feat1s, dim=1), 54 | "stage2": torch.stack(feat2s, dim=1), 55 | "stage3": torch.stack(feat3s, dim=1), 56 | "stage4": torch.stack(feat4s, dim=1), 57 | } 58 | else: 59 | imgs = imgs.reshape(B * V, 3, H, W) 60 | conv01, conv11, conv21, conv31 = self.encoder(imgs) 61 | feat1, feat2, feat3, feat4 = self.decoder.forward(conv01, conv11, conv21, conv31) 62 | 63 | features = { 64 | "stage1": feat1.reshape(B, V, feat1.shape[1], feat1.shape[2], feat1.shape[3]), 65 | "stage2": feat2.reshape(B, V, feat2.shape[1], feat2.shape[2], feat2.shape[3]), 66 | "stage3": feat3.reshape(B, V, feat3.shape[1], feat3.shape[2], feat3.shape[3]), 67 | "stage4": feat4.reshape(B, V, feat4.shape[1], feat4.shape[2], feat4.shape[3]), 68 | } 69 | 70 | outputs = {} 71 | outputs_stage = {} 72 | height_min, height_max = None, None 73 | width_min, width_max = None, None 74 | 75 | prob_maps = torch.zeros([B, H, W], dtype=torch.float32, device=imgs.device) 76 | valid_count = 0 77 | for stage_idx in range(len(self.ndepths)): 78 | proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)] # [B,V,2,4,4] 79 | features_stage = features["stage{}".format(stage_idx + 1)] 80 | B, V, C, H, W = features_stage.shape 81 | # init range 82 | if stage_idx == 0: 83 | if self.inverse_depth: 84 | depth_samples = init_inverse_range( 85 | depth_values, self.ndepths[stage_idx], imgs.device, imgs.dtype, H, W 86 | ) 87 | else: 88 | depth_samples = init_inverse_range( 89 | depth_values, self.ndepths[stage_idx], imgs.device, imgs.dtype, H, W 90 | ) 91 | else: 92 | if self.inverse_depth: 93 | depth_samples = schedule_inverse_range( 94 | outputs_stage["depth"].detach(), 95 | outputs_stage["depth_values"], 96 | self.ndepths[stage_idx], 97 | self.depth_interals_ratio[stage_idx], 98 | H, 99 | W, 100 | ) # B D H W 101 | else: 102 | depth_samples = schedule_range( 103 | outputs_stage["depth"].detach(), 104 | self.ndepths[stage_idx], 105 | self.depth_interals_ratio[stage_idx] * depth_interval, 106 | H, 107 | W, 108 | ) 109 | # get 3D PE 110 | if self.cost_reg_type[stage_idx] != "Normal" and self.use_pe3d: 111 | K = proj_matrices_stage[ 112 | :, 0, 1, :3, :3 113 | ] # [B,3,3] stage1获取全局最小最大h,w,后续根据stage1的全局空间进行归一化 114 | position3d, height_min, height_max, width_min, width_max = get_position_3d( 115 | B, 116 | H, 117 | W, 118 | K, 119 | depth_samples, 120 | depth_min=depth_values.min(), 121 | depth_max=depth_values.max(), 122 | height_min=height_min, 123 | height_max=height_max, 124 | width_min=width_min, 125 | width_max=width_max, 126 | normalize=True, 127 | ) 128 | else: 129 | position3d = None 130 | 131 | outputs_stage = self.fusions[stage_idx].forward( 132 | features_stage, proj_matrices_stage, depth_samples, tmp=tmp[stage_idx], position3d=position3d 133 | ) 134 | outputs["stage{}".format(stage_idx + 1)] = outputs_stage 135 | 136 | depth_conf = outputs_stage["photometric_confidence"] 137 | if depth_conf.shape[1] != prob_maps.shape[1] or depth_conf.shape[2] != prob_maps.shape[2]: 138 | depth_conf = F.interpolate( 139 | depth_conf.unsqueeze(1), [prob_maps.shape[1], prob_maps.shape[2]], mode="nearest" 140 | ).squeeze(1) 141 | prob_maps += depth_conf 142 | valid_count += 1 143 | outputs.update(outputs_stage) 144 | 145 | outputs["refined_depth"] = outputs_stage["depth"] 146 | print(f"depth min {outputs_stage['depth'].min()}, max {outputs_stage['depth'].max()}") 147 | 148 | if valid_count > 0: 149 | outputs["photometric_confidence"] = prob_maps / valid_count 150 | 151 | return outputs 152 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/src/model/encoder/backbone/unimatch/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class PositionEmbeddingSine(nn.Module): 11 | """ 12 | This is a more standard version of the position embedding, very similar to the one 13 | used by the Attention is all you need paper, generalized to work on images. 14 | """ 15 | 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, x): 28 | # x = tensor_list.tensors # [B, C, H, W] 29 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 30 | b, c, h, w = x.size() 31 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 32 | y_embed = mask.cumsum(1, dtype=torch.float32) 33 | x_embed = mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos 48 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/reg_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__( 8 | self, 9 | input_dim=128, 10 | hidden_dim=256, 11 | out_dim=2, 12 | ): 13 | super(FlowHead, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 16 | self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | out = self.conv2(self.relu(self.conv1(x))) 21 | 22 | return out 23 | 24 | 25 | class SepConvGRU(nn.Module): 26 | def __init__( 27 | self, 28 | hidden_dim=128, 29 | input_dim=192 + 128, 30 | kernel_size=5, 31 | ): 32 | padding = (kernel_size - 1) // 2 33 | 34 | super(SepConvGRU, self).__init__() 35 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 36 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 37 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 38 | 39 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 40 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 41 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 42 | 43 | def forward(self, h, x): 44 | # horizontal 45 | hx = torch.cat([h, x], dim=1) 46 | z = torch.sigmoid(self.convz1(hx)) 47 | r = torch.sigmoid(self.convr1(hx)) 48 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 49 | h = (1 - z) * h + z * q 50 | 51 | # vertical 52 | hx = torch.cat([h, x], dim=1) 53 | z = torch.sigmoid(self.convz2(hx)) 54 | r = torch.sigmoid(self.convr2(hx)) 55 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 56 | h = (1 - z) * h + z * q 57 | 58 | return h 59 | 60 | 61 | class BasicMotionEncoder(nn.Module): 62 | def __init__( 63 | self, 64 | corr_channels=324, 65 | flow_channels=2, 66 | ): 67 | super(BasicMotionEncoder, self).__init__() 68 | 69 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) 70 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 71 | self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) 72 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 73 | self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) 74 | 75 | def forward(self, flow, corr): 76 | cor = F.relu(self.convc1(corr)) 77 | cor = F.relu(self.convc2(cor)) 78 | flo = F.relu(self.convf1(flow)) 79 | flo = F.relu(self.convf2(flo)) 80 | 81 | cor_flo = torch.cat([cor, flo], dim=1) 82 | out = F.relu(self.conv(cor_flo)) 83 | return torch.cat([out, flow], dim=1) 84 | 85 | 86 | class BasicUpdateBlock(nn.Module): 87 | def __init__( 88 | self, 89 | corr_channels=324, 90 | hidden_dim=128, 91 | context_dim=128, 92 | downsample_factor=8, 93 | flow_dim=2, 94 | bilinear_up=False, 95 | ): 96 | super(BasicUpdateBlock, self).__init__() 97 | 98 | self.encoder = BasicMotionEncoder( 99 | corr_channels=corr_channels, 100 | flow_channels=flow_dim, 101 | ) 102 | 103 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) 104 | 105 | self.flow_head = FlowHead( 106 | hidden_dim, 107 | hidden_dim=256, 108 | out_dim=flow_dim, 109 | ) 110 | 111 | if bilinear_up: 112 | self.mask = None 113 | else: 114 | self.mask = nn.Sequential( 115 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 116 | nn.ReLU(inplace=True), 117 | nn.Conv2d(256, downsample_factor**2 * 9, 1, padding=0), 118 | ) 119 | 120 | def forward(self, net, inp, corr, flow): 121 | motion_features = self.encoder(flow, corr) 122 | 123 | inp = torch.cat([inp, motion_features], dim=1) 124 | 125 | net = self.gru(net, inp) 126 | delta_flow = self.flow_head(net) 127 | 128 | if self.mask is not None: 129 | mask = self.mask(net) 130 | else: 131 | mask = None 132 | 133 | return net, mask, delta_flow 134 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 53 | if bias: 54 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 55 | else: 56 | self.bias = None 57 | 58 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 59 | if self.bias is not None: 60 | nn.init.constant_(self.bias, 0) 61 | 62 | def forward(self, inputs): 63 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 64 | assert len(inputs) == num_branch 65 | 66 | if self.training or self.test_branch_idx == -1: 67 | outputs = [ 68 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 69 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 70 | ] 71 | else: 72 | outputs = [ 73 | F.conv2d( 74 | inputs[0], 75 | self.weight, 76 | self.bias, 77 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 78 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 79 | self.dilation, 80 | self.groups, 81 | ) 82 | ] 83 | 84 | if self.norm is not None: 85 | outputs = [self.norm(x) for x in outputs] 86 | if self.activation is not None: 87 | outputs = [self.activation(x) for x in outputs] 88 | return outputs 89 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussian_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import einsum, rearrange 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ....geometry.projection import get_world_rays_and_cam_xy_sin 9 | from ....global_cfg import get_cfg 10 | from ....misc.sh_rotation import rotate_sh 11 | from .gaussians import build_covariance 12 | 13 | 14 | @dataclass 15 | class Gaussians: 16 | means: Float[Tensor, "*batch 3"] 17 | covariances: Float[Tensor, "*batch 3 3"] 18 | scales: Float[Tensor, "*batch 3"] 19 | rotations: Float[Tensor, "*batch 4"] 20 | harmonics: Float[Tensor, "*batch 3 _"] 21 | opacities: Float[Tensor, " *batch"] 22 | # normal: Float[Tensor, "*batch 3"] | None 23 | # distance: Float[Tensor, "*batch 3"] | None 24 | 25 | 26 | @dataclass 27 | class GaussianAdapterCfg: 28 | gaussian_scale_min: float 29 | gaussian_scale_max: float 30 | sh_degree: int 31 | 32 | 33 | class GaussianAdapter(nn.Module): 34 | cfg: GaussianAdapterCfg 35 | 36 | def __init__(self, cfg: GaussianAdapterCfg): 37 | super().__init__() 38 | self.cfg = cfg 39 | 40 | # Create a mask for the spherical harmonics coefficients. This ensures that at 41 | # initialization, the coefficients are biased towards having a large DC 42 | # component and small view-dependent components. 43 | self.register_buffer( 44 | "sh_mask", 45 | torch.ones((self.d_sh,), dtype=torch.float32), 46 | persistent=False, 47 | ) 48 | for degree in range(1, self.cfg.sh_degree + 1): 49 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 50 | 51 | def forward( 52 | self, 53 | extrinsics: Float[Tensor, "*#batch 4 4"], 54 | intrinsics: Float[Tensor, "*#batch 3 3"], 55 | coordinates: Float[Tensor, "*#batch 2"], 56 | depths: Float[Tensor, "*#batch"], 57 | opacities: Float[Tensor, "*#batch"], 58 | raw_gaussians: Float[Tensor, "*#batch _"], 59 | image_shape: tuple[int, int], 60 | eps: float = 1e-8, 61 | stage_id: int = 0, 62 | ): 63 | device = extrinsics.device 64 | scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) 65 | 66 | # Map scale features to valid scale range. 67 | scale_min = self.cfg.gaussian_scale_min 68 | scale_max = self.cfg.gaussian_scale_max 69 | scales = scale_min + (scale_max - scale_min) * scales.sigmoid() 70 | h, w = image_shape 71 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device) 72 | multiplier = self.get_scale_multiplier(intrinsics, pixel_size) 73 | decouple_scales = scales * depths[..., None].detach() * multiplier[..., None] 74 | scales = scales * depths[..., None] * multiplier[..., None] # TODO: why? the far points need big scale? 75 | 76 | # Normalize the quaternion features to yield a valid quaternion. 77 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) 78 | 79 | # Apply sigmoid to get valid colors. 80 | sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 81 | sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask 82 | 83 | # Create world-space covariance matrices. 84 | covariances = build_covariance(scales, rotations) 85 | c2w_rotations = extrinsics[..., :3, :3] 86 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) 87 | 88 | # Compute Gaussian means. 89 | # origins, directions = get_world_rays(coordinates, extrinsics, intrinsics) 90 | origins, directions, xy_sin = get_world_rays_and_cam_xy_sin(coordinates, extrinsics, intrinsics) 91 | # We fix it and use the correct distance 92 | # TODO: whether to use xy_sin 93 | if get_cfg()["use_xy_sin"]: 94 | means = origins + directions * depths[..., None] / (xy_sin + 1e-8) 95 | else: 96 | means = origins + directions * depths[..., None] 97 | # means = origins + directions * depths[..., None] 98 | return Gaussians( 99 | means=means, 100 | covariances=covariances, 101 | harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]), 102 | opacities=opacities, 103 | # NOTE: These aren't yet rotated into world space, but they're only used for 104 | # exporting Gaussians to ply files. This needs to be fixed.. 105 | scales=scales, 106 | rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 107 | ), decouple_scales 108 | 109 | def get_scale_multiplier( 110 | self, 111 | intrinsics: Float[Tensor, "*#batch 3 3"], 112 | pixel_size: Float[Tensor, "*#batch 2"], 113 | multiplier: float = 0.1, 114 | ) -> Float[Tensor, " *batch"]: 115 | xy_multipliers = multiplier * einsum( 116 | intrinsics[..., :2, :2].inverse(), 117 | pixel_size, 118 | "... i j, j -> ... i", 119 | ) 120 | return xy_multipliers.sum(dim=-1) 121 | 122 | @property 123 | def d_sh(self) -> int: 124 | return (self.cfg.sh_degree + 1) ** 2 125 | 126 | @property 127 | def d_in(self) -> int: 128 | return 7 + 3 * self.d_sh 129 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 9 | def quaternion_to_matrix( 10 | quaternions: Float[Tensor, "*batch 4"], 11 | eps: float = 1e-8, 12 | ) -> Float[Tensor, "*batch 3 3"]: 13 | # Order changed to match scipy format! 14 | i, j, k, r = torch.unbind(quaternions, dim=-1) 15 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 16 | 17 | o = torch.stack( 18 | ( 19 | 1 - two_s * (j * j + k * k), 20 | two_s * (i * j - k * r), 21 | two_s * (i * k + j * r), 22 | two_s * (i * j + k * r), 23 | 1 - two_s * (i * i + k * k), 24 | two_s * (j * k - i * r), 25 | two_s * (i * k - j * r), 26 | two_s * (j * k + i * r), 27 | 1 - two_s * (i * i + j * j), 28 | ), 29 | -1, 30 | ) 31 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 32 | 33 | 34 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Returns torch.sqrt(torch.max(0, x)) 37 | but with a zero subgradient where x is 0. 38 | """ 39 | ret = torch.zeros_like(x) 40 | positive_mask = x > 0 41 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 42 | return ret 43 | 44 | 45 | def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: 46 | """ 47 | Convert a unit quaternion to a standard form: one in which the real 48 | part is non negative. 49 | 50 | Args: 51 | quaternions: Quaternions with real part first, 52 | as tensor of shape (..., 4). 53 | 54 | Returns: 55 | Standardized quaternions as tensor of shape (..., 4). 56 | """ 57 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 58 | 59 | 60 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Convert rotations given as rotation matrices to quaternions. 63 | 64 | Args: 65 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 66 | 67 | Returns: 68 | quaternions with real part first, as tensor of shape (..., 4). 69 | """ 70 | 71 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 72 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 73 | 74 | batch_dim = matrix.shape[:-2] 75 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) 76 | 77 | q_abs = _sqrt_positive_part( 78 | torch.stack( 79 | [ 80 | 1.0 + m00 + m11 + m22, 81 | 1.0 + m00 - m11 - m22, 82 | 1.0 - m00 + m11 - m22, 83 | 1.0 - m00 - m11 + m22, 84 | ], 85 | dim=-1, 86 | ) 87 | ) 88 | 89 | # we produce the desired quaternion multiplied by each of r, i, j, k 90 | quat_by_rijk = torch.stack( 91 | [ 92 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 93 | # `int`. 94 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 95 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 96 | # `int`. 97 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 98 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 99 | # `int`. 100 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 101 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 102 | # `int`. 103 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 104 | ], 105 | dim=-2, 106 | ) 107 | 108 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 109 | # the candidate won't be picked. 110 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 111 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 112 | 113 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 114 | # forall i; we pick the best-conditioned one (with the largest denominator) 115 | out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) 116 | out = standardize_quaternion(out) 117 | out = torch.cat([out[..., 1:], out[..., 0:1]], dim=-1) 118 | return out 119 | 120 | 121 | def build_covariance( 122 | scale: Float[Tensor, "*#batch 3"], 123 | rotation_xyzw: Float[Tensor, "*#batch 4"], 124 | ) -> Float[Tensor, "*batch 3 3"]: 125 | scale = scale.diag_embed() 126 | rotation = quaternion_to_matrix(rotation_xyzw) 127 | return rotation @ scale @ rearrange(scale, "... i j -> ... j i") @ rearrange(rotation, "... i j -> ... j i") 128 | -------------------------------------------------------------------------------- /src/model/encoder/common/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor, nn 3 | 4 | from ....misc.discrete_probability_distribution import ( 5 | gather_discrete_topk, 6 | sample_discrete_distribution, 7 | ) 8 | 9 | 10 | class Sampler(nn.Module): 11 | def forward( 12 | self, 13 | probabilities: Float[Tensor, "*batch bucket"], 14 | num_samples: int, 15 | deterministic: bool, 16 | ) -> tuple[ 17 | Int64[Tensor, "*batch 1"], # index 18 | Float[Tensor, "*batch 1"], # probability density 19 | ]: 20 | return ( 21 | gather_discrete_topk(probabilities, num_samples) 22 | if deterministic 23 | else sample_discrete_distribution(probabilities, num_samples) 24 | ) 25 | 26 | def gather( 27 | self, 28 | index: Int64[Tensor, "*batch sample"], 29 | target: Shaped[Tensor, "..."], # *batch bucket *shape 30 | ) -> Shaped[Tensor, "..."]: # *batch sample *shape 31 | """Gather from the target according to the specified index. Handle the 32 | broadcasting needed for the gather to work. See the comments for the actual 33 | expected input/output shapes since jaxtyping doesn't support multiple variadic 34 | lengths in annotations. 35 | """ 36 | bucket_dim = index.ndim - 1 37 | while len(index.shape) < len(target.shape): 38 | index = index[..., None] 39 | broadcasted_index_shape = list(target.shape) 40 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 41 | index = index.broadcast_to(broadcasted_index_shape) 42 | return target.gather(dim=bucket_dim, index=index) 43 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/conversions.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | 4 | 5 | def relative_disparity_to_depth( 6 | relative_disparity: Float[Tensor, "*#batch"], 7 | near: Float[Tensor, "*#batch"], 8 | far: Float[Tensor, "*#batch"], 9 | eps: float = 1e-10, 10 | ) -> Float[Tensor, " *batch"]: 11 | """Convert relative disparity, where 0 is near and 1 is far, to depth.""" 12 | disp_near = 1 / (near + eps) 13 | disp_far = 1 / (far + eps) 14 | return 1 / ((1 - relative_disparity) * (disp_near - disp_far) + disp_far + eps) 15 | 16 | 17 | def depth_to_relative_disparity( 18 | depth: Float[Tensor, "*#batch"], 19 | near: Float[Tensor, "*#batch"], 20 | far: Float[Tensor, "*#batch"], 21 | eps: float = 1e-10, 22 | ) -> Float[Tensor, " *batch"]: 23 | """Convert depth to relative disparity, where 0 is near and 1 is far""" 24 | disp_near = 1 / (near + eps) 25 | disp_far = 1 / (far + eps) 26 | disp = 1 / (depth + eps) 27 | return 1 - (disp - disp_far) / (disp_near - disp_far + eps) 28 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Open3DVLab/HiSplat/1d08e76998825d2b0fac48ef617243c1c1784cb1/src/model/encoder/costvolume/ldm_unet/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | deterministic: bool, 24 | ) -> Gaussians: 25 | pass 26 | 27 | def get_data_shim(self) -> DataShim: 28 | """The default shim doesn't modify the batch.""" 29 | return lambda x: x 30 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer_costvolume_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerCostVolumeCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /src/model/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange, repeat 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """For the sake of simplicity, this encodes values in the range [0, 1].""" 10 | 11 | frequencies: Float[Tensor, "frequency phase"] 12 | phases: Float[Tensor, "frequency phase"] 13 | 14 | def __init__(self, num_octaves: int): 15 | super().__init__() 16 | octaves = torch.arange(num_octaves).float() 17 | 18 | # The lowest frequency has a period of 1. 19 | frequencies = 2 * torch.pi * 2**octaves 20 | frequencies = repeat(frequencies, "f -> f p", p=2) 21 | self.register_buffer("frequencies", frequencies, persistent=False) 22 | 23 | # Choose the phases to match sine and cosine. 24 | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) 25 | phases = repeat(phases, "p -> f p", f=num_octaves) 26 | self.register_buffer("phases", phases, persistent=False) 27 | 28 | def forward( 29 | self, 30 | samples: Float[Tensor, "*batch dim"], 31 | ) -> Float[Tensor, "*batch embedded_dim"]: 32 | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") 33 | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") 34 | 35 | def d_out(self, dimensionality: int): 36 | return self.frequencies.numel() * dimensionality 37 | -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | # Shift the scene so that the median Gaussian is at the origin. 36 | means = means - means.median(dim=0).values 37 | 38 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 39 | scale_factor = means.abs().quantile(0.95, dim=0).max() 40 | means = means / scale_factor 41 | scales = scales / scale_factor 42 | 43 | # Define a rotation that makes +Z be the world up vector. 44 | rotation = [ 45 | [0, 0, 1], 46 | [-1, 0, 0], 47 | [0, -1, 0], 48 | ] 49 | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) 50 | 51 | # The Polycam viewer seems to start at a 45 degree angle. Since we want to be 52 | # looking directly at the object, we compose a 45 degree rotation onto the above 53 | # rotation. 54 | adjustment = torch.tensor( 55 | R.from_rotvec([0, 0, -45], True).as_matrix(), 56 | dtype=torch.float32, 57 | device=means.device, 58 | ) 59 | rotation = adjustment @ rotation 60 | 61 | # We also want to see the scene in camera space (as the default view). We therefore 62 | # compose the w2c rotation onto the above rotation. 63 | rotation = rotation @ extrinsics[:3, :3].inverse() 64 | 65 | # Apply the rotation to the means (Gaussian positions). 66 | means = einsum(rotation, means, "i j, ... j -> ... i") 67 | 68 | # Apply the rotation to the Gaussian rotations. 69 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 70 | rotations = rotation.detach().cpu().numpy() @ rotations 71 | rotations = R.from_matrix(rotations).as_quat() 72 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 73 | rotations = np.stack((w, x, y, z), axis=-1) 74 | 75 | # Since our axes are swizzled for the spherical harmonics, we only export the DC 76 | # band. 77 | harmonics_view_invariant = harmonics[..., 0] 78 | 79 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 80 | elements = np.empty(means.shape[0], dtype=dtype_full) 81 | attributes = ( 82 | means.detach().cpu().numpy(), 83 | torch.zeros_like(means).detach().cpu().numpy(), 84 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 85 | opacities[..., None].detach().cpu().numpy(), 86 | scales.log().detach().cpu().numpy(), 87 | rotations, 88 | ) 89 | attributes = np.concatenate(attributes, axis=1) 90 | elements[:] = list(map(tuple, attributes)) 91 | path.parent.mkdir(exist_ok=True, parents=True) 92 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 93 | -------------------------------------------------------------------------------- /src/model/transformer/attention.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | import torch 26 | from einops import rearrange 27 | from torch import nn 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None): 32 | super().__init__() 33 | inner_dim = dim_head * heads 34 | project_out = not (heads == 1 and dim_head == dim) 35 | 36 | self.heads = heads 37 | self.scale = dim_head**-0.5 38 | 39 | self.attend = nn.Softmax(dim=-1) 40 | if selfatt: 41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 42 | else: 43 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 44 | self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) 45 | 46 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() 47 | 48 | def forward(self, x, z=None): 49 | if z is None: 50 | qkv = self.to_qkv(x).chunk(3, dim=-1) 51 | else: 52 | q = self.to_q(x) 53 | k, v = self.to_kv(z).chunk(2, dim=-1) 54 | qkv = (q, k, v) 55 | 56 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | 62 | out = torch.matmul(attn, v) 63 | out = rearrange(out, "b h n d -> b n (h d)") 64 | return self.to_out(out) 65 | -------------------------------------------------------------------------------- /src/model/transformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout=0.0): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Linear(hidden_dim, dim), 36 | nn.Dropout(dropout), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | -------------------------------------------------------------------------------- /src/model/transformer/pre_norm.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class PreNorm(nn.Module): 29 | def __init__(self, dim, fn): 30 | super().__init__() 31 | self.norm = nn.LayerNorm(dim) 32 | self.fn = fn 33 | 34 | def forward(self, x, **kwargs): 35 | return self.fn(self.norm(x), **kwargs) 36 | -------------------------------------------------------------------------------- /src/model/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | from .attention import Attention 28 | from .feed_forward import FeedForward 29 | from .pre_norm import PreNorm 30 | 31 | 32 | class Transformer(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | depth, 37 | heads, 38 | dim_head, 39 | mlp_dim, 40 | dropout=0.0, 41 | selfatt=True, 42 | kv_dim=None, 43 | feed_forward_layer=FeedForward, 44 | ): 45 | super().__init__() 46 | self.layers = nn.ModuleList([]) 47 | for _ in range(depth): 48 | self.layers.append( 49 | nn.ModuleList( 50 | [ 51 | PreNorm( 52 | dim, 53 | Attention( 54 | dim, 55 | heads=heads, 56 | dim_head=dim_head, 57 | dropout=dropout, 58 | selfatt=selfatt, 59 | kv_dim=kv_dim, 60 | ), 61 | ), 62 | PreNorm(dim, feed_forward_layer(dim, mlp_dim, dropout=dropout)), 63 | ] 64 | ) 65 | ) 66 | 67 | def forward(self, x, z=None, **kwargs): 68 | for attn, ff in self.layers: 69 | x = attn(x, z=z) + x 70 | x = ff(x, **kwargs) + x 71 | return x 72 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] | Float[Tensor, "batch view gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] | Float[Tensor, "batch view gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] | Float[Tensor, "batch view gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] | Float[Tensor, "batch view gaussian"] 13 | -------------------------------------------------------------------------------- /src/utils/my_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.sum = 0.0 13 | self.val = 0.0 14 | self.avg = 0.0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def format_duration(seconds): 24 | duration = datetime.timedelta(seconds=seconds) 25 | days = duration.days 26 | hours, remainder = divmod(duration.seconds, 3600) 27 | minutes, seconds = divmod(remainder, 60) 28 | 29 | parts = [] 30 | if days > 0: 31 | parts.append(f"{days}d") 32 | if hours > 0: 33 | parts.append(f"{hours}h") 34 | if minutes > 0: 35 | parts.append(f"{minutes}m") 36 | if seconds > 0: 37 | parts.append(f"{seconds}s") 38 | 39 | if len(parts) == 0: 40 | return "0s" 41 | return ":".join(parts) 42 | 43 | 44 | def filter_low_opacities(gaussians, threshold=0.01): 45 | gaussians.opacities[gaussians.opacities < threshold] = 0 46 | 47 | 48 | def plot_and_save_hist(x, path): 49 | import matplotlib.pyplot as plt 50 | 51 | plt.cla() 52 | x = x.detach().cpu().numpy() 53 | plt.hist(x[0], bins=20, range=(0, 1)) 54 | plt.savefig(path) 55 | -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [color_function(batch) for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size)] 88 | color = torch.cat(color, dim=0) 89 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 90 | 91 | # If any MSAA passes remain, subdivide. 92 | if remaining_passes > 0: 93 | mask = detect_msaa_pixels(color) 94 | batch_index, row_index, col_index = torch.where(mask) 95 | xy = xy[batch_index, row_index, col_index] 96 | 97 | offsets = generate_sample_grid((subdivision, subdivision), device) 98 | offsets = (offsets / subdivision - 0.5) * scale 99 | 100 | color_fine = run_msaa_pass( 101 | xy[:, None, None] + offsets, 102 | color_function, 103 | scale / subdivision, 104 | subdivision, 105 | remaining_passes - 1, 106 | device, 107 | batch_size=batch_size, 108 | ) 109 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 110 | 111 | return color 112 | 113 | 114 | @torch.no_grad() 115 | def render( 116 | shape: tuple[int, int], 117 | color_function: ColorFunction, 118 | device: torch.device, 119 | subdivision: int = 8, 120 | num_passes: int = 2, 121 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 122 | xy = generate_sample_grid(shape, device) 123 | return run_msaa_pass( 124 | xy[None], 125 | color_function, 126 | 1.0, 127 | subdivision, 128 | num_passes, 129 | device, 130 | )[0] 131 | 132 | 133 | def render_over_image( 134 | image: Float[Tensor, "3 height width"], 135 | color_function: ColorFunction, 136 | device: torch.device, 137 | subdivision: int = 8, 138 | num_passes: int = 1, 139 | ) -> Float[Tensor, "3 height width"]: 140 | _, h, w = image.shape 141 | overlay = render( 142 | (h, w), 143 | color_function, 144 | device, 145 | subdivision=subdivision, 146 | num_passes=num_passes, 147 | ) 148 | color, alpha = overlay.split((3, 1), dim=0) 149 | return image * (1 - alpha) + color * alpha 150 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) for x in images] 16 | for image, result in zip(images, results): 17 | slices = [slice(0, x) for x in image.shape] 18 | result[slices] = image[slices] 19 | return results 20 | 21 | 22 | def render_projections( 23 | gaussians: Gaussians, 24 | resolution: int, 25 | margin: float = 0.1, 26 | draw_label: bool = True, 27 | extra_label: str = "", 28 | ) -> Float[Tensor, "batch 3 3 height width"]: 29 | device = gaussians.means.device 30 | b, _, _ = gaussians.means.shape 31 | 32 | # Compute the minima and maxima of the scene. 33 | minima = gaussians.means.min(dim=1).values 34 | maxima = gaussians.means.max(dim=1).values 35 | scene_minima, scene_maxima = compute_equal_aabb_with_margin(minima, maxima, margin=margin) 36 | 37 | projections = [] 38 | for look_axis in range(3): 39 | right_axis = (look_axis + 1) % 3 40 | down_axis = (look_axis + 2) % 3 41 | 42 | # Define the extrinsics for rendering. 43 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 44 | extrinsics[:, right_axis, 0] = 1 45 | extrinsics[:, down_axis, 1] = 1 46 | extrinsics[:, look_axis, 2] = 1 47 | extrinsics[:, right_axis, 3] = 0.5 * (scene_minima[:, right_axis] + scene_maxima[:, right_axis]) 48 | extrinsics[:, down_axis, 3] = 0.5 * (scene_minima[:, down_axis] + scene_maxima[:, down_axis]) 49 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 50 | extrinsics[:, 3, 3] = 1 51 | 52 | # Define the intrinsics for rendering. 53 | extents = scene_maxima - scene_minima 54 | far = extents[:, look_axis] 55 | near = torch.zeros_like(far) 56 | width = extents[:, right_axis] 57 | height = extents[:, down_axis] 58 | 59 | projection = render_cuda_orthographic( 60 | extrinsics, 61 | width, 62 | height, 63 | near, 64 | far, 65 | (resolution, resolution), 66 | torch.zeros((b, 3), dtype=torch.float32, device=device), 67 | gaussians.means, 68 | gaussians.covariances, 69 | gaussians.harmonics, 70 | gaussians.opacities, 71 | fov_degrees=10.0, 72 | ) 73 | if draw_label: 74 | right_axis_name = "XYZ"[right_axis] 75 | down_axis_name = "XYZ"[down_axis] 76 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 77 | projection = torch.stack([add_label(x, label) for x in projection]) 78 | 79 | projections.append(projection) 80 | 81 | return torch.stack(pad(projections), dim=1) 82 | 83 | 84 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 85 | # Define colors for context and target views. 86 | num_context_views = batch["context"]["extrinsics"].shape[1] 87 | num_target_views = batch["target"]["extrinsics"].shape[1] 88 | color = torch.ones( 89 | (num_target_views + num_context_views, 3), 90 | dtype=torch.float32, 91 | device=batch["target"]["extrinsics"].device, 92 | ) 93 | color[num_context_views:, 1:] = 0 94 | 95 | return draw_cameras( 96 | resolution, 97 | torch.cat((batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0])), 98 | torch.cat((batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0])), 99 | color, 100 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 101 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 102 | ) 103 | -------------------------------------------------------------------------------- /src/visualization/vis_depth.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib as mpl 3 | import matplotlib.cm as cm 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | # https://github.com/autonomousvision/unimatch/blob/master/utils/visualization.py 9 | 10 | 11 | def vis_disparity(disp): 12 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 13 | disp_vis = disp_vis.astype("uint8") 14 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 15 | 16 | return disp_vis 17 | 18 | 19 | def viz_depth_tensor(disp, return_numpy=False, colormap="plasma"): 20 | # visualize inverse depth 21 | assert isinstance(disp, torch.Tensor) 22 | 23 | disp = disp.numpy() 24 | vmax = np.percentile(disp, 95) 25 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 26 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 27 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] 28 | 29 | if return_numpy: 30 | return colormapped_im 31 | 32 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 33 | 34 | return viz 35 | --------------------------------------------------------------------------------