├── .gitignore ├── DATASETS.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── assets ├── dl3dv_start_0_distance_100_ctx_12v_video.json ├── dl3dv_start_0_distance_10_ctx_2v_tgt_4v.json ├── dl3dv_start_0_distance_50_ctx_2v_video_0_50.json ├── dl3dv_start_0_distance_50_ctx_4v_video_0_50.json ├── dl3dv_start_0_distance_50_ctx_6v_video_0_50.json ├── evaluation_index_acid.json ├── evaluation_index_re10k.json ├── evaluation_index_re10k_video.json └── re10k_ctx_6v_video.json ├── config ├── dataset │ ├── dl3dv.yaml │ ├── re10k.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ ├── boundedv2.yaml │ │ ├── boundedv2_360.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── bounded_re10k.yaml │ │ ├── boundedv2_dl3dv.yaml │ │ ├── evaluation_dl3dv.yaml │ │ └── evaluation_re10k.yaml ├── experiment │ ├── dl3dv.yaml │ └── re10k.yaml ├── loss │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ └── depthsplat.yaml ├── requirements.txt ├── scripts ├── dl3dv_depthsplat_train.sh ├── inference_depth.sh └── re10k_depthsplat_train.sh └── src ├── config.py ├── dataset ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_dl3dv.py ├── dataset_re10k.py ├── shims │ ├── augmentation_shim.py │ ├── bounds_shim.py │ ├── crop_shim.py │ └── patch_shim.py ├── types.py ├── validation_wrapper.py └── view_sampler │ ├── __init__.py │ ├── view_sampler.py │ ├── view_sampler_all.py │ ├── view_sampler_arbitrary.py │ ├── view_sampler_bounded.py │ ├── view_sampler_bounded_v2.py │ └── view_sampler_evaluation.py ├── evaluation ├── evaluation_cfg.py ├── evaluation_index_generator.py ├── metric_computer.py └── metrics.py ├── geometry ├── epipolar_lines.py └── projection.py ├── global_cfg.py ├── loss ├── __init__.py ├── loss.py ├── loss_lpips.py └── loss_mse.py ├── main.py ├── misc ├── LocalLogger.py ├── benchmarker.py ├── collation.py ├── discrete_probability_distribution.py ├── heterogeneous_pairings.py ├── image_io.py ├── nn_module_tools.py ├── render_utils.py ├── resume_ckpt.py ├── sh_rotation.py ├── stablize_camera.py ├── step_tracker.py └── wandb_tools.py ├── model ├── decoder │ ├── __init__.py │ ├── cuda_splatting.py │ ├── decoder.py │ └── decoder_splatting_cuda.py ├── encoder │ ├── __init__.py │ ├── common │ │ ├── gaussian_adapter.py │ │ ├── gaussians.py │ │ └── sampler.py │ ├── encoder.py │ ├── encoder_depthsplat.py │ ├── unimatch │ │ ├── backbone.py │ │ ├── dpt_head.py │ │ ├── feature_upsampler.py │ │ ├── ldm_unet │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── cross_attention.py │ │ │ ├── unet.py │ │ │ └── util.py │ │ ├── matching.py │ │ ├── mv_transformer.py │ │ ├── mv_unimatch.py │ │ ├── position.py │ │ ├── utils.py │ │ └── vit_fpn.py │ └── visualization │ │ ├── encoder_visualizer.py │ │ ├── encoder_visualizer_depthsplat.py │ │ └── encoder_visualizer_depthsplat_cfg.py ├── model_wrapper.py ├── ply_export.py └── types.py ├── scripts ├── convert_dl3dv_test.py ├── convert_dl3dv_train.py └── generate_dl3dv_index.py └── visualization ├── annotation.py ├── camera_trajectory ├── interpolation.py ├── spin.py └── wobble.py ├── color_map.py ├── colors.py ├── drawing ├── cameras.py ├── coordinate_conversion.py ├── lines.py ├── points.py ├── rendering.py └── types.py ├── layout.py ├── validation_in_3d.py └── vis_depth.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | datasets 165 | pretrained 166 | outputs 167 | output 168 | checkpoints 169 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | For view synthesis experiments with Gaussian splatting, we mainly use [RealEstate10K](https://google.github.io/realestate10k/index.html) and [DL3DV](https://github.com/DL3DV-10K/Dataset) datasets. We provide the data processing scripts to convert the original datasets to pytorch chunk files which can be directly loaded with this codebase. 4 | 5 | Expected folder structure: 6 | 7 | ``` 8 | ├── datasets 9 | │ ├── re10k 10 | │ ├── ├── train 11 | │ ├── ├── ├── 000000.torch 12 | │ ├── ├── ├── ... 13 | │ ├── ├── ├── index.json 14 | │ ├── ├── test 15 | │ ├── ├── ├── 000000.torch 16 | │ ├── ├── ├── ... 17 | │ ├── ├── ├── index.json 18 | │ ├── dl3dv 19 | │ ├── ├── train 20 | │ ├── ├── ├── 000000.torch 21 | │ ├── ├── ├── ... 22 | │ ├── ├── ├── index.json 23 | │ ├── ├── test 24 | │ ├── ├── ├── 000000.torch 25 | │ ├── ├── ├── ... 26 | │ ├── ├── ├── index.json 27 | ``` 28 | 29 | It's recommended to create a symbolic link from `YOUR_DATASET_PATH` to `datasets` using 30 | ``` 31 | ln -s YOUR_DATASET_PATH datasets 32 | ``` 33 | 34 | Or you can specify your dataset path with `dataset.roots=[YOUR_DATASET_PATH]/re10k` and `dataset.roots=[YOUR_DATASET_PATH]/dl3dv` in the config. 35 | 36 | We also provide instructions to convert additional datasets to the desired format. 37 | 38 | 39 | ## RealEstate10K 40 | 41 | For experiments on RealEstate10K, we primarily follow [pixelSplat](https://github.com/dcharatan/pixelsplat) and [MVSplat](https://github.com/donydchen/mvsplat) to train and evaluate on the 256x256 resolution. 42 | 43 | Please refer to [pixelSplat repo](https://github.com/dcharatan/pixelsplat?tab=readme-ov-file#acquiring-datasets) for acquiring the processed 360p (360x640) dataset. 44 | 45 | If you would like to train and evaluate on the high-resolution RealEstate10K dataset, you will need to download the 720p (720x1280) version. Please refer to [here](https://github.com/yilundu/cross_attention_renderer/tree/master/data_download) for the downloading script. Note that the script by default downloads the 360p videos, you will need to modify the`360p` to `720p` in [this line of code](https://github.com/yilundu/cross_attention_renderer/blob/master/data_download/generate_realestate.py#L137) to download the 720p videos. 46 | 47 | After downloading the 720p dataset, you can use the scripts [here](https://github.com/dcharatan/real_estate_10k_tools/tree/main/src) to convert the dataset to the desired format in this codebase. 48 | 49 | Considering the full 720p dataset is quite large and may take time to download and process, we provide a preprocessed subset in `.torch` chunks ([download](https://huggingface.co/datasets/haofeixu/depthsplat/resolve/main/re10k_720p_test_subset.zip)) containing two test scenes to quickly run inference with our model. 50 | 51 | ## DL3DV 52 | 53 | For experiments on DL3DV, we primarily train and evaluate at a resolution of 256×448. Additionally, we train high-resolution models (448×768) for qualitative results. 54 | 55 | For the test set, we use the [DL3DV-Benchmark](https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark) split, which contains 140 scenes for evaluation. You can first use the script [src/scripts/convert_dl3dv_test.py](src/scripts/convert_dl3dv_test.py) to convert the test set, and then run [src/scripts/generate_dl3dv_index.py](src/scripts/generate_dl3dv_index.py) to generate the `index.json` file for the test set. 56 | 57 | For the training set, we use the [DL3DV-480p](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-480P) dataset (270x480 resolution), where the 140 scenes in the test set are excluded during processing the training set. After downloading the [DL3DV-480p](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-480P) dataset, You can first use the script [src/scripts/convert_dl3dv_train.py](src/scripts/convert_dl3dv_train.py) to convert the training set, and then run [src/scripts/generate_dl3dv_index.py](src/scripts/generate_dl3dv_index.py) to generate the `index.json` file for the training set. 58 | 59 | Please note that you will need to update the dataset paths in the aforementioned processing scripts. 60 | 61 | If you would like to train and evaluate on the high-resolution DL3DV dataset, you will need to download the [DL3DV-960P](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-960P) version (540x960 resolution). Simply follow the same procedure for data processing, but update the `images_8` folder to `images_4`. 62 | 63 | Please follow the [DL3DV license](https://github.com/DL3DV-10K/Dataset/blob/main/License.md) if you use this dataset in your project and kindly [reference the DL3DV paper](https://github.com/DL3DV-10K/Dataset?tab=readme-ov-file#bibtex). 64 | 65 | Considering the full 960p dataset is quite large and may take time to download and process, we provide a preprocessed subset in `.torch` chunks ([download](https://huggingface.co/datasets/haofeixu/depthsplat/resolve/main/dl3dv_960p_test_subset.zip)) containing two test scenes to quickly run inference with our model. Please note that this released subset is intended solely for research purposes. We disclaim any responsibility for the misuse, inappropriate use, or unethical application of the dataset by individuals or entities who download or access it. We kindly ask users to adhere to the [DL3DV license](https://github.com/DL3DV-10K/Dataset/blob/main/License.md). 66 | 67 | 68 | ## ACID 69 | 70 | 71 | We also evaluate our generalization on the [ACID](https://infinite-nature.github.io/) dataset. Note that we do not use the training set; you only need to [download the test set](http://schadenfreude.csail.mit.edu:8000/re10k_test_only.zip) (provided by [pixelSplat repo](https://github.com/dcharatan/pixelsplat?tab=readme-ov-file#acquiring-datasets)) for evaluation. 72 | 73 | 74 | 75 | ## Additional Datasets 76 | 77 | If you would like to train and/or evaluate on additional datasets, just modify the [data processing scripts](src/scripts) to convert the dataset format. Kindly note the [camera conventions](https://github.com/cvg/depthsplat/tree/main?tab=readme-ov-file#camera-conventions) used in this codebase. 78 | 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Haofei Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | - We provide pre-trained models for view synthesis with 3D Gaussian splatting and scale-consistent depth estimation from multi-view posed images. 4 | 5 | - We assume that the downloaded weights are stored in the `pretrained` directory. It's recommended to create a symbolic link from `YOUR_MODEL_PATH` to `pretrained` using 6 | ``` 7 | ln -s YOUR_MODEL_PATH pretrained 8 | ``` 9 | 10 | - To verify the integrity of downloaded files, each model on this page includes its [sha256sum](https://sha256sum.com/) prefix in the file name, which can be checked using the command `sha256sum filename`. 11 | 12 | 13 | ## Gaussian Splatting 14 | 15 | - The models are trained on RealEstate10K (re10k) and/or DL3DV (dl3dv) datasets at resolutions of 256x256, 256x448, and 448x768. The number of training views ranges from 2 to 10. 16 | 17 | - The "→" symbol indicates that the models are trained in two stages. For example, "re10k → (re10k+dl3dv)" means the model is firstly trained on the RealEstate10K dataset and then fine-tuned using a combination of the RealEstate10K and DL3DV datasets. 18 | 19 | 20 | | Model | Training Data | Training Resolution | Training Views | Params (M) | Download | 21 | | ------------------------------------------------------------ | :------------------------: | :-------------------: | :------------: | :--------: | :----------------------------------------------------------: | 22 | | depthsplat-gs-small-re10k-256x256-view2-cfeab6b1.pth | re10k | 256x256 | 2 | 37 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-small-re10k-256x256-view2-cfeab6b1.pth) | 23 | | depthsplat-gs-base-re10k-256x256-view2-ca7b6795.pth | re10k | 256x256 | 2 | 117 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-base-re10k-256x256-view2-ca7b6795.pth) | 24 | | depthsplat-gs-large-re10k-256x256-view2-e0f0f27a.pth | re10k | 256x256 | 2 | 360 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-large-re10k-256x256-view2-e0f0f27a.pth) | 25 | | depthsplat-gs-base-re10k-256x448-view2-fea94f65.pth | re10k | 256x448 | 2 | 117 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-base-re10k-256x448-view2-fea94f65.pth) | 26 | | depthsplat-gs-base-dl3dv-256x448-randview2-6-02c7b19d.pth | re10k → dl3dv | 256x448 | 2-6 | 117 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-base-dl3dv-256x448-randview2-6-02c7b19d.pth) | 27 | | depthsplat-gs-small-re10kdl3dv-448x768-randview4-10-c08188db.pth | re10k → (re10k+dl3dv) | 256x448 →448x768 | 4-10 | 37 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-small-re10kdl3dv-448x768-randview4-10-c08188db.pth) | 28 | | depthsplat-gs-base-re10kdl3dv-448x768-randview2-6-f8ddd845.pth | re10k → (re10k+dl3dv) | 256x448 →448x768 | 2-6 | 117 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-gs-base-re10kdl3dv-448x768-randview2-6-f8ddd845.pth) | 29 | 30 | 31 | 32 | ## Depth Prediction 33 | 34 | - The depth models are trained with the following procedure: 35 | - Initialize the monocular feature with Depth Anything V2 and the multi-view Transformer with UniMatch. 36 | - Train the full DepthSplat model end-to-end on the mixed RealEstate10K and DL3DV datasets. 37 | - Fine-tune the pre-trained depth model on the depth datasets with ground truth depth supervision. The depth datasets used for fine-tuning include ScanNet, TartanAir, and VKITTI2. 38 | - The depth models are fine-tuned with random numbers (2-8) of input images, and the training image resolution is 352x640. 39 | - The scale of the predicted depth is aligned with the scale of camera pose's translation. 40 | 41 | | Model | Training Data | Training Resolution | Training Views | Params (M) | Download | 42 | | ------------------------------------------------------- | :----------------------------------------------: | :--------------------: | :------------: | :--------: | :----------------------------------------------------------: | 43 | | depthsplat-depth-small-352x640-randview2-8-e807bd82.pth | (re10k+dl3dv) → (scannet+tartanair+vkitti2) | 448x768 → 352x640 | 2-8 | 36 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-depth-small-352x640-randview2-8-e807bd82.pth) | 44 | | depthsplat-depth-base-352x640-randview2-8-65a892c5.pth | (re10k+dl3dv) → (scannet+tartanair+vkitti2) | 448x768 → 352x640 | 2-8 | 111 | [download](https://huggingface.co/haofeixu/depthsplat/resolve/main/depthsplat-depth-base-352x640-randview2-8-65a892c5.pth) | 45 | 46 | 47 | -------------------------------------------------------------------------------- /config/dataset/dl3dv.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: boundedv2_360 3 | 4 | name: dl3dv 5 | roots: [datasets/dl3dv] 6 | make_baseline_1: false 7 | augment: true 8 | 9 | image_shape: [270, 480] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: false 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | sort_target_index: true 24 | sort_context_index: true 25 | 26 | train_times_per_scene: 1 27 | test_times_per_scene: 1 28 | ori_image_shape: [270, 480] 29 | overfit_max_views: 148 30 | use_index_to_load_chunk: false 31 | 32 | mix_tartanair: false 33 | no_mix_test_set: true 34 | load_depth: false 35 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: re10k 5 | roots: [datasets/re10k] 6 | make_baseline_1: false 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | 16 | skip_bad_shape: true 17 | near: -1. 18 | far: -1. 19 | baseline_scale_bounds: true 20 | shuffle_val: true 21 | test_len: -1 22 | test_chunk_interval: 1 23 | 24 | use_index_to_load_chunk: false -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 -------------------------------------------------------------------------------- /config/dataset/view_sampler/boundedv2.yaml: -------------------------------------------------------------------------------- 1 | name: boundedv2 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | max_distance_to_context_views: 0 9 | 10 | context_gap_warm_up_steps: 0 11 | target_gap_warm_up_steps: 0 12 | 13 | initial_min_distance_between_context_views: 2 14 | initial_max_distance_between_context_views: 6 15 | initial_max_distance_to_context_views: 0 16 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/boundedv2_360.yaml: -------------------------------------------------------------------------------- 1 | name: boundedv2 2 | 3 | num_target_views: 4 4 | num_context_views: 4 5 | 6 | min_distance_between_context_views: 20 7 | max_distance_between_context_views: 50 8 | max_distance_to_context_views: 0 9 | 10 | context_gap_warm_up_steps: 10000 11 | target_gap_warm_up_steps: 0 12 | 13 | initial_min_distance_between_context_views: 15 14 | initial_max_distance_between_context_views: 30 15 | initial_max_distance_to_context_views: 0 16 | extra_views_sampling_strategy: farthest_point 17 | target_views_replace_sample: false 18 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k_video.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 45 6 | max_distance_between_context_views: 135 7 | min_distance_to_context_views: 0 8 | warm_up_steps: 30000 9 | initial_min_distance_between_context_views: 25 10 | initial_max_distance_between_context_views: 45 11 | num_target_views: 4 12 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 20 6 | max_distance_between_context_views: 50 7 | max_distance_to_context_views: 0 8 | context_gap_warm_up_steps: 10000 9 | target_gap_warm_up_steps: 0 10 | initial_min_distance_between_context_views: 15 11 | initial_max_distance_between_context_views: 30 12 | initial_max_distance_to_context_views: 0 13 | extra_views_sampling_strategy: farthest_point 14 | num_target_views: 4 15 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/dl3dv_360_v5.json 6 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | -------------------------------------------------------------------------------- /config/experiment/dl3dv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dl3dv 5 | - override /model/encoder: depthsplat 6 | - override /loss: [mse, lpips] 7 | - override /dataset/view_sampler: boundedv2_360 8 | 9 | wandb: 10 | name: dl3dv 11 | tags: [dl3dv, 270x480] 12 | 13 | data_loader: 14 | train: 15 | batch_size: 1 16 | 17 | trainer: 18 | max_steps: 300_001 19 | num_nodes: 1 20 | 21 | model: 22 | encoder: 23 | num_depth_candidates: 128 24 | costvolume_unet_feat_dim: 128 25 | costvolume_unet_channel_mult: [1,1,1] 26 | costvolume_unet_attn_res: [4] 27 | gaussians_per_pixel: 1 28 | depth_unet_feat_dim: 32 29 | depth_unet_attn_res: [16] 30 | depth_unet_channel_mult: [1,1,1,1,1] 31 | shim_patch_size: 16 32 | 33 | # lpips loss 34 | loss: 35 | lpips: 36 | apply_after_step: 0 37 | weight: 0.05 38 | 39 | dataset: 40 | near: 0.5 41 | far: 200. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | min_views: 0 45 | max_views: 0 46 | highres: false 47 | 48 | test: 49 | eval_time_skip_steps: 0 50 | compute_scores: true 51 | dec_chunk_size: 30 52 | -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: depthsplat 6 | - override /loss: [mse, lpips] 7 | 8 | wandb: 9 | name: re10k 10 | tags: [re10k, 256x256] 11 | 12 | data_loader: 13 | train: 14 | batch_size: 14 15 | 16 | trainer: 17 | max_steps: 300_001 18 | num_nodes: 1 19 | 20 | model: 21 | encoder: 22 | num_depth_candidates: 128 23 | costvolume_unet_feat_dim: 128 24 | costvolume_unet_channel_mult: [1,1,1] 25 | costvolume_unet_attn_res: [4] 26 | gaussians_per_pixel: 1 27 | depth_unet_feat_dim: 32 28 | depth_unet_attn_res: [16] 29 | depth_unet_channel_mult: [1,1,1,1,1] 30 | 31 | # lpips loss 32 | loss: 33 | lpips: 34 | apply_after_step: 0 35 | weight: 0.05 36 | 37 | dataset: 38 | image_shape: [256, 256] 39 | roots: [datasets/re10k] 40 | near: 0.5 41 | far: 100. 42 | baseline_scale_bounds: false 43 | make_baseline_1: false 44 | train_times_per_scene: 1 45 | highres: false 46 | 47 | test: 48 | eval_time_skip_steps: 5 49 | compute_scores: true 50 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.05 3 | apply_after_step: 150_000 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - model/encoder: depthsplat 5 | - model/decoder: splatting_cuda 6 | - loss: [mse] 7 | 8 | wandb: 9 | project: depthsplat 10 | entity: placeholder 11 | name: placeholder 12 | mode: online 13 | id: null 14 | 15 | mode: train 16 | 17 | dataset: 18 | overfit_to_scene: null 19 | 20 | data_loader: 21 | train: 22 | num_workers: 10 23 | persistent_workers: true 24 | batch_size: 4 25 | seed: 1234 26 | test: 27 | num_workers: 4 28 | persistent_workers: false 29 | batch_size: 1 30 | seed: 2345 31 | val: 32 | num_workers: 1 33 | persistent_workers: true 34 | batch_size: 1 35 | seed: 3456 36 | 37 | optimizer: 38 | lr: 2.e-4 39 | lr_monodepth: 2.e-6 40 | warm_up_steps: 2000 41 | weight_decay: 0.01 42 | 43 | checkpointing: 44 | load: null 45 | every_n_train_steps: 5000 46 | save_top_k: 5 47 | pretrained_model: null 48 | pretrained_monodepth: null 49 | pretrained_mvdepth: null 50 | pretrained_depth: null 51 | no_strict_load: false 52 | resume: false 53 | 54 | train: 55 | depth_mode: null 56 | extended_visualization: false 57 | print_log_every_n_steps: 100 58 | eval_model_every_n_val: 2 # quantitative evaluation every n val 59 | eval_data_length: 999999 60 | eval_deterministic: false 61 | eval_time_skip_steps: 3 62 | eval_save_model: true 63 | l1_loss: false 64 | intermediate_loss_weight: 0.9 65 | no_viz_video: false 66 | viz_depth: false 67 | forward_depth_only: false 68 | train_ignore_large_loss: 0. 69 | no_log_projections: false 70 | 71 | test: 72 | output_path: outputs/test 73 | compute_scores: true 74 | eval_time_skip_steps: 0 75 | save_image: false 76 | save_video: false 77 | save_gt_image: false 78 | save_input_images: false 79 | save_depth: false 80 | save_depth_npy: false 81 | save_depth_concat_img: false 82 | save_gaussian: false 83 | render_chunk_size: null 84 | stablize_camera: false 85 | stab_camera_kernel: 50 86 | 87 | seed: 111123 88 | 89 | trainer: 90 | max_steps: -1 91 | val_check_interval: 0.5 92 | gradient_clip_val: 0.5 93 | num_sanity_val_steps: 2 94 | 95 | output_dir: outputs/tmp 96 | 97 | use_plugins: false 98 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /config/model/encoder/depthsplat.yaml: -------------------------------------------------------------------------------- 1 | name: depthsplat 2 | 3 | num_depth_candidates: 128 4 | num_surfaces: 1 5 | 6 | gaussians_per_pixel: 1 7 | 8 | gaussian_adapter: 9 | gaussian_scale_min: 1e-10 10 | gaussian_scale_max: 3. 11 | sh_degree: 2 12 | 13 | d_feature: 128 14 | 15 | visualizer: 16 | num_samples: 8 17 | min_resolution: 256 18 | export_ply: false 19 | 20 | unimatch_weights_path: "pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" 21 | multiview_trans_attn_split: 2 22 | costvolume_unet_feat_dim: 128 23 | costvolume_unet_channel_mult: [1,1,1] 24 | costvolume_unet_attn_res: [] 25 | depth_unet_feat_dim: 64 26 | depth_unet_attn_res: [] 27 | depth_unet_channel_mult: [1, 1, 1] 28 | downscale_factor: 4 29 | shim_patch_size: 4 30 | 31 | local_mv_match: 2 32 | 33 | # monodepth 34 | monodepth_vit_type: vits 35 | 36 | # return depth 37 | supervise_intermediate_depth: true 38 | return_depth: true 39 | 40 | # mv_unimatch 41 | num_scales: 1 42 | upsample_factor: 4 43 | lowest_feature_resolution: 4 44 | depth_unet_channels: 128 45 | grid_sample_disable_cudnn: false 46 | 47 | # depthsplat color branch 48 | large_gaussian_head: false 49 | color_large_unet: false 50 | init_sh_input_img: true 51 | feature_upsampler_channels: 64 52 | gaussian_regressor_channels: 64 53 | 54 | # only depth 55 | train_depth_only: false -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beartype==0.18.5 2 | colorama==0.4.6 3 | colorspacious==1.1.2 4 | dacite==1.8.1 5 | e3nn==0.5.1 6 | einops==0.8.0 7 | hydra-core==1.3.2 8 | jaxtyping==0.2.33 9 | lpips==0.1.4 10 | matplotlib==3.9.1 11 | moviepy==1.0.3 12 | numpy==1.24.4 13 | opencv_python==4.11.0.86 14 | Pillow==10.4.0 15 | plyfile==1.1 16 | pytorch_lightning==2.4.0 17 | scikit-image==0.24.0 18 | sk-video==1.1.10 19 | tabulate==0.9.0 20 | tqdm==4.66.4 21 | wandb==0.17.7 22 | xformers==0.0.27.post2 23 | git+https://github.com/dcharatan/diff-gaussian-rasterization-modified 24 | -------------------------------------------------------------------------------- /scripts/dl3dv_depthsplat_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # base model 5 | # first train on re10k, 2 views, 256x448 6 | # train on 8x (8 nodes) 4x GPUs (>=80GB VRAM) for 150K steps, batch size 8 on each gpu 7 | python -m src.main +experiment=re10k \ 8 | data_loader.train.batch_size=8 \ 9 | dataset.test_chunk_interval=10 \ 10 | dataset.image_shape=[256,448] \ 11 | trainer.max_steps=150000 \ 12 | trainer.num_nodes=8 \ 13 | model.encoder.num_scales=2 \ 14 | model.encoder.upsample_factor=4 \ 15 | model.encoder.lowest_feature_resolution=8 \ 16 | model.encoder.monodepth_vit_type=vitb \ 17 | checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vitb.pth \ 18 | checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 19 | output_dir=checkpoints/re10k-256x448-depthsplat-base 20 | 21 | 22 | # finetune on dl3dv, random view 2-6 23 | # train on 8x GPUs (>=80GB VRAM) for 100K steps, batch size 1 on each gpu 24 | # resume from the previously pretrained model on re10k 25 | python -m src.main +experiment=dl3dv \ 26 | data_loader.train.batch_size=1 \ 27 | dataset.roots=[datasets/dl3dv] \ 28 | dataset.view_sampler.num_target_views=8 \ 29 | dataset.view_sampler.num_context_views=6 \ 30 | dataset.min_views=2 \ 31 | dataset.max_views=6 \ 32 | trainer.max_steps=100000 \ 33 | trainer.num_nodes=2 \ 34 | model.encoder.num_scales=2 \ 35 | model.encoder.upsample_factor=4 \ 36 | model.encoder.lowest_feature_resolution=8 \ 37 | model.encoder.monodepth_vit_type=vitb \ 38 | checkpointing.pretrained_model=pretrained/depthsplat-gs-base-re10k-256x448-view2-76a0605a.pth \ 39 | wandb.project=depthsplat \ 40 | output_dir=checkpoints/dl3dv-256x448-depthsplat-base-randview2-6 41 | 42 | 43 | -------------------------------------------------------------------------------- /scripts/inference_depth.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # base model: depth prediction on re10k: 2 input views 352x640 5 | CUDA_VISIBLE_DEVICES=0 python -m src.main +experiment=re10k \ 6 | dataset.test_chunk_interval=10 \ 7 | mode=test \ 8 | dataset/view_sampler=evaluation \ 9 | dataset.image_shape=[352,640] \ 10 | test.compute_scores=false \ 11 | dataset.view_sampler.num_context_views=2 \ 12 | model.encoder.num_scales=2 \ 13 | model.encoder.upsample_factor=4 \ 14 | model.encoder.lowest_feature_resolution=8 \ 15 | model.encoder.monodepth_vit_type=vitb \ 16 | train.forward_depth_only=true \ 17 | checkpointing.pretrained_depth=pretrained/depthsplat-depth-base-352x640-randview2-8-65a892c5.pth \ 18 | test.compute_scores=false \ 19 | test.save_depth=true \ 20 | test.save_depth_concat_img=true \ 21 | output_dir=outputs/depthsplat-depth-base-re10k 22 | 23 | 24 | # base model: depth prediction on re10k: 6 input views 352x640 25 | CUDA_VISIBLE_DEVICES=0 python -m src.main +experiment=dl3dv \ 26 | dataset.test_chunk_interval=10 \ 27 | mode=test \ 28 | dataset.roots=[datasets/re10k] \ 29 | dataset/view_sampler=evaluation \ 30 | dataset.view_sampler.num_context_views=6 \ 31 | dataset.view_sampler.index_path=assets/re10k_ctx_6v_video.json \ 32 | dataset.image_shape=[352,640] \ 33 | dataset.ori_image_shape=[360,640] \ 34 | model.encoder.num_scales=2 \ 35 | model.encoder.upsample_factor=4 \ 36 | model.encoder.lowest_feature_resolution=8 \ 37 | model.encoder.monodepth_vit_type=vitb \ 38 | train.forward_depth_only=true \ 39 | checkpointing.pretrained_depth=pretrained/depthsplat-depth-base-352x640-randview2-8-65a892c5.pth \ 40 | test.compute_scores=false \ 41 | test.save_depth=true \ 42 | test.save_depth_concat_img=true \ 43 | output_dir=outputs/depthsplat-depth-base-re10k-view6 44 | 45 | 46 | # base model: depth prediction on dl3dv: 12 input views 512x960 47 | CUDA_VISIBLE_DEVICES=0 python -m src.main +experiment=dl3dv \ 48 | dataset.test_chunk_interval=1 \ 49 | mode=test \ 50 | dataset.roots=[datasets/dl3dv_960p] \ 51 | dataset/view_sampler=evaluation \ 52 | dataset.image_shape=[512,960] \ 53 | dataset.ori_image_shape=[540,960] \ 54 | dataset.view_sampler.num_context_views=12 \ 55 | dataset.view_sampler.index_path=assets/dl3dv_start_0_distance_100_ctx_12v_video.json \ 56 | model.encoder.num_scales=2 \ 57 | model.encoder.upsample_factor=4 \ 58 | model.encoder.lowest_feature_resolution=8 \ 59 | model.encoder.monodepth_vit_type=vitb \ 60 | train.forward_depth_only=true \ 61 | checkpointing.pretrained_depth=pretrained/depthsplat-depth-base-352x640-randview2-8-65a892c5.pth \ 62 | test.compute_scores=false \ 63 | test.save_depth=true \ 64 | test.save_depth_concat_img=true \ 65 | output_dir=outputs/depthsplat-depth-base-dl3dv-view12 66 | 67 | -------------------------------------------------------------------------------- /scripts/re10k_depthsplat_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # small model 5 | # train on 4x GPUs (>=80GB VRAM) for 150K steps, batch size 8 on each gpu 6 | python -m src.main +experiment=re10k \ 7 | data_loader.train.batch_size=8 \ 8 | dataset.test_chunk_interval=10 \ 9 | trainer.max_steps=150000 \ 10 | model.encoder.upsample_factor=4 \ 11 | model.encoder.lowest_feature_resolution=4 \ 12 | checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vits.pth \ 13 | checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 14 | output_dir=checkpoints/re10k-256x256-depthsplat-small 15 | 16 | 17 | # or 18 | # 4x 4090 (24GB) for 300K steps, batch size 4 on each gpu 19 | # python -m src.main +experiment=re10k \ 20 | # data_loader.train.batch_size=4 \ 21 | # dataset.test_chunk_interval=10 \ 22 | # trainer.max_steps=300000 \ 23 | # model.encoder.upsample_factor=4 \ 24 | # model.encoder.lowest_feature_resolution=4 \ 25 | # checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vits.pth \ 26 | # checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 27 | # output_dir=checkpoints/re10k-256x256-depthsplat-small 28 | 29 | 30 | # or 31 | # a single A100 (80GB) for 600K steps, batch size 8 on each gpu 32 | # python -m src.main +experiment=re10k \ 33 | # data_loader.train.batch_size=8 \ 34 | # dataset.test_chunk_interval=10 \ 35 | # trainer.max_steps=600000 \ 36 | # model.encoder.upsample_factor=4 \ 37 | # model.encoder.lowest_feature_resolution=4 \ 38 | # checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vits.pth \ 39 | # checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 40 | # output_dir=checkpoints/re10k-256x256-depthsplat-small 41 | 42 | 43 | # how to resume if training crashes unexpectedly: 44 | # `checkpointing.resume=true`: find latest checkpoint and resume from it 45 | # `wandb.id=WANDB_ID`: continue logging to the same wandb run using the specified WANDB_ID 46 | # python -m src.main +experiment=re10k \ 47 | # data_loader.train.batch_size=8 \ 48 | # dataset.test_chunk_interval=10 \ 49 | # trainer.max_steps=150000 \ 50 | # model.encoder.upsample_factor=4 \ 51 | # model.encoder.lowest_feature_resolution=4 \ 52 | # checkpointing.resume=true \ 53 | # wandb.id=WANDB_ID \ 54 | # output_dir=checkpoints/re10k-256x256-depthsplat-small 55 | 56 | 57 | # base model 58 | # train on 4x GPUs (>=80GB VRAM) for 150K steps, batch size 8 on each gpu 59 | python -m src.main +experiment=re10k \ 60 | data_loader.train.batch_size=8 \ 61 | dataset.test_chunk_interval=10 \ 62 | trainer.max_steps=150000 \ 63 | model.encoder.num_scales=2 \ 64 | model.encoder.upsample_factor=2 \ 65 | model.encoder.lowest_feature_resolution=4 \ 66 | model.encoder.monodepth_vit_type=vitb \ 67 | checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vitb.pth \ 68 | checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 69 | output_dir=checkpoints/re10k-256x256-depthsplat-base 70 | 71 | 72 | # large model 73 | # train on 4x GPUs (>=80GB VRAM) for 150K steps, batch size 8 on each gpu 74 | python -m src.main +experiment=re10k \ 75 | data_loader.train.batch_size=8 \ 76 | dataset.test_chunk_interval=10 \ 77 | trainer.max_steps=150000 \ 78 | model.encoder.num_scales=2 \ 79 | model.encoder.upsample_factor=2 \ 80 | model.encoder.lowest_feature_resolution=4 \ 81 | model.encoder.monodepth_vit_type=vitl \ 82 | checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vitl.pth \ 83 | checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \ 84 | output_dir=checkpoints/re10k-256x256-depthsplat-large 85 | 86 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset.data_module import DataLoaderCfg, DatasetCfg 9 | from .loss import LossCfgWrapper 10 | from .model.decoder import DecoderCfg 11 | from .model.encoder import EncoderCfg 12 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 13 | 14 | 15 | @dataclass 16 | class CheckpointingCfg: 17 | load: Optional[str] # Not a path, since it could be something like wandb://... 18 | every_n_train_steps: int 19 | save_top_k: int 20 | pretrained_model: Optional[str] 21 | pretrained_monodepth: Optional[str] 22 | pretrained_mvdepth: Optional[str] 23 | pretrained_depth: Optional[str] 24 | no_strict_load: bool 25 | resume: bool 26 | 27 | 28 | @dataclass 29 | class ModelCfg: 30 | decoder: DecoderCfg 31 | encoder: EncoderCfg 32 | 33 | 34 | @dataclass 35 | class TrainerCfg: 36 | max_steps: int 37 | val_check_interval: int | float | None 38 | gradient_clip_val: int | float | None 39 | num_sanity_val_steps: int 40 | num_nodes: int 41 | 42 | 43 | @dataclass 44 | class RootCfg: 45 | wandb: dict 46 | mode: Literal["train", "test"] 47 | dataset: DatasetCfg 48 | data_loader: DataLoaderCfg 49 | model: ModelCfg 50 | optimizer: OptimizerCfg 51 | checkpointing: CheckpointingCfg 52 | trainer: TrainerCfg 53 | loss: list[LossCfgWrapper] 54 | test: TestCfg 55 | train: TrainCfg 56 | seed: int 57 | use_plugins: bool 58 | 59 | 60 | TYPE_HOOKS = { 61 | Path: Path, 62 | } 63 | 64 | 65 | T = TypeVar("T") 66 | 67 | 68 | def load_typed_config( 69 | cfg: DictConfig, 70 | data_class: Type[T], 71 | extra_type_hooks: dict = {}, 72 | ) -> T: 73 | return from_dict( 74 | data_class, 75 | OmegaConf.to_container(cfg), 76 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 77 | ) 78 | 79 | 80 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 81 | # The dummy allows the union to be converted. 82 | @dataclass 83 | class Dummy: 84 | dummy: LossCfgWrapper 85 | 86 | return [ 87 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 88 | for k, v in joined.items() 89 | ] 90 | 91 | 92 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 93 | return load_typed_config( 94 | cfg, 95 | RootCfg, 96 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers}, 97 | ) 98 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from ..misc.step_tracker import StepTracker 4 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg 5 | from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfg 6 | from .types import Stage 7 | from .view_sampler import get_view_sampler 8 | 9 | DATASETS: dict[str, Dataset] = { 10 | "re10k": DatasetRE10k, 11 | "dl3dv": DatasetDL3DV, 12 | } 13 | 14 | 15 | DatasetCfg = DatasetRE10kCfg | DatasetDL3DVCfg 16 | 17 | 18 | def get_dataset( 19 | cfg: DatasetCfg, 20 | stage: Stage, 21 | step_tracker: StepTracker | None, 22 | ) -> Dataset: 23 | view_sampler = get_view_sampler( 24 | cfg.view_sampler, 25 | stage, 26 | cfg.overfit_to_scene is not None, 27 | cfg.cameras_are_circular, 28 | step_tracker, 29 | ) 30 | return DATASETS[cfg.name](cfg, stage, view_sampler) 31 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | from pytorch_lightning import LightningDataModule 8 | from torch import Generator, nn 9 | from torch.utils.data import DataLoader, Dataset, IterableDataset 10 | 11 | from ..misc.step_tracker import StepTracker 12 | from . import DatasetCfg, get_dataset 13 | from .types import DataShim, Stage 14 | from .validation_wrapper import ValidationWrapper 15 | 16 | 17 | def get_data_shim(encoder: nn.Module) -> DataShim: 18 | """Get functions that modify the batch. It's sometimes necessary to modify batches 19 | outside the data loader because GPU computations are required to modify the batch or 20 | because the modification depends on something outside the data loader. 21 | """ 22 | 23 | shims: list[DataShim] = [] 24 | if hasattr(encoder, "get_data_shim"): 25 | shims.append(encoder.get_data_shim()) 26 | 27 | def combined_shim(batch): 28 | for shim in shims: 29 | batch = shim(batch) 30 | return batch 31 | 32 | return combined_shim 33 | 34 | 35 | @dataclass 36 | class DataLoaderStageCfg: 37 | batch_size: int 38 | num_workers: int 39 | persistent_workers: bool 40 | seed: int | None 41 | 42 | 43 | @dataclass 44 | class DataLoaderCfg: 45 | train: DataLoaderStageCfg 46 | test: DataLoaderStageCfg 47 | val: DataLoaderStageCfg 48 | 49 | 50 | DatasetShim = Callable[[Dataset, Stage], Dataset] 51 | 52 | 53 | def worker_init_fn(worker_id: int) -> None: 54 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 55 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 56 | 57 | 58 | class DataModule(LightningDataModule): 59 | dataset_cfg: DatasetCfg 60 | data_loader_cfg: DataLoaderCfg 61 | step_tracker: StepTracker | None 62 | dataset_shim: DatasetShim 63 | global_rank: int 64 | 65 | def __init__( 66 | self, 67 | dataset_cfg: DatasetCfg, 68 | data_loader_cfg: DataLoaderCfg, 69 | step_tracker: StepTracker | None = None, 70 | dataset_shim: DatasetShim = lambda dataset, _: dataset, 71 | global_rank: int = 0, 72 | ) -> None: 73 | super().__init__() 74 | self.dataset_cfg = dataset_cfg 75 | self.data_loader_cfg = data_loader_cfg 76 | self.step_tracker = step_tracker 77 | self.dataset_shim = dataset_shim 78 | self.global_rank = global_rank 79 | 80 | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: 81 | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers 82 | 83 | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: 84 | if loader_cfg.seed is None: 85 | return None 86 | generator = Generator() 87 | generator.manual_seed(loader_cfg.seed + self.global_rank) 88 | return generator 89 | 90 | def train_dataloader(self): 91 | dataset = get_dataset(self.dataset_cfg, "train", self.step_tracker) 92 | dataset = self.dataset_shim(dataset, "train") 93 | return DataLoader( 94 | dataset, 95 | self.data_loader_cfg.train.batch_size, 96 | shuffle=not isinstance(dataset, IterableDataset), 97 | num_workers=self.data_loader_cfg.train.num_workers, 98 | generator=self.get_generator(self.data_loader_cfg.train), 99 | worker_init_fn=worker_init_fn, 100 | persistent_workers=self.get_persistent(self.data_loader_cfg.train), 101 | ) 102 | 103 | def val_dataloader(self): 104 | dataset = get_dataset(self.dataset_cfg, "val", self.step_tracker) 105 | dataset = self.dataset_shim(dataset, "val") 106 | return DataLoader( 107 | ValidationWrapper(dataset, 1), 108 | self.data_loader_cfg.val.batch_size, 109 | num_workers=self.data_loader_cfg.val.num_workers, 110 | generator=self.get_generator(self.data_loader_cfg.val), 111 | worker_init_fn=worker_init_fn, 112 | persistent_workers=self.get_persistent(self.data_loader_cfg.val), 113 | ) 114 | 115 | def test_dataloader(self, dataset_cfg=None): 116 | dataset = get_dataset( 117 | self.dataset_cfg if dataset_cfg is None else dataset_cfg, 118 | "test", 119 | self.step_tracker, 120 | ) 121 | dataset = self.dataset_shim(dataset, "test") 122 | return DataLoader( 123 | dataset, 124 | self.data_loader_cfg.test.batch_size, 125 | num_workers=self.data_loader_cfg.test.num_workers, 126 | generator=self.get_generator(self.data_loader_cfg.test), 127 | worker_init_fn=worker_init_fn, 128 | persistent_workers=self.get_persistent(self.data_loader_cfg.test), 129 | shuffle=False, 130 | ) 131 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | 9 | from ..types import AnyExample, AnyViews 10 | 11 | 12 | def rescale( 13 | image: Float[Tensor, "3 h_in w_in"], 14 | shape: tuple[int, int], 15 | ) -> Float[Tensor, "3 h_out w_out"]: 16 | h, w = shape 17 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 18 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 19 | image_new = Image.fromarray(image_new) 20 | image_new = image_new.resize((w, h), Image.LANCZOS) 21 | image_new = np.array(image_new) / 255 22 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 23 | return rearrange(image_new, "h w c -> c h w") 24 | 25 | 26 | def center_crop( 27 | images: Float[Tensor, "*#batch c h w"], 28 | intrinsics: Float[Tensor, "*#batch 3 3"], 29 | shape: tuple[int, int], 30 | depths: None | Float[Tensor, "*#batch h w"], 31 | ) -> ( 32 | tuple[ 33 | Float[Tensor, "*#batch c h_out w_out"], # updated images 34 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 35 | ] 36 | | tuple[ 37 | Float[Tensor, "*#batch c h_out w_out"], # updated images 38 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 39 | Float[Tensor, "*#batch h_out w_out"], # updated depths 40 | ] 41 | ): 42 | *_, h_in, w_in = images.shape 43 | h_out, w_out = shape 44 | 45 | # Note that odd input dimensions induce half-pixel misalignments. 46 | row = (h_in - h_out) // 2 47 | col = (w_in - w_out) // 2 48 | 49 | # Center-crop the image. 50 | images = images[..., :, row : row + h_out, col : col + w_out] 51 | 52 | # Adjust the intrinsics to account for the cropping. 53 | intrinsics = intrinsics.clone() 54 | intrinsics[..., 0, 0] *= w_in / w_out # fx 55 | intrinsics[..., 1, 1] *= h_in / h_out # fy 56 | 57 | if depths is not None: 58 | depths = depths[..., :, row : row + h_out, col : col + w_out] 59 | return images, intrinsics, depths 60 | 61 | return images, intrinsics 62 | 63 | 64 | def rescale_and_crop( 65 | images: Float[Tensor, "*#batch c h w"], 66 | intrinsics: Float[Tensor, "*#batch 3 3"], 67 | shape: tuple[int, int], 68 | depths: None | Float[Tensor, "*#batch h w"], 69 | ) -> ( 70 | tuple[ 71 | Float[Tensor, "*#batch c h_out w_out"], # updated images 72 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 73 | ] 74 | | tuple[ 75 | Float[Tensor, "*#batch c h_out w_out"], # updated images 76 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 77 | Float[Tensor, "*#batch h_out w_out"], # updated depths 78 | ] 79 | ): 80 | *_, h_in, w_in = images.shape 81 | h_out, w_out = shape 82 | assert h_out <= h_in and w_out <= w_in 83 | 84 | scale_factor = max(h_out / h_in, w_out / w_in) 85 | h_scaled = round(h_in * scale_factor) 86 | w_scaled = round(w_in * scale_factor) 87 | assert h_scaled == h_out or w_scaled == w_out 88 | 89 | # Reshape the images to the correct size. Assume we don't have to worry about 90 | # changing the intrinsics based on how the images are rounded. 91 | *batch, c, h, w = images.shape 92 | images = images.reshape(-1, c, h, w) 93 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 94 | images = images.reshape(*batch, c, h_scaled, w_scaled) 95 | 96 | # reshape and crop depth as well when available 97 | if depths is not None: 98 | depths = F.interpolate( 99 | depths.unsqueeze(1), 100 | size=(h_scaled, w_scaled), 101 | mode="bilinear", 102 | align_corners=True, 103 | ).squeeze(1) 104 | 105 | return center_crop(images, intrinsics, shape, depths=depths) 106 | 107 | 108 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 109 | images, intrinsics = rescale_and_crop( 110 | views["image"], views["intrinsics"], shape, depths=None 111 | ) 112 | return { 113 | **views, 114 | "image": images, 115 | "intrinsics": intrinsics, 116 | } 117 | 118 | 119 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 120 | """Crop images in the example.""" 121 | return { 122 | **example, 123 | "context": apply_crop_shim_to_views(example["context"], shape), 124 | "target": apply_crop_shim_to_views(example["target"], shape), 125 | } 126 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ...misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | from .view_sampler_bounded_v2 import ViewSamplerBoundedV2, ViewSamplerBoundedV2Cfg 11 | 12 | 13 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 14 | "all": ViewSamplerAll, 15 | "arbitrary": ViewSamplerArbitrary, 16 | "bounded": ViewSamplerBounded, 17 | "evaluation": ViewSamplerEvaluation, 18 | "boundedv2": ViewSamplerBoundedV2, 19 | } 20 | 21 | ViewSamplerCfg = ( 22 | ViewSamplerArbitraryCfg 23 | | ViewSamplerBoundedCfg 24 | | ViewSamplerEvaluationCfg 25 | | ViewSamplerAllCfg 26 | | ViewSamplerBoundedV2Cfg 27 | ) 28 | 29 | 30 | def get_view_sampler( 31 | cfg: ViewSamplerCfg, 32 | stage: Stage, 33 | overfit: bool, 34 | cameras_are_circular: bool, 35 | step_tracker: StepTracker | None, 36 | ) -> ViewSampler[Any]: 37 | return VIEW_SAMPLERS[cfg.name]( 38 | cfg, 39 | stage, 40 | overfit, 41 | cameras_are_circular, 42 | step_tracker, 43 | ) 44 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | **kwargs, 43 | ) -> tuple[ 44 | Int64[Tensor, " context_view"], # indices for context views 45 | Int64[Tensor, " target_view"], # indices for target views 46 | ]: 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def num_target_views(self) -> int: 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def num_context_views(self) -> int: 57 | pass 58 | 59 | @property 60 | def global_step(self) -> int: 61 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 62 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | **kwargs, 24 | ) -> tuple[ 25 | Int64[Tensor, " context_view"], # indices for context views 26 | Int64[Tensor, " target_view"], # indices for target views 27 | ]: 28 | v, _, _ = extrinsics.shape 29 | all_frames = torch.arange(v, device=device) 30 | return all_frames, all_frames 31 | 32 | @property 33 | def num_context_views(self) -> int: 34 | return 0 35 | 36 | @property 37 | def num_target_views(self) -> int: 38 | return 0 39 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerArbitraryCfg: 13 | name: Literal["arbitrary"] 14 | num_context_views: int 15 | num_target_views: int 16 | context_views: list[int] | None 17 | target_views: list[int] | None 18 | 19 | 20 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 21 | def sample( 22 | self, 23 | scene: str, 24 | extrinsics: Float[Tensor, "view 4 4"], 25 | intrinsics: Float[Tensor, "view 3 3"], 26 | device: torch.device = torch.device("cpu"), 27 | **kwargs, 28 | ) -> tuple[ 29 | Int64[Tensor, " context_view"], # indices for context views 30 | Int64[Tensor, " target_view"], # indices for target views 31 | ]: 32 | """Arbitrarily sample context and target views.""" 33 | num_views, _, _ = extrinsics.shape 34 | 35 | index_context = torch.randint( 36 | 0, 37 | num_views, 38 | size=(self.cfg.num_context_views,), 39 | device=device, 40 | ) 41 | 42 | # Allow the context views to be fixed. 43 | if self.cfg.context_views is not None: 44 | assert len(self.cfg.context_views) == self.cfg.num_context_views 45 | index_context = torch.tensor( 46 | self.cfg.context_views, dtype=torch.int64, device=device 47 | ) 48 | 49 | index_target = torch.randint( 50 | 0, 51 | num_views, 52 | size=(self.cfg.num_target_views,), 53 | device=device, 54 | ) 55 | 56 | # Allow the target views to be fixed. 57 | if self.cfg.target_views is not None: 58 | assert len(self.cfg.target_views) == self.cfg.num_target_views 59 | index_target = torch.tensor( 60 | self.cfg.target_views, dtype=torch.int64, device=device 61 | ) 62 | 63 | return index_context, index_target 64 | 65 | @property 66 | def num_context_views(self) -> int: 67 | return self.cfg.num_context_views 68 | 69 | @property 70 | def num_target_views(self) -> int: 71 | return self.cfg.num_target_views 72 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | name: Literal["bounded"] 14 | num_context_views: int 15 | num_target_views: int 16 | min_distance_between_context_views: int 17 | max_distance_between_context_views: int 18 | min_distance_to_context_views: int 19 | warm_up_steps: int 20 | initial_min_distance_between_context_views: int 21 | initial_max_distance_between_context_views: int 22 | 23 | 24 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 25 | def schedule(self, initial: int, final: int) -> int: 26 | fraction = self.global_step / self.cfg.warm_up_steps 27 | return min(initial + int((final - initial) * fraction), final) 28 | 29 | def sample( 30 | self, 31 | scene: str, 32 | extrinsics: Float[Tensor, "view 4 4"], 33 | intrinsics: Float[Tensor, "view 3 3"], 34 | device: torch.device = torch.device("cpu"), 35 | min_view_dist: int | None = None, 36 | max_view_dist: int | None = None, 37 | **kwargs, 38 | ) -> tuple[ 39 | Int64[Tensor, " context_view"], # indices for context views 40 | Int64[Tensor, " target_view"], # indices for target views 41 | ]: 42 | num_views, _, _ = extrinsics.shape 43 | 44 | # Compute the context view spacing based on the current global step. 45 | if self.stage == "test": 46 | # When testing, always use the full gap. 47 | max_gap = self.cfg.max_distance_between_context_views 48 | min_gap = self.cfg.max_distance_between_context_views 49 | elif self.cfg.warm_up_steps > 0: 50 | max_gap = self.schedule( 51 | self.cfg.initial_max_distance_between_context_views, 52 | self.cfg.max_distance_between_context_views, 53 | ) 54 | min_gap = self.schedule( 55 | self.cfg.initial_min_distance_between_context_views, 56 | self.cfg.min_distance_between_context_views, 57 | ) 58 | else: 59 | max_gap = self.cfg.max_distance_between_context_views 60 | min_gap = self.cfg.min_distance_between_context_views 61 | 62 | # Pick the gap between the context views. 63 | if not self.cameras_are_circular: 64 | max_gap = min(num_views - 1, max_gap) 65 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 66 | 67 | # overwrite min_gap and max_gap, useful for mixed dataset training 68 | # use different view distance for different dataset 69 | if min_view_dist is not None: 70 | min_gap = min_view_dist 71 | 72 | if max_view_dist is not None: 73 | max_gap = max_view_dist 74 | 75 | if max_gap < min_gap: 76 | raise ValueError("Example does not have enough frames!") 77 | context_gap = torch.randint( 78 | min_gap, 79 | max_gap + 1, 80 | size=tuple(), 81 | device=device, 82 | ).item() 83 | 84 | # Pick the left and right context indices. 85 | index_context_left = torch.randint( 86 | num_views if self.cameras_are_circular else num_views - context_gap, 87 | size=tuple(), 88 | device=device, 89 | ).item() 90 | if self.stage == "test": 91 | index_context_left = index_context_left * 0 92 | index_context_right = index_context_left + context_gap 93 | 94 | if self.is_overfitting: 95 | index_context_left *= 0 96 | index_context_right *= 0 97 | index_context_right += max_gap 98 | 99 | # Pick the target view indices. 100 | if self.stage == "test": 101 | # When testing, pick all. 102 | index_target = torch.arange( 103 | index_context_left, 104 | index_context_right + 1, 105 | device=device, 106 | ) 107 | else: 108 | # When training or validating (visualizing), pick at random. 109 | index_target = torch.randint( 110 | index_context_left + self.cfg.min_distance_to_context_views, 111 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 112 | size=(self.cfg.num_target_views,), 113 | device=device, 114 | ) 115 | 116 | # Apply modulo for circular datasets. 117 | if self.cameras_are_circular: 118 | index_target %= num_views 119 | index_context_right %= num_views 120 | 121 | return ( 122 | torch.tensor((index_context_left, index_context_right)), 123 | index_target, 124 | ) 125 | 126 | @property 127 | def num_context_views(self) -> int: 128 | return 2 129 | 130 | @property 131 | def num_target_views(self) -> int: 132 | return self.cfg.num_target_views 133 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...misc.step_tracker import StepTracker 13 | from ..types import Stage 14 | from .view_sampler import ViewSampler 15 | 16 | 17 | @dataclass 18 | class ViewSamplerEvaluationCfg: 19 | name: Literal["evaluation"] 20 | index_path: Path 21 | num_context_views: int 22 | 23 | 24 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 25 | index: dict[str, IndexEntry | None] 26 | 27 | def __init__( 28 | self, 29 | cfg: ViewSamplerEvaluationCfg, 30 | stage: Stage, 31 | is_overfitting: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> None: 35 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 36 | 37 | dacite_config = Config(cast=[tuple]) 38 | with cfg.index_path.open("r") as f: 39 | self.index = { 40 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 41 | for k, v in json.load(f).items() 42 | } 43 | 44 | def sample( 45 | self, 46 | scene: str, 47 | extrinsics: Float[Tensor, "view 4 4"], 48 | intrinsics: Float[Tensor, "view 3 3"], 49 | device: torch.device = torch.device("cpu"), 50 | **kwargs, 51 | ) -> tuple[ 52 | Int64[Tensor, " context_view"], # indices for context views 53 | Int64[Tensor, " target_view"], # indices for target views 54 | ]: 55 | entry = self.index.get(scene) 56 | if entry is None: 57 | raise ValueError(f"No indices available for scene {scene}.") 58 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 59 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 60 | return context_indices, target_indices 61 | 62 | @property 63 | def num_context_views(self) -> int: 64 | return 0 65 | 66 | @property 67 | def num_target_views(self) -> int: 68 | return 0 69 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | animate_side_by_side: bool 23 | highlighted: list[SceneCfg] 24 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_index_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | import torch 6 | from einops import rearrange 7 | from pytorch_lightning import LightningModule 8 | from tqdm import tqdm 9 | 10 | from ..geometry.epipolar_lines import project_rays 11 | from ..geometry.projection import get_world_rays, sample_image_grid 12 | from ..misc.image_io import save_image 13 | from ..visualization.annotation import add_label 14 | from ..visualization.layout import add_border, hcat 15 | 16 | 17 | @dataclass 18 | class EvaluationIndexGeneratorCfg: 19 | num_target_views: int 20 | min_distance: int 21 | max_distance: int 22 | min_overlap: float 23 | max_overlap: float 24 | output_path: Path 25 | save_previews: bool 26 | seed: int 27 | 28 | 29 | @dataclass 30 | class IndexEntry: 31 | context: tuple[int, ...] 32 | target: tuple[int, ...] 33 | 34 | 35 | class EvaluationIndexGenerator(LightningModule): 36 | generator: torch.Generator 37 | cfg: EvaluationIndexGeneratorCfg 38 | index: dict[str, IndexEntry | None] 39 | 40 | def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: 41 | super().__init__() 42 | self.cfg = cfg 43 | self.generator = torch.Generator() 44 | self.generator.manual_seed(cfg.seed) 45 | self.index = {} 46 | 47 | def test_step(self, batch, batch_idx): 48 | b, v, _, h, w = batch["target"]["image"].shape 49 | assert b == 1 50 | extrinsics = batch["target"]["extrinsics"][0] 51 | intrinsics = batch["target"]["intrinsics"][0] 52 | scene = batch["scene"][0] 53 | 54 | context_indices = torch.randperm(v, generator=self.generator) 55 | for context_index in tqdm(context_indices, "Finding context pair"): 56 | xy, _ = sample_image_grid((h, w), self.device) 57 | context_origins, context_directions = get_world_rays( 58 | rearrange(xy, "h w xy -> (h w) xy"), 59 | extrinsics[context_index], 60 | intrinsics[context_index], 61 | ) 62 | 63 | # Step away from context view until the minimum overlap threshold is met. 64 | valid_indices = [] 65 | for step in (1, -1): 66 | min_distance = self.cfg.min_distance 67 | max_distance = self.cfg.max_distance 68 | current_index = context_index + step * min_distance 69 | 70 | while 0 <= current_index.item() < v: 71 | # Compute overlap. 72 | current_origins, current_directions = get_world_rays( 73 | rearrange(xy, "h w xy -> (h w) xy"), 74 | extrinsics[current_index], 75 | intrinsics[current_index], 76 | ) 77 | projection_onto_current = project_rays( 78 | context_origins, 79 | context_directions, 80 | extrinsics[current_index], 81 | intrinsics[current_index], 82 | ) 83 | projection_onto_context = project_rays( 84 | current_origins, 85 | current_directions, 86 | extrinsics[context_index], 87 | intrinsics[context_index], 88 | ) 89 | overlap_a = projection_onto_context["overlaps_image"].float().mean() 90 | overlap_b = projection_onto_current["overlaps_image"].float().mean() 91 | 92 | overlap = min(overlap_a, overlap_b) 93 | delta = (current_index - context_index).abs() 94 | 95 | min_overlap = self.cfg.min_overlap 96 | max_overlap = self.cfg.max_overlap 97 | if min_overlap <= overlap <= max_overlap: 98 | valid_indices.append( 99 | (current_index.item(), overlap_a, overlap_b) 100 | ) 101 | 102 | # Stop once the camera has panned away too much. 103 | if overlap < min_overlap or delta > max_distance: 104 | break 105 | 106 | current_index += step 107 | 108 | if valid_indices: 109 | # Pick a random valid view. Index the resulting views. 110 | num_options = len(valid_indices) 111 | chosen = torch.randint( 112 | 0, num_options, size=tuple(), generator=self.generator 113 | ) 114 | chosen, overlap_a, overlap_b = valid_indices[chosen] 115 | 116 | context_left = min(chosen, context_index.item()) 117 | context_right = max(chosen, context_index.item()) 118 | delta = context_right - context_left 119 | 120 | # Pick non-repeated random target views. 121 | while True: 122 | target_views = torch.randint( 123 | context_left, 124 | context_right + 1, 125 | (self.cfg.num_target_views,), 126 | generator=self.generator, 127 | ) 128 | if (target_views.unique(return_counts=True)[1] == 1).all(): 129 | break 130 | 131 | target = tuple(sorted(target_views.tolist())) 132 | self.index[scene] = IndexEntry( 133 | context=(context_left, context_right), 134 | target=target, 135 | ) 136 | 137 | # Optionally, save a preview. 138 | if self.cfg.save_previews: 139 | preview_path = self.cfg.output_path / "previews" 140 | preview_path.mkdir(exist_ok=True, parents=True) 141 | a = batch["target"]["image"][0, chosen] 142 | a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") 143 | b = batch["target"]["image"][0, context_index] 144 | b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") 145 | vis = add_border(add_border(hcat(a, b)), 1, 0) 146 | vis = add_label(vis, f"Distance: {delta} frames") 147 | save_image(add_border(vis), preview_path / f"{scene}.png") 148 | break 149 | else: 150 | # This happens if no starting frame produces a valid evaluation example. 151 | self.index[scene] = None 152 | 153 | def save_index(self) -> None: 154 | self.cfg.output_path.mkdir(exist_ok=True, parents=True) 155 | with (self.cfg.output_path / "evaluation_index.json").open("w") as f: 156 | json.dump( 157 | {k: None if v is None else asdict(v) for k, v in self.index.items()}, f 158 | ) 159 | -------------------------------------------------------------------------------- /src/evaluation/metric_computer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from tabulate import tabulate 7 | 8 | from ..misc.image_io import load_image, save_image 9 | from ..visualization.annotation import add_label 10 | from ..visualization.layout import add_border, hcat 11 | from .evaluation_cfg import EvaluationCfg 12 | from .metrics import compute_lpips, compute_psnr, compute_ssim 13 | 14 | 15 | class MetricComputer(LightningModule): 16 | cfg: EvaluationCfg 17 | 18 | def __init__(self, cfg: EvaluationCfg) -> None: 19 | super().__init__() 20 | self.cfg = cfg 21 | 22 | def test_step(self, batch, batch_idx): 23 | scene = batch["scene"][0] 24 | b, cv, _, _, _ = batch["context"]["image"].shape 25 | assert b == 1 and cv == 2 26 | _, v, _, _, _ = batch["target"]["image"].shape 27 | 28 | # Skip scenes. 29 | for method in self.cfg.methods: 30 | if not (method.path / scene).exists(): 31 | print(f'Skipping "{scene}".') 32 | return 33 | 34 | # Load the images. 35 | all_images = {} 36 | try: 37 | for method in self.cfg.methods: 38 | images = [ 39 | load_image(method.path / scene / f"color/{index.item():0>6}.png") 40 | for index in batch["target"]["index"][0] 41 | ] 42 | all_images[method.key] = torch.stack(images).to(self.device) 43 | except FileNotFoundError: 44 | print(f'Skipping "{scene}".') 45 | return 46 | 47 | # Compute metrics. 48 | all_metrics = {} 49 | rgb_gt = batch["target"]["image"][0] 50 | for key, images in all_images.items(): 51 | all_metrics = { 52 | **all_metrics, 53 | f"lpips_{key}": compute_lpips(rgb_gt, images).mean(), 54 | f"ssim_{key}": compute_ssim(rgb_gt, images).mean(), 55 | f"psnr_{key}": compute_psnr(rgb_gt, images).mean(), 56 | } 57 | self.log_dict(all_metrics) 58 | self.print_preview_metrics(all_metrics) 59 | 60 | # Skip the rest if no side-by-side is needed. 61 | if self.cfg.side_by_side_path is None: 62 | return 63 | 64 | # Create side-by-side. 65 | scene_key = f"{batch_idx:0>6}_{scene}" 66 | for i in range(v): 67 | true_index = batch["target"]["index"][0, i] 68 | row = [add_label(batch["target"]["image"][0, i], "Ground Truth")] 69 | for method in self.cfg.methods: 70 | image = all_images[method.key][i] 71 | image = add_label(image, method.name) 72 | row.append(image) 73 | start_frame = batch["target"]["index"][0, 0] 74 | end_frame = batch["target"]["index"][0, -1] 75 | label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})" 76 | row = add_border(add_label(hcat(*row), label, font_size=16)) 77 | save_image( 78 | row, 79 | self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png", 80 | ) 81 | 82 | # Create an animation. 83 | if self.cfg.animate_side_by_side: 84 | (self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True) 85 | command = ( 86 | 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 ' 87 | '-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"' 88 | ) 89 | os.system( 90 | f"cd {self.cfg.side_by_side_path / scene_key} && {command} " 91 | f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4" 92 | ) 93 | 94 | def print_preview_metrics(self, metrics: dict[str, float]) -> None: 95 | if getattr(self, "running_metrics", None) is None: 96 | self.running_metrics = metrics 97 | self.running_metric_steps = 1 98 | else: 99 | s = self.running_metric_steps 100 | self.running_metrics = { 101 | k: ((s * v) + metrics[k]) / (s + 1) 102 | for k, v in self.running_metrics.items() 103 | } 104 | self.running_metric_steps += 1 105 | 106 | table = [] 107 | for method in self.cfg.methods: 108 | row = [ 109 | f"{self.running_metrics[f'{metric}_{method.key}']:.3f}" 110 | for metric in ("psnr", "lpips", "ssim") 111 | ] 112 | table.append((method.key, *row)) 113 | 114 | table = tabulate(table, ["Method", "PSNR (dB)", "LPIPS", "SSIM"]) 115 | print(table) 116 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 3 | from .loss_mse import LossMse, LossMseCfgWrapper 4 | 5 | LOSSES = { 6 | LossLpipsCfgWrapper: LossLpips, 7 | LossMseCfgWrapper: LossMse, 8 | } 9 | 10 | LossCfgWrapper = LossLpipsCfgWrapper | LossMseCfgWrapper 11 | 12 | 13 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 14 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 15 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians: Gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians: Gaussians, 41 | global_step: int, 42 | valid_depth_mask: Tensor | None 43 | ) -> Float[Tensor, ""]: 44 | image = batch["target"]["image"] 45 | 46 | # Before the specified step, don't apply the loss. 47 | if global_step < self.cfg.apply_after_step: 48 | return torch.tensor(0, dtype=torch.float32, device=image.device) 49 | 50 | if valid_depth_mask is not None and valid_depth_mask.max() > 0.5: 51 | prediction.color[valid_depth_mask] = 0 52 | image[valid_depth_mask] = 0 53 | 54 | loss = self.lpips.forward( 55 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 56 | rearrange(image, "b v c h w -> (b v) c h w"), 57 | normalize=True, 58 | ) 59 | return self.cfg.weight * loss.mean() 60 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians: Gaussians, 28 | global_step: int, 29 | l1_loss: bool, 30 | clamp_large_error: float, 31 | valid_depth_mask: Tensor | None 32 | ) -> Float[Tensor, ""]: 33 | delta = prediction.color - batch["target"]["image"] 34 | 35 | if valid_depth_mask is not None and valid_depth_mask.max() > 0.5 and valid_depth_mask.min() < 0.5: 36 | delta = delta[~valid_depth_mask] 37 | 38 | if clamp_large_error > 0: 39 | valid_mask = (delta ** 2) < clamp_large_error 40 | delta = delta[valid_mask] 41 | 42 | if l1_loss: 43 | return self.cfg.weight * (delta.abs()).mean() 44 | return self.cfg.weight * (delta**2).mean() 45 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | from typing import Any, Optional 6 | 7 | from PIL import Image 8 | from pytorch_lightning.loggers.logger import Logger 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | LOG_PATH = Path("outputs/local") 12 | 13 | 14 | class LocalLogger(Logger): 15 | def __init__(self) -> None: 16 | super().__init__() 17 | self.experiment = None 18 | os.system(f"rm -r {LOG_PATH}") 19 | 20 | @property 21 | def name(self): 22 | return "LocalLogger" 23 | 24 | @property 25 | def version(self): 26 | return 0 27 | 28 | @rank_zero_only 29 | def log_hyperparams(self, params): 30 | pass 31 | 32 | @rank_zero_only 33 | def log_metrics(self, metrics, step): 34 | pass 35 | 36 | @rank_zero_only 37 | def log_image( 38 | self, 39 | key: str, 40 | images: list[Any], 41 | step: Optional[int] = None, 42 | **kwargs, 43 | ): 44 | # The function signature is the same as the wandb logger's, but the step is 45 | # actually required. 46 | assert step is not None 47 | for index, image in enumerate(images): 48 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 49 | path.parent.mkdir(exist_ok=True, parents=True) 50 | if isinstance(image, torch.Tensor): 51 | Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8)).save(path) 52 | else: 53 | Image.fromarray(image).save(path) 54 | -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | 39 | def clear_history(self) -> None: 40 | self.execution_times = defaultdict(list) 41 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | import skvideo.io 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as tf 9 | from einops import rearrange, repeat 10 | from jaxtyping import Float, UInt8 11 | from matplotlib.figure import Figure 12 | from PIL import Image 13 | from torch import Tensor 14 | 15 | FloatImage = Union[ 16 | Float[Tensor, "height width"], 17 | Float[Tensor, "channel height width"], 18 | Float[Tensor, "batch channel height width"], 19 | ] 20 | 21 | 22 | def fig_to_image( 23 | fig: Figure, 24 | dpi: int = 100, 25 | device: torch.device = torch.device("cpu"), 26 | ) -> Float[Tensor, "3 height width"]: 27 | buffer = io.BytesIO() 28 | fig.savefig(buffer, format="raw", dpi=dpi) 29 | buffer.seek(0) 30 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 31 | h = int(fig.bbox.bounds[3]) 32 | w = int(fig.bbox.bounds[2]) 33 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 34 | buffer.close() 35 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 36 | 37 | 38 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 39 | # Handle batched images. 40 | if image.ndim == 4: 41 | image = rearrange(image, "b c h w -> c h (b w)") 42 | 43 | # Handle single-channel images. 44 | if image.ndim == 2: 45 | image = rearrange(image, "h w -> () h w") 46 | 47 | # Ensure that there are 3 or 4 channels. 48 | channel, _, _ = image.shape 49 | if channel == 1: 50 | image = repeat(image, "() h w -> c h w", c=3) 51 | assert image.shape[0] in (3, 4) 52 | 53 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 54 | return rearrange(image, "c h w -> h w c").cpu().numpy() 55 | 56 | 57 | def save_image( 58 | image: FloatImage, 59 | path: Union[Path, str], 60 | ) -> None: 61 | """Save an image. Assumed to be in range 0-1.""" 62 | 63 | # Create the parent directory if it doesn't already exist. 64 | path = Path(path) 65 | path.parent.mkdir(exist_ok=True, parents=True) 66 | 67 | # Save the image. 68 | Image.fromarray(prep_image(image)).save(path) 69 | 70 | 71 | def load_image( 72 | path: Union[Path, str], 73 | ) -> Float[Tensor, "3 height width"]: 74 | return tf.ToTensor()(Image.open(path))[:3] 75 | 76 | 77 | def save_video( 78 | images: list[FloatImage], 79 | path: Union[Path, str], 80 | fps: None | int = None 81 | ) -> None: 82 | """Save an image. Assumed to be in range 0-1.""" 83 | 84 | # Create the parent directory if it doesn't already exist. 85 | path = Path(path) 86 | path.parent.mkdir(exist_ok=True, parents=True) 87 | 88 | # Save the image. 89 | # Image.fromarray(prep_image(image)).save(path) 90 | frames = [] 91 | for image in images: 92 | frames.append(prep_image(image)) 93 | 94 | outputdict = {'-pix_fmt': 'yuv420p', '-crf': '23', 95 | '-vf': f'setpts=1.*PTS'} 96 | 97 | if fps is not None: 98 | outputdict.update({'-r': str(fps)}) 99 | 100 | writer = skvideo.io.FFmpegWriter(path, 101 | outputdict=outputdict) 102 | for frame in frames: 103 | writer.writeFrame(frame) 104 | writer.close() 105 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/resume_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | # Function to extract the step number from the filename 5 | def extract_step(file_name): 6 | step_str = file_name.split("-")[1].split("_")[1].replace(".ckpt", "") 7 | return int(step_str) 8 | 9 | 10 | def find_latest_ckpt(ckpt_dir): 11 | # List all files in the directory that end with .ckpt 12 | ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")] 13 | 14 | # Check if there are any .ckpt files in the directory 15 | if not ckpt_files: 16 | raise ValueError(f"No .ckpt files found in {ckpt_dir}.") 17 | else: 18 | # Find the file with the maximum step 19 | latest_ckpt_file = max(ckpt_files, key=extract_step) 20 | 21 | return ckpt_dir / latest_ckpt_file 22 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | *_, n = sh_coefficients.shape 18 | alpha, beta, gamma = matrix_to_angles(rotations) 19 | result = [] 20 | for degree in range(isqrt(n)): 21 | with torch.device(device): 22 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 23 | sh_rotated = einsum( 24 | sh_rotations, 25 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 26 | "... i j, ... j -> ... i", 27 | ) 28 | result.append(sh_rotated) 29 | 30 | return torch.cat(result, dim=-1) 31 | 32 | 33 | if __name__ == "__main__": 34 | from pathlib import Path 35 | 36 | import matplotlib.pyplot as plt 37 | from e3nn.o3 import spherical_harmonics 38 | from matplotlib import cm 39 | from scipy.spatial.transform.rotation import Rotation as R 40 | 41 | device = torch.device("cuda") 42 | 43 | # Generate random spherical harmonics coefficients. 44 | degree = 4 45 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 46 | 47 | def plot_sh(sh_coefficients, path: Path) -> None: 48 | phi = torch.linspace(0, torch.pi, 100, device=device) 49 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 50 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 51 | x = torch.sin(phi) * torch.cos(theta) 52 | y = torch.sin(phi) * torch.sin(theta) 53 | z = torch.cos(phi) 54 | xyz = torch.stack([x, y, z], dim=-1) 55 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 56 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 57 | result = (result - result.min()) / (result.max() - result.min()) 58 | 59 | # Set the aspect ratio to 1 so our sphere looks spherical 60 | fig = plt.figure(figsize=plt.figaspect(1.0)) 61 | ax = fig.add_subplot(111, projection="3d") 62 | ax.plot_surface( 63 | x.cpu().numpy(), 64 | y.cpu().numpy(), 65 | z.cpu().numpy(), 66 | rstride=1, 67 | cstride=1, 68 | facecolors=cm.seismic(result.cpu().numpy()), 69 | ) 70 | # Turn off the axis planes 71 | ax.set_axis_off() 72 | path.parent.mkdir(exist_ok=True, parents=True) 73 | plt.savefig(path) 74 | 75 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 76 | rotation = torch.tensor( 77 | R.from_euler("x", angle.item()).as_matrix(), device=device 78 | ) 79 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 80 | 81 | print("Done!") 82 | -------------------------------------------------------------------------------- /src/misc/stablize_camera.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/google/dynibar/blob/main/ibrnet/data_loaders/llff_data_utils.py 3 | """ 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | def render_stabilization_path(poses, k_size=45): 10 | """Rendering stablizaed camera path.""" 11 | 12 | # hwf = poses[0, :, 4:5] 13 | num_frames = poses.shape[0] 14 | output_poses = [] 15 | 16 | input_poses = [] 17 | 18 | for i in range(num_frames): 19 | input_poses.append( 20 | np.concatenate( 21 | [poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], axis=-1 22 | ) 23 | ) 24 | 25 | input_poses = np.array(input_poses) 26 | 27 | gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1) 28 | output_r1 = cv2.filter2D(input_poses[:, :, 0], -1, gaussian_kernel) 29 | output_r2 = cv2.filter2D(input_poses[:, :, 1], -1, gaussian_kernel) 30 | 31 | output_r1 = output_r1 / np.linalg.norm(output_r1, axis=-1, keepdims=True) 32 | output_r2 = output_r2 / np.linalg.norm(output_r2, axis=-1, keepdims=True) 33 | 34 | output_t = cv2.filter2D(input_poses[:, :, 2], -1, gaussian_kernel) 35 | 36 | for i in range(num_frames): 37 | output_r3 = np.cross(output_r1[i], output_r2[i]) 38 | 39 | render_pose = np.concatenate( 40 | [ 41 | output_r1[i, :, None], 42 | output_r2[i, :, None], 43 | output_r3[:, None], 44 | output_t[i, :, None], 45 | ], 46 | axis=-1, 47 | ) 48 | 49 | output_poses.append(render_pose[:3, :]) 50 | 51 | return output_poses 52 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ...dataset import DatasetCfg 2 | from .decoder import Decoder 3 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 4 | 5 | DECODERS = { 6 | "splatting_cuda": DecoderSplattingCUDA, 7 | } 8 | 9 | DecoderCfg = DecoderSplattingCUDACfg 10 | 11 | 12 | def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: 13 | return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) 14 | -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ...dataset import DatasetCfg 9 | from ..types import Gaussians 10 | 11 | DepthRenderingMode = Literal[ 12 | "depth", 13 | "log", 14 | "disparity", 15 | "relative_disparity", 16 | ] 17 | 18 | 19 | @dataclass 20 | class DecoderOutput: 21 | color: Float[Tensor, "batch view 3 height width"] 22 | depth: Float[Tensor, "batch view height width"] | None 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class Decoder(nn.Module, ABC, Generic[T]): 29 | cfg: T 30 | dataset_cfg: DatasetCfg 31 | 32 | def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | self.dataset_cfg = dataset_cfg 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | gaussians: Gaussians, 41 | extrinsics: Float[Tensor, "batch view 4 4"], 42 | intrinsics: Float[Tensor, "batch view 3 3"], 43 | near: Float[Tensor, "batch view"], 44 | far: Float[Tensor, "batch view"], 45 | image_shape: tuple[int, int], 46 | depth_mode: DepthRenderingMode | None = None, 47 | ) -> DecoderOutput: 48 | pass 49 | -------------------------------------------------------------------------------- /src/model/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda 12 | from .decoder import Decoder, DecoderOutput 13 | 14 | 15 | @dataclass 16 | class DecoderSplattingCUDACfg: 17 | name: Literal["splatting_cuda"] 18 | 19 | 20 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 21 | background_color: Float[Tensor, "3"] 22 | 23 | def __init__( 24 | self, 25 | cfg: DecoderSplattingCUDACfg, 26 | dataset_cfg: DatasetCfg, 27 | ) -> None: 28 | super().__init__(cfg, dataset_cfg) 29 | self.register_buffer( 30 | "background_color", 31 | torch.tensor(dataset_cfg.background_color, dtype=torch.float32), 32 | persistent=False, 33 | ) 34 | 35 | def forward( 36 | self, 37 | gaussians: Gaussians, 38 | extrinsics: Float[Tensor, "batch view 4 4"], 39 | intrinsics: Float[Tensor, "batch view 3 3"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | image_shape: tuple[int, int], 43 | depth_mode: DepthRenderingMode | None = None, 44 | ) -> DecoderOutput: 45 | b, v, _, _ = extrinsics.shape 46 | color = render_cuda( 47 | rearrange(extrinsics, "b v i j -> (b v) i j"), 48 | rearrange(intrinsics, "b v i j -> (b v) i j"), 49 | rearrange(near, "b v -> (b v)"), 50 | rearrange(far, "b v -> (b v)"), 51 | image_shape, 52 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 53 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 54 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 55 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 56 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 57 | ) 58 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 59 | 60 | return DecoderOutput( 61 | color, 62 | None 63 | if depth_mode is None 64 | else self.render_depth( 65 | gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode 66 | ), 67 | ) 68 | 69 | def render_depth( 70 | self, 71 | gaussians: Gaussians, 72 | extrinsics: Float[Tensor, "batch view 4 4"], 73 | intrinsics: Float[Tensor, "batch view 3 3"], 74 | near: Float[Tensor, "batch view"], 75 | far: Float[Tensor, "batch view"], 76 | image_shape: tuple[int, int], 77 | mode: DepthRenderingMode = "depth", 78 | ) -> Float[Tensor, "batch view height width"]: 79 | b, v, _, _ = extrinsics.shape 80 | result = render_depth_cuda( 81 | rearrange(extrinsics, "b v i j -> (b v) i j"), 82 | rearrange(intrinsics, "b v i j -> (b v) i j"), 83 | rearrange(near, "b v -> (b v)"), 84 | rearrange(far, "b v -> (b v)"), 85 | image_shape, 86 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 87 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 88 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 89 | mode=mode, 90 | ) 91 | return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) 92 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_depthsplat import EncoderDepthSplat, EncoderDepthSplatCfg 5 | from .visualization.encoder_visualizer import EncoderVisualizer 6 | from .visualization.encoder_visualizer_depthsplat import EncoderVisualizerDepthSplat 7 | 8 | ENCODERS = { 9 | "depthsplat": (EncoderDepthSplat, EncoderVisualizerDepthSplat), 10 | } 11 | 12 | EncoderCfg = EncoderDepthSplatCfg 13 | 14 | 15 | def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]: 16 | encoder, visualizer = ENCODERS[cfg.name] 17 | encoder = encoder(cfg) 18 | if visualizer is not None: 19 | visualizer = visualizer(cfg.visualizer, encoder) 20 | return encoder, visualizer 21 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussian_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import einsum, rearrange 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | import torch.nn.functional as F 8 | 9 | from ....geometry.projection import get_world_rays 10 | from ....misc.sh_rotation import rotate_sh 11 | from .gaussians import build_covariance 12 | 13 | 14 | @dataclass 15 | class Gaussians: 16 | means: Float[Tensor, "*batch 3"] 17 | covariances: Float[Tensor, "*batch 3 3"] 18 | scales: Float[Tensor, "*batch 3"] 19 | rotations: Float[Tensor, "*batch 4"] 20 | harmonics: Float[Tensor, "*batch 3 _"] 21 | opacities: Float[Tensor, " *batch"] 22 | 23 | 24 | @dataclass 25 | class GaussianAdapterCfg: 26 | gaussian_scale_min: float 27 | gaussian_scale_max: float 28 | sh_degree: int 29 | 30 | 31 | class GaussianAdapter(nn.Module): 32 | cfg: GaussianAdapterCfg 33 | 34 | def __init__(self, cfg: GaussianAdapterCfg): 35 | super().__init__() 36 | self.cfg = cfg 37 | 38 | # Create a mask for the spherical harmonics coefficients. This ensures that at 39 | # initialization, the coefficients are biased towards having a large DC 40 | # component and small view-dependent components. 41 | self.register_buffer( 42 | "sh_mask", 43 | torch.ones((self.d_sh,), dtype=torch.float32), 44 | persistent=False, 45 | ) 46 | for degree in range(1, self.cfg.sh_degree + 1): 47 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 48 | 49 | def forward( 50 | self, 51 | extrinsics: Float[Tensor, "*#batch 4 4"], 52 | intrinsics: Float[Tensor, "*#batch 3 3"] | None, 53 | coordinates: Float[Tensor, "*#batch 2"], 54 | depths: Float[Tensor, "*#batch"] | None, 55 | opacities: Float[Tensor, "*#batch"], 56 | raw_gaussians: Float[Tensor, "*#batch _"], 57 | image_shape: tuple[int, int], 58 | eps: float = 1e-8, 59 | point_cloud: Float[Tensor, "*#batch 3"] | None = None, 60 | input_images: Tensor | None = None, 61 | ) -> Gaussians: 62 | scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) 63 | 64 | scales = torch.clamp(F.softplus(scales - 4.), 65 | min=self.cfg.gaussian_scale_min, 66 | max=self.cfg.gaussian_scale_max, 67 | ) 68 | 69 | assert input_images is not None 70 | 71 | # Normalize the quaternion features to yield a valid quaternion. 72 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) 73 | 74 | # [2, 2, 65536, 1, 1, 3, 25] 75 | sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 76 | sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask 77 | 78 | if input_images is not None: 79 | # [B, V, H*W, 1, 1, 3] 80 | imgs = rearrange(input_images, "b v c h w -> b v (h w) () () c") 81 | # init sh with input images 82 | sh[..., 0] = sh[..., 0] + RGB2SH(imgs) 83 | 84 | # Create world-space covariance matrices. 85 | covariances = build_covariance(scales, rotations) 86 | c2w_rotations = extrinsics[..., :3, :3] 87 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) 88 | 89 | # Compute Gaussian means. 90 | origins, directions = get_world_rays(coordinates, extrinsics, intrinsics) 91 | means = origins + directions * depths[..., None] 92 | 93 | return Gaussians( 94 | means=means, 95 | covariances=covariances, 96 | harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]), 97 | opacities=opacities, 98 | # NOTE: These aren't yet rotated into world space, but they're only used for 99 | # exporting Gaussians to ply files. This needs to be fixed... 100 | scales=scales, 101 | rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 102 | ) 103 | 104 | def get_scale_multiplier( 105 | self, 106 | intrinsics: Float[Tensor, "*#batch 3 3"], 107 | pixel_size: Float[Tensor, "*#batch 2"], 108 | multiplier: float = 0.1, 109 | ) -> Float[Tensor, " *batch"]: 110 | xy_multipliers = multiplier * einsum( 111 | intrinsics[..., :2, :2].inverse(), 112 | pixel_size, 113 | "... i j, j -> ... i", 114 | ) 115 | return xy_multipliers.sum(dim=-1) 116 | 117 | @property 118 | def d_sh(self) -> int: 119 | return (self.cfg.sh_degree + 1) ** 2 120 | 121 | @property 122 | def d_in(self) -> int: 123 | return 7 + 3 * self.d_sh 124 | 125 | 126 | def RGB2SH(rgb): 127 | C0 = 0.28209479177387814 128 | return (rgb - 0.5) / C0 129 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 8 | def quaternion_to_matrix( 9 | quaternions: Float[Tensor, "*batch 4"], 10 | eps: float = 1e-8, 11 | ) -> Float[Tensor, "*batch 3 3"]: 12 | # Order changed to match scipy format! 13 | i, j, k, r = torch.unbind(quaternions, dim=-1) 14 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 15 | 16 | o = torch.stack( 17 | ( 18 | 1 - two_s * (j * j + k * k), 19 | two_s * (i * j - k * r), 20 | two_s * (i * k + j * r), 21 | two_s * (i * j + k * r), 22 | 1 - two_s * (i * i + k * k), 23 | two_s * (j * k - i * r), 24 | two_s * (i * k - j * r), 25 | two_s * (j * k + i * r), 26 | 1 - two_s * (i * i + j * j), 27 | ), 28 | -1, 29 | ) 30 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 31 | 32 | 33 | def build_covariance( 34 | scale: Float[Tensor, "*#batch 3"], 35 | rotation_xyzw: Float[Tensor, "*#batch 4"], 36 | ) -> Float[Tensor, "*batch 3 3"]: 37 | scale = scale.diag_embed() 38 | rotation = quaternion_to_matrix(rotation_xyzw) 39 | return ( 40 | rotation 41 | @ scale 42 | @ rearrange(scale, "... i j -> ... j i") 43 | @ rearrange(rotation, "... i j -> ... j i") 44 | ) 45 | -------------------------------------------------------------------------------- /src/model/encoder/common/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor, nn 3 | 4 | from ....misc.discrete_probability_distribution import ( 5 | gather_discrete_topk, 6 | sample_discrete_distribution, 7 | ) 8 | 9 | 10 | class Sampler(nn.Module): 11 | def forward( 12 | self, 13 | probabilities: Float[Tensor, "*batch bucket"], 14 | num_samples: int, 15 | deterministic: bool, 16 | ) -> tuple[ 17 | Int64[Tensor, "*batch 1"], # index 18 | Float[Tensor, "*batch 1"], # probability density 19 | ]: 20 | return ( 21 | gather_discrete_topk(probabilities, num_samples) 22 | if deterministic 23 | else sample_discrete_distribution(probabilities, num_samples) 24 | ) 25 | 26 | def gather( 27 | self, 28 | index: Int64[Tensor, "*batch sample"], 29 | target: Shaped[Tensor, "..."], # *batch bucket *shape 30 | ) -> Shaped[Tensor, "..."]: # *batch sample *shape 31 | """Gather from the target according to the specified index. Handle the 32 | broadcasting needed for the gather to work. See the comments for the actual 33 | expected input/output shapes since jaxtyping doesn't support multiple variadic 34 | lengths in annotations. 35 | """ 36 | bucket_dim = index.ndim - 1 37 | while len(index.shape) < len(target.shape): 38 | index = index[..., None] 39 | broadcasted_index_shape = list(target.shape) 40 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 41 | index = index.broadcast_to(broadcasted_index_shape) 42 | return target.gather(dim=bucket_dim, index=index) 43 | -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | deterministic: bool, 24 | ) -> Gaussians: 25 | pass 26 | 27 | def get_data_shim(self) -> DataShim: 28 | """The default shim doesn't modify the batch.""" 29 | return lambda x: x 30 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResidualBlock(nn.Module): 5 | def __init__( 6 | self, 7 | in_planes, 8 | planes, 9 | norm_layer=nn.InstanceNorm2d, 10 | stride=1, 11 | dilation=1, 12 | ): 13 | super(ResidualBlock, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d( 16 | in_planes, 17 | planes, 18 | kernel_size=3, 19 | dilation=dilation, 20 | padding=dilation, 21 | stride=stride, 22 | bias=False, 23 | ) 24 | self.conv2 = nn.Conv2d( 25 | planes, 26 | planes, 27 | kernel_size=3, 28 | dilation=dilation, 29 | padding=dilation, 30 | bias=False, 31 | ) 32 | self.relu = nn.ReLU(inplace=True) 33 | 34 | self.norm1 = norm_layer(planes) 35 | self.norm2 = norm_layer(planes) 36 | if not stride == 1 or in_planes != planes: 37 | self.norm3 = norm_layer(planes) 38 | 39 | if stride == 1 and in_planes == planes: 40 | self.downsample = None 41 | else: 42 | self.downsample = nn.Sequential( 43 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 44 | ) 45 | 46 | def forward(self, x): 47 | y = x 48 | y = self.relu(self.norm1(self.conv1(y))) 49 | y = self.relu(self.norm2(self.conv2(y))) 50 | 51 | if self.downsample is not None: 52 | x = self.downsample(x) 53 | 54 | return self.relu(x + y) 55 | 56 | 57 | class CNNEncoder(nn.Module): 58 | def __init__( 59 | self, 60 | output_dim=128, 61 | norm_layer=nn.InstanceNorm2d, 62 | num_output_scales=1, 63 | return_quarter=False, # return 1/4 resolution feature 64 | lowest_scale=8, # lowest resolution, 1/8 or 1/4 65 | return_all_scales=False, 66 | **kwargs, 67 | ): 68 | super(CNNEncoder, self).__init__() 69 | self.num_scales = num_output_scales 70 | self.return_quarter = return_quarter 71 | self.lowest_scale = lowest_scale 72 | self.return_all_scales = return_all_scales 73 | 74 | feature_dims = [64, 96, 128] 75 | 76 | self.conv1 = nn.Conv2d( 77 | 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False 78 | ) # 1/2 79 | self.norm1 = norm_layer(feature_dims[0]) 80 | self.relu1 = nn.ReLU(inplace=True) 81 | 82 | self.in_planes = feature_dims[0] 83 | self.layer1 = self._make_layer( 84 | feature_dims[0], stride=1, norm_layer=norm_layer 85 | ) # 1/2 86 | 87 | if self.lowest_scale == 4: 88 | stride = 1 89 | else: 90 | stride = 2 91 | self.layer2 = self._make_layer( 92 | feature_dims[1], stride=stride, norm_layer=norm_layer 93 | ) # 1/2 or 1/4 94 | 95 | # lowest resolution 1/4 or 1/8 96 | self.layer3 = self._make_layer( 97 | feature_dims[2], 98 | stride=2, 99 | norm_layer=norm_layer, 100 | ) # 1/4 or 1/8 101 | 102 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 107 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 108 | if m.weight is not None: 109 | nn.init.constant_(m.weight, 1) 110 | if m.bias is not None: 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 114 | layer1 = ResidualBlock( 115 | self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation 116 | ) 117 | layer2 = ResidualBlock( 118 | dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation 119 | ) 120 | 121 | layers = (layer1, layer2) 122 | 123 | self.in_planes = dim 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | output_all_scales = [] 128 | output = [] 129 | x = self.conv1(x) 130 | x = self.norm1(x) 131 | x = self.relu1(x) 132 | 133 | x = self.layer1(x) # 1/2 134 | 135 | if self.return_all_scales: 136 | output_all_scales.append(x) 137 | 138 | if self.num_scales >= 3: 139 | output.append(x) 140 | 141 | x = self.layer2(x) # 1/2 or 1/4 142 | if self.return_quarter: 143 | output.append(x) 144 | 145 | if self.return_all_scales: 146 | output_all_scales.append(x) 147 | 148 | if self.num_scales >= 2: 149 | output.append(x) 150 | 151 | x = self.layer3(x) # 1/4 or 1/8 152 | x = self.conv2(x) 153 | 154 | if self.return_all_scales: 155 | output_all_scales.append(x) 156 | 157 | if self.return_all_scales: 158 | return output_all_scales 159 | 160 | if self.return_quarter: 161 | output.append(x) 162 | return output 163 | 164 | if self.num_scales >= 1: 165 | output.append(x) 166 | return output 167 | 168 | out = [x] 169 | 170 | return out 171 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/feature_upsampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | 8 | class ResizeConvFeatureUpsampler(nn.Module): 9 | """ 10 | https://distill.pub/2016/deconv-checkerboard/ 11 | """ 12 | 13 | def __init__(self, num_scales=1, 14 | lowest_feature_resolution=8, 15 | out_channels=128, 16 | vit_type='vits', 17 | no_mono_feature=False, 18 | gaussian_downsample=None, 19 | monodepth_backbone=False, 20 | ): 21 | super(ResizeConvFeatureUpsampler, self).__init__() 22 | 23 | self.num_scales = num_scales 24 | self.monodepth_backbone = monodepth_backbone 25 | 26 | self.upsampler = nn.ModuleList() 27 | 28 | vit_feature_channel_dict = { 29 | 'vits': 384, 30 | 'vitb': 768, 31 | 'vitl': 1024 32 | } 33 | 34 | vit_feature_channel = vit_feature_channel_dict[vit_type] 35 | 36 | if monodepth_backbone: 37 | vit_feature_channel = 384 38 | 39 | out_channels = out_channels // num_scales 40 | 41 | for i in range(num_scales): 42 | cnn_feature_channels = 128 - (32 * i) 43 | mv_transformer_feature_channels = 128 // (2 ** i) 44 | if no_mono_feature: 45 | mono_feature_channels = 0 46 | else: 47 | mono_feature_channels = vit_feature_channel // (2 ** i) 48 | 49 | in_channels = cnn_feature_channels + \ 50 | mv_transformer_feature_channels + mono_feature_channels 51 | 52 | if monodepth_backbone: 53 | in_channels = 384 54 | 55 | curr_upsample_factor = lowest_feature_resolution // (2 ** i) 56 | 57 | num_upsample = int(math.log(curr_upsample_factor, 2)) 58 | 59 | modules = [] 60 | if num_upsample == 1: 61 | curr_in_channels = out_channels * 2 62 | else: 63 | curr_in_channels = out_channels * 2 * (num_upsample - 1) 64 | modules.append(nn.Conv2d(in_channels, curr_in_channels, 1)) 65 | for i in range(num_upsample): 66 | modules.append(nn.Upsample(scale_factor=2, mode='nearest')) 67 | 68 | if i == num_upsample - 1: 69 | modules.append(nn.Conv2d(curr_in_channels, 70 | out_channels, 3, 1, 1, padding_mode='replicate')) 71 | else: 72 | modules.append(nn.Conv2d(curr_in_channels, 73 | curr_in_channels // 2, 3, 1, 1, padding_mode='replicate')) 74 | curr_in_channels = curr_in_channels // 2 75 | modules.append(nn.GELU()) 76 | 77 | if gaussian_downsample is not None: 78 | if gaussian_downsample == 2: 79 | del modules[-3:] 80 | elif gaussian_downsample == 4: 81 | del modules[-6:] 82 | else: 83 | raise NotImplementedError 84 | 85 | self.upsampler.append(nn.Sequential(*modules)) 86 | 87 | def forward(self, features_list_cnn, features_list_mv, features_list_mono=None): 88 | out = [] 89 | 90 | for i in range(self.num_scales): 91 | if self.monodepth_backbone: 92 | concat = features_list_cnn[i] 93 | elif features_list_mono is None: 94 | concat = torch.cat( 95 | (features_list_cnn[i], features_list_mv[i]), dim=1) 96 | else: 97 | concat = torch.cat( 98 | (features_list_cnn[i], features_list_mv[i], features_list_mono[i]), dim=1) 99 | concat = self.upsampler[i](concat) 100 | 101 | out.append(concat) 102 | 103 | out = torch.cat(out, dim=1) 104 | 105 | return out 106 | 107 | 108 | def _test(): 109 | device = torch.device('cuda:0') 110 | 111 | model = ResizeConvFeatureUpsampler(num_scales=2, 112 | lowest_feature_resolution=4, 113 | ).to(device) 114 | print(model) 115 | 116 | b, h, w = 2, 32, 64 117 | features_list_cnn = [torch.randn(b, 128, h, w).to(device)] 118 | features_list_mv = [torch.randn(b, 128, h, w).to(device)] 119 | features_list_mono = [torch.randn(b, 384, h, w).to(device)] 120 | 121 | # scale 2 122 | features_list_cnn.append(torch.randn(b, 96, h * 2, w * 2).to(device)) 123 | features_list_mv.append(torch.randn(b, 64, h * 2, w * 2).to(device)) 124 | features_list_mono.append(torch.randn(b, 192, h * 2, w * 2).to(device)) 125 | 126 | out = model(features_list_cnn, 127 | features_list_mv, features_list_mono) 128 | 129 | print(out.shape) 130 | 131 | 132 | if __name__ == '__main__': 133 | _test() 134 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/ldm_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/depthsplat/1f5e5486f005e5b9975cca5cbed3d61acf465707/src/model/encoder/unimatch/ldm_unet/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/unimatch/ldm_unet/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import warnings 6 | 7 | 8 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 9 | try: 10 | if XFORMERS_ENABLED: 11 | from xformers.ops import memory_efficient_attention, unbind 12 | 13 | XFORMERS_AVAILABLE = True 14 | # warnings.warn("xFormers is available (Attention)") 15 | else: 16 | # warnings.warn("xFormers is disabled (Attention)") 17 | raise ImportError 18 | except ImportError: 19 | XFORMERS_AVAILABLE = False 20 | # warnings.warn("xFormers is not available (Attention)") 21 | 22 | 23 | class CrossAttention(nn.Module): 24 | def __init__( 25 | self, 26 | in_dim1, 27 | in_dim2, 28 | dim=128, 29 | out_dim=None, 30 | num_heads=4, 31 | qkv_bias=False, 32 | proj_bias=False, 33 | ): 34 | super().__init__() 35 | 36 | assert XFORMERS_AVAILABLE 37 | 38 | if out_dim is None: 39 | out_dim = in_dim1 40 | 41 | self.num_heads = num_heads 42 | self.dim = dim 43 | self.q = nn.Linear(in_dim1, dim, bias=qkv_bias) 44 | self.kv = nn.Linear(in_dim2, dim * 2, bias=qkv_bias) 45 | self.proj = nn.Linear(dim, out_dim, bias=proj_bias) 46 | 47 | def forward(self, x, y): 48 | c = self.dim 49 | b, n1, c1 = x.shape 50 | n2, c2 = y.shape[1:] 51 | 52 | q = self.q(x).reshape(b, n1, self.num_heads, c // self.num_heads) 53 | kv = self.kv(y).reshape(b, n2, 2, self.num_heads, c // self.num_heads) 54 | k, v = unbind(kv, 2) 55 | 56 | x = memory_efficient_attention(q, k, v) 57 | x = x.reshape(b, n1, c) 58 | 59 | x = self.proj(x) 60 | 61 | return x 62 | 63 | 64 | class UNetCrossAttentionBlock(nn.Module): 65 | def __init__(self, 66 | in_dim1, 67 | in_dim2, 68 | dim=128, 69 | out_dim=None, 70 | num_heads=4, 71 | qkv_bias=False, 72 | proj_bias=False, 73 | with_ffn=False, 74 | concat_cross_attn=False, 75 | concat_output=False, 76 | no_cross_attn=False, 77 | with_norm=False, 78 | concat_conv3x3=False, 79 | ): 80 | super().__init__() 81 | 82 | out_dim = out_dim or in_dim1 83 | 84 | self.no_cross_attn = no_cross_attn 85 | self.with_norm = with_norm 86 | 87 | if no_cross_attn: 88 | if concat_conv3x3: 89 | self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 3, 1, 1) 90 | else: 91 | self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 1) 92 | else: 93 | self.with_ffn = with_ffn 94 | self.concat_cross_attn = concat_cross_attn 95 | self.concat_output = concat_output 96 | 97 | self.cross_attn = CrossAttention( 98 | in_dim1=in_dim1, 99 | in_dim2=in_dim2, 100 | dim=dim, 101 | out_dim=out_dim, 102 | num_heads=num_heads, 103 | qkv_bias=qkv_bias, 104 | proj_bias=proj_bias, 105 | ) 106 | 107 | if with_norm: 108 | self.norm1 = nn.LayerNorm(out_dim) 109 | else: 110 | self.norm1 = nn.Identity() 111 | 112 | if with_ffn: 113 | in_channels = out_dim + in_dim1 if concat_cross_attn else in_dim1 114 | ffn_dim_expansion = 4 115 | self.mlp = nn.Sequential( 116 | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), 117 | nn.GELU(), 118 | nn.Linear(in_channels * ffn_dim_expansion, in_dim1, bias=False), 119 | ) 120 | 121 | if with_norm: 122 | self.norm2 = nn.LayerNorm(in_dim1) 123 | else: 124 | self.norm2 = nn.Identity() 125 | 126 | if self.concat_output: 127 | self.out = nn.Linear(out_dim + in_dim1, in_dim1) 128 | 129 | def forward(self, x, y): 130 | # x: [B, C, H, W] 131 | # y: [B, N, C] or [B, C, H, W] 132 | 133 | if self.no_cross_attn: 134 | assert x.dim() == 4 and y.dim() == 4 135 | if y.shape[2:] != x.shape[2:]: 136 | y = F.interpolate(y, x.shape[2:], mode='bilinear', align_corners=True) 137 | return self.proj(torch.cat((x, y), dim=1)) 138 | 139 | identity = x 140 | 141 | b, c, h, w = x.size() 142 | x = x.view(b, c, -1).permute(0, 2, 1) 143 | 144 | cross_attn = self.norm1(self.cross_attn(x, y)) 145 | 146 | if self.with_ffn: 147 | if self.concat_cross_attn: 148 | concat = torch.cat((x, cross_attn), dim=-1) 149 | else: 150 | concat = x + cross_attn 151 | 152 | cross_attn = self.norm2(self.mlp(concat)) 153 | 154 | if self.concat_output: 155 | return self.out(torch.cat((x, cross_attn), dim=-1)) 156 | 157 | # reshape back 158 | cross_attn = cross_attn.view(b, h, w, c).permute(0, 3, 1, 2) # [B, C, H, W] 159 | 160 | return identity + cross_attn 161 | 162 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def warp_with_pose_depth_candidates( 25 | feature1, 26 | intrinsics, 27 | pose, 28 | depth, 29 | clamp_min_depth=1e-3, 30 | grid_sample_disable_cudnn=False, 31 | ): 32 | """ 33 | feature1: [B, C, H, W] 34 | intrinsics: [B, 3, 3] 35 | pose: [B, 4, 4] 36 | depth: [B, D, H, W] 37 | """ 38 | 39 | assert intrinsics.size(1) == intrinsics.size(2) == 3 40 | assert pose.size(1) == pose.size(2) == 4 41 | assert depth.dim() == 4 42 | 43 | b, d, h, w = depth.size() 44 | c = feature1.size(1) 45 | 46 | with torch.no_grad(): 47 | # pixel coordinates 48 | grid = coords_grid( 49 | b, h, w, homogeneous=True, device=depth.device 50 | ) # [B, 3, H, W] 51 | # back project to 3D and transform viewpoint 52 | points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] 53 | points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( 54 | 1, 1, d, 1 55 | ) * depth.view( 56 | b, 1, d, h * w 57 | ) # [B, 3, D, H*W] 58 | points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] 59 | # reproject to 2D image plane 60 | points = torch.bmm(intrinsics, points.view(b, 3, -1)).view( 61 | b, 3, d, h * w 62 | ) # [B, 3, D, H*W] 63 | pixel_coords = points[:, :2] / points[:, -1:].clamp( 64 | min=clamp_min_depth 65 | ) # [B, 2, D, H*W] 66 | 67 | # normalize to [-1, 1] 68 | x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 69 | y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 70 | 71 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] 72 | 73 | # sample features 74 | # ref: https://github.com/pytorch/pytorch/issues/88380 75 | # print(feature1.shape, grid.shape) 76 | # hardcoded workaround 77 | if feature1.numel() > 1000000: 78 | grid_sample_disable_cudnn = True 79 | with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn): 80 | warped_feature = F.grid_sample( 81 | feature1, 82 | grid.view(b, d * h, w, 2), 83 | mode="bilinear", 84 | padding_mode="zeros", 85 | align_corners=True, 86 | ).view( 87 | b, c, d, h, w 88 | ) # [B, C, D, H, W] 89 | 90 | return warped_feature 91 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack( 44 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 45 | ).flatten(3) 46 | pos_y = torch.stack( 47 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 48 | ).flatten(3) 49 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 50 | return pos 51 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .position import PositionEmbeddingSine 4 | 5 | 6 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 7 | assert device is not None 8 | 9 | x, y = torch.meshgrid( 10 | [ 11 | torch.linspace(w_min, w_max, len_w, device=device), 12 | torch.linspace(h_min, h_max, len_h, device=device), 13 | ], 14 | ) 15 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 16 | 17 | return grid 18 | 19 | 20 | def normalize_coords(coords, h, w): 21 | # coords: [B, H, W, 2] 22 | c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) 23 | return (coords - c) / c # [-1, 1] 24 | 25 | 26 | def normalize_img(img0, img1): 27 | # loaded images are in [0, 255] 28 | # normalize by ImageNet mean and std 29 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) 30 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) 31 | img0 = (img0 / 255.0 - mean) / std 32 | img1 = (img1 / 255.0 - mean) / std 33 | 34 | return img0, img1 35 | 36 | 37 | def split_feature( 38 | feature, 39 | num_splits=2, 40 | channel_last=False, 41 | ): 42 | if channel_last: # [B, H, W, C] 43 | b, h, w, c = feature.size() 44 | assert h % num_splits == 0 and w % num_splits == 0 45 | 46 | b_new = b * num_splits * num_splits 47 | h_new = h // num_splits 48 | w_new = w // num_splits 49 | 50 | feature = ( 51 | feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) 52 | .permute(0, 1, 3, 2, 4, 5) 53 | .reshape(b_new, h_new, w_new, c) 54 | ) # [B*K*K, H/K, W/K, C] 55 | else: # [B, C, H, W] 56 | b, c, h, w = feature.size() 57 | assert h % num_splits == 0 and w % num_splits == 0 58 | 59 | b_new = b * num_splits * num_splits 60 | h_new = h // num_splits 61 | w_new = w // num_splits 62 | 63 | feature = ( 64 | feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) 65 | .permute(0, 2, 4, 1, 3, 5) 66 | .reshape(b_new, c, h_new, w_new) 67 | ) # [B*K*K, C, H/K, W/K] 68 | 69 | return feature 70 | 71 | 72 | def merge_splits( 73 | splits, 74 | num_splits=2, 75 | channel_last=False, 76 | ): 77 | if channel_last: # [B*K*K, H/K, W/K, C] 78 | b, h, w, c = splits.size() 79 | new_b = b // num_splits // num_splits 80 | 81 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 82 | merge = ( 83 | splits.permute(0, 1, 3, 2, 4, 5) 84 | .contiguous() 85 | .view(new_b, num_splits * h, num_splits * w, c) 86 | ) # [B, H, W, C] 87 | else: # [B*K*K, C, H/K, W/K] 88 | b, c, h, w = splits.size() 89 | new_b = b // num_splits // num_splits 90 | 91 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 92 | merge = ( 93 | splits.permute(0, 3, 1, 4, 2, 5) 94 | .contiguous() 95 | .view(new_b, c, num_splits * h, num_splits * w) 96 | ) # [B, C, H, W] 97 | 98 | return merge 99 | 100 | 101 | def generate_shift_window_attn_mask( 102 | input_resolution, 103 | window_size_h, 104 | window_size_w, 105 | shift_size_h, 106 | shift_size_w, 107 | device=torch.device("cuda"), 108 | ): 109 | # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 110 | # calculate attention mask for SW-MSA 111 | h, w = input_resolution 112 | img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 113 | h_slices = ( 114 | slice(0, -window_size_h), 115 | slice(-window_size_h, -shift_size_h), 116 | slice(-shift_size_h, None), 117 | ) 118 | w_slices = ( 119 | slice(0, -window_size_w), 120 | slice(-window_size_w, -shift_size_w), 121 | slice(-shift_size_w, None), 122 | ) 123 | cnt = 0 124 | for h in h_slices: 125 | for w in w_slices: 126 | img_mask[:, h, w, :] = cnt 127 | cnt += 1 128 | 129 | mask_windows = split_feature( 130 | img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True 131 | ) 132 | 133 | mask_windows = mask_windows.view(-1, window_size_h * window_size_w) 134 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 135 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( 136 | attn_mask == 0, float(0.0) 137 | ) 138 | 139 | return attn_mask 140 | 141 | 142 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 143 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 144 | 145 | if attn_splits > 1: # add position in splited window 146 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 147 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 148 | 149 | position = pos_enc(feature0_splits) 150 | 151 | feature0_splits = feature0_splits + position 152 | feature1_splits = feature1_splits + position 153 | 154 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 155 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 156 | else: 157 | position = pos_enc(feature0) 158 | 159 | feature0 = feature0 + position 160 | feature1 = feature1 + position 161 | 162 | return feature0, feature1 163 | 164 | 165 | def mv_feature_add_position(features, attn_splits, feature_channels): 166 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 167 | 168 | assert features.dim() == 4 # [B*V, C, H, W] 169 | 170 | if attn_splits > 1: # add position in splited window 171 | features_splits = split_feature(features, num_splits=attn_splits) 172 | position = pos_enc(features_splits) 173 | features_splits = features_splits + position 174 | features = merge_splits(features_splits, num_splits=attn_splits) 175 | else: 176 | position = pos_enc(features) 177 | features = features + position 178 | 179 | return features 180 | -------------------------------------------------------------------------------- /src/model/encoder/unimatch/vit_fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Ref: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py#L363 7 | 8 | 9 | class ViTFeaturePyramid(nn.Module): 10 | """ 11 | This module implements SimpleFeaturePyramid in :paper:`vitdet`. 12 | It creates pyramid features built on top of the input feature map. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | in_channels, 18 | scale_factors, 19 | ): 20 | """ 21 | Args: 22 | scale_factors (list[float]): list of scaling factors to upsample or downsample 23 | the input features for creating pyramid features. 24 | """ 25 | super(ViTFeaturePyramid, self).__init__() 26 | 27 | self.scale_factors = scale_factors 28 | 29 | out_dim = dim = in_channels 30 | self.stages = nn.ModuleList() 31 | for idx, scale in enumerate(scale_factors): 32 | if scale == 4.0: 33 | layers = [ 34 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), 35 | nn.GELU(), 36 | nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), 37 | ] 38 | out_dim = dim // 4 39 | elif scale == 2.0: 40 | layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] 41 | out_dim = dim // 2 42 | elif scale == 1.0: 43 | layers = [] 44 | elif scale == 0.5: 45 | layers = [nn.MaxPool2d(kernel_size=2, stride=2)] 46 | else: 47 | raise NotImplementedError(f"scale_factor={scale} is not supported yet.") 48 | 49 | if scale != 1.0: 50 | layers.extend( 51 | [ 52 | nn.GELU(), 53 | nn.Conv2d(out_dim, out_dim, 3, 1, 1), 54 | ] 55 | ) 56 | layers = nn.Sequential(*layers) 57 | 58 | self.stages.append(layers) 59 | 60 | def forward(self, x): 61 | results = [] 62 | 63 | for stage in self.stages: 64 | results.append(stage(x)) 65 | 66 | return results 67 | 68 | 69 | def _test(): 70 | model = ViTFeaturePyramid( 71 | 384, 72 | scale_factors=[1, 2, 4], 73 | ).cuda() 74 | print(model) 75 | 76 | x = torch.randn(2, 384, 64, 96).cuda() 77 | 78 | out = model(x) 79 | 80 | for x in out: 81 | print(x.shape) 82 | 83 | 84 | if __name__ == "__main__": 85 | _test() 86 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer_depthsplat_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerDepthSplatCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange, repeat 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | 36 | view_rotation = extrinsics[:3, :3].inverse() 37 | # Apply the rotation to the means (Gaussian positions). 38 | means = einsum(view_rotation, means, "i j, ... j -> ... i") 39 | 40 | # Apply the rotation to the Gaussian rotations. 41 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 42 | rotations = view_rotation.detach().cpu().numpy() @ rotations 43 | rotations = R.from_matrix(rotations).as_quat() 44 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 45 | rotations = np.stack((w, x, y, z), axis=-1) 46 | 47 | # Since our axes are swizzled for the spherical harmonics, we only export the DC band 48 | harmonics_view_invariant = harmonics[..., 0] 49 | 50 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 51 | elements = np.empty(means.shape[0], dtype=dtype_full) 52 | attributes = ( 53 | means.detach().cpu().numpy(), 54 | torch.zeros_like(means).detach().cpu().numpy(), 55 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 56 | torch.logit(opacities[..., None]).detach().cpu().numpy(), 57 | scales.log().detach().cpu().numpy(), 58 | rotations, 59 | ) 60 | attributes = np.concatenate(attributes, axis=1) 61 | elements[:] = list(map(tuple, attributes)) 62 | path.parent.mkdir(exist_ok=True, parents=True) 63 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 64 | 65 | 66 | def save_gaussian_ply(gaussians, visualization_dump, example, save_path): 67 | 68 | v, _, h, w = example["context"]["image"].shape[1:] 69 | 70 | # Transform means into camera space. 71 | means = rearrange( 72 | gaussians.means, "() (v h w spp) xyz -> h w spp v xyz", v=v, h=h, w=w 73 | ) 74 | 75 | # Create a mask to filter the Gaussians. throw away Gaussians at the 76 | # borders, since they're generally of lower quality. 77 | mask = torch.zeros_like(means[..., 0], dtype=torch.bool) 78 | GAUSSIAN_TRIM = 8 79 | mask[GAUSSIAN_TRIM:-GAUSSIAN_TRIM, GAUSSIAN_TRIM:-GAUSSIAN_TRIM, :, :] = 1 80 | 81 | def trim(element): 82 | element = rearrange( 83 | element, "() (v h w spp) ... -> h w spp v ...", v=v, h=h, w=w 84 | ) 85 | return element[mask][None] 86 | 87 | # convert the rotations from camera space to world space as required 88 | cam_rotations = trim(visualization_dump["rotations"])[0] 89 | c2w_mat = repeat( 90 | example["context"]["extrinsics"][0, :, :3, :3], 91 | "v a b -> h w spp v a b", 92 | h=h, 93 | w=w, 94 | spp=1, 95 | ) 96 | c2w_mat = c2w_mat[mask] # apply trim 97 | 98 | cam_rotations_np = R.from_quat( 99 | cam_rotations.detach().cpu().numpy() 100 | ).as_matrix() 101 | world_mat = c2w_mat.detach().cpu().numpy() @ cam_rotations_np 102 | world_rotations = R.from_matrix(world_mat).as_quat() 103 | world_rotations = torch.from_numpy(world_rotations).to( 104 | visualization_dump["scales"] 105 | ) 106 | 107 | export_ply( 108 | example["context"]["extrinsics"][0, 0], 109 | trim(gaussians.means)[0], 110 | trim(visualization_dump["scales"])[0], 111 | world_rotations, 112 | trim(gaussians.harmonics)[0], 113 | trim(gaussians.opacities)[0], 114 | save_path, 115 | ) 116 | 117 | 118 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | -------------------------------------------------------------------------------- /src/scripts/convert_dl3dv_test.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | from pathlib import Path 4 | from typing import Literal, TypedDict 5 | from PIL import Image 6 | 7 | import numpy as np 8 | import torch 9 | from jaxtyping import Float, Int, UInt8 10 | from torch import Tensor 11 | from tqdm import tqdm 12 | import argparse 13 | import json 14 | import os 15 | 16 | from glob import glob 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--input_dir", type=str, help="original dataset directory") 21 | parser.add_argument("--output_dir", type=str, help="processed dataset directory") 22 | parser.add_argument( 23 | "--img_subdir", 24 | type=str, 25 | default="images_8", 26 | help="image directory name", 27 | choices=[ 28 | "images_4", 29 | "images_8", 30 | ], 31 | ) 32 | parser.add_argument("--n_test", type=int, default=10, help="test skip") 33 | parser.add_argument("--which_stage", type=str, default=None, help="dataset directory") 34 | parser.add_argument("--detect_overlap", action="store_true") 35 | 36 | args = parser.parse_args() 37 | 38 | 39 | INPUT_DIR = Path(args.input_dir) 40 | OUTPUT_DIR = Path(args.output_dir) 41 | 42 | 43 | # Target 200 MB per chunk. 44 | TARGET_BYTES_PER_CHUNK = int(2e8) 45 | 46 | 47 | def get_example_keys(stage: Literal["test", "train"]) -> list[str]: 48 | image_keys = set( 49 | example.name 50 | for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes") 51 | if example.is_dir() and not example.name.startswith(".") 52 | ) 53 | # keys = image_keys & metadata_keys 54 | keys = image_keys 55 | # print(keys) 56 | print(f"Found {len(keys)} keys.") 57 | return sorted(list(keys)) 58 | 59 | 60 | def get_size(path: Path) -> int: 61 | """Get file or folder size in bytes.""" 62 | return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) 63 | 64 | 65 | def load_raw(path: Path) -> UInt8[Tensor, " length"]: 66 | return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) 67 | 68 | 69 | def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: 70 | """Load JPG images as raw bytes (do not decode).""" 71 | 72 | return { 73 | int(path.stem.split("_")[-1]): load_raw(path) 74 | for path in example_path.iterdir() 75 | if path.suffix.lower() not in [".npz"] 76 | } 77 | 78 | 79 | class Metadata(TypedDict): 80 | url: str 81 | timestamps: Int[Tensor, " camera"] 82 | cameras: Float[Tensor, "camera entry"] 83 | 84 | 85 | class Example(Metadata): 86 | key: str 87 | images: list[UInt8[Tensor, "..."]] 88 | 89 | 90 | def load_metadata(example_path: Path) -> Metadata: 91 | blender2opencv = np.array( 92 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 93 | ) 94 | url = str(example_path).split("/")[-3] 95 | with open(example_path, "r") as f: 96 | meta_data = json.load(f) 97 | 98 | store_h, store_w = meta_data["h"], meta_data["w"] 99 | fx, fy, cx, cy = ( 100 | meta_data["fl_x"], 101 | meta_data["fl_y"], 102 | meta_data["cx"], 103 | meta_data["cy"], 104 | ) 105 | saved_fx = float(fx) / float(store_w) 106 | saved_fy = float(fy) / float(store_h) 107 | saved_cx = float(cx) / float(store_w) 108 | saved_cy = float(cy) / float(store_h) 109 | 110 | timestamps = [] 111 | cameras = [] 112 | opencv_c2ws = [] # will be used to calculate camera distance 113 | 114 | for frame in meta_data["frames"]: 115 | timestamps.append( 116 | int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) 117 | ) 118 | camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] 119 | # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here 120 | opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv 121 | opencv_c2ws.append(opencv_c2w) 122 | camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) 123 | cameras.append(np.array(camera)) 124 | 125 | # timestamp should be the one that match the above images keys, use for indexing 126 | timestamps = torch.tensor(timestamps, dtype=torch.int64) 127 | cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) 128 | 129 | return {"url": url, "timestamps": timestamps, "cameras": cameras} 130 | 131 | 132 | def partition_train_test_splits(root_dir, n_test=10): 133 | sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) 134 | test_list = sub_folders[::n_test] 135 | train_list = [x for x in sub_folders if x not in test_list] 136 | out_dict = {"train": train_list, "test": test_list} 137 | return out_dict 138 | 139 | 140 | def is_image_shape_matched(image_dir, target_shape): 141 | image_path = sorted(glob(str(image_dir / "*"))) 142 | if len(image_path) == 0: 143 | return False 144 | 145 | image_path = image_path[0] 146 | try: 147 | im = Image.open(image_path) 148 | except: 149 | return False 150 | w, h = im.size 151 | if (h, w) == target_shape: 152 | return True 153 | else: 154 | return False 155 | 156 | 157 | def legal_check_for_all_scenes(root_dir, target_shape): 158 | valid_folders = [] 159 | sub_folders = sorted(glob(os.path.join(root_dir, "*/nerfstudio"))) 160 | for sub_folder in tqdm(sub_folders, desc="checking scenes..."): 161 | img_dir = os.path.join(sub_folder, "images_8") # 270x480 162 | # img_dir = os.path.join(sub_folder, 'images_4') # 540x960 163 | if not is_image_shape_matched(Path(img_dir), target_shape): 164 | print(f"image shape does not match for {sub_folder}") 165 | continue 166 | pose_file = os.path.join(sub_folder, "transforms.json") 167 | if not os.path.isfile(pose_file): 168 | print(f"cannot find pose file for {sub_folder}") 169 | continue 170 | 171 | valid_folders.append(sub_folder) 172 | 173 | return valid_folders 174 | 175 | 176 | if __name__ == "__main__": 177 | if "images_8" in args.img_subdir: 178 | target_shape = (270, 480) # (h, w) 179 | elif "images_4" in args.img_subdir: 180 | target_shape = (540, 960) 181 | else: 182 | raise ValueError 183 | 184 | print("checking all scenes...") 185 | valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape) 186 | print("valid scenes:", len(valid_scenes)) 187 | 188 | for stage in ["test"]: 189 | 190 | error_logs = [] 191 | image_dirs = valid_scenes 192 | 193 | chunk_size = 0 194 | chunk_index = 0 195 | chunk: list[Example] = [] 196 | 197 | def save_chunk(): 198 | global chunk_size 199 | global chunk_index 200 | global chunk 201 | 202 | chunk_key = f"{chunk_index:0>6}" 203 | dir = OUTPUT_DIR / stage 204 | dir.mkdir(exist_ok=True, parents=True) 205 | torch.save(chunk, dir / f"{chunk_key}.torch") 206 | 207 | # Reset the chunk. 208 | chunk_size = 0 209 | chunk_index += 1 210 | chunk = [] 211 | 212 | for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"): 213 | key = os.path.basename(os.path.dirname(image_dir.strip("/"))) 214 | 215 | image_dir = Path(image_dir) / "images_8" # 270x480 216 | # image_dir = Path(image_dir) / 'images_4' # 540x960 217 | 218 | num_bytes = get_size(image_dir) 219 | 220 | # Read images and metadata. 221 | try: 222 | images = load_images(image_dir) 223 | except: 224 | print("image loading error") 225 | continue 226 | meta_path = image_dir.parent / "transforms.json" 227 | if not meta_path.is_file(): 228 | error_msg = f"---------> [ERROR] no meta file in {key}, skip." 229 | print(error_msg) 230 | error_logs.append(error_msg) 231 | continue 232 | example = load_metadata(meta_path) 233 | 234 | # Merge the images into the example. 235 | try: 236 | example["images"] = [ 237 | images[timestamp.item()] for timestamp in example["timestamps"] 238 | ] 239 | except: 240 | error_msg = f"---------> [ERROR] Some images missing in {key}, skip." 241 | print(error_msg) 242 | error_logs.append(error_msg) 243 | continue 244 | 245 | # Add the key to the example. 246 | example["key"] = key 247 | 248 | chunk.append(example) 249 | chunk_size += num_bytes 250 | 251 | if chunk_size >= TARGET_BYTES_PER_CHUNK: 252 | save_chunk() 253 | 254 | if chunk_size > 0: 255 | save_chunk() 256 | -------------------------------------------------------------------------------- /src/scripts/convert_dl3dv_train.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | from pathlib import Path 4 | from typing import Literal, TypedDict 5 | from PIL import Image 6 | 7 | import numpy as np 8 | import torch 9 | from jaxtyping import Float, Int, UInt8 10 | from torch import Tensor 11 | from tqdm import tqdm 12 | import argparse 13 | import json 14 | import os 15 | 16 | from glob import glob 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--input_dir", type=str, help="original dataset directory") 21 | parser.add_argument("--output_dir", type=str, help="processed dataset directory") 22 | parser.add_argument( 23 | "--img_subdir", 24 | type=str, 25 | default="images_8", 26 | help="image directory name", 27 | choices=[ 28 | "images_4", 29 | "images_8", 30 | ], 31 | ) 32 | parser.add_argument("--n_test", type=int, default=10, help="test skip") 33 | parser.add_argument("--which_stage", type=str, default=None, help="dataset directory") 34 | parser.add_argument("--detect_overlap", action="store_true") 35 | 36 | args = parser.parse_args() 37 | 38 | 39 | INPUT_DIR = Path(args.input_dir) 40 | OUTPUT_DIR = Path(args.output_dir) 41 | 42 | 43 | # Target 200 MB per chunk. 44 | TARGET_BYTES_PER_CHUNK = int(2e8) 45 | 46 | 47 | def get_example_keys(stage: Literal["test", "train"]) -> list[str]: 48 | image_keys = set( 49 | example.name 50 | for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes") 51 | if example.is_dir() and not example.name.startswith(".") 52 | ) 53 | # keys = image_keys & metadata_keys 54 | keys = image_keys 55 | # print(keys) 56 | # assert False 57 | print(f"Found {len(keys)} keys.") 58 | return sorted(list(keys)) 59 | 60 | 61 | def get_size(path: Path) -> int: 62 | """Get file or folder size in bytes.""" 63 | return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) 64 | 65 | 66 | def load_raw(path: Path) -> UInt8[Tensor, " length"]: 67 | return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) 68 | 69 | 70 | def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: 71 | """Load JPG images as raw bytes (do not decode).""" 72 | 73 | return { 74 | int(path.stem.split("_")[-1]): load_raw(path) 75 | for path in example_path.iterdir() 76 | if path.suffix.lower() not in [".npz"] 77 | } 78 | 79 | 80 | class Metadata(TypedDict): 81 | url: str 82 | timestamps: Int[Tensor, " camera"] 83 | cameras: Float[Tensor, "camera entry"] 84 | 85 | 86 | class Example(Metadata): 87 | key: str 88 | images: list[UInt8[Tensor, "..."]] 89 | 90 | 91 | def load_metadata(example_path: Path) -> Metadata: 92 | blender2opencv = np.array( 93 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 94 | ) 95 | url = str(example_path).split("/")[-3] 96 | with open(example_path, "r") as f: 97 | meta_data = json.load(f) 98 | 99 | store_h, store_w = meta_data["h"], meta_data["w"] 100 | fx, fy, cx, cy = ( 101 | meta_data["fl_x"], 102 | meta_data["fl_y"], 103 | meta_data["cx"], 104 | meta_data["cy"], 105 | ) 106 | saved_fx = float(fx) / float(store_w) 107 | saved_fy = float(fy) / float(store_h) 108 | saved_cx = float(cx) / float(store_w) 109 | saved_cy = float(cy) / float(store_h) 110 | 111 | timestamps = [] 112 | cameras = [] 113 | opencv_c2ws = [] # will be used to calculate camera distance 114 | 115 | for frame in meta_data["frames"]: 116 | timestamps.append( 117 | int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) 118 | ) 119 | camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] 120 | # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here 121 | opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv 122 | opencv_c2ws.append(opencv_c2w) 123 | camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) 124 | cameras.append(np.array(camera)) 125 | 126 | # timestamp should be the one that match the above images keys, use for indexing 127 | timestamps = torch.tensor(timestamps, dtype=torch.int64) 128 | cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) 129 | 130 | return {"url": url, "timestamps": timestamps, "cameras": cameras} 131 | 132 | 133 | def partition_train_test_splits(root_dir, n_test=10): 134 | sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) 135 | test_list = sub_folders[::n_test] 136 | train_list = [x for x in sub_folders if x not in test_list] 137 | out_dict = {"train": train_list, "test": test_list} 138 | return out_dict 139 | 140 | 141 | def is_image_shape_matched(image_dir, target_shape): 142 | image_path = sorted(glob(str(image_dir / "*"))) 143 | if len(image_path) == 0: 144 | return False 145 | 146 | image_path = image_path[0] 147 | try: 148 | im = Image.open(image_path) 149 | except: 150 | return False 151 | w, h = im.size 152 | if (h, w) == target_shape: 153 | return True 154 | else: 155 | return False 156 | 157 | 158 | def legal_check_for_all_scenes(root_dir, target_shape): 159 | valid_folders = [] 160 | sub_folders = sorted(glob(os.path.join(root_dir, "*/*"))) 161 | for sub_folder in tqdm(sub_folders, desc="checking scenes..."): 162 | # img_dir = os.path.join(sub_folder, 'images_8') 163 | img_dir = os.path.join(sub_folder, "images_4") 164 | if not is_image_shape_matched(Path(img_dir), target_shape): 165 | print(f"image shape does not match for {sub_folder}") 166 | continue 167 | pose_file = os.path.join(sub_folder, "transforms.json") 168 | if not os.path.isfile(pose_file): 169 | print(f"cannot find pose file for {sub_folder}") 170 | continue 171 | 172 | valid_folders.append(sub_folder) 173 | 174 | return valid_folders 175 | 176 | 177 | if __name__ == "__main__": 178 | if "images_8" in args.img_subdir: 179 | target_shape = (270, 480) # (h, w) 180 | elif "images_4" in args.img_subdir: 181 | target_shape = (540, 960) 182 | else: 183 | raise ValueError 184 | 185 | print("checking all scenes...") 186 | valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape) 187 | print("valid scenes:", len(valid_scenes)) 188 | 189 | # test scenes 190 | test_scenes = "your_test_set_index.json" 191 | with open(test_scenes, "r") as f: 192 | overlap_scenes = json.load(f) 193 | 194 | assert len(overlap_scenes) == 140, "test scenes should contain 140 scenes" 195 | 196 | for stage in ["train"]: 197 | 198 | error_logs = [] 199 | image_dirs = valid_scenes 200 | 201 | chunk_size = 0 202 | chunk_index = 0 203 | chunk: list[Example] = [] 204 | 205 | def save_chunk(): 206 | global chunk_size 207 | global chunk_index 208 | global chunk 209 | 210 | chunk_key = f"{chunk_index:0>6}" 211 | dir = OUTPUT_DIR / stage 212 | dir.mkdir(exist_ok=True, parents=True) 213 | torch.save(chunk, dir / f"{chunk_key}.torch") 214 | 215 | # Reset the chunk. 216 | chunk_size = 0 217 | chunk_index += 1 218 | chunk = [] 219 | 220 | for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"): 221 | key = os.path.basename(image_dir.strip("/")) 222 | # skip test scenes 223 | if key in overlap_scenes: 224 | print(f"scene {key} in benchmark, skip.") 225 | continue 226 | 227 | image_dir = Path(image_dir) / "images_8" # 270x480 228 | # image_dir = Path(image_dir) / 'images_4' # 540x960 229 | 230 | num_bytes = get_size(image_dir) 231 | 232 | # Read images and metadata. 233 | try: 234 | images = load_images(image_dir) 235 | except: 236 | print("image loading error") 237 | continue 238 | meta_path = image_dir.parent / "transforms.json" 239 | if not meta_path.is_file(): 240 | error_msg = f"---------> [ERROR] no meta file in {key}, skip." 241 | print(error_msg) 242 | error_logs.append(error_msg) 243 | continue 244 | example = load_metadata(meta_path) 245 | 246 | # Merge the images into the example. 247 | try: 248 | example["images"] = [ 249 | images[timestamp.item()] for timestamp in example["timestamps"] 250 | ] 251 | except: 252 | error_msg = f"---------> [ERROR] Some images missing in {key}, skip." 253 | print(error_msg) 254 | error_logs.append(error_msg) 255 | continue 256 | 257 | # Add the key to the example. 258 | example["key"] = "dl3dv_" + key 259 | 260 | chunk.append(example) 261 | chunk_size += num_bytes 262 | 263 | if chunk_size >= TARGET_BYTES_PER_CHUNK: 264 | save_chunk() 265 | 266 | if chunk_size > 0: 267 | save_chunk() 268 | -------------------------------------------------------------------------------- /src/scripts/generate_dl3dv_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | DATASET_PATH = Path("your_dataset_path") 8 | 9 | if __name__ == "__main__": 10 | # "train" or "test" 11 | for stage in ["test"]: 12 | stage = DATASET_PATH / stage 13 | 14 | index = {} 15 | for chunk_path in tqdm( 16 | sorted(list(stage.iterdir())), desc=f"Indexing {stage.name}" 17 | ): 18 | if chunk_path.suffix == ".torch": 19 | chunk = torch.load(chunk_path) 20 | for example in chunk: 21 | index[example["key"]] = str(chunk_path.relative_to(stage)) 22 | with (stage / "index.json").open("w") as f: 23 | json.dump(index, f) 24 | -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/cameras.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import einsum, rearrange, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ...geometry.projection import unproject 9 | from ..annotation import add_label 10 | from .lines import draw_lines 11 | from .types import Scalar, sanitize_scalar 12 | 13 | 14 | def draw_cameras( 15 | resolution: int, 16 | extrinsics: Float[Tensor, "batch 4 4"], 17 | intrinsics: Float[Tensor, "batch 3 3"], 18 | color: Float[Tensor, "batch 3"], 19 | near: Optional[Scalar] = None, 20 | far: Optional[Scalar] = None, 21 | margin: float = 0.1, # relative to AABB 22 | frustum_scale: float = 0.05, # relative to image resolution 23 | ) -> Float[Tensor, "3 3 height width"]: 24 | device = extrinsics.device 25 | 26 | # Compute scene bounds. 27 | minima, maxima = compute_aabb(extrinsics, intrinsics, near, far) 28 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 29 | minima, maxima, margin=margin 30 | ) 31 | span = (scene_maxima - scene_minima).max() 32 | 33 | # Compute frustum locations. 34 | corner_depth = (span * frustum_scale)[None] 35 | frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth) 36 | if near is not None: 37 | near_corners = unproject_frustum_corners(extrinsics, intrinsics, near) 38 | if far is not None: 39 | far_corners = unproject_frustum_corners(extrinsics, intrinsics, far) 40 | 41 | # Project the cameras onto each axis-aligned plane. 42 | projections = [] 43 | for projected_axis in range(3): 44 | image = torch.zeros( 45 | (3, resolution, resolution), 46 | dtype=torch.float32, 47 | device=device, 48 | ) 49 | image_x_axis = (projected_axis + 1) % 3 50 | image_y_axis = (projected_axis + 2) % 3 51 | 52 | def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]: 53 | x = points[..., image_x_axis] 54 | y = points[..., image_y_axis] 55 | return torch.stack([x, y], dim=-1) 56 | 57 | x_range, y_range = torch.stack( 58 | (project(scene_minima), project(scene_maxima)), dim=-1 59 | ) 60 | 61 | # Draw near and far planes. 62 | if near is not None: 63 | projected_near_corners = project(near_corners) 64 | image = draw_lines( 65 | image, 66 | rearrange(projected_near_corners, "b p xy -> (b p) xy"), 67 | rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"), 68 | color=0.25, 69 | width=2, 70 | x_range=x_range, 71 | y_range=y_range, 72 | ) 73 | if far is not None: 74 | projected_far_corners = project(far_corners) 75 | image = draw_lines( 76 | image, 77 | rearrange(projected_far_corners, "b p xy -> (b p) xy"), 78 | rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"), 79 | color=0.25, 80 | width=2, 81 | x_range=x_range, 82 | y_range=y_range, 83 | ) 84 | if near is not None and far is not None: 85 | image = draw_lines( 86 | image, 87 | rearrange(projected_near_corners, "b p xy -> (b p) xy"), 88 | rearrange(projected_far_corners, "b p xy -> (b p) xy"), 89 | color=0.25, 90 | width=2, 91 | x_range=x_range, 92 | y_range=y_range, 93 | ) 94 | 95 | # Draw the camera frustums themselves. 96 | projected_origins = project(extrinsics[:, :3, 3]) 97 | projected_frustum_corners = project(frustum_corners) 98 | start = [ 99 | repeat(projected_origins, "b xy -> (b p) xy", p=4), 100 | rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"), 101 | ] 102 | start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4) 103 | image = draw_lines( 104 | image, 105 | start, 106 | repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2), 107 | color=repeat(color, "b c -> (b r p) c", r=2, p=4), 108 | width=2, 109 | x_range=x_range, 110 | y_range=y_range, 111 | ) 112 | 113 | x_name = "XYZ"[image_x_axis] 114 | y_name = "XYZ"[image_y_axis] 115 | image = add_label(image, f"{x_name}{y_name} Projection") 116 | 117 | # TODO: Draw axis indicators. 118 | projections.append(image) 119 | 120 | return torch.stack(projections) 121 | 122 | 123 | def compute_aabb( 124 | extrinsics: Float[Tensor, "batch 4 4"], 125 | intrinsics: Float[Tensor, "batch 3 3"], 126 | near: Optional[Scalar] = None, 127 | far: Optional[Scalar] = None, 128 | ) -> tuple[ 129 | Float[Tensor, "3"], # minima of the scene 130 | Float[Tensor, "3"], # maxima of the scene 131 | ]: 132 | """Compute an axis-aligned bounding box for the camera frustums.""" 133 | 134 | device = extrinsics.device 135 | 136 | # These points are included in the AABB. 137 | points = [extrinsics[:, :3, 3]] 138 | 139 | if near is not None: 140 | near = sanitize_scalar(near, device) 141 | corners = unproject_frustum_corners(extrinsics, intrinsics, near) 142 | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) 143 | 144 | if far is not None: 145 | far = sanitize_scalar(far, device) 146 | corners = unproject_frustum_corners(extrinsics, intrinsics, far) 147 | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) 148 | 149 | points = torch.cat(points, dim=0) 150 | return points.min(dim=0).values, points.max(dim=0).values 151 | 152 | 153 | def compute_equal_aabb_with_margin( 154 | minima: Float[Tensor, "*#batch 3"], 155 | maxima: Float[Tensor, "*#batch 3"], 156 | margin: float = 0.1, 157 | ) -> tuple[ 158 | Float[Tensor, "*batch 3"], # minima of the scene 159 | Float[Tensor, "*batch 3"], # maxima of the scene 160 | ]: 161 | midpoint = (maxima + minima) * 0.5 162 | span = (maxima - minima).max() * (1 + margin) 163 | scene_minima = midpoint - 0.5 * span 164 | scene_maxima = midpoint + 0.5 * span 165 | return scene_minima, scene_maxima 166 | 167 | 168 | def unproject_frustum_corners( 169 | extrinsics: Float[Tensor, "batch 4 4"], 170 | intrinsics: Float[Tensor, "batch 3 3"], 171 | depth: Float[Tensor, "#batch"], 172 | ) -> Float[Tensor, "batch 4 3"]: 173 | device = extrinsics.device 174 | 175 | # Get coordinates for the corners. Following them in a circle makes a rectangle. 176 | xy = torch.linspace(0, 1, 2, device=device) 177 | xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1) 178 | xy = rearrange(xy, "i j xy -> (i j) xy") 179 | xy = xy[torch.tensor([0, 1, 3, 2], device=device)] 180 | 181 | # Get ray directions in camera space. 182 | directions = unproject( 183 | xy, 184 | torch.ones(1, dtype=torch.float32, device=device), 185 | rearrange(intrinsics, "b i j -> b () i j"), 186 | ) 187 | 188 | # Divide by the z coordinate so that multiplying by depth will produce orthographic 189 | # depth (z depth) as opposed to Euclidean depth (distance from the camera). 190 | directions = directions / directions[..., -1:] 191 | directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i") 192 | 193 | origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz") 194 | depth = rearrange(depth, "b -> b () ()") 195 | return origins + depth * directions 196 | -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [ 88 | color_function(batch) 89 | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) 90 | ] 91 | color = torch.cat(color, dim=0) 92 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 93 | 94 | # If any MSAA passes remain, subdivide. 95 | if remaining_passes > 0: 96 | mask = detect_msaa_pixels(color) 97 | batch_index, row_index, col_index = torch.where(mask) 98 | xy = xy[batch_index, row_index, col_index] 99 | 100 | offsets = generate_sample_grid((subdivision, subdivision), device) 101 | offsets = (offsets / subdivision - 0.5) * scale 102 | 103 | color_fine = run_msaa_pass( 104 | xy[:, None, None] + offsets, 105 | color_function, 106 | scale / subdivision, 107 | subdivision, 108 | remaining_passes - 1, 109 | device, 110 | batch_size=batch_size, 111 | ) 112 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 113 | 114 | return color 115 | 116 | 117 | @torch.no_grad() 118 | def render( 119 | shape: tuple[int, int], 120 | color_function: ColorFunction, 121 | device: torch.device, 122 | subdivision: int = 8, 123 | num_passes: int = 2, 124 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 125 | xy = generate_sample_grid(shape, device) 126 | return run_msaa_pass( 127 | xy[None], 128 | color_function, 129 | 1.0, 130 | subdivision, 131 | num_passes, 132 | device, 133 | )[0] 134 | 135 | 136 | def render_over_image( 137 | image: Float[Tensor, "3 height width"], 138 | color_function: ColorFunction, 139 | device: torch.device, 140 | subdivision: int = 8, 141 | num_passes: int = 1, 142 | ) -> Float[Tensor, "3 height width"]: 143 | _, h, w = image.shape 144 | overlay = render( 145 | (h, w), 146 | color_function, 147 | device, 148 | subdivision=subdivision, 149 | num_passes=num_passes, 150 | ) 151 | color, alpha = overlay.split((3, 1), dim=0) 152 | return image * (1 - alpha) + color * alpha 153 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/layout.py: -------------------------------------------------------------------------------- 1 | """This file contains useful layout utilities for images. They are: 2 | 3 | - add_border: Add a border to an image. 4 | - cat/hcat/vcat: Join images by arranging them in a line. If the images have different 5 | sizes, they are aligned as specified (start, end, center). Allows you to specify a gap 6 | between images. 7 | 8 | Images are assumed to be float32 tensors with shape (channel, height, width). 9 | """ 10 | 11 | from typing import Any, Generator, Iterable, Literal, Optional, Union 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from jaxtyping import Float 16 | from torch import Tensor 17 | 18 | Alignment = Literal["start", "center", "end"] 19 | Axis = Literal["horizontal", "vertical"] 20 | Color = Union[ 21 | int, 22 | float, 23 | Iterable[int], 24 | Iterable[float], 25 | Float[Tensor, "#channel"], 26 | Float[Tensor, ""], 27 | ] 28 | 29 | 30 | def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]: 31 | # Convert tensor to list (or individual item). 32 | if isinstance(color, torch.Tensor): 33 | color = color.tolist() 34 | 35 | # Turn iterators and individual items into lists. 36 | if isinstance(color, Iterable): 37 | color = list(color) 38 | else: 39 | color = [color] 40 | 41 | return torch.tensor(color, dtype=torch.float32) 42 | 43 | 44 | def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: 45 | it = iter(iterable) 46 | yield next(it) 47 | for item in it: 48 | yield delimiter 49 | yield item 50 | 51 | 52 | def _get_main_dim(main_axis: Axis) -> int: 53 | return { 54 | "horizontal": 2, 55 | "vertical": 1, 56 | }[main_axis] 57 | 58 | 59 | def _get_cross_dim(main_axis: Axis) -> int: 60 | return { 61 | "horizontal": 1, 62 | "vertical": 2, 63 | }[main_axis] 64 | 65 | 66 | def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: 67 | assert base >= overlay 68 | offset = { 69 | "start": 0, 70 | "center": (base - overlay) // 2, 71 | "end": base - overlay, 72 | }[align] 73 | return slice(offset, offset + overlay) 74 | 75 | 76 | def overlay( 77 | base: Float[Tensor, "channel base_height base_width"], 78 | overlay: Float[Tensor, "channel overlay_height overlay_width"], 79 | main_axis: Axis, 80 | main_axis_alignment: Alignment, 81 | cross_axis_alignment: Alignment, 82 | ) -> Float[Tensor, "channel base_height base_width"]: 83 | # The overlay must be smaller than the base. 84 | _, base_height, base_width = base.shape 85 | _, overlay_height, overlay_width = overlay.shape 86 | assert base_height >= overlay_height and base_width >= overlay_width 87 | 88 | # Compute spacing on the main dimension. 89 | main_dim = _get_main_dim(main_axis) 90 | main_slice = _compute_offset( 91 | base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment 92 | ) 93 | 94 | # Compute spacing on the cross dimension. 95 | cross_dim = _get_cross_dim(main_axis) 96 | cross_slice = _compute_offset( 97 | base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment 98 | ) 99 | 100 | # Combine the slices and paste the overlay onto the base accordingly. 101 | selector = [..., None, None] 102 | selector[main_dim] = main_slice 103 | selector[cross_dim] = cross_slice 104 | result = base.clone() 105 | result[selector] = overlay 106 | return result 107 | 108 | 109 | def cat( 110 | main_axis: Axis, 111 | *images: Iterable[Float[Tensor, "channel _ _"]], 112 | align: Alignment = "center", 113 | gap: int = 8, 114 | gap_color: Color = 1, 115 | ) -> Float[Tensor, "channel height width"]: 116 | """Arrange images in a line. The interface resembles a CSS div with flexbox.""" 117 | device = images[0].device 118 | gap_color = _sanitize_color(gap_color).to(device) 119 | 120 | # Find the maximum image side length in the cross axis dimension. 121 | cross_dim = _get_cross_dim(main_axis) 122 | cross_axis_length = max(image.shape[cross_dim] for image in images) 123 | 124 | # Pad the images. 125 | padded_images = [] 126 | for image in images: 127 | # Create an empty image with the correct size. 128 | padded_shape = list(image.shape) 129 | padded_shape[cross_dim] = cross_axis_length 130 | base = torch.ones(padded_shape, dtype=torch.float32, device=device) 131 | base = base * gap_color[:, None, None] 132 | padded_images.append(overlay(base, image, main_axis, "start", align)) 133 | 134 | # Intersperse separators if necessary. 135 | if gap > 0: 136 | # Generate a separator. 137 | c, _, _ = images[0].shape 138 | separator_size = [gap, gap] 139 | separator_size[cross_dim - 1] = cross_axis_length 140 | separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) 141 | separator = separator * gap_color[:, None, None] 142 | 143 | # Intersperse the separator between the images. 144 | padded_images = list(_intersperse(padded_images, separator)) 145 | 146 | return torch.cat(padded_images, dim=_get_main_dim(main_axis)) 147 | 148 | 149 | def hcat( 150 | *images: Iterable[Float[Tensor, "channel _ _"]], 151 | align: Literal["start", "center", "end", "top", "bottom"] = "start", 152 | gap: int = 8, 153 | gap_color: Color = 1, 154 | ): 155 | """Shorthand for a horizontal linear concatenation.""" 156 | return cat( 157 | "horizontal", 158 | *images, 159 | align={ 160 | "start": "start", 161 | "center": "center", 162 | "end": "end", 163 | "top": "start", 164 | "bottom": "end", 165 | }[align], 166 | gap=gap, 167 | gap_color=gap_color, 168 | ) 169 | 170 | 171 | def vcat( 172 | *images: Iterable[Float[Tensor, "channel _ _"]], 173 | align: Literal["start", "center", "end", "left", "right"] = "start", 174 | gap: int = 8, 175 | gap_color: Color = 1, 176 | ): 177 | """Shorthand for a horizontal linear concatenation.""" 178 | return cat( 179 | "vertical", 180 | *images, 181 | align={ 182 | "start": "start", 183 | "center": "center", 184 | "end": "end", 185 | "left": "start", 186 | "right": "end", 187 | }[align], 188 | gap=gap, 189 | gap_color=gap_color, 190 | ) 191 | 192 | 193 | def add_border( 194 | image: Float[Tensor, "channel height width"], 195 | border: int = 8, 196 | color: Color = 1, 197 | ) -> Float[Tensor, "channel new_height new_width"]: 198 | color = _sanitize_color(color).to(image) 199 | c, h, w = image.shape 200 | result = torch.empty( 201 | (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device 202 | ) 203 | result[:] = color[:, None, None] 204 | result[:, border : h + border, border : w + border] = image 205 | return result 206 | 207 | 208 | def resize( 209 | image: Float[Tensor, "channel height width"], 210 | shape: Optional[tuple[int, int]] = None, 211 | width: Optional[int] = None, 212 | height: Optional[int] = None, 213 | ) -> Float[Tensor, "channel new_height new_width"]: 214 | assert (shape is not None) + (width is not None) + (height is not None) == 1 215 | _, h, w = image.shape 216 | 217 | if width is not None: 218 | shape = (int(h * width / w), width) 219 | elif height is not None: 220 | shape = (height, int(w * height / h)) 221 | 222 | return F.interpolate( 223 | image[None], 224 | shape, 225 | mode="bilinear", 226 | align_corners=False, 227 | antialias="bilinear", 228 | )[0] 229 | -------------------------------------------------------------------------------- /src/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [ 16 | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) 17 | for x in images 18 | ] 19 | for image, result in zip(images, results): 20 | slices = [slice(0, x) for x in image.shape] 21 | result[slices] = image[slices] 22 | return results 23 | 24 | 25 | def render_projections( 26 | gaussians: Gaussians, 27 | resolution: int, 28 | margin: float = 0.1, 29 | draw_label: bool = True, 30 | extra_label: str = "", 31 | ) -> Float[Tensor, "batch 3 3 height width"]: 32 | device = gaussians.means.device 33 | b, _, _ = gaussians.means.shape 34 | 35 | # Compute the minima and maxima of the scene. 36 | minima = gaussians.means.min(dim=1).values 37 | maxima = gaussians.means.max(dim=1).values 38 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 39 | minima, maxima, margin=margin 40 | ) 41 | 42 | projections = [] 43 | for look_axis in range(3): 44 | right_axis = (look_axis + 1) % 3 45 | down_axis = (look_axis + 2) % 3 46 | 47 | # Define the extrinsics for rendering. 48 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 49 | extrinsics[:, right_axis, 0] = 1 50 | extrinsics[:, down_axis, 1] = 1 51 | extrinsics[:, look_axis, 2] = 1 52 | extrinsics[:, right_axis, 3] = 0.5 * ( 53 | scene_minima[:, right_axis] + scene_maxima[:, right_axis] 54 | ) 55 | extrinsics[:, down_axis, 3] = 0.5 * ( 56 | scene_minima[:, down_axis] + scene_maxima[:, down_axis] 57 | ) 58 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 59 | extrinsics[:, 3, 3] = 1 60 | 61 | # Define the intrinsics for rendering. 62 | extents = scene_maxima - scene_minima 63 | far = extents[:, look_axis] 64 | near = torch.zeros_like(far) 65 | width = extents[:, right_axis] 66 | height = extents[:, down_axis] 67 | 68 | projection = render_cuda_orthographic( 69 | extrinsics, 70 | width, 71 | height, 72 | near, 73 | far, 74 | (resolution, resolution), 75 | torch.zeros((b, 3), dtype=torch.float32, device=device), 76 | gaussians.means, 77 | gaussians.covariances, 78 | gaussians.harmonics, 79 | gaussians.opacities, 80 | fov_degrees=10.0, 81 | ) 82 | if draw_label: 83 | right_axis_name = "XYZ"[right_axis] 84 | down_axis_name = "XYZ"[down_axis] 85 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 86 | projection = torch.stack([add_label(x, label) for x in projection]) 87 | 88 | projections.append(projection) 89 | 90 | return torch.stack(pad(projections), dim=1) 91 | 92 | 93 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 94 | # Define colors for context and target views. 95 | num_context_views = batch["context"]["extrinsics"].shape[1] 96 | num_target_views = batch["target"]["extrinsics"].shape[1] 97 | color = torch.ones( 98 | (num_target_views + num_context_views, 3), 99 | dtype=torch.float32, 100 | device=batch["target"]["extrinsics"].device, 101 | ) 102 | color[num_context_views:, 1:] = 0 103 | 104 | return draw_cameras( 105 | resolution, 106 | torch.cat( 107 | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) 108 | ), 109 | torch.cat( 110 | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) 111 | ), 112 | color, 113 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 114 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 115 | ) 116 | -------------------------------------------------------------------------------- /src/visualization/vis_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torchvision.utils as vutils 5 | import cv2 6 | from matplotlib.cm import get_cmap 7 | import matplotlib as mpl 8 | import matplotlib.cm as cm 9 | 10 | 11 | # https://github.com/autonomousvision/unimatch/blob/master/utils/visualization.py 12 | 13 | 14 | def vis_disparity(disp): 15 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 16 | disp_vis = disp_vis.astype("uint8") 17 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 18 | 19 | return disp_vis 20 | 21 | 22 | def viz_depth_tensor(disp, return_numpy=False, colormap="plasma"): 23 | # visualize inverse depth 24 | assert isinstance(disp, torch.Tensor) 25 | 26 | disp = disp.numpy() 27 | vmax = np.percentile(disp, 95) 28 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 29 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 30 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype( 31 | np.uint8 32 | ) # [H, W, 3] 33 | 34 | if return_numpy: 35 | return colormapped_im 36 | 37 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 38 | 39 | return viz 40 | --------------------------------------------------------------------------------