├── .gitignore ├── LICENSE ├── README.md ├── assets ├── evaluation_index_acid.json ├── evaluation_index_acid_video.json ├── evaluation_index_dtu_nctx2.json ├── evaluation_index_dtu_nctx3.json ├── evaluation_index_re10k.json └── evaluation_index_re10k_video.json ├── config ├── compute_metrics.yaml ├── dataset │ ├── re10k.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── bounded_re10k.yaml │ │ └── evaluation_re10k.yaml ├── evaluation │ ├── ablation.yaml │ ├── acid.yaml │ ├── acid_video.yaml │ ├── re10k.yaml │ └── re10k_video.yaml ├── experiment │ ├── acid.yaml │ ├── dtu.yaml │ └── re10k.yaml ├── generate_evaluation_index.yaml ├── loss │ ├── depth.yaml │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ └── costvolume.yaml ├── more_commands.sh ├── requirements.txt ├── requirements_w_version.txt └── src ├── config.py ├── dataset ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_re10k.py ├── shims │ ├── augmentation_shim.py │ ├── bounds_shim.py │ ├── crop_shim.py │ └── patch_shim.py ├── types.py ├── validation_wrapper.py └── view_sampler │ ├── __init__.py │ ├── view_sampler.py │ ├── view_sampler_all.py │ ├── view_sampler_arbitrary.py │ ├── view_sampler_bounded.py │ └── view_sampler_evaluation.py ├── evaluation ├── evaluation_cfg.py ├── evaluation_index_generator.py ├── metric_computer.py └── metrics.py ├── geometry ├── epipolar_lines.py └── projection.py ├── global_cfg.py ├── loss ├── __init__.py ├── loss.py ├── loss_depth.py ├── loss_lpips.py └── loss_mse.py ├── main.py ├── misc ├── LocalLogger.py ├── benchmarker.py ├── collation.py ├── discrete_probability_distribution.py ├── heterogeneous_pairings.py ├── image_io.py ├── nn_module_tools.py ├── sh_rotation.py ├── step_tracker.py └── wandb_tools.py ├── model ├── decoder │ ├── __init__.py │ ├── cuda_splatting.py │ ├── decoder.py │ └── decoder_splatting_cuda.py ├── encoder │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── backbone_multiview.py │ │ ├── multiview_transformer.py │ │ └── unimatch │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── backbone.py │ │ │ ├── geometry.py │ │ │ ├── matching.py │ │ │ ├── position.py │ │ │ ├── reg_refine.py │ │ │ ├── transformer.py │ │ │ ├── trident_conv.py │ │ │ ├── unimatch.py │ │ │ └── utils.py │ ├── common │ │ ├── gaussian_adapter.py │ │ ├── gaussians.py │ │ └── sampler.py │ ├── costvolume │ │ ├── conversions.py │ │ ├── depth_predictor_multiview.py │ │ └── ldm_unet │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── unet.py │ │ │ └── util.py │ ├── encoder.py │ ├── encoder_costvolume.py │ ├── epipolar │ │ └── epipolar_sampler.py │ └── visualization │ │ ├── encoder_visualizer.py │ │ ├── encoder_visualizer_costvolume.py │ │ └── encoder_visualizer_costvolume_cfg.py ├── encodings │ └── positional_encoding.py ├── model_wrapper.py ├── ply_export.py └── types.py ├── paper ├── generate_point_cloud_figure_mvsplat.py └── generate_point_cloud_figure_mvsplat_teaser.py ├── scripts ├── compute_metrics.py ├── convert_dtu.py ├── dump_launch_configs.py ├── generate_dtu_evaluation_index.py ├── generate_evaluation_index.py ├── generate_video_evaluation_index.py ├── test_splatter.py └── visualize_epipolar_lines.py └── visualization ├── annotation.py ├── camera_trajectory ├── interpolation.py ├── spin.py └── wobble.py ├── color_map.py ├── colors.py ├── drawing ├── cameras.py ├── coordinate_conversion.py ├── lines.py ├── points.py ├── rendering.py └── types.py ├── layout.py ├── validation_in_3d.py └── vis_depth.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .vscode 162 | 163 | /datasets 164 | /dataset_cache 165 | 166 | # Outputs 167 | /outputs 168 | /lightning_logs 169 | /checkpoints 170 | 171 | .bashrc 172 | /launcher_venv 173 | /slurm_logs 174 | *.torch 175 | *.ckpt 176 | table.tex 177 | /baselines 178 | /test/* 179 | point_clouds* 180 | assets/*_3views.json 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yuedong Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/evaluation_index_dtu_nctx2.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31], "target": [32]}, "scan1_train_01": {"context": [33, 31], "target": [24]}, "scan1_train_02": {"context": [33, 31], "target": [23]}, "scan1_train_03": {"context": [33, 31], "target": [44]}, "scan8_train_00": {"context": [33, 31], "target": [32]}, "scan8_train_01": {"context": [33, 31], "target": [24]}, "scan8_train_02": {"context": [33, 31], "target": [23]}, "scan8_train_03": {"context": [33, 31], "target": [44]}, "scan21_train_00": {"context": [33, 31], "target": [32]}, "scan21_train_01": {"context": [33, 31], "target": [24]}, "scan21_train_02": {"context": [33, 31], "target": [23]}, "scan21_train_03": {"context": [33, 31], "target": [44]}, "scan30_train_00": {"context": [33, 31], "target": [32]}, "scan30_train_01": {"context": [33, 31], "target": [24]}, "scan30_train_02": {"context": [33, 31], "target": [23]}, "scan30_train_03": {"context": [33, 31], "target": [44]}, "scan31_train_00": {"context": [33, 31], "target": [32]}, "scan31_train_01": {"context": [33, 31], "target": [24]}, "scan31_train_02": {"context": [33, 31], "target": [23]}, "scan31_train_03": {"context": [33, 31], "target": [44]}, "scan34_train_00": {"context": [33, 31], "target": [32]}, "scan34_train_01": {"context": [33, 31], "target": [24]}, "scan34_train_02": {"context": [33, 31], "target": [23]}, "scan34_train_03": {"context": [33, 31], "target": [44]}, "scan38_train_00": {"context": [33, 31], "target": [32]}, "scan38_train_01": {"context": [33, 31], "target": [24]}, "scan38_train_02": {"context": [33, 31], "target": [23]}, "scan38_train_03": {"context": [33, 31], "target": [44]}, "scan40_train_00": {"context": [33, 31], "target": [32]}, "scan40_train_01": {"context": [33, 31], "target": [24]}, "scan40_train_02": {"context": [33, 31], "target": [23]}, "scan40_train_03": {"context": [33, 31], "target": [44]}, "scan41_train_00": {"context": [33, 31], "target": [32]}, "scan41_train_01": {"context": [33, 31], "target": [24]}, "scan41_train_02": {"context": [33, 31], "target": [23]}, "scan41_train_03": {"context": [33, 31], "target": [44]}, "scan45_train_00": {"context": [33, 31], "target": [32]}, "scan45_train_01": {"context": [33, 31], "target": [24]}, "scan45_train_02": {"context": [33, 31], "target": [23]}, "scan45_train_03": {"context": [33, 31], "target": [44]}, "scan55_train_00": {"context": [33, 31], "target": [32]}, "scan55_train_01": {"context": [33, 31], "target": [24]}, "scan55_train_02": {"context": [33, 31], "target": [23]}, "scan55_train_03": {"context": [33, 31], "target": [44]}, "scan63_train_00": {"context": [33, 31], "target": [32]}, "scan63_train_01": {"context": [33, 31], "target": [24]}, "scan63_train_02": {"context": [33, 31], "target": [23]}, "scan63_train_03": {"context": [33, 31], "target": [44]}, "scan82_train_00": {"context": [33, 31], "target": [32]}, "scan82_train_01": {"context": [33, 31], "target": [24]}, "scan82_train_02": {"context": [33, 31], "target": [23]}, "scan82_train_03": {"context": [33, 31], "target": [44]}, "scan103_train_00": {"context": [33, 31], "target": [32]}, "scan103_train_01": {"context": [33, 31], "target": [24]}, "scan103_train_02": {"context": [33, 31], "target": [23]}, "scan103_train_03": {"context": [33, 31], "target": [44]}, "scan110_train_00": {"context": [33, 31], "target": [32]}, "scan110_train_01": {"context": [33, 31], "target": [24]}, "scan110_train_02": {"context": [33, 31], "target": [23]}, "scan110_train_03": {"context": [33, 31], "target": [44]}, "scan114_train_00": {"context": [33, 31], "target": [32]}, "scan114_train_01": {"context": [33, 31], "target": [24]}, "scan114_train_02": {"context": [33, 31], "target": [23]}, "scan114_train_03": {"context": [33, 31], "target": [44]}} -------------------------------------------------------------------------------- /assets/evaluation_index_dtu_nctx3.json: -------------------------------------------------------------------------------- 1 | {"scan1_train_00": {"context": [33, 31, 43], "target": [32]}, "scan1_train_01": {"context": [33, 31, 43], "target": [24]}, "scan1_train_02": {"context": [33, 31, 43], "target": [23]}, "scan1_train_03": {"context": [33, 31, 43], "target": [44]}, "scan8_train_00": {"context": [33, 31, 43], "target": [32]}, "scan8_train_01": {"context": [33, 31, 43], "target": [24]}, "scan8_train_02": {"context": [33, 31, 43], "target": [23]}, "scan8_train_03": {"context": [33, 31, 43], "target": [44]}, "scan21_train_00": {"context": [33, 31, 43], "target": [32]}, "scan21_train_01": {"context": [33, 31, 43], "target": [24]}, "scan21_train_02": {"context": [33, 31, 43], "target": [23]}, "scan21_train_03": {"context": [33, 31, 43], "target": [44]}, "scan30_train_00": {"context": [33, 31, 43], "target": [32]}, "scan30_train_01": {"context": [33, 31, 43], "target": [24]}, "scan30_train_02": {"context": [33, 31, 43], "target": [23]}, "scan30_train_03": {"context": [33, 31, 43], "target": [44]}, "scan31_train_00": {"context": [33, 31, 43], "target": [32]}, "scan31_train_01": {"context": [33, 31, 43], "target": [24]}, "scan31_train_02": {"context": [33, 31, 43], "target": [23]}, "scan31_train_03": {"context": [33, 31, 43], "target": [44]}, "scan34_train_00": {"context": [33, 31, 43], "target": [32]}, "scan34_train_01": {"context": [33, 31, 43], "target": [24]}, "scan34_train_02": {"context": [33, 31, 43], "target": [23]}, "scan34_train_03": {"context": [33, 31, 43], "target": [44]}, "scan38_train_00": {"context": [33, 31, 43], "target": [32]}, "scan38_train_01": {"context": [33, 31, 43], "target": [24]}, "scan38_train_02": {"context": [33, 31, 43], "target": [23]}, "scan38_train_03": {"context": [33, 31, 43], "target": [44]}, "scan40_train_00": {"context": [33, 31, 43], "target": [32]}, "scan40_train_01": {"context": [33, 31, 43], "target": [24]}, "scan40_train_02": {"context": [33, 31, 43], "target": [23]}, "scan40_train_03": {"context": [33, 31, 43], "target": [44]}, "scan41_train_00": {"context": [33, 31, 43], "target": [32]}, "scan41_train_01": {"context": [33, 31, 43], "target": [24]}, "scan41_train_02": {"context": [33, 31, 43], "target": [23]}, "scan41_train_03": {"context": [33, 31, 43], "target": [44]}, "scan45_train_00": {"context": [33, 31, 43], "target": [32]}, "scan45_train_01": {"context": [33, 31, 43], "target": [24]}, "scan45_train_02": {"context": [33, 31, 43], "target": [23]}, "scan45_train_03": {"context": [33, 31, 43], "target": [44]}, "scan55_train_00": {"context": [33, 31, 43], "target": [32]}, "scan55_train_01": {"context": [33, 31, 43], "target": [24]}, "scan55_train_02": {"context": [33, 31, 43], "target": [23]}, "scan55_train_03": {"context": [33, 31, 43], "target": [44]}, "scan63_train_00": {"context": [33, 31, 43], "target": [32]}, "scan63_train_01": {"context": [33, 31, 43], "target": [24]}, "scan63_train_02": {"context": [33, 31, 43], "target": [23]}, "scan63_train_03": {"context": [33, 31, 43], "target": [44]}, "scan82_train_00": {"context": [33, 31, 43], "target": [32]}, "scan82_train_01": {"context": [33, 31, 43], "target": [24]}, "scan82_train_02": {"context": [33, 31, 43], "target": [23]}, "scan82_train_03": {"context": [33, 31, 43], "target": [44]}, "scan103_train_00": {"context": [33, 31, 43], "target": [32]}, "scan103_train_01": {"context": [33, 31, 43], "target": [24]}, "scan103_train_02": {"context": [33, 31, 43], "target": [23]}, "scan103_train_03": {"context": [33, 31, 43], "target": [44]}, "scan110_train_00": {"context": [33, 31, 43], "target": [32]}, "scan110_train_01": {"context": [33, 31, 43], "target": [24]}, "scan110_train_02": {"context": [33, 31, 43], "target": [23]}, "scan110_train_03": {"context": [33, 31, 43], "target": [44]}, "scan114_train_00": {"context": [33, 31, 43], "target": [32]}, "scan114_train_01": {"context": [33, 31, 43], "target": [24]}, "scan114_train_02": {"context": [33, 31, 43], "target": [23]}, "scan114_train_03": {"context": [33, 31, 43], "target": [44]}} -------------------------------------------------------------------------------- /config/compute_metrics.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - model/encoder: epipolar 4 | - loss: [] 5 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 6 | - override dataset/view_sampler: evaluation 7 | 8 | data_loader: 9 | train: 10 | num_workers: 0 11 | persistent_workers: true 12 | batch_size: 1 13 | seed: 1234 14 | test: 15 | num_workers: 4 16 | persistent_workers: false 17 | batch_size: 1 18 | seed: 2345 19 | val: 20 | num_workers: 0 21 | persistent_workers: true 22 | batch_size: 1 23 | seed: 3456 24 | 25 | seed: 111123 26 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: re10k 5 | roots: [datasets/re10k] 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | test_times_per_scene: 1 24 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k_video.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 45 6 | max_distance_between_context_views: 192 7 | min_distance_to_context_views: 0 8 | warm_up_steps: 150_000 9 | initial_min_distance_between_context_views: 25 10 | initial_max_distance_between_context_views: 45 11 | num_target_views: 4 12 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | -------------------------------------------------------------------------------- /config/evaluation/ablation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames 12 | - name: No Epipolar 13 | key: no_epipolar 14 | path: baselines/re10k/re10k_ablation_no_epipolar_transformer/frames 15 | - name: No Sampling 16 | key: no_sampling 17 | path: baselines/re10k/re10k_ablation_no_probabilistic_sampling/frames 18 | - name: No Depth Enc. 19 | key: no_depth_encoding 20 | path: baselines/re10k/re10k_ablation_no_depth_encoding/frames 21 | - name: Depth Reg. 22 | key: depth_regularization 23 | path: baselines/re10k/re10k_depth_loss/frames 24 | 25 | side_by_side_path: null 26 | animate_side_by_side: false 27 | highlighted: 28 | - scene: 67a69088a2695987 29 | target_index: 74 30 | - scene: e4f4574df7938f37 31 | target_index: 26 32 | - scene: 29e0bfbad00f0d5e 33 | target_index: 89 34 | 35 | output_metrics_path: baselines/re10k/evaluation_metrics_ablation.json 36 | -------------------------------------------------------------------------------- /config/evaluation/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_acid.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/acid/ours/frames 12 | - name: Du et al.~\cite{du2023cross} 13 | key: du2023 14 | path: baselines/acid/yilun/frames 15 | - name: GPNR~\cite{suhail2022generalizable} 16 | key: gpnr 17 | path: baselines/acid/gpnr/frames 18 | - name: pixelNeRF~\cite{pixelnerf} 19 | key: pixelnerf 20 | path: baselines/acid/pixelnerf/frames 21 | 22 | side_by_side_path: null 23 | animate_side_by_side: false 24 | highlighted: 25 | # Main Paper 26 | # - scene: 405dcfc20f9ba5cb 27 | # target_index: 60 28 | # - scene: 23bb6c5d93c40e13 29 | # target_index: 38 30 | # - scene: a5ab552ede3e70a0 31 | # target_index: 52 32 | # Supplemental 33 | - scene: b896ebcbe8d18553 34 | target_index: 61 35 | - scene: bfbb8ec94454b84f 36 | target_index: 89 37 | - scene: c124e820e8123518 38 | target_index: 111 39 | - scene: c06467006ab8705a 40 | target_index: 24 41 | - scene: d4c545fb5fa1a84a 42 | target_index: 23 43 | - scene: d8f7d9ef3a5a4527 44 | target_index: 138 45 | - scene: 4139c1649bcce55a 46 | target_index: 140 47 | # No Space 48 | # - scene: 20c3d098e3b28058 49 | # target_index: 360 50 | 51 | output_metrics_path: baselines/acid/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/acid_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_acid_video.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/acid/ours/frames_video 12 | - name: Du et al. 13 | key: du2023 14 | path: baselines/acid/yilun/frames_video 15 | - name: GPNR 16 | key: gpnr 17 | path: baselines/acid/gpnr/frames_video 18 | - name: pixelNeRF 19 | key: pixelnerf 20 | path: baselines/acid/pixelnerf/frames_video 21 | 22 | side_by_side_path: outputs/video/acid 23 | animate_side_by_side: true 24 | highlighted: [] 25 | 26 | output_metrics_path: outputs/video/acid/evaluation_metrics.json 27 | -------------------------------------------------------------------------------- /config/evaluation/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames 12 | - name: Du et al.~\cite{du2023cross} 13 | key: du2023 14 | path: baselines/re10k/yilun/frames 15 | - name: GPNR~\cite{suhail2022generalizable} 16 | key: gpnr 17 | path: baselines/re10k/gpnr/frames 18 | - name: pixelNeRF~\cite{pixelnerf} 19 | key: pixelnerf 20 | path: baselines/re10k/pixelnerf/frames 21 | 22 | side_by_side_path: null 23 | animate_side_by_side: false 24 | highlighted: 25 | # Main Paper 26 | - scene: 5be4f1f46b408d68 27 | target_index: 136 28 | - scene: 800ea72b6988f63e 29 | target_index: 167 30 | - scene: d3a01038c5f21473 31 | target_index: 201 32 | # Supplemental 33 | # - scene: 9e585ebbacb3e94c 34 | # target_index: 80 35 | # - scene: 7a00ed342b630d31 36 | # target_index: 101 37 | # - scene: 6f243139ca86b4e5 38 | # target_index: 54 39 | # - scene: 6e77ac6af5163f5b 40 | # target_index: 153 41 | # - scene: 5fbace6c6ca56228 42 | # target_index: 33 43 | # - scene: 7a34348316608aee 44 | # target_index: 80 45 | # - scene: 7a911883348688e9 46 | # target_index: 88 47 | 48 | 49 | output_metrics_path: baselines/re10k/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/re10k_video.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k_video.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/re10k/ours/frames_video 12 | - name: Du et al. 13 | key: du2023 14 | path: baselines/re10k/yilun/frames_video 15 | - name: GPNR 16 | key: gpnr 17 | path: baselines/re10k/gpnr/frames_video 18 | - name: pixelNeRF 19 | key: pixelnerf 20 | path: baselines/re10k/pixelnerf/frames_video 21 | 22 | side_by_side_path: outputs/video/re10k 23 | animate_side_by_side: true 24 | highlighted: [] 25 | 26 | output_metrics_path: outputs/video/re10k/evaluation_metrics.json 27 | -------------------------------------------------------------------------------- /config/experiment/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: acid 10 | tags: [acid, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/acid] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/experiment/dtu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: dtu/views2 10 | tags: [dtu, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/dtu] 40 | near: 2.125 41 | far: 4.525 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | test_times_per_scene: 4 45 | skip_bad_shape: false 46 | 47 | test: 48 | eval_time_skip_steps: 5 49 | compute_scores: true 50 | -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: costvolume 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: re10k 10 | tags: [re10k, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | 19 | # ----- Additional params for default best model customization 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/re10k] 40 | near: 1. 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | 45 | test: 46 | eval_time_skip_steps: 5 47 | compute_scores: true 48 | -------------------------------------------------------------------------------- /config/generate_evaluation_index.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - override dataset/view_sampler: all 5 | 6 | dataset: 7 | overfit_to_scene: null 8 | 9 | data_loader: 10 | train: 11 | num_workers: 0 12 | persistent_workers: true 13 | batch_size: 1 14 | seed: 1234 15 | test: 16 | num_workers: 8 17 | persistent_workers: false 18 | batch_size: 1 19 | seed: 2345 20 | val: 21 | num_workers: 0 22 | persistent_workers: true 23 | batch_size: 1 24 | seed: 3456 25 | 26 | index_generator: 27 | num_target_views: 3 28 | min_overlap: 0.6 29 | max_overlap: 1.0 30 | min_distance: 45 31 | max_distance: 135 32 | output_path: outputs/evaluation_index_re10k 33 | save_previews: false 34 | seed: 123 35 | 36 | seed: 456 37 | -------------------------------------------------------------------------------- /config/loss/depth.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | weight: 0.25 3 | sigma_image: null 4 | use_second_derivative: false 5 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.05 3 | apply_after_step: 150_000 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - model/encoder: costvolume 5 | - model/decoder: splatting_cuda 6 | - loss: [mse] 7 | 8 | wandb: 9 | project: mvsplat 10 | entity: placeholder 11 | name: placeholder 12 | mode: disabled 13 | id: null 14 | 15 | mode: train 16 | 17 | dataset: 18 | overfit_to_scene: null 19 | 20 | data_loader: 21 | # Avoid having to spin up new processes to print out visualizations. 22 | train: 23 | num_workers: 10 24 | persistent_workers: true 25 | batch_size: 4 26 | seed: 1234 27 | test: 28 | num_workers: 4 29 | persistent_workers: false 30 | batch_size: 1 31 | seed: 2345 32 | val: 33 | num_workers: 1 34 | persistent_workers: true 35 | batch_size: 1 36 | seed: 3456 37 | 38 | optimizer: 39 | lr: 2.e-4 40 | warm_up_steps: 2000 41 | cosine_lr: true 42 | 43 | checkpointing: 44 | load: null 45 | every_n_train_steps: 20000 # 5000 46 | save_top_k: -1 47 | pretrained_model: null 48 | resume: true 49 | 50 | train: 51 | depth_mode: null 52 | extended_visualization: false 53 | print_log_every_n_steps: 1 54 | 55 | test: 56 | output_path: outputs/test 57 | compute_scores: false 58 | eval_time_skip_steps: 0 59 | save_image: true 60 | save_video: false 61 | 62 | seed: 111123 63 | 64 | trainer: 65 | max_steps: -1 66 | val_check_interval: 0.5 67 | gradient_clip_val: 0.5 68 | num_sanity_val_steps: 2 69 | num_nodes: 1 70 | 71 | output_dir: null 72 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /config/model/encoder/costvolume.yaml: -------------------------------------------------------------------------------- 1 | name: costvolume 2 | 3 | opacity_mapping: 4 | initial: 0.0 5 | final: 0.0 6 | warm_up: 1 7 | 8 | num_depth_candidates: 32 9 | num_surfaces: 1 10 | 11 | gaussians_per_pixel: 1 12 | 13 | gaussian_adapter: 14 | gaussian_scale_min: 0.5 15 | gaussian_scale_max: 15.0 16 | sh_degree: 4 17 | 18 | d_feature: 128 19 | 20 | visualizer: 21 | num_samples: 8 22 | min_resolution: 256 23 | export_ply: false 24 | 25 | # params for multi-view depth predictor 26 | unimatch_weights_path: "checkpoints/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" 27 | multiview_trans_attn_split: 2 28 | costvolume_unet_feat_dim: 128 29 | costvolume_unet_channel_mult: [1,1,1] 30 | costvolume_unet_attn_res: [] 31 | depth_unet_feat_dim: 64 32 | depth_unet_attn_res: [] 33 | depth_unet_channel_mult: [1, 1, 1] 34 | downscale_factor: 4 35 | shim_patch_size: 4 36 | 37 | # below are ablation settings, keep them as false for default model 38 | wo_depth_refine: false # Table 3: base 39 | wo_cost_volume: false # Table 3: w/o cost volume 40 | wo_backbone_cross_attn: false # Table 3: w/o cross-view attention 41 | wo_cost_volume_refine: false # Table 3: w/o U-Net 42 | use_epipolar_trans: false # Table B: w/ Epipolar Transformer 43 | -------------------------------------------------------------------------------- /more_commands.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # In this file, we provide commands to get the quantitative results presented in the MVSplat paper. 4 | # Commands are provided by following the order of Tables appearing in the paper. 5 | 6 | 7 | # --------------- Default Final Models --------------- 8 | 9 | # Table 1: re10k 10 | python -m src.main +experiment=re10k \ 11 | checkpointing.load=checkpoints/re10k.ckpt \ 12 | mode=test \ 13 | dataset/view_sampler=evaluation \ 14 | test.compute_scores=true 15 | 16 | # Table 1: acid 17 | python -m src.main +experiment=acid \ 18 | checkpointing.load=checkpoints/acid.ckpt \ 19 | mode=test \ 20 | dataset/view_sampler=evaluation \ 21 | dataset.view_sampler.index_path=assets/evaluation_index_acid.json \ 22 | test.compute_scores=true 23 | 24 | # generate video 25 | python -m src.main +experiment=re10k \ 26 | checkpointing.load=checkpoints/re10k.ckpt \ 27 | mode=test \ 28 | dataset/view_sampler=evaluation \ 29 | dataset.view_sampler.index_path=assets/evaluation_index_re10k_video.json \ 30 | test.save_video=true \ 31 | test.save_image=false \ 32 | test.compute_scores=false 33 | 34 | 35 | # --------------- Cross-Dataset Generalization --------------- 36 | 37 | # Table 2: RealEstate10K -> ACID 38 | python -m src.main +experiment=acid \ 39 | checkpointing.load=checkpoints/re10k.ckpt \ 40 | mode=test \ 41 | dataset/view_sampler=evaluation \ 42 | dataset.view_sampler.index_path=assets/evaluation_index_acid.json \ 43 | test.compute_scores=true 44 | 45 | # Table 2: RealEstate10K -> DTU (2 context views) 46 | python -m src.main +experiment=dtu \ 47 | checkpointing.load=checkpoints/re10k.ckpt \ 48 | mode=test \ 49 | dataset/view_sampler=evaluation \ 50 | dataset.view_sampler.index_path=assets/evaluation_index_dtu_nctx2.json \ 51 | test.compute_scores=true 52 | 53 | # RealEstate10K -> DTU (3 context views) 54 | python -m src.main +experiment=dtu \ 55 | checkpointing.load=checkpoints/re10k.ckpt \ 56 | mode=test \ 57 | dataset/view_sampler=evaluation \ 58 | dataset.view_sampler.index_path=assets/evaluation_index_dtu_nctx3.json \ 59 | dataset.view_sampler.num_context_views=3 \ 60 | wandb.name=dtu/views3 \ 61 | test.compute_scores=true 62 | 63 | 64 | # --------------- Ablation Models --------------- 65 | 66 | # Table 3: base 67 | python -m src.main +experiment=re10k \ 68 | checkpointing.load=checkpoints/ablations/re10k_worefine.ckpt \ 69 | mode=test \ 70 | dataset/view_sampler=evaluation \ 71 | test.compute_scores=true \ 72 | wandb.name=abl/re10k_base \ 73 | model.encoder.wo_depth_refine=true 74 | 75 | # Table 3: w/o cost volume 76 | python -m src.main +experiment=re10k \ 77 | checkpointing.load=checkpoints/ablations/re10k_worefine_wocv.ckpt \ 78 | mode=test \ 79 | dataset/view_sampler=evaluation \ 80 | test.compute_scores=true \ 81 | wandb.name=abl/re10k_wocv \ 82 | model.encoder.wo_depth_refine=true \ 83 | model.encoder.wo_cost_volume=true 84 | 85 | # Table 3: w/o cross-view attention 86 | python -m src.main +experiment=re10k \ 87 | checkpointing.load=checkpoints/ablations/re10k_worefine_wobbcrossattn_best.ckpt \ 88 | mode=test \ 89 | dataset/view_sampler=evaluation \ 90 | test.compute_scores=true \ 91 | wandb.name=abl/re10k_wo_backbone_cross_attn \ 92 | model.encoder.wo_depth_refine=true \ 93 | model.encoder.wo_backbone_cross_attn=true 94 | 95 | # Table 3: w/o U-Net 96 | python -m src.main +experiment=re10k \ 97 | checkpointing.load=checkpoints/ablations/re10k_worefine_wounet.ckpt \ 98 | mode=test \ 99 | dataset/view_sampler=evaluation \ 100 | test.compute_scores=true \ 101 | wandb.name=abl/re10k_wo_unet \ 102 | model.encoder.wo_depth_refine=true \ 103 | model.encoder.wo_cost_volume_refine=true 104 | 105 | # Table B: w/ Epipolar Transformer 106 | python -m src.main +experiment=re10k \ 107 | checkpointing.load=checkpoints/ablations/re10k_worefine_wepitrans.ckpt \ 108 | mode=test \ 109 | dataset/view_sampler=evaluation \ 110 | test.compute_scores=true \ 111 | wandb.name=abl/re10k_w_epipolar_trans \ 112 | model.encoder.wo_depth_refine=true \ 113 | model.encoder.use_epipolar_trans=true 114 | 115 | # Table C: 3 Gaussians per pixel 116 | python -m src.main +experiment=re10k \ 117 | checkpointing.load=checkpoints/ablations/re10k_gpp3.ckpt \ 118 | mode=test \ 119 | dataset/view_sampler=evaluation \ 120 | test.compute_scores=true \ 121 | wandb.name=abl/re10k_gpp3 \ 122 | model.encoder.gaussians_per_pixel=3 123 | 124 | # Table D: w/ random init (300K) 125 | python -m src.main +experiment=re10k \ 126 | checkpointing.load=checkpoints/ablations/re10k_wopretrained.ckpt \ 127 | mode=test \ 128 | dataset/view_sampler=evaluation \ 129 | test.compute_scores=true \ 130 | wandb.name=abl/re10k_wo_pretrained 131 | 132 | # Table D: w/ random init (450K) 133 | python -m src.main +experiment=re10k \ 134 | checkpointing.load=checkpoints/ablations/re10k_wopretrained_450k.ckpt \ 135 | mode=test \ 136 | dataset/view_sampler=evaluation \ 137 | test.compute_scores=true \ 138 | wandb.name=abl/re10k_wo_pretrained_450k 139 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | tqdm 3 | pytorch_lightning 4 | black 5 | ruff 6 | hydra-core 7 | jaxtyping 8 | beartype 9 | wandb 10 | einops 11 | colorama 12 | scikit-image 13 | colorspacious 14 | matplotlib 15 | moviepy 16 | imageio 17 | git+https://github.com/dcharatan/diff-gaussian-rasterization-modified 18 | timm 19 | dacite 20 | lpips 21 | e3nn 22 | plyfile 23 | tabulate 24 | svg.py 25 | opencv-python==4.6.0.66 26 | sk-video 27 | -------------------------------------------------------------------------------- /requirements_w_version.txt: -------------------------------------------------------------------------------- 1 | wheel==0.43.0 2 | tqdm==4.66.4 3 | pytorch_lightning==2.2.4 4 | black==24.4.2 5 | ruff==0.4.4 6 | hydra-core==1.3.2 7 | jaxtyping==0.2.28 8 | beartype==0.18.5 9 | wandb==0.17.0 10 | einops==0.8.0 11 | colorama==0.4.6 12 | scikit-image==0.23.2 13 | colorspacious==1.1.2 14 | matplotlib==3.8.4 15 | moviepy==1.0.3 16 | imageio==2.34.1 17 | git+https://github.com/dcharatan/diff-gaussian-rasterization-modified 18 | timm==0.9.16 19 | dacite==1.8.1 20 | lpips==0.1.4 21 | e3nn==0.5.1 22 | plyfile==1.0.3 23 | tabulate==0.9.0 24 | svg.py==1.4.3 25 | opencv-python==4.6.0.66 26 | sk-video==1.1.10 27 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset.data_module import DataLoaderCfg, DatasetCfg 9 | from .loss import LossCfgWrapper 10 | from .model.decoder import DecoderCfg 11 | from .model.encoder import EncoderCfg 12 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 13 | 14 | 15 | @dataclass 16 | class CheckpointingCfg: 17 | load: Optional[str] # Not a path, since it could be something like wandb://... 18 | every_n_train_steps: int 19 | save_top_k: int 20 | pretrained_model: Optional[str] 21 | resume: Optional[bool] = True 22 | 23 | 24 | @dataclass 25 | class ModelCfg: 26 | decoder: DecoderCfg 27 | encoder: EncoderCfg 28 | 29 | 30 | @dataclass 31 | class TrainerCfg: 32 | max_steps: int 33 | val_check_interval: int | float | None 34 | gradient_clip_val: int | float | None 35 | num_sanity_val_steps: int 36 | num_nodes: Optional[int] = 1 37 | 38 | 39 | @dataclass 40 | class RootCfg: 41 | wandb: dict 42 | mode: Literal["train", "test"] 43 | dataset: DatasetCfg 44 | data_loader: DataLoaderCfg 45 | model: ModelCfg 46 | optimizer: OptimizerCfg 47 | checkpointing: CheckpointingCfg 48 | trainer: TrainerCfg 49 | loss: list[LossCfgWrapper] 50 | test: TestCfg 51 | train: TrainCfg 52 | seed: int 53 | 54 | 55 | TYPE_HOOKS = { 56 | Path: Path, 57 | } 58 | 59 | 60 | T = TypeVar("T") 61 | 62 | 63 | def load_typed_config( 64 | cfg: DictConfig, 65 | data_class: Type[T], 66 | extra_type_hooks: dict = {}, 67 | ) -> T: 68 | return from_dict( 69 | data_class, 70 | OmegaConf.to_container(cfg), 71 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 72 | ) 73 | 74 | 75 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 76 | # The dummy allows the union to be converted. 77 | @dataclass 78 | class Dummy: 79 | dummy: LossCfgWrapper 80 | 81 | return [ 82 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 83 | for k, v in joined.items() 84 | ] 85 | 86 | 87 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 88 | return load_typed_config( 89 | cfg, 90 | RootCfg, 91 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers}, 92 | ) 93 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from ..misc.step_tracker import StepTracker 4 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg 5 | from .types import Stage 6 | from .view_sampler import get_view_sampler 7 | 8 | DATASETS: dict[str, Dataset] = { 9 | "re10k": DatasetRE10k, 10 | } 11 | 12 | 13 | DatasetCfg = DatasetRE10kCfg 14 | 15 | 16 | def get_dataset( 17 | cfg: DatasetCfg, 18 | stage: Stage, 19 | step_tracker: StepTracker | None, 20 | ) -> Dataset: 21 | view_sampler = get_view_sampler( 22 | cfg.view_sampler, 23 | stage, 24 | cfg.overfit_to_scene is not None, 25 | cfg.cameras_are_circular, 26 | step_tracker, 27 | ) 28 | return DATASETS[cfg.name](cfg, stage, view_sampler) 29 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | from pytorch_lightning import LightningDataModule 8 | from torch import Generator, nn 9 | from torch.utils.data import DataLoader, Dataset, IterableDataset 10 | 11 | from ..misc.step_tracker import StepTracker 12 | from . import DatasetCfg, get_dataset 13 | from .types import DataShim, Stage 14 | from .validation_wrapper import ValidationWrapper 15 | 16 | 17 | def get_data_shim(encoder: nn.Module) -> DataShim: 18 | """Get functions that modify the batch. It's sometimes necessary to modify batches 19 | outside the data loader because GPU computations are required to modify the batch or 20 | because the modification depends on something outside the data loader. 21 | """ 22 | 23 | shims: list[DataShim] = [] 24 | if hasattr(encoder, "get_data_shim"): 25 | shims.append(encoder.get_data_shim()) 26 | 27 | def combined_shim(batch): 28 | for shim in shims: 29 | batch = shim(batch) 30 | return batch 31 | 32 | return combined_shim 33 | 34 | 35 | @dataclass 36 | class DataLoaderStageCfg: 37 | batch_size: int 38 | num_workers: int 39 | persistent_workers: bool 40 | seed: int | None 41 | 42 | 43 | @dataclass 44 | class DataLoaderCfg: 45 | train: DataLoaderStageCfg 46 | test: DataLoaderStageCfg 47 | val: DataLoaderStageCfg 48 | 49 | 50 | DatasetShim = Callable[[Dataset, Stage], Dataset] 51 | 52 | 53 | def worker_init_fn(worker_id: int) -> None: 54 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 55 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 56 | 57 | 58 | class DataModule(LightningDataModule): 59 | dataset_cfg: DatasetCfg 60 | data_loader_cfg: DataLoaderCfg 61 | step_tracker: StepTracker | None 62 | dataset_shim: DatasetShim 63 | global_rank: int 64 | 65 | def __init__( 66 | self, 67 | dataset_cfg: DatasetCfg, 68 | data_loader_cfg: DataLoaderCfg, 69 | step_tracker: StepTracker | None = None, 70 | dataset_shim: DatasetShim = lambda dataset, _: dataset, 71 | global_rank: int = 0, 72 | ) -> None: 73 | super().__init__() 74 | self.dataset_cfg = dataset_cfg 75 | self.data_loader_cfg = data_loader_cfg 76 | self.step_tracker = step_tracker 77 | self.dataset_shim = dataset_shim 78 | self.global_rank = global_rank 79 | 80 | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: 81 | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers 82 | 83 | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: 84 | if loader_cfg.seed is None: 85 | return None 86 | generator = Generator() 87 | generator.manual_seed(loader_cfg.seed + self.global_rank) 88 | return generator 89 | 90 | def train_dataloader(self): 91 | dataset = get_dataset(self.dataset_cfg, "train", self.step_tracker) 92 | dataset = self.dataset_shim(dataset, "train") 93 | return DataLoader( 94 | dataset, 95 | self.data_loader_cfg.train.batch_size, 96 | shuffle=not isinstance(dataset, IterableDataset), 97 | num_workers=self.data_loader_cfg.train.num_workers, 98 | generator=self.get_generator(self.data_loader_cfg.train), 99 | worker_init_fn=worker_init_fn, 100 | persistent_workers=self.get_persistent(self.data_loader_cfg.train), 101 | ) 102 | 103 | def val_dataloader(self): 104 | dataset = get_dataset(self.dataset_cfg, "val", self.step_tracker) 105 | dataset = self.dataset_shim(dataset, "val") 106 | return DataLoader( 107 | ValidationWrapper(dataset, 1), 108 | self.data_loader_cfg.val.batch_size, 109 | num_workers=self.data_loader_cfg.val.num_workers, 110 | generator=self.get_generator(self.data_loader_cfg.val), 111 | worker_init_fn=worker_init_fn, 112 | persistent_workers=self.get_persistent(self.data_loader_cfg.val), 113 | ) 114 | 115 | def test_dataloader(self, dataset_cfg=None): 116 | dataset = get_dataset( 117 | self.dataset_cfg if dataset_cfg is None else dataset_cfg, 118 | "test", 119 | self.step_tracker, 120 | ) 121 | dataset = self.dataset_shim(dataset, "test") 122 | return DataLoader( 123 | dataset, 124 | self.data_loader_cfg.test.batch_size, 125 | num_workers=self.data_loader_cfg.test.num_workers, 126 | generator=self.get_generator(self.data_loader_cfg.test), 127 | worker_init_fn=worker_init_fn, 128 | persistent_workers=self.get_persistent(self.data_loader_cfg.test), 129 | shuffle=False, 130 | ) 131 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ...misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | 11 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 12 | "all": ViewSamplerAll, 13 | "arbitrary": ViewSamplerArbitrary, 14 | "bounded": ViewSamplerBounded, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = ( 19 | ViewSamplerArbitraryCfg 20 | | ViewSamplerBoundedCfg 21 | | ViewSamplerEvaluationCfg 22 | | ViewSamplerAllCfg 23 | ) 24 | 25 | 26 | def get_view_sampler( 27 | cfg: ViewSamplerCfg, 28 | stage: Stage, 29 | overfit: bool, 30 | cameras_are_circular: bool, 31 | step_tracker: StepTracker | None, 32 | ) -> ViewSampler[Any]: 33 | return VIEW_SAMPLERS[cfg.name]( 34 | cfg, 35 | stage, 36 | overfit, 37 | cameras_are_circular, 38 | step_tracker, 39 | ) 40 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | **kwargs, 43 | ) -> tuple[ 44 | Int64[Tensor, " context_view"], # indices for context views 45 | Int64[Tensor, " target_view"], # indices for target views 46 | ]: 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def num_target_views(self) -> int: 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def num_context_views(self) -> int: 57 | pass 58 | 59 | @property 60 | def global_step(self) -> int: 61 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 62 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerArbitraryCfg: 13 | name: Literal["arbitrary"] 14 | num_context_views: int 15 | num_target_views: int 16 | context_views: list[int] | None 17 | target_views: list[int] | None 18 | 19 | 20 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 21 | def sample( 22 | self, 23 | scene: str, 24 | extrinsics: Float[Tensor, "view 4 4"], 25 | intrinsics: Float[Tensor, "view 3 3"], 26 | device: torch.device = torch.device("cpu"), 27 | ) -> tuple[ 28 | Int64[Tensor, " context_view"], # indices for context views 29 | Int64[Tensor, " target_view"], # indices for target views 30 | ]: 31 | """Arbitrarily sample context and target views.""" 32 | num_views, _, _ = extrinsics.shape 33 | 34 | index_context = torch.randint( 35 | 0, 36 | num_views, 37 | size=(self.cfg.num_context_views,), 38 | device=device, 39 | ) 40 | 41 | # Allow the context views to be fixed. 42 | if self.cfg.context_views is not None: 43 | assert len(self.cfg.context_views) == self.cfg.num_context_views 44 | index_context = torch.tensor( 45 | self.cfg.context_views, dtype=torch.int64, device=device 46 | ) 47 | 48 | index_target = torch.randint( 49 | 0, 50 | num_views, 51 | size=(self.cfg.num_target_views,), 52 | device=device, 53 | ) 54 | 55 | # Allow the target views to be fixed. 56 | if self.cfg.target_views is not None: 57 | assert len(self.cfg.target_views) == self.cfg.num_target_views 58 | index_target = torch.tensor( 59 | self.cfg.target_views, dtype=torch.int64, device=device 60 | ) 61 | 62 | return index_context, index_target 63 | 64 | @property 65 | def num_context_views(self) -> int: 66 | return self.cfg.num_context_views 67 | 68 | @property 69 | def num_target_views(self) -> int: 70 | return self.cfg.num_target_views 71 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | name: Literal["bounded"] 14 | num_context_views: int 15 | num_target_views: int 16 | min_distance_between_context_views: int 17 | max_distance_between_context_views: int 18 | min_distance_to_context_views: int 19 | warm_up_steps: int 20 | initial_min_distance_between_context_views: int 21 | initial_max_distance_between_context_views: int 22 | 23 | 24 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 25 | def schedule(self, initial: int, final: int) -> int: 26 | fraction = self.global_step / self.cfg.warm_up_steps 27 | return min(initial + int((final - initial) * fraction), final) 28 | 29 | def sample( 30 | self, 31 | scene: str, 32 | extrinsics: Float[Tensor, "view 4 4"], 33 | intrinsics: Float[Tensor, "view 3 3"], 34 | device: torch.device = torch.device("cpu"), 35 | ) -> tuple[ 36 | Int64[Tensor, " context_view"], # indices for context views 37 | Int64[Tensor, " target_view"], # indices for target views 38 | ]: 39 | num_views, _, _ = extrinsics.shape 40 | 41 | # Compute the context view spacing based on the current global step. 42 | if self.stage == "test": 43 | # When testing, always use the full gap. 44 | max_gap = self.cfg.max_distance_between_context_views 45 | min_gap = self.cfg.max_distance_between_context_views 46 | elif self.cfg.warm_up_steps > 0: 47 | max_gap = self.schedule( 48 | self.cfg.initial_max_distance_between_context_views, 49 | self.cfg.max_distance_between_context_views, 50 | ) 51 | min_gap = self.schedule( 52 | self.cfg.initial_min_distance_between_context_views, 53 | self.cfg.min_distance_between_context_views, 54 | ) 55 | else: 56 | max_gap = self.cfg.max_distance_between_context_views 57 | min_gap = self.cfg.min_distance_between_context_views 58 | 59 | # Pick the gap between the context views. 60 | # NOTE: we keep the bug untouched to follow initial pixelsplat cfgs 61 | if not self.cameras_are_circular: 62 | max_gap = min(num_views - 1, min_gap) 63 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 64 | if max_gap < min_gap: 65 | raise ValueError("Example does not have enough frames!") 66 | context_gap = torch.randint( 67 | min_gap, 68 | max_gap + 1, 69 | size=tuple(), 70 | device=device, 71 | ).item() 72 | 73 | # Pick the left and right context indices. 74 | index_context_left = torch.randint( 75 | num_views if self.cameras_are_circular else num_views - context_gap, 76 | size=tuple(), 77 | device=device, 78 | ).item() 79 | if self.stage == "test": 80 | index_context_left = index_context_left * 0 81 | index_context_right = index_context_left + context_gap 82 | 83 | if self.is_overfitting: 84 | index_context_left *= 0 85 | index_context_right *= 0 86 | index_context_right += max_gap 87 | 88 | # Pick the target view indices. 89 | if self.stage == "test": 90 | # When testing, pick all. 91 | index_target = torch.arange( 92 | index_context_left, 93 | index_context_right + 1, 94 | device=device, 95 | ) 96 | else: 97 | # When training or validating (visualizing), pick at random. 98 | index_target = torch.randint( 99 | index_context_left + self.cfg.min_distance_to_context_views, 100 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 101 | size=(self.cfg.num_target_views,), 102 | device=device, 103 | ) 104 | 105 | # Apply modulo for circular datasets. 106 | if self.cameras_are_circular: 107 | index_target %= num_views 108 | index_context_right %= num_views 109 | 110 | return ( 111 | torch.tensor((index_context_left, index_context_right)), 112 | index_target, 113 | ) 114 | 115 | @property 116 | def num_context_views(self) -> int: 117 | return 2 118 | 119 | @property 120 | def num_target_views(self) -> int: 121 | return self.cfg.num_target_views 122 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...misc.step_tracker import StepTracker 13 | from ..types import Stage 14 | from .view_sampler import ViewSampler 15 | 16 | 17 | @dataclass 18 | class ViewSamplerEvaluationCfg: 19 | name: Literal["evaluation"] 20 | index_path: Path 21 | num_context_views: int 22 | 23 | 24 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 25 | index: dict[str, IndexEntry | None] 26 | 27 | def __init__( 28 | self, 29 | cfg: ViewSamplerEvaluationCfg, 30 | stage: Stage, 31 | is_overfitting: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> None: 35 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 36 | 37 | dacite_config = Config(cast=[tuple]) 38 | with cfg.index_path.open("r") as f: 39 | self.index = { 40 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 41 | for k, v in json.load(f).items() 42 | } 43 | 44 | def sample( 45 | self, 46 | scene: str, 47 | extrinsics: Float[Tensor, "view 4 4"], 48 | intrinsics: Float[Tensor, "view 3 3"], 49 | device: torch.device = torch.device("cpu"), 50 | ) -> tuple[ 51 | Int64[Tensor, " context_view"], # indices for context views 52 | Int64[Tensor, " target_view"], # indices for target views 53 | ]: 54 | entry = self.index.get(scene) 55 | if entry is None: 56 | raise ValueError(f"No indices available for scene {scene}.") 57 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 58 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 59 | return context_indices, target_indices 60 | 61 | @property 62 | def num_context_views(self) -> int: 63 | return 0 64 | 65 | @property 66 | def num_target_views(self) -> int: 67 | return 0 68 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | animate_side_by_side: bool 23 | highlighted: list[SceneCfg] 24 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_index_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | import torch 6 | from einops import rearrange 7 | from pytorch_lightning import LightningModule 8 | from tqdm import tqdm 9 | 10 | from ..geometry.epipolar_lines import project_rays 11 | from ..geometry.projection import get_world_rays, sample_image_grid 12 | from ..misc.image_io import save_image 13 | from ..visualization.annotation import add_label 14 | from ..visualization.layout import add_border, hcat 15 | 16 | 17 | @dataclass 18 | class EvaluationIndexGeneratorCfg: 19 | num_target_views: int 20 | min_distance: int 21 | max_distance: int 22 | min_overlap: float 23 | max_overlap: float 24 | output_path: Path 25 | save_previews: bool 26 | seed: int 27 | 28 | 29 | @dataclass 30 | class IndexEntry: 31 | context: tuple[int, ...] 32 | target: tuple[int, ...] 33 | 34 | 35 | class EvaluationIndexGenerator(LightningModule): 36 | generator: torch.Generator 37 | cfg: EvaluationIndexGeneratorCfg 38 | index: dict[str, IndexEntry | None] 39 | 40 | def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: 41 | super().__init__() 42 | self.cfg = cfg 43 | self.generator = torch.Generator() 44 | self.generator.manual_seed(cfg.seed) 45 | self.index = {} 46 | 47 | def test_step(self, batch, batch_idx): 48 | b, v, _, h, w = batch["target"]["image"].shape 49 | assert b == 1 50 | extrinsics = batch["target"]["extrinsics"][0] 51 | intrinsics = batch["target"]["intrinsics"][0] 52 | scene = batch["scene"][0] 53 | 54 | context_indices = torch.randperm(v, generator=self.generator) 55 | for context_index in tqdm(context_indices, "Finding context pair"): 56 | xy, _ = sample_image_grid((h, w), self.device) 57 | context_origins, context_directions = get_world_rays( 58 | rearrange(xy, "h w xy -> (h w) xy"), 59 | extrinsics[context_index], 60 | intrinsics[context_index], 61 | ) 62 | 63 | # Step away from context view until the minimum overlap threshold is met. 64 | valid_indices = [] 65 | for step in (1, -1): 66 | min_distance = self.cfg.min_distance 67 | max_distance = self.cfg.max_distance 68 | current_index = context_index + step * min_distance 69 | 70 | while 0 <= current_index.item() < v: 71 | # Compute overlap. 72 | current_origins, current_directions = get_world_rays( 73 | rearrange(xy, "h w xy -> (h w) xy"), 74 | extrinsics[current_index], 75 | intrinsics[current_index], 76 | ) 77 | projection_onto_current = project_rays( 78 | context_origins, 79 | context_directions, 80 | extrinsics[current_index], 81 | intrinsics[current_index], 82 | ) 83 | projection_onto_context = project_rays( 84 | current_origins, 85 | current_directions, 86 | extrinsics[context_index], 87 | intrinsics[context_index], 88 | ) 89 | overlap_a = projection_onto_context["overlaps_image"].float().mean() 90 | overlap_b = projection_onto_current["overlaps_image"].float().mean() 91 | 92 | overlap = min(overlap_a, overlap_b) 93 | delta = (current_index - context_index).abs() 94 | 95 | min_overlap = self.cfg.min_overlap 96 | max_overlap = self.cfg.max_overlap 97 | if min_overlap <= overlap <= max_overlap: 98 | valid_indices.append( 99 | (current_index.item(), overlap_a, overlap_b) 100 | ) 101 | 102 | # Stop once the camera has panned away too much. 103 | if overlap < min_overlap or delta > max_distance: 104 | break 105 | 106 | current_index += step 107 | 108 | if valid_indices: 109 | # Pick a random valid view. Index the resulting views. 110 | num_options = len(valid_indices) 111 | chosen = torch.randint( 112 | 0, num_options, size=tuple(), generator=self.generator 113 | ) 114 | chosen, overlap_a, overlap_b = valid_indices[chosen] 115 | 116 | context_left = min(chosen, context_index.item()) 117 | context_right = max(chosen, context_index.item()) 118 | delta = context_right - context_left 119 | 120 | # Pick non-repeated random target views. 121 | while True: 122 | target_views = torch.randint( 123 | context_left, 124 | context_right + 1, 125 | (self.cfg.num_target_views,), 126 | generator=self.generator, 127 | ) 128 | if (target_views.unique(return_counts=True)[1] == 1).all(): 129 | break 130 | 131 | target = tuple(sorted(target_views.tolist())) 132 | self.index[scene] = IndexEntry( 133 | context=(context_left, context_right), 134 | target=target, 135 | ) 136 | 137 | # Optionally, save a preview. 138 | if self.cfg.save_previews: 139 | preview_path = self.cfg.output_path / "previews" 140 | preview_path.mkdir(exist_ok=True, parents=True) 141 | a = batch["target"]["image"][0, chosen] 142 | a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") 143 | b = batch["target"]["image"][0, context_index] 144 | b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") 145 | vis = add_border(add_border(hcat(a, b)), 1, 0) 146 | vis = add_label(vis, f"Distance: {delta} frames") 147 | save_image(add_border(vis), preview_path / f"{scene}.png") 148 | break 149 | else: 150 | # This happens if no starting frame produces a valid evaluation example. 151 | self.index[scene] = None 152 | 153 | def save_index(self) -> None: 154 | self.cfg.output_path.mkdir(exist_ok=True, parents=True) 155 | with (self.cfg.output_path / "evaluation_index.json").open("w") as f: 156 | json.dump( 157 | {k: None if v is None else asdict(v) for k, v in self.index.items()}, f 158 | ) 159 | -------------------------------------------------------------------------------- /src/evaluation/metric_computer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from tabulate import tabulate 7 | 8 | from ..misc.image_io import load_image, save_image 9 | from ..visualization.annotation import add_label 10 | from ..visualization.layout import add_border, hcat 11 | from .evaluation_cfg import EvaluationCfg 12 | from .metrics import compute_lpips, compute_psnr, compute_ssim 13 | 14 | 15 | class MetricComputer(LightningModule): 16 | cfg: EvaluationCfg 17 | 18 | def __init__(self, cfg: EvaluationCfg) -> None: 19 | super().__init__() 20 | self.cfg = cfg 21 | 22 | def test_step(self, batch, batch_idx): 23 | scene = batch["scene"][0] 24 | b, cv, _, _, _ = batch["context"]["image"].shape 25 | assert b == 1 and cv == 2 26 | _, v, _, _, _ = batch["target"]["image"].shape 27 | 28 | # Skip scenes. 29 | for method in self.cfg.methods: 30 | if not (method.path / scene).exists(): 31 | print(f'Skipping "{scene}".') 32 | return 33 | 34 | # Load the images. 35 | all_images = {} 36 | try: 37 | for method in self.cfg.methods: 38 | images = [ 39 | load_image(method.path / scene / f"color/{index.item():0>6}.png") 40 | for index in batch["target"]["index"][0] 41 | ] 42 | all_images[method.key] = torch.stack(images).to(self.device) 43 | except FileNotFoundError: 44 | print(f'Skipping "{scene}".') 45 | return 46 | 47 | # Compute metrics. 48 | all_metrics = {} 49 | rgb_gt = batch["target"]["image"][0] 50 | for key, images in all_images.items(): 51 | all_metrics = { 52 | **all_metrics, 53 | f"lpips_{key}": compute_lpips(rgb_gt, images).mean(), 54 | f"ssim_{key}": compute_ssim(rgb_gt, images).mean(), 55 | f"psnr_{key}": compute_psnr(rgb_gt, images).mean(), 56 | } 57 | self.log_dict(all_metrics) 58 | self.print_preview_metrics(all_metrics) 59 | 60 | # Skip the rest if no side-by-side is needed. 61 | if self.cfg.side_by_side_path is None: 62 | return 63 | 64 | # Create side-by-side. 65 | scene_key = f"{batch_idx:0>6}_{scene}" 66 | for i in range(v): 67 | true_index = batch["target"]["index"][0, i] 68 | row = [add_label(batch["target"]["image"][0, i], "Ground Truth")] 69 | for method in self.cfg.methods: 70 | image = all_images[method.key][i] 71 | image = add_label(image, method.name) 72 | row.append(image) 73 | start_frame = batch["target"]["index"][0, 0] 74 | end_frame = batch["target"]["index"][0, -1] 75 | label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})" 76 | row = add_border(add_label(hcat(*row), label, font_size=16)) 77 | save_image( 78 | row, 79 | self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png", 80 | ) 81 | 82 | # Create an animation. 83 | if self.cfg.animate_side_by_side: 84 | (self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True) 85 | command = ( 86 | 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 ' 87 | '-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"' 88 | ) 89 | os.system( 90 | f"cd {self.cfg.side_by_side_path / scene_key} && {command} " 91 | f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4" 92 | ) 93 | 94 | def print_preview_metrics(self, metrics: dict[str, float]) -> None: 95 | if getattr(self, "running_metrics", None) is None: 96 | self.running_metrics = metrics 97 | self.running_metric_steps = 1 98 | else: 99 | s = self.running_metric_steps 100 | self.running_metrics = { 101 | k: ((s * v) + metrics[k]) / (s + 1) 102 | for k, v in self.running_metrics.items() 103 | } 104 | self.running_metric_steps += 1 105 | 106 | table = [] 107 | for method in self.cfg.methods: 108 | row = [ 109 | f"{self.running_metrics[f'{metric}_{method.key}']:.3f}" 110 | for metric in ("psnr", "lpips", "ssim") 111 | ] 112 | table.append((method.key, *row)) 113 | 114 | table = tabulate(table, ["Method", "PSNR (dB)", "LPIPS", "SSIM"]) 115 | print(table) 116 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_depth import LossDepth, LossDepthCfgWrapper 3 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 4 | from .loss_mse import LossMse, LossMseCfgWrapper 5 | 6 | LOSSES = { 7 | LossDepthCfgWrapper: LossDepth, 8 | LossLpipsCfgWrapper: LossLpips, 9 | LossMseCfgWrapper: LossMse, 10 | } 11 | 12 | LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper 13 | 14 | 15 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 16 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 17 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians: Gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | -------------------------------------------------------------------------------- /src/loss/loss_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | from .loss import Loss 12 | 13 | 14 | @dataclass 15 | class LossDepthCfg: 16 | weight: float 17 | sigma_image: float | None 18 | use_second_derivative: bool 19 | 20 | 21 | @dataclass 22 | class LossDepthCfgWrapper: 23 | depth: LossDepthCfg 24 | 25 | 26 | class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): 27 | def forward( 28 | self, 29 | prediction: DecoderOutput, 30 | batch: BatchedExample, 31 | gaussians: Gaussians, 32 | global_step: int, 33 | ) -> Float[Tensor, ""]: 34 | # Scale the depth between the near and far planes. 35 | near = batch["target"]["near"][..., None, None].log() 36 | far = batch["target"]["far"][..., None, None].log() 37 | depth = prediction.depth.minimum(far).maximum(near) 38 | depth = (depth - near) / (far - near) 39 | 40 | # Compute the difference between neighboring pixels in each direction. 41 | depth_dx = depth.diff(dim=-1) 42 | depth_dy = depth.diff(dim=-2) 43 | 44 | # If desired, compute a 2nd derivative. 45 | if self.cfg.use_second_derivative: 46 | depth_dx = depth_dx.diff(dim=-1) 47 | depth_dy = depth_dy.diff(dim=-2) 48 | 49 | # If desired, add bilateral filtering. 50 | if self.cfg.sigma_image is not None: 51 | color_gt = batch["target"]["image"] 52 | color_dx = reduce(color_gt.diff(dim=-1), "b v c h w -> b v h w", "max") 53 | color_dy = reduce(color_gt.diff(dim=-2), "b v c h w -> b v h w", "max") 54 | if self.cfg.use_second_derivative: 55 | color_dx = color_dx[..., :, 1:].maximum(color_dx[..., :, :-1]) 56 | color_dy = color_dy[..., 1:, :].maximum(color_dy[..., :-1, :]) 57 | depth_dx = depth_dx * torch.exp(-color_dx * self.cfg.sigma_image) 58 | depth_dy = depth_dy * torch.exp(-color_dy * self.cfg.sigma_image) 59 | 60 | return self.cfg.weight * (depth_dx.abs().mean() + depth_dy.abs().mean()) 61 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians: Gaussians, 41 | global_step: int, 42 | ) -> Float[Tensor, ""]: 43 | image = batch["target"]["image"] 44 | 45 | # Before the specified step, don't apply the loss. 46 | if global_step < self.cfg.apply_after_step: 47 | return torch.tensor(0, dtype=torch.float32, device=image.device) 48 | 49 | loss = self.lpips.forward( 50 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 51 | rearrange(image, "b v c h w -> (b v) c h w"), 52 | normalize=True, 53 | ) 54 | return self.cfg.weight * loss.mean() 55 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians: Gaussians, 28 | global_step: int, 29 | ) -> Float[Tensor, ""]: 30 | delta = prediction.color - batch["target"]["image"] 31 | return self.cfg.weight * (delta**2).mean() 32 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import warnings 4 | 5 | import hydra 6 | import torch 7 | import wandb 8 | from colorama import Fore 9 | from jaxtyping import install_import_hook 10 | from omegaconf import DictConfig, OmegaConf 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning.callbacks import ( 13 | LearningRateMonitor, 14 | ModelCheckpoint, 15 | ) 16 | from pytorch_lightning.loggers.wandb import WandbLogger 17 | 18 | # Configure beartype and jaxtyping. 19 | with install_import_hook( 20 | ("src",), 21 | ("beartype", "beartype"), 22 | ): 23 | from src.config import load_typed_root_config 24 | from src.dataset.data_module import DataModule 25 | from src.global_cfg import set_cfg 26 | from src.loss import get_losses 27 | from src.misc.LocalLogger import LocalLogger 28 | from src.misc.step_tracker import StepTracker 29 | from src.misc.wandb_tools import update_checkpoint_path 30 | from src.model.decoder import get_decoder 31 | from src.model.encoder import get_encoder 32 | from src.model.model_wrapper import ModelWrapper 33 | 34 | 35 | def cyan(text: str) -> str: 36 | return f"{Fore.CYAN}{text}{Fore.RESET}" 37 | 38 | 39 | @hydra.main( 40 | version_base=None, 41 | config_path="../config", 42 | config_name="main", 43 | ) 44 | def train(cfg_dict: DictConfig): 45 | cfg = load_typed_root_config(cfg_dict) 46 | set_cfg(cfg_dict) 47 | 48 | # Set up the output directory. 49 | if cfg_dict.output_dir is None: 50 | output_dir = Path( 51 | hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] 52 | ) 53 | else: # for resuming 54 | output_dir = Path(cfg_dict.output_dir) 55 | os.makedirs(output_dir, exist_ok=True) 56 | print(cyan(f"Saving outputs to {output_dir}.")) 57 | latest_run = output_dir.parents[1] / "latest-run" 58 | os.system(f"rm {latest_run}") 59 | os.system(f"ln -s {output_dir} {latest_run}") 60 | 61 | # Set up logging with wandb. 62 | callbacks = [] 63 | if cfg_dict.wandb.mode != "disabled": 64 | wandb_extra_kwargs = {} 65 | if cfg_dict.wandb.id is not None: 66 | wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 67 | 'resume': "must"}) 68 | logger = WandbLogger( 69 | entity=cfg_dict.wandb.entity, 70 | project=cfg_dict.wandb.project, 71 | mode=cfg_dict.wandb.mode, 72 | name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", 73 | tags=cfg_dict.wandb.get("tags", None), 74 | log_model=False, 75 | save_dir=output_dir, 76 | config=OmegaConf.to_container(cfg_dict), 77 | **wandb_extra_kwargs, 78 | ) 79 | callbacks.append(LearningRateMonitor("step", True)) 80 | 81 | # On rank != 0, wandb.run is None. 82 | if wandb.run is not None: 83 | wandb.run.log_code("src") 84 | else: 85 | logger = LocalLogger() 86 | 87 | # Set up checkpointing. 88 | callbacks.append( 89 | ModelCheckpoint( 90 | output_dir / "checkpoints", 91 | every_n_train_steps=cfg.checkpointing.every_n_train_steps, 92 | save_top_k=cfg.checkpointing.save_top_k, 93 | monitor="info/global_step", 94 | mode="max", # save the lastest k ckpt, can do offline test later 95 | ) 96 | ) 97 | for cb in callbacks: 98 | cb.CHECKPOINT_EQUALS_CHAR = '_' 99 | 100 | # Prepare the checkpoint for loading. 101 | checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) 102 | 103 | # This allows the current step to be shared with the data loader processes. 104 | step_tracker = StepTracker() 105 | 106 | trainer = Trainer( 107 | max_epochs=-1, 108 | accelerator="gpu", 109 | logger=logger, 110 | devices="auto", 111 | num_nodes=cfg.trainer.num_nodes, 112 | strategy="ddp" if torch.cuda.device_count() > 1 else "auto", 113 | callbacks=callbacks, 114 | val_check_interval=cfg.trainer.val_check_interval, 115 | enable_progress_bar=cfg.mode == "test", 116 | gradient_clip_val=cfg.trainer.gradient_clip_val, 117 | max_steps=cfg.trainer.max_steps, 118 | num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, 119 | ) 120 | torch.manual_seed(cfg_dict.seed + trainer.global_rank) 121 | 122 | encoder, encoder_visualizer = get_encoder(cfg.model.encoder) 123 | 124 | model_kwargs = { 125 | "optimizer_cfg": cfg.optimizer, 126 | "test_cfg": cfg.test, 127 | "train_cfg": cfg.train, 128 | "encoder": encoder, 129 | "encoder_visualizer": encoder_visualizer, 130 | "decoder": get_decoder(cfg.model.decoder, cfg.dataset), 131 | "losses": get_losses(cfg.loss), 132 | "step_tracker": step_tracker, 133 | } 134 | if cfg.mode == "train" and checkpoint_path is not None and not cfg.checkpointing.resume: 135 | # Just load model weights, without optimizer states 136 | # e.g., fine-tune from the released weights on other datasets 137 | model_wrapper = ModelWrapper.load_from_checkpoint( 138 | checkpoint_path, **model_kwargs, strict=True) 139 | print(cyan(f"Loaded weigths from {checkpoint_path}.")) 140 | else: 141 | model_wrapper = ModelWrapper(**model_kwargs) 142 | 143 | data_module = DataModule( 144 | cfg.dataset, 145 | cfg.data_loader, 146 | step_tracker, 147 | global_rank=trainer.global_rank, 148 | ) 149 | 150 | if cfg.mode == "train": 151 | trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=( 152 | checkpoint_path if cfg.checkpointing.resume else None)) 153 | else: 154 | trainer.test( 155 | model_wrapper, 156 | datamodule=data_module, 157 | ckpt_path=checkpoint_path, 158 | ) 159 | 160 | 161 | if __name__ == "__main__": 162 | warnings.filterwarnings("ignore") 163 | torch.set_float32_matmul_precision('high') 164 | 165 | train() 166 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from PIL import Image 6 | from pytorch_lightning.loggers.logger import Logger 7 | from pytorch_lightning.utilities import rank_zero_only 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | 18 | @property 19 | def name(self): 20 | return "LocalLogger" 21 | 22 | @property 23 | def version(self): 24 | return 0 25 | 26 | @rank_zero_only 27 | def log_hyperparams(self, params): 28 | pass 29 | 30 | @rank_zero_only 31 | def log_metrics(self, metrics, step): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_image( 36 | self, 37 | key: str, 38 | images: list[Any], 39 | step: Optional[int] = None, 40 | **kwargs, 41 | ): 42 | # The function signature is the same as the wandb logger's, but the step is 43 | # actually required. 44 | assert step is not None 45 | for index, image in enumerate(images): 46 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 47 | path.parent.mkdir(exist_ok=True, parents=True) 48 | Image.fromarray(image).save(path) 49 | -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | 39 | def clear_history(self) -> None: 40 | self.execution_times = defaultdict(list) 41 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | import skvideo.io 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as tf 9 | from einops import rearrange, repeat 10 | from jaxtyping import Float, UInt8 11 | from matplotlib.figure import Figure 12 | from PIL import Image 13 | from torch import Tensor 14 | 15 | FloatImage = Union[ 16 | Float[Tensor, "height width"], 17 | Float[Tensor, "channel height width"], 18 | Float[Tensor, "batch channel height width"], 19 | ] 20 | 21 | 22 | def fig_to_image( 23 | fig: Figure, 24 | dpi: int = 100, 25 | device: torch.device = torch.device("cpu"), 26 | ) -> Float[Tensor, "3 height width"]: 27 | buffer = io.BytesIO() 28 | fig.savefig(buffer, format="raw", dpi=dpi) 29 | buffer.seek(0) 30 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 31 | h = int(fig.bbox.bounds[3]) 32 | w = int(fig.bbox.bounds[2]) 33 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 34 | buffer.close() 35 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 36 | 37 | 38 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 39 | # Handle batched images. 40 | if image.ndim == 4: 41 | image = rearrange(image, "b c h w -> c h (b w)") 42 | 43 | # Handle single-channel images. 44 | if image.ndim == 2: 45 | image = rearrange(image, "h w -> () h w") 46 | 47 | # Ensure that there are 3 or 4 channels. 48 | channel, _, _ = image.shape 49 | if channel == 1: 50 | image = repeat(image, "() h w -> c h w", c=3) 51 | assert image.shape[0] in (3, 4) 52 | 53 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 54 | return rearrange(image, "c h w -> h w c").cpu().numpy() 55 | 56 | 57 | def save_image( 58 | image: FloatImage, 59 | path: Union[Path, str], 60 | ) -> None: 61 | """Save an image. Assumed to be in range 0-1.""" 62 | 63 | # Create the parent directory if it doesn't already exist. 64 | path = Path(path) 65 | path.parent.mkdir(exist_ok=True, parents=True) 66 | 67 | # Save the image. 68 | Image.fromarray(prep_image(image)).save(path) 69 | 70 | 71 | def load_image( 72 | path: Union[Path, str], 73 | ) -> Float[Tensor, "3 height width"]: 74 | return tf.ToTensor()(Image.open(path))[:3] 75 | 76 | 77 | def save_video( 78 | images: list[FloatImage], 79 | path: Union[Path, str], 80 | ) -> None: 81 | """Save an image. Assumed to be in range 0-1.""" 82 | 83 | # Create the parent directory if it doesn't already exist. 84 | path = Path(path) 85 | path.parent.mkdir(exist_ok=True, parents=True) 86 | 87 | # Save the image. 88 | # Image.fromarray(prep_image(image)).save(path) 89 | frames = [] 90 | for image in images: 91 | frames.append(prep_image(image)) 92 | 93 | writer = skvideo.io.FFmpegWriter(path, 94 | outputdict={'-pix_fmt': 'yuv420p', '-crf': '21', 95 | '-vf': f'setpts=1.*PTS'}) 96 | for frame in frames: 97 | writer.writeFrame(frame) 98 | writer.close() 99 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | *_, n = sh_coefficients.shape 18 | alpha, beta, gamma = matrix_to_angles(rotations) 19 | result = [] 20 | for degree in range(isqrt(n)): 21 | with torch.device(device): 22 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 23 | sh_rotated = einsum( 24 | sh_rotations, 25 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 26 | "... i j, ... j -> ... i", 27 | ) 28 | result.append(sh_rotated) 29 | 30 | return torch.cat(result, dim=-1) 31 | 32 | 33 | if __name__ == "__main__": 34 | from pathlib import Path 35 | 36 | import matplotlib.pyplot as plt 37 | from e3nn.o3 import spherical_harmonics 38 | from matplotlib import cm 39 | from scipy.spatial.transform.rotation import Rotation as R 40 | 41 | device = torch.device("cuda") 42 | 43 | # Generate random spherical harmonics coefficients. 44 | degree = 4 45 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 46 | 47 | def plot_sh(sh_coefficients, path: Path) -> None: 48 | phi = torch.linspace(0, torch.pi, 100, device=device) 49 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 50 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 51 | x = torch.sin(phi) * torch.cos(theta) 52 | y = torch.sin(phi) * torch.sin(theta) 53 | z = torch.cos(phi) 54 | xyz = torch.stack([x, y, z], dim=-1) 55 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 56 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 57 | result = (result - result.min()) / (result.max() - result.min()) 58 | 59 | # Set the aspect ratio to 1 so our sphere looks spherical 60 | fig = plt.figure(figsize=plt.figaspect(1.0)) 61 | ax = fig.add_subplot(111, projection="3d") 62 | ax.plot_surface( 63 | x.cpu().numpy(), 64 | y.cpu().numpy(), 65 | z.cpu().numpy(), 66 | rstride=1, 67 | cstride=1, 68 | facecolors=cm.seismic(result.cpu().numpy()), 69 | ) 70 | # Turn off the axis planes 71 | ax.set_axis_off() 72 | path.parent.mkdir(exist_ok=True, parents=True) 73 | plt.savefig(path) 74 | 75 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 76 | rotation = torch.tensor( 77 | R.from_euler("x", angle.item()).as_matrix(), device=device 78 | ) 79 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 80 | 81 | print("Done!") 82 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ...dataset import DatasetCfg 2 | from .decoder import Decoder 3 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 4 | 5 | DECODERS = { 6 | "splatting_cuda": DecoderSplattingCUDA, 7 | } 8 | 9 | DecoderCfg = DecoderSplattingCUDACfg 10 | 11 | 12 | def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: 13 | return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) 14 | -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ...dataset import DatasetCfg 9 | from ..types import Gaussians 10 | 11 | DepthRenderingMode = Literal[ 12 | "depth", 13 | "log", 14 | "disparity", 15 | "relative_disparity", 16 | ] 17 | 18 | 19 | @dataclass 20 | class DecoderOutput: 21 | color: Float[Tensor, "batch view 3 height width"] 22 | depth: Float[Tensor, "batch view height width"] | None 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class Decoder(nn.Module, ABC, Generic[T]): 29 | cfg: T 30 | dataset_cfg: DatasetCfg 31 | 32 | def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | self.dataset_cfg = dataset_cfg 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | gaussians: Gaussians, 41 | extrinsics: Float[Tensor, "batch view 4 4"], 42 | intrinsics: Float[Tensor, "batch view 3 3"], 43 | near: Float[Tensor, "batch view"], 44 | far: Float[Tensor, "batch view"], 45 | image_shape: tuple[int, int], 46 | depth_mode: DepthRenderingMode | None = None, 47 | ) -> DecoderOutput: 48 | pass 49 | -------------------------------------------------------------------------------- /src/model/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda 12 | from .decoder import Decoder, DecoderOutput 13 | 14 | 15 | @dataclass 16 | class DecoderSplattingCUDACfg: 17 | name: Literal["splatting_cuda"] 18 | 19 | 20 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 21 | background_color: Float[Tensor, "3"] 22 | 23 | def __init__( 24 | self, 25 | cfg: DecoderSplattingCUDACfg, 26 | dataset_cfg: DatasetCfg, 27 | ) -> None: 28 | super().__init__(cfg, dataset_cfg) 29 | self.register_buffer( 30 | "background_color", 31 | torch.tensor(dataset_cfg.background_color, dtype=torch.float32), 32 | persistent=False, 33 | ) 34 | 35 | def forward( 36 | self, 37 | gaussians: Gaussians, 38 | extrinsics: Float[Tensor, "batch view 4 4"], 39 | intrinsics: Float[Tensor, "batch view 3 3"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | image_shape: tuple[int, int], 43 | depth_mode: DepthRenderingMode | None = None, 44 | ) -> DecoderOutput: 45 | b, v, _, _ = extrinsics.shape 46 | color = render_cuda( 47 | rearrange(extrinsics, "b v i j -> (b v) i j"), 48 | rearrange(intrinsics, "b v i j -> (b v) i j"), 49 | rearrange(near, "b v -> (b v)"), 50 | rearrange(far, "b v -> (b v)"), 51 | image_shape, 52 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 53 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 54 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 55 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 56 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 57 | ) 58 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 59 | 60 | return DecoderOutput( 61 | color, 62 | None 63 | if depth_mode is None 64 | else self.render_depth( 65 | gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode 66 | ), 67 | ) 68 | 69 | def render_depth( 70 | self, 71 | gaussians: Gaussians, 72 | extrinsics: Float[Tensor, "batch view 4 4"], 73 | intrinsics: Float[Tensor, "batch view 3 3"], 74 | near: Float[Tensor, "batch view"], 75 | far: Float[Tensor, "batch view"], 76 | image_shape: tuple[int, int], 77 | mode: DepthRenderingMode = "depth", 78 | ) -> Float[Tensor, "batch view height width"]: 79 | b, v, _, _ = extrinsics.shape 80 | result = render_depth_cuda( 81 | rearrange(extrinsics, "b v i j -> (b v) i j"), 82 | rearrange(intrinsics, "b v i j -> (b v) i j"), 83 | rearrange(near, "b v -> (b v)"), 84 | rearrange(far, "b v -> (b v)"), 85 | image_shape, 86 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 87 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 88 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 89 | mode=mode, 90 | ) 91 | return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) 92 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_costvolume import EncoderCostVolume, EncoderCostVolumeCfg 5 | from .visualization.encoder_visualizer import EncoderVisualizer 6 | from .visualization.encoder_visualizer_costvolume import EncoderVisualizerCostVolume 7 | 8 | ENCODERS = { 9 | "costvolume": (EncoderCostVolume, EncoderVisualizerCostVolume), 10 | } 11 | 12 | EncoderCfg = EncoderCostVolumeCfg 13 | 14 | 15 | def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]: 16 | encoder, visualizer = ENCODERS[cfg.name] 17 | encoder = encoder(cfg) 18 | if visualizer is not None: 19 | visualizer = visualizer(cfg.visualizer, encoder) 20 | return encoder, visualizer 21 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone_multiview import BackboneMultiview 2 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone_multiview.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | from .unimatch.backbone import CNNEncoder 5 | from .multiview_transformer import MultiViewFeatureTransformer 6 | from .unimatch.utils import split_feature, merge_splits 7 | from .unimatch.position import PositionEmbeddingSine 8 | 9 | from ..costvolume.conversions import depth_to_relative_disparity 10 | from ....geometry.epipolar_lines import get_depth 11 | 12 | 13 | def feature_add_position_list(features_list, attn_splits, feature_channels): 14 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 15 | 16 | if attn_splits > 1: # add position in splited window 17 | features_splits = [ 18 | split_feature(x, num_splits=attn_splits) for x in features_list 19 | ] 20 | 21 | position = pos_enc(features_splits[0]) 22 | features_splits = [x + position for x in features_splits] 23 | 24 | out_features_list = [ 25 | merge_splits(x, num_splits=attn_splits) for x in features_splits 26 | ] 27 | 28 | else: 29 | position = pos_enc(features_list[0]) 30 | 31 | out_features_list = [x + position for x in features_list] 32 | 33 | return out_features_list 34 | 35 | 36 | class BackboneMultiview(torch.nn.Module): 37 | """docstring for BackboneMultiview.""" 38 | 39 | def __init__( 40 | self, 41 | feature_channels=128, 42 | num_transformer_layers=6, 43 | ffn_dim_expansion=4, 44 | no_self_attn=False, 45 | no_cross_attn=False, 46 | num_head=1, 47 | no_split_still_shift=False, 48 | no_ffn=False, 49 | global_attn_fast=True, 50 | downscale_factor=8, 51 | use_epipolar_trans=False, 52 | ): 53 | super(BackboneMultiview, self).__init__() 54 | self.feature_channels = feature_channels 55 | # Table 3: w/o cross-view attention 56 | self.no_cross_attn = no_cross_attn 57 | # Table B: w/ Epipolar Transformer 58 | self.use_epipolar_trans = use_epipolar_trans 59 | 60 | # NOTE: '0' here hack to get 1/4 features 61 | self.backbone = CNNEncoder( 62 | output_dim=feature_channels, 63 | num_output_scales=1 if downscale_factor == 8 else 0, 64 | ) 65 | 66 | self.transformer = MultiViewFeatureTransformer( 67 | num_layers=num_transformer_layers, 68 | d_model=feature_channels, 69 | nhead=num_head, 70 | ffn_dim_expansion=ffn_dim_expansion, 71 | no_cross_attn=no_cross_attn, 72 | ) 73 | 74 | def normalize_images(self, images): 75 | '''Normalize image to match the pretrained GMFlow backbone. 76 | images: (B, N_Views, C, H, W) 77 | ''' 78 | shape = [*[1]*(images.dim() - 3), 3, 1, 1] 79 | mean = torch.tensor([0.485, 0.456, 0.406]).reshape( 80 | *shape).to(images.device) 81 | std = torch.tensor([0.229, 0.224, 0.225]).reshape( 82 | *shape).to(images.device) 83 | 84 | return (images - mean) / std 85 | 86 | def extract_feature(self, images): 87 | b, v = images.shape[:2] 88 | concat = rearrange(images, "b v c h w -> (b v) c h w") 89 | 90 | # list of [nB, C, H, W], resolution from high to low 91 | features = self.backbone(concat) 92 | if not isinstance(features, list): 93 | features = [features] 94 | # reverse: resolution from low to high 95 | features = features[::-1] 96 | 97 | features_list = [[] for _ in range(v)] 98 | for feature in features: 99 | feature = rearrange(feature, "(b v) c h w -> b v c h w", b=b, v=v) 100 | for idx in range(v): 101 | features_list[idx].append(feature[:, idx]) 102 | 103 | return features_list 104 | 105 | def forward( 106 | self, 107 | images, 108 | attn_splits=2, 109 | return_cnn_features=False, 110 | epipolar_kwargs=None, 111 | ): 112 | ''' images: (B, N_Views, C, H, W), range [0, 1] ''' 113 | # resolution low to high 114 | features_list = self.extract_feature( 115 | self.normalize_images(images)) # list of features 116 | 117 | cur_features_list = [x[0] for x in features_list] 118 | 119 | if return_cnn_features: 120 | cnn_features = torch.stack(cur_features_list, dim=1) # [B, V, C, H, W] 121 | 122 | if self.use_epipolar_trans: 123 | # NOTE: Epipolar Transformer, only for ablation used 124 | # we only abalate Epipolar Transformer under 2 views setting 125 | assert ( 126 | epipolar_kwargs is not None 127 | ), "must provide camera params to apply epipolar transformer" 128 | assert len(cur_features_list) == 2, "only use 2 views input for Epipolar Transformer ablation" 129 | feature0, feature1 = cur_features_list 130 | epipolar_sampler = epipolar_kwargs["epipolar_sampler"] 131 | depth_encoding = epipolar_kwargs["depth_encoding"] 132 | 133 | features = torch.stack((feature0, feature1), dim=1) # [B, V, C, H, W] 134 | extrinsics = epipolar_kwargs["extrinsics"] 135 | intrinsics = epipolar_kwargs["intrinsics"] 136 | near = epipolar_kwargs["near"] 137 | far = epipolar_kwargs["far"] 138 | # Get the samples used for epipolar attention. 139 | sampling = epipolar_sampler.forward( 140 | features, extrinsics, intrinsics, near, far 141 | ) 142 | # similar to pixelsplat, use camera distance as position encoding 143 | # Compute positionally encoded depths for the features. 144 | collect = epipolar_sampler.collect 145 | depths = get_depth( 146 | rearrange(sampling.origins, "b v r xyz -> b v () r () xyz"), 147 | rearrange(sampling.directions, "b v r xyz -> b v () r () xyz"), 148 | sampling.xy_sample, 149 | rearrange(collect(extrinsics), "b v ov i j -> b v ov () () i j"), 150 | rearrange(collect(intrinsics), "b v ov i j -> b v ov () () i j"), 151 | ) 152 | 153 | # Clip the depths. This is necessary for edge cases where the context views 154 | # are extremely close together (or possibly oriented the same way). 155 | depths = depths.maximum(near[..., None, None, None]) 156 | depths = depths.minimum(far[..., None, None, None]) 157 | depths = depth_to_relative_disparity( 158 | depths, 159 | rearrange(near, "b v -> b v () () ()"), 160 | rearrange(far, "b v -> b v () () ()"), 161 | ) 162 | depths = depth_encoding(depths[..., None]) 163 | target = sampling.features + depths 164 | source = features 165 | 166 | features = self.transformer((source, target), attn_type="epipolar") 167 | else: 168 | # add position to features 169 | cur_features_list = feature_add_position_list( 170 | cur_features_list, attn_splits, self.feature_channels) 171 | 172 | # Transformer 173 | cur_features_list = self.transformer( 174 | cur_features_list, attn_num_splits=attn_splits) 175 | 176 | features = torch.stack(cur_features_list, dim=1) # [B, V, C, H, W] 177 | 178 | if return_cnn_features: 179 | out_lists = [features, cnn_features] 180 | else: 181 | out_lists = [features, None] 182 | 183 | return out_lists 184 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/mvsplat/01f9a28edb5eb68416e7e63b01f8d90c3bdfbf01/src/model/encoder/backbone/unimatch/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 25 | assert device is not None 26 | 27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 28 | torch.linspace(h_min, h_max, len_h, device=device)], 29 | ) 30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 31 | 32 | return grid 33 | 34 | 35 | def normalize_coords(coords, h, w): 36 | # coords: [B, H, W, 2] 37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 38 | return (coords - c) / c # [-1, 1] 39 | 40 | 41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 42 | # img: [B, C, H, W] 43 | # sample_coords: [B, 2, H, W] in image scale 44 | if sample_coords.size(1) != 2: # [B, H, W, 2] 45 | sample_coords = sample_coords.permute(0, 3, 1, 2) 46 | 47 | b, _, h, w = sample_coords.shape 48 | 49 | # Normalize to [-1, 1] 50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 52 | 53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 54 | 55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 56 | 57 | if return_mask: 58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 59 | 60 | return img, mask 61 | 62 | return img 63 | 64 | 65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 66 | b, c, h, w = feature.size() 67 | assert flow.size(1) == 2 68 | 69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 70 | 71 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 72 | return_mask=mask) 73 | 74 | 75 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 76 | alpha=0.01, 77 | beta=0.5 78 | ): 79 | # fwd_flow, bwd_flow: [B, 2, H, W] 80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 84 | 85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 87 | 88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 90 | 91 | threshold = alpha * flow_mag + beta 92 | 93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 94 | bwd_occ = (diff_bwd > threshold).float() 95 | 96 | return fwd_occ, bwd_occ 97 | 98 | 99 | def back_project(depth, intrinsics): 100 | # Back project 2D pixel coords to 3D points 101 | # depth: [B, H, W] 102 | # intrinsics: [B, 3, 3] 103 | b, h, w = depth.shape 104 | grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] 105 | 106 | intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] 107 | 108 | points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] 109 | 110 | return points 111 | 112 | 113 | def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): 114 | # Transform 3D points from reference camera to target camera 115 | # points_ref: [B, 3, H, W] 116 | # extrinsics_ref: [B, 4, 4] 117 | # extrinsics_tgt: [B, 4, 4] 118 | # extrinsics_rel: [B, 4, 4], relative pose transform 119 | b, _, h, w = points_ref.shape 120 | 121 | if extrinsics_rel is None: 122 | extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] 123 | 124 | points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], 125 | points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] 126 | 127 | points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] 128 | 129 | return points_tgt 130 | 131 | 132 | def reproject(points_tgt, intrinsics, return_mask=False): 133 | # reproject to target view 134 | # points_tgt: [B, 3, H, W] 135 | # intrinsics: [B, 3, 3] 136 | 137 | b, _, h, w = points_tgt.shape 138 | 139 | proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] 140 | 141 | X = proj_points[:, 0] 142 | Y = proj_points[:, 1] 143 | Z = proj_points[:, 2].clamp(min=1e-3) 144 | 145 | pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale 146 | 147 | if return_mask: 148 | # valid mask in pixel space 149 | mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( 150 | pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] 151 | 152 | return pixel_coords, mask 153 | 154 | return pixel_coords 155 | 156 | 157 | def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, 158 | return_mask=False): 159 | # Compute reprojection sample coords 160 | points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] 161 | points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) 162 | 163 | if return_mask: 164 | reproj_coords, mask = reproject(points_tgt, intrinsics, 165 | return_mask=return_mask) # [B, 2, H, W] in image scale 166 | 167 | return reproj_coords, mask 168 | 169 | reproj_coords = reproject(points_tgt, intrinsics, 170 | return_mask=return_mask) # [B, 2, H, W] in image scale 171 | 172 | return reproj_coords 173 | 174 | 175 | def compute_flow_with_depth_pose(depth_ref, intrinsics, 176 | extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, 177 | return_mask=False): 178 | b, h, w = depth_ref.shape 179 | coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] 180 | 181 | if return_mask: 182 | reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, 183 | extrinsics_rel=extrinsics_rel, 184 | return_mask=return_mask) # [B, 2, H, W] 185 | rigid_flow = reproj_coords - coords_init 186 | 187 | return rigid_flow, mask 188 | 189 | reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, 190 | extrinsics_rel=extrinsics_rel, 191 | return_mask=return_mask) # [B, 2, H, W] 192 | 193 | rigid_flow = reproj_coords - coords_init 194 | 195 | return rigid_flow 196 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/reg_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, 8 | out_dim=2, 9 | ): 10 | super(FlowHead, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 13 | self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | out = self.conv2(self.relu(self.conv1(x))) 18 | 19 | return out 20 | 21 | 22 | class SepConvGRU(nn.Module): 23 | def __init__(self, hidden_dim=128, input_dim=192 + 128, 24 | kernel_size=5, 25 | ): 26 | padding = (kernel_size - 1) // 2 27 | 28 | super(SepConvGRU, self).__init__() 29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 32 | 33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 36 | 37 | def forward(self, h, x): 38 | # horizontal 39 | hx = torch.cat([h, x], dim=1) 40 | z = torch.sigmoid(self.convz1(hx)) 41 | r = torch.sigmoid(self.convr1(hx)) 42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 43 | h = (1 - z) * h + z * q 44 | 45 | # vertical 46 | hx = torch.cat([h, x], dim=1) 47 | z = torch.sigmoid(self.convz2(hx)) 48 | r = torch.sigmoid(self.convr2(hx)) 49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 50 | h = (1 - z) * h + z * q 51 | 52 | return h 53 | 54 | 55 | class BasicMotionEncoder(nn.Module): 56 | def __init__(self, corr_channels=324, 57 | flow_channels=2, 58 | ): 59 | super(BasicMotionEncoder, self).__init__() 60 | 61 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) 62 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 63 | self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) 64 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 65 | self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) 66 | 67 | def forward(self, flow, corr): 68 | cor = F.relu(self.convc1(corr)) 69 | cor = F.relu(self.convc2(cor)) 70 | flo = F.relu(self.convf1(flow)) 71 | flo = F.relu(self.convf2(flo)) 72 | 73 | cor_flo = torch.cat([cor, flo], dim=1) 74 | out = F.relu(self.conv(cor_flo)) 75 | return torch.cat([out, flow], dim=1) 76 | 77 | 78 | class BasicUpdateBlock(nn.Module): 79 | def __init__(self, corr_channels=324, 80 | hidden_dim=128, 81 | context_dim=128, 82 | downsample_factor=8, 83 | flow_dim=2, 84 | bilinear_up=False, 85 | ): 86 | super(BasicUpdateBlock, self).__init__() 87 | 88 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels, 89 | flow_channels=flow_dim, 90 | ) 91 | 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) 93 | 94 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256, 95 | out_dim=flow_dim, 96 | ) 97 | 98 | if bilinear_up: 99 | self.mask = None 100 | else: 101 | self.mask = nn.Sequential( 102 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | 109 | inp = torch.cat([inp, motion_features], dim=1) 110 | 111 | net = self.gru(net, inp) 112 | delta_flow = self.flow_head(net) 113 | 114 | if self.mask is not None: 115 | mask = self.mask(net) 116 | else: 117 | mask = None 118 | 119 | return net, mask, delta_flow 120 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/unimatch/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussian_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import einsum, rearrange 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ....geometry.projection import get_world_rays 9 | from ....misc.sh_rotation import rotate_sh 10 | from .gaussians import build_covariance 11 | 12 | 13 | @dataclass 14 | class Gaussians: 15 | means: Float[Tensor, "*batch 3"] 16 | covariances: Float[Tensor, "*batch 3 3"] 17 | scales: Float[Tensor, "*batch 3"] 18 | rotations: Float[Tensor, "*batch 4"] 19 | harmonics: Float[Tensor, "*batch 3 _"] 20 | opacities: Float[Tensor, " *batch"] 21 | 22 | 23 | @dataclass 24 | class GaussianAdapterCfg: 25 | gaussian_scale_min: float 26 | gaussian_scale_max: float 27 | sh_degree: int 28 | 29 | 30 | class GaussianAdapter(nn.Module): 31 | cfg: GaussianAdapterCfg 32 | 33 | def __init__(self, cfg: GaussianAdapterCfg): 34 | super().__init__() 35 | self.cfg = cfg 36 | 37 | # Create a mask for the spherical harmonics coefficients. This ensures that at 38 | # initialization, the coefficients are biased towards having a large DC 39 | # component and small view-dependent components. 40 | self.register_buffer( 41 | "sh_mask", 42 | torch.ones((self.d_sh,), dtype=torch.float32), 43 | persistent=False, 44 | ) 45 | for degree in range(1, self.cfg.sh_degree + 1): 46 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 47 | 48 | def forward( 49 | self, 50 | extrinsics: Float[Tensor, "*#batch 4 4"], 51 | intrinsics: Float[Tensor, "*#batch 3 3"], 52 | coordinates: Float[Tensor, "*#batch 2"], 53 | depths: Float[Tensor, "*#batch"], 54 | opacities: Float[Tensor, "*#batch"], 55 | raw_gaussians: Float[Tensor, "*#batch _"], 56 | image_shape: tuple[int, int], 57 | eps: float = 1e-8, 58 | ) -> Gaussians: 59 | device = extrinsics.device 60 | scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) 61 | 62 | # Map scale features to valid scale range. 63 | scale_min = self.cfg.gaussian_scale_min 64 | scale_max = self.cfg.gaussian_scale_max 65 | scales = scale_min + (scale_max - scale_min) * scales.sigmoid() 66 | h, w = image_shape 67 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device) 68 | multiplier = self.get_scale_multiplier(intrinsics, pixel_size) 69 | scales = scales * depths[..., None] * multiplier[..., None] 70 | 71 | # Normalize the quaternion features to yield a valid quaternion. 72 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) 73 | 74 | # Apply sigmoid to get valid colors. 75 | sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 76 | sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask 77 | 78 | # Create world-space covariance matrices. 79 | covariances = build_covariance(scales, rotations) 80 | c2w_rotations = extrinsics[..., :3, :3] 81 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) 82 | 83 | # Compute Gaussian means. 84 | origins, directions = get_world_rays(coordinates, extrinsics, intrinsics) 85 | means = origins + directions * depths[..., None] 86 | 87 | return Gaussians( 88 | means=means, 89 | covariances=covariances, 90 | harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]), 91 | opacities=opacities, 92 | # NOTE: These aren't yet rotated into world space, but they're only used for 93 | # exporting Gaussians to ply files. This needs to be fixed... 94 | scales=scales, 95 | rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 96 | ) 97 | 98 | def get_scale_multiplier( 99 | self, 100 | intrinsics: Float[Tensor, "*#batch 3 3"], 101 | pixel_size: Float[Tensor, "*#batch 2"], 102 | multiplier: float = 0.1, 103 | ) -> Float[Tensor, " *batch"]: 104 | xy_multipliers = multiplier * einsum( 105 | intrinsics[..., :2, :2].inverse(), 106 | pixel_size, 107 | "... i j, j -> ... i", 108 | ) 109 | return xy_multipliers.sum(dim=-1) 110 | 111 | @property 112 | def d_sh(self) -> int: 113 | return (self.cfg.sh_degree + 1) ** 2 114 | 115 | @property 116 | def d_in(self) -> int: 117 | return 7 + 3 * self.d_sh 118 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 8 | def quaternion_to_matrix( 9 | quaternions: Float[Tensor, "*batch 4"], 10 | eps: float = 1e-8, 11 | ) -> Float[Tensor, "*batch 3 3"]: 12 | # Order changed to match scipy format! 13 | i, j, k, r = torch.unbind(quaternions, dim=-1) 14 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 15 | 16 | o = torch.stack( 17 | ( 18 | 1 - two_s * (j * j + k * k), 19 | two_s * (i * j - k * r), 20 | two_s * (i * k + j * r), 21 | two_s * (i * j + k * r), 22 | 1 - two_s * (i * i + k * k), 23 | two_s * (j * k - i * r), 24 | two_s * (i * k - j * r), 25 | two_s * (j * k + i * r), 26 | 1 - two_s * (i * i + j * j), 27 | ), 28 | -1, 29 | ) 30 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 31 | 32 | 33 | def build_covariance( 34 | scale: Float[Tensor, "*#batch 3"], 35 | rotation_xyzw: Float[Tensor, "*#batch 4"], 36 | ) -> Float[Tensor, "*batch 3 3"]: 37 | scale = scale.diag_embed() 38 | rotation = quaternion_to_matrix(rotation_xyzw) 39 | return ( 40 | rotation 41 | @ scale 42 | @ rearrange(scale, "... i j -> ... j i") 43 | @ rearrange(rotation, "... i j -> ... j i") 44 | ) 45 | -------------------------------------------------------------------------------- /src/model/encoder/common/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor, nn 3 | 4 | from ....misc.discrete_probability_distribution import ( 5 | gather_discrete_topk, 6 | sample_discrete_distribution, 7 | ) 8 | 9 | 10 | class Sampler(nn.Module): 11 | def forward( 12 | self, 13 | probabilities: Float[Tensor, "*batch bucket"], 14 | num_samples: int, 15 | deterministic: bool, 16 | ) -> tuple[ 17 | Int64[Tensor, "*batch 1"], # index 18 | Float[Tensor, "*batch 1"], # probability density 19 | ]: 20 | return ( 21 | gather_discrete_topk(probabilities, num_samples) 22 | if deterministic 23 | else sample_discrete_distribution(probabilities, num_samples) 24 | ) 25 | 26 | def gather( 27 | self, 28 | index: Int64[Tensor, "*batch sample"], 29 | target: Shaped[Tensor, "..."], # *batch bucket *shape 30 | ) -> Shaped[Tensor, "..."]: # *batch sample *shape 31 | """Gather from the target according to the specified index. Handle the 32 | broadcasting needed for the gather to work. See the comments for the actual 33 | expected input/output shapes since jaxtyping doesn't support multiple variadic 34 | lengths in annotations. 35 | """ 36 | bucket_dim = index.ndim - 1 37 | while len(index.shape) < len(target.shape): 38 | index = index[..., None] 39 | broadcasted_index_shape = list(target.shape) 40 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 41 | index = index.broadcast_to(broadcasted_index_shape) 42 | return target.gather(dim=bucket_dim, index=index) 43 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/conversions.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | 4 | 5 | def relative_disparity_to_depth( 6 | relative_disparity: Float[Tensor, "*#batch"], 7 | near: Float[Tensor, "*#batch"], 8 | far: Float[Tensor, "*#batch"], 9 | eps: float = 1e-10, 10 | ) -> Float[Tensor, " *batch"]: 11 | """Convert relative disparity, where 0 is near and 1 is far, to depth.""" 12 | disp_near = 1 / (near + eps) 13 | disp_far = 1 / (far + eps) 14 | return 1 / ((1 - relative_disparity) * (disp_near - disp_far) + disp_far + eps) 15 | 16 | 17 | def depth_to_relative_disparity( 18 | depth: Float[Tensor, "*#batch"], 19 | near: Float[Tensor, "*#batch"], 20 | far: Float[Tensor, "*#batch"], 21 | eps: float = 1e-10, 22 | ) -> Float[Tensor, " *batch"]: 23 | """Convert depth to relative disparity, where 0 is near and 1 is far""" 24 | disp_near = 1 / (near + eps) 25 | disp_far = 1 / (far + eps) 26 | disp = 1 / (depth + eps) 27 | return 1 - (disp - disp_far) / (disp_near - disp_far + eps) 28 | -------------------------------------------------------------------------------- /src/model/encoder/costvolume/ldm_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/mvsplat/01f9a28edb5eb68416e7e63b01f8d90c3bdfbf01/src/model/encoder/costvolume/ldm_unet/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | deterministic: bool, 24 | ) -> Gaussians: 25 | pass 26 | 27 | def get_data_shim(self) -> DataShim: 28 | """The default shim doesn't modify the batch.""" 29 | return lambda x: x 30 | -------------------------------------------------------------------------------- /src/model/encoder/epipolar/epipolar_sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | from jaxtyping import Bool, Float, Shaped 7 | from torch import Tensor, nn 8 | 9 | from ....geometry.epipolar_lines import project_rays 10 | from ....geometry.projection import get_world_rays, sample_image_grid 11 | from ....misc.heterogeneous_pairings import ( 12 | Index, 13 | generate_heterogeneous_index, 14 | generate_heterogeneous_index_transpose, 15 | ) 16 | 17 | 18 | @dataclass 19 | class EpipolarSampling: 20 | features: Float[Tensor, "batch view other_view ray sample channel"] 21 | valid: Bool[Tensor, "batch view other_view ray"] 22 | xy_ray: Float[Tensor, "batch view ray 2"] 23 | xy_sample: Float[Tensor, "batch view other_view ray sample 2"] 24 | xy_sample_near: Float[Tensor, "batch view other_view ray sample 2"] 25 | xy_sample_far: Float[Tensor, "batch view other_view ray sample 2"] 26 | origins: Float[Tensor, "batch view ray 3"] 27 | directions: Float[Tensor, "batch view ray 3"] 28 | 29 | 30 | class EpipolarSampler(nn.Module): 31 | num_samples: int 32 | index_v: Index 33 | transpose_v: Index 34 | transpose_ov: Index 35 | 36 | def __init__( 37 | self, 38 | num_views: int, 39 | num_samples: int, 40 | ) -> None: 41 | super().__init__() 42 | self.num_samples = num_samples 43 | 44 | # Generate indices needed to sample only other views. 45 | _, index_v = generate_heterogeneous_index(num_views) 46 | t_v, t_ov = generate_heterogeneous_index_transpose(num_views) 47 | self.register_buffer("index_v", index_v, persistent=False) 48 | self.register_buffer("transpose_v", t_v, persistent=False) 49 | self.register_buffer("transpose_ov", t_ov, persistent=False) 50 | 51 | def forward( 52 | self, 53 | images: Float[Tensor, "batch view channel height width"], 54 | extrinsics: Float[Tensor, "batch view 4 4"], 55 | intrinsics: Float[Tensor, "batch view 3 3"], 56 | near: Float[Tensor, "batch view"], 57 | far: Float[Tensor, "batch view"], 58 | ) -> EpipolarSampling: 59 | device = images.device 60 | b, v, _, _, _ = images.shape 61 | 62 | # Generate the rays that are projected onto other views. 63 | xy_ray, origins, directions = self.generate_image_rays( 64 | images, extrinsics, intrinsics 65 | ) 66 | 67 | # Select the camera extrinsics and intrinsics to project onto. For each context 68 | # view, this means all other context views in the batch. 69 | projection = project_rays( 70 | rearrange(origins, "b v r xyz -> b v () r xyz"), 71 | rearrange(directions, "b v r xyz -> b v () r xyz"), 72 | rearrange(self.collect(extrinsics), "b v ov i j -> b v ov () i j"), 73 | rearrange(self.collect(intrinsics), "b v ov i j -> b v ov () i j"), 74 | rearrange(near, "b v -> b v () ()"), 75 | rearrange(far, "b v -> b v () ()"), 76 | ) 77 | 78 | 79 | # Generate sample points. 80 | s = self.num_samples 81 | sample_depth = (torch.arange(s, device=device) + 0.5) / s 82 | sample_depth = rearrange(sample_depth, "s -> s ()") 83 | xy_min = projection["xy_min"].nan_to_num(posinf=0, neginf=0) 84 | xy_min = xy_min * projection["overlaps_image"][..., None] 85 | xy_min = rearrange(xy_min, "b v ov r xy -> b v ov r () xy") 86 | xy_max = projection["xy_max"].nan_to_num(posinf=0, neginf=0) 87 | xy_max = xy_max * projection["overlaps_image"][..., None] 88 | xy_max = rearrange(xy_max, "b v ov r xy -> b v ov r () xy") 89 | xy_sample = xy_min + sample_depth * (xy_max - xy_min) 90 | 91 | # The samples' shape is (batch, view, other_view, ...). However, before the 92 | # transpose, the view dimension refers to the view from which the ray is cast, 93 | # not the view from which samples are drawn. Thus, we need to transpose the 94 | # samples so that the view dimension refers to the view from which samples are 95 | # drawn. If the diagonal weren't removed for efficiency, this would be a literal 96 | # transpose. In our case, it's as if the diagonal were re-added, the transpose 97 | # were taken, and the diagonal were then removed again. 98 | samples = self.transpose(xy_sample) 99 | samples = F.grid_sample( 100 | rearrange(images, "b v c h w -> (b v) c h w"), 101 | rearrange(2 * samples - 1, "b v ov r s xy -> (b v) (ov r s) () xy"), 102 | mode="bilinear", 103 | padding_mode="zeros", 104 | align_corners=False, 105 | ) 106 | samples = rearrange( 107 | samples, "(b v) c (ov r s) () -> b v ov r s c", b=b, v=v, ov=v - 1, s=s 108 | ) 109 | samples = self.transpose(samples) 110 | 111 | # Zero out invalid samples. 112 | samples = samples * projection["overlaps_image"][..., None, None] 113 | 114 | half_span = 0.5 / s 115 | return EpipolarSampling( 116 | features=samples, 117 | valid=projection["overlaps_image"], 118 | xy_ray=xy_ray, 119 | xy_sample=xy_sample, 120 | xy_sample_near=xy_min + (sample_depth - half_span) * (xy_max - xy_min), 121 | xy_sample_far=xy_min + (sample_depth + half_span) * (xy_max - xy_min), 122 | origins=origins, 123 | directions=directions, 124 | ) 125 | 126 | def generate_image_rays( 127 | self, 128 | images: Float[Tensor, "batch view channel height width"], 129 | extrinsics: Float[Tensor, "batch view 4 4"], 130 | intrinsics: Float[Tensor, "batch view 3 3"], 131 | ) -> tuple[ 132 | Float[Tensor, "batch view ray 2"], # xy 133 | Float[Tensor, "batch view ray 3"], # origins 134 | Float[Tensor, "batch view ray 3"], # directions 135 | ]: 136 | """Generate the rays along which Gaussians are defined. For now, these rays are 137 | simply arranged in a grid. 138 | """ 139 | b, v, _, h, w = images.shape 140 | xy, _ = sample_image_grid((h, w), device=images.device) 141 | origins, directions = get_world_rays( 142 | rearrange(xy, "h w xy -> (h w) xy"), 143 | rearrange(extrinsics, "b v i j -> b v () i j"), 144 | rearrange(intrinsics, "b v i j -> b v () i j"), 145 | ) 146 | return repeat(xy, "h w xy -> b v (h w) xy", b=b, v=v), origins, directions 147 | 148 | def transpose( 149 | self, 150 | x: Shaped[Tensor, "batch view other_view *rest"], 151 | ) -> Shaped[Tensor, "batch view other_view *rest"]: 152 | b, v, ov, *_ = x.shape 153 | t_b = torch.arange(b, device=x.device) 154 | t_b = repeat(t_b, "b -> b v ov", v=v, ov=ov) 155 | t_v = repeat(self.transpose_v, "v ov -> b v ov", b=b) 156 | t_ov = repeat(self.transpose_ov, "v ov -> b v ov", b=b) 157 | return x[t_b, t_v, t_ov] 158 | 159 | def collect( 160 | self, 161 | target: Shaped[Tensor, "batch view ..."], 162 | ) -> Shaped[Tensor, "batch view view-1 ..."]: 163 | b, v, *_ = target.shape 164 | index_b = torch.arange(b, device=target.device) 165 | index_b = repeat(index_b, "b -> b v ov", v=v, ov=v - 1) 166 | index_v = repeat(self.index_v, "v ov -> b v ov", b=b) 167 | return target[index_b, index_v] 168 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer_costvolume_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerCostVolumeCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /src/model/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange, repeat 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """For the sake of simplicity, this encodes values in the range [0, 1].""" 10 | 11 | frequencies: Float[Tensor, "frequency phase"] 12 | phases: Float[Tensor, "frequency phase"] 13 | 14 | def __init__(self, num_octaves: int): 15 | super().__init__() 16 | octaves = torch.arange(num_octaves).float() 17 | 18 | # The lowest frequency has a period of 1. 19 | frequencies = 2 * torch.pi * 2**octaves 20 | frequencies = repeat(frequencies, "f -> f p", p=2) 21 | self.register_buffer("frequencies", frequencies, persistent=False) 22 | 23 | # Choose the phases to match sine and cosine. 24 | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) 25 | phases = repeat(phases, "p -> f p", f=num_octaves) 26 | self.register_buffer("phases", phases, persistent=False) 27 | 28 | def forward( 29 | self, 30 | samples: Float[Tensor, "*batch dim"], 31 | ) -> Float[Tensor, "*batch embedded_dim"]: 32 | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") 33 | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") 34 | 35 | def d_out(self, dimensionality: int): 36 | return self.frequencies.numel() * dimensionality 37 | -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | # Shift the scene so that the median Gaussian is at the origin. 36 | means = means - means.median(dim=0).values 37 | 38 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 39 | scale_factor = means.abs().quantile(0.95, dim=0).max() 40 | means = means / scale_factor 41 | scales = scales / scale_factor 42 | 43 | # Define a rotation that makes +Z be the world up vector. 44 | rotation = [ 45 | [0, 0, 1], 46 | [-1, 0, 0], 47 | [0, -1, 0], 48 | ] 49 | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) 50 | 51 | # The Polycam viewer seems to start at a 45 degree angle. Since we want to be 52 | # looking directly at the object, we compose a 45 degree rotation onto the above 53 | # rotation. 54 | adjustment = torch.tensor( 55 | R.from_rotvec([0, 0, -45], True).as_matrix(), 56 | dtype=torch.float32, 57 | device=means.device, 58 | ) 59 | rotation = adjustment @ rotation 60 | 61 | # We also want to see the scene in camera space (as the default view). We therefore 62 | # compose the w2c rotation onto the above rotation. 63 | rotation = rotation @ extrinsics[:3, :3].inverse() 64 | 65 | # Apply the rotation to the means (Gaussian positions). 66 | means = einsum(rotation, means, "i j, ... j -> ... i") 67 | 68 | # Apply the rotation to the Gaussian rotations. 69 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 70 | rotations = rotation.detach().cpu().numpy() @ rotations 71 | rotations = R.from_matrix(rotations).as_quat() 72 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 73 | rotations = np.stack((w, x, y, z), axis=-1) 74 | 75 | # Since our axes are swizzled for the spherical harmonics, we only export the DC 76 | # band. 77 | harmonics_view_invariant = harmonics[..., 0] 78 | 79 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 80 | elements = np.empty(means.shape[0], dtype=dtype_full) 81 | attributes = ( 82 | means.detach().cpu().numpy(), 83 | torch.zeros_like(means).detach().cpu().numpy(), 84 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 85 | opacities[..., None].detach().cpu().numpy(), 86 | scales.log().detach().cpu().numpy(), 87 | rotations, 88 | ) 89 | attributes = np.concatenate(attributes, axis=1) 90 | elements[:] = list(map(tuple, attributes)) 91 | path.parent.mkdir(exist_ok=True, parents=True) 92 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 93 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | -------------------------------------------------------------------------------- /src/scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import hydra 6 | import torch 7 | from jaxtyping import install_import_hook 8 | from omegaconf import DictConfig 9 | from pytorch_lightning import Trainer 10 | 11 | # Configure beartype and jaxtyping. 12 | with install_import_hook( 13 | ("src",), 14 | ("beartype", "beartype"), 15 | ): 16 | from src.config import load_typed_config 17 | from src.dataset.data_module import DataLoaderCfg, DataModule, DatasetCfg 18 | from src.evaluation.evaluation_cfg import EvaluationCfg 19 | from src.evaluation.metric_computer import MetricComputer 20 | from src.global_cfg import set_cfg 21 | 22 | 23 | @dataclass 24 | class RootCfg: 25 | evaluation: EvaluationCfg 26 | dataset: DatasetCfg 27 | data_loader: DataLoaderCfg 28 | seed: int 29 | output_metrics_path: Path 30 | 31 | 32 | @hydra.main( 33 | version_base=None, 34 | config_path="../../config", 35 | config_name="compute_metrics", 36 | ) 37 | def evaluate(cfg_dict: DictConfig): 38 | cfg = load_typed_config(cfg_dict, RootCfg) 39 | set_cfg(cfg_dict) 40 | torch.manual_seed(cfg.seed) 41 | trainer = Trainer(max_epochs=-1, accelerator="gpu") 42 | computer = MetricComputer(cfg.evaluation) 43 | data_module = DataModule(cfg.dataset, cfg.data_loader) 44 | metrics = trainer.test(computer, datamodule=data_module) 45 | cfg.output_metrics_path.parent.mkdir(exist_ok=True, parents=True) 46 | with cfg.output_metrics_path.open("w") as f: 47 | json.dump(metrics[0], f) 48 | 49 | 50 | if __name__ == "__main__": 51 | evaluate() 52 | -------------------------------------------------------------------------------- /src/scripts/dump_launch_configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from colorama import Fore 5 | 6 | if __name__ == "__main__": 7 | # Go to the repo directory. 8 | x = Path.cwd() 9 | while not (x / ".git").exists(): 10 | x = x.parent 11 | 12 | # Hackily load JSON with comments and trailing commas. 13 | with (x / ".vscode/launch.json").open("r") as f: 14 | launch_with_comments = f.readlines() 15 | launch = [ 16 | line for line in launch_with_comments if not line.strip().startswith("//") 17 | ] 18 | launch = "".join(launch) 19 | launch = yaml.safe_load(launch) 20 | 21 | for cfg in launch["configurations"]: 22 | print(f"{Fore.CYAN}{cfg['name']}{Fore.RESET}") 23 | 24 | arg_str = " ".join(cfg.get("args", [])) 25 | if "env" in cfg: 26 | env_str = " ".join([f"{key}={value}" for key, value in cfg["env"].items()]) 27 | else: 28 | env_str = "" 29 | 30 | command = f"{env_str} python3 -m {cfg['module']} {arg_str}".strip() 31 | print(f"{command}\n") 32 | -------------------------------------------------------------------------------- /src/scripts/generate_dtu_evaluation_index.py: -------------------------------------------------------------------------------- 1 | ''' Build upon: https://github.com/autonomousvision/murf/blob/main/datasets/dtu.py 2 | ''' 3 | 4 | import torch 5 | import os 6 | from glob import glob 7 | import argparse 8 | from einops import rearrange, repeat 9 | from dataclasses import asdict, dataclass 10 | import json 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | 15 | @dataclass 16 | class IndexEntry: 17 | context: tuple[int, ...] 18 | target: tuple[int, ...] 19 | 20 | 21 | def sorted_test_src_views_fixed(cam2worlds_dict, test_views, train_views): 22 | # use fixed src views for testing, instead of for using different src views for different test views 23 | cam_pos_trains = np.stack([cam2worlds_dict[x] for x in train_views])[ 24 | :, :3, 3 25 | ] # [V, 3], V src views 26 | cam_pos_target = np.stack([cam2worlds_dict[x] for x in test_views])[ 27 | :, :3, 3 28 | ] # [N, 3], N test views in total 29 | dis = np.sum( 30 | np.abs(cam_pos_trains[:, None] - cam_pos_target[None]), axis=(1, 2)) 31 | src_idx = np.argsort(dis) 32 | src_idx = [train_views[x] for x in src_idx] 33 | 34 | return src_idx 35 | 36 | 37 | def main(args): 38 | data_dir = os.path.join("datasets", args.dataset_name, "test") 39 | 40 | # load view pairs 41 | # Adopt from: https://github.com/autonomousvision/murf/blob/main/datasets/dtu.py#L95 42 | test_views = [32, 24, 23, 44] 43 | train_views = [i for i in range(49) if i not in test_views] 44 | 45 | index = {} 46 | for torch_file in tqdm(sorted(glob(os.path.join(data_dir, "*.torch")))): 47 | scene_datas = torch.load(torch_file) 48 | for scene_data in scene_datas: 49 | cameras = scene_data["cameras"] 50 | scene_name = scene_data["key"] 51 | 52 | # calculate nearest camera index 53 | w2c = repeat( 54 | torch.eye(4, dtype=torch.float32), 55 | "h w -> b h w", 56 | b=cameras.shape[0], 57 | ).clone() 58 | w2c[:, :3] = rearrange( 59 | cameras[:, 6:], "b (h w) -> b h w", h=3, w=4) 60 | opencv_c2ws = w2c.inverse() # .unsqueeze(0) 61 | xyzs = opencv_c2ws[:, :3, -1].unsqueeze(0) # 1, N, 3 62 | cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) 63 | cameras_dist_index = torch.argsort( 64 | cameras_dist_matrix, dim=-1).squeeze(0) 65 | 66 | cam2worlds_dict = {k: v for k, v in enumerate(opencv_c2ws)} 67 | nearest_fixed_views = sorted_test_src_views_fixed( 68 | cam2worlds_dict, test_views, train_views 69 | ) 70 | 71 | selected_pts = test_views 72 | for seq_idx, cur_mid in enumerate(selected_pts): 73 | cur_nn_index = nearest_fixed_views 74 | contexts = tuple([int(x) 75 | for x in cur_nn_index[: args.n_contexts]]) 76 | targets = (cur_mid,) 77 | index[f"{scene_name}_{seq_idx:02d}"] = IndexEntry( 78 | context=contexts, 79 | target=targets, 80 | ) 81 | # save index to files 82 | out_path = f"assets/evaluation_index_{args.dataset_name}_nctx{args.n_contexts}.json" 83 | with open(out_path, "w") as f: 84 | json.dump({k: None if v is None else asdict(v) 85 | for k, v in index.items()}, f) 86 | print(f"Dumped index to: {out_path}") 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--n_contexts", type=int, 93 | default=2, help="output directory") 94 | parser.add_argument("--dataset_name", type=str, default="dtu") 95 | 96 | params = parser.parse_args() 97 | 98 | main(params) 99 | -------------------------------------------------------------------------------- /src/scripts/generate_evaluation_index.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import hydra 4 | import torch 5 | from jaxtyping import install_import_hook 6 | from omegaconf import DictConfig 7 | from pytorch_lightning import Trainer 8 | 9 | # Configure beartype and jaxtyping. 10 | with install_import_hook( 11 | ("src",), 12 | ("beartype", "beartype"), 13 | ): 14 | from src.config import load_typed_config 15 | from src.dataset import DatasetCfg 16 | from src.dataset.data_module import DataLoaderCfg, DataModule 17 | from src.evaluation.evaluation_index_generator import ( 18 | EvaluationIndexGenerator, 19 | EvaluationIndexGeneratorCfg, 20 | ) 21 | from src.global_cfg import set_cfg 22 | 23 | 24 | @dataclass 25 | class RootCfg: 26 | dataset: DatasetCfg 27 | data_loader: DataLoaderCfg 28 | index_generator: EvaluationIndexGeneratorCfg 29 | seed: int 30 | 31 | 32 | @hydra.main( 33 | version_base=None, 34 | config_path="../../config", 35 | config_name="generate_evaluation_index", 36 | ) 37 | def train(cfg_dict: DictConfig): 38 | cfg = load_typed_config(cfg_dict, RootCfg) 39 | set_cfg(cfg_dict) 40 | torch.manual_seed(cfg.seed) 41 | trainer = Trainer(max_epochs=1, accelerator="gpu", devices="auto", strategy="auto") 42 | data_module = DataModule(cfg.dataset, cfg.data_loader, None) 43 | evaluation_index_generator = EvaluationIndexGenerator(cfg.index_generator) 44 | trainer.test(evaluation_index_generator, datamodule=data_module) 45 | evaluation_index_generator.save_index() 46 | 47 | 48 | if __name__ == "__main__": 49 | train() 50 | -------------------------------------------------------------------------------- /src/scripts/generate_video_evaluation_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--index_input', type=str, help='depth directory') 8 | parser.add_argument('--index_output', type=str, help='dataset directory') 9 | args = parser.parse_args() 10 | 11 | 12 | # INDEX_INPUT = Path("assets/evaluation_index_re10k.json") 13 | # INDEX_OUTPUT = Path("assets/evaluation_index_re10k_video.json") 14 | INDEX_INPUT = Path(args.index_input) 15 | INDEX_OUTPUT = Path(args.index_output) 16 | 17 | 18 | if __name__ == "__main__": 19 | with INDEX_INPUT.open("r") as f: 20 | index_input = json.load(f) 21 | 22 | index_output = {} 23 | for scene, scene_index_input in index_input.items(): 24 | # Handle scenes for which there's no index. 25 | if scene_index_input is None: 26 | index_output[scene] = None 27 | continue 28 | 29 | # Add all intermediate frames as target frames. 30 | a, b = scene_index_input["context"] 31 | index_output[scene] = { 32 | "context": [a, b], 33 | "target": list(range(a, b + 1)), 34 | } 35 | 36 | with INDEX_OUTPUT.open("w") as f: 37 | json.dump(index_output, f) 38 | -------------------------------------------------------------------------------- /src/scripts/test_splatter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from time import time 3 | 4 | import torch 5 | from einops import einsum, repeat 6 | from jaxtyping import install_import_hook 7 | from scipy.spatial.transform import Rotation as R 8 | from tqdm import tqdm 9 | 10 | # Configure beartype and jaxtyping. 11 | with install_import_hook( 12 | ("src",), 13 | ("beartype", "beartype"), 14 | ): 15 | from src.misc.image_io import save_image 16 | from src.misc.sh_rotation import rotate_sh 17 | from src.model.decoder.cuda_splatting import render_cuda 18 | from src.visualization.camera_trajectory.spin import generate_spin 19 | 20 | 21 | if __name__ == "__main__": 22 | NUM_FRAMES = 60 23 | NUM_GAUSSIANS = 1 24 | DEGREE = 4 25 | IMAGE_SHAPE = (512, 512) 26 | RESULT_PATH = Path("outputs/test_splatter") 27 | 28 | device = torch.device("cuda:0") 29 | 30 | # Generate camera parameters. 31 | extrinsics = generate_spin(60, device, 0.0, 10.0) 32 | intrinsics = torch.eye(3, dtype=torch.float32, device=device) 33 | intrinsics[:2, 2] = 0.5 34 | intrinsics[:2, :2] *= 0.5 35 | intrinsics = repeat(intrinsics, "i j -> b i j", b=NUM_FRAMES) 36 | 37 | # Generate Gaussians. 38 | means = torch.randn((NUM_GAUSSIANS, 3), dtype=torch.float32, device=device) * 0 39 | scales = torch.rand((NUM_GAUSSIANS, 3), dtype=torch.float32, device=device) * 0 + 1 40 | rotations = R.random(NUM_GAUSSIANS).as_matrix() 41 | rotations = torch.tensor(rotations, dtype=torch.float32, device=device) 42 | covariances = rotations @ scales.diag_embed() 43 | covariances = einsum(covariances, covariances, "b i j, b k j -> b i k") 44 | sh_coefficients = torch.randn( 45 | (NUM_GAUSSIANS, 3, (DEGREE + 1) ** 2), dtype=torch.float32, device=device 46 | ) 47 | 48 | # https://en.wikipedia.org/wiki/Spherical_harmonics#/media/File:Spherical_Harmonics.png 49 | # we are rolling forward, rotation-wise 50 | # red is blue, blue is yellow 51 | # we are rotating about the y axis 52 | 53 | sh_coefficients[:] = 0 54 | 55 | # sh_coefficients[:, 0, 1] = 10 # rotation does not change this // -1 56 | # sh_coefficients[:, 0, 2] = 10 # in/out // 0 57 | # sh_coefficients[:, 0, 3] = 10 # sides // 1 58 | 59 | sh_coefficients[:, 0, 4] = 10 # rotation does not change this // -2 60 | sh_coefficients[:, 0, 5] = 10 # rotation does not change this // -1 61 | sh_coefficients[:, 0, 6] = 10 # BRBR // 0 62 | sh_coefficients[:, 0, 7] = 10 # BRBR // 1 63 | sh_coefficients[:, 0, 8] = 10 # 2x red // 2 64 | 65 | opacities = torch.rand(NUM_GAUSSIANS, dtype=torch.float32, device=device) * 0 + 1 66 | 67 | # rotate_sh(sh_coefficients, extrinsics[0, :3, :3].inverse()) 68 | 69 | # Render images using the CUDA splatter. 70 | start_time = time() 71 | rendered_cuda = [ 72 | render_cuda( 73 | c2w[None], 74 | k[None], 75 | torch.tensor([0.1], dtype=torch.float32, device=device), 76 | torch.tensor([20.0], dtype=torch.float32, device=device), 77 | IMAGE_SHAPE, 78 | torch.zeros((1, 3), dtype=torch.float32, device=device), 79 | means[None], 80 | covariances[None], 81 | rotate_sh(sh_coefficients, c2w[:3, :3])[None], 82 | # sh_coefficients[None], 83 | opacities[None], 84 | )[0] 85 | for c2w, k in zip(tqdm(extrinsics, desc="Rendering"), intrinsics) 86 | ] 87 | print(f"CUDA rendering took {time() - start_time:.2f} seconds.") 88 | 89 | RESULT_PATH.mkdir(exist_ok=True, parents=True) 90 | for index, frame in enumerate(tqdm(rendered_cuda, "Saving images")): 91 | save_image(frame, RESULT_PATH / f"frame_{index:0>3}.png") 92 | 93 | import os 94 | 95 | cmd = ( 96 | 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 -pix_fmt ' 97 | 'yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" animation.mp4' 98 | ) 99 | os.system(f"cd {RESULT_PATH} && {cmd}") 100 | 101 | a = 1 102 | -------------------------------------------------------------------------------- /src/scripts/visualize_epipolar_lines.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from random import randrange 3 | 4 | import hydra 5 | import torch 6 | from jaxtyping import install_import_hook 7 | from lightning_fabric.utilities.apply_func import move_data_to_device 8 | from omegaconf import DictConfig 9 | 10 | # Configure beartype and jaxtyping. 11 | with install_import_hook( 12 | ("src",), 13 | ("beartype", "beartype"), 14 | ): 15 | from src.config import load_typed_root_config 16 | from src.dataset.data_module import DataModule 17 | from src.geometry.epipolar_lines import project_rays 18 | from src.geometry.projection import get_world_rays 19 | from src.global_cfg import set_cfg 20 | from src.misc.image_io import save_image 21 | from src.misc.step_tracker import StepTracker 22 | from src.visualization.annotation import add_label 23 | from src.visualization.drawing.lines import draw_lines 24 | from src.visualization.drawing.points import draw_points 25 | from src.visualization.layout import add_border, hcat 26 | 27 | 28 | @hydra.main( 29 | version_base=None, 30 | config_path="../../config", 31 | config_name="main", 32 | ) 33 | def visualize_epipolar_lines(cfg_dict: DictConfig): 34 | device = torch.device("cuda:0") 35 | num_lines = 5 36 | 37 | # Boilerplate configuration stuff like in the main file... 38 | cfg = load_typed_root_config(cfg_dict) 39 | set_cfg(cfg_dict) 40 | torch.manual_seed(cfg_dict.seed) 41 | data_module = DataModule(cfg.dataset, cfg.data_loader, StepTracker()) 42 | # dataset = iter(data_module.train_dataloader()) 43 | dataset = iter(data_module.test_dataloader()) 44 | 45 | cur_radius = ( 46 | str(cfg.dataset.view_sampler.index_path.stem).split(".")[0].split("_")[-1] 47 | ) 48 | 49 | for e_idx, example in enumerate(dataset): 50 | # example = next(dataset) 51 | if e_idx > 10: 52 | break 53 | print(f"Drawing scene {example['scene'][0]}") 54 | 55 | example = move_data_to_device(example, device) 56 | 57 | # Plot a few different examples to try to get an interesting line. 58 | for i in range(num_lines): 59 | # Get a single example from the dataset. 60 | # example = next(dataset) 61 | # example = move_data_to_device(example, device) 62 | 63 | # Pick a random pixel to visualize the epipolar line for. 64 | _, v, _, h, w = example["context"]["image"].shape 65 | assert v >= 2 66 | x = randrange(0, w) 67 | y = randrange(0, h) 68 | xy = torch.tensor((x / w, y / h), dtype=torch.float32, device=device) 69 | 70 | # Generate the ray that corresponds to the point. 71 | source_extrinsics = example["context"]["extrinsics"][0, 0] 72 | source_intrinsics = example["context"]["intrinsics"][0, 0] 73 | origin, direction = get_world_rays(xy, source_extrinsics, source_intrinsics) 74 | target_extrinsics = example["context"]["extrinsics"][0, 1] 75 | target_intrinsics = example["context"]["intrinsics"][0, 1] 76 | projection = project_rays( 77 | origin, 78 | direction, 79 | target_extrinsics, 80 | target_intrinsics, 81 | near=example["context"]["near"][0, 0], 82 | far=example["context"]["far"][0, 0], 83 | ) 84 | 85 | # Draw the point (ray) onto the source view. 86 | source_image = example["context"]["image"][0, 0] 87 | source_image = draw_points( 88 | source_image, xy, (1, 0, 0), 4, x_range=(0, 1), y_range=(0, 1) 89 | ) 90 | 91 | # Draw the epipolar line onto the target view. 92 | target_image = example["context"]["image"][0, 1] 93 | target_image = draw_lines( 94 | target_image, 95 | projection["xy_min"], 96 | projection["xy_max"], 97 | (1, 0, 0), 98 | 4, 99 | x_range=(0, 1), 100 | y_range=(0, 1), 101 | ) 102 | 103 | # Put the images side by side. 104 | source_image = add_label(source_image, "Source") 105 | target_image = add_label(target_image, "Target") 106 | together = add_border(hcat(source_image, target_image)) 107 | save_image( 108 | together, 109 | Path(f"epipolar_lines/{cur_radius}/{example['scene'][0]}_{i:0>2}.png"), 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | visualize_epipolar_lines() 115 | -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [ 88 | color_function(batch) 89 | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) 90 | ] 91 | color = torch.cat(color, dim=0) 92 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 93 | 94 | # If any MSAA passes remain, subdivide. 95 | if remaining_passes > 0: 96 | mask = detect_msaa_pixels(color) 97 | batch_index, row_index, col_index = torch.where(mask) 98 | xy = xy[batch_index, row_index, col_index] 99 | 100 | offsets = generate_sample_grid((subdivision, subdivision), device) 101 | offsets = (offsets / subdivision - 0.5) * scale 102 | 103 | color_fine = run_msaa_pass( 104 | xy[:, None, None] + offsets, 105 | color_function, 106 | scale / subdivision, 107 | subdivision, 108 | remaining_passes - 1, 109 | device, 110 | batch_size=batch_size, 111 | ) 112 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 113 | 114 | return color 115 | 116 | 117 | @torch.no_grad() 118 | def render( 119 | shape: tuple[int, int], 120 | color_function: ColorFunction, 121 | device: torch.device, 122 | subdivision: int = 8, 123 | num_passes: int = 2, 124 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 125 | xy = generate_sample_grid(shape, device) 126 | return run_msaa_pass( 127 | xy[None], 128 | color_function, 129 | 1.0, 130 | subdivision, 131 | num_passes, 132 | device, 133 | )[0] 134 | 135 | 136 | def render_over_image( 137 | image: Float[Tensor, "3 height width"], 138 | color_function: ColorFunction, 139 | device: torch.device, 140 | subdivision: int = 8, 141 | num_passes: int = 1, 142 | ) -> Float[Tensor, "3 height width"]: 143 | _, h, w = image.shape 144 | overlay = render( 145 | (h, w), 146 | color_function, 147 | device, 148 | subdivision=subdivision, 149 | num_passes=num_passes, 150 | ) 151 | color, alpha = overlay.split((3, 1), dim=0) 152 | return image * (1 - alpha) + color * alpha 153 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [ 16 | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) 17 | for x in images 18 | ] 19 | for image, result in zip(images, results): 20 | slices = [slice(0, x) for x in image.shape] 21 | result[slices] = image[slices] 22 | return results 23 | 24 | 25 | def render_projections( 26 | gaussians: Gaussians, 27 | resolution: int, 28 | margin: float = 0.1, 29 | draw_label: bool = True, 30 | extra_label: str = "", 31 | ) -> Float[Tensor, "batch 3 3 height width"]: 32 | device = gaussians.means.device 33 | b, _, _ = gaussians.means.shape 34 | 35 | # Compute the minima and maxima of the scene. 36 | minima = gaussians.means.min(dim=1).values 37 | maxima = gaussians.means.max(dim=1).values 38 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 39 | minima, maxima, margin=margin 40 | ) 41 | 42 | projections = [] 43 | for look_axis in range(3): 44 | right_axis = (look_axis + 1) % 3 45 | down_axis = (look_axis + 2) % 3 46 | 47 | # Define the extrinsics for rendering. 48 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 49 | extrinsics[:, right_axis, 0] = 1 50 | extrinsics[:, down_axis, 1] = 1 51 | extrinsics[:, look_axis, 2] = 1 52 | extrinsics[:, right_axis, 3] = 0.5 * ( 53 | scene_minima[:, right_axis] + scene_maxima[:, right_axis] 54 | ) 55 | extrinsics[:, down_axis, 3] = 0.5 * ( 56 | scene_minima[:, down_axis] + scene_maxima[:, down_axis] 57 | ) 58 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 59 | extrinsics[:, 3, 3] = 1 60 | 61 | # Define the intrinsics for rendering. 62 | extents = scene_maxima - scene_minima 63 | far = extents[:, look_axis] 64 | near = torch.zeros_like(far) 65 | width = extents[:, right_axis] 66 | height = extents[:, down_axis] 67 | 68 | projection = render_cuda_orthographic( 69 | extrinsics, 70 | width, 71 | height, 72 | near, 73 | far, 74 | (resolution, resolution), 75 | torch.zeros((b, 3), dtype=torch.float32, device=device), 76 | gaussians.means, 77 | gaussians.covariances, 78 | gaussians.harmonics, 79 | gaussians.opacities, 80 | fov_degrees=10.0, 81 | ) 82 | if draw_label: 83 | right_axis_name = "XYZ"[right_axis] 84 | down_axis_name = "XYZ"[down_axis] 85 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 86 | projection = torch.stack([add_label(x, label) for x in projection]) 87 | 88 | projections.append(projection) 89 | 90 | return torch.stack(pad(projections), dim=1) 91 | 92 | 93 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 94 | # Define colors for context and target views. 95 | num_context_views = batch["context"]["extrinsics"].shape[1] 96 | num_target_views = batch["target"]["extrinsics"].shape[1] 97 | color = torch.ones( 98 | (num_target_views + num_context_views, 3), 99 | dtype=torch.float32, 100 | device=batch["target"]["extrinsics"].device, 101 | ) 102 | color[num_context_views:, 1:] = 0 103 | 104 | return draw_cameras( 105 | resolution, 106 | torch.cat( 107 | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) 108 | ), 109 | torch.cat( 110 | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) 111 | ), 112 | color, 113 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 114 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 115 | ) 116 | -------------------------------------------------------------------------------- /src/visualization/vis_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torchvision.utils as vutils 5 | import cv2 6 | from matplotlib.cm import get_cmap 7 | import matplotlib as mpl 8 | import matplotlib.cm as cm 9 | 10 | 11 | # https://github.com/autonomousvision/unimatch/blob/master/utils/visualization.py 12 | 13 | 14 | def vis_disparity(disp): 15 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 16 | disp_vis = disp_vis.astype("uint8") 17 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 18 | 19 | return disp_vis 20 | 21 | 22 | def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'): 23 | # visualize inverse depth 24 | assert isinstance(disp, torch.Tensor) 25 | 26 | disp = disp.numpy() 27 | vmax = np.percentile(disp, 95) 28 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 29 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 30 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] 31 | 32 | if return_numpy: 33 | return colormapped_im 34 | 35 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 36 | 37 | return viz 38 | --------------------------------------------------------------------------------