├── .gitignore ├── DATASETS.md ├── LICENSE ├── README.md ├── assets ├── Inter-Regular.otf ├── evaluation_index_acid.json ├── evaluation_index_dtu.json ├── evaluation_index_re10k.json └── evaluation_index_scannetpp.json ├── config ├── compute_metrics.yaml ├── dataset │ ├── base_dataset.yaml │ ├── dl3dv.yaml │ ├── re10k.yaml │ ├── scannet.yaml │ ├── scannetpp.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── base_view_sampler.yaml │ │ ├── bounded_dl3dv.yaml │ │ ├── bounded_re10k.yaml │ │ ├── bounded_scannetpp.yaml │ │ └── evaluation_re10k.yaml ├── evaluation │ ├── acid.yaml │ ├── eval_pose.yaml │ └── re10k.yaml ├── experiment │ ├── acid.yaml │ ├── dl3dv.yaml │ ├── re10k.yaml │ ├── re10k_1x8.yaml │ ├── re10k_3view.yaml │ ├── re10k_dl3dv.yaml │ ├── re10k_dl3dv_512x512.yaml │ └── scannet_pose.yaml ├── generate_evaluation_index.yaml ├── loss │ ├── depth.yaml │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ ├── backbone │ └── croco.yaml │ └── noposplat.yaml ├── pyproject.toml ├── requirements.txt └── src ├── config.py ├── dataset ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_re10k.py ├── dataset_scannet_pose.py ├── shims │ ├── augmentation_shim.py │ ├── bounds_shim.py │ ├── crop_shim.py │ ├── normalize_shim.py │ └── patch_shim.py ├── types.py ├── validation_wrapper.py └── view_sampler │ ├── __init__.py │ ├── additional_view_hack.py │ ├── view_sampler.py │ ├── view_sampler_all.py │ ├── view_sampler_arbitrary.py │ ├── view_sampler_bounded.py │ └── view_sampler_evaluation.py ├── eval_pose.py ├── evaluation ├── evaluation_cfg.py ├── evaluation_index_generator.py ├── metric_computer.py ├── metrics.py └── pose_evaluator.py ├── geometry ├── camera_emb.py ├── epipolar_lines.py ├── projection.py └── ptc_geometry.py ├── global_cfg.py ├── loss ├── __init__.py ├── loss.py ├── loss_depth.py ├── loss_lpips.py ├── loss_mse.py ├── loss_point.py └── loss_ssim.py ├── main.py ├── misc ├── LocalLogger.py ├── benchmarker.py ├── cam_utils.py ├── collation.py ├── discrete_probability_distribution.py ├── heterogeneous_pairings.py ├── image_io.py ├── nn_module_tools.py ├── sh_rotation.py ├── sht.py ├── step_tracker.py ├── utils.py ├── wandb_tools.py └── weight_modify.py ├── model ├── decoder │ ├── __init__.py │ ├── cuda_splatting.py │ ├── decoder.py │ └── decoder_splatting_cuda.py ├── distiller │ ├── __init__.py │ └── dust3d_backbone.py ├── encoder │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── backbone_croco.py │ │ ├── backbone_croco_multiview.py │ │ ├── backbone_dino.py │ │ ├── backbone_resnet.py │ │ └── croco │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ ├── croco.py │ │ │ ├── curope │ │ │ ├── __init__.py │ │ │ ├── curope.cpp │ │ │ ├── curope2d.py │ │ │ ├── kernels.cu │ │ │ └── setup.py │ │ │ ├── masking.py │ │ │ ├── misc.py │ │ │ ├── patch_embed.py │ │ │ └── pos_embed.py │ ├── common │ │ ├── gaussian_adapter.py │ │ └── gaussians.py │ ├── encoder.py │ ├── encoder_noposplat.py │ ├── encoder_noposplat_multi.py │ ├── heads │ │ ├── __init__.py │ │ ├── dpt_block.py │ │ ├── dpt_gs_head.py │ │ ├── dpt_head.py │ │ ├── head_modules.py │ │ ├── linear_head.py │ │ └── postprocess.py │ └── visualization │ │ ├── encoder_visualizer.py │ │ ├── encoder_visualizer_epipolar.py │ │ └── encoder_visualizer_epipolar_cfg.py ├── encodings │ └── positional_encoding.py ├── model_wrapper.py ├── ply_export.py ├── transformer │ ├── attention.py │ ├── feed_forward.py │ ├── pre_norm.py │ └── transformer.py └── types.py ├── scripts ├── compute_metrics.py └── convert_dl3dv.py └── visualization ├── annotation.py ├── camera_trajectory ├── interpolation.py ├── spin.py └── wobble.py ├── color_map.py ├── colors.py ├── drawing ├── cameras.py ├── coordinate_conversion.py ├── lines.py ├── points.py ├── rendering.py └── types.py ├── layout.py ├── validation_in_3d.py └── video_render.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | /datasets 163 | /dataset_cache 164 | 165 | # Outputs 166 | /outputs 167 | /lightning_logs 168 | /checkpoints 169 | 170 | .bashrc 171 | /launcher_venv 172 | /slurm_logs 173 | *.torch 174 | *.ckpt 175 | table.tex 176 | /baselines 177 | /test/* -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | For training, we mainly use [RealEstate10K](https://google.github.io/realestate10k/index.html), [DL3DV](https://github.com/DL3DV-10K/Dataset), and [ACID](https://infinite-nature.github.io/) datasets. We provide the data processing scripts to convert the original datasets to pytorch chunk files which can be directly loaded with this codebase. 4 | 5 | Expected folder structure: 6 | 7 | ``` 8 | ├── datasets 9 | │ ├── re10k 10 | │ ├── ├── train 11 | │ ├── ├── ├── 000000.torch 12 | │ ├── ├── ├── ... 13 | │ ├── ├── ├── index.json 14 | │ ├── ├── test 15 | │ ├── ├── ├── 000000.torch 16 | │ ├── ├── ├── ... 17 | │ ├── ├── ├── index.json 18 | │ ├── dl3dv 19 | │ ├── ├── train 20 | │ ├── ├── ├── 000000.torch 21 | │ ├── ├── ├── ... 22 | │ ├── ├── ├── index.json 23 | │ ├── ├── test 24 | │ ├── ├── ├── 000000.torch 25 | │ ├── ├── ├── ... 26 | │ ├── ├── ├── index.json 27 | ``` 28 | 29 | By default, we assume the datasets are placed in `datasets/re10k`, `datasets/dl3dv`, and `datasets/acid`. Otherwise you will need to specify your dataset path with `dataset.DATASET_NAME.roots=[YOUR_DATASET_PATH]` in the running script. 30 | 31 | We also provide instructions to convert additional datasets to the desired format. 32 | 33 | 34 | 35 | ## RealEstate10K 36 | 37 | For experiments on RealEstate10K, we primarily follow [pixelSplat](https://github.com/dcharatan/pixelsplat) and [MVSplat](https://github.com/donydchen/mvsplat) to train and evaluate on 256x256 resolution. 38 | 39 | Please refer to [here](https://github.com/dcharatan/pixelsplat?tab=readme-ov-file#acquiring-datasets) for acquiring the processed 360p dataset (360x640 resolution). 40 | 41 | If you would like to train and evaluate on the high-resolution RealEstate10K dataset, you will need to download the 720p (720x1280) version. Please refer to [here](https://github.com/yilundu/cross_attention_renderer/tree/master/data_download) for the downloading script. Note that the script by default downloads the 360p videos, you will need to modify the`360p` to `720p` in [this line of code](https://github.com/yilundu/cross_attention_renderer/blob/master/data_download/generate_realestate.py#L137) to download the 720p videos. 42 | 43 | After downloading the 720p dataset, you can use the scripts [here](https://github.com/dcharatan/real_estate_10k_tools/tree/main/src) to convert the dataset to the desired format in this codebase. 44 | 45 | 46 | 47 | ## DL3DV 48 | 49 | In the DL3DV experiments, we trained with RealEstate10k at 256x256, 512x512 and 368x640 resolutions, respectively. 50 | 51 | For the training set, we use the [DL3DV-480p](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-480P) dataset (270x480 resolution), where the 140 scenes in the test set are excluded during processing the training set. After downloading the [DL3DV-480p](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-480P) dataset, you can then use the script [src/scripts/convert_dl3dv.py](src/scripts/convert_dl3dv.py) to convert the training set. 52 | 53 | Please note that you will need to update the dataset paths in the aforementioned processing scripts. 54 | 55 | If you would like to train on the high-resolution DL3DV dataset, you will need to download the [DL3DV-960P](https://huggingface.co/datasets/DL3DV/DL3DV-ALL-960P) version (540x960 resolution). Simply follow the same procedure for data processing (use the `images_4` folder instead of `images_8`). 56 | 57 | 58 | 59 | ## Test Only Datasets 60 | 61 | ### DTU 62 | 63 | You can download processed DTU datasets from [here](https://drive.google.com/file/d/1Bd9il_O1jjom6Lk9NP77K8aeQP0d9eEt/view?usp=drive_link). 64 | 65 | ### ScanNet++ 66 | 67 | You can download processed ScanNet++ datasets from [here](https://drive.google.com/file/d/1bmkNjXuWLhAOkf-6liS0yyARCybZyQSE/view?usp=sharing). 68 | 69 | ### ScanNet-1500 70 | 71 | For ScanNet-1500, you need to download download `test.npy` [here](https://github.com/zju3dv/LoFTR/blob/master/assets/scannet_test_1500/test.npz) and the corresponding test dataset [here](https://drive.google.com/file/d/1wtl-mNicxGlXZ-UQJxFnKuWPvvssQBwd/view). 72 | 73 | ## Custom Datasets 74 | 75 | If you would like to train and/or evaluate on additional datasets, just modify the [data processing scripts](src/scripts) to convert the dataset format. Kindly note the [camera conventions](https://github.com/cvg/NoPoSplat/tree/main?tab=readme-ov-file#camera-conventions) used in this codebase. 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Botao Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/Inter-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/NoPoSplat/9b04307ebe179d610c04208db8d69f7c3106d03b/assets/Inter-Regular.otf -------------------------------------------------------------------------------- /assets/evaluation_index_dtu.json: -------------------------------------------------------------------------------- 1 | { 2 | "scan1_train": { 3 | "context": [ 4 | 33, 5 | 31 6 | ], 7 | "target": [ 8 | 32, 9 | 24, 10 | 44, 11 | 23 12 | ] 13 | }, 14 | "scan8_train": { 15 | "context": [ 16 | 33, 17 | 31 18 | ], 19 | "target": [ 20 | 32, 21 | 24, 22 | 44, 23 | 23 24 | ] 25 | }, 26 | "scan21_train": { 27 | "context": [ 28 | 33, 29 | 31 30 | ], 31 | "target": [ 32 | 32, 33 | 24, 34 | 44, 35 | 23 36 | ] 37 | }, 38 | "scan30_train": { 39 | "context": [ 40 | 33, 41 | 31 42 | ], 43 | "target": [ 44 | 32, 45 | 24, 46 | 44, 47 | 23 48 | ] 49 | }, 50 | "scan31_train": { 51 | "context": [ 52 | 33, 53 | 31 54 | ], 55 | "target": [ 56 | 32, 57 | 24, 58 | 44, 59 | 23 60 | ] 61 | }, 62 | "scan34_train": { 63 | "context": [ 64 | 33, 65 | 31 66 | ], 67 | "target": [ 68 | 32, 69 | 24, 70 | 44, 71 | 23 72 | ] 73 | }, 74 | "scan38_train": { 75 | "context": [ 76 | 33, 77 | 31 78 | ], 79 | "target": [ 80 | 32, 81 | 24, 82 | 44, 83 | 23 84 | ] 85 | }, 86 | "scan40_train": { 87 | "context": [ 88 | 33, 89 | 31 90 | ], 91 | "target": [ 92 | 32, 93 | 24, 94 | 44, 95 | 23 96 | ] 97 | }, 98 | "scan41_train": { 99 | "context": [ 100 | 33, 101 | 31 102 | ], 103 | "target": [ 104 | 32, 105 | 24, 106 | 44, 107 | 23 108 | ] 109 | }, 110 | "scan45_train": { 111 | "context": [ 112 | 33, 113 | 31 114 | ], 115 | "target": [ 116 | 32, 117 | 24, 118 | 44, 119 | 23 120 | ] 121 | }, 122 | "scan55_train": { 123 | "context": [ 124 | 33, 125 | 31 126 | ], 127 | "target": [ 128 | 32, 129 | 24, 130 | 44, 131 | 23 132 | ] 133 | }, 134 | "scan63_train": { 135 | "context": [ 136 | 33, 137 | 31 138 | ], 139 | "target": [ 140 | 32, 141 | 24, 142 | 44, 143 | 23 144 | ] 145 | }, 146 | "scan82_train": { 147 | "context": [ 148 | 33, 149 | 31 150 | ], 151 | "target": [ 152 | 32, 153 | 24, 154 | 44, 155 | 23 156 | ] 157 | }, 158 | "scan103_train": { 159 | "context": [ 160 | 33, 161 | 31 162 | ], 163 | "target": [ 164 | 32, 165 | 24, 166 | 44, 167 | 23 168 | ] 169 | }, 170 | "scan110_train": { 171 | "context": [ 172 | 33, 173 | 31 174 | ], 175 | "target": [ 176 | 32, 177 | 24, 178 | 44, 179 | 23 180 | ] 181 | }, 182 | "scan114_train": { 183 | "context": [ 184 | 33, 185 | 31 186 | ], 187 | "target": [ 188 | 32, 189 | 24, 190 | 44, 191 | 23 192 | ] 193 | } 194 | } -------------------------------------------------------------------------------- /assets/evaluation_index_scannetpp.json: -------------------------------------------------------------------------------- 1 | {"be2e10f16a_iphone": {"context": [184, 192], "target": [184, 187, 190], "overlap": -1.0}, "497588b572_iphone": {"context": [0, 16], "target": [1, 6, 12], "overlap": -1.0}, "02455b3d20_iphone": {"context": [766, 772], "target": [769, 770, 771], "overlap": -1.0}, "3e6ceea56c_iphone": {"context": [709, 727], "target": [711, 719, 726], "overlap": -1.0}, "f7a60ba2a2_iphone": {"context": [357, 365], "target": [357, 360, 363], "overlap": -1.0}, "2a1a3afad9_iphone": {"context": [256, 273], "target": [258, 263, 271], "overlap": -1.0}, "3a161a857d_iphone": {"context": [889, 897], "target": [891, 892, 895], "overlap": -1.0}, "0529d56cce_iphone": {"context": [21, 27], "target": [22, 24, 25], "overlap": -1.0}, "6ebe30292e_iphone": {"context": [224, 231], "target": [224, 225, 229], "overlap": -1.0}, "7dfdff1b7d_iphone": {"context": [601, 608], "target": [601, 603, 606], "overlap": -1.0}, "be66c57b92_iphone": {"context": [215, 221], "target": [215, 216, 218], "overlap": -1.0}, "be6205d016_iphone": {"context": [160, 167], "target": [160, 166, 167], "overlap": -1.0}, "569f99f881_iphone": {"context": [101, 114], "target": [105, 108, 112], "overlap": -1.0}, "da8043d54e_iphone": {"context": [561, 570], "target": [564, 565, 569], "overlap": -1.0}, "06a3d79b68_iphone": {"context": [250, 256], "target": [252, 254, 256], "overlap": -1.0}, "f94c225e84_iphone": {"context": [472, 488], "target": [480, 481, 482], "overlap": -1.0}, "a24858e51e_iphone": {"context": [375, 382], "target": [377, 381, 382], "overlap": -1.0}, "cd2994fcc1_iphone": {"context": [39, 45], "target": [39, 43, 44], "overlap": -1.0}, "a897272241_iphone": {"context": [581, 591], "target": [581, 582, 585], "overlap": -1.0}, "285efbc7cf_iphone": {"context": [370, 382], "target": [370, 377, 380], "overlap": -1.0}, "2e67a32314_iphone": {"context": [86, 92], "target": [88, 89, 92], "overlap": -1.0}, "dffce1cf9a_iphone": {"context": [320, 327], "target": [321, 322, 327], "overlap": -1.0}, "e6afbe3753_iphone": {"context": [65, 71], "target": [66, 68, 71], "overlap": -1.0}, "d070e22e3b_iphone": {"context": [254, 271], "target": [262, 264, 269], "overlap": -1.0}, "9f7641ce94_iphone": {"context": [111, 125], "target": [112, 116, 119], "overlap": -1.0}, "aa6e508f0c_iphone": {"context": [244, 261], "target": [248, 249, 260], "overlap": -1.0}, "5371eff4f9_iphone": {"context": [202, 214], "target": [204, 207, 212], "overlap": -1.0}, "cd88899edb_iphone": {"context": [194, 200], "target": [195, 196, 198], "overlap": -1.0}, "21532e059d_iphone": {"context": [85, 101], "target": [86, 91, 100], "overlap": -1.0}, "47eb87b5bb_iphone": {"context": [442, 455], "target": [445, 446, 452], "overlap": -1.0}, "d662592b54_iphone": {"context": [84, 94], "target": [85, 87, 92], "overlap": -1.0}, "68739bdf1f_iphone": {"context": [426, 433], "target": [427, 428, 429], "overlap": -1.0}, "accad58571_iphone": {"context": [76, 90], "target": [78, 86, 90], "overlap": -1.0}, "4ef75031e3_iphone": {"context": [899, 905], "target": [899, 900, 902], "overlap": -1.0}, "1be2c31cac_iphone": {"context": [1324, 1337], "target": [1325, 1329, 1333], "overlap": -1.0}, "9663292843_iphone": {"context": [406, 412], "target": [407, 410, 412], "overlap": -1.0}, "7ffc86edf4_iphone": {"context": [406, 413], "target": [408, 411, 413], "overlap": -1.0}, "54bca9597e_iphone": {"context": [131, 137], "target": [132, 136, 137], "overlap": -1.0}, "728daff2a3_iphone": {"context": [2, 31], "target": [4, 21, 26], "overlap": -1.0}, "18fd041970_iphone": {"context": [631, 640], "target": [633, 637, 638], "overlap": -1.0}, "eeddfe67f5_iphone": {"context": [288, 294], "target": [289, 292, 293], "overlap": -1.0}, "b09b119466_iphone": {"context": [178, 190], "target": [179, 181, 189], "overlap": -1.0}, "ab252b28c0_iphone": {"context": [244, 256], "target": [246, 247, 253], "overlap": -1.0}, "cc49215f67_iphone": {"context": [211, 226], "target": [214, 216, 224], "overlap": -1.0}, "154c3e10d9_iphone": {"context": [211, 220], "target": [211, 213, 219], "overlap": -1.0}, "29b607c6d5_iphone": {"context": [883, 889], "target": [883, 885, 889], "overlap": -1.0}, "ba414a3e6f_iphone": {"context": [474, 488], "target": [479, 481, 488], "overlap": -1.0}, "954633ea01_iphone": {"context": [422, 433], "target": [429, 430, 431], "overlap": -1.0}, "6f12492455_iphone": {"context": [365, 372], "target": [366, 370, 372], "overlap": -1.0}, "ea15f3457c_iphone": {"context": [285, 302], "target": [285, 294, 301], "overlap": -1.0}} -------------------------------------------------------------------------------- /config/compute_metrics.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model/encoder: noposplat 3 | - loss: [] 4 | - override dataset/view_sampler@dataset.re10k.view_sampler: evaluation 5 | 6 | dataset: 7 | re10k: 8 | view_sampler: 9 | index_path: assets/evaluation_index_re10k.json 10 | 11 | data_loader: 12 | train: 13 | num_workers: 0 14 | persistent_workers: true 15 | batch_size: 1 16 | seed: 1234 17 | test: 18 | num_workers: 4 19 | persistent_workers: false 20 | batch_size: 1 21 | seed: 2345 22 | val: 23 | num_workers: 0 24 | persistent_workers: true 25 | batch_size: 1 26 | seed: 3456 27 | 28 | seed: 111123 29 | -------------------------------------------------------------------------------- /config/dataset/base_dataset.yaml: -------------------------------------------------------------------------------- 1 | make_baseline_1: true 2 | relative_pose: true 3 | augment: true 4 | background_color: [0.0, 0.0, 0.0] 5 | overfit_to_scene: null 6 | skip_bad_shape: true 7 | -------------------------------------------------------------------------------- /config/dataset/dl3dv.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_dataset 3 | - view_sampler: bounded 4 | - optional view_sampler_dataset_specific_config@view_sampler: bounded_dl3dv 5 | 6 | name: dl3dv 7 | roots: [datasets/dl3dv] 8 | 9 | input_image_shape: [256, 256] 10 | original_image_shape: [270, 480] 11 | cameras_are_circular: false 12 | 13 | baseline_min: 1e-3 14 | baseline_max: 1e2 15 | max_fov: 100.0 16 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_dataset 3 | - view_sampler: bounded 4 | - optional view_sampler_dataset_specific_config@view_sampler: bounded_re10k 5 | 6 | name: re10k 7 | roots: [datasets/re10k] 8 | 9 | input_image_shape: [256, 256] 10 | original_image_shape: [360, 640] 11 | cameras_are_circular: false 12 | 13 | baseline_min: 1e-3 14 | baseline_max: 1e10 15 | max_fov: 100.0 16 | -------------------------------------------------------------------------------- /config/dataset/scannet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_dataset 3 | - view_sampler: bounded 4 | 5 | name: scannet_pose 6 | roots: [datasets/scannet_pose_test_1500] 7 | make_baseline_1: true 8 | augment: true 9 | 10 | input_image_shape: [256, 256] 11 | original_image_shape: [720, 960] 12 | cameras_are_circular: false 13 | 14 | baseline_min: 1e-3 15 | baseline_max: 1e2 16 | max_fov: 120.0 17 | 18 | skip_bad_shape: false 19 | -------------------------------------------------------------------------------- /config/dataset/scannetpp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_dataset 3 | - view_sampler: bounded 4 | - optional view_sampler_dataset_specific_config@view_sampler: bounded_scannetpp 5 | 6 | name: scannetpp 7 | roots: [datasets/scannetpp] 8 | 9 | input_image_shape: [256, 256] 10 | original_image_shape: [720, 960] 11 | cameras_are_circular: false 12 | 13 | baseline_min: 1e-3 14 | baseline_max: 1e2 15 | max_fov: 120.0 16 | 17 | skip_bad_shape: false 18 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/base_view_sampler.yaml: -------------------------------------------------------------------------------- 1 | num_target_views: 4 2 | warm_up_steps: 150_000 3 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_dl3dv.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_view_sampler 3 | 4 | min_distance_between_context_views: 8 5 | max_distance_between_context_views: 22 6 | min_distance_to_context_views: 0 7 | initial_min_distance_between_context_views: 5 8 | initial_max_distance_between_context_views: 7 9 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_view_sampler 3 | 4 | min_distance_between_context_views: 45 5 | max_distance_between_context_views: 90 6 | min_distance_to_context_views: 0 7 | initial_min_distance_between_context_views: 25 8 | initial_max_distance_between_context_views: 25 9 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_scannetpp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_view_sampler 3 | 4 | min_distance_between_context_views: 5 5 | max_distance_between_context_views: 15 6 | min_distance_to_context_views: 0 7 | initial_min_distance_between_context_views: 5 8 | initial_max_distance_between_context_views: 5 9 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | re10k: 5 | view_sampler: 6 | index_path: assets/evaluation_index_re10k.json 7 | -------------------------------------------------------------------------------- /config/evaluation/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_acid.json 6 | 7 | evaluation: 8 | methods: 9 | - name: Ours 10 | key: ours 11 | path: baselines/acid/ours/frames 12 | - name: Du et al.~\cite{du2023cross} 13 | key: du2023 14 | path: baselines/acid/yilun/frames 15 | - name: GPNR~\cite{suhail2022generalizable} 16 | key: gpnr 17 | path: baselines/acid/gpnr/frames 18 | - name: pixelNeRF~\cite{pixelnerf} 19 | key: pixelnerf 20 | path: baselines/acid/pixelnerf/frames 21 | 22 | side_by_side_path: null 23 | animate_side_by_side: false 24 | highlighted: 25 | # Main Paper 26 | # - scene: 405dcfc20f9ba5cb 27 | # target_index: 60 28 | # - scene: 23bb6c5d93c40e13 29 | # target_index: 38 30 | # - scene: a5ab552ede3e70a0 31 | # target_index: 52 32 | # Supplemental 33 | - scene: b896ebcbe8d18553 34 | target_index: 61 35 | - scene: bfbb8ec94454b84f 36 | target_index: 89 37 | - scene: c124e820e8123518 38 | target_index: 111 39 | - scene: c06467006ab8705a 40 | target_index: 24 41 | - scene: d4c545fb5fa1a84a 42 | target_index: 23 43 | - scene: d8f7d9ef3a5a4527 44 | target_index: 138 45 | - scene: 4139c1649bcce55a 46 | target_index: 140 47 | # No Space 48 | # - scene: 20c3d098e3b28058 49 | # target_index: 360 50 | 51 | output_metrics_path: baselines/acid/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/eval_pose.yaml: -------------------------------------------------------------------------------- 1 | methods: 2 | - name: ours 3 | key: ours 4 | path: '' 5 | 6 | side_by_side_path: null 7 | animate_side_by_side: false 8 | highlighted: 9 | # Main Paper 10 | - scene: 5be4f1f46b408d68 11 | target_index: 136 12 | - scene: 800ea72b6988f63e 13 | target_index: 167 14 | - scene: d3a01038c5f21473 15 | target_index: 201 16 | 17 | output_metrics_path: baselines/re10k/evaluation_metrics.json -------------------------------------------------------------------------------- /config/evaluation/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | re10k: 5 | view_sampler: 6 | index_path: assets/evaluation_index_re10k.json 7 | 8 | evaluation: 9 | methods: 10 | - name: Ours 11 | key: ours 12 | path: outputs/test/ours_public_rerun 13 | - name: Ours_old 14 | key: ours_old 15 | path: /local/home/botaye/Codes/pixelsplat/outputs/test/ablation_IntrinToken_lr5_nossim 16 | 17 | side_by_side_path: null 18 | animate_side_by_side: false 19 | highlighted: 20 | # Main Paper 21 | - scene: 32ce9b303717a29d 22 | target_index: 216 23 | - scene: 45e6fa48ddd00e87 24 | target_index: 127 25 | - scene: 7a13c5debea977a8 26 | target_index: 37 27 | 28 | # # Supplemental 29 | # # Large 30 | # - scene: f04b433b91f569c3 31 | # target_index: 52 32 | # - scene: c3eaaa4be79355b7 33 | # target_index: 27 34 | # - scene: 7d5bbbfe59fb6d85 35 | # target_index: 30 36 | # - scene: 9a44145b51b162ee 37 | # target_index: 17 38 | # - scene: 043c48135c5e8cc2 39 | # target_index: 31 40 | # - scene: 40f6d540b9b16531 41 | # target_index: 72 42 | # - scene: c28856e5ddb0a22c 43 | # target_index: 97 44 | # - scene: f673068196024955 45 | # target_index: 80 46 | # # Medium 47 | # - scene: 750ddf09bd6d1eab 48 | # target_index: 32 49 | # - scene: ce827f0c64f90c22 50 | # target_index: 143 51 | # - scene: 7a13c5debea977a8 52 | # target_index: 37 53 | # - scene: 18c6473be3bd827a 54 | # target_index: 31 55 | # - scene: 9a44145b51b162ee 56 | # target_index: 17 57 | # - scene: 96a2338040f14ccc 58 | # target_index: 47 59 | # - scene: 7c1333f2b74b067b 60 | # target_index: 43 61 | # - scene: 3ea8d9787998f70a 62 | # target_index: 188 63 | # # Small 64 | # - scene: 515ff56d91647ed8 65 | # target_index: 53 66 | # - scene: e40ca395753837ce 67 | # target_index: 11 68 | # - scene: 33a3fc21efdc8547 69 | # target_index: 114 70 | # - scene: 1910e79a60d57aa7 71 | # target_index: 93 72 | # - scene: 04ed1812719e05f0 73 | # target_index: 80 74 | # - scene: 01fa6190cd47d125 75 | # target_index: 124 76 | # - scene: d69dc6c3720a2b8b 77 | # target_index: 128 78 | # - scene: 627be3fb033b8cc0 79 | # target_index: 84 80 | 81 | output_metrics_path: baselines/re10k/evaluation_metrics.json -------------------------------------------------------------------------------- /config/experiment/acid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: acid 11 | tags: [acid, 256x256] 12 | 13 | model: 14 | encoder: 15 | gs_params_head_type: dpt_gs 16 | pose_free: true 17 | intrinsics_embed_loc: encoder 18 | intrinsics_embed_type: token 19 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 20 | decoder: 21 | make_scale_invariant: true 22 | 23 | dataset: 24 | re10k: 25 | roots: [datasets/acid] 26 | view_sampler: 27 | warm_up_steps: 9375 28 | 29 | optimizer: 30 | lr: 2e-4 31 | warm_up_steps: 125 32 | backbone_lr_multiplier: 0.1 33 | 34 | data_loader: 35 | train: 36 | batch_size: 16 37 | 38 | trainer: 39 | max_steps: 18751 40 | val_check_interval: 500 41 | 42 | checkpointing: 43 | every_n_train_steps: 9375 -------------------------------------------------------------------------------- /config/experiment/dl3dv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.dl3dv: dl3dv 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: dl3dv 11 | tags: [dl3dv, 256x256] 12 | 13 | model: 14 | encoder: 15 | gs_params_head_type: dpt_gs 16 | pose_free: true 17 | intrinsics_embed_loc: encoder 18 | intrinsics_embed_type: token 19 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 20 | 21 | dataset: 22 | dl3dv: 23 | view_sampler: 24 | warm_up_steps: 9375 25 | 26 | optimizer: 27 | lr: 2e-4 28 | warm_up_steps: 125 29 | backbone_lr_multiplier: 0.1 30 | 31 | data_loader: 32 | train: 33 | batch_size: 16 34 | 35 | trainer: 36 | max_steps: 18751 37 | val_check_interval: 500 38 | 39 | checkpointing: 40 | every_n_train_steps: 9375 -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: re10k 11 | tags: [re10k, 256x256] 12 | 13 | model: 14 | encoder: 15 | gs_params_head_type: dpt_gs 16 | pose_free: true 17 | intrinsics_embed_loc: encoder 18 | intrinsics_embed_type: token 19 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 20 | decoder: 21 | make_scale_invariant: true 22 | 23 | dataset: 24 | re10k: 25 | view_sampler: 26 | warm_up_steps: 9375 27 | 28 | optimizer: 29 | lr: 2e-4 30 | warm_up_steps: 125 31 | backbone_lr_multiplier: 0.1 32 | 33 | data_loader: 34 | train: 35 | batch_size: 16 36 | 37 | trainer: 38 | max_steps: 18751 39 | val_check_interval: 500 40 | 41 | checkpointing: 42 | every_n_train_steps: 9375 -------------------------------------------------------------------------------- /config/experiment/re10k_1x8.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: re10k 11 | tags: [re10k, 256x256] 12 | 13 | model: 14 | encoder: 15 | gs_params_head_type: dpt_gs 16 | pose_free: true 17 | intrinsics_embed_loc: encoder 18 | intrinsics_embed_type: token 19 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 20 | decoder: 21 | make_scale_invariant: true 22 | 23 | dataset: 24 | re10k: 25 | view_sampler: 26 | warm_up_steps: 150_000 27 | 28 | optimizer: 29 | lr: 1e-4 30 | warm_up_steps: 2000 31 | backbone_lr_multiplier: 0.1 32 | 33 | data_loader: 34 | train: 35 | batch_size: 8 36 | 37 | trainer: 38 | max_steps: 300_001 39 | val_check_interval: 2000 40 | 41 | checkpointing: 42 | every_n_train_steps: 150_000 -------------------------------------------------------------------------------- /config/experiment/re10k_3view.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: re10k 11 | tags: [re10k, 256x256] 12 | 13 | model: 14 | encoder: 15 | backbone: 16 | name: croco_multi 17 | name: noposplat_multi 18 | gs_params_head_type: dpt_gs 19 | pose_free: true 20 | intrinsics_embed_loc: encoder 21 | intrinsics_embed_type: token 22 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 23 | decoder: 24 | make_scale_invariant: true 25 | 26 | dataset: 27 | re10k: 28 | view_sampler: 29 | warm_up_steps: 9375 30 | num_context_views: 3 31 | 32 | optimizer: 33 | lr: 2e-4 34 | warm_up_steps: 125 35 | backbone_lr_multiplier: 0.1 36 | 37 | data_loader: 38 | train: 39 | batch_size: 8 40 | 41 | trainer: 42 | max_steps: 18751 43 | val_check_interval: 500 44 | 45 | checkpointing: 46 | every_n_train_steps: 9375 -------------------------------------------------------------------------------- /config/experiment/re10k_dl3dv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - /dataset@_group_.dl3dv: dl3dv 6 | - override /model/encoder: noposplat 7 | - override /model/encoder/backbone: croco 8 | - override /loss: [mse, lpips] 9 | 10 | wandb: 11 | name: re10k_dl3dv 12 | tags: [re10k_dl3dv, 256x256] 13 | 14 | model: 15 | encoder: 16 | gs_params_head_type: dpt_gs 17 | pose_free: true 18 | intrinsics_embed_loc: encoder 19 | intrinsics_embed_type: token 20 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 21 | 22 | dataset: 23 | re10k: 24 | view_sampler: 25 | warm_up_steps: 9375 26 | dl3dv: 27 | view_sampler: 28 | warm_up_steps: 9375 29 | 30 | optimizer: 31 | lr: 2e-4 32 | warm_up_steps: 125 33 | backbone_lr_multiplier: 0.1 34 | 35 | data_loader: 36 | train: 37 | batch_size: 8 # 8 for each dataset, since we have 2 datasets, the total batch size is 16 38 | 39 | trainer: 40 | max_steps: 18751 41 | val_check_interval: 500 42 | 43 | checkpointing: 44 | every_n_train_steps: 9375 45 | -------------------------------------------------------------------------------- /config/experiment/re10k_dl3dv_512x512.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.re10k: re10k 5 | - /dataset@_group_.dl3dv: dl3dv 6 | - override /model/encoder: noposplat 7 | - override /model/encoder/backbone: croco 8 | - override /loss: [mse, lpips] 9 | 10 | wandb: 11 | name: re10k_dl3dv 12 | tags: [re10k_dl3dv, 512x512] 13 | 14 | model: 15 | encoder: 16 | gs_params_head_type: dpt_gs 17 | pose_free: true 18 | intrinsics_embed_loc: encoder 19 | intrinsics_embed_type: token 20 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 21 | 22 | dataset: 23 | re10k: 24 | roots: [ datasets/re10k_720p ] 25 | skip_bad_shape: true 26 | input_image_shape: [ 512, 512 ] 27 | original_image_shape: [ 720, 1280 ] 28 | view_sampler: 29 | warm_up_steps: 9375 30 | dl3dv: 31 | roots: [ datasets/dl3dv_960p ] 32 | skip_bad_shape: true 33 | input_image_shape: [ 512, 512 ] 34 | original_image_shape: [ 540, 960 ] 35 | view_sampler: 36 | warm_up_steps: 9375 37 | 38 | optimizer: 39 | lr: 2e-4 40 | warm_up_steps: 125 41 | backbone_lr_multiplier: 0.1 42 | 43 | data_loader: 44 | train: 45 | batch_size: 2 # 2 for each dataset, since we have 2 datasets, the total batch size is 4 46 | 47 | trainer: 48 | max_steps: 18751 49 | val_check_interval: 500 50 | 51 | checkpointing: 52 | every_n_train_steps: 9375 53 | -------------------------------------------------------------------------------- /config/experiment/scannet_pose.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /dataset@_group_.scannet_pose: scannet_pose 5 | - override /model/encoder: noposplat 6 | - override /model/encoder/backbone: croco 7 | - override /loss: [mse, lpips] 8 | 9 | wandb: 10 | name: re10k 11 | tags: [re10k, 256x256] 12 | 13 | model: 14 | encoder: 15 | gs_params_head_type: dpt_gs 16 | pose_free: true 17 | intrinsics_embed_loc: encoder 18 | intrinsics_embed_type: token 19 | pretrained_weights: './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 20 | 21 | dataset: 22 | scannet_pose: 23 | view_sampler: 24 | warm_up_steps: 9375 25 | 26 | optimizer: 27 | lr: 2e-4 28 | warm_up_steps: 125 29 | backbone_lr_multiplier: 0.1 30 | 31 | data_loader: 32 | train: 33 | batch_size: 16 34 | 35 | trainer: 36 | max_steps: 18751 37 | val_check_interval: 500 38 | 39 | checkpointing: 40 | every_n_train_steps: 9375 -------------------------------------------------------------------------------- /config/generate_evaluation_index.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - override dataset/view_sampler: all 5 | 6 | dataset: 7 | overfit_to_scene: null 8 | 9 | data_loader: 10 | train: 11 | num_workers: 0 12 | persistent_workers: true 13 | batch_size: 1 14 | seed: 1234 15 | test: 16 | num_workers: 8 17 | persistent_workers: false 18 | batch_size: 1 19 | seed: 2345 20 | val: 21 | num_workers: 0 22 | persistent_workers: true 23 | batch_size: 1 24 | seed: 3456 25 | 26 | index_generator: 27 | num_target_views: 3 28 | min_overlap: 0.6 29 | max_overlap: 1.0 30 | min_distance: 45 31 | max_distance: 135 32 | output_path: outputs/evaluation_index_re10k 33 | save_previews: false 34 | seed: 123 35 | 36 | seed: 456 37 | -------------------------------------------------------------------------------- /config/loss/depth.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | weight: 0.25 3 | sigma_image: null 4 | use_second_derivative: false 5 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.05 3 | apply_after_step: 0 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model/encoder: noposplat 3 | - model/decoder: splatting_cuda 4 | - loss: [mse] 5 | 6 | wandb: 7 | project: noposplat 8 | entity: scene-representation-group 9 | name: placeholder 10 | mode: disabled 11 | 12 | mode: train 13 | 14 | #dataset: 15 | # overfit_to_scene: null 16 | 17 | data_loader: 18 | # Avoid having to spin up new processes to print out visualizations. 19 | train: 20 | num_workers: 16 21 | persistent_workers: true 22 | batch_size: 4 23 | seed: 1234 24 | test: 25 | num_workers: 4 26 | persistent_workers: false 27 | batch_size: 1 28 | seed: 2345 29 | val: 30 | num_workers: 1 31 | persistent_workers: true 32 | batch_size: 1 33 | seed: 3456 34 | 35 | optimizer: 36 | lr: 1.5e-4 37 | warm_up_steps: 2000 38 | backbone_lr_multiplier: 0.1 39 | 40 | checkpointing: 41 | load: null 42 | every_n_train_steps: 5000 43 | save_top_k: 1 44 | save_weights_only: true 45 | 46 | train: 47 | depth_mode: null 48 | extended_visualization: false 49 | print_log_every_n_steps: 10 50 | distiller: '' 51 | distill_max_steps: 1000000 52 | 53 | test: 54 | output_path: outputs/test 55 | align_pose: true 56 | pose_align_steps: 100 57 | rot_opt_lr: 0.005 58 | trans_opt_lr: 0.005 59 | compute_scores: true 60 | save_image: true 61 | save_video: false 62 | save_compare: false 63 | 64 | seed: 111123 65 | 66 | trainer: 67 | max_steps: -1 68 | val_check_interval: 250 69 | gradient_clip_val: 0.5 70 | num_nodes: 1 71 | 72 | hydra: 73 | run: 74 | dir: outputs/exp_${wandb.name}/${now:%Y-%m-%d_%H-%M-%S} 75 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | background_color: [0.0, 0.0, 0.0] 3 | make_scale_invariant: false 4 | -------------------------------------------------------------------------------- /config/model/encoder/backbone/croco.yaml: -------------------------------------------------------------------------------- 1 | name: croco 2 | 3 | model: ViTLarge_BaseDecoder 4 | patch_embed_cls: PatchEmbedDust3R 5 | asymmetry_decoder: true 6 | 7 | intrinsics_embed_loc: 'encoder' 8 | intrinsics_embed_degree: 4 9 | intrinsics_embed_type: 'token' -------------------------------------------------------------------------------- /config/model/encoder/noposplat.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: croco 3 | 4 | name: noposplat 5 | 6 | opacity_mapping: 7 | initial: 0.0 8 | final: 0.0 9 | warm_up: 1 10 | 11 | num_monocular_samples: 32 12 | num_surfaces: 1 13 | predict_opacity: false 14 | 15 | gaussians_per_pixel: 1 16 | 17 | gaussian_adapter: 18 | gaussian_scale_min: 0.5 19 | gaussian_scale_max: 15.0 20 | sh_degree: 4 21 | 22 | d_feature: 128 23 | 24 | visualizer: 25 | num_samples: 8 26 | min_resolution: 256 27 | export_ply: false 28 | 29 | apply_bounds_shim: true 30 | 31 | gs_params_head_type: dpt_gs 32 | pose_free: true 33 | pretrained_weights: "" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Enable Pyflakes `E` and `F` codes by default. 3 | select = ["E", "F", "I"] 4 | ignore = ["F722"] # Ignore F722 for jaxtyping compatibility. 5 | 6 | # Allow autofix for all enabled rules (when `--fix`) is provided. 7 | fixable = ["A", "B", "C", "D", "E", "F", "I"] 8 | unfixable = [] 9 | 10 | # Exclude a variety of commonly ignored directories. 11 | exclude = [ 12 | ".bzr", 13 | ".direnv", 14 | ".eggs", 15 | ".git", 16 | ".hg", 17 | ".mypy_cache", 18 | ".nox", 19 | ".pants.d", 20 | ".ruff_cache", 21 | ".svn", 22 | ".tox", 23 | ".venv", 24 | "__pypackages__", 25 | "_build", 26 | "buck-out", 27 | "build", 28 | "dist", 29 | "node_modules", 30 | "venv", 31 | ] 32 | per-file-ignores = {} 33 | 34 | # Same as Black. 35 | line-length = 88 36 | 37 | # Allow unused variables when underscore-prefixed. 38 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 39 | 40 | # Assume Python 3.10. 41 | target-version = "py310" 42 | 43 | [tool.ruff.mccabe] 44 | # Unlike Flake8, default to a complexity level of 10. 45 | max-complexity = 10 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | tqdm 3 | lightning 4 | black 5 | ruff 6 | hydra-core 7 | jaxtyping 8 | beartype 9 | wandb 10 | einops 11 | colorama 12 | scikit-image 13 | colorspacious 14 | matplotlib 15 | moviepy 16 | imageio 17 | git+https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git 18 | timm 19 | dacite 20 | lpips 21 | e3nn 22 | plyfile 23 | tabulate 24 | svg.py 25 | scikit-video 26 | opencv-python 27 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset import DatasetCfgWrapper 9 | from .dataset.data_module import DataLoaderCfg 10 | from .loss import LossCfgWrapper 11 | from .model.decoder import DecoderCfg 12 | from .model.encoder import EncoderCfg 13 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 14 | 15 | 16 | @dataclass 17 | class CheckpointingCfg: 18 | load: Optional[str] # Not a path, since it could be something like wandb://... 19 | every_n_train_steps: int 20 | save_top_k: int 21 | save_weights_only: bool 22 | 23 | 24 | @dataclass 25 | class ModelCfg: 26 | decoder: DecoderCfg 27 | encoder: EncoderCfg 28 | 29 | 30 | @dataclass 31 | class TrainerCfg: 32 | max_steps: int 33 | val_check_interval: int | float | None 34 | gradient_clip_val: int | float | None 35 | num_nodes: int = 1 36 | 37 | 38 | @dataclass 39 | class RootCfg: 40 | wandb: dict 41 | mode: Literal["train", "test"] 42 | dataset: list[DatasetCfgWrapper] 43 | data_loader: DataLoaderCfg 44 | model: ModelCfg 45 | optimizer: OptimizerCfg 46 | checkpointing: CheckpointingCfg 47 | trainer: TrainerCfg 48 | loss: list[LossCfgWrapper] 49 | test: TestCfg 50 | train: TrainCfg 51 | seed: int 52 | 53 | 54 | TYPE_HOOKS = { 55 | Path: Path, 56 | } 57 | 58 | 59 | T = TypeVar("T") 60 | 61 | 62 | def load_typed_config( 63 | cfg: DictConfig, 64 | data_class: Type[T], 65 | extra_type_hooks: dict = {}, 66 | ) -> T: 67 | return from_dict( 68 | data_class, 69 | OmegaConf.to_container(cfg), 70 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 71 | ) 72 | 73 | 74 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 75 | # The dummy allows the union to be converted. 76 | @dataclass 77 | class Dummy: 78 | dummy: LossCfgWrapper 79 | 80 | return [ 81 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 82 | for k, v in joined.items() 83 | ] 84 | 85 | 86 | def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]: 87 | # The dummy allows the union to be converted. 88 | @dataclass 89 | class Dummy: 90 | dummy: DatasetCfgWrapper 91 | 92 | return [ 93 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 94 | for k, v in joined.items() 95 | ] 96 | 97 | 98 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 99 | return load_typed_config( 100 | cfg, 101 | RootCfg, 102 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers, 103 | list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers}, 104 | ) 105 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields 2 | 3 | from torch.utils.data import Dataset 4 | 5 | from .dataset_scannet_pose import DatasetScannetPose, DatasetScannetPoseCfgWrapper 6 | from ..misc.step_tracker import StepTracker 7 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg, DatasetRE10kCfgWrapper, DatasetDL3DVCfgWrapper, \ 8 | DatasetScannetppCfgWrapper 9 | from .types import Stage 10 | from .view_sampler import get_view_sampler 11 | 12 | DATASETS: dict[str, Dataset] = { 13 | "re10k": DatasetRE10k, 14 | "dl3dv": DatasetRE10k, 15 | "scannetpp": DatasetRE10k, 16 | "scannet_pose": DatasetScannetPose, 17 | } 18 | 19 | 20 | DatasetCfgWrapper = DatasetRE10kCfgWrapper | DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetScannetPoseCfgWrapper 21 | DatasetCfg = DatasetRE10kCfg 22 | 23 | 24 | def get_dataset( 25 | cfgs: list[DatasetCfgWrapper], 26 | stage: Stage, 27 | step_tracker: StepTracker | None, 28 | ) -> list[Dataset]: 29 | datasets = [] 30 | for cfg in cfgs: 31 | (field,) = fields(type(cfg)) 32 | cfg = getattr(cfg, field.name) 33 | 34 | view_sampler = get_view_sampler( 35 | cfg.view_sampler, 36 | stage, 37 | cfg.overfit_to_scene is not None, 38 | cfg.cameras_are_circular, 39 | step_tracker, 40 | ) 41 | dataset = DATASETS[cfg.name](cfg, stage, view_sampler) 42 | datasets.append(dataset) 43 | 44 | return datasets 45 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | from lightning.pytorch import LightningDataModule 8 | from torch import Generator, nn 9 | from torch.utils.data import DataLoader, Dataset, IterableDataset 10 | 11 | from ..misc.step_tracker import StepTracker 12 | from . import DatasetCfgWrapper, get_dataset 13 | from .types import DataShim, Stage 14 | from .validation_wrapper import ValidationWrapper 15 | 16 | 17 | def get_data_shim(encoder: nn.Module) -> DataShim: 18 | """Get functions that modify the batch. It's sometimes necessary to modify batches 19 | outside the data loader because GPU computations are required to modify the batch or 20 | because the modification depends on something outside the data loader. 21 | """ 22 | 23 | shims: list[DataShim] = [] 24 | if hasattr(encoder, "get_data_shim"): 25 | shims.append(encoder.get_data_shim()) 26 | 27 | def combined_shim(batch): 28 | for shim in shims: 29 | batch = shim(batch) 30 | return batch 31 | 32 | return combined_shim 33 | 34 | 35 | @dataclass 36 | class DataLoaderStageCfg: 37 | batch_size: int 38 | num_workers: int 39 | persistent_workers: bool 40 | seed: int | None 41 | 42 | 43 | @dataclass 44 | class DataLoaderCfg: 45 | train: DataLoaderStageCfg 46 | test: DataLoaderStageCfg 47 | val: DataLoaderStageCfg 48 | 49 | 50 | DatasetShim = Callable[[Dataset, Stage], Dataset] 51 | 52 | 53 | def worker_init_fn(worker_id: int) -> None: 54 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 55 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 56 | 57 | 58 | class DataModule(LightningDataModule): 59 | dataset_cfgs: list[DatasetCfgWrapper] 60 | data_loader_cfg: DataLoaderCfg 61 | step_tracker: StepTracker | None 62 | dataset_shim: DatasetShim 63 | global_rank: int 64 | 65 | def __init__( 66 | self, 67 | dataset_cfgs: list[DatasetCfgWrapper], 68 | data_loader_cfg: DataLoaderCfg, 69 | step_tracker: StepTracker | None = None, 70 | dataset_shim: DatasetShim = lambda dataset, _: dataset, 71 | global_rank: int = 0, 72 | ) -> None: 73 | super().__init__() 74 | self.dataset_cfgs = dataset_cfgs 75 | self.data_loader_cfg = data_loader_cfg 76 | self.step_tracker = step_tracker 77 | self.dataset_shim = dataset_shim 78 | self.global_rank = global_rank 79 | 80 | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: 81 | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers 82 | 83 | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: 84 | if loader_cfg.seed is None: 85 | return None 86 | generator = Generator() 87 | generator.manual_seed(loader_cfg.seed + self.global_rank) 88 | return generator 89 | 90 | def train_dataloader(self): 91 | datasets = get_dataset(self.dataset_cfgs, "train", self.step_tracker) 92 | data_loaders = [] 93 | for dataset in datasets: 94 | dataset = self.dataset_shim(dataset, "train") 95 | data_loaders.append( 96 | DataLoader( 97 | dataset, 98 | self.data_loader_cfg.train.batch_size, 99 | shuffle=not isinstance(dataset, IterableDataset), 100 | num_workers=self.data_loader_cfg.train.num_workers, 101 | generator=self.get_generator(self.data_loader_cfg.train), 102 | worker_init_fn=worker_init_fn, 103 | persistent_workers=self.get_persistent(self.data_loader_cfg.train), 104 | ) 105 | ) 106 | return data_loaders if len(data_loaders) > 1 else data_loaders[0] 107 | 108 | def val_dataloader(self): 109 | datasets = get_dataset(self.dataset_cfgs, "val", self.step_tracker) 110 | data_loaders = [] 111 | for dataset in datasets: 112 | dataset = self.dataset_shim(dataset, "val") 113 | data_loaders.append( 114 | DataLoader( 115 | ValidationWrapper(dataset, 1), 116 | self.data_loader_cfg.val.batch_size, 117 | num_workers=self.data_loader_cfg.val.num_workers, 118 | generator=self.get_generator(self.data_loader_cfg.val), 119 | worker_init_fn=worker_init_fn, 120 | persistent_workers=self.get_persistent(self.data_loader_cfg.val), 121 | ) 122 | ) 123 | return data_loaders if len(data_loaders) > 1 else data_loaders[0] 124 | 125 | def test_dataloader(self): 126 | datasets = get_dataset(self.dataset_cfgs, "test", self.step_tracker) 127 | data_loaders = [] 128 | for dataset in datasets: 129 | dataset = self.dataset_shim(dataset, "test") 130 | data_loaders.append( 131 | DataLoader( 132 | dataset, 133 | self.data_loader_cfg.test.batch_size, 134 | num_workers=self.data_loader_cfg.test.num_workers, 135 | generator=self.get_generator(self.data_loader_cfg.test), 136 | worker_init_fn=worker_init_fn, 137 | persistent_workers=self.get_persistent(self.data_loader_cfg.test), 138 | ) 139 | ) 140 | return data_loaders if len(data_loaders) > 1 else data_loaders[0] 141 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | original_image_shape: list[int] 9 | input_image_shape: list[int] 10 | background_color: list[float] 11 | cameras_are_circular: bool 12 | overfit_to_scene: str | None 13 | view_sampler: ViewSamplerCfg 14 | -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /src/dataset/shims/normalize_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): 10 | mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 11 | std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 12 | return tensor * std + mean 13 | 14 | 15 | def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): 16 | mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 17 | std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 18 | return (tensor - mean) / std 19 | 20 | 21 | def apply_normalize_shim( 22 | batch: BatchedExample, 23 | mean: tuple[float, float, float] = (0.5, 0.5, 0.5), 24 | std: tuple[float, float, float] = (0.5, 0.5, 0.5), 25 | ) -> BatchedExample: 26 | batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std) 27 | return batch 28 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | overlap: Float[Tensor, "batch _"] # batch view 24 | 25 | 26 | class BatchedExample(TypedDict, total=False): 27 | target: BatchedViews 28 | context: BatchedViews 29 | scene: list[str] 30 | 31 | 32 | class UnbatchedViews(TypedDict, total=False): 33 | extrinsics: Float[Tensor, "_ 4 4"] 34 | intrinsics: Float[Tensor, "_ 3 3"] 35 | image: Float[Tensor, "_ 3 height width"] 36 | near: Float[Tensor, " _"] 37 | far: Float[Tensor, " _"] 38 | index: Int64[Tensor, " _"] 39 | 40 | 41 | class UnbatchedExample(TypedDict, total=False): 42 | target: UnbatchedViews 43 | context: UnbatchedViews 44 | scene: str 45 | 46 | 47 | # A data shim modifies the example after it's been returned from the data loader. 48 | DataShim = Callable[[BatchedExample], BatchedExample] 49 | 50 | AnyExample = BatchedExample | UnbatchedExample 51 | AnyViews = BatchedViews | UnbatchedViews 52 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ...misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | 11 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 12 | "all": ViewSamplerAll, 13 | "arbitrary": ViewSamplerArbitrary, 14 | "bounded": ViewSamplerBounded, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = ( 19 | ViewSamplerArbitraryCfg 20 | | ViewSamplerBoundedCfg 21 | | ViewSamplerEvaluationCfg 22 | | ViewSamplerAllCfg 23 | ) 24 | 25 | 26 | def get_view_sampler( 27 | cfg: ViewSamplerCfg, 28 | stage: Stage, 29 | overfit: bool, 30 | cameras_are_circular: bool, 31 | step_tracker: StepTracker | None, 32 | ) -> ViewSampler[Any]: 33 | return VIEW_SAMPLERS[cfg.name]( 34 | cfg, 35 | stage, 36 | overfit, 37 | cameras_are_circular, 38 | step_tracker, 39 | ) 40 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/additional_view_hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Int 3 | from torch import Tensor 4 | 5 | 6 | def add_addtional_context_index( 7 | indices: Int[Tensor, "*batch 2"], 8 | number_of_context_views: int, 9 | ) -> Int[Tensor, "*batch view"]: 10 | left, right = indices.unbind(dim=-1) 11 | # evenly distribute the additional context views between the left and right views 12 | ctx_indices = torch.stack( 13 | [ 14 | torch.linspace(left, right, number_of_context_views).long() 15 | ], 16 | dim=-1, 17 | ).squeeze(-1) 18 | return ctx_indices 19 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | ) -> tuple[ 43 | Int64[Tensor, " context_view"], # indices for context views 44 | Int64[Tensor, " target_view"], # indices for target views 45 | Float[Tensor, " overlap"], # overlap 46 | ]: 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def num_target_views(self) -> int: 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def num_context_views(self) -> int: 57 | pass 58 | 59 | @property 60 | def global_step(self) -> int: 61 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 62 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .additional_view_hack import add_addtional_context_index 9 | from .view_sampler import ViewSampler 10 | 11 | 12 | @dataclass 13 | class ViewSamplerArbitraryCfg: 14 | name: Literal["arbitrary"] 15 | num_context_views: int 16 | num_target_views: int 17 | context_views: list[int] | None 18 | target_views: list[int] | None 19 | 20 | 21 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 22 | def sample( 23 | self, 24 | scene: str, 25 | extrinsics: Float[Tensor, "view 4 4"], 26 | intrinsics: Float[Tensor, "view 3 3"], 27 | device: torch.device = torch.device("cpu"), 28 | ) -> tuple[ 29 | Int64[Tensor, " context_view"], # indices for context views 30 | Int64[Tensor, " target_view"], # indices for target views 31 | Float[Tensor, " overlap"], # overlap 32 | ]: 33 | """Arbitrarily sample context and target views.""" 34 | num_views, _, _ = extrinsics.shape 35 | 36 | index_context = torch.randint( 37 | 0, 38 | num_views, 39 | size=(self.cfg.num_context_views,), 40 | device=device, 41 | ) 42 | 43 | # Allow the context views to be fixed. 44 | if self.cfg.context_views is not None: 45 | index_context = torch.tensor( 46 | self.cfg.context_views, dtype=torch.int64, device=device 47 | ) 48 | 49 | if self.cfg.num_context_views >= 3 and len(self.cfg.context_views) == 2: 50 | index_context = add_addtional_context_index(index_context, self.cfg.num_context_views) 51 | else: 52 | assert len(self.cfg.context_views) == self.cfg.num_context_views 53 | index_target = torch.randint( 54 | 0, 55 | num_views, 56 | size=(self.cfg.num_target_views,), 57 | device=device, 58 | ) 59 | 60 | # Allow the target views to be fixed. 61 | if self.cfg.target_views is not None: 62 | assert len(self.cfg.target_views) == self.cfg.num_target_views 63 | index_target = torch.tensor( 64 | self.cfg.target_views, dtype=torch.int64, device=device 65 | ) 66 | 67 | overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy 68 | 69 | return index_context, index_target, overlap 70 | 71 | @property 72 | def num_context_views(self) -> int: 73 | return self.cfg.num_context_views 74 | 75 | @property 76 | def num_target_views(self) -> int: 77 | return self.cfg.num_target_views 78 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | name: Literal["bounded"] 14 | num_context_views: int 15 | num_target_views: int 16 | min_distance_between_context_views: int 17 | max_distance_between_context_views: int 18 | min_distance_to_context_views: int 19 | warm_up_steps: int 20 | initial_min_distance_between_context_views: int 21 | initial_max_distance_between_context_views: int 22 | 23 | 24 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 25 | def schedule(self, initial: int, final: int) -> int: 26 | fraction = self.global_step / self.cfg.warm_up_steps 27 | return min(initial + int((final - initial) * fraction), final) 28 | 29 | def sample( 30 | self, 31 | scene: str, 32 | extrinsics: Float[Tensor, "view 4 4"], 33 | intrinsics: Float[Tensor, "view 3 3"], 34 | device: torch.device = torch.device("cpu"), 35 | ) -> tuple[ 36 | Int64[Tensor, " context_view"], # indices for context views 37 | Int64[Tensor, " target_view"], # indices for target views 38 | Float[Tensor, " overlap"], # overlap 39 | ]: 40 | num_views, _, _ = extrinsics.shape 41 | 42 | # Compute the context view spacing based on the current global step. 43 | if self.stage == "test": 44 | # When testing, always use the full gap. 45 | max_gap = self.cfg.max_distance_between_context_views 46 | min_gap = self.cfg.max_distance_between_context_views 47 | elif self.cfg.warm_up_steps > 0: 48 | max_gap = self.schedule( 49 | self.cfg.initial_max_distance_between_context_views, 50 | self.cfg.max_distance_between_context_views, 51 | ) 52 | min_gap = self.schedule( 53 | self.cfg.initial_min_distance_between_context_views, 54 | self.cfg.min_distance_between_context_views, 55 | ) 56 | else: 57 | max_gap = self.cfg.max_distance_between_context_views 58 | min_gap = self.cfg.min_distance_between_context_views 59 | 60 | # Pick the gap between the context views. 61 | if not self.cameras_are_circular: 62 | max_gap = min(num_views - 1, max_gap) 63 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 64 | if max_gap < min_gap: 65 | raise ValueError("Example does not have enough frames!") 66 | context_gap = torch.randint( 67 | min_gap, 68 | max_gap + 1, 69 | size=tuple(), 70 | device=device, 71 | ).item() 72 | 73 | # Pick the left and right context indices. 74 | index_context_left = torch.randint( 75 | num_views if self.cameras_are_circular else num_views - context_gap, 76 | size=tuple(), 77 | device=device, 78 | ).item() 79 | if self.stage == "test": 80 | index_context_left = index_context_left * 0 81 | index_context_right = index_context_left + context_gap 82 | 83 | if self.is_overfitting: 84 | index_context_left *= 0 85 | index_context_right *= 0 86 | index_context_right += max_gap 87 | 88 | # Pick the target view indices. 89 | if self.stage == "test": 90 | # When testing, pick all. 91 | index_target = torch.arange( 92 | index_context_left, 93 | index_context_right + 1, 94 | device=device, 95 | ) 96 | else: 97 | # When training or validating (visualizing), pick at random. 98 | index_target = torch.randint( 99 | index_context_left + self.cfg.min_distance_to_context_views, 100 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 101 | size=(self.cfg.num_target_views,), 102 | device=device, 103 | ) 104 | 105 | # Apply modulo for circular datasets. 106 | if self.cameras_are_circular: 107 | index_target %= num_views 108 | index_context_right %= num_views 109 | 110 | # If more than two context views are desired, pick extra context views between 111 | # the left and right ones. 112 | if self.cfg.num_context_views > 2: 113 | num_extra_views = self.cfg.num_context_views - 2 114 | extra_views = [] 115 | while len(set(extra_views)) != num_extra_views: 116 | extra_views = torch.randint( 117 | index_context_left + 1, 118 | index_context_right, 119 | (num_extra_views,), 120 | ).tolist() 121 | else: 122 | extra_views = [] 123 | 124 | overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy 125 | 126 | return ( 127 | torch.tensor((index_context_left, *extra_views, index_context_right)), 128 | index_target, 129 | overlap 130 | ) 131 | 132 | @property 133 | def num_context_views(self) -> int: 134 | return self.cfg.num_context_views 135 | 136 | @property 137 | def num_target_views(self) -> int: 138 | return self.cfg.num_target_views 139 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...global_cfg import get_cfg 13 | from ...misc.step_tracker import StepTracker 14 | from ..types import Stage 15 | from .additional_view_hack import add_addtional_context_index 16 | from .view_sampler import ViewSampler 17 | 18 | 19 | @dataclass 20 | class ViewSamplerEvaluationCfg: 21 | name: Literal["evaluation"] 22 | index_path: Path 23 | num_context_views: int 24 | 25 | 26 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 27 | index: dict[str, IndexEntry | None] 28 | 29 | def __init__( 30 | self, 31 | cfg: ViewSamplerEvaluationCfg, 32 | stage: Stage, 33 | is_overfitting: bool, 34 | cameras_are_circular: bool, 35 | step_tracker: StepTracker | None, 36 | ) -> None: 37 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 38 | 39 | self.cfg = cfg 40 | 41 | dacite_config = Config(cast=[tuple]) 42 | with cfg.index_path.open("r") as f: 43 | self.index = { 44 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 45 | for k, v in json.load(f).items() 46 | } 47 | 48 | def sample( 49 | self, 50 | scene: str, 51 | extrinsics: Float[Tensor, "view 4 4"], 52 | intrinsics: Float[Tensor, "view 3 3"], 53 | device: torch.device = torch.device("cpu"), 54 | ) -> tuple[ 55 | Int64[Tensor, " context_view"], # indices for context views 56 | Int64[Tensor, " target_view"], # indices for target views 57 | Float[Tensor, " overlap"], # overlap 58 | ]: 59 | entry = self.index.get(scene) 60 | if entry is None: 61 | raise ValueError(f"No indices available for scene {scene}.") 62 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 63 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 64 | 65 | overlap = entry.overlap if isinstance(entry.overlap, float) else 0.75 if entry.overlap == "large" else 0.25 66 | overlap = torch.tensor([overlap], dtype=torch.float32, device=device) 67 | 68 | # Handle 2-view index for more views. 69 | v = self.num_context_views 70 | if v >= 3 and v > len(context_indices): 71 | context_indices = add_addtional_context_index(context_indices, v) 72 | 73 | return context_indices, target_indices, overlap 74 | 75 | @property 76 | def num_context_views(self) -> int: 77 | return self.cfg.num_context_views 78 | 79 | @property 80 | def num_target_views(self) -> int: 81 | return 0 82 | -------------------------------------------------------------------------------- /src/eval_pose.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import hydra 6 | import torch 7 | from jaxtyping import install_import_hook 8 | from omegaconf import DictConfig 9 | from lightning import Trainer 10 | 11 | from src.evaluation.pose_evaluator import PoseEvaluator 12 | from src.loss import get_losses, LossCfgWrapper 13 | from src.misc.wandb_tools import update_checkpoint_path 14 | from src.model.decoder import get_decoder 15 | from src.model.encoder import get_encoder 16 | 17 | # Configure beartype and jaxtyping. 18 | with install_import_hook( 19 | ("src",), 20 | ("beartype", "beartype"), 21 | ): 22 | from src.config import load_typed_config, ModelCfg, CheckpointingCfg, separate_loss_cfg_wrappers, \ 23 | separate_dataset_cfg_wrappers 24 | from src.dataset.data_module import DataLoaderCfg, DataModule, DatasetCfgWrapper 25 | from src.evaluation.evaluation_cfg import EvaluationCfg 26 | from src.global_cfg import set_cfg 27 | 28 | 29 | @dataclass 30 | class RootCfg: 31 | evaluation: EvaluationCfg 32 | dataset: list[DatasetCfgWrapper] 33 | data_loader: DataLoaderCfg 34 | model: ModelCfg 35 | checkpointing: CheckpointingCfg 36 | loss: list[LossCfgWrapper] 37 | seed: int 38 | 39 | 40 | @hydra.main( 41 | version_base=None, 42 | config_path="../config", 43 | config_name="main", 44 | ) 45 | def evaluate(cfg_dict: DictConfig): 46 | cfg = load_typed_config(cfg_dict, RootCfg, 47 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers, 48 | list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers},) 49 | set_cfg(cfg_dict) 50 | torch.manual_seed(cfg.seed) 51 | 52 | encoder, encoder_visualizer = get_encoder(cfg.model.encoder) 53 | ckpt_weights = torch.load(cfg.checkpointing.load, map_location='cpu')['state_dict'] 54 | # remove the prefix "encoder.", need to judge if is at start of key 55 | ckpt_weights = {k[8:] if k.startswith("encoder.") else k: v for k, v in ckpt_weights.items()} 56 | missing_keys, unexpected_keys = encoder.load_state_dict(ckpt_weights, strict=True) 57 | 58 | trainer = Trainer(max_epochs=-1, accelerator="gpu", inference_mode=False) 59 | pose_evaluator = PoseEvaluator(cfg.evaluation, 60 | encoder, 61 | get_decoder(cfg.model.decoder), 62 | get_losses(cfg.loss)) 63 | data_module = DataModule( 64 | cfg.dataset, 65 | cfg.data_loader, 66 | ) 67 | 68 | metrics = trainer.test(pose_evaluator, datamodule=data_module) 69 | 70 | cfg.evaluation.output_metrics_path.parent.mkdir(exist_ok=True, parents=True) 71 | with cfg.evaluation.output_metrics_path.open("w") as f: 72 | json.dump(metrics[0], f) 73 | 74 | 75 | if __name__ == "__main__": 76 | evaluate() 77 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | output_metrics_path: Path 23 | animate_side_by_side: bool 24 | highlighted: list[SceneCfg] 25 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | 54 | 55 | def compute_geodesic_distance_from_two_matrices(m1, m2): 56 | batch = m1.shape[0] 57 | m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 58 | 59 | cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 60 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device))) 61 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).to(m1.device)) * -1) 62 | 63 | theta = torch.acos(cos) 64 | 65 | # theta = torch.min(theta, 2*np.pi - theta) 66 | 67 | return theta 68 | 69 | 70 | def angle_error_mat(R1, R2): 71 | cos = (torch.trace(torch.mm(R1.T, R2)) - 1) / 2 72 | cos = torch.clamp(cos, -1.0, 1.0) # numerical errors can make it out of bounds 73 | return torch.rad2deg(torch.abs(torch.acos(cos))) 74 | 75 | 76 | def angle_error_vec(v1, v2): 77 | n = torch.norm(v1) * torch.norm(v2) 78 | cos_theta = torch.dot(v1, v2) / n 79 | cos_theta = torch.clamp(cos_theta, -1.0, 1.0) # numerical errors can make it out of bounds 80 | return torch.rad2deg(torch.acos(cos_theta)) 81 | 82 | 83 | def compute_translation_error(t1, t2): 84 | return torch.norm(t1 - t2) 85 | 86 | 87 | @torch.no_grad() 88 | def compute_pose_error(pose_gt, pose_pred): 89 | R_gt = pose_gt[:3, :3] 90 | t_gt = pose_gt[:3, 3] 91 | 92 | R = pose_pred[:3, :3] 93 | t = pose_pred[:3, 3] 94 | 95 | error_t = angle_error_vec(t, t_gt) 96 | error_t = torch.minimum(error_t, 180 - error_t) # ambiguity of E estimation 97 | error_t_scale = compute_translation_error(t, t_gt) 98 | error_R = angle_error_mat(R, R_gt) 99 | return error_t, error_t_scale, error_R 100 | -------------------------------------------------------------------------------- /src/geometry/camera_emb.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | 3 | from .projection import sample_image_grid, get_local_rays 4 | from ..misc.sht import rsh_cart_2, rsh_cart_4, rsh_cart_6, rsh_cart_8 5 | 6 | 7 | def get_intrinsic_embedding(context, degree=0, downsample=1, merge_hw=False): 8 | assert degree in [0, 2, 4, 8] 9 | 10 | b, v, _, h, w = context["image"].shape 11 | device = context["image"].device 12 | tgt_h, tgt_w = h // downsample, w // downsample 13 | xy_ray, _ = sample_image_grid((tgt_h, tgt_w), device) 14 | xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # [b, v, h, w, 2] 15 | directions = get_local_rays(xy_ray, rearrange(context["intrinsics"], "b v i j -> b v () () i j"),) 16 | 17 | if degree == 2: 18 | directions = rsh_cart_2(directions) 19 | elif degree == 4: 20 | directions = rsh_cart_4(directions) 21 | elif degree == 8: 22 | directions = rsh_cart_8(directions) 23 | 24 | if merge_hw: 25 | directions = rearrange(directions, "b v h w d -> b v (h w) d") 26 | else: 27 | directions = rearrange(directions, "b v h w d -> b v d h w") 28 | 29 | return directions 30 | -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_depth import LossDepth, LossDepthCfgWrapper 3 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 4 | from .loss_mse import LossMse, LossMseCfgWrapper 5 | 6 | LOSSES = { 7 | LossDepthCfgWrapper: LossDepth, 8 | LossLpipsCfgWrapper: LossLpips, 9 | LossMseCfgWrapper: LossMse, 10 | } 11 | 12 | LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper 13 | 14 | 15 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 16 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 17 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians: Gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | -------------------------------------------------------------------------------- /src/loss/loss_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | from .loss import Loss 12 | 13 | 14 | @dataclass 15 | class LossDepthCfg: 16 | weight: float 17 | sigma_image: float | None 18 | use_second_derivative: bool 19 | 20 | 21 | @dataclass 22 | class LossDepthCfgWrapper: 23 | depth: LossDepthCfg 24 | 25 | 26 | class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): 27 | def forward( 28 | self, 29 | prediction: DecoderOutput, 30 | batch: BatchedExample, 31 | gaussians: Gaussians, 32 | global_step: int, 33 | ) -> Float[Tensor, ""]: 34 | # Scale the depth between the near and far planes. 35 | near = batch["target"]["near"][..., None, None].log() 36 | far = batch["target"]["far"][..., None, None].log() 37 | depth = prediction.depth.minimum(far).maximum(near) 38 | depth = (depth - near) / (far - near) 39 | 40 | # Compute the difference between neighboring pixels in each direction. 41 | depth_dx = depth.diff(dim=-1) 42 | depth_dy = depth.diff(dim=-2) 43 | 44 | # If desired, compute a 2nd derivative. 45 | if self.cfg.use_second_derivative: 46 | depth_dx = depth_dx.diff(dim=-1) 47 | depth_dy = depth_dy.diff(dim=-2) 48 | 49 | # If desired, add bilateral filtering. 50 | if self.cfg.sigma_image is not None: 51 | color_gt = batch["target"]["image"] 52 | color_dx = reduce(color_gt.diff(dim=-1), "b v c h w -> b v h w", "max") 53 | color_dy = reduce(color_gt.diff(dim=-2), "b v c h w -> b v h w", "max") 54 | if self.cfg.use_second_derivative: 55 | color_dx = color_dx[..., :, 1:].maximum(color_dx[..., :, :-1]) 56 | color_dy = color_dy[..., 1:, :].maximum(color_dy[..., :-1, :]) 57 | depth_dx = depth_dx * torch.exp(-color_dx * self.cfg.sigma_image) 58 | depth_dy = depth_dy * torch.exp(-color_dy * self.cfg.sigma_image) 59 | 60 | return self.cfg.weight * (depth_dx.abs().mean() + depth_dy.abs().mean()) 61 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians: Gaussians, 41 | global_step: int, 42 | ) -> Float[Tensor, ""]: 43 | image = batch["target"]["image"] 44 | 45 | # Before the specified step, don't apply the loss. 46 | if global_step < self.cfg.apply_after_step: 47 | return torch.tensor(0, dtype=torch.float32, device=image.device) 48 | 49 | loss = self.lpips.forward( 50 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 51 | rearrange(image, "b v c h w -> (b v) c h w"), 52 | normalize=True, 53 | ) 54 | return self.cfg.weight * loss.mean() 55 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians: Gaussians, 28 | global_step: int, 29 | ) -> Float[Tensor, ""]: 30 | delta = prediction.color - batch["target"]["image"] 31 | return self.cfg.weight * (delta**2).mean() 32 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from lightning.pytorch.loggers.logger import Logger 6 | from lightning.pytorch.utilities import rank_zero_only 7 | from PIL import Image 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | 18 | @property 19 | def name(self): 20 | return "LocalLogger" 21 | 22 | @property 23 | def version(self): 24 | return 0 25 | 26 | @rank_zero_only 27 | def log_hyperparams(self, params): 28 | pass 29 | 30 | @rank_zero_only 31 | def log_metrics(self, metrics, step): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_image( 36 | self, 37 | key: str, 38 | images: list[Any], 39 | step: Optional[int] = None, 40 | **kwargs, 41 | ): 42 | # The function signature is the same as the wandb logger's, but the step is 43 | # actually required. 44 | assert step is not None 45 | for index, image in enumerate(images): 46 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 47 | path.parent.mkdir(exist_ok=True, parents=True) 48 | Image.fromarray(image).save(path) 49 | -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | -------------------------------------------------------------------------------- /src/misc/cam_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | def decompose_extrinsic_RT(E: torch.Tensor): 9 | """ 10 | Decompose the standard extrinsic matrix into RT. 11 | Batched I/O. 12 | """ 13 | return E[:, :3, :] 14 | 15 | 16 | def compose_extrinsic_RT(RT: torch.Tensor): 17 | """ 18 | Compose the standard form extrinsic matrix from RT. 19 | Batched I/O. 20 | """ 21 | return torch.cat([ 22 | RT, 23 | torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1) 24 | ], dim=1) 25 | 26 | 27 | def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor): 28 | # [1, 4, 4], [N, 4, 4] 29 | 30 | canonical_camera_extrinsics = torch.tensor([[ 31 | [1, 0, 0, 0], 32 | [0, 1, 0, 0], 33 | [0, 0, 1, 0], 34 | [0, 0, 0, 1], 35 | ]], dtype=torch.float32, device=pivotal_pose.device) 36 | pivotal_pose_inv = torch.inverse(pivotal_pose) 37 | camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv) 38 | 39 | # normalize all views 40 | poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses) 41 | 42 | return poses 43 | 44 | 45 | ####### Pose update from delta 46 | 47 | def rt2mat(R, T): 48 | mat = np.eye(4) 49 | mat[0:3, 0:3] = R 50 | mat[0:3, 3] = T 51 | return mat 52 | 53 | 54 | def skew_sym_mat(x): 55 | device = x.device 56 | dtype = x.dtype 57 | ssm = torch.zeros(3, 3, device=device, dtype=dtype) 58 | ssm[0, 1] = -x[2] 59 | ssm[0, 2] = x[1] 60 | ssm[1, 0] = x[2] 61 | ssm[1, 2] = -x[0] 62 | ssm[2, 0] = -x[1] 63 | ssm[2, 1] = x[0] 64 | return ssm 65 | 66 | 67 | def SO3_exp(theta): 68 | device = theta.device 69 | dtype = theta.dtype 70 | 71 | W = skew_sym_mat(theta) 72 | W2 = W @ W 73 | angle = torch.norm(theta) 74 | I = torch.eye(3, device=device, dtype=dtype) 75 | if angle < 1e-5: 76 | return I + W + 0.5 * W2 77 | else: 78 | return ( 79 | I 80 | + (torch.sin(angle) / angle) * W 81 | + ((1 - torch.cos(angle)) / (angle**2)) * W2 82 | ) 83 | 84 | 85 | def V(theta): 86 | dtype = theta.dtype 87 | device = theta.device 88 | I = torch.eye(3, device=device, dtype=dtype) 89 | W = skew_sym_mat(theta) 90 | W2 = W @ W 91 | angle = torch.norm(theta) 92 | if angle < 1e-5: 93 | V = I + 0.5 * W + (1.0 / 6.0) * W2 94 | else: 95 | V = ( 96 | I 97 | + W * ((1.0 - torch.cos(angle)) / (angle**2)) 98 | + W2 * ((angle - torch.sin(angle)) / (angle**3)) 99 | ) 100 | return V 101 | 102 | 103 | def SE3_exp(tau): 104 | dtype = tau.dtype 105 | device = tau.device 106 | 107 | rho = tau[:3] 108 | theta = tau[3:] 109 | R = SO3_exp(theta) 110 | t = V(theta) @ rho 111 | 112 | T = torch.eye(4, device=device, dtype=dtype) 113 | T[:3, :3] = R 114 | T[:3, 3] = t 115 | return T 116 | 117 | 118 | def update_pose(cam_trans_delta: Float[Tensor, "batch 3"], 119 | cam_rot_delta: Float[Tensor, "batch 3"], 120 | extrinsics: Float[Tensor, "batch 4 4"], 121 | # original_rot: Float[Tensor, "batch 3 3"], 122 | # original_trans: Float[Tensor, "batch 3"], 123 | # converged_threshold: float = 1e-4 124 | ): 125 | # extrinsics is c2w, here we need w2c as input, so we need to invert it 126 | bs = cam_trans_delta.shape[0] 127 | 128 | tau = torch.cat([cam_trans_delta, cam_rot_delta], dim=-1) 129 | T_w2c = extrinsics.inverse() 130 | 131 | new_w2c_list = [] 132 | for i in range(bs): 133 | new_w2c = SE3_exp(tau[i]) @ T_w2c[i] 134 | new_w2c_list.append(new_w2c) 135 | 136 | new_w2c = torch.stack(new_w2c_list, dim=0) 137 | return new_w2c.inverse() 138 | 139 | # converged = tau.norm() < converged_threshold 140 | # camera.update_RT(new_R, new_T) 141 | # 142 | # camera.cam_rot_delta.data.fill_(0) 143 | # camera.cam_trans_delta.data.fill_(0) 144 | # return converged 145 | 146 | 147 | ####### Pose estimation 148 | def inv(mat): 149 | """ Invert a torch or numpy matrix 150 | """ 151 | if isinstance(mat, torch.Tensor): 152 | return torch.linalg.inv(mat) 153 | if isinstance(mat, np.ndarray): 154 | return np.linalg.inv(mat) 155 | raise ValueError(f'bad matrix type = {type(mat)}') 156 | 157 | 158 | def get_pnp_pose(pts3d, opacity, K, H, W, opacity_threshold=0.3): 159 | pixels = np.mgrid[:W, :H].T.astype(np.float32) 160 | pts3d = pts3d.cpu().numpy() 161 | opacity = opacity.cpu().numpy() 162 | K = K.cpu().numpy() 163 | 164 | K[0, :] = K[0, :] * W 165 | K[1, :] = K[1, :] * H 166 | 167 | mask = opacity > opacity_threshold 168 | 169 | res = cv2.solvePnPRansac(pts3d[mask], pixels[mask], K, None, 170 | iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) 171 | success, R, T, inliers = res 172 | 173 | assert success 174 | 175 | R = cv2.Rodrigues(R)[0] # world to cam 176 | pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world 177 | 178 | return torch.from_numpy(pose.astype(np.float32)) 179 | 180 | 181 | def pose_auc(errors, thresholds): 182 | sort_idx = np.argsort(errors) 183 | errors = np.array(errors.copy())[sort_idx] 184 | recall = (np.arange(len(errors)) + 1) / len(errors) 185 | errors = np.r_[0.0, errors] 186 | recall = np.r_[0.0, recall] 187 | aucs = [] 188 | for t in thresholds: 189 | last_index = np.searchsorted(errors, t) 190 | r = np.r_[recall[:last_index], recall[last_index - 1]] 191 | e = np.r_[errors[:last_index], t] 192 | aucs.append(np.trapz(r, x=e) / t) 193 | return aucs 194 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import numpy as np 6 | import skvideo.io 7 | import torch 8 | import torchvision.transforms as tf 9 | from einops import rearrange, repeat 10 | from jaxtyping import Float, UInt8 11 | from matplotlib.figure import Figure 12 | from PIL import Image 13 | from torch import Tensor 14 | 15 | FloatImage = Union[ 16 | Float[Tensor, "height width"], 17 | Float[Tensor, "channel height width"], 18 | Float[Tensor, "batch channel height width"], 19 | ] 20 | 21 | 22 | def fig_to_image( 23 | fig: Figure, 24 | dpi: int = 100, 25 | device: torch.device = torch.device("cpu"), 26 | ) -> Float[Tensor, "3 height width"]: 27 | buffer = io.BytesIO() 28 | fig.savefig(buffer, format="raw", dpi=dpi) 29 | buffer.seek(0) 30 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 31 | h = int(fig.bbox.bounds[3]) 32 | w = int(fig.bbox.bounds[2]) 33 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 34 | buffer.close() 35 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 36 | 37 | 38 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 39 | # Handle batched images. 40 | if image.ndim == 4: 41 | image = rearrange(image, "b c h w -> c h (b w)") 42 | 43 | # Handle single-channel images. 44 | if image.ndim == 2: 45 | image = rearrange(image, "h w -> () h w") 46 | 47 | # Ensure that there are 3 or 4 channels. 48 | channel, _, _ = image.shape 49 | if channel == 1: 50 | image = repeat(image, "() h w -> c h w", c=3) 51 | assert image.shape[0] in (3, 4) 52 | 53 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 54 | return rearrange(image, "c h w -> h w c").cpu().numpy() 55 | 56 | 57 | def save_image( 58 | image: FloatImage, 59 | path: Union[Path, str], 60 | ) -> None: 61 | """Save an image. Assumed to be in range 0-1.""" 62 | 63 | # Create the parent directory if it doesn't already exist. 64 | path = Path(path) 65 | path.parent.mkdir(exist_ok=True, parents=True) 66 | 67 | # Save the image. 68 | Image.fromarray(prep_image(image)).save(path) 69 | 70 | 71 | def load_image( 72 | path: Union[Path, str], 73 | ) -> Float[Tensor, "3 height width"]: 74 | return tf.ToTensor()(Image.open(path))[:3] 75 | 76 | 77 | def save_video( 78 | images: list[FloatImage], 79 | path: Union[Path, str], 80 | ) -> None: 81 | """Save an image. Assumed to be in range 0-1.""" 82 | 83 | # Create the parent directory if it doesn't already exist. 84 | path = Path(path) 85 | path.parent.mkdir(exist_ok=True, parents=True) 86 | 87 | # Save the image. 88 | # Image.fromarray(prep_image(image)).save(path) 89 | frames = [] 90 | for image in images: 91 | frames.append(prep_image(image)) 92 | 93 | writer = skvideo.io.FFmpegWriter(path, 94 | outputdict={'-pix_fmt': 'yuv420p', '-crf': '21', 95 | '-vf': f'setpts=1.*PTS'}) 96 | for frame in frames: 97 | writer.writeFrame(frame) 98 | writer.close() 99 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | # change the basis from YZX -> XYZ to fit the convention of e3nn 18 | P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], 19 | dtype=sh_coefficients.dtype, device=sh_coefficients.device) 20 | inversed_P = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0], ], 21 | dtype=sh_coefficients.dtype, device=sh_coefficients.device) 22 | permuted_rotation_matrix = inversed_P @ rotations @ P 23 | 24 | *_, n = sh_coefficients.shape 25 | alpha, beta, gamma = matrix_to_angles(permuted_rotation_matrix) 26 | result = [] 27 | for degree in range(isqrt(n)): 28 | with torch.device(device): 29 | sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype) 30 | sh_rotated = einsum( 31 | sh_rotations, 32 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 33 | "... i j, ... j -> ... i", 34 | ) 35 | result.append(sh_rotated) 36 | 37 | return torch.cat(result, dim=-1) 38 | 39 | 40 | # def rotate_sh( 41 | # sh_coefficients: Float[Tensor, "*#batch n"], 42 | # rotations: Float[Tensor, "*#batch 3 3"], 43 | # ) -> Float[Tensor, "*batch n"]: 44 | # device = sh_coefficients.device 45 | # dtype = sh_coefficients.dtype 46 | # 47 | # *_, n = sh_coefficients.shape 48 | # alpha, beta, gamma = matrix_to_angles(rotations) 49 | # result = [] 50 | # for degree in range(isqrt(n)): 51 | # with torch.device(device): 52 | # sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 53 | # sh_rotated = einsum( 54 | # sh_rotations, 55 | # sh_coefficients[..., degree**2 : (degree + 1) ** 2], 56 | # "... i j, ... j -> ... i", 57 | # ) 58 | # result.append(sh_rotated) 59 | # 60 | # return torch.cat(result, dim=-1) 61 | 62 | 63 | if __name__ == "__main__": 64 | from pathlib import Path 65 | 66 | import matplotlib.pyplot as plt 67 | from e3nn.o3 import spherical_harmonics 68 | from matplotlib import cm 69 | from scipy.spatial.transform.rotation import Rotation as R 70 | 71 | device = torch.device("cuda") 72 | 73 | # Generate random spherical harmonics coefficients. 74 | degree = 4 75 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 76 | 77 | def plot_sh(sh_coefficients, path: Path) -> None: 78 | phi = torch.linspace(0, torch.pi, 100, device=device) 79 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 80 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 81 | x = torch.sin(phi) * torch.cos(theta) 82 | y = torch.sin(phi) * torch.sin(theta) 83 | z = torch.cos(phi) 84 | xyz = torch.stack([x, y, z], dim=-1) 85 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 86 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 87 | result = (result - result.min()) / (result.max() - result.min()) 88 | 89 | # Set the aspect ratio to 1 so our sphere looks spherical 90 | fig = plt.figure(figsize=plt.figaspect(1.0)) 91 | ax = fig.add_subplot(111, projection="3d") 92 | ax.plot_surface( 93 | x.cpu().numpy(), 94 | y.cpu().numpy(), 95 | z.cpu().numpy(), 96 | rstride=1, 97 | cstride=1, 98 | facecolors=cm.seismic(result.cpu().numpy()), 99 | ) 100 | # Turn off the axis planes 101 | ax.set_axis_off() 102 | path.parent.mkdir(exist_ok=True, parents=True) 103 | plt.savefig(path) 104 | 105 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 106 | rotation = torch.tensor( 107 | R.from_euler("x", angle.item()).as_matrix(), device=device 108 | ) 109 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 110 | 111 | print("Done!") 112 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.visualization.color_map import apply_color_map_to_image 4 | 5 | 6 | def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): 7 | mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 8 | std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) 9 | return tensor.mul(std).add(mean) 10 | 11 | 12 | # Color-map the result. 13 | def vis_depth_map(result): 14 | far = result.view(-1)[:16_000_000].quantile(0.99).log() 15 | try: 16 | near = result[result > 0][:16_000_000].quantile(0.01).log() 17 | except: 18 | print("No valid depth values found.") 19 | near = torch.zeros_like(far) 20 | result = result.log() 21 | result = 1 - (result - near) / (far - near) 22 | return apply_color_map_to_image(result, "turbo") 23 | 24 | 25 | def confidence_map(result): 26 | # far = result.view(-1)[:16_000_000].quantile(0.99).log() 27 | # try: 28 | # near = result[result > 0][:16_000_000].quantile(0.01).log() 29 | # except: 30 | # print("No valid depth values found.") 31 | # near = torch.zeros_like(far) 32 | # result = result.log() 33 | # result = 1 - (result - near) / (far - near) 34 | result = result / result.view(-1).max() 35 | return apply_color_map_to_image(result, "magma") 36 | 37 | 38 | def get_overlap_tag(overlap): 39 | if 0.05 <= overlap <= 0.3: 40 | overlap_tag = "small" 41 | elif overlap <= 0.55: 42 | overlap_tag = "medium" 43 | elif overlap <= 0.8: 44 | overlap_tag = "large" 45 | else: 46 | overlap_tag = "ignore" 47 | 48 | return overlap_tag 49 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import Decoder 2 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 3 | 4 | DECODERS = { 5 | "splatting_cuda": DecoderSplattingCUDA, 6 | } 7 | 8 | DecoderCfg = DecoderSplattingCUDACfg 9 | 10 | 11 | def get_decoder(decoder_cfg: DecoderCfg) -> Decoder: 12 | return DECODERS[decoder_cfg.name](decoder_cfg) 13 | -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..types import Gaussians 9 | 10 | DepthRenderingMode = Literal[ 11 | "depth", 12 | "log", 13 | "disparity", 14 | "relative_disparity", 15 | ] 16 | 17 | 18 | @dataclass 19 | class DecoderOutput: 20 | color: Float[Tensor, "batch view 3 height width"] 21 | depth: Float[Tensor, "batch view height width"] | None 22 | 23 | 24 | T = TypeVar("T") 25 | 26 | 27 | class Decoder(nn.Module, ABC, Generic[T]): 28 | cfg: T 29 | 30 | def __init__(self, cfg: T) -> None: 31 | super().__init__() 32 | self.cfg = cfg 33 | 34 | @abstractmethod 35 | def forward( 36 | self, 37 | gaussians: Gaussians, 38 | extrinsics: Float[Tensor, "batch view 4 4"], 39 | intrinsics: Float[Tensor, "batch view 3 3"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | image_shape: tuple[int, int], 43 | depth_mode: DepthRenderingMode | None = None, 44 | ) -> DecoderOutput: 45 | pass 46 | -------------------------------------------------------------------------------- /src/model/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | from .cuda_splatting import DepthRenderingMode, render_cuda 12 | from .decoder import Decoder, DecoderOutput 13 | 14 | 15 | @dataclass 16 | class DecoderSplattingCUDACfg: 17 | name: Literal["splatting_cuda"] 18 | background_color: list[float] 19 | make_scale_invariant: bool 20 | 21 | 22 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 23 | background_color: Float[Tensor, "3"] 24 | 25 | def __init__( 26 | self, 27 | cfg: DecoderSplattingCUDACfg, 28 | ) -> None: 29 | super().__init__(cfg) 30 | self.make_scale_invariant = cfg.make_scale_invariant 31 | self.register_buffer( 32 | "background_color", 33 | torch.tensor(cfg.background_color, dtype=torch.float32), 34 | persistent=False, 35 | ) 36 | 37 | def forward( 38 | self, 39 | gaussians: Gaussians, 40 | extrinsics: Float[Tensor, "batch view 4 4"], 41 | intrinsics: Float[Tensor, "batch view 3 3"], 42 | near: Float[Tensor, "batch view"], 43 | far: Float[Tensor, "batch view"], 44 | image_shape: tuple[int, int], 45 | depth_mode: DepthRenderingMode | None = None, 46 | cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, 47 | cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, 48 | ) -> DecoderOutput: 49 | b, v, _, _ = extrinsics.shape 50 | color, depth = render_cuda( 51 | rearrange(extrinsics, "b v i j -> (b v) i j"), 52 | rearrange(intrinsics, "b v i j -> (b v) i j"), 53 | rearrange(near, "b v -> (b v)"), 54 | rearrange(far, "b v -> (b v)"), 55 | image_shape, 56 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 57 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 58 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 59 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 60 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 61 | scale_invariant=self.make_scale_invariant, 62 | cam_rot_delta=rearrange(cam_rot_delta, "b v i -> (b v) i") if cam_rot_delta is not None else None, 63 | cam_trans_delta=rearrange(cam_trans_delta, "b v i -> (b v) i") if cam_trans_delta is not None else None, 64 | ) 65 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 66 | 67 | depth = rearrange(depth, "(b v) h w -> b v h w", b=b, v=v) 68 | return DecoderOutput(color, depth) 69 | -------------------------------------------------------------------------------- /src/model/distiller/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .dust3d_backbone import Dust3R 4 | 5 | 6 | inf = float('inf') 7 | 8 | 9 | def get_distiller(name): 10 | assert name == 'dust3r' or name == 'mast3r', f"unexpected name={name}" 11 | distiller = Dust3R(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf)) 12 | distiller = distiller.eval() 13 | 14 | if name == 'dust3r': 15 | weight_path = './pretrained_weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth' 16 | elif name == 'mast3r': 17 | weight_path = './pretrained_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' 18 | else: 19 | raise NotImplementedError(f"unexpected {name=}") 20 | ckpt_weights = torch.load(weight_path, map_location='cpu')['model'] 21 | missing_keys, unexpected_keys = distiller.load_state_dict(ckpt_weights, strict=False if name == 'mast3r' else True) 22 | 23 | return distiller 24 | -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_noposplat import EncoderNoPoSplatCfg, EncoderNoPoSplat 5 | from .encoder_noposplat_multi import EncoderNoPoSplatMulti 6 | from .visualization.encoder_visualizer import EncoderVisualizer 7 | 8 | ENCODERS = { 9 | "noposplat": (EncoderNoPoSplat, None), 10 | "noposplat_multi": (EncoderNoPoSplatMulti, None), 11 | } 12 | 13 | EncoderCfg = EncoderNoPoSplatCfg 14 | 15 | 16 | def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]: 17 | encoder, visualizer = ENCODERS[cfg.name] 18 | encoder = encoder(cfg) 19 | if visualizer is not None: 20 | visualizer = visualizer(cfg.visualizer, encoder) 21 | return encoder, visualizer 22 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch.nn as nn 3 | 4 | from .backbone import Backbone 5 | from .backbone_croco_multiview import AsymmetricCroCoMulti 6 | from .backbone_dino import BackboneDino, BackboneDinoCfg 7 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 8 | from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg 9 | 10 | BACKBONES: dict[str, Backbone[Any]] = { 11 | "resnet": BackboneResnet, 12 | "dino": BackboneDino, 13 | "croco": AsymmetricCroCo, 14 | "croco_multi": AsymmetricCroCoMulti, 15 | } 16 | 17 | BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg 18 | 19 | 20 | def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module: 21 | return BACKBONES[cfg.name](cfg, d_in) 22 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor, nn 6 | 7 | from ....dataset.types import BatchedViews 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Backbone(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | ) -> Float[Tensor, "batch view d_out height width"]: 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def d_out(self) -> int: 29 | pass 30 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone_dino.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor, nn 8 | 9 | from ....dataset.types import BatchedViews 10 | from .backbone import Backbone 11 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 12 | 13 | 14 | @dataclass 15 | class BackboneDinoCfg: 16 | name: Literal["dino"] 17 | model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"] 18 | d_out: int 19 | 20 | 21 | class BackboneDino(Backbone[BackboneDinoCfg]): 22 | def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None: 23 | super().__init__(cfg) 24 | assert d_in == 3 25 | self.dino = torch.hub.load("facebookresearch/dino:main", cfg.model) 26 | self.resnet_backbone = BackboneResnet( 27 | BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out), 28 | d_in, 29 | ) 30 | self.global_token_mlp = nn.Sequential( 31 | nn.Linear(768, 768), 32 | nn.ReLU(), 33 | nn.Linear(768, cfg.d_out), 34 | ) 35 | self.local_token_mlp = nn.Sequential( 36 | nn.Linear(768, 768), 37 | nn.ReLU(), 38 | nn.Linear(768, cfg.d_out), 39 | ) 40 | 41 | def forward( 42 | self, 43 | context: BatchedViews, 44 | ) -> Float[Tensor, "batch view d_out height width"]: 45 | # Compute features from the DINO-pretrained resnet50. 46 | resnet_features = self.resnet_backbone(context) 47 | 48 | # Compute features from the DINO-pretrained ViT. 49 | b, v, _, h, w = context["image"].shape 50 | assert h % self.patch_size == 0 and w % self.patch_size == 0 51 | tokens = rearrange(context["image"], "b v c h w -> (b v) c h w") 52 | tokens = self.dino.get_intermediate_layers(tokens)[0] 53 | global_token = self.global_token_mlp(tokens[:, 0]) 54 | local_tokens = self.local_token_mlp(tokens[:, 1:]) 55 | 56 | # Repeat the global token to match the image shape. 57 | global_token = repeat(global_token, "(b v) c -> b v c h w", b=b, v=v, h=h, w=w) 58 | 59 | # Repeat the local tokens to match the image shape. 60 | local_tokens = repeat( 61 | local_tokens, 62 | "(b v) (h w) c -> b v c (h hps) (w wps)", 63 | b=b, 64 | v=v, 65 | h=h // self.patch_size, 66 | hps=self.patch_size, 67 | w=w // self.patch_size, 68 | wps=self.patch_size, 69 | ) 70 | 71 | return resnet_features + local_tokens + global_token 72 | 73 | @property 74 | def patch_size(self) -> int: 75 | return int("".join(filter(str.isdigit, self.cfg.model))) 76 | 77 | @property 78 | def d_out(self) -> int: 79 | return self.cfg.d_out 80 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/backbone_resnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Literal 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from einops import rearrange 9 | from jaxtyping import Float 10 | from torch import Tensor, nn 11 | from torchvision.models import ResNet 12 | 13 | from ....dataset.types import BatchedViews 14 | from .backbone import Backbone 15 | 16 | 17 | @dataclass 18 | class BackboneResnetCfg: 19 | name: Literal["resnet"] 20 | model: Literal[ 21 | "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50" 22 | ] 23 | num_layers: int 24 | use_first_pool: bool 25 | d_out: int 26 | 27 | 28 | class BackboneResnet(Backbone[BackboneResnetCfg]): 29 | model: ResNet 30 | 31 | def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None: 32 | super().__init__(cfg) 33 | 34 | assert d_in == 3 35 | 36 | norm_layer = functools.partial( 37 | nn.InstanceNorm2d, 38 | affine=False, 39 | track_running_stats=False, 40 | ) 41 | 42 | if cfg.model == "dino_resnet50": 43 | self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") 44 | else: 45 | self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer) 46 | 47 | # Set up projections 48 | self.projections = nn.ModuleDict({}) 49 | for index in range(1, cfg.num_layers): 50 | key = f"layer{index}" 51 | block = getattr(self.model, key) 52 | conv_index = 1 53 | try: 54 | while True: 55 | d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels 56 | conv_index += 1 57 | except AttributeError: 58 | pass 59 | self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1) 60 | 61 | # Add a projection for the first layer. 62 | self.projections["layer0"] = nn.Conv2d( 63 | self.model.conv1.out_channels, cfg.d_out, 1 64 | ) 65 | 66 | def forward( 67 | self, 68 | context: BatchedViews, 69 | ) -> Float[Tensor, "batch view d_out height width"]: 70 | # Merge the batch dimensions. 71 | b, v, _, h, w = context["image"].shape 72 | x = rearrange(context["image"], "b v c h w -> (b v) c h w") 73 | 74 | # Run the images through the resnet. 75 | x = self.model.conv1(x) 76 | x = self.model.bn1(x) 77 | x = self.model.relu(x) 78 | features = [self.projections["layer0"](x)] 79 | 80 | # Propagate the input through the resnet's layers. 81 | for index in range(1, self.cfg.num_layers): 82 | key = f"layer{index}" 83 | if index == 0 and self.cfg.use_first_pool: 84 | x = self.model.maxpool(x) 85 | x = getattr(self.model, key)(x) 86 | features.append(self.projections[key](x)) 87 | 88 | # Upscale the features. 89 | features = [ 90 | F.interpolate(f, (h, w), mode="bilinear", align_corners=True) 91 | for f in features 92 | ] 93 | features = torch.stack(features).sum(dim=0) 94 | 95 | # Separate batch dimensions. 96 | return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v) 97 | 98 | @property 99 | def d_out(self) -> int: 100 | return self.cfg.d_out 101 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/README.md: -------------------------------------------------------------------------------- 1 | Most of the code under src/model/encoder/backbone/croco/ is from the original CROCO implementation. 2 | The code is not modified in any way except the relative module path. 3 | The original code can be found at [croco Github Repo](https://github.com/naver/croco/tree/743ee71a2a9bf57cea6832a9064a70a0597fcfcb/models). 4 | 5 | 6 | Except: 7 | - 'misc.py', 'patch_embed.py' is from DUSt3R. -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/NoPoSplat/9b04307ebe179d610c04208db8d69f7c3106d03b/src/model/encoder/backbone/croco/__init__.py -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope2d import cuRoPE2D 5 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 40 | return tokens -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) 35 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/masking.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Masking utils 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | class RandomMask(nn.Module): 13 | """ 14 | random masking 15 | """ 16 | 17 | def __init__(self, num_patches, mask_ratio): 18 | super().__init__() 19 | self.num_patches = num_patches 20 | self.num_mask = int(mask_ratio * self.num_patches) 21 | 22 | def __call__(self, x): 23 | noise = torch.rand(x.size(0), self.num_patches, device=x.device) 24 | argsort = torch.argsort(noise, dim=1) 25 | return argsort < self.num_mask 26 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def fill_default_args(kwargs, func): 11 | import inspect # a bit hacky but it works reliably 12 | signature = inspect.signature(func) 13 | 14 | for k, v in signature.parameters.items(): 15 | if v.default is inspect.Parameter.empty: 16 | continue 17 | kwargs.setdefault(k, v.default) 18 | 19 | return kwargs 20 | 21 | 22 | def freeze_all_params(modules): 23 | for module in modules: 24 | try: 25 | for n, param in module.named_parameters(): 26 | param.requires_grad = False 27 | except AttributeError: 28 | # module is directly a parameter 29 | module.requires_grad = False 30 | 31 | 32 | def is_symmetrized(gt1, gt2): 33 | x = gt1['instance'] 34 | y = gt2['instance'] 35 | if len(x) == len(y) and len(x) == 1: 36 | return False # special case of batchsize 1 37 | ok = True 38 | for i in range(0, len(x), 2): 39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) 40 | return ok 41 | 42 | 43 | def flip(tensor): 44 | """ flip so that tensor[0::2] <=> tensor[1::2] """ 45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) 46 | 47 | 48 | def interleave(tensor1, tensor2): 49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) 50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) 51 | return res1, res2 52 | 53 | 54 | def _interleave_imgs(img1, img2): 55 | res = {} 56 | for key, value1 in img1.items(): 57 | value2 = img2[key] 58 | if isinstance(value1, torch.Tensor): 59 | value = torch.stack((value1, value2), dim=1).flatten(0, 1) 60 | else: 61 | value = [x for pair in zip(value1, value2) for x in pair] 62 | res[key] = value 63 | return res 64 | 65 | 66 | def make_batch_symmetric(view1, view2): 67 | view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) 68 | return view1, view2 69 | 70 | 71 | def transpose_to_landscape(head, activate=True): 72 | """ Predict in the correct aspect-ratio, 73 | then transpose the result in landscape 74 | and stack everything back together. 75 | """ 76 | def wrapper_no(decout, true_shape, ray_embedding=None): 77 | B = len(true_shape) 78 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' 79 | H, W = true_shape[0].cpu().tolist() 80 | res = head(decout, (H, W), ray_embedding=ray_embedding) 81 | return res 82 | 83 | def wrapper_yes(decout, true_shape, ray_embedding=None): 84 | B = len(true_shape) 85 | # by definition, the batch is in landscape mode so W >= H 86 | H, W = int(true_shape.min()), int(true_shape.max()) 87 | 88 | height, width = true_shape.T 89 | is_landscape = (width >= height) 90 | is_portrait = ~is_landscape 91 | 92 | # true_shape = true_shape.cpu() 93 | if is_landscape.all(): 94 | return head(decout, (H, W), ray_embedding=ray_embedding) 95 | if is_portrait.all(): 96 | return transposed(head(decout, (W, H), ray_embedding=ray_embedding)) 97 | 98 | # batch is a mix of both portraint & landscape 99 | def selout(ar): return [d[ar] for d in decout] 100 | l_result = head(selout(is_landscape), (H, W), ray_embedding=ray_embedding) 101 | p_result = transposed(head(selout(is_portrait), (W, H), ray_embedding=ray_embedding)) 102 | 103 | # allocate full result 104 | result = {} 105 | for k in l_result | p_result: 106 | x = l_result[k].new(B, *l_result[k].shape[1:]) 107 | x[is_landscape] = l_result[k] 108 | x[is_portrait] = p_result[k] 109 | result[k] = x 110 | 111 | return result 112 | 113 | return wrapper_yes if activate else wrapper_no 114 | 115 | 116 | def transposed(dic): 117 | return {k: v.swapaxes(1, 2) for k, v in dic.items()} 118 | 119 | 120 | def invalid_to_nans(arr, valid_mask, ndim=999): 121 | if valid_mask is not None: 122 | arr = arr.clone() 123 | arr[~valid_mask] = float('nan') 124 | if arr.ndim > ndim: 125 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 126 | return arr 127 | 128 | 129 | def invalid_to_zeros(arr, valid_mask, ndim=999): 130 | if valid_mask is not None: 131 | arr = arr.clone() 132 | arr[~valid_mask] = 0 133 | nnz = valid_mask.view(len(valid_mask), -1).sum(1) 134 | else: 135 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image 136 | if arr.ndim > ndim: 137 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 138 | return arr, nnz 139 | -------------------------------------------------------------------------------- /src/model/encoder/backbone/croco/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # PatchEmbed implementation for DUST3R, 6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio 7 | # -------------------------------------------------------- 8 | import torch 9 | 10 | from .blocks import PatchEmbed 11 | 12 | 13 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3): 14 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] 15 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, in_chans, enc_embed_dim) 16 | return patch_embed 17 | 18 | 19 | class PatchEmbedDust3R(PatchEmbed): 20 | def forward(self, x, **kw): 21 | B, C, H, W = x.shape 22 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 23 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 24 | x = self.proj(x) 25 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 26 | if self.flatten: 27 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 28 | x = self.norm(x) 29 | return x, pos 30 | 31 | 32 | class ManyAR_PatchEmbed (PatchEmbed): 33 | """ Handle images with non-square aspect ratio. 34 | All images in the same batch have the same aspect ratio. 35 | true_shape = [(height, width) ...] indicates the actual shape of each image. 36 | """ 37 | 38 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 39 | self.embed_dim = embed_dim 40 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) 41 | 42 | def forward(self, img, true_shape): 43 | B, C, H, W = img.shape 44 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' 45 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 46 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 47 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" 48 | 49 | # size expressed in tokens 50 | W //= self.patch_size[0] 51 | H //= self.patch_size[1] 52 | n_tokens = H * W 53 | 54 | height, width = true_shape.T 55 | is_landscape = (width >= height) 56 | is_portrait = ~is_landscape 57 | 58 | # allocate result 59 | x = img.new_zeros((B, n_tokens, self.embed_dim)) 60 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) 61 | 62 | # linear projection, transposed if necessary 63 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() 64 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() 65 | 66 | pos[is_landscape] = self.position_getter(1, H, W, pos.device) 67 | pos[is_portrait] = self.position_getter(1, W, H, pos.device) 68 | 69 | x = self.norm(x) 70 | return x, pos 71 | -------------------------------------------------------------------------------- /src/model/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 8 | def quaternion_to_matrix( 9 | quaternions: Float[Tensor, "*batch 4"], 10 | eps: float = 1e-8, 11 | ) -> Float[Tensor, "*batch 3 3"]: 12 | # Order changed to match scipy format! 13 | i, j, k, r = torch.unbind(quaternions, dim=-1) 14 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 15 | 16 | o = torch.stack( 17 | ( 18 | 1 - two_s * (j * j + k * k), 19 | two_s * (i * j - k * r), 20 | two_s * (i * k + j * r), 21 | two_s * (i * j + k * r), 22 | 1 - two_s * (i * i + k * k), 23 | two_s * (j * k - i * r), 24 | two_s * (i * k - j * r), 25 | two_s * (j * k + i * r), 26 | 1 - two_s * (i * i + j * j), 27 | ), 28 | -1, 29 | ) 30 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 31 | 32 | 33 | def build_covariance( 34 | scale: Float[Tensor, "*#batch 3"], 35 | rotation_xyzw: Float[Tensor, "*#batch 4"], 36 | ) -> Float[Tensor, "*batch 3 3"]: 37 | scale = scale.diag_embed() 38 | rotation = quaternion_to_matrix(rotation_xyzw) 39 | return ( 40 | rotation 41 | @ scale 42 | @ rearrange(scale, "... i j -> ... j i") 43 | @ rearrange(rotation, "... i j -> ... j i") 44 | ) 45 | -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | ) -> Gaussians: 24 | pass 25 | 26 | def get_data_shim(self) -> DataShim: 27 | """The default shim doesn't modify the batch.""" 28 | return lambda x: x 29 | -------------------------------------------------------------------------------- /src/model/encoder/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # head factory 6 | # -------------------------------------------------------- 7 | from .dpt_gs_head import create_gs_dpt_head 8 | from .linear_head import LinearPts3d 9 | from .dpt_head import create_dpt_head 10 | 11 | 12 | def head_factory(head_type, output_mode, net, has_conf=False, out_nchan=3): 13 | """" build a prediction head for the decoder 14 | """ 15 | if head_type == 'linear' and output_mode == 'pts3d': 16 | return LinearPts3d(net, has_conf) 17 | elif head_type == 'dpt' and output_mode == 'pts3d': 18 | return create_dpt_head(net, has_conf=has_conf) 19 | elif head_type == 'dpt' and output_mode == 'gs_params': 20 | return create_dpt_head(net, has_conf=False, out_nchan=out_nchan, postprocess_func=None) 21 | elif head_type == 'dpt_gs' and output_mode == 'gs_params': 22 | return create_gs_dpt_head(net, has_conf=False, out_nchan=out_nchan, postprocess_func=None) 23 | else: 24 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") 25 | -------------------------------------------------------------------------------- /src/model/encoder/heads/dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | # import dust3r.utils.path_to_croco 17 | from .dpt_block import DPTOutputAdapter 18 | from .postprocess import postprocess 19 | 20 | 21 | class DPTOutputAdapter_fix(DPTOutputAdapter): 22 | """ 23 | Adapt croco's DPTOutputAdapter implementation for dust3r: 24 | remove duplicated weigths, and fix forward for dust3r 25 | """ 26 | 27 | def init(self, dim_tokens_enc=768): 28 | super().init(dim_tokens_enc) 29 | # these are duplicated weights 30 | del self.act_1_postprocess 31 | del self.act_2_postprocess 32 | del self.act_3_postprocess 33 | del self.act_4_postprocess 34 | 35 | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None, ray_embedding=None): 36 | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' 37 | # H, W = input_info['image_size'] 38 | image_size = self.image_size if image_size is None else image_size 39 | H, W = image_size 40 | # Number of patches in height and width 41 | N_H = H // (self.stride_level * self.P_H) 42 | N_W = W // (self.stride_level * self.P_W) 43 | 44 | # Hook decoder onto 4 layers from specified ViT layers 45 | layers = [encoder_tokens[hook] for hook in self.hooks] 46 | 47 | # Extract only task-relevant tokens and ignore global tokens. 48 | layers = [self.adapt_tokens(l) for l in layers] 49 | 50 | # Reshape tokens to spatial representation 51 | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] 52 | 53 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 54 | # Project layers to chosen feature dim 55 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 56 | 57 | # Fuse layers using refinement stages 58 | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] 59 | path_3 = self.scratch.refinenet3(path_4, layers[2]) 60 | path_2 = self.scratch.refinenet2(path_3, layers[1]) 61 | path_1 = self.scratch.refinenet1(path_2, layers[0]) 62 | 63 | # if ray_embedding is not None: 64 | # ray_embedding = F.interpolate(ray_embedding, size=(path_1.shape[2], path_1.shape[3]), mode='bilinear') 65 | # path_1 = torch.cat([path_1, ray_embedding], dim=1) 66 | 67 | # Output head 68 | out = self.head(path_1) 69 | 70 | return out 71 | 72 | 73 | class PixelwiseTaskWithDPT(nn.Module): 74 | """ DPT module for dust3r, can return 3D points + confidence for all pixels""" 75 | 76 | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, 77 | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): 78 | super(PixelwiseTaskWithDPT, self).__init__() 79 | self.return_all_layers = True # backbone needs to return all layers 80 | self.postprocess = postprocess 81 | self.depth_mode = depth_mode 82 | self.conf_mode = conf_mode 83 | 84 | assert n_cls_token == 0, "Not implemented" 85 | dpt_args = dict(output_width_ratio=output_width_ratio, 86 | num_channels=num_channels, 87 | **kwargs) 88 | if hooks_idx is not None: 89 | dpt_args.update(hooks=hooks_idx) 90 | self.dpt = DPTOutputAdapter_fix(**dpt_args) 91 | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} 92 | self.dpt.init(**dpt_init_args) 93 | 94 | def forward(self, x, img_info, ray_embedding=None): 95 | out = self.dpt(x, image_size=(img_info[0], img_info[1]), ray_embedding=ray_embedding) 96 | if self.postprocess: 97 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 98 | return out 99 | 100 | 101 | def create_dpt_head(net, has_conf=False, out_nchan=3, postprocess_func=postprocess): 102 | """ 103 | return PixelwiseTaskWithDPT for given net params 104 | """ 105 | assert net.dec_depth > 9 106 | l2 = net.dec_depth 107 | feature_dim = 256 108 | last_dim = feature_dim//2 109 | ed = net.enc_embed_dim 110 | dd = net.dec_embed_dim 111 | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, 112 | feature_dim=feature_dim, 113 | last_dim=last_dim, 114 | hooks_idx=[0, l2*2//4, l2*3//4, l2], 115 | dim_tokens=[ed, dd, dd, dd], 116 | postprocess=postprocess_func, 117 | depth_mode=net.depth_mode, 118 | conf_mode=net.conf_mode, 119 | head_type='regression') 120 | -------------------------------------------------------------------------------- /src/model/encoder/heads/head_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not (stride == 1 and in_planes == planes): 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not (stride == 1 and in_planes == planes): 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not (stride == 1 and in_planes == planes): 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not (stride == 1 and in_planes == planes): 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1 and in_planes == planes: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.conv1(y) 50 | y = self.norm1(y) 51 | y = self.relu(y) 52 | y = self.conv2(y) 53 | y = self.norm2(y) 54 | y = self.relu(y) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x + y) 60 | 61 | 62 | class UnetExtractor(nn.Module): 63 | def __init__(self, in_channel=3, encoder_dim=[256, 256, 256], norm_fn='group'): 64 | super().__init__() 65 | self.in_ds = nn.Sequential( 66 | nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3), 67 | nn.GroupNorm(num_groups=8, num_channels=64), 68 | nn.ReLU(inplace=True) 69 | ) 70 | 71 | self.res1 = nn.Sequential( 72 | ResidualBlock(64, encoder_dim[0], stride=2, norm_fn=norm_fn), 73 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn) 74 | ) 75 | self.res2 = nn.Sequential( 76 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn), 77 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn) 78 | ) 79 | self.res3 = nn.Sequential( 80 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn), 81 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn), 82 | ) 83 | 84 | def forward(self, x): 85 | x = self.in_ds(x) 86 | x1 = self.res1(x) 87 | x2 = self.res2(x1) 88 | x3 = self.res3(x2) 89 | 90 | return x1, x2, x3 91 | 92 | 93 | class MultiBasicEncoder(nn.Module): 94 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]): 95 | super(MultiBasicEncoder, self).__init__() 96 | 97 | # output convolution for feature 98 | self.conv2 = nn.Sequential( 99 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 100 | nn.Conv2d(encoder_dim[2], encoder_dim[2] * 2, 3, padding=1)) 101 | 102 | # output convolution for context 103 | output_list = [] 104 | for dim in output_dim: 105 | conv_out = nn.Sequential( 106 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 107 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1)) 108 | output_list.append(conv_out) 109 | 110 | self.outputs08 = nn.ModuleList(output_list) 111 | 112 | def forward(self, x): 113 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2) 114 | 115 | outputs08 = [f(x) for f in self.outputs08] 116 | return outputs08, feat1, feat2 117 | 118 | 119 | if __name__ == '__main__': 120 | data = torch.ones((1, 3, 1024, 1024)) 121 | 122 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128]) 123 | 124 | x1, x2, x3 = model(data) 125 | print(x1.shape, x2.shape, x3.shape) 126 | -------------------------------------------------------------------------------- /src/model/encoder/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # linear head implementation for DUST3R 6 | # -------------------------------------------------------- 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .postprocess import postprocess 10 | 11 | 12 | class LinearPts3d (nn.Module): 13 | """ 14 | Linear head for dust3r 15 | Each token outputs: - 16x16 3D points (+ confidence) 16 | """ 17 | 18 | def __init__(self, net, has_conf=False): 19 | super().__init__() 20 | self.patch_size = net.patch_embed.patch_size[0] 21 | self.depth_mode = net.depth_mode 22 | self.conf_mode = net.conf_mode 23 | self.has_conf = has_conf 24 | 25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) 26 | 27 | def setup(self, croconet): 28 | pass 29 | 30 | def forward(self, decout, img_shape): 31 | H, W = img_shape 32 | tokens = decout[-1] 33 | B, S, D = tokens.shape 34 | 35 | # extract 3D points 36 | feat = self.proj(tokens) # B,S,D 37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 39 | 40 | # permute + norm depth 41 | return postprocess(feat, self.depth_mode, self.conf_mode) 42 | 43 | 44 | class LinearGS(nn.Module): 45 | """ 46 | Linear head for GS parameter prediction 47 | Each token outputs: - 16x16 3D points (+ confidence) 48 | """ 49 | 50 | def __init__(self, net, has_conf=False): 51 | super().__init__() 52 | self.patch_size = net.patch_embed.patch_size[0] 53 | self.depth_mode = net.depth_mode 54 | self.conf_mode = net.conf_mode 55 | self.has_conf = has_conf 56 | 57 | self.proj = nn.Linear(net.dec_embed_dim, (2 + 1 + net.gaussian_adapter.d_in)*self.patch_size**2) # 2 for xy offset, 1 for opacity 58 | 59 | def setup(self, croconet): 60 | pass 61 | 62 | def forward(self, decout, img_shape): 63 | H, W = img_shape 64 | tokens = decout[-1] 65 | B, S, D = tokens.shape 66 | 67 | # extract 3D points 68 | feat = self.proj(tokens) # B,S,D 69 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 70 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 71 | 72 | # permute + norm depth 73 | return postprocess(feat, self.depth_mode, self.conf_mode) 74 | -------------------------------------------------------------------------------- /src/model/encoder/heads/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # post process function for all heads: extract 3D points/confidence from output 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def postprocess(out, depth_mode, conf_mode): 11 | """ 12 | extract 3D points/confidence from prediction head output 13 | """ 14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3 15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) 16 | 17 | if conf_mode is not None: 18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) 19 | return res 20 | 21 | 22 | def reg_dense_depth(xyz, mode): 23 | """ 24 | extract 3D points from prediction head output 25 | """ 26 | mode, vmin, vmax = mode 27 | 28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) 29 | # assert no_bounds 30 | 31 | if mode == 'range': 32 | xyz = xyz.sigmoid() 33 | xyz = (1 - xyz) * vmin + xyz * vmax 34 | return xyz 35 | 36 | if mode == 'linear': 37 | if no_bounds: 38 | return xyz # [-inf, +inf] 39 | return xyz.clip(min=vmin, max=vmax) 40 | 41 | if mode == 'exp_direct': 42 | xyz = xyz.expm1() 43 | return xyz.clip(min=vmin, max=vmax) 44 | 45 | # distance to origin 46 | d = xyz.norm(dim=-1, keepdim=True) 47 | xyz = xyz / d.clip(min=1e-8) 48 | 49 | if mode == 'square': 50 | return xyz * d.square() 51 | 52 | if mode == 'exp': 53 | exp_d = d.expm1() 54 | if not no_bounds: 55 | exp_d = exp_d.clip(min=vmin, max=vmax) 56 | xyz = xyz * exp_d 57 | # if not no_bounds: 58 | # # xyz = xyz.clip(min=vmin, max=vmax) 59 | # depth = xyz.clone()[..., 2].clip(min=vmin, max=vmax) 60 | # xyz = torch.cat([xyz[..., :2], depth.unsqueeze(-1)], dim=-1) 61 | return xyz 62 | 63 | raise ValueError(f'bad {mode=}') 64 | 65 | 66 | def reg_dense_conf(x, mode): 67 | """ 68 | extract confidence from prediction head output 69 | """ 70 | mode, vmin, vmax = mode 71 | if mode == 'opacity': 72 | return x.sigmoid() 73 | if mode == 'exp': 74 | return vmin + x.exp().clip(max=vmax-vmin) 75 | if mode == 'sigmoid': 76 | return (vmax - vmin) * torch.sigmoid(x) + vmin 77 | raise ValueError(f'bad {mode=}') 78 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /src/model/encoder/visualization/encoder_visualizer_epipolar_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerEpipolarCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /src/model/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange, repeat 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """For the sake of simplicity, this encodes values in the range [0, 1].""" 10 | 11 | frequencies: Float[Tensor, "frequency phase"] 12 | phases: Float[Tensor, "frequency phase"] 13 | 14 | def __init__(self, num_octaves: int): 15 | super().__init__() 16 | octaves = torch.arange(num_octaves).float() 17 | 18 | # The lowest frequency has a period of 1. 19 | frequencies = 2 * torch.pi * 2**octaves 20 | frequencies = repeat(frequencies, "f -> f p", p=2) 21 | self.register_buffer("frequencies", frequencies, persistent=False) 22 | 23 | # Choose the phases to match sine and cosine. 24 | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) 25 | phases = repeat(phases, "p -> f p", f=num_octaves) 26 | self.register_buffer("phases", phases, persistent=False) 27 | 28 | def forward( 29 | self, 30 | samples: Float[Tensor, "*batch dim"], 31 | ) -> Float[Tensor, "*batch embedded_dim"]: 32 | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") 33 | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") 34 | 35 | def d_out(self, dimensionality: int): 36 | return self.frequencies.numel() * dimensionality 37 | -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | means: Float[Tensor, "gaussian 3"], 28 | scales: Float[Tensor, "gaussian 3"], 29 | rotations: Float[Tensor, "gaussian 4"], 30 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 31 | opacities: Float[Tensor, " gaussian"], 32 | path: Path, 33 | shift_and_scale: bool = False, 34 | save_sh_dc_only: bool = True, 35 | ): 36 | if shift_and_scale: 37 | # Shift the scene so that the median Gaussian is at the origin. 38 | means = means - means.median(dim=0).values 39 | 40 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 41 | scale_factor = means.abs().quantile(0.95, dim=0).max() 42 | means = means / scale_factor 43 | scales = scales / scale_factor 44 | 45 | # Apply the rotation to the Gaussian rotations. 46 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 47 | rotations = R.from_matrix(rotations).as_quat() 48 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 49 | rotations = np.stack((w, x, y, z), axis=-1) 50 | 51 | # Since current model use SH_degree = 4, 52 | # which require large memory to store, we can only save the DC band to save memory. 53 | f_dc = harmonics[..., 0] 54 | f_rest = harmonics[..., 1:].flatten(start_dim=1) 55 | 56 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0 if save_sh_dc_only else f_rest.shape[1])] 57 | elements = np.empty(means.shape[0], dtype=dtype_full) 58 | attributes = [ 59 | means.detach().cpu().numpy(), 60 | torch.zeros_like(means).detach().cpu().numpy(), 61 | f_dc.detach().cpu().contiguous().numpy(), 62 | f_rest.detach().cpu().contiguous().numpy(), 63 | opacities[..., None].detach().cpu().numpy(), 64 | scales.log().detach().cpu().numpy(), 65 | rotations, 66 | ] 67 | if save_sh_dc_only: 68 | # remove f_rest from attributes 69 | attributes.pop(3) 70 | 71 | attributes = np.concatenate(attributes, axis=1) 72 | elements[:] = list(map(tuple, attributes)) 73 | path.parent.mkdir(exist_ok=True, parents=True) 74 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 75 | -------------------------------------------------------------------------------- /src/model/transformer/attention.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | import torch 26 | from einops import rearrange 27 | from torch import nn 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__( 32 | self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None 33 | ): 34 | super().__init__() 35 | inner_dim = dim_head * heads 36 | project_out = not (heads == 1 and dim_head == dim) 37 | 38 | self.heads = heads 39 | self.scale = dim_head**-0.5 40 | 41 | self.attend = nn.Softmax(dim=-1) 42 | if selfatt: 43 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 44 | else: 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) 47 | 48 | self.to_out = ( 49 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 50 | if project_out 51 | else nn.Identity() 52 | ) 53 | 54 | def forward(self, x, z=None): 55 | if z is None: 56 | qkv = self.to_qkv(x).chunk(3, dim=-1) 57 | else: 58 | q = self.to_q(x) 59 | k, v = self.to_kv(z).chunk(2, dim=-1) 60 | qkv = (q, k, v) 61 | 62 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) 63 | 64 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 65 | 66 | attn = self.attend(dots) 67 | 68 | out = torch.matmul(attn, v) 69 | out = rearrange(out, "b h n d -> b n (h d)") 70 | return self.to_out(out) 71 | -------------------------------------------------------------------------------- /src/model/transformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout=0.0): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Linear(hidden_dim, dim), 36 | nn.Dropout(dropout), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | -------------------------------------------------------------------------------- /src/model/transformer/pre_norm.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class PreNorm(nn.Module): 29 | def __init__(self, dim, fn): 30 | super().__init__() 31 | self.norm = nn.LayerNorm(dim) 32 | self.fn = fn 33 | 34 | def forward(self, x, **kwargs): 35 | return self.fn(self.norm(x), **kwargs) 36 | -------------------------------------------------------------------------------- /src/model/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | from .attention import Attention 28 | from .feed_forward import FeedForward 29 | from .pre_norm import PreNorm 30 | 31 | 32 | class Transformer(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | depth, 37 | heads, 38 | dim_head, 39 | mlp_dim, 40 | dropout=0.0, 41 | selfatt=True, 42 | kv_dim=None, 43 | feed_forward_layer=FeedForward, 44 | ): 45 | super().__init__() 46 | self.layers = nn.ModuleList([]) 47 | for _ in range(depth): 48 | self.layers.append( 49 | nn.ModuleList( 50 | [ 51 | PreNorm( 52 | dim, 53 | Attention( 54 | dim, 55 | heads=heads, 56 | dim_head=dim_head, 57 | dropout=dropout, 58 | selfatt=selfatt, 59 | kv_dim=kv_dim, 60 | ), 61 | ), 62 | PreNorm(dim, feed_forward_layer(dim, mlp_dim, dropout=dropout)), 63 | ] 64 | ) 65 | ) 66 | 67 | def forward(self, x, z=None, **kwargs): 68 | for attn, ff in self.layers: 69 | x = attn(x, z=z) + x 70 | x = ff(x, **kwargs) + x 71 | return x 72 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | -------------------------------------------------------------------------------- /src/scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import hydra 6 | import torch 7 | from jaxtyping import install_import_hook 8 | from lightning.pytorch import Trainer 9 | from omegaconf import DictConfig 10 | 11 | # Configure beartype and jaxtyping. 12 | with install_import_hook( 13 | ("src",), 14 | ("beartype", "beartype"), 15 | ): 16 | from src.config import load_typed_config, separate_dataset_cfg_wrappers 17 | from src.dataset.data_module import DataLoaderCfg, DataModule, DatasetCfgWrapper 18 | from src.evaluation.evaluation_cfg import EvaluationCfg 19 | from src.evaluation.metric_computer import MetricComputer 20 | from src.global_cfg import set_cfg 21 | 22 | 23 | @dataclass 24 | class RootCfg: 25 | evaluation: EvaluationCfg 26 | dataset: list[DatasetCfgWrapper] 27 | data_loader: DataLoaderCfg 28 | seed: int 29 | 30 | 31 | @hydra.main( 32 | version_base=None, 33 | config_path="../../config", 34 | config_name="compute_metrics", 35 | ) 36 | def evaluate(cfg_dict: DictConfig): 37 | cfg = load_typed_config(cfg_dict, RootCfg, {list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers},) 38 | set_cfg(cfg_dict) 39 | torch.manual_seed(cfg.seed) 40 | trainer = Trainer(max_epochs=-1, accelerator="gpu") 41 | computer = MetricComputer(cfg.evaluation) 42 | data_module = DataModule(cfg.dataset, cfg.data_loader) 43 | metrics = trainer.test(computer, datamodule=data_module) 44 | cfg.evaluation.output_metrics_path.parent.mkdir(exist_ok=True, parents=True) 45 | with cfg.evaluation.output_metrics_path.open("w") as f: 46 | json.dump(metrics[0], f) 47 | 48 | 49 | if __name__ == "__main__": 50 | evaluate() 51 | -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [ 88 | color_function(batch) 89 | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) 90 | ] 91 | color = torch.cat(color, dim=0) 92 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 93 | 94 | # If any MSAA passes remain, subdivide. 95 | if remaining_passes > 0: 96 | mask = detect_msaa_pixels(color) 97 | batch_index, row_index, col_index = torch.where(mask) 98 | xy = xy[batch_index, row_index, col_index] 99 | 100 | offsets = generate_sample_grid((subdivision, subdivision), device) 101 | offsets = (offsets / subdivision - 0.5) * scale 102 | 103 | color_fine = run_msaa_pass( 104 | xy[:, None, None] + offsets, 105 | color_function, 106 | scale / subdivision, 107 | subdivision, 108 | remaining_passes - 1, 109 | device, 110 | batch_size=batch_size, 111 | ) 112 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 113 | 114 | return color 115 | 116 | 117 | @torch.no_grad() 118 | def render( 119 | shape: tuple[int, int], 120 | color_function: ColorFunction, 121 | device: torch.device, 122 | subdivision: int = 8, 123 | num_passes: int = 2, 124 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 125 | xy = generate_sample_grid(shape, device) 126 | return run_msaa_pass( 127 | xy[None], 128 | color_function, 129 | 1.0, 130 | subdivision, 131 | num_passes, 132 | device, 133 | )[0] 134 | 135 | 136 | def render_over_image( 137 | image: Float[Tensor, "3 height width"], 138 | color_function: ColorFunction, 139 | device: torch.device, 140 | subdivision: int = 8, 141 | num_passes: int = 1, 142 | ) -> Float[Tensor, "3 height width"]: 143 | _, h, w = image.shape 144 | overlay = render( 145 | (h, w), 146 | color_function, 147 | device, 148 | subdivision=subdivision, 149 | num_passes=num_passes, 150 | ) 151 | color, alpha = overlay.split((3, 1), dim=0) 152 | return image * (1 - alpha) + color * alpha 153 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [ 16 | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) 17 | for x in images 18 | ] 19 | for image, result in zip(images, results): 20 | slices = [slice(0, x) for x in image.shape] 21 | result[slices] = image[slices] 22 | return results 23 | 24 | 25 | def render_projections( 26 | gaussians: Gaussians, 27 | resolution: int, 28 | margin: float = 0.1, 29 | draw_label: bool = True, 30 | extra_label: str = "", 31 | ) -> Float[Tensor, "batch 3 3 height width"]: 32 | device = gaussians.means.device 33 | b, _, _ = gaussians.means.shape 34 | 35 | # Compute the minima and maxima of the scene. 36 | minima = gaussians.means.min(dim=1).values 37 | maxima = gaussians.means.max(dim=1).values 38 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 39 | minima, maxima, margin=margin 40 | ) 41 | 42 | projections = [] 43 | for look_axis in range(3): 44 | right_axis = (look_axis + 1) % 3 45 | down_axis = (look_axis + 2) % 3 46 | 47 | # Define the extrinsics for rendering. 48 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 49 | extrinsics[:, right_axis, 0] = 1 50 | extrinsics[:, down_axis, 1] = 1 51 | extrinsics[:, look_axis, 2] = 1 52 | extrinsics[:, right_axis, 3] = 0.5 * ( 53 | scene_minima[:, right_axis] + scene_maxima[:, right_axis] 54 | ) 55 | extrinsics[:, down_axis, 3] = 0.5 * ( 56 | scene_minima[:, down_axis] + scene_maxima[:, down_axis] 57 | ) 58 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 59 | extrinsics[:, 3, 3] = 1 60 | 61 | # Define the intrinsics for rendering. 62 | extents = scene_maxima - scene_minima 63 | far = extents[:, look_axis] 64 | near = torch.zeros_like(far) 65 | width = extents[:, right_axis] 66 | height = extents[:, down_axis] 67 | 68 | projection = render_cuda_orthographic( 69 | extrinsics, 70 | width, 71 | height, 72 | near, 73 | far, 74 | (resolution, resolution), 75 | torch.zeros((b, 3), dtype=torch.float32, device=device), 76 | gaussians.means, 77 | gaussians.covariances, 78 | gaussians.harmonics, 79 | gaussians.opacities, 80 | fov_degrees=10.0, 81 | ) 82 | if draw_label: 83 | right_axis_name = "XYZ"[right_axis] 84 | down_axis_name = "XYZ"[down_axis] 85 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 86 | projection = torch.stack([add_label(x, label) for x in projection]) 87 | 88 | projections.append(projection) 89 | 90 | return torch.stack(pad(projections), dim=1) 91 | 92 | 93 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 94 | # Define colors for context and target views. 95 | num_context_views = batch["context"]["extrinsics"].shape[1] 96 | num_target_views = batch["target"]["extrinsics"].shape[1] 97 | color = torch.ones( 98 | (num_target_views + num_context_views, 3), 99 | dtype=torch.float32, 100 | device=batch["target"]["extrinsics"].device, 101 | ) 102 | color[num_context_views:, 1:] = 0 103 | 104 | return draw_cameras( 105 | resolution, 106 | torch.cat( 107 | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) 108 | ), 109 | torch.cat( 110 | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) 111 | ), 112 | color, 113 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 114 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 115 | ) 116 | --------------------------------------------------------------------------------