├── .gitignore ├── LICENSE ├── README.md ├── assets ├── loss_curve.png └── render_video.gif ├── config ├── compute_metrics.yaml ├── dataset │ ├── re10k.yaml │ ├── view_sampler │ │ ├── all.yaml │ │ ├── arbitrary.yaml │ │ ├── bounded.yaml │ │ └── evaluation.yaml │ └── view_sampler_dataset_specific_config │ │ ├── bounded_re10k.yaml │ │ └── evaluation_re10k.yaml ├── experiment │ └── re10k.yaml ├── generate_evaluation_index.yaml ├── loss │ ├── depth.yaml │ ├── lpips.yaml │ └── mse.yaml ├── main.yaml └── model │ ├── decoder │ └── splatting_cuda.yaml │ └── encoder │ └── lrm.yaml ├── requirements.txt └── src ├── config.py ├── dataset ├── __init__.py ├── data_module.py ├── dataset.py ├── dataset_re10k.py ├── shims │ ├── augmentation_shim.py │ ├── bounds_shim.py │ ├── crop_shim.py │ └── patch_shim.py ├── types.py ├── validation_wrapper.py └── view_sampler │ ├── __init__.py │ ├── three_view_hack.py │ ├── view_sampler.py │ ├── view_sampler_all.py │ ├── view_sampler_arbitrary.py │ ├── view_sampler_bounded.py │ └── view_sampler_evaluation.py ├── evaluation ├── evaluation_cfg.py ├── evaluation_index_generator.py ├── metric_computer.py └── metrics.py ├── geometry ├── pose_utils.py └── projection.py ├── global_cfg.py ├── loss ├── __init__.py ├── loss.py ├── loss_depth.py ├── loss_lpips.py └── loss_mse.py ├── main.py ├── misc ├── LocalLogger.py ├── benchmarker.py ├── collation.py ├── discrete_probability_distribution.py ├── heterogeneous_pairings.py ├── image_io.py ├── nn_module_tools.py ├── sh_rotation.py ├── step_tracker.py └── wandb_tools.py ├── model ├── decoder │ ├── __init__.py │ ├── cuda_splatting.py │ ├── decoder.py │ ├── decoder_splatting_cuda.py │ └── spaltting_example.py ├── encoder │ ├── __init__.py │ ├── encoder.py │ ├── encoder_lrm.py │ └── transformer_processor │ │ └── processor.py ├── model_wrapper.py ├── ply_export.py └── types.py └── visualization ├── annotation.py ├── camera_trajectory ├── interpolation.py ├── spin.py └── wobble.py ├── color_map.py ├── colors.py ├── drawing ├── cameras.py ├── coordinate_conversion.py ├── lines.py ├── points.py ├── rendering.py └── types.py ├── layout.py └── validation_in_3d.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | /datasets 163 | /dataset_cache 164 | 165 | # Outputs 166 | /outputs 167 | /lightning_logs 168 | /checkpoints 169 | 170 | .bashrc 171 | /launcher_venv 172 | /slurm_logs 173 | *.torch 174 | *.ckpt 175 | table.tex 176 | /baselines 177 | /test/* 178 | 179 | 180 | wandb/ 181 | **/.ipynb_checkpoints/ 182 | .vscode/ 183 | .idea 184 | 185 | # ignore these types 186 | *.pyc 187 | *.out 188 | *.log 189 | *.mexa64 190 | *.pdf 191 | *.tar 192 | *.out 193 | *.ipynb 194 | 195 | *.jpg 196 | *.out 197 | 198 | *.ply 199 | events.* 200 | checkpoint/* 201 | checkpoints/* 202 | 203 | submodule/* 204 | wandb/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 David Charatan, Sizhe Li, Andrea Tagliasacchi, and Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial Implementation of GS-LRM 2 | 3 | This is the code an unofficial implementation for **GS-LRM: Large Reconstruction Model for 3D Gaussian Splatting**. 4 | 5 | Check out the official [project website here](https://sai-bi.github.io/project/gs-lrm/). 6 | 7 | In this repository, we present a PyTorch implementation of GS-LRM, focusing specifically on the stage 1 training pipeline on scene-level datasets from the original paper. We welcome and encourage the community to extend this work by developing the stage 2 training process, inference pipeline, and evaluation frameworks based on our implementation. 8 | 9 | ## Installation 10 | 11 | To get started, create a virtual environment using Python 3.10+: 12 | 13 | ```bash 14 | conda create -n gs-lrm python=3.10 15 | conda activate gs-lrm 16 | # Install these first! We recommend using pytorch version==2.5.0 with cuda 12.4 by default. 17 | pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124 18 | pip install -r requirements.txt 19 | # We recommend using xformers==0.0.28.post2 to match the version of pytorch. 20 | pip3 install -U xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu124 21 | ``` 22 | 23 | Install `diff-gaussian-rasterization` now, we recommend using a third party implementation to support depth rendering and backpropagation of gaussian rasterization: 24 | 25 | ```bash 26 | mkdir submodules 27 | cd submodules 28 | git clone https://github.com/ashawkey/diff-gaussian-rasterization.git --recursive 29 | cd diff-gaussian-rasterization 30 | pip install -e . 31 | ``` 32 | 33 | ## Acquiring Datasets 34 | 35 | We follow the same data format as [pixelSplat](https://davidcharatan.com/pixelsplat/), which was trained using versions of the RealEstate10k dataset that were split into ~100 MB chunks for use on server cluster file systems. Small subsets of the Real Estate 10k datasets in this format can be found [here](https://drive.google.com/drive/folders/1joiezNCyQK2BvWMnfwHJpm2V77c7iYGe?usp=sharing). To use them, simply unzip them into a newly created `datasets` folder in the project root directory. 36 | 37 | After downloading the dataset, please change the dataset path in the `config/dataset/re10k.yaml` and `config/experiment/re10k.yaml` file to your personal dataset path. 38 | 39 | ## Sanity Check: Training Diagnostics & Rendering Validation 40 | 41 | ### 1. Training Loss Dynamics 42 | ![Training Loss Dynamics](assets/loss_curve.png) 43 | 44 | 45 | ### 2. Rendering Quality Assessment 46 | 47 | ![render_video](assets/render_video.gif) 48 | 49 | **Evaluation Set Fidelity Metrics**: 50 | ```python 51 | PSNR: 28.42 dB # Peak Signal-to-Noise Ratio 52 | SSIM: 0.903 # Structural Similarity Index 53 | LPIPS: 0.052 # Learned Perceptual Image Patch Similarity 54 | ``` 55 | 56 | Please note that the evaluation setting (resolution: 256x256, 2 context views, 4 target views) is not the same as the original paper setting, and we didn't conduct the evaluation on the full evaluation set, so the metrics are not directly comparable with the original paper. 57 | 58 | ## Running the Code 59 | 60 | ### Training 61 | 62 | The main entry point is `src/main.py`. Call it via: 63 | 64 | ```bash 65 | python3 -m src.main +experiment=re10k 66 | ``` 67 | 68 | This configuration requires a single GPU with 80 GB of VRAM (A100 or H100). To reduce memory usage, you can change the batch size as follows: 69 | 70 | ```bash 71 | python3 -m src.main +experiment=re10k data_loader.train.batch_size=1 72 | ``` 73 | 74 | Our code supports multi-GPU training. The above batch size is the per-GPU batch size. 75 | 76 | 77 | ## Acknowledgements 78 | 79 | We extend our sincere gratitude to the following research teams and open-source projects that have contributed foundational elements to this work: 80 | 81 | ### Codebase Acknowledgement 82 | 1. ​**pixelSplat: 3D Gaussian Splats from Image Pairs for Scalable Generalizable 3D Reconstruction** 83 | - ​**Contributors**: [David Charatan](https://davidcharatan.com/#/) et al. 84 | - ​**References**: [Project Page](https://davidcharatan.com/pixelsplat/) | [Code Repository](https://github.com/dcharatan/pixelsplat) 85 | - ​**Technical Support**: 86 | - Our implementation framework is built upon the pixelSplat architecture 87 | - Inherited compatibility with its data formatting conventions and training pipeline 88 | - Special recognition for the innovative code architecture design 89 | 90 | 2. ​**Long-LRM Self-Reimplementation** 91 | - ​**Maintainers**: [Arthur Hero](https://github.com/arthurhero) 92 | - ​**References**: [Code Repository](https://github.com/arthurhero/Long-LRM) 93 | - ​**Technical Adoption**: 94 | - Integrated core model implementation schemes from this repository 95 | - Adapted key modules for neural rendering optimization 96 | - Appreciation for maintaining code accessibility and documentation completeness 97 | 98 | ### Special Recognition 99 | We particularly acknowledge the open-source community's invaluable contributions in advancing 3D reconstruction research. The transparency and extensibility demonstrated in these projects have significantly accelerated our development process. 100 | 101 | 102 | -------------------------------------------------------------------------------- /assets/loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/gs-lrm-unofficial/6fe1104d5fe7176b866b877f3ff798b40849d0d0/assets/loss_curve.png -------------------------------------------------------------------------------- /assets/render_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/gs-lrm-unofficial/6fe1104d5fe7176b866b877f3ff798b40849d0d0/assets/render_video.gif -------------------------------------------------------------------------------- /config/compute_metrics.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - model/encoder: epipolar 4 | - loss: [] 5 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 6 | - override dataset/view_sampler: evaluation 7 | 8 | data_loader: 9 | train: 10 | num_workers: 0 11 | persistent_workers: true 12 | batch_size: 1 13 | seed: 1234 14 | test: 15 | num_workers: 4 16 | persistent_workers: false 17 | batch_size: 1 18 | seed: 2345 19 | val: 20 | num_workers: 0 21 | persistent_workers: true 22 | batch_size: 1 23 | seed: 3456 24 | 25 | seed: 111123 26 | -------------------------------------------------------------------------------- /config/dataset/re10k.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - view_sampler: bounded 3 | 4 | name: re10k 5 | roots: [data/re10k] # chaneg here to your personal re10k dataset path 6 | make_baseline_1: true 7 | augment: true 8 | 9 | image_shape: [180, 320] 10 | background_color: [0.0, 0.0, 0.0] 11 | cameras_are_circular: false 12 | 13 | baseline_epsilon: 1e-3 14 | max_fov: 100.0 15 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/all.yaml: -------------------------------------------------------------------------------- 1 | name: all 2 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/arbitrary.yaml: -------------------------------------------------------------------------------- 1 | name: arbitrary 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | # If you want to hard-code context views, do so here. 7 | context_views: null 8 | -------------------------------------------------------------------------------- /config/dataset/view_sampler/bounded.yaml: -------------------------------------------------------------------------------- 1 | name: bounded 2 | 3 | num_target_views: 1 4 | num_context_views: 2 5 | 6 | min_distance_between_context_views: 2 7 | max_distance_between_context_views: 6 8 | min_distance_to_context_views: 0 9 | 10 | warm_up_steps: 0 11 | initial_min_distance_between_context_views: 2 12 | initial_max_distance_between_context_views: 6 -------------------------------------------------------------------------------- /config/dataset/view_sampler/evaluation.yaml: -------------------------------------------------------------------------------- 1 | name: evaluation 2 | 3 | index_path: assets/evaluation_index_re10k_video.json 4 | num_context_views: 2 5 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | min_distance_between_context_views: 45 6 | max_distance_between_context_views: 45 7 | min_distance_to_context_views: 0 8 | warm_up_steps: 150_000 9 | initial_min_distance_between_context_views: 25 10 | initial_max_distance_between_context_views: 25 11 | num_target_views: 4 12 | -------------------------------------------------------------------------------- /config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dataset: 4 | view_sampler: 5 | index_path: assets/evaluation_index_re10k.json 6 | -------------------------------------------------------------------------------- /config/experiment/re10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: re10k 5 | - override /model/encoder: lrm 6 | - override /loss: [mse, lpips] # 7 | 8 | wandb: 9 | name: gs-lrm-pretrain-run1 10 | tags: [re10k, 256x256] 11 | 12 | dataset: 13 | image_shape: [256, 256] 14 | roots: [data/re10k] # chaneg here to your personal re10k dataset path 15 | 16 | data_loader: 17 | train: 18 | batch_size: 8 19 | 20 | trainer: 21 | max_steps: 800_001 22 | -------------------------------------------------------------------------------- /config/generate_evaluation_index.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - override dataset/view_sampler: all 5 | 6 | dataset: 7 | overfit_to_scene: null 8 | 9 | data_loader: 10 | train: 11 | num_workers: 0 12 | persistent_workers: true 13 | batch_size: 1 14 | seed: 1234 15 | test: 16 | num_workers: 8 17 | persistent_workers: false 18 | batch_size: 1 19 | seed: 2345 20 | val: 21 | num_workers: 0 22 | persistent_workers: true 23 | batch_size: 1 24 | seed: 3456 25 | 26 | index_generator: 27 | num_target_views: 3 28 | min_overlap: 0.6 29 | max_overlap: 1.0 30 | min_distance: 45 31 | max_distance: 135 32 | output_path: outputs/evaluation_index_re10k 33 | save_previews: false 34 | seed: 123 35 | 36 | seed: 456 37 | -------------------------------------------------------------------------------- /config/loss/depth.yaml: -------------------------------------------------------------------------------- 1 | depth: 2 | weight: 0.25 3 | sigma_image: null 4 | use_second_derivative: false 5 | -------------------------------------------------------------------------------- /config/loss/lpips.yaml: -------------------------------------------------------------------------------- 1 | lpips: 2 | weight: 0.5 3 | # apply_after_step: 1 4 | -------------------------------------------------------------------------------- /config/loss/mse.yaml: -------------------------------------------------------------------------------- 1 | mse: 2 | weight: 1.0 3 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: re10k 3 | - optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset} 4 | - model/encoder: lrm 5 | - model/decoder: splatting_cuda 6 | - loss: [mse, lpips] 7 | 8 | wandb: 9 | project: gs-lrm-unofficial 10 | entity: scene-representation-group 11 | name: maoyucheng 12 | mode: online 13 | 14 | mode: train 15 | 16 | dataset: 17 | overfit_to_scene: null 18 | 19 | data_loader: 20 | # Avoid having to spin up new processes to print out visualizations. 21 | train: 22 | num_workers: 16 23 | persistent_workers: true 24 | batch_size: 4 25 | seed: 1234 26 | test: 27 | num_workers: 4 28 | persistent_workers: false 29 | batch_size: 1 30 | seed: 2345 31 | val: 32 | num_workers: 1 33 | persistent_workers: true 34 | batch_size: 1 35 | seed: 3456 36 | 37 | optimizer: 38 | lr: 2e-4 39 | warm_up_steps: 2000 40 | max_steps: 800_001 41 | 42 | checkpointing: 43 | load: null 44 | every_n_train_steps: 5000 45 | save_top_k: -1 46 | 47 | train: 48 | extended_visualization: false 49 | 50 | test: 51 | output_path: outputs/test 52 | 53 | seed: 111123 54 | 55 | trainer: 56 | max_steps: -1 57 | val_check_interval: 500 58 | gradient_clip_val: 0.5 59 | -------------------------------------------------------------------------------- /config/model/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /config/model/encoder/lrm.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: lrm 3 | patch_size: 8 4 | attn_dim: 1024 5 | 6 | transformer: 7 | head_dim: 16 8 | num_layers: 24 9 | 10 | gaussians_params: 11 | sh_degree: 0 12 | scale_bias: -2.3 13 | scale_max: 0.3 14 | opacity_bias: -2.0 15 | near_plane: 0.01 16 | far_plane: 500.0 17 | 18 | apply_bounds_shim: True 19 | near_disparity: 3.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | tqdm 3 | lightning 4 | black 5 | ruff 6 | hydra-core 7 | jaxtyping 8 | beartype 9 | wandb 10 | einops 11 | colorama 12 | scikit-image 13 | colorspacious 14 | matplotlib 15 | moviepy 16 | imageio 17 | 18 | timm 19 | dacite 20 | lpips 21 | e3nn 22 | plyfile 23 | tabulate 24 | svg.py 25 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Literal, Optional, Type, TypeVar 4 | 5 | from dacite import Config, from_dict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .dataset.data_module import DataLoaderCfg, DatasetCfg 9 | from .loss import LossCfgWrapper 10 | from .model.decoder import DecoderCfg 11 | from .model.encoder import EncoderCfg 12 | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg 13 | 14 | 15 | @dataclass 16 | class CheckpointingCfg: 17 | load: Optional[str] # Not a path, since it could be something like wandb://... 18 | every_n_train_steps: int 19 | save_top_k: int 20 | 21 | 22 | @dataclass 23 | class ModelCfg: 24 | decoder: DecoderCfg 25 | encoder: EncoderCfg 26 | 27 | 28 | @dataclass 29 | class TrainerCfg: 30 | max_steps: int 31 | val_check_interval: int | float | None 32 | gradient_clip_val: int | float | None 33 | 34 | 35 | @dataclass 36 | class RootCfg: 37 | wandb: dict 38 | mode: Literal["train", "test"] 39 | dataset: DatasetCfg 40 | data_loader: DataLoaderCfg 41 | model: ModelCfg 42 | optimizer: OptimizerCfg 43 | checkpointing: CheckpointingCfg 44 | trainer: TrainerCfg 45 | loss: list[LossCfgWrapper] 46 | test: TestCfg 47 | train: TrainCfg 48 | seed: int 49 | 50 | 51 | TYPE_HOOKS = { 52 | Path: Path, 53 | } 54 | 55 | 56 | T = TypeVar("T") 57 | 58 | 59 | def load_typed_config( 60 | cfg: DictConfig, 61 | data_class: Type[T], 62 | extra_type_hooks: dict = {}, 63 | ) -> T: 64 | return from_dict( 65 | data_class, 66 | OmegaConf.to_container(cfg), 67 | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), 68 | ) 69 | 70 | 71 | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: 72 | # The dummy allows the union to be converted. 73 | @dataclass 74 | class Dummy: 75 | dummy: LossCfgWrapper 76 | 77 | return [ 78 | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy 79 | for k, v in joined.items() 80 | ] 81 | 82 | 83 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 84 | return load_typed_config( 85 | cfg, 86 | RootCfg, 87 | {list[LossCfgWrapper]: separate_loss_cfg_wrappers}, 88 | ) 89 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from ..misc.step_tracker import StepTracker 4 | from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg 5 | from .types import Stage 6 | from .view_sampler import get_view_sampler 7 | 8 | DATASETS: dict[str, Dataset] = { 9 | "re10k": DatasetRE10k, 10 | } 11 | 12 | 13 | DatasetCfg = DatasetRE10kCfg 14 | 15 | 16 | def get_dataset( 17 | cfg: DatasetCfg, 18 | stage: Stage, 19 | step_tracker: StepTracker | None, 20 | ) -> Dataset: 21 | view_sampler = get_view_sampler( 22 | cfg.view_sampler, 23 | stage, 24 | cfg.overfit_to_scene is not None, 25 | cfg.cameras_are_circular, 26 | step_tracker, 27 | ) 28 | return DATASETS[cfg.name](cfg, stage, view_sampler) 29 | -------------------------------------------------------------------------------- /src/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | from lightning.pytorch import LightningDataModule 8 | from torch import Generator, nn 9 | from torch.utils.data import DataLoader, Dataset, IterableDataset 10 | 11 | from ..misc.step_tracker import StepTracker 12 | from . import DatasetCfg, get_dataset 13 | from .types import DataShim, Stage 14 | from .validation_wrapper import ValidationWrapper 15 | 16 | 17 | def get_data_shim(encoder: nn.Module) -> DataShim: 18 | """Get functions that modify the batch. It's sometimes necessary to modify batches 19 | outside the data loader because GPU computations are required to modify the batch or 20 | because the modification depends on something outside the data loader. 21 | """ 22 | 23 | shims: list[DataShim] = [] 24 | if hasattr(encoder, "get_data_shim"): 25 | shims.append(encoder.get_data_shim()) 26 | 27 | def combined_shim(batch): 28 | for shim in shims: 29 | batch = shim(batch) 30 | return batch 31 | 32 | return combined_shim 33 | 34 | 35 | @dataclass 36 | class DataLoaderStageCfg: 37 | batch_size: int 38 | num_workers: int 39 | persistent_workers: bool 40 | seed: int | None 41 | 42 | 43 | @dataclass 44 | class DataLoaderCfg: 45 | train: DataLoaderStageCfg 46 | test: DataLoaderStageCfg 47 | val: DataLoaderStageCfg 48 | 49 | 50 | DatasetShim = Callable[[Dataset, Stage], Dataset] 51 | 52 | 53 | def worker_init_fn(worker_id: int) -> None: 54 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 55 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 56 | 57 | 58 | class DataModule(LightningDataModule): 59 | dataset_cfg: DatasetCfg 60 | data_loader_cfg: DataLoaderCfg 61 | step_tracker: StepTracker | None 62 | dataset_shim: DatasetShim 63 | global_rank: int 64 | 65 | def __init__( 66 | self, 67 | dataset_cfg: DatasetCfg, 68 | data_loader_cfg: DataLoaderCfg, 69 | step_tracker: StepTracker | None = None, 70 | dataset_shim: DatasetShim = lambda dataset, _: dataset, 71 | global_rank: int = 0, 72 | ) -> None: 73 | super().__init__() 74 | self.dataset_cfg = dataset_cfg 75 | self.data_loader_cfg = data_loader_cfg 76 | self.step_tracker = step_tracker 77 | self.dataset_shim = dataset_shim 78 | self.global_rank = global_rank 79 | 80 | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: 81 | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers 82 | 83 | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: 84 | if loader_cfg.seed is None: 85 | return None 86 | generator = Generator() 87 | generator.manual_seed(loader_cfg.seed + self.global_rank) 88 | return generator 89 | 90 | def train_dataloader(self): 91 | dataset = get_dataset(self.dataset_cfg, "train", self.step_tracker) 92 | dataset = self.dataset_shim(dataset, "train") 93 | return DataLoader( 94 | dataset, 95 | self.data_loader_cfg.train.batch_size, 96 | shuffle=not isinstance(dataset, IterableDataset), 97 | num_workers=self.data_loader_cfg.train.num_workers, 98 | generator=self.get_generator(self.data_loader_cfg.train), 99 | worker_init_fn=worker_init_fn, 100 | persistent_workers=self.get_persistent(self.data_loader_cfg.train), 101 | ) 102 | 103 | def val_dataloader(self): 104 | dataset = get_dataset(self.dataset_cfg, "val", self.step_tracker) 105 | dataset = self.dataset_shim(dataset, "val") 106 | return DataLoader( 107 | ValidationWrapper(dataset, 1), 108 | self.data_loader_cfg.val.batch_size, 109 | num_workers=self.data_loader_cfg.val.num_workers, 110 | generator=self.get_generator(self.data_loader_cfg.val), 111 | worker_init_fn=worker_init_fn, 112 | persistent_workers=self.get_persistent(self.data_loader_cfg.val), 113 | ) 114 | 115 | def test_dataloader(self): 116 | dataset = get_dataset(self.dataset_cfg, "test", self.step_tracker) 117 | dataset = self.dataset_shim(dataset, "test") 118 | return DataLoader( 119 | dataset, 120 | self.data_loader_cfg.test.batch_size, 121 | num_workers=self.data_loader_cfg.test.num_workers, 122 | generator=self.get_generator(self.data_loader_cfg.test), 123 | worker_init_fn=worker_init_fn, 124 | persistent_workers=self.get_persistent(self.data_loader_cfg.test), 125 | ) 126 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /src/dataset/dataset_re10k.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from functools import cached_property 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import torch 9 | import torchvision.transforms as tf 10 | from einops import rearrange, repeat 11 | from jaxtyping import Float, UInt8 12 | from PIL import Image 13 | from torch import Tensor 14 | from torch.utils.data import IterableDataset 15 | 16 | from ..geometry.projection import get_fov 17 | from ..geometry.pose_utils import centerize_scale_poses 18 | from .dataset import DatasetCfgCommon 19 | from .shims.augmentation_shim import apply_augmentation_shim 20 | from .shims.crop_shim import apply_crop_shim 21 | from .types import Stage 22 | from .view_sampler import ViewSampler 23 | 24 | 25 | @dataclass 26 | class DatasetRE10kCfg(DatasetCfgCommon): 27 | name: Literal["re10k"] 28 | roots: list[Path] 29 | baseline_epsilon: float 30 | max_fov: float 31 | make_baseline_1: bool 32 | augment: bool 33 | 34 | 35 | class DatasetRE10k(IterableDataset): 36 | cfg: DatasetRE10kCfg 37 | stage: Stage 38 | view_sampler: ViewSampler 39 | 40 | to_tensor: tf.ToTensor 41 | chunks: list[Path] 42 | near: float = 0.0 43 | far: float = 500.0 44 | 45 | def __init__( 46 | self, 47 | cfg: DatasetRE10kCfg, 48 | stage: Stage, 49 | view_sampler: ViewSampler, 50 | ) -> None: 51 | super().__init__() 52 | self.cfg = cfg 53 | self.stage = stage 54 | self.view_sampler = view_sampler 55 | self.to_tensor = tf.ToTensor() 56 | 57 | # Collect chunks. 58 | self.chunks = [] 59 | for root in cfg.roots: 60 | root = root / self.data_stage 61 | root_chunks = sorted( 62 | [path for path in root.iterdir() if path.suffix == ".torch"] 63 | ) 64 | self.chunks.extend(root_chunks) 65 | # breakpoint() 66 | if self.cfg.overfit_to_scene is not None: 67 | chunk_path = self.index[self.cfg.overfit_to_scene] 68 | self.chunks = [chunk_path] * len(self.chunks) 69 | 70 | def shuffle(self, lst: list) -> list: 71 | indices = torch.randperm(len(lst)) 72 | return [lst[x] for x in indices] 73 | 74 | def __iter__(self): 75 | # Chunks must be shuffled here (not inside __init__) for validation to show 76 | # random chunks. 77 | if self.stage in ("train", "val"): 78 | self.chunks = self.shuffle(self.chunks) 79 | 80 | # When testing, the data loaders alternate chunks. 81 | worker_info = torch.utils.data.get_worker_info() 82 | if self.stage == "test" and worker_info is not None: 83 | self.chunks = [ 84 | chunk 85 | for chunk_index, chunk in enumerate(self.chunks) 86 | if chunk_index % worker_info.num_workers == worker_info.id 87 | ] 88 | 89 | for chunk_path in self.chunks: 90 | # Load the chunk. 91 | chunk = torch.load(chunk_path) 92 | 93 | if self.cfg.overfit_to_scene is not None: 94 | item = [x for x in chunk if x["key"] == self.cfg.overfit_to_scene] 95 | assert len(item) == 1 96 | chunk = item * len(chunk) 97 | 98 | if self.stage in ("train", "val"): 99 | chunk = self.shuffle(chunk) 100 | 101 | for example in chunk: 102 | extrinsics, intrinsics = self.convert_poses(example["cameras"]) #c2w 103 | scene = example["key"] 104 | 105 | try: 106 | context_indices, target_indices = self.view_sampler.sample( 107 | scene, 108 | extrinsics, 109 | intrinsics, 110 | ) 111 | except ValueError: 112 | # Skip because the example doesn't have enough frames. 113 | continue 114 | 115 | # Skip the example if the field of view is too wide. 116 | if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any(): 117 | continue 118 | 119 | # Load the images. 120 | try: 121 | context_images = [ 122 | example["images"][index.item()] for index in context_indices 123 | ] 124 | context_images = self.convert_images(context_images) 125 | target_images = [ 126 | example["images"][index.item()] for index in target_indices 127 | ] 128 | target_images = self.convert_images(target_images) 129 | except IndexError: 130 | continue 131 | 132 | # Skip the example if the images don't have the right shape. 133 | context_image_invalid = context_images.shape[1:] != (3, 360, 640) 134 | target_image_invalid = target_images.shape[1:] != (3, 360, 640) 135 | if context_image_invalid or target_image_invalid: 136 | print( 137 | f"Skipped bad example {example['key']}. Context shape was " 138 | f"{context_images.shape} and target shape was " 139 | f"{target_images.shape}." 140 | ) 141 | continue 142 | 143 | # Resize the world to make the baseline 1. 144 | using_extrinsics = extrinsics[torch.cat([context_indices, target_indices])] 145 | # context_extrinsics = extrinsics[context_indices] 146 | # if context_extrinsics.shape[0] == 2 and self.cfg.make_baseline_1: 147 | # a, b = context_extrinsics[:, :3, 3] 148 | # breakpoint() 149 | # scale = (a - b).norm() 150 | # if scale < self.cfg.baseline_epsilon: 151 | # print( 152 | # f"Skipped {scene} because of insufficient baseline " 153 | # f"{scale:.6f}" 154 | # ) 155 | # continue 156 | # extrinsics[:, :3, 3] /= scale 157 | # else: 158 | # scale = 1 159 | # breakpoint() 160 | # avg_c2w = poses_avg(using_extrinsics) 161 | using_extrinsics, scale = centerize_scale_poses(using_extrinsics) 162 | # breakpoint() 163 | 164 | 165 | example = { 166 | "context": { 167 | "extrinsics": using_extrinsics[:2], 168 | "intrinsics": intrinsics[context_indices], 169 | "image": context_images, 170 | "near": self.get_bound("near", len(context_indices)), 171 | "far": self.get_bound("far", len(context_indices)), 172 | "index": context_indices, 173 | }, 174 | "target": { 175 | "extrinsics": using_extrinsics[2:], 176 | "intrinsics": intrinsics[target_indices], 177 | "image": target_images, 178 | "near": self.get_bound("near", len(target_indices)), 179 | "far": self.get_bound("far", len(target_indices)), 180 | "index": target_indices, 181 | }, 182 | "scene": scene, 183 | } 184 | if self.stage == "train" and self.cfg.augment: 185 | example = apply_augmentation_shim(example) 186 | yield apply_crop_shim(example, tuple(self.cfg.image_shape)) 187 | 188 | def convert_poses( 189 | self, 190 | poses: Float[Tensor, "batch 18"], 191 | ) -> tuple[ 192 | Float[Tensor, "batch 4 4"], # extrinsics 193 | Float[Tensor, "batch 3 3"], # intrinsics 194 | ]: 195 | b, _ = poses.shape 196 | 197 | # Convert the intrinsics to a 3x3 normalized K matrix. 198 | intrinsics = torch.eye(3, dtype=torch.float32) 199 | intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() 200 | fx, fy, cx, cy = poses[:, :4].T 201 | intrinsics[:, 0, 0] = fx 202 | intrinsics[:, 1, 1] = fy 203 | intrinsics[:, 0, 2] = cx 204 | intrinsics[:, 1, 2] = cy 205 | 206 | # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix. 207 | w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone() 208 | w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) 209 | return w2c.inverse(), intrinsics.clone() 210 | 211 | def convert_images( 212 | self, 213 | images: list[UInt8[Tensor, "..."]], 214 | ) -> Float[Tensor, "batch 3 height width"]: 215 | torch_images = [] 216 | for image in images: 217 | image = Image.open(BytesIO(image.numpy().tobytes())) 218 | torch_images.append(self.to_tensor(image)) 219 | return torch.stack(torch_images) 220 | 221 | def get_bound( 222 | self, 223 | bound: Literal["near", "far"], 224 | num_views: int, 225 | ) -> Float[Tensor, " view"]: 226 | value = torch.tensor(getattr(self, bound), dtype=torch.float32) 227 | return repeat(value, "-> v", v=num_views) 228 | 229 | @property 230 | def data_stage(self) -> Stage: 231 | if self.cfg.overfit_to_scene is not None: 232 | return "test" 233 | if self.stage == "val": 234 | return "test" 235 | return self.stage 236 | 237 | @cached_property 238 | def index(self) -> dict[str, Path]: 239 | merged_index = {} 240 | data_stages = [self.data_stage] 241 | if self.cfg.overfit_to_scene is not None: 242 | data_stages = ("test", "train") 243 | for data_stage in data_stages: 244 | for root in self.cfg.roots: 245 | # Load the root's index. 246 | with (root / data_stage / "index.json").open("r") as f: 247 | index = json.load(f) 248 | index = {k: Path(root / data_stage / v) for k, v in index.items()} 249 | 250 | # The constituent datasets should have unique keys. 251 | assert not (set(merged_index.keys()) & set(index.keys())) 252 | 253 | # Merge the root's index into the main index. 254 | merged_index = {**merged_index, **index} 255 | return merged_index 256 | 257 | def __len__(self) -> int: 258 | return len(self.index.keys()) 259 | 260 | 261 | -------------------------------------------------------------------------------- /src/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /src/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /src/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /src/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /src/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /src/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ...misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | 11 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 12 | "all": ViewSamplerAll, 13 | "arbitrary": ViewSamplerArbitrary, 14 | "bounded": ViewSamplerBounded, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = ( 19 | ViewSamplerArbitraryCfg 20 | | ViewSamplerBoundedCfg 21 | | ViewSamplerEvaluationCfg 22 | | ViewSamplerAllCfg 23 | ) 24 | 25 | 26 | def get_view_sampler( 27 | cfg: ViewSamplerCfg, 28 | stage: Stage, 29 | overfit: bool, 30 | cameras_are_circular: bool, 31 | step_tracker: StepTracker | None, 32 | ) -> ViewSampler[Any]: 33 | return VIEW_SAMPLERS[cfg.name]( 34 | cfg, 35 | stage, 36 | overfit, 37 | cameras_are_circular, 38 | step_tracker, 39 | ) 40 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/three_view_hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Int 3 | from torch import Tensor 4 | 5 | 6 | def add_third_context_index( 7 | indices: Int[Tensor, "*batch 2"] 8 | ) -> Int[Tensor, "*batch 3"]: 9 | left, right = indices.unbind(dim=-1) 10 | return torch.stack((left, (left + right) // 2, right), dim=-1) 11 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from ...misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | ) -> tuple[ 43 | Int64[Tensor, " context_view"], # indices for context views 44 | Int64[Tensor, " target_view"], # indices for target views 45 | ]: 46 | pass 47 | 48 | @property 49 | @abstractmethod 50 | def num_target_views(self) -> int: 51 | pass 52 | 53 | @property 54 | @abstractmethod 55 | def num_context_views(self) -> int: 56 | pass 57 | 58 | @property 59 | def global_step(self) -> int: 60 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 61 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .three_view_hack import add_third_context_index 9 | from .view_sampler import ViewSampler 10 | 11 | 12 | @dataclass 13 | class ViewSamplerArbitraryCfg: 14 | name: Literal["arbitrary"] 15 | num_context_views: int 16 | num_target_views: int 17 | context_views: list[int] | None 18 | target_views: list[int] | None 19 | 20 | 21 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 22 | def sample( 23 | self, 24 | scene: str, 25 | extrinsics: Float[Tensor, "view 4 4"], 26 | intrinsics: Float[Tensor, "view 3 3"], 27 | device: torch.device = torch.device("cpu"), 28 | ) -> tuple[ 29 | Int64[Tensor, " context_view"], # indices for context views 30 | Int64[Tensor, " target_view"], # indices for target views 31 | ]: 32 | """Arbitrarily sample context and target views.""" 33 | num_views, _, _ = extrinsics.shape 34 | 35 | index_context = torch.randint( 36 | 0, 37 | num_views, 38 | size=(self.cfg.num_context_views,), 39 | device=device, 40 | ) 41 | 42 | # Allow the context views to be fixed. 43 | if self.cfg.context_views is not None: 44 | index_context = torch.tensor( 45 | self.cfg.context_views, dtype=torch.int64, device=device 46 | ) 47 | 48 | if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2: 49 | index_context = add_third_context_index(index_context) 50 | else: 51 | assert len(self.cfg.context_views) == self.cfg.num_context_views 52 | index_target = torch.randint( 53 | 0, 54 | num_views, 55 | size=(self.cfg.num_target_views,), 56 | device=device, 57 | ) 58 | 59 | # Allow the target views to be fixed. 60 | if self.cfg.target_views is not None: 61 | assert len(self.cfg.target_views) == self.cfg.num_target_views 62 | index_target = torch.tensor( 63 | self.cfg.target_views, dtype=torch.int64, device=device 64 | ) 65 | 66 | return index_context, index_target 67 | 68 | @property 69 | def num_context_views(self) -> int: 70 | return self.cfg.num_context_views 71 | 72 | @property 73 | def num_target_views(self) -> int: 74 | return self.cfg.num_target_views 75 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | name: Literal["bounded"] 14 | num_context_views: int 15 | num_target_views: int 16 | min_distance_between_context_views: int 17 | max_distance_between_context_views: int 18 | min_distance_to_context_views: int 19 | warm_up_steps: int 20 | initial_min_distance_between_context_views: int 21 | initial_max_distance_between_context_views: int 22 | 23 | 24 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 25 | def schedule(self, initial: int, final: int) -> int: 26 | fraction = self.global_step / self.cfg.warm_up_steps 27 | return min(initial + int((final - initial) * fraction), final) 28 | 29 | def sample( 30 | self, 31 | scene: str, 32 | extrinsics: Float[Tensor, "view 4 4"], 33 | intrinsics: Float[Tensor, "view 3 3"], 34 | device: torch.device = torch.device("cpu"), 35 | ) -> tuple[ 36 | Int64[Tensor, " context_view"], # indices for context views 37 | Int64[Tensor, " target_view"], # indices for target views 38 | ]: 39 | num_views, _, _ = extrinsics.shape 40 | 41 | # Compute the context view spacing based on the current global step. 42 | if self.stage == "test": 43 | # When testing, always use the full gap. 44 | max_gap = self.cfg.max_distance_between_context_views 45 | min_gap = self.cfg.max_distance_between_context_views 46 | elif self.cfg.warm_up_steps > 0: 47 | max_gap = self.schedule( 48 | self.cfg.initial_max_distance_between_context_views, 49 | self.cfg.max_distance_between_context_views, 50 | ) 51 | min_gap = self.schedule( 52 | self.cfg.initial_min_distance_between_context_views, 53 | self.cfg.min_distance_between_context_views, 54 | ) 55 | else: 56 | max_gap = self.cfg.max_distance_between_context_views 57 | min_gap = self.cfg.min_distance_between_context_views 58 | 59 | # Pick the gap between the context views. 60 | if not self.cameras_are_circular: 61 | max_gap = min(num_views - 1, max_gap) 62 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 63 | if max_gap < min_gap: 64 | raise ValueError("Example does not have enough frames!") 65 | context_gap = torch.randint( 66 | min_gap, 67 | max_gap + 1, 68 | size=tuple(), 69 | device=device, 70 | ).item() 71 | 72 | # Pick the left and right context indices. 73 | index_context_left = torch.randint( 74 | num_views if self.cameras_are_circular else num_views - context_gap, 75 | size=tuple(), 76 | device=device, 77 | ).item() 78 | if self.stage == "test": 79 | index_context_left = index_context_left * 0 80 | index_context_right = index_context_left + context_gap 81 | 82 | if self.is_overfitting: 83 | index_context_left *= 0 84 | index_context_right *= 0 85 | index_context_right += max_gap 86 | 87 | # Pick the target view indices. 88 | if self.stage == "test": 89 | # When testing, pick all. 90 | index_target = torch.arange( 91 | index_context_left, 92 | index_context_right + 1, 93 | device=device, 94 | ) 95 | else: 96 | # When training or validating (visualizing), pick at random. 97 | index_target = torch.randint( 98 | index_context_left + self.cfg.min_distance_to_context_views, 99 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 100 | size=(self.cfg.num_target_views,), 101 | device=device, 102 | ) 103 | 104 | # Apply modulo for circular datasets. 105 | if self.cameras_are_circular: 106 | index_target %= num_views 107 | index_context_right %= num_views 108 | 109 | # If more than two context views are desired, pick extra context views between 110 | # the left and right ones. 111 | if self.cfg.num_context_views > 2: 112 | num_extra_views = self.cfg.num_context_views - 2 113 | extra_views = [] 114 | while len(set(extra_views)) != num_extra_views: 115 | extra_views = torch.randint( 116 | index_context_left + 1, 117 | index_context_right, 118 | (num_extra_views,), 119 | ).tolist() 120 | else: 121 | extra_views = [] 122 | 123 | return ( 124 | torch.tensor((index_context_left, *extra_views, index_context_right)), 125 | index_target, 126 | ) 127 | 128 | @property 129 | def num_context_views(self) -> int: 130 | return self.cfg.num_context_views 131 | 132 | @property 133 | def num_target_views(self) -> int: 134 | return self.cfg.num_target_views 135 | -------------------------------------------------------------------------------- /src/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from ...global_cfg import get_cfg 13 | from ...misc.step_tracker import StepTracker 14 | from ..types import Stage 15 | from .three_view_hack import add_third_context_index 16 | from .view_sampler import ViewSampler 17 | 18 | 19 | @dataclass 20 | class ViewSamplerEvaluationCfg: 21 | name: Literal["evaluation"] 22 | index_path: Path 23 | num_context_views: int 24 | 25 | 26 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 27 | index: dict[str, IndexEntry | None] 28 | 29 | def __init__( 30 | self, 31 | cfg: ViewSamplerEvaluationCfg, 32 | stage: Stage, 33 | is_overfitting: bool, 34 | cameras_are_circular: bool, 35 | step_tracker: StepTracker | None, 36 | ) -> None: 37 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 38 | 39 | dacite_config = Config(cast=[tuple]) 40 | with cfg.index_path.open("r") as f: 41 | self.index = { 42 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 43 | for k, v in json.load(f).items() 44 | } 45 | 46 | def sample( 47 | self, 48 | scene: str, 49 | extrinsics: Float[Tensor, "view 4 4"], 50 | intrinsics: Float[Tensor, "view 3 3"], 51 | device: torch.device = torch.device("cpu"), 52 | ) -> tuple[ 53 | Int64[Tensor, " context_view"], # indices for context views 54 | Int64[Tensor, " target_view"], # indices for target views 55 | ]: 56 | entry = self.index.get(scene) 57 | if entry is None: 58 | raise ValueError(f"No indices available for scene {scene}.") 59 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 60 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 61 | 62 | # Handle 2-view index for 3 views. 63 | v = get_cfg()["dataset"]["view_sampler"]["num_context_views"] 64 | if v > len(context_indices) and v == 3: 65 | context_indices = add_third_context_index(context_indices) 66 | 67 | return context_indices, target_indices 68 | 69 | @property 70 | def num_context_views(self) -> int: 71 | return 0 72 | 73 | @property 74 | def num_target_views(self) -> int: 75 | return 0 76 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | 4 | 5 | @dataclass 6 | class MethodCfg: 7 | name: str 8 | key: str 9 | path: Path 10 | 11 | 12 | @dataclass 13 | class SceneCfg: 14 | scene: str 15 | target_index: int 16 | 17 | 18 | @dataclass 19 | class EvaluationCfg: 20 | methods: list[MethodCfg] 21 | side_by_side_path: Path | None 22 | animate_side_by_side: bool 23 | highlighted: list[SceneCfg] 24 | -------------------------------------------------------------------------------- /src/evaluation/evaluation_index_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import asdict, dataclass 3 | from pathlib import Path 4 | 5 | import torch 6 | from einops import rearrange 7 | from lightning.pytorch import LightningModule 8 | from tqdm import tqdm 9 | 10 | from ..geometry.projection import get_world_rays, sample_image_grid 11 | from ..misc.image_io import save_image 12 | from ..visualization.annotation import add_label 13 | from ..visualization.layout import add_border, hcat 14 | 15 | 16 | @dataclass 17 | class EvaluationIndexGeneratorCfg: 18 | num_target_views: int 19 | min_distance: int 20 | max_distance: int 21 | min_overlap: float 22 | max_overlap: float 23 | output_path: Path 24 | save_previews: bool 25 | seed: int 26 | 27 | 28 | @dataclass 29 | class IndexEntry: 30 | context: tuple[int, int] 31 | target: tuple[int, ...] 32 | 33 | 34 | class EvaluationIndexGenerator(LightningModule): 35 | generator: torch.Generator 36 | cfg: EvaluationIndexGeneratorCfg 37 | index: dict[str, IndexEntry | None] 38 | 39 | def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: 40 | super().__init__() 41 | self.cfg = cfg 42 | self.generator = torch.Generator() 43 | self.generator.manual_seed(cfg.seed) 44 | self.index = {} 45 | 46 | def test_step(self, batch, batch_idx): 47 | b, v, _, h, w = batch["target"]["image"].shape 48 | assert b == 1 49 | extrinsics = batch["target"]["extrinsics"][0] 50 | intrinsics = batch["target"]["intrinsics"][0] 51 | scene = batch["scene"][0] 52 | 53 | context_indices = torch.randperm(v, generator=self.generator) 54 | for context_index in tqdm(context_indices, "Finding context pair"): 55 | xy, _ = sample_image_grid((h, w), self.device) 56 | context_origins, context_directions = get_world_rays( 57 | rearrange(xy, "h w xy -> (h w) xy"), 58 | extrinsics[context_index], 59 | intrinsics[context_index], 60 | ) 61 | 62 | # Step away from context view until the minimum overlap threshold is met. 63 | valid_indices = [] 64 | for step in (1, -1): 65 | min_distance = self.cfg.min_distance 66 | max_distance = self.cfg.max_distance 67 | current_index = context_index + step * min_distance 68 | 69 | while 0 <= current_index.item() < v: 70 | # Compute overlap. 71 | current_origins, current_directions = get_world_rays( 72 | rearrange(xy, "h w xy -> (h w) xy"), 73 | extrinsics[current_index], 74 | intrinsics[current_index], 75 | ) 76 | projection_onto_current = project_rays( 77 | context_origins, 78 | context_directions, 79 | extrinsics[current_index], 80 | intrinsics[current_index], 81 | ) 82 | projection_onto_context = project_rays( 83 | current_origins, 84 | current_directions, 85 | extrinsics[context_index], 86 | intrinsics[context_index], 87 | ) 88 | overlap_a = projection_onto_context["overlaps_image"].float().mean() 89 | overlap_b = projection_onto_current["overlaps_image"].float().mean() 90 | 91 | overlap = min(overlap_a, overlap_b) 92 | delta = (current_index - context_index).abs() 93 | 94 | min_overlap = self.cfg.min_overlap 95 | max_overlap = self.cfg.max_overlap 96 | if min_overlap <= overlap <= max_overlap: 97 | valid_indices.append( 98 | (current_index.item(), overlap_a, overlap_b) 99 | ) 100 | 101 | # Stop once the camera has panned away too much. 102 | if overlap < min_overlap or delta > max_distance: 103 | break 104 | 105 | current_index += step 106 | 107 | if valid_indices: 108 | # Pick a random valid view. Index the resulting views. 109 | num_options = len(valid_indices) 110 | chosen = torch.randint( 111 | 0, num_options, size=tuple(), generator=self.generator 112 | ) 113 | chosen, overlap_a, overlap_b = valid_indices[chosen] 114 | 115 | context_left = min(chosen, context_index.item()) 116 | context_right = max(chosen, context_index.item()) 117 | delta = context_right - context_left 118 | 119 | # Pick non-repeated random target views. 120 | while True: 121 | target_views = torch.randint( 122 | context_left, 123 | context_right + 1, 124 | (self.cfg.num_target_views,), 125 | generator=self.generator, 126 | ) 127 | if (target_views.unique(return_counts=True)[1] == 1).all(): 128 | break 129 | 130 | target = tuple(sorted(target_views.tolist())) 131 | self.index[scene] = IndexEntry( 132 | context=(context_left, context_right), 133 | target=target, 134 | ) 135 | 136 | # Optionally, save a preview. 137 | if self.cfg.save_previews: 138 | preview_path = self.cfg.output_path / "previews" 139 | preview_path.mkdir(exist_ok=True, parents=True) 140 | a = batch["target"]["image"][0, chosen] 141 | a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") 142 | b = batch["target"]["image"][0, context_index] 143 | b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") 144 | vis = add_border(add_border(hcat(a, b)), 1, 0) 145 | vis = add_label(vis, f"Distance: {delta} frames") 146 | save_image(add_border(vis), preview_path / f"{scene}.png") 147 | break 148 | else: 149 | # This happens if no starting frame produces a valid evaluation example. 150 | self.index[scene] = None 151 | 152 | def save_index(self) -> None: 153 | self.cfg.output_path.mkdir(exist_ok=True, parents=True) 154 | with (self.cfg.output_path / "evaluation_index.json").open("w") as f: 155 | json.dump( 156 | {k: None if v is None else asdict(v) for k, v in self.index.items()}, f 157 | ) 158 | -------------------------------------------------------------------------------- /src/evaluation/metric_computer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from lightning.pytorch import LightningModule 6 | from tabulate import tabulate 7 | 8 | from ..misc.image_io import load_image, save_image 9 | from ..visualization.annotation import add_label 10 | from ..visualization.layout import add_border, hcat 11 | from .evaluation_cfg import EvaluationCfg 12 | from .metrics import compute_lpips, compute_psnr, compute_ssim 13 | 14 | 15 | class MetricComputer(LightningModule): 16 | cfg: EvaluationCfg 17 | 18 | def __init__(self, cfg: EvaluationCfg) -> None: 19 | super().__init__() 20 | self.cfg = cfg 21 | 22 | def test_step(self, batch, batch_idx): 23 | scene = batch["scene"][0] 24 | b, cv, _, _, _ = batch["context"]["image"].shape 25 | assert b == 1 26 | _, v, _, _, _ = batch["target"]["image"].shape 27 | 28 | # Skip scenes. 29 | for method in self.cfg.methods: 30 | if not (method.path / scene).exists(): 31 | print(f'Skipping "{scene}".') 32 | return 33 | 34 | # Load the images. 35 | all_images = {} 36 | try: 37 | for method in self.cfg.methods: 38 | images = [ 39 | load_image(method.path / scene / f"color/{index.item():0>6}.png") 40 | for index in batch["target"]["index"][0] 41 | ] 42 | all_images[method.key] = torch.stack(images).to(self.device) 43 | except FileNotFoundError: 44 | print(f'Skipping "{scene}".') 45 | return 46 | 47 | # Compute metrics. 48 | all_metrics = {} 49 | rgb_gt = batch["target"]["image"][0] 50 | for key, images in all_images.items(): 51 | all_metrics = { 52 | **all_metrics, 53 | f"lpips_{key}": compute_lpips(rgb_gt, images).mean(), 54 | f"ssim_{key}": compute_ssim(rgb_gt, images).mean(), 55 | f"psnr_{key}": compute_psnr(rgb_gt, images).mean(), 56 | } 57 | self.log_dict(all_metrics) 58 | self.print_preview_metrics(all_metrics) 59 | 60 | # Skip the rest if no side-by-side is needed. 61 | if self.cfg.side_by_side_path is None: 62 | return 63 | 64 | # Create side-by-side. 65 | scene_key = f"{batch_idx:0>6}_{scene}" 66 | for i in range(v): 67 | true_index = batch["target"]["index"][0, i] 68 | row = [add_label(batch["target"]["image"][0, i], "Ground Truth")] 69 | for method in self.cfg.methods: 70 | image = all_images[method.key][i] 71 | image = add_label(image, method.name) 72 | row.append(image) 73 | start_frame = batch["target"]["index"][0, 0] 74 | end_frame = batch["target"]["index"][0, -1] 75 | label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})" 76 | row = add_border(add_label(hcat(*row), label, font_size=16)) 77 | save_image( 78 | row, 79 | self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png", 80 | ) 81 | 82 | # Create an animation. 83 | if self.cfg.animate_side_by_side: 84 | (self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True) 85 | command = ( 86 | 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 ' 87 | '-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"' 88 | ) 89 | os.system( 90 | f"cd {self.cfg.side_by_side_path / scene_key} && {command} " 91 | f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4" 92 | ) 93 | 94 | def print_preview_metrics(self, metrics: dict[str, float]) -> None: 95 | if getattr(self, "running_metrics", None) is None: 96 | self.running_metrics = metrics 97 | self.running_metric_steps = 1 98 | else: 99 | s = self.running_metric_steps 100 | self.running_metrics = { 101 | k: ((s * v) + metrics[k]) / (s + 1) 102 | for k, v in self.running_metrics.items() 103 | } 104 | self.running_metric_steps += 1 105 | 106 | table = [] 107 | for method in self.cfg.methods: 108 | row = [ 109 | f"{self.running_metrics[f'{metric}_{method.key}']:.3f}" 110 | for metric in ("psnr", "lpips", "ssim") 111 | ] 112 | table.append((method.key, *row)) 113 | 114 | table = tabulate(table, ["Method", "PSNR (dB)", "LPIPS", "SSIM"]) 115 | print(table) 116 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from skimage.metrics import structural_similarity 8 | from torch import Tensor 9 | 10 | 11 | @torch.no_grad() 12 | def compute_psnr( 13 | ground_truth: Float[Tensor, "batch channel height width"], 14 | predicted: Float[Tensor, "batch channel height width"], 15 | ) -> Float[Tensor, " batch"]: 16 | ground_truth = ground_truth.clip(min=0, max=1) 17 | predicted = predicted.clip(min=0, max=1) 18 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 19 | return -10 * mse.log10() 20 | 21 | 22 | @cache 23 | def get_lpips(device: torch.device) -> LPIPS: 24 | return LPIPS(net="vgg").to(device) 25 | 26 | 27 | @torch.no_grad() 28 | def compute_lpips( 29 | ground_truth: Float[Tensor, "batch channel height width"], 30 | predicted: Float[Tensor, "batch channel height width"], 31 | ) -> Float[Tensor, " batch"]: 32 | value = get_lpips(predicted.device).forward(ground_truth, predicted, normalize=True) 33 | return value[:, 0, 0, 0] 34 | 35 | 36 | @torch.no_grad() 37 | def compute_ssim( 38 | ground_truth: Float[Tensor, "batch channel height width"], 39 | predicted: Float[Tensor, "batch channel height width"], 40 | ) -> Float[Tensor, " batch"]: 41 | ssim = [ 42 | structural_similarity( 43 | gt.detach().cpu().numpy(), 44 | hat.detach().cpu().numpy(), 45 | win_size=11, 46 | gaussian_weights=True, 47 | channel_axis=0, 48 | data_range=1.0, 49 | ) 50 | for gt, hat in zip(ground_truth, predicted) 51 | ] 52 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 53 | -------------------------------------------------------------------------------- /src/geometry/pose_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple 3 | 4 | def normalize(x): 5 | return x / torch.linalg.norm(x) 6 | 7 | def viewmatrix(z, up, pos): 8 | vec2 = normalize(z) 9 | vec1_avg = up 10 | vec0 = normalize(torch.cross(vec1_avg, vec2)) 11 | vec1 = normalize(torch.cross(vec2, vec0)) 12 | 13 | m = torch.stack([vec0, vec1, vec2, pos], dim=1) 14 | 15 | m_4x4 = torch.cat([m, torch.tensor([[0, 0, 0, 1]], dtype=pos.dtype)], dim=0) 16 | return m_4x4 17 | 18 | def poses_avg(poses): 19 | 20 | poses_3x4 = poses[:, :3, :4] 21 | center = poses_3x4[:, :, -1].mean(dim=0) 22 | 23 | vec2 = normalize(poses_3x4[:, :, 2].sum(dim=0)) 24 | # breakpoint() 25 | up = poses_3x4[:, :, 1].sum(dim=0) 26 | c2w = viewmatrix(vec2, up, center) 27 | 28 | return c2w 29 | 30 | def centerize_scale_poses( 31 | in_c2ws: torch.Tensor, 32 | frame_method: str = 'mean_cam', 33 | scale_range: Union[Tuple[float, float], None] = None, 34 | scene_scale_method: str = 'two_cam' 35 | ) -> Tuple[torch.Tensor, torch.Tensor]: 36 | 37 | in_c2ws = in_c2ws.clone() 38 | N, _, _ = in_c2ws.shape 39 | 40 | if frame_method == 'mean_cam': 41 | # bottom = torch.tensor([0, 0, 0, 1.0], device=in_c2ws.device).view(1, 4) 42 | apos = poses_avg(in_c2ws) 43 | # apos = torch.cat([apos, bottom], dim=0).unsqueeze(0) 44 | # breakpoint() 45 | in_c2ws = torch.matmul(torch.inverse(apos), in_c2ws) 46 | # TODO: the following two method dosen't support yet! 47 | elif frame_method == 'first_cam': 48 | first_c2w = in_c2ws[0] 49 | in_c2ws = torch.matmul(torch.inverse(first_c2w), in_c2ws) 50 | elif frame_method == 'center': 51 | scene_center = (torch.max(in_c2ws[:, :3, 3], dim=0).values + 52 | torch.min(in_c2ws[:, :3, 3], dim=0).values) / 2 53 | in_c2ws[:, :3, 3] = in_c2ws[:, :3, 3] - scene_center 54 | else: 55 | raise NotImplementedError(f"Unknown frame_method: {frame_method}") 56 | 57 | 58 | scene_scale = torch.max(torch.abs(in_c2ws[:, :3, 3])) 59 | 60 | if scene_scale_method == "two_cam": 61 | two_cam_dist = torch.linalg.norm(in_c2ws[0, :3, 3] - in_c2ws[1, :3, 3]) 62 | scene_scale = 1.0 / (two_cam_dist + 0.01) 63 | elif scene_scale_method == "fix_range": 64 | if scale_range is None: 65 | raise ValueError("scale_range must be provided when scene_scale_method is 'fix_range'") 66 | min_scale, max_scale = scale_range 67 | random_scale = torch.rand(1, device=in_c2ws.device)[0] * (max_scale - min_scale) + min_scale 68 | scene_scale *= random_scale 69 | else: 70 | raise NotImplementedError(f"Unknown scene_scale_method: {scene_scale_method}") 71 | 72 | in_c2ws[:, :3, 3] = in_c2ws[:, :3, 3] / scene_scale 73 | 74 | return in_c2ws, scene_scale -------------------------------------------------------------------------------- /src/geometry/projection.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | import torch 4 | from einops import einsum, rearrange, reduce, repeat 5 | from jaxtyping import Bool, Float, Int64 6 | from torch import Tensor 7 | 8 | 9 | def homogenize_points( 10 | points: Float[Tensor, "*batch dim"], 11 | ) -> Float[Tensor, "*batch dim+1"]: 12 | """Convert batched points (xyz) to (xyz1).""" 13 | return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) 14 | 15 | 16 | def homogenize_vectors( 17 | vectors: Float[Tensor, "*batch dim"], 18 | ) -> Float[Tensor, "*batch dim+1"]: 19 | """Convert batched vectors (xyz) to (xyz0).""" 20 | return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) 21 | 22 | 23 | def transform_rigid( 24 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 25 | transformation: Float[Tensor, "*#batch dim dim"], 26 | ) -> Float[Tensor, "*batch dim"]: 27 | """Apply a rigid-body transformation to points or vectors.""" 28 | return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i") 29 | 30 | 31 | def transform_cam2world( 32 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 33 | extrinsics: Float[Tensor, "*#batch dim dim"], 34 | ) -> Float[Tensor, "*batch dim"]: 35 | """Transform points from 3D camera coordinates to 3D world coordinates.""" 36 | return transform_rigid(homogeneous_coordinates, extrinsics) 37 | 38 | 39 | def transform_world2cam( 40 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 41 | extrinsics: Float[Tensor, "*#batch dim dim"], 42 | ) -> Float[Tensor, "*batch dim"]: 43 | """Transform points from 3D world coordinates to 3D camera coordinates.""" 44 | return transform_rigid(homogeneous_coordinates, extrinsics.inverse()) 45 | 46 | 47 | def project_camera_space( 48 | points: Float[Tensor, "*#batch dim"], 49 | intrinsics: Float[Tensor, "*#batch dim dim"], 50 | epsilon: float = torch.finfo(torch.float32).eps, 51 | infinity: float = 1e8, 52 | ) -> Float[Tensor, "*batch dim-1"]: 53 | points = points / (points[..., -1:] + epsilon) 54 | points = points.nan_to_num(posinf=infinity, neginf=-infinity) 55 | points = einsum(intrinsics, points, "... i j, ... j -> ... i") 56 | return points[..., :-1] 57 | 58 | 59 | def project( 60 | points: Float[Tensor, "*#batch dim"], 61 | extrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 62 | intrinsics: Float[Tensor, "*#batch dim dim"], 63 | epsilon: float = torch.finfo(torch.float32).eps, 64 | ) -> tuple[ 65 | Float[Tensor, "*batch dim-1"], # xy coordinates 66 | Bool[Tensor, " *batch"], # whether points are in front of the camera 67 | ]: 68 | points = homogenize_points(points) 69 | points = transform_world2cam(points, extrinsics)[..., :-1] 70 | in_front_of_camera = points[..., -1] >= 0 71 | return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera 72 | 73 | 74 | def unproject( 75 | coordinates: Float[Tensor, "*#batch dim"], 76 | z: Float[Tensor, "*#batch"], 77 | intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 78 | ) -> Float[Tensor, "*batch dim+1"]: 79 | """Unproject 2D camera coordinates with the given Z values.""" 80 | 81 | # Apply the inverse intrinsics to the coordinates. 82 | coordinates = homogenize_points(coordinates) 83 | ray_directions = einsum( 84 | intrinsics.inverse(), coordinates, "... i j, ... j -> ... i" 85 | ) 86 | 87 | # Apply the supplied depth values. 88 | return ray_directions * z[..., None] 89 | 90 | 91 | def get_world_rays( 92 | coordinates: Float[Tensor, "*#batch dim"], 93 | extrinsics: Float[Tensor, "*#batch dim+2 dim+2"], 94 | intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 95 | ) -> tuple[ 96 | Float[Tensor, "*batch dim+1"], # origins 97 | Float[Tensor, "*batch dim+1"], # directions 98 | ]: 99 | # Get camera-space ray directions. 100 | directions = unproject( 101 | coordinates, 102 | torch.ones_like(coordinates[..., 0]), 103 | intrinsics, 104 | ) 105 | directions = directions / directions.norm(dim=-1, keepdim=True) 106 | 107 | # Transform ray directions to world coordinates. 108 | directions = homogenize_vectors(directions) 109 | directions = transform_cam2world(directions, extrinsics)[..., :-1] 110 | 111 | # Tile the ray origins to have the same shape as the ray directions. 112 | origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) 113 | 114 | return origins, directions 115 | 116 | 117 | def sample_image_grid( 118 | shape: tuple[int, ...], 119 | device: torch.device = torch.device("cpu"), 120 | ) -> tuple[ 121 | Float[Tensor, "*shape dim"], # float coordinates (xy indexing) 122 | Int64[Tensor, "*shape dim"], # integer indices (ij indexing) 123 | ]: 124 | """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" 125 | 126 | # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a 127 | # (row, col) coordinate. 128 | indices = [torch.arange(length, device=device) for length in shape] 129 | stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) 130 | 131 | # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, 132 | # each entry is an (x, y) coordinate. 133 | coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] 134 | coordinates = reversed(coordinates) 135 | coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) 136 | 137 | return coordinates, stacked_indices 138 | 139 | 140 | def sample_training_rays( 141 | image: Float[Tensor, "batch view channel ..."], 142 | intrinsics: Float[Tensor, "batch view dim dim"], 143 | extrinsics: Float[Tensor, "batch view dim+1 dim+1"], 144 | num_rays: int, 145 | ) -> tuple[ 146 | Float[Tensor, "batch ray dim"], # origins 147 | Float[Tensor, "batch ray dim"], # directions 148 | Float[Tensor, "batch ray 3"], # sampled color 149 | ]: 150 | device = extrinsics.device 151 | b, v, _, *grid_shape = image.shape 152 | 153 | # Generate all possible target rays. 154 | xy, _ = sample_image_grid(tuple(grid_shape), device) 155 | origins, directions = get_world_rays( 156 | rearrange(xy, "... d -> ... () () d"), 157 | extrinsics, 158 | intrinsics, 159 | ) 160 | origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v) 161 | directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v) 162 | pixels = rearrange(image, "b v c ... -> b (v ...) c") 163 | 164 | # Sample random rays. 165 | num_possible_rays = v * prod(grid_shape) 166 | ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device) 167 | batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays) 168 | 169 | return ( 170 | origins[batch_indices, ray_indices], 171 | directions[batch_indices, ray_indices], 172 | pixels[batch_indices, ray_indices], 173 | ) 174 | 175 | 176 | def intersect_rays( 177 | origins_x: Float[Tensor, "*#batch 3"], 178 | directions_x: Float[Tensor, "*#batch 3"], 179 | origins_y: Float[Tensor, "*#batch 3"], 180 | directions_y: Float[Tensor, "*#batch 3"], 181 | eps: float = 1e-5, 182 | inf: float = 1e10, 183 | ) -> Float[Tensor, "*batch 3"]: 184 | """Compute the least-squares intersection of rays. Uses the math from here: 185 | https://math.stackexchange.com/a/1762491/286022 186 | """ 187 | 188 | # Broadcast the rays so their shapes match. 189 | shape = torch.broadcast_shapes( 190 | origins_x.shape, 191 | directions_x.shape, 192 | origins_y.shape, 193 | directions_y.shape, 194 | ) 195 | origins_x = origins_x.broadcast_to(shape) 196 | directions_x = directions_x.broadcast_to(shape) 197 | origins_y = origins_y.broadcast_to(shape) 198 | directions_y = directions_y.broadcast_to(shape) 199 | 200 | # Detect and remove batch elements where the directions are parallel. 201 | parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps 202 | origins_x = origins_x[~parallel] 203 | directions_x = directions_x[~parallel] 204 | origins_y = origins_y[~parallel] 205 | directions_y = directions_y[~parallel] 206 | 207 | # Stack the rays into (2, *shape). 208 | origins = torch.stack([origins_x, origins_y], dim=0) 209 | directions = torch.stack([directions_x, directions_y], dim=0) 210 | dtype = origins.dtype 211 | device = origins.device 212 | 213 | # Compute n_i * n_i^T - eye(3) from the equation. 214 | n = einsum(directions, directions, "r b i, r b j -> r b i j") 215 | n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3)) 216 | 217 | # Compute the left-hand side of the equation. 218 | lhs = reduce(n, "r b i j -> b i j", "sum") 219 | 220 | # Compute the right-hand side of the equation. 221 | rhs = einsum(n, origins, "r b i j, r b j -> r b i") 222 | rhs = reduce(rhs, "r b i -> b i", "sum") 223 | 224 | # Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p. 225 | result = torch.linalg.lstsq(lhs, rhs).solution 226 | 227 | # Handle the case of parallel lines by setting depth to infinity. 228 | result_all = torch.ones(shape, dtype=dtype, device=device) * inf 229 | result_all[~parallel] = result 230 | return result_all 231 | 232 | 233 | def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]: 234 | intrinsics_inv = intrinsics.inverse() 235 | 236 | def process_vector(vector): 237 | vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device) 238 | vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") 239 | return vector / vector.norm(dim=-1, keepdim=True) 240 | 241 | left = process_vector([0, 0.5, 1]) 242 | right = process_vector([1, 0.5, 1]) 243 | top = process_vector([0.5, 0, 1]) 244 | bottom = process_vector([0.5, 1, 1]) 245 | fov_x = (left * right).sum(dim=-1).acos() 246 | fov_y = (top * bottom).sum(dim=-1).acos() 247 | return torch.stack((fov_x, fov_y), dim=-1) 248 | -------------------------------------------------------------------------------- /src/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import Loss 2 | from .loss_depth import LossDepth, LossDepthCfgWrapper 3 | from .loss_lpips import LossLpips, LossLpipsCfgWrapper 4 | from .loss_mse import LossMse, LossMseCfgWrapper 5 | 6 | LOSSES = { 7 | LossDepthCfgWrapper: LossDepth, 8 | LossLpipsCfgWrapper: LossLpips, 9 | LossMseCfgWrapper: LossMse, 10 | } 11 | 12 | LossCfgWrapper = LossDepthCfgWrapper | LossLpipsCfgWrapper | LossMseCfgWrapper 13 | 14 | 15 | def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: 16 | return [LOSSES[type(cfg)](cfg) for cfg in cfgs] 17 | -------------------------------------------------------------------------------- /src/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import fields 3 | from typing import Generic, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | 12 | T_cfg = TypeVar("T_cfg") 13 | T_wrapper = TypeVar("T_wrapper") 14 | 15 | 16 | class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): 17 | cfg: T_cfg 18 | name: str 19 | 20 | def __init__(self, cfg: T_wrapper) -> None: 21 | super().__init__() 22 | 23 | # Extract the configuration from the wrapper. 24 | (field,) = fields(type(cfg)) 25 | self.cfg = getattr(cfg, field.name) 26 | self.name = field.name 27 | 28 | @abstractmethod 29 | def forward( 30 | self, 31 | prediction: DecoderOutput, 32 | batch: BatchedExample, 33 | gaussians: Gaussians, 34 | global_step: int, 35 | ) -> Float[Tensor, ""]: 36 | pass 37 | -------------------------------------------------------------------------------- /src/loss/loss_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import reduce 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ..dataset.types import BatchedExample 9 | from ..model.decoder.decoder import DecoderOutput 10 | from ..model.types import Gaussians 11 | from .loss import Loss 12 | 13 | 14 | @dataclass 15 | class LossDepthCfg: 16 | weight: float 17 | sigma_image: float | None 18 | use_second_derivative: bool 19 | 20 | 21 | @dataclass 22 | class LossDepthCfgWrapper: 23 | depth: LossDepthCfg 24 | 25 | 26 | class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): 27 | def forward( 28 | self, 29 | prediction: DecoderOutput, 30 | batch: BatchedExample, 31 | gaussians: Gaussians, 32 | global_step: int, 33 | ) -> Float[Tensor, ""]: 34 | # Scale the depth between the near and far planes. 35 | near = batch["target"]["near"][..., None, None].log() 36 | far = batch["target"]["far"][..., None, None].log() 37 | depth = prediction.depth.minimum(far).maximum(near) 38 | depth = (depth - near) / (far - near) 39 | 40 | # Compute the difference between neighboring pixels in each direction. 41 | depth_dx = depth.diff(dim=-1) 42 | depth_dy = depth.diff(dim=-2) 43 | 44 | # If desired, compute a 2nd derivative. 45 | if self.cfg.use_second_derivative: 46 | depth_dx = depth_dx.diff(dim=-1) 47 | depth_dy = depth_dy.diff(dim=-2) 48 | 49 | # If desired, add bilateral filtering. 50 | if self.cfg.sigma_image is not None: 51 | color_gt = batch["target"]["image"] 52 | color_dx = reduce(color_gt.diff(dim=-1), "b v c h w -> b v h w", "max") 53 | color_dy = reduce(color_gt.diff(dim=-2), "b v c h w -> b v h w", "max") 54 | if self.cfg.use_second_derivative: 55 | color_dx = color_dx[..., :, 1:].maximum(color_dx[..., :, :-1]) 56 | color_dy = color_dy[..., 1:, :].maximum(color_dy[..., :-1, :]) 57 | depth_dx = depth_dx * torch.exp(-color_dx * self.cfg.sigma_image) 58 | depth_dy = depth_dy * torch.exp(-color_dy * self.cfg.sigma_image) 59 | 60 | return self.cfg.weight * (depth_dx.abs().mean() + depth_dy.abs().mean()) 61 | -------------------------------------------------------------------------------- /src/loss/loss_lpips.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from jaxtyping import Float 6 | from lpips import LPIPS 7 | from torch import Tensor 8 | 9 | from ..dataset.types import BatchedExample 10 | from ..misc.nn_module_tools import convert_to_buffer 11 | from ..model.decoder.decoder import DecoderOutput 12 | from ..model.types import Gaussians 13 | from .loss import Loss 14 | 15 | 16 | @dataclass 17 | class LossLpipsCfg: 18 | weight: float 19 | # apply_after_step: int 20 | 21 | 22 | @dataclass 23 | class LossLpipsCfgWrapper: 24 | lpips: LossLpipsCfg 25 | 26 | 27 | class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): 28 | lpips: LPIPS 29 | 30 | def __init__(self, cfg: LossLpipsCfgWrapper) -> None: 31 | super().__init__(cfg) 32 | 33 | self.lpips = LPIPS(net="vgg") 34 | convert_to_buffer(self.lpips, persistent=False) 35 | 36 | def forward( 37 | self, 38 | prediction: DecoderOutput, 39 | batch: BatchedExample, 40 | gaussians: Gaussians, 41 | global_step: int, 42 | ) -> Float[Tensor, ""]: 43 | image = batch["target"]["image"] 44 | 45 | # Before the specified step, don't apply the loss. 46 | # if global_step < self.cfg.apply_after_step: 47 | # return torch.tensor(0, dtype=torch.float32, device=image.device) 48 | 49 | loss = self.lpips.forward( 50 | rearrange(prediction.color, "b v c h w -> (b v) c h w"), 51 | rearrange(image, "b v c h w -> (b v) c h w"), 52 | normalize=True, 53 | ) 54 | return self.cfg.weight * loss.mean() 55 | -------------------------------------------------------------------------------- /src/loss/loss_mse.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..dataset.types import BatchedExample 7 | from ..model.decoder.decoder import DecoderOutput 8 | from ..model.types import Gaussians 9 | from .loss import Loss 10 | 11 | 12 | @dataclass 13 | class LossMseCfg: 14 | weight: float 15 | 16 | 17 | @dataclass 18 | class LossMseCfgWrapper: 19 | mse: LossMseCfg 20 | 21 | 22 | class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): 23 | def forward( 24 | self, 25 | prediction: DecoderOutput, 26 | batch: BatchedExample, 27 | gaussians: Gaussians, 28 | global_step: int, 29 | ) -> Float[Tensor, ""]: 30 | delta = prediction.color - batch["target"]["image"] 31 | return self.cfg.weight * (delta**2).mean() 32 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import torch 6 | import wandb 7 | from colorama import Fore 8 | from jaxtyping import install_import_hook 9 | from lightning.pytorch import Trainer 10 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 11 | from lightning.pytorch.loggers.wandb import WandbLogger 12 | from lightning.pytorch.plugins.environments import SLURMEnvironment 13 | from omegaconf import DictConfig, OmegaConf 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | # Configure beartype and jaxtyping. 19 | with install_import_hook( 20 | ("src",), 21 | ("beartype", "beartype"), 22 | ): 23 | from src.config import load_typed_root_config 24 | from src.dataset.data_module import DataModule 25 | from src.global_cfg import set_cfg 26 | from src.loss import get_losses 27 | from src.misc.LocalLogger import LocalLogger 28 | from src.misc.step_tracker import StepTracker 29 | from src.misc.wandb_tools import update_checkpoint_path 30 | from src.model.decoder import get_decoder 31 | from src.model.encoder import get_encoder 32 | from src.model.model_wrapper import ModelWrapper 33 | 34 | 35 | def cyan(text: str) -> str: 36 | return f"{Fore.CYAN}{text}{Fore.RESET}" 37 | 38 | 39 | @hydra.main( 40 | version_base=None, 41 | config_path="../config", 42 | config_name="main", 43 | ) 44 | def train(cfg_dict: DictConfig): 45 | cfg = load_typed_root_config(cfg_dict) 46 | set_cfg(cfg_dict) 47 | 48 | # Set up the output directory. 49 | output_dir = Path( 50 | hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] 51 | ) 52 | print(cyan(f"Saving outputs to {output_dir}.")) 53 | latest_run = output_dir.parents[1] / "latest-run" 54 | os.system(f"rm {latest_run}") 55 | os.system(f"ln -s {output_dir} {latest_run}") 56 | 57 | # Set up logging with wandb. 58 | callbacks = [] 59 | if cfg_dict.wandb.mode != "disabled": 60 | # breakpoint() 61 | logger = WandbLogger( 62 | project=cfg_dict.wandb.project, 63 | mode=cfg_dict.wandb.mode, 64 | name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", 65 | tags=cfg_dict.wandb.get("tags", None), 66 | log_model="all", 67 | save_dir=output_dir, 68 | config=OmegaConf.to_container(cfg_dict), 69 | ) 70 | callbacks.append(LearningRateMonitor("step", True)) 71 | 72 | # On rank != 0, wandb.run is None. 73 | if wandb.run is not None: 74 | wandb.run.log_code("src") 75 | else: 76 | logger = LocalLogger() 77 | 78 | # Set up checkpointing. 79 | callbacks.append( 80 | ModelCheckpoint( 81 | output_dir / "checkpoints", 82 | every_n_train_steps=cfg.checkpointing.every_n_train_steps, 83 | save_top_k=cfg.checkpointing.save_top_k, 84 | ) 85 | ) 86 | 87 | # Prepare the checkpoint for loading. 88 | checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) 89 | 90 | # This allows the current step to be shared with the data loader processes. 91 | step_tracker = StepTracker() 92 | 93 | trainer = Trainer( 94 | max_epochs=-1, 95 | accelerator="gpu", 96 | # precision="bf16-mixed", 97 | num_nodes=1, 98 | logger=logger, 99 | devices="auto", 100 | strategy=( 101 | "ddp_find_unused_parameters_true" 102 | if torch.cuda.device_count() > 1 103 | else "auto" 104 | ), 105 | callbacks=callbacks, 106 | val_check_interval=cfg.trainer.val_check_interval, 107 | enable_progress_bar=False, 108 | gradient_clip_val=cfg.trainer.gradient_clip_val, 109 | max_steps=cfg.trainer.max_steps, 110 | # plugins=[SLURMEnvironment(auto_requeue=False)], 111 | plugins=[], 112 | ) 113 | torch.manual_seed(cfg_dict.seed + trainer.global_rank) 114 | 115 | encoder = get_encoder(cfg.model.encoder) 116 | 117 | model_wrapper = ModelWrapper( 118 | cfg.optimizer, 119 | cfg.test, 120 | cfg.train, 121 | encoder, 122 | get_decoder(cfg.model.decoder, cfg.dataset), 123 | get_losses(cfg.loss), 124 | step_tracker, 125 | ) 126 | data_module = DataModule( 127 | cfg.dataset, 128 | cfg.data_loader, 129 | step_tracker, 130 | global_rank=trainer.global_rank, 131 | ) 132 | 133 | trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) 134 | 135 | 136 | if __name__ == "__main__": 137 | train() 138 | -------------------------------------------------------------------------------- /src/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from lightning.pytorch.loggers.logger import Logger 6 | from lightning.pytorch.utilities import rank_zero_only 7 | from PIL import Image 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | 18 | @property 19 | def name(self): 20 | return "LocalLogger" 21 | 22 | @property 23 | def version(self): 24 | return 0 25 | 26 | @rank_zero_only 27 | def log_hyperparams(self, params): 28 | pass 29 | 30 | @rank_zero_only 31 | def log_metrics(self, metrics, step): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_image( 36 | self, 37 | key: str, 38 | images: list[Any], 39 | step: Optional[int] = None, 40 | **kwargs, 41 | ): 42 | # The function signature is the same as the wandb logger's, but the step is 43 | # actually required. 44 | assert step is not None 45 | for index, image in enumerate(images): 46 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 47 | path.parent.mkdir(exist_ok=True, parents=True) 48 | Image.fromarray(image).save(path) 49 | -------------------------------------------------------------------------------- /src/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | -------------------------------------------------------------------------------- /src/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /src/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /src/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /src/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as tf 8 | from einops import rearrange, repeat 9 | from jaxtyping import Float, UInt8 10 | from matplotlib.figure import Figure 11 | from PIL import Image 12 | from torch import Tensor 13 | 14 | FloatImage = Union[ 15 | Float[Tensor, "height width"], 16 | Float[Tensor, "channel height width"], 17 | Float[Tensor, "batch channel height width"], 18 | ] 19 | 20 | 21 | def fig_to_image( 22 | fig: Figure, 23 | dpi: int = 100, 24 | device: torch.device = torch.device("cpu"), 25 | ) -> Float[Tensor, "3 height width"]: 26 | buffer = io.BytesIO() 27 | fig.savefig(buffer, format="raw", dpi=dpi) 28 | buffer.seek(0) 29 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 30 | h = int(fig.bbox.bounds[3]) 31 | w = int(fig.bbox.bounds[2]) 32 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 33 | buffer.close() 34 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 35 | 36 | 37 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 38 | # Handle batched images. 39 | if image.ndim == 4: 40 | image = rearrange(image, "b c h w -> c h (b w)") 41 | 42 | # Handle single-channel images. 43 | if image.ndim == 2: 44 | image = rearrange(image, "h w -> () h w") 45 | 46 | # Ensure that there are 3 or 4 channels. 47 | channel, _, _ = image.shape 48 | if channel == 1: 49 | image = repeat(image, "() h w -> c h w", c=3) 50 | assert image.shape[0] in (3, 4) 51 | 52 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 53 | return rearrange(image, "c h w -> h w c").cpu().numpy() 54 | 55 | 56 | def save_image( 57 | image: FloatImage, 58 | path: Union[Path, str], 59 | ) -> None: 60 | """Save an image. Assumed to be in range 0-1.""" 61 | 62 | # Create the parent directory if it doesn't already exist. 63 | path = Path(path) 64 | path.parent.mkdir(exist_ok=True, parents=True) 65 | 66 | # Save the image. 67 | Image.fromarray(prep_image(image)).save(path) 68 | 69 | 70 | def load_image( 71 | path: Union[Path, str], 72 | ) -> Float[Tensor, "3 height width"]: 73 | return tf.ToTensor()(Image.open(path))[:3] 74 | -------------------------------------------------------------------------------- /src/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /src/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | *_, n = sh_coefficients.shape 18 | alpha, beta, gamma = matrix_to_angles(rotations) 19 | result = [] 20 | for degree in range(isqrt(n)): 21 | with torch.device(device): 22 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 23 | sh_rotated = einsum( 24 | sh_rotations, 25 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 26 | "... i j, ... j -> ... i", 27 | ) 28 | result.append(sh_rotated) 29 | 30 | return torch.cat(result, dim=-1) 31 | 32 | 33 | if __name__ == "__main__": 34 | from pathlib import Path 35 | 36 | import matplotlib.pyplot as plt 37 | from e3nn.o3 import spherical_harmonics 38 | from matplotlib import cm 39 | from scipy.spatial.transform.rotation import Rotation as R 40 | 41 | device = torch.device("cuda") 42 | 43 | # Generate random spherical harmonics coefficients. 44 | degree = 4 45 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 46 | 47 | def plot_sh(sh_coefficients, path: Path) -> None: 48 | phi = torch.linspace(0, torch.pi, 100, device=device) 49 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 50 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 51 | x = torch.sin(phi) * torch.cos(theta) 52 | y = torch.sin(phi) * torch.sin(theta) 53 | z = torch.cos(phi) 54 | xyz = torch.stack([x, y, z], dim=-1) 55 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 56 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 57 | result = (result - result.min()) / (result.max() - result.min()) 58 | 59 | # Set the aspect ratio to 1 so our sphere looks spherical 60 | fig = plt.figure(figsize=plt.figaspect(1.0)) 61 | ax = fig.add_subplot(111, projection="3d") 62 | ax.plot_surface( 63 | x.cpu().numpy(), 64 | y.cpu().numpy(), 65 | z.cpu().numpy(), 66 | rstride=1, 67 | cstride=1, 68 | facecolors=cm.seismic(result.cpu().numpy()), 69 | ) 70 | # Turn off the axis planes 71 | ax.set_axis_off() 72 | path.parent.mkdir(exist_ok=True, parents=True) 73 | plt.savefig(path) 74 | 75 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 76 | rotation = torch.tensor( 77 | R.from_euler("x", angle.item()).as_matrix(), device=device 78 | ) 79 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 80 | 81 | print("Done!") 82 | -------------------------------------------------------------------------------- /src/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /src/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /src/model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from ...dataset import DatasetCfg 2 | from .decoder import Decoder 3 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 4 | 5 | DECODERS = { 6 | "splatting_cuda": DecoderSplattingCUDA, 7 | } 8 | 9 | DecoderCfg = DecoderSplattingCUDACfg 10 | 11 | 12 | def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: 13 | return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) 14 | -------------------------------------------------------------------------------- /src/model/decoder/cuda_splatting.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | from typing import Literal 3 | 4 | import torch 5 | from diff_gaussian_rasterization import ( 6 | GaussianRasterizationSettings, 7 | GaussianRasterizer, 8 | ) 9 | from einops import einsum, rearrange, repeat 10 | from typing import Tuple 11 | from jaxtyping import Float 12 | from torch import Tensor 13 | 14 | from ...geometry.projection import get_fov, homogenize_points 15 | 16 | def get_projection_matrix( 17 | near: Float[Tensor, " batch"], 18 | far: Float[Tensor, " batch"], 19 | fov_x: Float[Tensor, " batch"], 20 | fov_y: Float[Tensor, " batch"], 21 | ) -> Float[Tensor, "batch 4 4"]: 22 | """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z 23 | axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after 24 | transformation and that Z is flipped. 25 | """ 26 | tan_fov_x = (0.5 * fov_x).tan() 27 | tan_fov_y = (0.5 * fov_y).tan() 28 | 29 | top = tan_fov_y * near 30 | bottom = -top 31 | right = tan_fov_x * near 32 | left = -right 33 | 34 | (b,) = near.shape 35 | result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) 36 | result[:, 0, 0] = 2 * near / (right - left) 37 | result[:, 1, 1] = 2 * near / (top - bottom) 38 | result[:, 0, 2] = (right + left) / (right - left) 39 | result[:, 1, 2] = (top + bottom) / (top - bottom) 40 | result[:, 3, 2] = 1 41 | result[:, 2, 2] = far / (far - near) 42 | result[:, 2, 3] = -(far * near) / (far - near) 43 | return result 44 | 45 | 46 | def render_cuda( 47 | extrinsics: Float[Tensor, "batch 4 4"], 48 | intrinsics: Float[Tensor, "batch 3 3"], 49 | near: Float[Tensor, " batch"], 50 | far: Float[Tensor, " batch"], 51 | image_shape: tuple[int, int], 52 | background_color: Float[Tensor, "batch 3"], 53 | gaussian_means: Float[Tensor, "batch gaussian 3"], 54 | gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], 55 | gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], 56 | gaussian_opacities: Float[Tensor, "batch gaussian"], 57 | scale_invariant: bool = True, 58 | use_sh: bool = True, 59 | ) -> Tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch 1 height width"]]: 60 | assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 61 | 62 | # Make sure everything is in a range where numerical issues don't appear. 63 | if scale_invariant: 64 | scale = 1 / near 65 | extrinsics = extrinsics.clone() 66 | extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None] 67 | gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2) 68 | gaussian_means = gaussian_means * scale[:, None, None] 69 | near = near * scale 70 | far = far * scale 71 | 72 | _, _, _, n = gaussian_sh_coefficients.shape 73 | degree = isqrt(n) - 1 74 | shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() 75 | 76 | b, _, _ = extrinsics.shape 77 | h, w = image_shape 78 | 79 | fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) 80 | tan_fov_x = (0.5 * fov_x).tan() 81 | tan_fov_y = (0.5 * fov_y).tan() 82 | 83 | projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) 84 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i") 85 | view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") # W2C 86 | full_projection = view_matrix @ projection_matrix 87 | 88 | all_images = [] 89 | all_radii = [] 90 | all_depth = [] 91 | for i in range(b): 92 | # Set up a tensor for the gradients of the screen-space means. 93 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) 94 | try: 95 | mean_gradients.retain_grad() 96 | except Exception: 97 | pass 98 | 99 | settings = GaussianRasterizationSettings( 100 | image_height=h, 101 | image_width=w, 102 | tanfovx=tan_fov_x[i].item(), 103 | tanfovy=tan_fov_y[i].item(), 104 | bg=background_color[i], 105 | scale_modifier=1.0, 106 | viewmatrix=view_matrix[i], 107 | projmatrix=full_projection[i], 108 | sh_degree=degree, 109 | campos=extrinsics[i, :3, 3], 110 | prefiltered=False, # This matches the original usage. 111 | debug=False, 112 | ) 113 | rasterizer = GaussianRasterizer(settings) 114 | 115 | row, col = torch.triu_indices(3, 3) 116 | 117 | image, radii, depth, alpha = rasterizer( 118 | means3D=gaussian_means[i], 119 | means2D=mean_gradients, 120 | shs=shs[i] if use_sh else None, 121 | colors_precomp=None if use_sh else shs[i, :, 0, :], 122 | opacities=gaussian_opacities[i, ..., None], 123 | cov3D_precomp=gaussian_covariances[i, :, row, col], 124 | ) 125 | all_images.append(image) 126 | all_radii.append(radii) 127 | all_depth.append(depth) 128 | return torch.stack(all_images), torch.stack(all_depth) 129 | 130 | 131 | def render_cuda_orthographic( 132 | extrinsics: Float[Tensor, "batch 4 4"], 133 | width: Float[Tensor, " batch"], 134 | height: Float[Tensor, " batch"], 135 | near: Float[Tensor, " batch"], 136 | far: Float[Tensor, " batch"], 137 | image_shape: tuple[int, int], 138 | background_color: Float[Tensor, "batch 3"], 139 | gaussian_means: Float[Tensor, "batch gaussian 3"], 140 | gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], 141 | gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], 142 | gaussian_opacities: Float[Tensor, "batch gaussian"], 143 | fov_degrees: float = 0.1, 144 | use_sh: bool = True, 145 | dump: dict | None = None, 146 | ) -> Float[Tensor, "batch 3 height width"]: 147 | b, _, _ = extrinsics.shape 148 | h, w = image_shape 149 | assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 150 | # with torch.autocast(device_type="cuda", enabled=False): 151 | 152 | _, _, _, n = gaussian_sh_coefficients.shape 153 | degree = isqrt(n) - 1 154 | shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() 155 | 156 | # Create fake "orthographic" projection by moving the camera back and picking a 157 | # small field of view. 158 | fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad() 159 | tan_fov_x = (0.5 * fov_x).tan() 160 | distance_to_near = (0.5 * width) / tan_fov_x 161 | tan_fov_y = 0.5 * height / distance_to_near 162 | fov_y = (2 * tan_fov_y).atan() 163 | near = near + distance_to_near 164 | far = far + distance_to_near 165 | move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 166 | move_back[2, 3] = -distance_to_near 167 | extrinsics = extrinsics @ move_back 168 | 169 | # Escape hatch for visualization/figures. 170 | if dump is not None: 171 | dump["extrinsics"] = extrinsics 172 | dump["fov_x"] = fov_x 173 | dump["fov_y"] = fov_y 174 | dump["near"] = near 175 | dump["far"] = far 176 | 177 | projection_matrix = get_projection_matrix( 178 | near, far, repeat(fov_x, "-> b", b=b), fov_y 179 | ) 180 | projection_matrix = rearrange(projection_matrix.float(), "b i j -> b j i") 181 | view_matrix = rearrange(extrinsics.float().inverse(), "b i j -> b j i") 182 | full_projection = view_matrix @ projection_matrix 183 | 184 | all_images = [] 185 | all_radii = [] 186 | for i in range(b): 187 | # Set up a tensor for the gradients of the screen-space means. 188 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) 189 | try: 190 | mean_gradients.retain_grad() 191 | except Exception: 192 | pass 193 | 194 | settings = GaussianRasterizationSettings( 195 | image_height=h, 196 | image_width=w, 197 | tanfovx=tan_fov_x, 198 | tanfovy=tan_fov_y, 199 | bg=background_color[i].float(), 200 | scale_modifier=1.0, 201 | viewmatrix=view_matrix[i].float(), 202 | projmatrix=full_projection[i].float(), 203 | sh_degree=degree, 204 | campos=extrinsics[i, :3, 3].float(), 205 | prefiltered=False, # This matches the original usage. 206 | debug=False, 207 | ) 208 | rasterizer = GaussianRasterizer(settings) 209 | 210 | row, col = torch.triu_indices(3, 3) 211 | with torch.cuda.amp.autocast(enabled=False): 212 | image, radii, depth, alpha = rasterizer( 213 | means3D=gaussian_means[i], 214 | means2D=mean_gradients, 215 | shs=shs[i] if use_sh else None, 216 | colors_precomp=None if use_sh else shs[i, :, 0, :], 217 | opacities=gaussian_opacities[i, ..., None], 218 | cov3D_precomp=gaussian_covariances[i, :, row, col], 219 | ) 220 | all_images.append(image) 221 | all_radii.append(radii) 222 | # del mean_gradients 223 | # torch.cuda.empty_cache() 224 | return torch.stack(all_images) 225 | 226 | -------------------------------------------------------------------------------- /src/model/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ...dataset import DatasetCfg 9 | from ..types import Gaussians 10 | 11 | DepthRenderingMode = Literal[ 12 | "depth", 13 | "log", 14 | "disparity", 15 | "relative_disparity", 16 | ] 17 | 18 | 19 | @dataclass 20 | class DecoderOutput: 21 | color: Float[Tensor, "batch view 3 height width"] 22 | depth: Float[Tensor, "batch view height width"] | None 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | class Decoder(nn.Module, ABC, Generic[T]): 29 | cfg: T 30 | dataset_cfg: DatasetCfg 31 | 32 | def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | self.dataset_cfg = dataset_cfg 36 | 37 | @abstractmethod 38 | def forward( 39 | self, 40 | gaussians: Gaussians, 41 | extrinsics: Float[Tensor, "batch view 4 4"], 42 | intrinsics: Float[Tensor, "batch view 3 3"], 43 | near: Float[Tensor, "batch view"], 44 | far: Float[Tensor, "batch view"], 45 | image_shape: tuple[int, int], 46 | depth_mode: DepthRenderingMode | None = None, 47 | ) -> DecoderOutput: 48 | pass 49 | -------------------------------------------------------------------------------- /src/model/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | from .cuda_splatting import render_cuda 12 | from .decoder import Decoder, DecoderOutput 13 | 14 | 15 | @dataclass 16 | class DecoderSplattingCUDACfg: 17 | name: Literal["splatting_cuda"] 18 | 19 | 20 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 21 | background_color: Float[Tensor, "3"] 22 | 23 | def __init__( 24 | self, 25 | cfg: DecoderSplattingCUDACfg, 26 | dataset_cfg: DatasetCfg, 27 | ) -> None: 28 | super().__init__(cfg, dataset_cfg) 29 | self.register_buffer( 30 | "background_color", 31 | torch.tensor(dataset_cfg.background_color, dtype=torch.float32), 32 | persistent=False, 33 | ) 34 | 35 | def forward( 36 | self, 37 | gaussians: Gaussians, 38 | extrinsics: Float[Tensor, "batch view 4 4"], 39 | intrinsics: Float[Tensor, "batch view 3 3"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | image_shape: tuple[int, int], 43 | ) -> DecoderOutput: 44 | b, v, _, _ = extrinsics.shape 45 | color, depth = render_cuda( 46 | rearrange(extrinsics, "b v i j -> (b v) i j"), 47 | rearrange(intrinsics, "b v i j -> (b v) i j"), 48 | rearrange(near, "b v -> (b v)"), 49 | rearrange(far, "b v -> (b v)"), 50 | image_shape, 51 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 52 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 53 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 54 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 55 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 56 | ) 57 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 58 | depth = rearrange(depth, "(b v) c h w -> b v c h w", b=b, v=v).squeeze(2) 59 | 60 | return DecoderOutput( 61 | color, 62 | depth, 63 | ) 64 | -------------------------------------------------------------------------------- /src/model/decoder/spaltting_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from simple_knn._C import distCUDA2 7 | from .utils import get_covariance 8 | 9 | from diff_gaussian_rasterization import ( 10 | GaussianRasterizationSettings, 11 | GaussianRasterizer, 12 | ) 13 | 14 | from core.options import Options 15 | 16 | import kiui 17 | 18 | class GaussianRenderer: 19 | def __init__(self, opt: Options): 20 | 21 | self.opt = opt 22 | self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") 23 | 24 | # intrinsics 25 | 26 | self.tan_half_fov = np.tan(0.5 * self.opt.FoVy) 27 | 28 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=0.5): 29 | device = gaussians['position'].device 30 | B, V = cam_view.shape[:2] 31 | # loop of loop... 32 | images = [] 33 | alphas = [] 34 | 35 | position_batch = gaussians['position'] 36 | opacity_batch = gaussians['opacity'] 37 | scale_batch = gaussians['scale'] 38 | cov3D_batch = gaussians['cov3d'] 39 | rgb_batch = gaussians['rgb'] 40 | 41 | for b in range(B): 42 | # pos, opacity, scale, rotation, shs 43 | means3D = position_batch[b].contiguous().float() 44 | opacity = opacity_batch[b].contiguous().float() 45 | scales = scale_batch[b].contiguous().float() 46 | cov3D = cov3D_batch[b].contiguous().float() 47 | rgbs = rgb_batch[b].contiguous().float() # [N, 3] 48 | 49 | dist2 = torch.clamp_min(distCUDA2(means3D), 0.0000001) 50 | scales_ = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() 51 | scale = (scales+1)*scales_ 52 | cov3D = get_covariance(scale, cov3D).reshape(-1, 6) 53 | 54 | for v in range(V): 55 | 56 | # render novel views 57 | view_matrix = cam_view[b, v].float() 58 | view_proj_matrix = cam_view_proj[b, v].float() 59 | campos = cam_pos[b, v].float() 60 | 61 | raster_settings = GaussianRasterizationSettings( 62 | image_height=self.opt.output_size_h, 63 | image_width=self.opt.output_size_w, 64 | tanfovx=self.tan_half_fov, 65 | tanfovy=self.tan_half_fov, 66 | bg=self.bg_color if bg_color is None else bg_color, 67 | scale_modifier=scale_modifier, 68 | viewmatrix=view_matrix, 69 | projmatrix=view_proj_matrix, 70 | sh_degree=0, 71 | campos=campos, 72 | prefiltered=False, 73 | debug=False, 74 | ) 75 | 76 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 77 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 78 | with torch.cuda.amp.autocast(enabled=True): 79 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 80 | means3D=means3D, 81 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), 82 | shs=None, 83 | colors_precomp=rgbs, 84 | opacities=opacity, 85 | # scales=None, 86 | # rotations=None, 87 | cov3D_precomp=cov3D, 88 | ) 89 | rendered_image = rendered_image.clamp(0, 1) 90 | #torch.cuda.synchronize() 91 | # images.append(rendered_image.cpu()) 92 | # alphas.append(rendered_alpha.cpu()) 93 | images.append(rendered_image) 94 | alphas.append(rendered_alpha) 95 | 96 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size_h, self.opt.output_size_w) 97 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size_h, self.opt.output_size_w) 98 | 99 | return { 100 | "image": images.to(device), # [B, V, 3, H, W] 101 | "alpha": alphas.to(device), # [B, V, 1, H, W] 102 | } 103 | 104 | 105 | def save_ply(self, gaussians, path, compatible=True): 106 | # gaussians: [B, N, 14] 107 | # compatible: save pre-activated gaussians as in the original paper 108 | 109 | assert gaussians.shape[0] == 1, 'only support batch size 1' 110 | 111 | from plyfile import PlyData, PlyElement 112 | 113 | means3D = gaussians[0, :, 0:3].contiguous().float() 114 | opacity = gaussians[0, :, 3:4].contiguous().float() 115 | scales = gaussians[0, :, 4:7].contiguous().float() 116 | rotations = gaussians[0, :, 7:11].contiguous().float() 117 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] 118 | 119 | # prune by opacity 120 | mask = opacity.squeeze(-1) >= 0.005 121 | means3D = means3D[mask] 122 | opacity = opacity[mask] 123 | scales = scales[mask] 124 | rotations = rotations[mask] 125 | shs = shs[mask] 126 | 127 | # invert activation to make it compatible with the original ply format 128 | if compatible: 129 | opacity = kiui.op.inverse_sigmoid(opacity) 130 | scales = torch.log(scales + 1e-8) 131 | shs = (shs - 0.5) / 0.28209479177387814 132 | 133 | xyzs = means3D.detach().cpu().numpy() 134 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 135 | opacities = opacity.detach().cpu().numpy() 136 | scales = scales.detach().cpu().numpy() 137 | rotations = rotations.detach().cpu().numpy() 138 | 139 | l = ['x', 'y', 'z'] 140 | # All channels except the 3 DC 141 | for i in range(f_dc.shape[1]): 142 | l.append('f_dc_{}'.format(i)) 143 | l.append('opacity') 144 | for i in range(scales.shape[1]): 145 | l.append('scale_{}'.format(i)) 146 | for i in range(rotations.shape[1]): 147 | l.append('rot_{}'.format(i)) 148 | 149 | dtype_full = [(attribute, 'f4') for attribute in l] 150 | 151 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 152 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 153 | elements[:] = list(map(tuple, attributes)) 154 | el = PlyElement.describe(elements, 'vertex') 155 | 156 | PlyData([el]).write(path) 157 | 158 | def load_ply(self, path, compatible=True): 159 | 160 | from plyfile import PlyData, PlyElement 161 | 162 | plydata = PlyData.read(path) 163 | 164 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 165 | np.asarray(plydata.elements[0]["y"]), 166 | np.asarray(plydata.elements[0]["z"])), axis=1) 167 | print("Number of points at loading : ", xyz.shape[0]) 168 | 169 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 170 | 171 | shs = np.zeros((xyz.shape[0], 3)) 172 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 173 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) 174 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) 175 | 176 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 177 | scales = np.zeros((xyz.shape[0], len(scale_names))) 178 | for idx, attr_name in enumerate(scale_names): 179 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 180 | 181 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] 182 | rots = np.zeros((xyz.shape[0], len(rot_names))) 183 | for idx, attr_name in enumerate(rot_names): 184 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 185 | 186 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) 187 | gaussians = torch.from_numpy(gaussians).float() # cpu 188 | 189 | if compatible: 190 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) 191 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) 192 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 193 | 194 | return gaussians 195 | 196 | def load_gaussians_from_ply(self, path): 197 | 198 | from plyfile import PlyData, PlyElement 199 | 200 | plydata = PlyData.read(path) 201 | 202 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 203 | np.asarray(plydata.elements[0]["y"]), 204 | np.asarray(plydata.elements[0]["z"])), axis = 1) 205 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 206 | 207 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 208 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 209 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 210 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 211 | 212 | extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] 213 | extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) 214 | # assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 215 | features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) 216 | for idx, attr_name in enumerate(extra_f_names): 217 | features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) 218 | # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) 219 | features_extra = features_extra.reshape((features_extra.shape[0], 3, (3 + 1) ** 2 - 1)) 220 | 221 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 222 | scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) 223 | scales = np.zeros((xyz.shape[0], len(scale_names))) 224 | for idx, attr_name in enumerate(scale_names): 225 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 226 | 227 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 228 | rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) 229 | rots = np.zeros((xyz.shape[0], len(rot_names))) 230 | for idx, attr_name in enumerate(rot_names): 231 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 232 | 233 | positions = torch.tensor(xyz, dtype = torch.float) 234 | colors = torch.tensor(SH2RGB(features_dc)[:, [2, 1, 0]], dtype = torch.float).squeeze(-1) 235 | opacity = torch.sigmoid(torch.tensor(opacities, dtype = torch.float)) 236 | scales = torch.exp(torch.tensor(scales, dtype = torch.float)) 237 | rotations = torch.nn.functional.normalize(torch.tensor(rots, dtype = torch.float)) 238 | 239 | gaussians = torch.cat([positions, opacity, scales, rotations, colors], dim=1) 240 | 241 | return gaussians 242 | 243 | C0 = 0.28209479177387814 244 | def SH2RGB(sh): 245 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /src/model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_lrm import EncoderLRM, EncoderLRMCfg 5 | 6 | 7 | ENCODERS = { 8 | "lrm": (EncoderLRM), 9 | } 10 | 11 | EncoderCfg = EncoderLRMCfg 12 | 13 | 14 | def get_encoder(cfg: EncoderCfg) -> Encoder: 15 | 16 | encoder = ENCODERS[cfg.name] 17 | encoder = encoder(cfg) 18 | return encoder -------------------------------------------------------------------------------- /src/model/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import nn 5 | 6 | from ...dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | deterministic: bool, 24 | ) -> Gaussians: 25 | pass 26 | 27 | def get_data_shim(self) -> DataShim: 28 | """The default shim doesn't modify the batch.""" 29 | return lambda x: x 30 | -------------------------------------------------------------------------------- /src/model/encoder/encoder_lrm.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from typing import Tuple 8 | from jaxtyping import Float 9 | from torch import Tensor, nn 10 | 11 | from .transformer_processor.processor import Processor 12 | 13 | from ...dataset.shims.bounds_shim import apply_bounds_shim 14 | from ...dataset.shims.patch_shim import apply_patch_shim 15 | from ...dataset.types import BatchedExample, DataShim 16 | from ..types import Gaussians 17 | 18 | from .encoder import Encoder 19 | 20 | 21 | EPS = 1e-8 22 | 23 | def _init_weights(m): 24 | if isinstance(m, nn.Linear): 25 | nn.init.normal_(m.weight, std=.02) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | 29 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 30 | def quaternion_to_matrix( 31 | quaternions, 32 | eps = 1e-8, 33 | ): 34 | # Order changed to match scipy format! 35 | i, j, k, r = torch.unbind(quaternions, dim=-1) 36 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 37 | 38 | o = torch.stack( 39 | ( 40 | 1 - two_s * (j * j + k * k), 41 | two_s * (i * j - k * r), 42 | two_s * (i * k + j * r), 43 | two_s * (i * j + k * r), 44 | 1 - two_s * (i * i + k * k), 45 | two_s * (j * k - i * r), 46 | two_s * (i * k - j * r), 47 | two_s * (j * k + i * r), 48 | 1 - two_s * (i * i + j * j), 49 | ), 50 | -1, 51 | ) 52 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 53 | 54 | def build_covariance( 55 | scale, 56 | rotation_xyzw, 57 | ): 58 | scale = scale.diag_embed() 59 | rotation = quaternion_to_matrix(rotation_xyzw) 60 | return ( 61 | rotation 62 | @ scale 63 | @ rearrange(scale, "... i j -> ... j i") 64 | @ rearrange(rotation, "... i j -> ... j i") 65 | ) 66 | 67 | 68 | 69 | @dataclass 70 | class TransformerCfg: 71 | head_dim: int 72 | num_layers: int 73 | 74 | @dataclass 75 | class GaussianCfg: 76 | sh_degree: int 77 | scale_bias: float 78 | scale_max: float 79 | opacity_bias: float 80 | near_plane: float 81 | far_plane: float 82 | 83 | @dataclass 84 | class OpacityMappingCfg: 85 | initial: float 86 | final: float 87 | warm_up: int 88 | 89 | @dataclass 90 | class EncoderLRMCfg: 91 | name: Literal["lrm"] 92 | patch_size: int 93 | attn_dim: int 94 | transformer: TransformerCfg 95 | gaussians_params: GaussianCfg 96 | apply_bounds_shim: bool 97 | near_disparity: float 98 | 99 | class EncoderLRM(Encoder[EncoderLRMCfg]): 100 | tokenizer: nn.Sequential 101 | attn_processor: Processor 102 | token_decoder: nn.Sequential 103 | 104 | 105 | def __init__(self, cfg: EncoderLRMCfg) -> None: 106 | super().__init__(cfg) 107 | input_dim = 9 # RGB + plucker ray 108 | self.patch_size = cfg.patch_size 109 | self.attn_dim = cfg.attn_dim 110 | self.tokenizer = nn.Sequential( 111 | nn.Linear(input_dim * self.patch_size ** 2, self.attn_dim, bias=False), 112 | ) 113 | self.tokenizer.apply(_init_weights) 114 | 115 | self.attn_processor = Processor(cfg) 116 | 117 | self.token_decoder = nn.Sequential( 118 | nn.LayerNorm(self.attn_dim, bias=False), 119 | nn.Linear( 120 | self.attn_dim, (1 + (cfg.gaussians_params.sh_degree + 1) ** 2 * 3 + 3 + 4 + 1) * self.patch_size ** 2, 121 | bias=False, 122 | ) 123 | ) 124 | self.token_decoder.apply(_init_weights) 125 | 126 | 127 | 128 | def map_pdf_to_opacity( 129 | self, 130 | pdf: Float[Tensor, " *batch"], 131 | global_step: int, 132 | ) -> Float[Tensor, " *batch"]: 133 | # https://www.desmos.com/calculator/opvwti3ba9 134 | 135 | # Figure out the exponent. 136 | cfg = self.cfg.opacity_mapping 137 | x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial) 138 | exponent = 2**x 139 | 140 | # Map the probability density to an opacity. 141 | return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent)) 142 | 143 | def get_camera_ray_dir( 144 | self, 145 | context: dict, 146 | ) -> Tuple[Tensor, Tensor]: 147 | 148 | dtype = context["image"].dtype 149 | device = context["image"].device 150 | B, V, _, H, W = context["image"].shape 151 | input_c2ws, input_intr_raw = context["extrinsics"], context["intrinsics"] #W2C 152 | 153 | # Reshape the intrinsics 154 | fx = input_intr_raw[..., 0, 0].unsqueeze(-1) * W # (B, V, 1) 155 | fy = input_intr_raw[..., 1, 1].unsqueeze(-1) * H # (B, V, 1) 156 | cx = input_intr_raw[..., 0, 2].unsqueeze(-1) * W # (B, V, 1) 157 | cy = input_intr_raw[..., 1, 2].unsqueeze(-1) * H # (B, V, 1) 158 | 159 | input_intr = torch.cat([fx, fy, cx, cy], dim=-1) 160 | 161 | # Embed camera info 162 | ray_o = input_c2ws[:, :, :3, 3].unsqueeze(2).expand(-1, -1, H * W, -1).float() # (B, V, H*W, 3) # camera origin 163 | x, y = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy") 164 | x = (x.to(dtype) + 0.5).view(1, 1, -1).expand(B, V, -1).to(device).contiguous() 165 | y = (y.to(dtype) + 0.5).view(1, 1, -1).expand(B, V, -1).to(device).contiguous() 166 | # unproject to camera space 167 | # breakpoint() 168 | x = (x - input_intr[:, :, 2:3]) / input_intr[:, :, 0:1] # 169 | y = (y - input_intr[:, :, 3:4]) / input_intr[:, :, 1:2] # 170 | ray_d = torch.stack([x, y, torch.ones_like(x)], dim=-1).float() # (B, V, H*W, 3) 171 | ray_d = F.normalize(ray_d, p=2, dim=-1) 172 | ray_d = ray_d @ input_c2ws[:, :, :3, :3].transpose(-1, -2).contiguous() # (B, V, H*W, 3) 173 | return ray_o, ray_d 174 | 175 | def feat2gaussian( 176 | self, 177 | gaussian_params: dict, 178 | ) -> Gaussians: 179 | means, scales, rotations, sh_feature, opacities = gaussian_params['xyz'], gaussian_params['scale'], gaussian_params['rotation'], gaussian_params['sh_feature'], gaussian_params['opacity'] 180 | covariances = build_covariance(scales, rotations) 181 | 182 | # breakpoint() 183 | return Gaussians( 184 | means=means.float(), 185 | covariances=covariances.float(), 186 | harmonics=sh_feature.permute(0, 1, 3, 2).contiguous().float(), 187 | opacities=opacities.squeeze(-1).float(), 188 | # Note: These aren't yet rotated into world space, but they're only used for 189 | # exporting Gaussians to ply files. This needs to be fixed... 190 | # scales=scales, 191 | # rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 192 | ) 193 | 194 | 195 | def forward( 196 | self, 197 | context: dict, 198 | ) -> Gaussians: 199 | device = context["image"].device 200 | b, v, _, h, w = context["image"].shape 201 | 202 | ray_o, ray_d = self.get_camera_ray_dir(context) 203 | 204 | input_image_cam = torch.cat([context["image"].view(b, v, 3, -1).permute(0, 1, 3, 2).contiguous() * 2 - 1, 205 | torch.cross(ray_o, ray_d, dim=-1), 206 | ray_d], dim=-1) # (B, V, H*W, 9) 207 | 208 | # Pachify 209 | patch_size = self.patch_size 210 | hh = h // patch_size 211 | ww = w // patch_size 212 | input_image_cam = rearrange(input_image_cam, 213 | "b v (hh ph ww pw) d -> b (v hh ww) (ph pw d)", 214 | hh=hh, ww=ww, ph=patch_size, pw=patch_size) 215 | 216 | # Tokenize the input images 217 | image_tokens = self.tokenizer(input_image_cam) # (B, V*hh*ww, D) 218 | with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): 219 | image_tokens = self.attn_processor(image_tokens, use_checkpoint=True) 220 | 221 | # Decode token to gaussians 222 | gaussians = self.token_decoder(image_tokens) 223 | gaussians = rearrange(gaussians, "b (v hh ww) (ph pw d) -> b (v hh ph ww pw) d", v=v, hh=hh, ww=ww, ph=patch_size, pw=patch_size) 224 | 225 | dist, feature, scale, rotation, opacity = torch.split(gaussians, [1, (self.cfg.gaussians_params.sh_degree + 1) ** 2 * 3, 3, 4, 1], dim=-1) 226 | feature = feature.view(b, v*h*w, (self.cfg.gaussians_params.sh_degree + 1) ** 2, 3).contiguous() 227 | 228 | 229 | # Activate gaussian parameters 230 | w = rearrange(dist.sigmoid(), "b (v n) c -> b v n c", v=v) 231 | xyz = ray_o + ray_d * (context['near'].view(b, v, 1, 1) * (1 - w) + context['far'].view(b, v, 1, 1) * (w)) 232 | 233 | 234 | scale = torch.exp(scale + self.cfg.gaussians_params.scale_bias).clamp(max=self.cfg.gaussians_params.scale_max) 235 | 236 | opacity = (opacity + self.cfg.gaussians_params.opacity_bias).sigmoid() # 237 | 238 | rotation = rotation / (rotation.norm(dim=-1,keepdim=True) + EPS) 239 | 240 | gaussian_params = dict(xyz=xyz.flatten(1, 2), scale=scale, opacity=opacity, rotation=rotation, sh_feature=feature) 241 | 242 | 243 | gaussians = self.feat2gaussian(gaussian_params) 244 | gaussian_params.update({"gaussians": gaussians}) 245 | 246 | return self.feat2gaussian(gaussian_params) 247 | 248 | 249 | def get_data_shim(self) -> DataShim: 250 | def data_shim(batch: BatchedExample) -> BatchedExample: 251 | batch = apply_patch_shim( 252 | batch, 253 | patch_size=16, # Hard code patch size for now 254 | ) 255 | 256 | if self.cfg.apply_bounds_shim: 257 | _, _, _, h, w = batch["context"]["image"].shape 258 | near_disparity = self.cfg.near_disparity * min(h, w) 259 | batch = apply_bounds_shim(batch, near_disparity, 0.5) 260 | 261 | return batch 262 | 263 | return data_shim 264 | 265 | @property 266 | def sampler(self): 267 | # hack to make the visualizer work 268 | return self.epipolar_transformer.epipolar_sampler 269 | -------------------------------------------------------------------------------- /src/model/encoder/transformer_processor/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import xformers.ops as xops 4 | from einops import rearrange 5 | 6 | 7 | def _init_weights(m): 8 | if isinstance(m, nn.Linear): 9 | nn.init.normal_(m.weight, std=.02) 10 | if m.bias is not None: 11 | nn.init.constant_(m.bias, 0) 12 | 13 | class Mlp(nn.Module): 14 | def __init__(self, in_features, mlp_ratio=4., mlp_bias=False, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = int(in_features * mlp_ratio) 18 | self.fc1 = nn.Linear(in_features, hidden_features, bias=mlp_bias) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_bias) 21 | self.drop = nn.Dropout(drop) 22 | 23 | def forward(self, x): 24 | """ 25 | x: (B, L, D) 26 | Returns: same shape as input 27 | """ 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class SelfAttention(nn.Module): 36 | def __init__(self, dim, head_dim=64, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., use_flashatt_v2=True): 37 | super().__init__() 38 | assert dim % head_dim == 0, 'dim must be divisible by head_dim' 39 | self.num_heads = dim // head_dim 40 | self.scale = qk_scale or head_dim ** -0.5 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.attn_drop_p = attn_drop 44 | self.attn_drop = nn.Dropout(attn_drop) 45 | self.proj = nn.Linear(dim, dim, bias=False) 46 | self.proj_drop = nn.Dropout(proj_drop) 47 | 48 | self.use_flashatt_v2 = use_flashatt_v2 49 | 50 | def forward(self, x): 51 | """ 52 | x: (B, L, D) 53 | Returns: same shape as input 54 | """ 55 | B, N, C = x.shape 56 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 57 | 58 | if self.use_flashatt_v2: 59 | qkv = qkv.permute(2, 0, 1, 3, 4) 60 | q, k, v = qkv[0], qkv[1], qkv[2] # (B, N, H, C) 61 | x = xops.memory_efficient_attention(q, k, v, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), p=self.attn_drop_p) 62 | x = rearrange(x, 'b n h d -> b n (h d)') 63 | else: 64 | qkv = qkv.permute(2, 0, 3, 1, 4) 65 | q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, C) 66 | attn = (q @ k.transpose(-2, -1)) * self.scale 67 | attn = attn.softmax(dim=-1) 68 | attn = self.attn_drop(attn) 69 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 70 | 71 | x = self.proj(x) 72 | x = self.proj_drop(x) 73 | return x 74 | 75 | class TransformerBlock(nn.Module): 76 | def __init__(self, dim, head_dim, mlp_ratio=4., mlp_bias=False, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 77 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flashatt_v2=True): 78 | super().__init__() 79 | self.norm1 = norm_layer(dim, bias=False) 80 | self.attn = SelfAttention( 81 | dim, head_dim=head_dim, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, use_flashatt_v2=use_flashatt_v2) 82 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 83 | self.norm2 = norm_layer(dim, bias=False) 84 | self.mlp = Mlp(in_features=dim, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, act_layer=act_layer, drop=drop) 85 | 86 | def forward(self, x, return_attention=False): 87 | """ 88 | x: (B, L, D) 89 | Returns: same shape as input 90 | """ 91 | with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, device_type='cuda'): 92 | y = self.attn(self.norm1(x)) 93 | x = x + self.drop_path(y) 94 | x = x + self.drop_path(self.mlp(self.norm2(x))) 95 | return x 96 | 97 | class Processor(nn.Module): 98 | def __init__(self, config): 99 | super().__init__() 100 | self.config = config 101 | self.num_layers = config.transformer.num_layers 102 | self.attn_dim = config.attn_dim 103 | 104 | 105 | self.blocks = nn.ModuleList() 106 | 107 | for _ in range(self.num_layers): 108 | self.blocks.append(TransformerBlock(self.attn_dim, config.transformer.head_dim)) 109 | self.blocks[-1].apply(_init_weights) 110 | 111 | def forward(self, x, use_checkpoint=True): 112 | """ 113 | x: (B, L, D) 114 | Returns: B and D remain the same, L might change if there are merge layers 115 | """ 116 | 117 | for i in range(self.num_layers): 118 | if use_checkpoint: 119 | x = torch.utils.checkpoint.checkpoint(self.blocks[i], x, use_reentrant=False) 120 | else: 121 | x = self.blocks(i)(x) 122 | 123 | return x -------------------------------------------------------------------------------- /src/model/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | # Shift the scene so that the median Gaussian is at the origin. 36 | means = means - means.median(dim=0).values 37 | 38 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 39 | scale_factor = means.abs().quantile(0.95, dim=0).max() 40 | means = means / scale_factor 41 | scales = scales / scale_factor 42 | 43 | # Define a rotation that makes +Z be the world up vector. 44 | rotation = [ 45 | [0, 0, 1], 46 | [-1, 0, 0], 47 | [0, -1, 0], 48 | ] 49 | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) 50 | 51 | # The Polycam viewer seems to start at a 45 degree angle. Since we want to be 52 | # looking directly at the object, we compose a 45 degree rotation onto the above 53 | # rotation. 54 | adjustment = torch.tensor( 55 | R.from_rotvec([0, 0, -45], True).as_matrix(), 56 | dtype=torch.float32, 57 | device=means.device, 58 | ) 59 | rotation = adjustment @ rotation 60 | 61 | # We also want to see the scene in camera space (as the default view). We therefore 62 | # compose the w2c rotation onto the above rotation. 63 | rotation = rotation @ extrinsics[:3, :3].inverse() 64 | 65 | # Apply the rotation to the means (Gaussian positions). 66 | means = einsum(rotation, means, "i j, ... j -> ... i") 67 | 68 | # Apply the rotation to the Gaussian rotations. 69 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 70 | rotations = rotation.detach().cpu().numpy() @ rotations 71 | rotations = R.from_matrix(rotations).as_quat() 72 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 73 | rotations = np.stack((w, x, y, z), axis=-1) 74 | 75 | # Since our axes are swizzled for the spherical harmonics, we only export the DC 76 | # band. 77 | harmonics_view_invariant = harmonics[..., 0] 78 | 79 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 80 | elements = np.empty(means.shape[0], dtype=dtype_full) 81 | attributes = ( 82 | means.detach().cpu().numpy(), 83 | torch.zeros_like(means).detach().cpu().numpy(), 84 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 85 | opacities[..., None].detach().cpu().numpy(), 86 | scales.log().detach().cpu().numpy(), 87 | rotations, 88 | ) 89 | attributes = np.concatenate(attributes, axis=1) 90 | elements[:] = list(map(tuple, attributes)) 91 | path.parent.mkdir(exist_ok=True, parents=True) 92 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 93 | -------------------------------------------------------------------------------- /src/model/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | -------------------------------------------------------------------------------- /src/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, rearrange, reduce 3 | from jaxtyping import Float 4 | from scipy.spatial.transform import Rotation as R 5 | from torch import Tensor 6 | 7 | 8 | def interpolate_intrinsics( 9 | initial: Float[Tensor, "*#batch 3 3"], 10 | final: Float[Tensor, "*#batch 3 3"], 11 | t: Float[Tensor, " time_step"], 12 | ) -> Float[Tensor, "*batch time_step 3 3"]: 13 | initial = rearrange(initial, "... i j -> ... () i j") 14 | final = rearrange(final, "... i j -> ... () i j") 15 | t = rearrange(t, "t -> t () ()") 16 | return initial + (final - initial) * t 17 | 18 | 19 | def intersect_rays( 20 | a_origins: Float[Tensor, "*#batch dim"], 21 | a_directions: Float[Tensor, "*#batch dim"], 22 | b_origins: Float[Tensor, "*#batch dim"], 23 | b_directions: Float[Tensor, "*#batch dim"], 24 | ) -> Float[Tensor, "*batch dim"]: 25 | """Compute the least-squares intersection of rays. Uses the math from here: 26 | https://math.stackexchange.com/a/1762491/286022 27 | """ 28 | 29 | # Broadcast and stack the tensors. 30 | a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( 31 | a_origins, a_directions, b_origins, b_directions 32 | ) 33 | origins = torch.stack((a_origins, b_origins), dim=-2) 34 | directions = torch.stack((a_directions, b_directions), dim=-2) 35 | 36 | # Compute n_i * n_i^T - eye(3) from the equation. 37 | n = einsum(directions, directions, "... n i, ... n j -> ... n i j") 38 | n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) 39 | 40 | # Compute the left-hand side of the equation. 41 | lhs = reduce(n, "... n i j -> ... i j", "sum") 42 | 43 | # Compute the right-hand side of the equation. 44 | rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") 45 | rhs = reduce(rhs, "... n i -> ... i", "sum") 46 | 47 | # Left-matrix-multiply both sides by the inverse of lhs to find p. 48 | return torch.linalg.lstsq(lhs, rhs).solution 49 | 50 | 51 | def normalize(a: Float[Tensor, "*#batch dim"]) -> Float[Tensor, "*#batch dim"]: 52 | return a / a.norm(dim=-1, keepdim=True) 53 | 54 | 55 | def generate_coordinate_frame( 56 | y: Float[Tensor, "*#batch 3"], 57 | z: Float[Tensor, "*#batch 3"], 58 | ) -> Float[Tensor, "*batch 3 3"]: 59 | """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" 60 | y, z = torch.broadcast_tensors(y, z) 61 | return torch.stack([y.cross(z), y, z], dim=-1) 62 | 63 | 64 | def generate_rotation_coordinate_frame( 65 | a: Float[Tensor, "*#batch 3"], 66 | b: Float[Tensor, "*#batch 3"], 67 | eps: float = 1e-4, 68 | ) -> Float[Tensor, "*batch 3 3"]: 69 | """Generate a coordinate frame where the Y direction is normal to the plane defined 70 | by unit vectors a and b. The other axes are arbitrary.""" 71 | device = a.device 72 | 73 | # Replace every entry in b that's parallel to the corresponding entry in a with an 74 | # arbitrary vector. 75 | b = b.detach().clone() 76 | parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps 77 | b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) 78 | parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps 79 | b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) 80 | 81 | # Generate the coordinate frame. The initial cross product defines the plane. 82 | return generate_coordinate_frame(normalize(a.cross(b)), a) 83 | 84 | 85 | def matrix_to_euler( 86 | rotations: Float[Tensor, "*batch 3 3"], 87 | pattern: str, 88 | ) -> Float[Tensor, "*batch 3"]: 89 | *batch, _, _ = rotations.shape 90 | rotations = rotations.reshape(-1, 3, 3) 91 | angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) 92 | rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) 93 | return rotations.reshape(*batch, 3) 94 | 95 | 96 | def euler_to_matrix( 97 | rotations: Float[Tensor, "*batch 3"], 98 | pattern: str, 99 | ) -> Float[Tensor, "*batch 3 3"]: 100 | *batch, _ = rotations.shape 101 | rotations = rotations.reshape(-1, 3) 102 | matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() 103 | rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) 104 | return rotations.reshape(*batch, 3, 3) 105 | 106 | 107 | def extrinsics_to_pivot_parameters( 108 | extrinsics: Float[Tensor, "*#batch 4 4"], 109 | pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], 110 | pivot_point: Float[Tensor, "*#batch 3"], 111 | ) -> Float[Tensor, "*batch 5"]: 112 | """Convert the extrinsics to a representation with 5 degrees of freedom: 113 | 1. Distance from pivot point in the "X" (look cross pivot axis) direction. 114 | 2. Distance from pivot point in the "Y" (pivot axis) direction. 115 | 3. Distance from pivot point in the Z (look) direction 116 | 4. Angle in plane 117 | 5. Twist (rotation not in plane) 118 | """ 119 | 120 | # The pivot coordinate frame's Z axis is normal to the plane. 121 | pivot_axis = pivot_coordinate_frame[..., :, 1] 122 | 123 | # Compute the translation elements of the pivot parametrization. 124 | translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) 125 | origin = extrinsics[..., :3, 3] 126 | delta = pivot_point - origin 127 | translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") 128 | 129 | # Add the rotation elements of the pivot parametrization. 130 | inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] 131 | y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) 132 | 133 | return torch.cat([translation, y[..., None], z[..., None]], dim=-1) 134 | 135 | 136 | def pivot_parameters_to_extrinsics( 137 | parameters: Float[Tensor, "*#batch 5"], 138 | pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], 139 | pivot_point: Float[Tensor, "*#batch 3"], 140 | ) -> Float[Tensor, "*batch 4 4"]: 141 | translation, y, z = parameters.split((3, 1, 1), dim=-1) 142 | 143 | euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) 144 | rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") 145 | 146 | # The pivot coordinate frame's Z axis is normal to the plane. 147 | pivot_axis = pivot_coordinate_frame[..., :, 1] 148 | 149 | translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) 150 | delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") 151 | origin = pivot_point - delta 152 | 153 | *batch, _ = origin.shape 154 | extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) 155 | extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() 156 | extrinsics[..., 3, 3] = 1 157 | extrinsics[..., :3, :3] = rotation 158 | extrinsics[..., :3, 3] = origin 159 | return extrinsics 160 | 161 | 162 | def interpolate_circular( 163 | a: Float[Tensor, "*#batch"], 164 | b: Float[Tensor, "*#batch"], 165 | t: Float[Tensor, "*#batch"], 166 | ) -> Float[Tensor, " *batch"]: 167 | a, b, t = torch.broadcast_tensors(a, b, t) 168 | 169 | tau = 2 * torch.pi 170 | a = a % tau 171 | b = b % tau 172 | 173 | # Consider piecewise edge cases. 174 | d = (b - a).abs() 175 | a_left = a - tau 176 | d_left = (b - a_left).abs() 177 | a_right = a + tau 178 | d_right = (b - a_right).abs() 179 | use_d = (d < d_left) & (d < d_right) 180 | use_d_left = (d_left < d_right) & (~use_d) 181 | use_d_right = (~use_d) & (~use_d_left) 182 | 183 | result = a + (b - a) * t 184 | result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] 185 | result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] 186 | 187 | return result 188 | 189 | 190 | def interpolate_pivot_parameters( 191 | initial: Float[Tensor, "*#batch 5"], 192 | final: Float[Tensor, "*#batch 5"], 193 | t: Float[Tensor, " time_step"], 194 | ) -> Float[Tensor, "*batch time_step 5"]: 195 | initial = rearrange(initial, "... d -> ... () d") 196 | final = rearrange(final, "... d -> ... () d") 197 | t = rearrange(t, "t -> t ()") 198 | ti, ri = initial.split((3, 2), dim=-1) 199 | tf, rf = final.split((3, 2), dim=-1) 200 | 201 | t_lerp = ti + (tf - ti) * t 202 | r_lerp = interpolate_circular(ri, rf, t) 203 | 204 | return torch.cat((t_lerp, r_lerp), dim=-1) 205 | 206 | 207 | @torch.no_grad() 208 | def interpolate_extrinsics( 209 | initial: Float[Tensor, "*#batch 4 4"], 210 | final: Float[Tensor, "*#batch 4 4"], 211 | t: Float[Tensor, " time_step"], 212 | eps: float = 1e-4, 213 | ) -> Float[Tensor, "*batch time_step 4 4"]: 214 | """Interpolate extrinsics by rotating around their "focus point," which is the 215 | least-squares intersection between the look vectors of the initial and final 216 | extrinsics. 217 | """ 218 | 219 | initial = initial.type(torch.float64) 220 | final = final.type(torch.float64) 221 | t = t.type(torch.float64) 222 | 223 | # Based on the dot product between the look vectors, pick from one of two cases: 224 | # 1. Look vectors are parallel: interpolate about their origins' midpoint. 225 | # 3. Look vectors aren't parallel: interpolate about their focus point. 226 | initial_look = initial[..., :3, 2] 227 | final_look = final[..., :3, 2] 228 | dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") 229 | parallel_mask = (dot_products.abs() - 1).abs() < eps 230 | 231 | # Pick focus points. 232 | initial_origin = initial[..., :3, 3] 233 | final_origin = final[..., :3, 3] 234 | pivot_point = 0.5 * (initial_origin + final_origin) 235 | pivot_point[~parallel_mask] = intersect_rays( 236 | initial_origin[~parallel_mask], 237 | initial_look[~parallel_mask], 238 | final_origin[~parallel_mask], 239 | final_look[~parallel_mask], 240 | ) 241 | 242 | # Convert to pivot parameters. 243 | pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) 244 | initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) 245 | final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) 246 | 247 | # Interpolate the pivot parameters. 248 | interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) 249 | 250 | # Convert back. 251 | return pivot_parameters_to_extrinsics( 252 | interpolated_params.type(torch.float32), 253 | rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), 254 | rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), 255 | ) 256 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /src/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /src/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /src/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /src/visualization/drawing/cameras.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import einsum, rearrange, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from ...geometry.projection import unproject 9 | from ..annotation import add_label 10 | from .lines import draw_lines 11 | from .types import Scalar, sanitize_scalar 12 | 13 | 14 | def draw_cameras( 15 | resolution: int, 16 | extrinsics: Float[Tensor, "batch 4 4"], 17 | intrinsics: Float[Tensor, "batch 3 3"], 18 | color: Float[Tensor, "batch 3"], 19 | near: Optional[Scalar] = None, 20 | far: Optional[Scalar] = None, 21 | margin: float = 0.1, # relative to AABB 22 | frustum_scale: float = 0.05, # relative to image resolution 23 | ) -> Float[Tensor, "3 3 height width"]: 24 | device = extrinsics.device 25 | 26 | # Compute scene bounds. 27 | minima, maxima = compute_aabb(extrinsics, intrinsics, near, far) 28 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 29 | minima, maxima, margin=margin 30 | ) 31 | span = (scene_maxima - scene_minima).max() 32 | 33 | # Compute frustum locations. 34 | corner_depth = (span * frustum_scale)[None] 35 | frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth) 36 | if near is not None: 37 | near_corners = unproject_frustum_corners(extrinsics, intrinsics, near) 38 | if far is not None: 39 | far_corners = unproject_frustum_corners(extrinsics, intrinsics, far) 40 | 41 | # Project the cameras onto each axis-aligned plane. 42 | projections = [] 43 | for projected_axis in range(3): 44 | image = torch.zeros( 45 | (3, resolution, resolution), 46 | dtype=torch.float32, 47 | device=device, 48 | ) 49 | image_x_axis = (projected_axis + 1) % 3 50 | image_y_axis = (projected_axis + 2) % 3 51 | 52 | def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]: 53 | x = points[..., image_x_axis] 54 | y = points[..., image_y_axis] 55 | return torch.stack([x, y], dim=-1) 56 | 57 | x_range, y_range = torch.stack( 58 | (project(scene_minima), project(scene_maxima)), dim=-1 59 | ) 60 | 61 | # Draw near and far planes. 62 | if near is not None: 63 | projected_near_corners = project(near_corners) 64 | image = draw_lines( 65 | image, 66 | rearrange(projected_near_corners, "b p xy -> (b p) xy"), 67 | rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"), 68 | color=0.25, 69 | width=2, 70 | x_range=x_range, 71 | y_range=y_range, 72 | ) 73 | if far is not None: 74 | projected_far_corners = project(far_corners) 75 | image = draw_lines( 76 | image, 77 | rearrange(projected_far_corners, "b p xy -> (b p) xy"), 78 | rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"), 79 | color=0.25, 80 | width=2, 81 | x_range=x_range, 82 | y_range=y_range, 83 | ) 84 | if near is not None and far is not None: 85 | image = draw_lines( 86 | image, 87 | rearrange(projected_near_corners, "b p xy -> (b p) xy"), 88 | rearrange(projected_far_corners, "b p xy -> (b p) xy"), 89 | color=0.25, 90 | width=2, 91 | x_range=x_range, 92 | y_range=y_range, 93 | ) 94 | 95 | # Draw the camera frustums themselves. 96 | projected_origins = project(extrinsics[:, :3, 3]) 97 | projected_frustum_corners = project(frustum_corners) 98 | start = [ 99 | repeat(projected_origins, "b xy -> (b p) xy", p=4), 100 | rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"), 101 | ] 102 | start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4) 103 | image = draw_lines( 104 | image, 105 | start, 106 | repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2), 107 | color=repeat(color, "b c -> (b r p) c", r=2, p=4), 108 | width=2, 109 | x_range=x_range, 110 | y_range=y_range, 111 | ) 112 | 113 | x_name = "XYZ"[image_x_axis] 114 | y_name = "XYZ"[image_y_axis] 115 | image = add_label(image, f"{x_name}{y_name} Projection") 116 | 117 | # TODO: Draw axis indicators. 118 | projections.append(image) 119 | 120 | return torch.stack(projections) 121 | 122 | 123 | def compute_aabb( 124 | extrinsics: Float[Tensor, "batch 4 4"], 125 | intrinsics: Float[Tensor, "batch 3 3"], 126 | near: Optional[Scalar] = None, 127 | far: Optional[Scalar] = None, 128 | ) -> tuple[ 129 | Float[Tensor, "3"], # minima of the scene 130 | Float[Tensor, "3"], # maxima of the scene 131 | ]: 132 | """Compute an axis-aligned bounding box for the camera frustums.""" 133 | 134 | device = extrinsics.device 135 | 136 | # These points are included in the AABB. 137 | points = [extrinsics[:, :3, 3]] 138 | 139 | if near is not None: 140 | near = sanitize_scalar(near, device) 141 | corners = unproject_frustum_corners(extrinsics, intrinsics, near) 142 | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) 143 | 144 | if far is not None: 145 | far = sanitize_scalar(far, device) 146 | corners = unproject_frustum_corners(extrinsics, intrinsics, far) 147 | points.append(rearrange(corners, "b p xyz -> (b p) xyz")) 148 | 149 | points = torch.cat(points, dim=0) 150 | return points.min(dim=0).values, points.max(dim=0).values 151 | 152 | 153 | def compute_equal_aabb_with_margin( 154 | minima: Float[Tensor, "*#batch 3"], 155 | maxima: Float[Tensor, "*#batch 3"], 156 | margin: float = 0.1, 157 | ) -> tuple[ 158 | Float[Tensor, "*batch 3"], # minima of the scene 159 | Float[Tensor, "*batch 3"], # maxima of the scene 160 | ]: 161 | midpoint = (maxima + minima) * 0.5 162 | span = (maxima - minima).max() * (1 + margin) 163 | scene_minima = midpoint - 0.5 * span 164 | scene_maxima = midpoint + 0.5 * span 165 | return scene_minima, scene_maxima 166 | 167 | 168 | def unproject_frustum_corners( 169 | extrinsics: Float[Tensor, "batch 4 4"], 170 | intrinsics: Float[Tensor, "batch 3 3"], 171 | depth: Float[Tensor, "#batch"], 172 | ) -> Float[Tensor, "batch 4 3"]: 173 | device = extrinsics.device 174 | 175 | # Get coordinates for the corners. Following them in a circle makes a rectangle. 176 | xy = torch.linspace(0, 1, 2, device=device) 177 | xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1) 178 | xy = rearrange(xy, "i j xy -> (i j) xy") 179 | xy = xy[torch.tensor([0, 1, 3, 2], device=device)] 180 | 181 | # Get ray directions in camera space. 182 | directions = unproject( 183 | xy, 184 | torch.ones(1, dtype=torch.float32, device=device), 185 | rearrange(intrinsics, "b i j -> b () i j"), 186 | ) 187 | 188 | # Divide by the z coordinate so that multiplying by depth will produce orthographic 189 | # depth (z depth) as opposed to Euclidean depth (distance from the camera). 190 | directions = directions / directions[..., -1:] 191 | directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i") 192 | 193 | origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz") 194 | depth = rearrange(depth, "b -> b () ()") 195 | return origins + depth * directions 196 | -------------------------------------------------------------------------------- /src/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /src/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /src/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /src/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [ 88 | color_function(batch) 89 | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) 90 | ] 91 | color = torch.cat(color, dim=0) 92 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 93 | 94 | # If any MSAA passes remain, subdivide. 95 | if remaining_passes > 0: 96 | mask = detect_msaa_pixels(color) 97 | batch_index, row_index, col_index = torch.where(mask) 98 | xy = xy[batch_index, row_index, col_index] 99 | 100 | offsets = generate_sample_grid((subdivision, subdivision), device) 101 | offsets = (offsets / subdivision - 0.5) * scale 102 | 103 | color_fine = run_msaa_pass( 104 | xy[:, None, None] + offsets, 105 | color_function, 106 | scale / subdivision, 107 | subdivision, 108 | remaining_passes - 1, 109 | device, 110 | batch_size=batch_size, 111 | ) 112 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 113 | 114 | return color 115 | 116 | 117 | @torch.no_grad() 118 | def render( 119 | shape: tuple[int, int], 120 | color_function: ColorFunction, 121 | device: torch.device, 122 | subdivision: int = 8, 123 | num_passes: int = 2, 124 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 125 | xy = generate_sample_grid(shape, device) 126 | return run_msaa_pass( 127 | xy[None], 128 | color_function, 129 | 1.0, 130 | subdivision, 131 | num_passes, 132 | device, 133 | )[0] 134 | 135 | 136 | def render_over_image( 137 | image: Float[Tensor, "3 height width"], 138 | color_function: ColorFunction, 139 | device: torch.device, 140 | subdivision: int = 8, 141 | num_passes: int = 1, 142 | ) -> Float[Tensor, "3 height width"]: 143 | _, h, w = image.shape 144 | overlay = render( 145 | (h, w), 146 | color_function, 147 | device, 148 | subdivision=subdivision, 149 | num_passes=num_passes, 150 | ) 151 | color, alpha = overlay.split((3, 1), dim=0) 152 | return image * (1 - alpha) + color * alpha 153 | -------------------------------------------------------------------------------- /src/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /src/visualization/layout.py: -------------------------------------------------------------------------------- 1 | """This file contains useful layout utilities for images. They are: 2 | 3 | - add_border: Add a border to an image. 4 | - cat/hcat/vcat: Join images by arranging them in a line. If the images have different 5 | sizes, they are aligned as specified (start, end, center). Allows you to specify a gap 6 | between images. 7 | 8 | Images are assumed to be float32 tensors with shape (channel, height, width). 9 | """ 10 | 11 | from typing import Any, Generator, Iterable, Literal, Optional, Union 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from jaxtyping import Float 16 | from torch import Tensor 17 | 18 | Alignment = Literal["start", "center", "end"] 19 | Axis = Literal["horizontal", "vertical"] 20 | Color = Union[ 21 | int, 22 | float, 23 | Iterable[int], 24 | Iterable[float], 25 | Float[Tensor, "#channel"], 26 | Float[Tensor, ""], 27 | ] 28 | 29 | 30 | def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]: 31 | # Convert tensor to list (or individual item). 32 | if isinstance(color, torch.Tensor): 33 | color = color.tolist() 34 | 35 | # Turn iterators and individual items into lists. 36 | if isinstance(color, Iterable): 37 | color = list(color) 38 | else: 39 | color = [color] 40 | 41 | return torch.tensor(color, dtype=torch.float32) 42 | 43 | 44 | def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: 45 | it = iter(iterable) 46 | yield next(it) 47 | for item in it: 48 | yield delimiter 49 | yield item 50 | 51 | 52 | def _get_main_dim(main_axis: Axis) -> int: 53 | return { 54 | "horizontal": 2, 55 | "vertical": 1, 56 | }[main_axis] 57 | 58 | 59 | def _get_cross_dim(main_axis: Axis) -> int: 60 | return { 61 | "horizontal": 1, 62 | "vertical": 2, 63 | }[main_axis] 64 | 65 | 66 | def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: 67 | assert base >= overlay 68 | offset = { 69 | "start": 0, 70 | "center": (base - overlay) // 2, 71 | "end": base - overlay, 72 | }[align] 73 | return slice(offset, offset + overlay) 74 | 75 | 76 | def overlay( 77 | base: Float[Tensor, "channel base_height base_width"], 78 | overlay: Float[Tensor, "channel overlay_height overlay_width"], 79 | main_axis: Axis, 80 | main_axis_alignment: Alignment, 81 | cross_axis_alignment: Alignment, 82 | ) -> Float[Tensor, "channel base_height base_width"]: 83 | # The overlay must be smaller than the base. 84 | _, base_height, base_width = base.shape 85 | _, overlay_height, overlay_width = overlay.shape 86 | assert base_height >= overlay_height and base_width >= overlay_width 87 | 88 | # Compute spacing on the main dimension. 89 | main_dim = _get_main_dim(main_axis) 90 | main_slice = _compute_offset( 91 | base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment 92 | ) 93 | 94 | # Compute spacing on the cross dimension. 95 | cross_dim = _get_cross_dim(main_axis) 96 | cross_slice = _compute_offset( 97 | base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment 98 | ) 99 | 100 | # Combine the slices and paste the overlay onto the base accordingly. 101 | selector = [..., None, None] 102 | selector[main_dim] = main_slice 103 | selector[cross_dim] = cross_slice 104 | result = base.clone() 105 | result[selector] = overlay 106 | return result 107 | 108 | 109 | def cat( 110 | main_axis: Axis, 111 | *images: Iterable[Float[Tensor, "channel _ _"]], 112 | align: Alignment = "center", 113 | gap: int = 8, 114 | gap_color: Color = 1, 115 | ) -> Float[Tensor, "channel height width"]: 116 | """Arrange images in a line. The interface resembles a CSS div with flexbox.""" 117 | device = images[0].device 118 | gap_color = _sanitize_color(gap_color).to(device) 119 | 120 | # Find the maximum image side length in the cross axis dimension. 121 | cross_dim = _get_cross_dim(main_axis) 122 | cross_axis_length = max(image.shape[cross_dim] for image in images) 123 | 124 | # Pad the images. 125 | padded_images = [] 126 | for image in images: 127 | # Create an empty image with the correct size. 128 | padded_shape = list(image.shape) 129 | padded_shape[cross_dim] = cross_axis_length 130 | base = torch.ones(padded_shape, dtype=torch.float32, device=device) 131 | base = base * gap_color[:, None, None] 132 | padded_images.append(overlay(base, image, main_axis, "start", align)) 133 | 134 | # Intersperse separators if necessary. 135 | if gap > 0: 136 | # Generate a separator. 137 | c, _, _ = images[0].shape 138 | separator_size = [gap, gap] 139 | separator_size[cross_dim - 1] = cross_axis_length 140 | separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) 141 | separator = separator * gap_color[:, None, None] 142 | 143 | # Intersperse the separator between the images. 144 | padded_images = list(_intersperse(padded_images, separator)) 145 | 146 | return torch.cat(padded_images, dim=_get_main_dim(main_axis)) 147 | 148 | 149 | def hcat( 150 | *images: Iterable[Float[Tensor, "channel _ _"]], 151 | align: Literal["start", "center", "end", "top", "bottom"] = "start", 152 | gap: int = 8, 153 | gap_color: Color = 1, 154 | ): 155 | """Shorthand for a horizontal linear concatenation.""" 156 | return cat( 157 | "horizontal", 158 | *images, 159 | align={ 160 | "start": "start", 161 | "center": "center", 162 | "end": "end", 163 | "top": "start", 164 | "bottom": "end", 165 | }[align], 166 | gap=gap, 167 | gap_color=gap_color, 168 | ) 169 | 170 | 171 | def vcat( 172 | *images: Iterable[Float[Tensor, "channel _ _"]], 173 | align: Literal["start", "center", "end", "left", "right"] = "start", 174 | gap: int = 8, 175 | gap_color: Color = 1, 176 | ): 177 | """Shorthand for a horizontal linear concatenation.""" 178 | return cat( 179 | "vertical", 180 | *images, 181 | align={ 182 | "start": "start", 183 | "center": "center", 184 | "end": "end", 185 | "left": "start", 186 | "right": "end", 187 | }[align], 188 | gap=gap, 189 | gap_color=gap_color, 190 | ) 191 | 192 | 193 | def add_border( 194 | image: Float[Tensor, "channel height width"], 195 | border: int = 8, 196 | color: Color = 1, 197 | ) -> Float[Tensor, "channel new_height new_width"]: 198 | color = _sanitize_color(color).to(image) 199 | c, h, w = image.shape 200 | result = torch.empty( 201 | (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device 202 | ) 203 | result[:] = color[:, None, None] 204 | result[:, border : h + border, border : w + border] = image 205 | return result 206 | 207 | 208 | def resize( 209 | image: Float[Tensor, "channel height width"], 210 | shape: Optional[tuple[int, int]] = None, 211 | width: Optional[int] = None, 212 | height: Optional[int] = None, 213 | ) -> Float[Tensor, "channel new_height new_width"]: 214 | assert (shape is not None) + (width is not None) + (height is not None) == 1 215 | _, h, w = image.shape 216 | 217 | if width is not None: 218 | shape = (int(h * width / w), width) 219 | elif height is not None: 220 | shape = (height, int(w * height / h)) 221 | 222 | return F.interpolate( 223 | image[None], 224 | shape, 225 | mode="bilinear", 226 | align_corners=False, 227 | antialias="bilinear", 228 | )[0] 229 | -------------------------------------------------------------------------------- /src/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [ 16 | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) 17 | for x in images 18 | ] 19 | for image, result in zip(images, results): 20 | slices = [slice(0, x) for x in image.shape] 21 | result[slices] = image[slices] 22 | return results 23 | 24 | 25 | def render_projections( 26 | gaussians: Gaussians, 27 | resolution: int, 28 | margin: float = 0.1, 29 | draw_label: bool = True, 30 | extra_label: str = "", 31 | ) -> Float[Tensor, "batch 3 3 height width"]: 32 | device = gaussians.means.device 33 | b, _, _ = gaussians.means.shape 34 | 35 | # Compute the minima and maxima of the scene. 36 | minima = gaussians.means.min(dim=1).values 37 | maxima = gaussians.means.max(dim=1).values 38 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 39 | minima, maxima, margin=margin 40 | ) 41 | 42 | projections = [] 43 | for look_axis in range(3): 44 | right_axis = (look_axis + 1) % 3 45 | down_axis = (look_axis + 2) % 3 46 | 47 | # Define the extrinsics for rendering. 48 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 49 | extrinsics[:, right_axis, 0] = 1 50 | extrinsics[:, down_axis, 1] = 1 51 | extrinsics[:, look_axis, 2] = 1 52 | extrinsics[:, right_axis, 3] = 0.5 * ( 53 | scene_minima[:, right_axis] + scene_maxima[:, right_axis] 54 | ) 55 | extrinsics[:, down_axis, 3] = 0.5 * ( 56 | scene_minima[:, down_axis] + scene_maxima[:, down_axis] 57 | ) 58 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 59 | extrinsics[:, 3, 3] = 1 60 | 61 | # Define the intrinsics for rendering. 62 | extents = scene_maxima - scene_minima 63 | far = extents[:, look_axis] 64 | near = torch.zeros_like(far) 65 | width = extents[:, right_axis] 66 | height = extents[:, down_axis] 67 | 68 | projection = render_cuda_orthographic( 69 | extrinsics, 70 | width, 71 | height, 72 | near, 73 | far, 74 | (resolution, resolution), 75 | torch.zeros((b, 3), dtype=torch.float32, device=device), 76 | gaussians.means, 77 | gaussians.covariances, 78 | gaussians.harmonics, 79 | gaussians.opacities, 80 | fov_degrees=10.0, 81 | ) 82 | if draw_label: 83 | right_axis_name = "XYZ"[right_axis] 84 | down_axis_name = "XYZ"[down_axis] 85 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 86 | projection = torch.stack([add_label(x, label) for x in projection]) 87 | 88 | projections.append(projection) 89 | 90 | return torch.stack(pad(projections), dim=1) 91 | 92 | 93 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 94 | # Define colors for context and target views. 95 | num_context_views = batch["context"]["extrinsics"].shape[1] 96 | num_target_views = batch["target"]["extrinsics"].shape[1] 97 | color = torch.ones( 98 | (num_target_views + num_context_views, 3), 99 | dtype=torch.float32, 100 | device=batch["target"]["extrinsics"].device, 101 | ) 102 | color[num_context_views:, 1:] = 0 103 | 104 | return draw_cameras( 105 | resolution, 106 | torch.cat( 107 | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) 108 | ), 109 | torch.cat( 110 | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) 111 | ), 112 | color, 113 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 114 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 115 | ) 116 | --------------------------------------------------------------------------------