├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── .gitignore ├── gs_garden.yaml ├── ingp_lego.yaml └── nerf_lego.yaml ├── dataset └── .gitignore ├── output └── .gitignore ├── resources ├── .gitignore ├── icg_logo.bmp ├── nerficg_banner.png ├── teaser_videos │ ├── dnpc.gif │ ├── dnpc.mp4 │ ├── inpc.gif │ └── inpc.mp4 └── train_demo.gif ├── scripts ├── colmap.py ├── condaEnv.sh ├── cutie.py ├── defaultConfig.py ├── generateTables.py ├── gui.py ├── inference.py ├── install.py ├── monocularDepth.py ├── raft.py ├── sequentialTrain.py ├── train.py ├── utils.py └── vggsfm.py └── src ├── Cameras ├── Base.py ├── Equirectangular.py ├── NDC.py ├── ODS.py ├── Perspective.py ├── PerspectiveStereo.py └── utils.py ├── Datasets ├── Base.py ├── Colmap.py ├── DNeRF.py ├── LLFF.py ├── MipNeRF360.py ├── NeRF.py ├── VGGSfM.py ├── iPhone.py └── utils.py ├── Framework.py ├── Implementations.py ├── Logging.py ├── Methods ├── Base │ ├── GuiTrainer.py │ ├── Model.py │ ├── Renderer.py │ ├── Trainer.py │ └── utils.py ├── GaussianSplatting │ ├── Loss.py │ ├── Model.py │ ├── Renderer.py │ ├── Trainer.py │ ├── __init__.py │ └── utils.py ├── HierarchicalNeRF │ ├── Loss.py │ ├── Model.py │ ├── Renderer.py │ ├── Trainer.py │ ├── __init__.py │ └── utils.py ├── InstantNGP │ ├── CudaExtensions │ │ └── VolumeRenderingV2 │ │ │ ├── __init__.py │ │ │ ├── csrc │ │ │ ├── binding.cpp │ │ │ ├── include │ │ │ │ ├── helper_math.h │ │ │ │ └── utils.h │ │ │ ├── intersection.cu │ │ │ ├── losses.cu │ │ │ ├── raymarching.cu │ │ │ ├── setup.py │ │ │ └── volumerendering.cu │ │ │ ├── custom_functions.py │ │ │ └── setup.py │ ├── Loss.py │ ├── Model.py │ ├── Renderer.py │ ├── Trainer.py │ ├── __init__.py │ └── utils.py └── NeRF │ ├── Loss.py │ ├── Model.py │ ├── Renderer.py │ ├── Trainer.py │ ├── __init__.py │ └── utils.py ├── Optim ├── AdamUtils.py ├── GradientScaling.py ├── Losses │ ├── BackgroundEntropy.py │ ├── Base.py │ ├── Charbonnier.py │ ├── DSSIM.py │ ├── DepthSmoothness.py │ ├── Distortion.py │ ├── FusedDSSIM.py │ ├── Magnitude.py │ ├── Robust.py │ ├── VGG.py │ └── utils.py ├── MaskedMetrics.py └── Samplers │ ├── DatasetSamplers.py │ ├── ImageSamplers.py │ ├── RaySamplers.py │ └── utils.py ├── Thirdparty ├── Apex.py ├── DiffGaussianRasterization.py ├── FusedSSIM.py ├── SimpleKNN.py ├── TinyCudaNN.py └── TorchScatter.py └── Visual ├── ColorMap.py ├── Trajectories ├── BulletTime.py ├── Ellipse.py ├── FancyZoom.py ├── NovelView.py ├── SpiralPath.py ├── StabilizedTrain.py ├── StabilizedView.py ├── __init__.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.pdf 3 | *.pyc 4 | *.aux 5 | *.yaml 6 | *.mp4 7 | *.txt 8 | *.egg 9 | *.so 10 | *.egg-info 11 | *.ipynb 12 | *.log 13 | !resources/* 14 | .vscode 15 | .DS_Store 16 | .idea/ 17 | build/ 18 | dist/ 19 | backup/ 20 | .cache/ 21 | .pytest_cache/ 22 | imgui.ini 23 | __pycache__/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/ICGui"] 2 | path = src/ICGui 3 | url = ../../nerficg-project/icgui 4 | [submodule "src/CudaUtils"] 5 | path = src/CudaUtils 6 | url = ../../nerficg-project/cuda-utils 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Moritz Kappel, Florian Hahlbohm 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | ![Python](https://img.shields.io/static/v1?label=Python&message=3.11&color=success&logo=Python) ![PyTorch](https://img.shields.io/static/v1?label=Pytorch&message=2.5&color=success&logo=PyTorch) ![CUDA](https://img.shields.io/static/v1?label=CUDA&message=11.8&color=success&logo=NVIDIA) ![OS](https://img.shields.io/static/v1?label=OS&message=Linux&color=success&logo=Linux) [![License: MIT](https://img.shields.io/badge/License-MIT-success.svg)](https://opensource.org/licenses/MIT) 8 | 9 | A flexible Pytorch framework for simple and efficient implementation of neural radiance fields and rasterization-based view synthesis methods, including a GUI for interactive rendering. 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | ## Getting Started 22 | 23 | - This repository uses submodules, clone using one of the following options: 24 | ```shell 25 | # HTTPS 26 | git clone https://github.com/nerficg-project/nerficg.git --recursive && cd nerficg 27 | ``` 28 | or 29 | ```shell 30 | # SSH 31 | git clone git@github.com:nerficg-project/nerficg.git --recursive && cd nerficg 32 | ``` 33 | 34 | - Install the global dependencies listed in `scripts/condaEnv.sh`, or automatically create a new conda environment by executing the script: 35 | ```shell 36 | ./scripts/condaEnv.sh && conda activate nerficg 37 | ``` 38 | 39 | - To install all additional dependencies for a specific method, run: 40 | ```shell 41 | ./scripts/install.py -m 42 | ``` 43 | or use 44 | ```shell 45 | ./scripts/install.py -e 46 | ``` 47 | to only install a specific extension. 48 | 49 | - [optional] To use our training visualizations with [Weights & Biases](https://wandb.ai/site), run the following command and enter your account identifier: 50 | ```shell 51 | wandb login 52 | ``` 53 | 54 | 55 | ## Creating a Configuration File 56 | 57 | To create a configuration file for training, run 58 | ``` 59 | ./scripts/defaultConfig.py -m -d -o 60 | ``` 61 | where `` and `` match one of the items in the `src/Methods` and `src/Datasets` directories, respectively. 62 | The resulting configuration file `.yaml` will be available in the `configs` directory and can be customized as needed. 63 | To create a directory of configuration files for all scenes of a dataset, use the `-a` flag. This requires the full dataset to be available in the `dataset` directory. 64 | 65 | 66 | ## Training a New Model 67 | 68 | To train a new model from a configuration file, run: 69 | ``` 70 | ./scripts/train.py -c configs/.yaml 71 | ``` 72 | The resulting images and model checkpoints will be saved to the `output` directory. 73 | 74 | To train multiple models from a directory or list of configuration files, use the `scripts/sequentialTrain.py` script with the `-d` or `-c` flag respectively. 75 | 76 | 77 | ## Training on custom image sequences 78 | 79 | If you want to train on a custom image sequence, create a new directory with an `images` subdirectory containing all training images. 80 | Then you can prepare the image sequence using the provided [COLMAP](https://colmap.github.io) script, including various preprocessing options like monocular depth estimation, image segmentation and optical flow. 81 | Run 82 | ``` 83 | ./scripts/colmap.py -h 84 | ``` 85 | to see all available flags and options. Alternatively, you can try to use the `scripts/vggsfm.py` script if COLMAP’s SfM pipeline fails. 86 | 87 | After calibration, the custom dataset can be loaded by setting `Colmap` or `VGGSfM` as `GLOBAL.DATASET_TYPE` in the config file and entering the correct directory path in the config file under `DATASET.PATH`. 88 | 89 | 90 | ## Inference and evaluation 91 | 92 | We provide multiple scripts for easy model inference and performance evaluation after training. 93 | Use the `scripts/inference.py` script to render output images for individual subsets (train/test/eval) or custom camera trajectories defined in `src/Visual/Trajectories` using the `-s` option. 94 | Additional rendering performance benchmarking and metric calculation is available using the `-b` and `-m` flags respectively . 95 | 96 | The `scripts/generateTables.py` script further enables consistent metric calculation over multiple pre-generated output image directories (e.g. to compare multiple methods against GT), and automatically generates LaTeX code for tables containing the resulting values. 97 | Use `-h` to see the available options for all scripts. 98 | 99 | 100 | ## Graphical User Interface 101 | 102 | To inspect a pretrained model in our GUI, make sure the GUI submodule is initialized, run 103 | ``` 104 | ./scripts/gui.py 105 | ``` 106 | and select the generated output directory. 107 | 108 | Some methods support live GUI interaction during optimization. To enable live GUI support, activate the `TRAINING.GUI.ACTIVATE` flag in your config file. 109 | 110 | 111 | ## Frequently Asked Questions (FAQ) 112 | 113 | __Q:__ What coordinate system does the framework use internally? 114 | 115 | __A:__ The framework uses a left-handed coordinate system for all internal calculations: 116 | - World space: x-right, y-forward, z-down (left-handed) 117 | - Camera space: x-right, y-down, z-backward (left-handed) 118 | 119 | 126 | 127 | 128 | ## Acknowledgments 129 | 130 | We started working on this project in 2021. Over the years many projects have inspired and helped us to develop this framework. 131 | Apart from any reference you might find in our source code, we would specifically like to thank all authors of the following projects for their great work: 132 | - [NeRF: Neural Radiance Fields](https://github.com/bmild/nerf) 133 | - [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch) 134 | - [Instant Neural Graphics Primitives](https://github.com/NVlabs/instant-ngp) 135 | - [ngp_pl](https://github.com/kwea123/ngp_pl.git) 136 | - [MultiNeRF: A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF](https://github.com/google-research/multinerf) 137 | - [torch_efficient_distloss](https://github.com/sunset1995/torch_efficient_distloss) 138 | - [3D Gaussian Splatting for Real-Time Radiance Field Rendering](https://github.com/graphdeco-inria/gaussian-splatting) 139 | - [CamP Zip-NeRF: A Code Release for CamP and Zip-NeRF](https://github.com/jonbarron/camp_zipnerf/) 140 | - [ADOP: Approximate Differentiable One-Pixel Point Rendering](https://github.com/darglein/ADOP) 141 | 142 | 143 | ## License and Citation 144 | 145 | This framework is licensed under the MIT license (see [LICENSE](LICENSE)). 146 | 147 | If you use it in your research projects, please consider a citation: 148 | ```bibtex 149 | @software{nerficg, 150 | author = {Kappel, Moritz and Hahlbohm, Florian and Scholz, Timon}, 151 | license = {MIT}, 152 | month = {1}, 153 | title = {NeRFICG}, 154 | url = {https://github.com/nerficg-project}, 155 | version = {1.0}, 156 | year = {2025} 157 | } 158 | ``` -------------------------------------------------------------------------------- /configs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !gs_garden.yaml 4 | !nerf_lego.yaml 5 | !ingp_lego.yaml -------------------------------------------------------------------------------- /configs/gs_garden.yaml: -------------------------------------------------------------------------------- 1 | GLOBAL: 2 | LOG_LEVEL: 2 3 | GPU_INDICES: 4 | - 0 5 | RANDOM_SEED: 1618033989 6 | ANOMALY_DETECTION: false 7 | FILTER_WARNINGS: true 8 | METHOD_TYPE: GaussianSplatting 9 | DATASET_TYPE: MipNeRF360 10 | MODEL: 11 | SH_DEGREE: 3 12 | RENDERER: 13 | USE_FUSED_COVARIANCE_COMPUTATION: true 14 | USE_FUSED_SH_CONVERSION: true 15 | SCALE_MODIFIER: 1.0 16 | DISABLE_SH0: false 17 | DISABLE_SH1: false 18 | DISABLE_SH2: false 19 | DISABLE_SH3: false 20 | USE_BAKED_COVARIANCE: false 21 | TRAINING: 22 | LOAD_CHECKPOINT: null 23 | MODEL_NAME: 3dgs_garden 24 | NUM_ITERATIONS: 30_000 25 | ACTIVATE_TIMING: false 26 | RUN_VALIDATION: false 27 | BACKUP: 28 | FINAL_CHECKPOINT: true 29 | RENDER_TESTSET: true 30 | RENDER_TRAINSET: false 31 | RENDER_VALSET: false 32 | VISUALIZE_ERRORS: false 33 | INTERVAL: -1 34 | TRAINING_STATE: false 35 | WANDB: 36 | ACTIVATE: false 37 | ENTITY: null 38 | PROJECT: nerficg 39 | LOG_IMAGES: true 40 | INDEX_VALIDATION: -1 41 | INDEX_TRAINING: -1 42 | INTERVAL: 1000 43 | SWEEP_MODE: 44 | ACTIVE: false 45 | START_ITERATION: 999 46 | ITERATION_STRIDE: 1000 47 | GUI: 48 | ACTIVATE: true 49 | RENDER_INTERVAL: 5 50 | GUI_STATUS_ENABLED: true 51 | GUI_STATUS_INTERVAL: 20 52 | SKIP_GUI_SETUP: false 53 | FPS_ROLLING_AVERAGE_SIZE: 100 54 | LEARNING_RATE_POSITION_INIT: 0.00016 55 | LEARNING_RATE_POSITION_FINAL: 1.6e-06 56 | LEARNING_RATE_POSITION_DELAY_MULT: 0.01 57 | LEARNING_RATE_POSITION_MAX_STEPS: 30_000 58 | LEARNING_RATE_FEATURE: 0.0025 59 | LEARNING_RATE_OPACITY: 0.05 60 | LEARNING_RATE_SCALING: 0.005 61 | LEARNING_RATE_ROTATION: 0.001 62 | PERCENT_DENSE: 0.01 63 | OPACITY_RESET_INTERVAL: 3_000 64 | DENSIFY_START_ITERATION: 500 65 | DENSIFY_END_ITERATION: 15_000 66 | DENSIFICATION_INTERVAL: 100 67 | DENSIFY_GRAD_THRESHOLD: 0.0002 68 | LOSS: 69 | LAMBDA_L1: 0.8 70 | LAMBDA_DSSIM: 0.2 71 | DATASET: 72 | PATH: dataset/mipnerf360/garden 73 | IMAGE_SCALE_FACTOR: 0.25 74 | NORMALIZE_CUBE: null 75 | NORMALIZE_RECENTER: false 76 | PRECOMPUTE_RAYS: false 77 | TO_DEVICE: true 78 | BACKGROUND_COLOR: 79 | - 0.0 80 | - 0.0 81 | - 0.0 82 | NEAR_PLANE: 0.2 83 | FAR_PLANE: 100.0 84 | TEST_STEP: 8 85 | APPLY_PCA: true 86 | APPLY_PCA_RESCALE: false 87 | USE_PRECOMPUTED_DOWNSCALING: true 88 | -------------------------------------------------------------------------------- /configs/ingp_lego.yaml: -------------------------------------------------------------------------------- 1 | GLOBAL: 2 | LOG_LEVEL: 2 3 | GPU_INDICES: 4 | - 0 5 | RANDOM_SEED: 1618033989 6 | ANOMALY_DETECTION: false 7 | FILTER_WARNINGS: true 8 | METHOD_TYPE: InstantNGP 9 | DATASET_TYPE: NeRF 10 | MODEL: 11 | SCALE: 0.5 12 | RESOLUTION: 128 13 | CENTER: 14 | - 0.0 15 | - 0.0 16 | - 0.0 17 | HASHMAP_NUM_LEVELS: 16 18 | HASHMAP_NUM_FEATURES_PER_LEVEL: 2 19 | HASHMAP_LOG2_SIZE: 19 20 | HASHMAP_BASE_RESOLUTION: 16 21 | HASHMAP_TARGET_RESOLUTION: 2048 22 | NUM_DENSITY_OUTPUT_FEATURES: 16 23 | NUM_DENSITY_NEURONS: 64 24 | NUM_DENSITY_LAYERS: 1 25 | DIR_SH_ENCODING_DEGREE: 4 26 | NUM_COLOR_NEURONS: 64 27 | NUM_COLOR_LAYERS: 2 28 | RENDERER: 29 | MAX_SAMPLES: 1024 30 | EXPONENTIAL_STEPS: false 31 | DENSITY_THRESHOLD: 0.01 32 | TRAINING: 33 | LOAD_CHECKPOINT: null 34 | MODEL_NAME: ingp_lego 35 | NUM_ITERATIONS: 50_000 36 | ACTIVATE_TIMING: false 37 | RUN_VALIDATION: false 38 | BACKUP: 39 | FINAL_CHECKPOINT: true 40 | RENDER_TESTSET: true 41 | RENDER_TRAINSET: false 42 | RENDER_VALSET: false 43 | VISUALIZE_ERRORS: false 44 | INTERVAL: -1 45 | TRAINING_STATE: false 46 | WANDB: 47 | ACTIVATE: false 48 | ENTITY: null 49 | PROJECT: nerficg 50 | LOG_IMAGES: true 51 | INDEX_VALIDATION: -1 52 | INDEX_TRAINING: -1 53 | INTERVAL: 1000 54 | SWEEP_MODE: 55 | ACTIVE: false 56 | START_ITERATION: 999 57 | ITERATION_STRIDE: 1000 58 | RENDER_OCCUPANCY_GRIDS: false 59 | GUI: 60 | ACTIVATE: true 61 | RENDER_INTERVAL: 30 62 | GUI_STATUS_ENABLED: true 63 | GUI_STATUS_INTERVAL: 20 64 | SKIP_GUI_SETUP: false 65 | FPS_ROLLING_AVERAGE_SIZE: 100 66 | TARGET_BATCH_SIZE: 262_144 67 | WARMUP_STEPS: 256 68 | DENSITY_GRID_UPDATE_INTERVAL: 16 69 | LEARNING_RATE: 0.01 70 | LEARNING_RATE_DECAY_START: 20_000 71 | LEARNING_RATE_DECAY_INTERVAL: 10_000 72 | LEARNING_RATE_DECAY_BASE: 0.33 73 | ADAM_EPS: 1.0e-15 74 | USE_APEX: false 75 | DATASET: 76 | PATH: dataset/nerf_synthetic/lego 77 | IMAGE_SCALE_FACTOR: null 78 | NORMALIZE_CUBE: 2.66666666667 # NeRF synthetic cameras are inside [-4, 4]^3, geometry is only in the [-1.5, 1.5]^3 79 | NORMALIZE_RECENTER: false 80 | PRECOMPUTE_RAYS: true 81 | TO_DEVICE: true 82 | BACKGROUND_COLOR: 83 | - 0.0 84 | - 0.0 85 | - 0.0 86 | LOAD_TESTSET_DEPTHS: false 87 | -------------------------------------------------------------------------------- /configs/nerf_lego.yaml: -------------------------------------------------------------------------------- 1 | GLOBAL: 2 | LOG_LEVEL: 2 3 | GPU_INDICES: 4 | - 0 5 | RANDOM_SEED: 1618033989 6 | ANOMALY_DETECTION: false 7 | FILTER_WARNINGS: true 8 | METHOD_TYPE: NeRF 9 | DATASET_TYPE: NeRF 10 | MODEL: 11 | NUM_LAYERS: 8 12 | NUM_COLOR_LAYERS: 1 13 | NUM_FEATURES: 256 14 | ENCODING_LENGTH_POSITIONS: 10 15 | ENCODING_LENGTH_DIRECTIONS: 4 16 | ENCODING_APPEND_INPUT: true 17 | INPUT_SKIPS: 18 | - 5 19 | ACTIVATION_FUNCTION: relu 20 | RENDERER: 21 | RAY_BATCH_SIZE: 8_192 22 | NUM_SAMPLES: 256 23 | TRAINING: 24 | LOAD_CHECKPOINT: null 25 | MODEL_NAME: nerf_lego 26 | NUM_ITERATIONS: 500_000 27 | ACTIVATE_TIMING: false 28 | RUN_VALIDATION: false 29 | BACKUP: 30 | FINAL_CHECKPOINT: true 31 | RENDER_TESTSET: true 32 | RENDER_TRAINSET: false 33 | RENDER_VALSET: false 34 | VISUALIZE_ERRORS: false 35 | INTERVAL: -1 36 | TRAINING_STATE: false 37 | WANDB: 38 | ACTIVATE: false 39 | ENTITY: null 40 | PROJECT: nerficg 41 | LOG_IMAGES: true 42 | INDEX_VALIDATION: -1 43 | INDEX_TRAINING: -1 44 | INTERVAL: 10_000 45 | SWEEP_MODE: 46 | ACTIVE: false 47 | START_ITERATION: 999 48 | ITERATION_STRIDE: 1000 49 | BATCH_SIZE: 1024 50 | SAMPLE_SINGLE_IMAGE: true 51 | DENSITY_RANDOM_NOISE_STD: 0.0 52 | ADAM_BETA_1: 0.9 53 | ADAM_BETA_2: 0.999 54 | LEARNINGRATE: 0.0005 55 | LEARNINGRATE_DECAY_RATE: 0.1 56 | LEARNINGRATE_DECAY_STEPS: 500_000 57 | LAMBDA_COLOR_LOSS: 1.0 58 | LAMBDA_ALPHA_LOSS: 0.0 59 | DATASET: 60 | PATH: dataset/nerf_synthetic/lego 61 | IMAGE_SCALE_FACTOR: null 62 | NORMALIZE_CUBE: null 63 | NORMALIZE_RECENTER: true 64 | PRECOMPUTE_RAYS: true 65 | TO_DEVICE: true 66 | BACKGROUND_COLOR: 67 | - 1.0 68 | - 1.0 69 | - 1.0 70 | LOAD_TESTSET_DEPTHS: false 71 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | # ln -s /afs/cg.cs.tu-bs.de/cgdata/nerfdata/* dataset 2 | * 3 | !.gitignore -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /resources/.gitignore: -------------------------------------------------------------------------------- 1 | !* -------------------------------------------------------------------------------- /resources/icg_logo.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/icg_logo.bmp -------------------------------------------------------------------------------- /resources/nerficg_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/nerficg_banner.png -------------------------------------------------------------------------------- /resources/teaser_videos/dnpc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/dnpc.gif -------------------------------------------------------------------------------- /resources/teaser_videos/dnpc.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/dnpc.mp4 -------------------------------------------------------------------------------- /resources/teaser_videos/inpc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/inpc.gif -------------------------------------------------------------------------------- /resources/teaser_videos/inpc.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/inpc.mp4 -------------------------------------------------------------------------------- /resources/train_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/train_demo.gif -------------------------------------------------------------------------------- /scripts/condaEnv.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | ENV_NAME="nerficg" 4 | PYTHONVERSION="3.11" 5 | CUDAVERSION="cu118" 6 | HEADLESS=false 7 | 8 | # parse args 9 | while [[ $# -gt 0 ]]; do 10 | case $1 in 11 | -n|--name) 12 | ENV_NAME="$2" 13 | shift 2 14 | ;; 15 | -p|--python) 16 | PYTHONVERSION="$2" 17 | shift 2 18 | ;; 19 | -c|--cuda) 20 | CUDAVERSION="$2" 21 | shift 2 22 | ;; 23 | -h|--headless) 24 | HEADLESS=true 25 | shift 26 | ;; 27 | *) 28 | echo "Unknown argument: $1" 29 | echo "Usage: $0 [-n|--name ] [-p|--python ] [-c|--cuda ] [-h|--headless]" 30 | exit 1 31 | ;; 32 | esac 33 | done 34 | 35 | echo "Creating conda environment '$ENV_NAME' with Python $PYTHONVERSION and CUDA $CUDAVERSION" 36 | if [ "$HEADLESS" = true ]; then 37 | echo "Installing in headless mode (no GUI dependencies)" 38 | fi 39 | 40 | # create new conda environment 41 | conda create -y --name $ENV_NAME python=$PYTHONVERSION 42 | # install base dependencies 43 | conda install -y -n $ENV_NAME packaging 44 | conda run -n $ENV_NAME pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/${CUDAVERSION} 45 | conda run -n $ENV_NAME pip install numpy tqdm natsort GitPython av ffmpeg-python pyyaml munch tabulate wandb opencv-python kornia torchmetrics lpips einops setuptools plyfile matplotlib timm plotly pillow jax 46 | conda install -y -n $ENV_NAME -c conda-forge colmap 47 | # install gui dependencies 48 | if [ "$HEADLESS" = false ]; then 49 | conda run -n $ENV_NAME pip install imgui[sdl2] cuda-python platformdirs 50 | fi 51 | -------------------------------------------------------------------------------- /scripts/defaultConfig.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """defaultConfig.py: Creates a new config file with default values for a given method and dataset.""" 5 | 6 | import os 7 | from argparse import ArgumentParser 8 | import warnings 9 | import yaml 10 | from pathlib import Path 11 | import utils 12 | 13 | with utils.discoverSourcePath(): 14 | import Framework 15 | from Logging import Logger 16 | from Implementations import Methods as MI 17 | from Implementations import Datasets as DI 18 | 19 | 20 | def main(*, method_name: str, dataset_name: str, all_sequences: bool, output_filename: str) -> None: 21 | Logger.setMode(Logger.MODE_VERBOSE) 22 | # create config with global defaults 23 | Framework.config = Framework.ConfigWrapper(GLOBAL=Framework.getDefaultGlobalConfig()) 24 | Framework.config.GLOBAL.METHOD_TYPE = method_name 25 | Framework.config.GLOBAL.DATASET_TYPE = dataset_name 26 | # add renderer, model and training parameters 27 | method = MI.importMethod(method_name) 28 | Framework.config.MODEL = method.MODEL.getDefaultParameters() 29 | Framework.config.RENDERER = method.RENDERER.getDefaultParameters() 30 | Framework.config.TRAINING = method.TRAINING_INSTANCE.getDefaultParameters() 31 | # add dataset parameters 32 | dataset_class = DI.getDatasetClass(dataset_name) 33 | Framework.config.DATASET = dataset_class.getDefaultParameters() 34 | # dump config into file 35 | output_path = Path(__file__).resolve().parents[1] / 'configs' 36 | dataset_path = None 37 | if all_sequences: 38 | output_path = output_path / output_filename 39 | dataset_path = Path(Framework.config.DATASET.PATH).parents[0] 40 | os.makedirs(str(output_path), exist_ok=True) 41 | if not dataset_path.is_dir(): 42 | Logger.logError(f'failed to gather sequences from "{dataset_path}": directory not found') 43 | return 44 | config_file_names = [i.name for i in dataset_path.iterdir() if i.is_dir()] 45 | else: 46 | config_file_names = [output_filename] 47 | for config_file_name in config_file_names: 48 | config_file_path = output_path / f'{config_file_name}.yaml' 49 | try: 50 | Framework.config.TRAINING.MODEL_NAME = config_file_name 51 | if dataset_path is not None: 52 | Framework.config.DATASET.PATH = str(dataset_path / config_file_name) 53 | with open(config_file_path, 'w') as f: 54 | yaml.dump(Framework.ConfigParameterList.toDict(Framework.config), f, 55 | default_flow_style=False, indent=4, canonical=False, sort_keys=False) 56 | Logger.logInfo(f'configuration file successfully created: {config_file_path}') 57 | except IOError as e: 58 | Logger.logError(f'failed to create configuration file: "{e}"') 59 | 60 | 61 | if __name__ == '__main__': 62 | warnings.filterwarnings('ignore') 63 | # parse command line args 64 | parser: ArgumentParser = ArgumentParser(prog='defaultConfig') 65 | parser.add_argument( 66 | '-m', '--method', action='store', dest='method_name', 67 | metavar='method directory name', required=True, 68 | help='Name of the method you want to train. Name should match the directory in lib/methods.' 69 | ) 70 | parser.add_argument( 71 | '-d', '--dataset', action='store', dest='dataset_name', 72 | metavar='dataset name', required=True, 73 | help='Name of the dataset you want to train on. Name should match the python file in src/Datasets.' 74 | ) 75 | parser.add_argument( 76 | '-a', '--all', action='store_true', dest='all_sequences', 77 | help='If set, creates a directory containing a config file for each sequence in the dataset.' 78 | ) 79 | parser.add_argument( 80 | '-o', '--output', action='store', dest='output_filename', 81 | metavar='output config filename', required=True, 82 | help='Name of the generated config file.' 83 | ) 84 | args = parser.parse_args() 85 | if args.output_filename.endswith('.yaml'): 86 | args.output_filename = args.output_filename[:-5] 87 | 88 | # run main 89 | main( 90 | method_name=args.method_name, 91 | dataset_name=args.dataset_name, 92 | all_sequences=args.all_sequences, 93 | output_filename=args.output_filename 94 | ) 95 | -------------------------------------------------------------------------------- /scripts/gui.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """gui.py: opens graphical user interface.""" 5 | 6 | import sys 7 | 8 | import torch.multiprocessing as mp 9 | 10 | import utils 11 | with utils.discoverSourcePath(): 12 | from Logging import Logger 13 | 14 | if __name__ == '__main__': 15 | Logger.setMode(Logger.MODE_VERBOSE) 16 | with utils.discoverSourcePath(): 17 | try: 18 | from ICGui.launchViewer import main 19 | except ImportError as e: 20 | Logger.logError(f'Failed to open GUI: {e}\n' 21 | f'Make sure the icgui submodule is initialized and ' 22 | f'updated before running this script. ') 23 | sys.exit(0) 24 | mp.set_start_method('spawn') 25 | main() 26 | -------------------------------------------------------------------------------- /scripts/install.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """install.py: installs specific extensions or all extensions required by a given method.""" 5 | 6 | import os 7 | import subprocess 8 | import sys 9 | from argparse import ArgumentParser 10 | from pathlib import Path 11 | import importlib 12 | import warnings 13 | from types import ModuleType 14 | 15 | import utils 16 | 17 | with utils.discoverSourcePath(): 18 | import Framework 19 | from Implementations import Methods as MI 20 | from Logging import Logger 21 | 22 | 23 | def installExtension(install_name: str, install_command: list[str]) -> bool: 24 | """Installs a single extension.""" 25 | Logger.logInfo(f'Installing extension {install_name}...') 26 | result = subprocess.run(install_command, check=False) 27 | if result.returncode != 0: 28 | Logger.logError(f'Failed to install extension "{install_name}" with command: "{install_command if isinstance(install_command, str) else " ".join(install_command)}"') 29 | return result.returncode == 0 30 | 31 | 32 | def importExtension(extension_path: str) -> ModuleType: 33 | """Imports an extension module.""" 34 | extension_spec = Path(extension_path).resolve() 35 | if extension_spec.is_dir(): 36 | extension_spec = extension_spec / '__init__.py' 37 | extension_spec = importlib.util.spec_from_file_location(str(extension_spec.stem), str(extension_spec)) 38 | extension_module = importlib.util.module_from_spec(extension_spec) 39 | extension_spec.loader.exec_module(extension_module) 40 | return extension_module 41 | 42 | 43 | def main(extension_path: str, method_name: str) -> None: 44 | """Installs extensions required by a given method or a specific extension.""" 45 | Framework.setup() 46 | essential_modules = set(sys.modules.keys()) 47 | if extension_path is not None: 48 | try: 49 | extension_module = importExtension(extension_path) 50 | install_command = extension_module.__install_command__ 51 | install_name = extension_module.__extension_name__ 52 | except Framework.ExtensionError as e: 53 | install_name = e.__extension_name__ 54 | install_command = e.__install_command__ 55 | except FileNotFoundError: 56 | Logger.logError(f'Invalid extension path "{extension_path}": Module not found.') 57 | return 58 | except AttributeError as e: 59 | Logger.logError(f'Invalid extension module "{extension_path}": {e}') 60 | return 61 | if not installExtension(install_name, install_command): 62 | return 63 | if method_name is not None: 64 | if method_name not in MI.options: 65 | Logger.logError(f'Invalid method name "{method_name}".\nAvailable methods are: {MI.options}') 66 | return 67 | Logger.logInfo(f'Installing extensions for method "{method_name}"...') 68 | last_installed = None 69 | with utils.discoverSourcePath(): 70 | while True: 71 | try: 72 | all_modules = set(sys.modules.keys()) 73 | for module in all_modules.symmetric_difference(essential_modules): 74 | del sys.modules[module] 75 | MI._import(method_name) 76 | break 77 | except Framework.ExtensionError as e: 78 | if last_installed == e.__extension_name__: 79 | Logger.logError(f'Failed to install extension "{e.__extension_name__}" with command: "{e.__install_command__ if isinstance(e.__install_command__, str) else " ".join(e.__install_command__)}"') 80 | return 81 | last_installed = e.__extension_name__ 82 | if not installExtension(e.__extension_name__, e.__install_command__): 83 | return 84 | except Exception as e: 85 | Logger.logError(f'Unexpected error during method import: {e}') 86 | return 87 | Logger.logInfo('done') 88 | 89 | 90 | if __name__ == '__main__': 91 | # parse arguments 92 | parser: ArgumentParser = ArgumentParser(prog='Install') 93 | parser.add_argument( 94 | '-m', '--method', action='store', dest='method_name', default=None, 95 | metavar='method_name', required=False, 96 | help='Name of the method to install extensions for.' 97 | ) 98 | parser.add_argument( 99 | '-e', '--extension', action='store', dest='extension_path', default=None, 100 | metavar='extension_path', required=False, 101 | help='Path to extension to be installed.' 102 | ) 103 | args = parser.parse_args() 104 | # run 105 | Logger.setMode(Logger.MODE_VERBOSE) 106 | warnings.filterwarnings('ignore') 107 | os.environ['MKL_THREADING_LAYER'] = 'GNU' 108 | main(args.extension_path, args.method_name) 109 | -------------------------------------------------------------------------------- /scripts/raft.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """raft.py: Predicts optical flow for the given input image sequence using RAFT.""" 5 | 6 | import os 7 | from argparse import ArgumentParser 8 | from pathlib import Path 9 | import torch 10 | from torchvision.models.optical_flow import raft_large, Raft_Large_Weights 11 | 12 | import utils 13 | 14 | with utils.discoverSourcePath(): 15 | import Framework 16 | from Logging import Logger 17 | from Datasets.utils import list_sorted_files, list_sorted_directories, \ 18 | loadImagesParallel, saveOpticalFlowFile, flowToImage, saveImage 19 | 20 | 21 | def predictAndSave(*, filename: str, output_dir: Path, model: torch.nn.Module, color: bool, 22 | inputs1: torch.Tensor, inputs2: torch.Tensor, width: int, height: int) -> None: 23 | flows = model(inputs1, inputs2)[-1][:, :, :height, :width] 24 | saveOpticalFlowFile(output_dir / f'{filename}.flo', flows[0]) 25 | if color: 26 | saveImage(output_dir / f'{filename}.png', flowToImage(flows[0])) 27 | 28 | 29 | @torch.no_grad() 30 | def main(*, base_path: Path, output_path: Path | None, recursive: bool, backward: bool, color: bool, 31 | model: torch.nn.Module = None) -> bool: 32 | device = Framework.config.GLOBAL.DEFAULT_DEVICE 33 | # load model 34 | if model is None: 35 | Logger.logInfo('Loading RAFT model...') 36 | model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device).eval() 37 | # run on subdirectories 38 | if recursive: 39 | subdirs = [base_path / i for i in list_sorted_directories(base_path)] 40 | for subdir in subdirs: 41 | main( 42 | base_path=subdir, 43 | output_path=output_path if output_path is None else output_path / subdir.name, 44 | recursive=recursive, 45 | backward=backward, 46 | color=color, 47 | model=model 48 | ) 49 | # load images 50 | filenames = [i for i in list_sorted_files(base_path) if Path(i).suffix.lower() in ['.png', '.jpg', '.jpeg']] 51 | file_paths = [str(base_path / i) for i in filenames] 52 | if filenames: 53 | Logger.logInfo(f'Running RAFT on sequence directory: "{base_path}"...') 54 | Logger.logInfo(f'Found {len(filenames)} images') 55 | rgbs, _ = loadImagesParallel(file_paths, None, 4, 'loading image sequence') 56 | rgbs = (torch.stack(rgbs, dim=0).to(device) * 2.0) - 1.0 57 | # create output directory 58 | if output_path is None: 59 | output_dir = base_path / 'flow' 60 | else: 61 | output_dir = output_path 62 | os.makedirs(str(output_dir), exist_ok=True) 63 | # pad inputs to multiple of 8 64 | *_, h, w = rgbs.shape 65 | delta_h = (8 - (h % 8)) % 8 66 | delta_w = (8 - (w % 8)) % 8 67 | if delta_h != 0 or delta_w != 0: 68 | rgbs = torch.nn.functional.pad(rgbs, (0, delta_w, 0, delta_h, 0, 0, 0, 0), 'constant', 0) 69 | # predict and save flows 70 | for i in Logger.logProgressBar(range(len(rgbs) - 1), desc='image', leave=False): 71 | inputs1 = rgbs[i:i+1] 72 | inputs2 = rgbs[i+1:i+2] 73 | predictAndSave(filename=f'{filenames[i].split(".")[0]}_forward', output_dir=output_dir, model=model, 74 | color=color, inputs1=inputs1, inputs2=inputs2, width=w, height=h) 75 | if backward: 76 | predictAndSave(filename=f'{filenames[i+1].split(".")[0]}_backward', output_dir=output_dir, model=model, 77 | color=color, inputs1=inputs2, inputs2=inputs1, width=w, height=h) 78 | Logger.logInfo('done.') 79 | return True 80 | 81 | 82 | if __name__ == '__main__': 83 | # parse command line args 84 | parser: ArgumentParser = ArgumentParser( 85 | prog='raft.py', 86 | description='Predicts optical flow for the given input image sequence using RAFT.' 87 | ) 88 | parser.add_argument( 89 | '-i', '--input', action='store', dest='sequence_path', 90 | required=True, help='Path to the base directory containing the images.' 91 | ) 92 | parser.add_argument( 93 | '-r', '--recusive', action='store_true', dest='recursive', 94 | help='Scan for subdirectories and estimate flow.' 95 | ) 96 | parser.add_argument( 97 | '-b', '--backward', action='store_true', dest='backward', 98 | help='Generate flow predictions in backward direction.' 99 | ) 100 | parser.add_argument( 101 | '-o', '--output_path', action='store', dest='output_path', 102 | required=False, default=None, 103 | help='If set, the flow predictions will be stored in the given directory.' 104 | ) 105 | parser.add_argument( 106 | '-v', '--visualize', action='store_true', dest='color', 107 | help='Generate and store color visualizations of estimated flow.' 108 | ) 109 | args = parser.parse_args() 110 | # init Framework with defaults 111 | Framework.setup() 112 | # run main 113 | Logger.setMode(Logger.MODE_VERBOSE) 114 | main( 115 | base_path=Path(args.sequence_path), 116 | output_path=Path(args.output_path) if args.output_path is not None else None, 117 | recursive=args.recursive, 118 | backward=args.backward, 119 | color=args.color 120 | ) 121 | -------------------------------------------------------------------------------- /scripts/sequentialTrain.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """sequentialTrain.py: Sequentially runs model trainings for a list or directory of config files.""" 5 | 6 | import os 7 | import shutil 8 | from argparse import ArgumentParser 9 | from pathlib import Path 10 | import datetime 11 | from statistics import mean 12 | from tabulate import tabulate 13 | from multiprocessing import Process, Queue 14 | 15 | import utils 16 | with utils.discoverSourcePath(): 17 | import Framework 18 | import train 19 | from Logging import Logger 20 | from Datasets.utils import list_sorted_files 21 | 22 | 23 | def gatherConfigs() -> tuple[list[str], Path]: 24 | Logger.log('collecting config files') 25 | # parse arguments to retrieve config file locations 26 | parser: ArgumentParser = ArgumentParser(prog='SequentialTrain') 27 | parser.add_argument( 28 | '-c', '--configs', action='store', dest='config_paths', default=[], 29 | metavar='paths/to/config_files', required=False, nargs='+', 30 | help='Multiple whitespace separated training config files.' 31 | ) 32 | parser.add_argument( 33 | '-d', '--dir', action='store', dest='config_dir', default=None, 34 | metavar='path/to/configdir/', required=False, 35 | help='A directory containing training configuration files.' 36 | ) 37 | args, _ = parser.parse_known_args() 38 | # add all configs from -c flag 39 | config_paths: list[str] = args.config_paths 40 | # set output common directory 41 | output_directory = f'sequential_train_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}' 42 | # add all configs from config dir 43 | if args.config_dir is not None: 44 | config_dir_path = Path(args.config_dir) 45 | if not config_dir_path.is_dir(): 46 | raise Framework.TrainingError(f'config dir is not a valid directory: {config_dir_path}') 47 | directory_configs = [str(config_dir_path / i) for i in list_sorted_files(config_dir_path) if '.yaml' in i] 48 | config_paths = config_paths + directory_configs 49 | output_directory = f'{config_dir_path.name}_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}' 50 | # output directory and all configs for execution 51 | return output_directory, config_paths 52 | 53 | 54 | def writeOutputMetrics(metric_files: dict[str, Path], output_directory: Path) -> None: 55 | if metric_files: 56 | output_file = output_directory / 'summary.txt' 57 | Logger.log(f'Gathering final quality metrics for {len(metric_files)} runs in: {output_file}') 58 | parsed_values = [] 59 | for metric_file in metric_files.values(): 60 | with open(metric_file) as f: 61 | for line in f: 62 | pass 63 | parsed_values.append({i[0]: float(i[1]) for i in [j.split(':') for j in line.split(' ')]}) 64 | headers = ['Metric'] + list(metric_files.keys()) + ['Mean'] 65 | tab = [[metric_name] + [(run[metric_name]) for run in parsed_values] for metric_name in parsed_values[0].keys()] 66 | for row in tab: 67 | row.append(mean(row[1:])) 68 | with open(output_file, 'w') as f: 69 | f.write(tabulate(tabular_data=tab, headers=headers, floatfmt=".3f")) 70 | 71 | 72 | def training_process_helper(output, config): 73 | training_instance = train.main(config_path=config) 74 | output.put((training_instance.output_directory, training_instance.model.model_name, training_instance.NUM_ITERATIONS)) 75 | 76 | 77 | def main(): 78 | Logger.setMode(Logger.MODE_VERBOSE) 79 | # get list of training configs 80 | output_directory, config_paths = gatherConfigs() 81 | if not config_paths: 82 | raise Framework.TrainingError('no valid config file found') 83 | # create output directory 84 | output_directory = Path(__file__).resolve().parents[1] / 'output' / output_directory 85 | os.makedirs(str(output_directory), exist_ok=False) 86 | # run trainings 87 | Logger.log(f'running sequential training for {len(config_paths)} configurations') 88 | Logger.log(f'outputs will be available at: {output_directory}') 89 | success, failed = 0, 0 90 | metric_files = {} 91 | for config in config_paths: 92 | try: 93 | # run training for single config 94 | output = Queue() 95 | training_process = Process(target=training_process_helper, args=(output, config)) 96 | training_process.start() 97 | training_process.join() 98 | output_directory_run, model_name, num_iterations = output.get(block=False) 99 | shutil.move(output_directory_run, output_directory) 100 | metric_file: Path = output_directory / Path(output_directory_run).name / \ 101 | f'test_{num_iterations}' / 'metrics_8bit.txt' 102 | # detect output metric file 103 | if metric_file.is_file(): 104 | metric_files[model_name] = metric_file 105 | success += 1 106 | except Exception as e: 107 | Logger.logError(f'training failed for config file: "{config}" with error:\n{e}') 108 | failed += 1 109 | Logger.log(f'\nfinished sequential training ({success} successful, {failed} failed)') 110 | # combine training results in single table 111 | writeOutputMetrics(metric_files, output_directory) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -- coding: utf-8 -- 3 | 4 | """train.py: Trains a new model from config file.""" 5 | 6 | import utils 7 | 8 | with utils.discoverSourcePath(): 9 | import Framework 10 | from Implementations import Methods as MI 11 | from Implementations import Datasets as DI 12 | from Methods.Base.Trainer import BaseTrainer 13 | from Datasets.Base import BaseDataset 14 | 15 | 16 | def main(config_path: str = None): 17 | Framework.setup(config_path=config_path, require_custom_config=True) 18 | training_instance: BaseTrainer = MI.getTrainingInstance( 19 | method=Framework.config.GLOBAL.METHOD_TYPE, 20 | checkpoint=Framework.config.TRAINING.LOAD_CHECKPOINT 21 | ) 22 | dataset: BaseDataset = DI.getDataset( 23 | dataset_type=Framework.config.GLOBAL.DATASET_TYPE, 24 | path=Framework.config.DATASET.PATH 25 | ) 26 | training_instance.run(dataset) 27 | Framework.teardown() 28 | return training_instance 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """utils.py: utility code for script execution.""" 4 | 5 | import sys 6 | from pathlib import Path 7 | 8 | 9 | class discoverSourcePath: 10 | """A context class adding the source code location to the current python path.""" 11 | 12 | def __enter__(self): 13 | sys.path.insert(0, str(Path(__file__).resolve().parents[1] / 'src')) 14 | 15 | def __exit__(self, *_): 16 | sys.path.pop(0) 17 | 18 | 19 | def getCachePath() -> Path: 20 | """Returns the path to the framework .cache directory.""" 21 | return Path(__file__).resolve().parents[1] / '.cache' 22 | -------------------------------------------------------------------------------- /src/Cameras/Equirectangular.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Cameras/Equirectangular.py: Implements a 360-degree panorama camera model.""" 4 | 5 | import math 6 | from torch import Tensor 7 | import torch 8 | 9 | from Cameras.Base import BaseCamera 10 | 11 | 12 | class EquirectangularCamera(BaseCamera): 13 | """Defines a 360-degree panorama camera model for ray generation.""" 14 | 15 | def __init__(self, near_plane: float, far_plane: float) -> None: 16 | super(EquirectangularCamera, self).__init__(near_plane, far_plane) 17 | 18 | def getRayOrigins(self) -> Tensor: 19 | """Returns a tensor containing the origin of each ray.""" 20 | return self.properties.c2w[:3, -1].expand((self.properties.width * self.properties.height, 3)) 21 | 22 | def _getLocalRayDirections(self) -> Tensor: 23 | """Returns a tensor containing the direction of each ray.""" 24 | azimuth: Tensor = torch.linspace( 25 | start=(-math.pi + (math.pi / self.properties.width)), 26 | end=(math.pi - (math.pi / self.properties.width)), 27 | steps=self.properties.width 28 | )[None, :].expand((self.properties.height, self.properties.width)) 29 | inclination: Tensor = torch.linspace( 30 | start=((-math.pi / 2.0) + (0.5 * math.pi / self.properties.height)), 31 | end=((math.pi / 2.0) - (0.5 * math.pi / self.properties.height)), 32 | steps=self.properties.height 33 | )[:, None].expand((self.properties.height, self.properties.width)) 34 | x_direction: Tensor = torch.sin(azimuth) * torch.cos(inclination) 35 | y_direction: Tensor = torch.sin(inclination) 36 | z_direction: Tensor = -torch.cos(azimuth) * torch.cos(inclination) 37 | directions_camera_space: Tensor = torch.stack( 38 | (x_direction, y_direction, z_direction), dim=-1 39 | ).reshape(-1, 3) 40 | return directions_camera_space 41 | 42 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]: 43 | raise NotImplementedError('point projection not yet implemented for panorama camera') 44 | 45 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor: 46 | raise NotImplementedError('projection matrix not yet implemented for panorama camera') 47 | -------------------------------------------------------------------------------- /src/Cameras/NDC.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Cameras/NDC.py: A perspective RGB camera model generating rays in normalized device space.""" 4 | 5 | from torch import Tensor 6 | import torch 7 | 8 | from Cameras.Perspective import PerspectiveCamera 9 | from Cameras.utils import RayPropertySlice 10 | 11 | 12 | class NDCCamera(PerspectiveCamera): 13 | """Defines a perspective RGB camera model that generates rays in normalized device space.""" 14 | 15 | def __init__(self, cube_scale: float = 1.0, z_plane: float = -1.0, near_plane: float = 0.01, far_plane: float = 1.0) -> None: 16 | super(NDCCamera, self).__init__(near_plane, far_plane * cube_scale) 17 | self.cube_scale: float = cube_scale 18 | self.z_plane: float = z_plane 19 | 20 | def generateRays(self) -> Tensor: 21 | # generate perspective rays 22 | rays: Tensor = super().generateRays() 23 | origins: Tensor = rays[:, RayPropertySlice.origin] 24 | directions: Tensor = rays[:, RayPropertySlice.direction] 25 | # shift rays to near plane 26 | t: Tensor = -(1.0 + origins[..., 2]) / directions[..., 2] 27 | origins: Tensor = origins + t[..., None] * directions 28 | # precompute some intermediate results 29 | w2fx: float = self.properties.width / (2. * self.properties.focal_x) 30 | h2fy: float = self.properties.height / (2. * self.properties.focal_y) 31 | ox_oz: Tensor = origins[..., 0] / origins[..., 2] 32 | oy_oz: Tensor = origins[..., 1] / origins[..., 2] 33 | # projection to normalized device space 34 | o0: Tensor = -1. / w2fx * ox_oz 35 | o1: Tensor = -1. / h2fy * oy_oz 36 | o2: Tensor = 1. + 2. * 1.0 / origins[..., 2] 37 | d0: Tensor = -1. / w2fx * (directions[..., 0] / directions[..., 2] - ox_oz) 38 | d1: Tensor = -1. / h2fy * (directions[..., 1] / directions[..., 2] - oy_oz) 39 | d2: Tensor = 1. - o2 40 | ndc_rays: Tensor = torch.cat([ 41 | torch.stack([o0, o1, o2], -1) * self.cube_scale, 42 | torch.stack([d0, d1, d2], -1) * self.cube_scale, 43 | rays[:, RayPropertySlice.all_annotations_ndc], 44 | ], dim=-1) 45 | ndc_rays[:, 2] = self.z_plane 46 | return ndc_rays 47 | 48 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]: 49 | raise NotImplementedError('point projection not yet implemented for NDC camera') 50 | 51 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor: 52 | raise NotImplementedError('projection matrix not yet implemented for NDC camera') 53 | -------------------------------------------------------------------------------- /src/Cameras/ODS.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Cameras/ODS.py: An omni-directional stereo panorama camera model.""" 4 | 5 | import math 6 | from torch import Tensor 7 | import torch 8 | 9 | from Cameras.Equirectangular import EquirectangularCamera 10 | 11 | 12 | class ODSCamera(EquirectangularCamera): 13 | """ 14 | Defines an omnidirectional stereo panorama camera model for ray generation. 15 | Expects the stereo images to be vertically concatenated. 16 | """ 17 | 18 | def __init__(self, near_plane: float, far_plane: float, baseline: float = 0.065) -> None: 19 | super(ODSCamera, self).__init__(near_plane, far_plane) 20 | self.half_baseline: float = baseline / 2 21 | 22 | def getRayOrigins(self) -> Tensor: 23 | """Returns a tensor containing the origin of each ray.""" 24 | rotation_angle: Tensor = torch.linspace( 25 | start=(-math.pi + (math.pi / self.properties.width)), 26 | end=(math.pi - (math.pi / self.properties.width)), 27 | steps=self.properties.width 28 | )[None, :].expand((self.properties.height // 2, self.properties.width)) 29 | baseline_vector: Tensor = torch.stack( 30 | (torch.cos(rotation_angle), torch.zeros_like(rotation_angle), torch.sin(rotation_angle)), dim=-1 31 | ).reshape(-1, 3) 32 | origins_camera_space: Tensor = torch.cat([ 33 | -self.half_baseline * baseline_vector, 34 | self.half_baseline * baseline_vector 35 | ], dim=0) 36 | origins_world_space: Tensor = torch.matmul( 37 | self.properties.c2w, 38 | torch.cat([origins_camera_space, torch.ones((origins_camera_space.shape[0], 1))], dim=-1)[:, :, None] 39 | ).squeeze() 40 | return origins_world_space[:, :3] 41 | 42 | def _getLocalRayDirections(self) -> Tensor: 43 | """Returns a tensor containing the direction of each ray.""" 44 | # adjust height for vertical concatenation 45 | self.properties.height = self.properties.height // 2 46 | directions: Tensor = super().getLocalRayDirections() 47 | self.properties.height = self.properties.height * 2 48 | return torch.cat([directions, directions], dim=0) 49 | 50 | @property 51 | def baseline(self) -> float: 52 | """Returns the baseline of the stereo camera.""" 53 | return self.half_baseline * 2.0 54 | 55 | @baseline.setter 56 | def baseline(self, value: float) -> None: 57 | self.half_baseline = value / 2.0 58 | 59 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]: 60 | raise NotImplementedError('point projection not yet implemented for ODS camera') 61 | 62 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor: 63 | raise NotImplementedError('projection matrix not yet implemented for ODS camera') 64 | -------------------------------------------------------------------------------- /src/Cameras/Perspective.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Cameras/Perspective.py: Implementation of a perspective RGB camera model.""" 4 | 5 | from torch import Tensor 6 | import torch 7 | 8 | from Cameras.Base import BaseCamera 9 | 10 | 11 | class PerspectiveCamera(BaseCamera): 12 | """Defines a perspective RGB camera model for ray generation.""" 13 | 14 | def __init__(self, near_plane: float, far_plane: float) -> None: 15 | super(PerspectiveCamera, self).__init__(near_plane, far_plane) 16 | 17 | def getRayOrigins(self) -> Tensor: 18 | """Returns a tensor containing the origin of each ray.""" 19 | return self.properties.c2w[:3, -1].expand((self.properties.width * self.properties.height, 3)) 20 | 21 | def _getLocalRayDirections(self) -> Tensor: 22 | """Returns a tensor containing the direction of each ray.""" 23 | # calculate initial directions 24 | x_direction, y_direction = self.getPixelCoordinates() 25 | x_direction: Tensor = ((x_direction + 0.5) - (0.5 * self.properties.width + self.properties.principal_offset_x)) / self.properties.focal_x 26 | y_direction: Tensor = ((y_direction + 0.5) - (0.5 * self.properties.height + self.properties.principal_offset_y)) / self.properties.focal_y 27 | z_direction: Tensor = torch.full((self.properties.height, self.properties.width), fill_value=-1) 28 | directions_camera_space: Tensor = torch.stack( 29 | (x_direction, y_direction, z_direction), dim=-1 30 | ).reshape(-1, 3) 31 | return directions_camera_space 32 | 33 | def projectPoints(self, points: Tensor, eps: float = 1.0e-8) -> tuple[Tensor, Tensor, Tensor]: 34 | """projects points (Nx3) to image plane. returns xy image plane pixel coordinates, 35 | mask of points that hit sensor, and depth values.""" 36 | # points = torch.cat((points, torch.ones((points.shape[0], 1), device=points.device, dtype=points.dtype)), dim=1) 37 | # points = points @ self.properties.w2c.T 38 | points = self.pointsToCameraSpace(points) 39 | depths = -points[:, 2] 40 | valid_mask = ((depths > self.near_plane) & (depths < self.far_plane)) 41 | focal = torch.tensor((self.properties.focal_x, self.properties.focal_y), device=points.device, dtype=points.dtype) 42 | screen_size = torch.tensor((self.properties.width, self.properties.height), device=points.device, dtype=points.dtype) 43 | offset = screen_size * 0.5 + torch.tensor((self.properties.principal_offset_x, self.properties.principal_offset_y), device=points.device, dtype=points.dtype) 44 | points = self.properties.distortion_parameters.distort(points[:, :2] / (depths[:, None] + eps)) * focal + offset 45 | valid_mask &= ((points >= 0) & (points < screen_size - 1)).all(dim=-1) 46 | return points, valid_mask, depths 47 | 48 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor: 49 | """ 50 | Returns the projection matrix for this camera. 51 | After perspective division, x, y, and z will be in [-1, 1] (OpenGL convention). 52 | The z-axis can be inverted for camera coordinate systems where the camera looks along the negative z-axis. 53 | 54 | Args: 55 | invert_z: If True, the z-axis will be inverted. 56 | 57 | Returns: 58 | The projection matrix from camera space to clip space. 59 | """ 60 | half_width = self.properties.width * 0.5 61 | half_height = self.properties.height * 0.5 62 | z_sign = -1.0 if invert_z else 1.0 63 | projection_matrix = torch.tensor([ 64 | [self.properties.focal_x / half_width, 0.0, z_sign * self.properties.principal_offset_x / half_width, 0.0], 65 | [0.0, self.properties.focal_y / half_height, z_sign * self.properties.principal_offset_y / half_height, 0.0], 66 | [0.0, 0.0, z_sign * (self.far_plane + self.near_plane) / (self.far_plane - self.near_plane), -2.0 * self.far_plane * self.near_plane / (self.far_plane - self.near_plane)], 67 | [0.0, 0.0, z_sign, 0.0] 68 | ], dtype=torch.float32, device=self.properties.c2w.device) 69 | return projection_matrix 70 | 71 | def getViewportTransform(self, pixel_centers_at_integer_coordinates: bool = True) -> Tensor: 72 | """ 73 | Returns the transformation matrix from NDC to screen space. 74 | if pixel_centers_at_integer_coordinates: 75 | [-1, 1]^3 -> [-0.5, width - 0.5] x [-0.5, height - 0.5] x [near_plane, far_plane] 76 | else: 77 | [-1, 1]^3 -> [0, width] x [0, height] x [near_plane, far_plane] 78 | 79 | Args: 80 | pixel_centers_at_integer_coordinates: If True, pixel centers will be at integer coordinates. 81 | 82 | Returns: 83 | The transformation matrix from NDC to screen space. 84 | """ 85 | offset = 0.5 if pixel_centers_at_integer_coordinates else 0.0 86 | center_x = self.properties.width * 0.5 87 | center_y = self.properties.height * 0.5 88 | viewport_transform = torch.tensor([ 89 | [center_x, 0.0, 0.0, center_x - offset], 90 | [0.0, center_y, 0.0, center_y - offset], 91 | [0.0, 0.0, 1.0, 0.0], 92 | [0.0, 0.0, 0.0, 1.0] 93 | ], dtype=torch.float32, device=self.properties.c2w.device) 94 | # the following lines add a nonlinear mapping of z from [-1, 1] to [near_plane, far_plane] 95 | # viewport_transform[2, 2] = (self.far_plane - self.near_plane) * 0.5 96 | # viewport_transform[2, 3] = (self.far_plane + self.near_plane) * 0.5 97 | return viewport_transform 98 | -------------------------------------------------------------------------------- /src/Cameras/PerspectiveStereo.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Cameras/PerspectiveStereo.py: Implementation of a perspective camera model generating stereo views.""" 4 | 5 | from torch import Tensor 6 | import torch 7 | 8 | from Cameras.Perspective import PerspectiveCamera 9 | 10 | 11 | class PerspectiveStereoCamera(PerspectiveCamera): 12 | """Defines a perspective camera model that generates rays of vertically concatenated stereo views.""" 13 | 14 | def __init__(self, near_plane: float, far_plane: float, baseline: float = 0.062) -> None: 15 | super(PerspectiveStereoCamera, self).__init__(near_plane, far_plane) 16 | self.half_baseline: float = baseline / 2.0 17 | 18 | def getRayOrigins(self) -> Tensor: 19 | """Returns a tensor containing the origin of each ray.""" 20 | x_axis_world_space: Tensor = self.properties.c2w[:3, 0][None] 21 | # adjust height for vertical concatenation 22 | self.properties.height = self.properties.height // 2 23 | origin_world_space: Tensor = super().getRayOrigins() 24 | self.properties.height = self.properties.height * 2 25 | origins_left: Tensor = origin_world_space - (x_axis_world_space * self.half_baseline) 26 | origins_right: Tensor = origin_world_space + (x_axis_world_space * self.half_baseline) 27 | return torch.cat([origins_left, origins_right], dim=0) 28 | 29 | def _getLocalRayDirections(self) -> Tensor: 30 | """Returns a tensor containing the direction of each ray.""" 31 | # adjust height for vertical concatenation 32 | self.properties.height = self.properties.height // 2 33 | directions: Tensor = super().getLocalRayDirections() 34 | self.properties.height = self.properties.height * 2 35 | return torch.cat([directions, directions], dim=0) 36 | 37 | @property 38 | def baseline(self) -> float: 39 | """Returns the baseline of the stereo camera.""" 40 | return self.half_baseline * 2.0 41 | 42 | @baseline.setter 43 | def baseline(self, value: float) -> None: 44 | self.half_baseline = value / 2.0 45 | 46 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]: 47 | raise NotImplementedError('point projection not yet implemented for PerspectiveStereoCamera') 48 | 49 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor: 50 | raise NotImplementedError('projection matrix not yet implemented for PerspectiveStereoCamera') 51 | -------------------------------------------------------------------------------- /src/Datasets/DNeRF.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Datasets/DNerf.py: Provides a dataset class for D-NeRF scenes. 5 | Data available at https://github.com/albertpumarola/D-NeRF (last accessed 2023-05-25). 6 | """ 7 | 8 | import json 9 | import math 10 | from pathlib import Path 11 | from typing import Any 12 | 13 | import torch 14 | 15 | import Framework 16 | from Cameras.Perspective import PerspectiveCamera 17 | from Cameras.utils import CameraProperties 18 | from Datasets.Base import BaseDataset 19 | from Datasets.utils import applyBGColor, loadImagesParallel, \ 20 | CameraCoordinateSystemsTransformations, WorldCoordinateSystemTransformations 21 | 22 | 23 | @Framework.Configurable.configure( 24 | PATH='dataset/d-nerf/standup', 25 | IMAGE_SCALE_FACTOR=0.5, 26 | NORMALIZE_CUBE=1.5, 27 | ) 28 | class CustomDataset(BaseDataset): 29 | """Dataset class for D-NeRF scenes.""" 30 | 31 | def __init__(self, path: str) -> None: 32 | super().__init__( 33 | path=path, 34 | camera=PerspectiveCamera(2.0, 6.0), 35 | camera_system=CameraCoordinateSystemsTransformations.RIGHT_HAND, 36 | world_system=WorldCoordinateSystemTransformations.XnYnZ 37 | ) 38 | 39 | def load(self) -> dict[str, list[CameraProperties]]: 40 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing.""" 41 | # set bb 42 | self._bounding_box = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]).cpu() 43 | # load data 44 | data: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets} 45 | for subset in self.subsets: 46 | metadata_filepath: Path = self.dataset_path / f'transforms_{subset}.json' 47 | try: 48 | with open(metadata_filepath, 'r') as f: 49 | metadata_file: dict[str, Any] = json.load(f) 50 | except IOError: 51 | raise Framework.DatasetError(f'Invalid dataset metadata file path "{metadata_filepath}"') 52 | opening_angle: float = float(metadata_file['camera_angle_x']) 53 | # load images 54 | image_filenames = [str(self.dataset_path / (frame['file_path'] + '.png')) for frame in metadata_file['frames']] 55 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc=subset) 56 | # create split CameraProperties objects 57 | for frame, rgb, alpha in zip(metadata_file['frames'], rgbs, alphas): 58 | # apply background color where alpha < 1 59 | rgb = applyBGColor(rgb, alpha, self.camera.background_color) 60 | # load camera extrinsics 61 | c2w = torch.as_tensor(frame['transform_matrix'], dtype=torch.float32) 62 | # load camera intrinsics 63 | focal_x: float = 0.5 * rgb.shape[2] / math.tan(0.5 * opening_angle) 64 | focal_y: float = 0.5 * rgb.shape[1] / math.tan(0.5 * opening_angle) 65 | # insert loaded values 66 | data[subset].append(CameraProperties( 67 | width=rgb.shape[2], 68 | height=rgb.shape[1], 69 | rgb=rgb, 70 | alpha=alpha, 71 | c2w=c2w, 72 | focal_x=focal_x, 73 | focal_y=focal_y, 74 | timestamp=frame['time'] 75 | )) 76 | return data 77 | -------------------------------------------------------------------------------- /src/Datasets/LLFF.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Datasets/LLFF.py: Provides a dataset class for Local Light Field Fusion (LLFF) scenes. 5 | Data available at https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 (last accessed 2023-05-25). 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | import Framework 12 | from Cameras.NDC import NDCCamera 13 | from Cameras.Perspective import PerspectiveCamera 14 | from Cameras.utils import CameraProperties 15 | from Datasets.Base import BaseDataset 16 | from Datasets.utils import list_sorted_files, recenterPoses, loadImagesParallel, \ 17 | CameraCoordinateSystemsTransformations 18 | 19 | 20 | @Framework.Configurable.configure( 21 | PATH='dataset/nerf_llff_data/fern', 22 | IMAGE_SCALE_FACTOR=0.25, 23 | DISABLE_NDC=False, 24 | TEST_STEP=8, 25 | WORLD_SCALING=0.75 26 | ) 27 | class CustomDataset(BaseDataset): 28 | """Dataset class for Local Light Field Fusion (LLFF) scenes.""" 29 | 30 | def __init__(self, path: str) -> None: 31 | super().__init__( 32 | path, 33 | NDCCamera(), 34 | CameraCoordinateSystemsTransformations.RIGHT_HAND 35 | ) 36 | 37 | def load(self) -> dict[str, list[CameraProperties] | None]: 38 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing.""" 39 | # TODO K-planes uses this bounding box for llff scenes! 40 | # torch.tensor([[-3.0, -1.67, -1.2], [3.0, 1.67, 1.2]]) 41 | 42 | # load images 43 | images_path = self.dataset_path / 'images' 44 | image_filenames = [str(images_path / file) for file in list_sorted_files(images_path)] 45 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc='images') 46 | # load intrinsics and extrinsics 47 | colmap_poses = torch.as_tensor(np.load(str(self.dataset_path / 'poses_bounds.npy'))) 48 | view_matrices = colmap_poses[:, :-2].reshape([-1, 3, 5]) 49 | intrinsics = view_matrices[:, :, 4] 50 | focals = intrinsics[:, 2:3] 51 | if self.IMAGE_SCALE_FACTOR is not None: 52 | focals = focals * self.IMAGE_SCALE_FACTOR 53 | view_matrices = view_matrices[:, :, :-1] 54 | view_matrices = torch.cat( 55 | [view_matrices[:, :, 1:2], -view_matrices[:, :, 0:1], view_matrices[:, :, 2:]], dim=2 56 | ) 57 | view_matrices = torch.cat( 58 | (view_matrices, torch.broadcast_to(torch.tensor([0, 0, 0, 1]), (view_matrices.shape[0], 1, 4))), dim=1 59 | ) 60 | depth_min_max = colmap_poses[:, -2:] 61 | # rescale coordinates 62 | if self.WORLD_SCALING is not None: 63 | scaling = 1.0 / (depth_min_max.min() * self.WORLD_SCALING) 64 | view_matrices[:, :3, 3] *= scaling 65 | depth_min_max *= scaling 66 | # disable normalized device coordinates, if enabled 67 | if self.DISABLE_NDC: 68 | view_matrices = view_matrices.cpu() 69 | self.camera = PerspectiveCamera( 70 | near_plane=depth_min_max.min().item() * 0.9, 71 | far_plane=depth_min_max.max().item() 72 | ) 73 | self.camera.setBackgroundColor(*self.BACKGROUND_COLOR) 74 | else: 75 | # recenter coordinates cameras (only makes sense for forward facing scenes) 76 | view_matrices = recenterPoses(view_matrices).cpu() 77 | # insert data into target data structure 78 | data: list[CameraProperties] = [ 79 | CameraProperties( 80 | width=rgb.shape[2], 81 | height=rgb.shape[1], 82 | rgb=rgb, 83 | alpha=alpha, 84 | c2w=c2w, 85 | focal_x=focal.item(), 86 | focal_y=focal.item() 87 | ) 88 | for rgb, alpha, c2w, focal in zip(rgbs, alphas, view_matrices, focals) 89 | ] 90 | # perform test split 91 | train_data: list[CameraProperties] = [] 92 | test_data: list[CameraProperties] = [] 93 | if self.TEST_STEP > 0: 94 | for i in range(len(data)): 95 | if i % self.TEST_STEP == 0: 96 | test_data.append(data[i]) 97 | else: 98 | train_data.append(data[i]) 99 | else: 100 | train_data: list[CameraProperties] = data 101 | # return the dataset 102 | return { 103 | 'train': train_data, 104 | 'test': test_data, 105 | 'val': [] 106 | } 107 | -------------------------------------------------------------------------------- /src/Datasets/MipNeRF360.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Datasets/MipNeRF360.py: Provides a dataset class for scenes from the Mip-NeRF 360 dataset. 5 | Data available at https://storage.googleapis.com/gresearch/refraw360/360_v2.zip (last accessed 2024-02-01) and 6 | https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip (last accessed 2024-02-01). 7 | Will also work for any other scene in the same format as the Mip-NeRF 360 dataset. 8 | """ 9 | 10 | import os 11 | import torch 12 | 13 | import Framework 14 | from Logging import Logger 15 | from Cameras.Perspective import PerspectiveCamera 16 | from Cameras.utils import CameraProperties 17 | from Datasets.Base import BaseDataset 18 | from Datasets.utils import CameraCoordinateSystemsTransformations, loadImagesParallel, WorldCoordinateSystemTransformations 19 | from Datasets.Colmap import quaternion_to_R, read_points3D_binary, storePly, fetchPly, read_extrinsics_binary, read_intrinsics_binary, transformPosesPCA 20 | 21 | @Framework.Configurable.configure( 22 | PATH='dataset/mipnerf360/garden', 23 | IMAGE_SCALE_FACTOR=0.25, 24 | BACKGROUND_COLOR=[0.0, 0.0, 0.0], 25 | NEAR_PLANE=0.01, 26 | FAR_PLANE=100.0, 27 | TEST_STEP=8, 28 | APPLY_PCA=True, 29 | APPLY_PCA_RESCALE=True, 30 | USE_PRECOMPUTED_DOWNSCALING=True, 31 | ) 32 | class CustomDataset(BaseDataset): 33 | """Dataset class for MipNeRF360 scenes.""" 34 | 35 | def __init__(self, path: str) -> None: 36 | super().__init__( 37 | path, 38 | PerspectiveCamera(0.01, 100.0), # mipnerf360 itself uses 0.2, 1e6 39 | CameraCoordinateSystemsTransformations.LEFT_HAND, 40 | WorldCoordinateSystemTransformations.XnZY, 41 | ) 42 | 43 | def load(self) -> dict[str, list[CameraProperties] | None]: 44 | """Loads the dataset into a dict containing lists of CameraProperties for training and testing.""" 45 | # set near and far plane to values from config 46 | self.camera.near_plane = self.NEAR_PLANE 47 | self.camera.far_plane = self.FAR_PLANE 48 | 49 | # load colmap data 50 | cam_extrinsics = read_extrinsics_binary(self.dataset_path / 'sparse' / '0' / 'images.bin') 51 | cam_intrinsics = read_intrinsics_binary(self.dataset_path / 'sparse' / '0' / 'cameras.bin') 52 | 53 | # create camera properties 54 | data: list[CameraProperties] = [] 55 | for cam_idx, cam_data in enumerate(cam_intrinsics.values()): 56 | # load images 57 | images = [data for data in cam_extrinsics.values() if data.camera_id == cam_data.id] 58 | images = sorted(images, key=lambda data: data.name) 59 | image_directory_name = 'images' 60 | image_scale_factor = self.IMAGE_SCALE_FACTOR 61 | # optionally use pre-downscaled images 62 | if self.USE_PRECOMPUTED_DOWNSCALING: 63 | match self.IMAGE_SCALE_FACTOR: 64 | case 0.5: 65 | image_directory_name = 'images_2' 66 | image_scale_factor = None 67 | case 0.25: 68 | image_directory_name = 'images_4' 69 | image_scale_factor = None 70 | case 0.125: 71 | image_directory_name = 'images_8' 72 | image_scale_factor = None 73 | case _: 74 | pass 75 | image_filenames = [str(self.dataset_path / image_directory_name / image.name) for image in images] 76 | rgbs, _ = loadImagesParallel(image_filenames, image_scale_factor, num_threads=4, desc=f'camera {cam_data.id}') 77 | for idx, (image, rgb) in enumerate(zip(images, rgbs)): 78 | # extract w2c matrix 79 | rotation_matrix = torch.from_numpy(quaternion_to_R(image.qvec)).float() 80 | translation_vector = torch.from_numpy(image.tvec).float() 81 | w2c = torch.eye(4, device=torch.device('cpu')) 82 | w2c[:3, :3] = rotation_matrix 83 | w2c[:3, 3] = translation_vector 84 | # intrinsics 85 | focal_x = cam_data.params[0] 86 | focal_y = cam_data.params[1] 87 | principal_offset_x = cam_data.params[2] - cam_data.width / 2 88 | principal_offset_y = cam_data.params[3] - cam_data.height / 2 89 | if self.IMAGE_SCALE_FACTOR is not None: 90 | scale_factor_intrinsics_x = rgb.shape[2] / cam_data.width 91 | scale_factor_intrinsics_y = rgb.shape[1] / cam_data.height 92 | focal_x *= scale_factor_intrinsics_x 93 | focal_y *= scale_factor_intrinsics_y 94 | principal_offset_x *= scale_factor_intrinsics_x 95 | principal_offset_y *= scale_factor_intrinsics_y 96 | # create and append camera properties object 97 | camera_properties = CameraProperties( 98 | width=rgb.shape[2], 99 | height=rgb.shape[1], 100 | rgb=rgb, 101 | focal_x=focal_x, 102 | focal_y=focal_y, 103 | principal_offset_x=principal_offset_x, 104 | principal_offset_y=principal_offset_y, 105 | timestamp=idx / (len(images) - 1), # TODO: rename this to id 106 | ) 107 | camera_properties.w2c = w2c 108 | data.append(camera_properties) 109 | 110 | # load point cloud 111 | ply_path = self.dataset_path / 'sparse' / '0' / 'points3D.ply' 112 | if not os.path.exists(ply_path): 113 | Logger.logInfo('Found new scene. Converting sparse SfM points to .ply format.') 114 | xyz, rgb, _ = read_points3D_binary(self.dataset_path / 'sparse' / '0' / 'points3D.bin') 115 | storePly(ply_path, xyz, rgb) 116 | try: 117 | self.point_cloud = fetchPly(ply_path) 118 | except Exception: 119 | raise Framework.DatasetError(f'Failed to load SfM point cloud') 120 | 121 | # rotate/scale poses to align ground with xy plane and optionally fit to [-1, 1]^3 cube 122 | if self.APPLY_PCA: 123 | c2ws = torch.stack([camera.c2w for camera in data]) 124 | c2ws, transformation = transformPosesPCA(c2ws, rescale=self.APPLY_PCA_RESCALE) 125 | for camera_properties, c2w in zip(data, c2ws): 126 | camera_properties.c2w = c2w 127 | self.point_cloud.transform(transformation) 128 | self.world_coordinate_system = None 129 | 130 | # create splits 131 | dataset: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets} 132 | if self.TEST_STEP > 0: 133 | for i in range(len(data)): 134 | if i % self.TEST_STEP == 0: 135 | dataset['test'].append(data[i]) 136 | else: 137 | dataset['train'].append(data[i]) 138 | else: 139 | dataset['train'] = data 140 | 141 | # return the dataset 142 | return dataset 143 | -------------------------------------------------------------------------------- /src/Datasets/NeRF.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Datasets/NeRF.py: Provides a dataset class for NeRF scenes. 5 | Data available at https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 (last accessed 2023-05-25). 6 | """ 7 | 8 | import json 9 | import math 10 | from pathlib import Path 11 | from typing import Any 12 | 13 | import torch 14 | from torchvision import io 15 | 16 | import Framework 17 | from Cameras.Perspective import PerspectiveCamera 18 | from Cameras.utils import CameraProperties 19 | from Datasets.Base import BaseDataset 20 | from Datasets.utils import applyBGColor, loadImagesParallel, getParallelLoadIterator, \ 21 | applyImageScaleFactor, CameraCoordinateSystemsTransformations, WorldCoordinateSystemTransformations 22 | from Logging import Logger 23 | 24 | 25 | @Framework.Configurable.configure( 26 | PATH='dataset/nerf_synthetic/lego', 27 | LOAD_TESTSET_DEPTHS=False, 28 | ) 29 | class CustomDataset(BaseDataset): 30 | """Dataset class for NeRF scenes.""" 31 | 32 | def __init__(self, path: str) -> None: 33 | super().__init__( 34 | path, 35 | PerspectiveCamera(2.0, 6.0), 36 | CameraCoordinateSystemsTransformations.RIGHT_HAND, 37 | WorldCoordinateSystemTransformations.XnYnZ 38 | ) 39 | 40 | def load(self) -> dict[str, list[CameraProperties]]: 41 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing.""" 42 | # set bb 43 | self._bounding_box = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], dtype=torch.float32, device='cpu') 44 | data: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets} 45 | for subset in self.subsets: 46 | metadata_filepath: Path = self.dataset_path / f'transforms_{subset}.json' 47 | try: 48 | with open(metadata_filepath, 'r') as f: 49 | metadata_file: dict[str, Any] = json.load(f) 50 | except IOError: 51 | raise Framework.DatasetError(f'Invalid dataset metadata file path "{metadata_filepath}"') 52 | opening_angle: float = float(metadata_file['camera_angle_x']) 53 | # load images 54 | image_filenames = [str(self.dataset_path / (frame['file_path'] + '.png')) for frame in metadata_file['frames']] 55 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc=subset) 56 | if subset == 'test' and self.LOAD_TESTSET_DEPTHS: 57 | # the synthetic NeRF dataset's test set includes depth maps 58 | depths = self.loadTestsetDepthsParallel(metadata_file['frames'], num_threads=4) 59 | else: 60 | depths = [None] * len(rgbs) 61 | # create split CameraProperties objects 62 | for frame, rgb, alpha, depth in zip(metadata_file['frames'], rgbs, alphas, depths): 63 | # apply background color where alpha < 1 64 | rgb = applyBGColor(rgb, alpha, self.camera.background_color) 65 | # load camera extrinsics 66 | c2w = torch.as_tensor(frame['transform_matrix'], dtype=torch.float32) 67 | # load camera intrinsics 68 | focal_x: float = 0.5 * rgb.shape[2] / math.tan(0.5 * opening_angle) 69 | focal_y: float = 0.5 * rgb.shape[1] / math.tan(0.5 * opening_angle) 70 | # insert loaded values 71 | data[subset].append(CameraProperties( 72 | width=rgb.shape[2], 73 | height=rgb.shape[1], 74 | rgb=rgb, 75 | alpha=alpha, 76 | depth=depth, 77 | c2w=c2w, 78 | focal_x=focal_x, 79 | focal_y=focal_y 80 | )) 81 | # return the dataset 82 | return data 83 | 84 | def loadTestsetDepthsParallel(self, frames: list[dict[str, Any]], num_threads: int) -> list[torch.Tensor]: 85 | """Loads a multiple depth maps in parallel.""" 86 | filenames = [str(next(self.dataset_path.glob(f'{frame["file_path"]}_depth_*.png'))) for frame in frames] 87 | iterator, pool = getParallelLoadIterator(filenames, self.IMAGE_SCALE_FACTOR, num_threads, load_function=self.loadNeRFDepth) 88 | depths = [] 89 | for depth in Logger.logProgressBar(iterator, desc='test depth', leave=False, total=len(filenames)): 90 | # clone tensor to extract it from shared memory (/dev/shm), otherwise we can not use all RAM 91 | depths.append(depth.clone()) 92 | pool.close() 93 | pool.join() 94 | return depths 95 | 96 | @staticmethod 97 | def loadNeRFDepth(filename: str, scale_factor: float | None) -> torch.Tensor: 98 | """Loads a depth map from the test set of a NeRF scene.""" 99 | try: 100 | depth_raw: torch.Tensor = io.read_image(path=filename, mode=io.ImageReadMode.UNCHANGED) 101 | except Exception: 102 | raise Framework.DatasetError(f'Failed to load image file: "{filename}"') 103 | # convert image to the format used by the framework 104 | depth_raw = depth_raw.float() / 255 105 | # apply scaling factor 106 | if scale_factor is not None: 107 | depth_raw = applyImageScaleFactor(depth_raw, scale_factor) 108 | # see depth map creation in blender files of original NeRF codebase 109 | depth = -(depth_raw[:1] - 1) * 8 110 | return depth 111 | -------------------------------------------------------------------------------- /src/Implementations.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Implementations.py: Dynamically provides access to the implemented datasets and methods.""" 4 | 5 | import sys 6 | import importlib 7 | from types import ModuleType 8 | from typing import Type 9 | from pathlib import Path 10 | 11 | import Framework 12 | from Logging import Logger 13 | 14 | from Methods.Base.Model import BaseModel 15 | from Methods.Base.Renderer import BaseRenderer 16 | from Methods.Base.Trainer import BaseTrainer 17 | from Datasets.Base import BaseDataset 18 | 19 | 20 | class Methods: 21 | """A class containing all implemented methods""" 22 | path: Path = Path(__file__).resolve().parents[0] / 'Methods' 23 | options: tuple[str] = tuple([i.name for i in path.iterdir() if i.is_dir() and i.name not in ['Base', '__pycache__']]) 24 | modules: dict[str, ModuleType] = {} 25 | 26 | @staticmethod 27 | def _import(method: str) -> ModuleType: 28 | with setImportPaths(): 29 | m = importlib.import_module(f'Methods.{method}') 30 | return m 31 | 32 | @staticmethod 33 | def importMethod(method: str) -> ModuleType: 34 | """imports the requested method module""" 35 | if method not in Methods.options: 36 | raise Framework.MethodError(f'requested invalid method type: {method}\navailable methods are: {Methods.options}') 37 | if method not in Methods.modules: 38 | try: 39 | Methods.modules[method] = Methods._import(method) 40 | except Exception as e: 41 | raise Framework.MethodError(f'failed to import method {method}:\n{e}') 42 | return Methods.modules[method] 43 | 44 | @staticmethod 45 | def getModel(method: str, checkpoint: str = None, name: str = 'Default') -> BaseModel: 46 | """returns a model of the given type loaded with the provided checkpoint""" 47 | Logger.logInfo('creating model') 48 | model_class: Type[BaseModel] = Methods.importMethod(method).MODEL 49 | return model_class.load(checkpoint) if checkpoint is not None else model_class(name).build() 50 | 51 | @staticmethod 52 | def getRenderer(method: str, model: BaseModel) -> BaseRenderer: 53 | """returns a renderer for the specified method initialized with the given model instance""" 54 | Logger.logInfo('creating renderer') 55 | model_class: Type[BaseRenderer] = Methods.importMethod(method).RENDERER 56 | return model_class(model) 57 | 58 | @staticmethod 59 | def getTrainingInstance(method: str, checkpoint: str | None = None) -> BaseTrainer: 60 | """returns a trainer of the given type loaded with the provided checkpoint""" 61 | Logger.logInfo('creating training instance') 62 | model_class: Type[BaseTrainer] = Methods.importMethod(method).TRAINING_INSTANCE 63 | if checkpoint is not None: 64 | return model_class.load(checkpoint) 65 | model = Methods.getModel(method=method, name=Framework.config.TRAINING.MODEL_NAME) 66 | renderer = Methods.getRenderer(method=method, model=model) 67 | return model_class(model=model, renderer=renderer) 68 | 69 | 70 | class Datasets: 71 | """Dynamically loads and provides access to the implemented datasets.""" 72 | path: Path = Path(__file__).resolve().parents[0] / 'Datasets' 73 | options: tuple[str] = tuple([i.name.split('.')[0] for i in path.iterdir() if i.is_file() and i.name not in ['Base.py', 'utils.py', 'datasets.md']]) 74 | loaded: dict[str, Type[BaseDataset]] = {} 75 | 76 | @staticmethod 77 | def importDataset(dataset_type: str) -> None: 78 | if dataset_type in Datasets.options: 79 | try: 80 | with setImportPaths(): 81 | m = importlib.import_module(f'Datasets.{dataset_type}') 82 | Datasets.loaded[dataset_type] = m.CustomDataset 83 | except Exception: 84 | raise Framework.DatasetError(f'failed to import dataset: {dataset_type}') 85 | else: 86 | raise Framework.DatasetError(f'requested invalid dataset type: {dataset_type}\navailable datasets are: {Datasets.options}') 87 | 88 | @staticmethod 89 | def getDatasetClass(dataset_type: str) -> Type[BaseDataset]: 90 | if dataset_type not in Datasets.loaded: 91 | Datasets.importDataset(dataset_type) 92 | return Datasets.loaded[dataset_type] 93 | 94 | @staticmethod 95 | def getDataset(dataset_type: str, path: str) -> BaseDataset: 96 | """Returns a dataset instance of the given type loaded from the given path.""" 97 | dataset_class: Type[BaseDataset] = Datasets.getDatasetClass(dataset_type) 98 | return dataset_class(path) 99 | 100 | 101 | class setImportPaths: 102 | """helper context adding source code directory to pythonpath during dynamic imports""" 103 | 104 | def __init__(self, sub_path: Path = Path('')): 105 | self.sub_path = sub_path 106 | 107 | def __enter__(self): 108 | sys.path.insert(0, str(Path(__file__).resolve().parents[0] / self.sub_path)) 109 | 110 | def __exit__(self, *_): 111 | sys.path.pop(0) 112 | -------------------------------------------------------------------------------- /src/Logging.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Logging.py: A simple logging facade for level-based printing.""" 4 | 5 | import sys 6 | from typing import Callable, List 7 | from tqdm.auto import tqdm 8 | from datetime import datetime, time 9 | 10 | 11 | class Logger: 12 | """Static class for logging messages with different levels of verbosity.""" 13 | # define logging levels 14 | LOG_LEVEL_RANGE: List[int] = range(4) 15 | MODE_SILENT, MODE_NORMAL, MODE_VERBOSE, MODE_DEBUG = LOG_LEVEL_RANGE 16 | 17 | @classmethod 18 | def setMode(cls, lvl: int) -> None: 19 | """Sets logging mode defining which message types will be printed.""" 20 | cls.logProgressBar, cls.log, cls.logError, cls.logInfo, cls.logWarning, cls.logDebug = cls._fgen( 21 | cls.MODE_NORMAL if lvl not in cls.LOG_LEVEL_RANGE 22 | else lvl, cls.MODE_NORMAL, cls.MODE_VERBOSE, cls.MODE_DEBUG 23 | ) 24 | 25 | @staticmethod 26 | def _fgen(lvl: int, MODE_NORMAL: int, MODE_VERBOSE: int, MODE_DEBUG: int, 27 | _: bool = datetime.now().time() < time(0o7, 0o0)) -> List[Callable]: 28 | """Composes lambda print functions for each logging level.""" 29 | m_data = zip( 30 | [f'\033[{n}m\033[1m{bytearray.fromhex(m).decode()}\033[0m\033[0m{o}: ' for n, m, o in zip( 31 | (91, 92, 93, 94), 32 | ('4552524f52', '494e464f', '5741524e494e47', '4445425547') if not _ else ('4352494e4745', '425457', '535553', '43524f574544'), 33 | [''] * 4 if not _ else (' \U0001F346\U0001F90F', ' \U0001F485', ' \U0001F928', ' \U0001F351') 34 | )], 35 | [MODE_VERBOSE, MODE_VERBOSE, MODE_VERBOSE, MODE_DEBUG] 36 | ) 37 | return [ 38 | (lambda iterable, **kwargs: tqdm(iterable, file=sys.stdout, dynamic_ncols=True, **kwargs)) if lvl >= MODE_NORMAL else (lambda iterable, **_: iterable), 39 | (lambda msg: tqdm.write(f'\033[1m{msg}\033[0m', file=sys.stdout)) if lvl >= MODE_NORMAL else (lambda _: None) 40 | ] + [ 41 | (lambda msg, m_type=n: tqdm.write(f'{m_type}{msg}', file=sys.stdout)) if lvl >= m else (lambda _: None) 42 | for n, m in m_data 43 | ] 44 | 45 | # initialize to MODE_NORMAL 46 | logProgressBar, log, logError, logInfo, logWarning, logDebug = _fgen.__func__( 47 | MODE_NORMAL, MODE_NORMAL, MODE_VERBOSE, MODE_DEBUG 48 | ) 49 | -------------------------------------------------------------------------------- /src/Methods/Base/Model.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Base/Model.py: Abstract base class for scene models.""" 4 | 5 | from abc import ABC, abstractmethod 6 | import datetime 7 | from pathlib import Path 8 | from typing import Callable 9 | import torch 10 | 11 | import Framework 12 | from Methods.Base.utils import getGitCommit 13 | from Logging import Logger 14 | 15 | 16 | class BaseModel(Framework.Configurable, ABC, torch.nn.Module): 17 | """Defines the basic PyTorch neural model.""" 18 | 19 | def __init__(self, name: str = None) -> None: 20 | Framework.Configurable.__init__(self, 'MODEL') 21 | ABC.__init__(self) 22 | torch.nn.Module.__init__(self) 23 | self.model_name: str = name if name is not None else 'Default' 24 | self.creation_date: str = f'{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}' 25 | self.num_iterations_trained: int = 0 26 | self.git_commit: str = None 27 | self.output_directory: Path = Path(__file__).resolve().parents[3] / 'output' / str(Framework.config.GLOBAL.METHOD_TYPE) / f'{self.model_name}_{self.creation_date}' 28 | 29 | @abstractmethod 30 | def build(self) -> 'BaseModel': 31 | """ 32 | Automatically called after model constructor during model initialization / checkpoint loading. 33 | This function should create / register all submodules, model parameters and buffers with correct shape based on the current configuration. 34 | If a parameter is of dynamic shape, i.e. the shape depends on the training data and might differ between checkpoints, 35 | this parameter should be registered as None type using: self.register_buffer('param_name', None). 36 | """ 37 | return self 38 | 39 | def forward(self) -> None: 40 | """Invalidates forward passes of model as all models are executed exclusively through renderers.""" 41 | Logger.logError('Model cannot be executed directly. Use a Renderer instead.') 42 | 43 | def __repr__(self) -> str: 44 | """Returns string representation of the model's metadata.""" 45 | params_string = '' 46 | additional_parameters = type(self).getDefaultParameters().keys() 47 | if additional_parameters: 48 | params_string += '\t Additional parameters:' 49 | for param in additional_parameters: 50 | params_string += f'\n\t\t{param}: {self.__dict__[param]}' 51 | return f'' 58 | 59 | @classmethod 60 | def load(cls, checkpoint_name: str | None, 61 | map_location: Callable = lambda storage, location: storage) -> 'BaseModel': 62 | """Loads a saved model from '.pt' file.""" 63 | if checkpoint_name is None or checkpoint_name.split('.')[-1] != 'pt': 64 | raise Framework.ModelError(f'Invalid model checkpoint: "{checkpoint_name}"') 65 | try: 66 | # load checkpoint 67 | checkpoint_path = Path(__file__).resolve().parents[3] 68 | checkpoint = torch.load(checkpoint_path / checkpoint_name, map_location=map_location) 69 | # create new model 70 | model = cls() 71 | # load model configuration 72 | for param in ['model_name', 'creation_date', 'num_iterations_trained', 'git_commit', 'output_directory'] + list(cls.getDefaultParameters().keys()): 73 | try: 74 | model.__dict__[param] = checkpoint[param] 75 | except KeyError: 76 | Logger.logWarning(f'failed to load model parameter "{param}" -> using default value: "{model.__dict__[param]}"') 77 | # build the model 78 | model.build() 79 | # load model parameters 80 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False) 81 | # print warnings for missing keys 82 | for key in missing_keys: 83 | Logger.logWarning(f'missing key in model checkpoint: "{key}"') 84 | # add parameters of dynamic size 85 | for key in unexpected_keys: 86 | target = model 87 | attr_name = key 88 | while '.' in attr_name: 89 | sub_target, attr_name = key.split('.', 1) 90 | target = getattr(model, sub_target) 91 | if attr_name in target._parameters: 92 | setattr(target, attr_name, torch.nn.Parameter(checkpoint['model_state_dict'][key])) 93 | else: 94 | if hasattr(target, attr_name): 95 | delattr(target, attr_name) 96 | target.register_buffer(attr_name, checkpoint['model_state_dict'][key]) 97 | model.to(Framework.config.GLOBAL.DEFAULT_DEVICE) 98 | except IOError as e: 99 | raise Framework.ModelError(f'failed to load model from file: "{e}"') 100 | # check git commit 101 | if model.git_commit is not None: 102 | git_commit_id = getGitCommit() 103 | if git_commit_id != model.git_commit: 104 | Logger.logWarning(f'Git status mismatch (Model "{model.git_commit}", Current "{git_commit_id}").\n' 105 | '\tCheck out the correct branch/commit for reproducibility.') 106 | return model 107 | 108 | def save(self, path: Path) -> None: 109 | """Saves the current model as '.pt' file.""" 110 | try: 111 | checkpoint = {'model_state_dict': self.state_dict()} 112 | for param in ['model_name', 'creation_date', 'num_iterations_trained', 'git_commit', 'output_directory'] + list(type(self).getDefaultParameters().keys()): 113 | checkpoint[param] = self.__dict__[param] 114 | torch.save(checkpoint, path) 115 | except IOError as e: 116 | Logger.logWarning(f'failed to save model: "{e}"') 117 | 118 | def exportTorchScript(self, path: Path) -> None: 119 | """Exports model as torch script module (e.g. for execution in c++)""" 120 | try: 121 | script_module = torch.jit.script(self) 122 | script_module.save(str(path)) 123 | except IOError as e: 124 | Logger.logWarning(f'failed to generate script module: "{e}"') 125 | 126 | def numModuleParameters(self, trainable_only=False) -> int: 127 | """Returns the model's number of parameters.""" 128 | return sum(p.numel() for p in self.parameters() if (p.requires_grad or not trainable_only)) 129 | -------------------------------------------------------------------------------- /src/Methods/Base/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Base/utils.py: Contains utility functions used for the implementation of the available NeRF methods.""" 4 | 5 | import time 6 | import git 7 | from contextlib import AbstractContextManager, nullcontext 8 | from typing import Callable 9 | from pathlib import Path 10 | 11 | import torch 12 | 13 | import Framework 14 | from Logging import Logger 15 | 16 | 17 | class CallbackTimer(AbstractContextManager): 18 | """Measures system-wide time elapsed during function call.""" 19 | 20 | def __init__(self) -> None: 21 | self.duration: float = 0 22 | self.num_calls: int = 0 23 | 24 | def getValues(self) -> tuple[float, float, int]: 25 | """Returns absolute time, average time per call, and number of total calls of the callback.""" 26 | return self.duration, self.duration / self.num_calls if self.num_calls > 0 else self.duration, self.num_calls 27 | 28 | def __enter__(self) -> None: 29 | """Starts the timer.""" 30 | self.start = time.perf_counter() 31 | 32 | def __exit__(self, *_) -> None: 33 | """Stops the timer and adds the elapsed time to the total execution duration.""" 34 | if Framework.config.GLOBAL.GPU_INDICES is not None: 35 | torch.cuda.synchronize(device=None) 36 | self.end = time.perf_counter() 37 | self.duration += (self.end - self.start) 38 | self.num_calls += 1 39 | 40 | 41 | def callbackDecoratorFactory(callback_type: int = 0, active: bool = True, priority: int = 50, 42 | start_iteration: (int | str | None) = None, end_iteration: (int | str | None) = None, 43 | iteration_stride: (int | str | None) = None) -> Callable: 44 | """ 45 | Decorator registering class members as training Callbacks. If argument is of type string, the value will be copied from the corresponding config variable. 46 | Arguments: 47 | callback_type Indicates if callback is executed before, during or after training. 48 | active Used to deactivate callbacks 49 | priority Determines order of callback execution (higher priority first). 50 | start_iteration Index of first iteration where this callback is called. 51 | end_iteration Last iteration where this callback is called (exclusive). 52 | iteration_stride Number of iterations between callback calls. 53 | """ 54 | def decorator(function: Callable) -> Callable: 55 | def wrapper(*args, **kwargs): 56 | return function(*args, **kwargs) 57 | wrapper.callback_type = callback_type 58 | wrapper.active = active 59 | wrapper.priority = priority 60 | wrapper.start_iteration = start_iteration 61 | wrapper.end_iteration = end_iteration 62 | wrapper.iteration_stride = iteration_stride 63 | wrapper.timer = nullcontext() 64 | wrapper.__name__ = function.__name__ 65 | return wrapper 66 | return decorator 67 | 68 | 69 | def trainingCallback(active: bool | str = True, priority: int = 50, start_iteration: (int | str | None) = None, 70 | end_iteration: (int | str | None) = None, iteration_stride: (int | str | None) = None) -> Callable: 71 | """Training callback decorator.""" 72 | return callbackDecoratorFactory(0, active, priority, start_iteration, end_iteration, iteration_stride) 73 | 74 | 75 | def preTrainingCallback(active: bool | str = True, priority: int = 50) -> Callable: 76 | """Pre-training callback decorator.""" 77 | return callbackDecoratorFactory(-1, active, priority) 78 | 79 | 80 | def postTrainingCallback(active: bool | str = True, priority: int = 50) -> Callable: 81 | """Post-training callback decorator.""" 82 | return callbackDecoratorFactory(1, active, priority) 83 | 84 | 85 | def getGitCommit() -> str | None: 86 | """Writes current git commit to model""" 87 | Logger.logInfo('Checking git status') 88 | parent_path = Path(__file__).resolve().parents[3] 89 | try: 90 | repo = git.Repo(parent_path) 91 | if repo.is_dirty(untracked_files=True): 92 | Logger.logWarning('Detected uncommitted changes in your git repository. Using the latest commit as reference.') 93 | return f'{repo.active_branch}:{repo.head.commit.hexsha}' 94 | except git.InvalidGitRepositoryError: 95 | Logger.logInfo(f'Could not find git repository at "{parent_path}"') 96 | return None 97 | -------------------------------------------------------------------------------- /src/Methods/GaussianSplatting/Loss.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """GaussianSplatting/Loss.py: GaussianSplatting training objective function.""" 4 | 5 | import torch 6 | import torchmetrics 7 | 8 | from Cameras.utils import CameraProperties 9 | from Framework import ConfigParameterList 10 | from Optim.Losses.Base import BaseLoss 11 | from Optim.Losses.DSSIM import DSSIMLoss 12 | 13 | 14 | class GaussianSplattingLoss(BaseLoss): 15 | def __init__(self, loss_config: ConfigParameterList) -> None: 16 | super().__init__() 17 | self.addLossMetric('L1_Color', torch.nn.functional.l1_loss, loss_config.LAMBDA_L1) 18 | self.addLossMetric('DSSIM_Color', DSSIMLoss(), loss_config.LAMBDA_DSSIM) 19 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio) 20 | 21 | def forward(self, outputs: dict[str, torch.Tensor], camera_properties: CameraProperties) -> torch.Tensor: 22 | return super().forward({ 23 | 'L1_Color': {'input': outputs['rgb'], 'target': camera_properties.rgb}, 24 | 'DSSIM_Color': {'input': outputs['rgb'], 'target': camera_properties.rgb}, 25 | 'PSNR': {'preds': outputs['rgb'], 'target': camera_properties.rgb, 'data_range': 1.0} 26 | }) 27 | -------------------------------------------------------------------------------- /src/Methods/GaussianSplatting/Trainer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """GaussianSplatting/Trainer.py: Implementation of the trainer for the GaussianSplatting method. 4 | Callback schedule is slightly modified from the original implementation. They count iterations starting at 1 instead of 0.""" 5 | 6 | import torch 7 | 8 | import Framework 9 | from Datasets.Base import BaseDataset 10 | from Datasets.utils import BasicPointCloud 11 | from Logging import Logger 12 | from Methods.Base.GuiTrainer import GuiTrainer 13 | from Methods.Base.utils import preTrainingCallback, trainingCallback, postTrainingCallback 14 | from Methods.GaussianSplatting.Loss import GaussianSplattingLoss 15 | from Optim.Samplers.DatasetSamplers import DatasetSampler 16 | 17 | 18 | @Framework.Configurable.configure( 19 | NUM_ITERATIONS=30_000, 20 | LEARNING_RATE_POSITION_INIT=0.00016, 21 | LEARNING_RATE_POSITION_FINAL=0.0000016, 22 | LEARNING_RATE_POSITION_DELAY_MULT=0.01, 23 | LEARNING_RATE_POSITION_MAX_STEPS=30_000, 24 | LEARNING_RATE_FEATURE=0.0025, 25 | LEARNING_RATE_OPACITY=0.05, 26 | LEARNING_RATE_SCALING=0.005, 27 | LEARNING_RATE_ROTATION=0.001, 28 | PERCENT_DENSE=0.01, 29 | OPACITY_RESET_INTERVAL=3_000, 30 | DENSIFY_START_ITERATION=500, 31 | DENSIFY_END_ITERATION=15_000, 32 | DENSIFICATION_INTERVAL=100, 33 | DENSIFY_GRAD_THRESHOLD=0.0002, 34 | LOSS=Framework.ConfigParameterList( 35 | LAMBDA_L1=0.8, 36 | LAMBDA_DSSIM=0.2, 37 | ), 38 | ) 39 | class GaussianSplattingTrainer(GuiTrainer): 40 | """Defines the trainer for the GaussianSplatting variant.""" 41 | 42 | def __init__(self, **kwargs) -> None: 43 | super(GaussianSplattingTrainer, self).__init__(**kwargs) 44 | self.train_sampler = None 45 | self.loss = GaussianSplattingLoss(loss_config=self.LOSS) 46 | 47 | @preTrainingCallback(priority=50) 48 | @torch.no_grad() 49 | def createSampler(self, _, dataset: 'BaseDataset') -> None: 50 | """Creates the sampler.""" 51 | self.train_sampler = DatasetSampler(dataset=dataset.train(), random=True) 52 | 53 | @preTrainingCallback(priority=40) 54 | @torch.no_grad() 55 | def setupGaussians(self, _, dataset: 'BaseDataset') -> None: 56 | """Sets up the model.""" 57 | camera_centers = torch.stack([camera_properties.T for camera_properties in dataset.train()]) 58 | radius = (1.1 * torch.max(torch.linalg.norm(camera_centers - torch.mean(camera_centers, dim=0), dim=1))).item() 59 | Logger.logInfo(f'Training cameras extent: {radius:.2f}') 60 | 61 | if dataset.point_cloud is not None: 62 | point_cloud = dataset.point_cloud 63 | else: 64 | n_random_points = 100_000 65 | min_bounds, max_bounds = dataset.getBoundingBox() 66 | extent = max_bounds - min_bounds 67 | point_cloud = BasicPointCloud(torch.rand(n_random_points, 3, dtype=torch.float32, device=min_bounds.device) * extent + min_bounds) 68 | self.model.gaussians.initialize_from_point_cloud(point_cloud, radius) 69 | self.model.gaussians.training_setup(self) 70 | 71 | @trainingCallback(priority=110, start_iteration=1000, iteration_stride=1000) 72 | @torch.no_grad() 73 | def increaseSHDegree(self, *_) -> None: 74 | """Increase the levels of SH up to a maximum degree.""" 75 | self.model.gaussians.increase_used_sh_degree() 76 | 77 | @trainingCallback(priority=100) 78 | def trainingIteration(self, iteration: int, dataset: 'BaseDataset') -> None: 79 | """Performs a training step without actually doing the optimizer step.""" 80 | # init modes 81 | self.model.train() 82 | dataset.train() 83 | self.loss.train() 84 | # update learning rate 85 | self.model.gaussians.update_learning_rate(iteration + 1) 86 | # get random sample from dataset 87 | camera_properties = self.train_sampler.get(dataset=dataset)['camera_properties'] 88 | dataset.camera.setProperties(camera_properties) 89 | # render sample 90 | outputs = self.renderer.renderImage(camera=dataset.camera, to_chw=True) 91 | # calculate loss 92 | loss = self.loss(outputs, camera_properties) 93 | loss.backward() 94 | # track values for pruning and densification 95 | if iteration < self.DENSIFY_END_ITERATION: 96 | self.model.gaussians.add_densification_stats(outputs['viewspace_points'], outputs['visibility_mask']) 97 | 98 | @trainingCallback(priority=90, start_iteration='DENSIFY_START_ITERATION', end_iteration='DENSIFY_END_ITERATION', iteration_stride='DENSIFICATION_INTERVAL') 99 | @torch.no_grad() 100 | def densify(self, iteration: int, _) -> None: 101 | """Apply densification.""" 102 | if iteration == self.DENSIFY_START_ITERATION or iteration == self.DENSIFY_END_ITERATION: # matches behavior of official 3dgs codebase 103 | return 104 | self.model.gaussians.densify_and_prune(self.DENSIFY_GRAD_THRESHOLD, 0.005, iteration > self.OPACITY_RESET_INTERVAL) 105 | 106 | @trainingCallback(priority=80, start_iteration='OPACITY_RESET_INTERVAL', end_iteration='DENSIFY_END_ITERATION', iteration_stride='OPACITY_RESET_INTERVAL') 107 | @torch.no_grad() 108 | def resetOpacities(self, iteration: int, _) -> None: 109 | """Reset opacities.""" 110 | if iteration == self.DENSIFY_END_ITERATION: # matches behavior of official 3dgs codebase 111 | return 112 | self.model.gaussians.reset_opacities() 113 | 114 | @trainingCallback(priority=80, start_iteration='DENSIFY_START_ITERATION', iteration_stride='NUM_ITERATIONS') 115 | @torch.no_grad() 116 | def resetOpacitiesWhiteBackground(self, _, dataset: 'BaseDataset') -> None: 117 | """Reset opacities one additional time when using a white background.""" 118 | # original implementation only supports black or white background, this is an attempt to make it work with any color 119 | if (dataset.camera.background_color > 0.0).any(): 120 | self.model.gaussians.reset_opacities() 121 | 122 | @trainingCallback(priority=70) 123 | @torch.no_grad() 124 | def performOptimizerStep(self, *_) -> None: 125 | """Update parameters.""" 126 | self.model.gaussians.optimizer.step() 127 | self.model.gaussians.optimizer.zero_grad() 128 | 129 | @postTrainingCallback(priority=1000) 130 | @torch.no_grad() 131 | def bakeActivations(self, *_) -> None: 132 | """Bake relevant activation functions after training.""" 133 | self.model.gaussians.bake_activations() 134 | # delete optimizer to save memory 135 | self.model.gaussians.optimizer = None 136 | -------------------------------------------------------------------------------- /src/Methods/GaussianSplatting/__init__.py: -------------------------------------------------------------------------------- 1 | from Methods.GaussianSplatting.Model import GaussianSplattingModel 2 | from Methods.GaussianSplatting.Renderer import GaussianSplattingRenderer 3 | from Methods.GaussianSplatting.Trainer import GaussianSplattingTrainer 4 | 5 | MODEL = GaussianSplattingModel 6 | RENDERER = GaussianSplattingRenderer 7 | TRAINING_INSTANCE = GaussianSplattingTrainer 8 | -------------------------------------------------------------------------------- /src/Methods/GaussianSplatting/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | GaussianSplatting/utils.py: Utility functions for GaussianSplatting. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from Cameras.utils import quaternion_to_rotation_matrix 13 | 14 | 15 | @dataclass(frozen=True) 16 | class LRDecayPolicy(object): 17 | """Allows for flexible definition of a decay policy for a learning rate.""" 18 | lr_init: float = 1.0 19 | lr_final: float = 1.0 20 | lr_delay_steps: int = 0 21 | lr_delay_mult: float = 1.0 22 | max_steps: int = 1000000 23 | 24 | # taken from https://github.com/sxyu/svox2/blob/master/opt/util/util.py#L78 25 | def __call__(self, iteration) -> float: 26 | """Calculates learning rate for the given iteration.""" 27 | if iteration < 0 or (self.lr_init == 0.0 and self.lr_final == 0.0): 28 | # Disable this parameter 29 | return 0.0 30 | if self.lr_delay_steps > 0 and iteration < self.lr_delay_steps: 31 | # A kind of reverse cosine decay. 32 | delay_rate = self.lr_delay_mult + (1 - self.lr_delay_mult) * np.sin( 33 | 0.5 * np.pi * np.clip(iteration / self.lr_delay_steps, 0, 1) 34 | ) 35 | else: 36 | delay_rate = 1.0 37 | t = np.clip(iteration / self.max_steps, 0, 1) 38 | log_lerp = np.exp(np.log(self.lr_init) * (1 - t) + np.log(self.lr_final) * t) 39 | return delay_rate * log_lerp 40 | 41 | 42 | def inverse_sigmoid(x: torch.Tensor) -> torch.Tensor: 43 | return torch.log(x / (1.0 - x)) 44 | 45 | 46 | def build_covariances(scales: torch.Tensor, rotations: torch.Tensor) -> torch.Tensor: 47 | R = quaternion_to_rotation_matrix(rotations, normalize=False) 48 | # add batch dimension if necessary 49 | batch_dim_added = False 50 | if scales.dim() == 1: 51 | scales = scales[None] 52 | batch_dim_added = True 53 | S = torch.diag_embed(scales) 54 | RS = R @ S 55 | RSSR = RS @ RS.transpose(-2, -1) 56 | return RSSR[0] if batch_dim_added else RSSR 57 | 58 | 59 | def convert_sh_features(sh_features: torch.Tensor, view_directions: torch.Tensor, degree: int) -> torch.Tensor: 60 | """ 61 | Convert spherical harmonics features to RGB. 62 | As in 3DGS, we do not use Sigmoid but instead add 0.5 and clamp with 0 from below. 63 | 64 | adapted from multiple sources: 65 | 1. https://www.ppsloan.org/publications/StupidSH36.pdf 66 | 2. https://github.com/sxyu/svox2/blob/59984d6c4fd3d713353bafdcb011646e64647cc7/svox2/utils.py#L115 67 | 3. https://github.com/NVlabs/tiny-cuda-nn/blob/212104156403bd87616c1a4f73a1c5f2c2e172a9/include/tiny-cuda-nn/common_device.h#L340 68 | 4. https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d/cuda_rasterizer/forward.cu#L20 69 | """ 70 | result = 0.5 + 0.28209479177387814 * sh_features[..., 0] 71 | if degree == 0: 72 | return result.clamp_min(0.0) 73 | x = view_directions[..., 0:1] 74 | y = view_directions[..., 1:2] 75 | z = view_directions[..., 2:3] 76 | result += -0.48860251190291987 * y * sh_features[..., 1] 77 | result += 0.48860251190291987 * z * sh_features[..., 2] 78 | result += -0.48860251190291987 * x * sh_features[..., 3] 79 | if degree == 1: 80 | return result.clamp_min(0.0) 81 | x2, y2, z2 = x * x, y * y, z * z 82 | xy, yz, xz = x * y, y * z, x * z 83 | result += 1.0925484305920792 * xy * sh_features[..., 4] 84 | result += -1.0925484305920792 * yz * sh_features[..., 5] 85 | result += (0.94617469575755997 * z2 - 0.31539156525251999) * sh_features[..., 6] 86 | result += -1.0925484305920792 * xz * sh_features[..., 7] 87 | result += 0.54627421529603959 * (x2 - y2) * sh_features[..., 8] 88 | if degree == 2: 89 | return result.clamp_min(0.0) 90 | result += 0.59004358992664352 * y * (-3.0 * x2 + y2) * sh_features[..., 9] 91 | result += 2.8906114426405538 * xy * z * sh_features[..., 10] 92 | result += 0.45704579946446572 * y * (1.0 - 5.0 * z2) * sh_features[..., 11] 93 | result += 0.3731763325901154 * z * (5.0 * z2 - 3.0) * sh_features[..., 12] 94 | result += 0.45704579946446572 * x * (1.0 - 5.0 * z2) * sh_features[..., 13] 95 | result += 1.4453057213202769 * z * (x2 - y2) * sh_features[..., 14] 96 | result += 0.59004358992664352 * x * (-x2 + 3.0 * y2) * sh_features[..., 15] 97 | return result.clamp_min(0.0) 98 | 99 | 100 | def rgb_to_sh0(rgb: torch.Tensor | float) -> torch.Tensor | float: 101 | return (rgb - 0.5) / 0.28209479177387814 102 | 103 | 104 | def sh0_to_rgb(sh: torch.Tensor | float) -> torch.Tensor | float: 105 | return sh * 0.28209479177387814 + 0.5 106 | 107 | 108 | def extract_upper_triangular_matrix(matrix: torch.Tensor) -> torch.Tensor: 109 | upper_triangular_indices = torch.triu_indices(matrix.shape[-2], matrix.shape[-1]) 110 | upper_triangular_matrix = matrix[..., upper_triangular_indices[0], upper_triangular_indices[1]] 111 | return upper_triangular_matrix 112 | -------------------------------------------------------------------------------- /src/Methods/HierarchicalNeRF/Loss.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """HierarchicalNeRF/Loss.py: Loss implementation for the hierarchical NeRF method.""" 4 | 5 | import torch 6 | import torchmetrics 7 | 8 | from Cameras.utils import RayPropertySlice 9 | from Optim.Losses.Base import BaseLoss 10 | 11 | 12 | class HierarchicalNeRFLoss(BaseLoss): 13 | """Defines a class for all sub-losses of the hierarchical NeRF method.""" 14 | 15 | def __init__(self, lambda_color: float, lambda_alpha: float) -> None: 16 | super().__init__() 17 | self.addLossMetric('L2_Color', torch.nn.functional.mse_loss, lambda_color) 18 | self.addLossMetric('L2_Color_Coarse', torch.nn.functional.mse_loss, lambda_color) 19 | self.addLossMetric('L2_Alpha', torch.nn.functional.mse_loss, lambda_alpha) 20 | self.addLossMetric('L2_Alpha_Coarse', torch.nn.functional.mse_loss, lambda_alpha) 21 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio) 22 | self.addQualityMetric('PSNR_coarse', torchmetrics.functional.image.peak_signal_noise_ratio) 23 | 24 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor) -> torch.Tensor: 25 | """Defines loss calculation.""" 26 | return super().forward({ 27 | 'L2_Color': {'input': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb]}, 28 | 'L2_Color_Coarse': {'input': outputs['rgb_coarse'], 'target': rays[:, RayPropertySlice.rgb]}, 29 | 'L2_Alpha': {'input': outputs['alpha'], 'target': rays[:, RayPropertySlice.alpha]}, 30 | 'L2_Alpha_Coarse': {'input': outputs['alpha_coarse'], 'target': rays[:, RayPropertySlice.alpha]}, 31 | 'PSNR': {'preds': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0}, 32 | 'PSNR_coarse': {'preds': outputs['rgb_coarse'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0} 33 | }) 34 | -------------------------------------------------------------------------------- /src/Methods/HierarchicalNeRF/Model.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """HierarchicalNeRF/Model.py: Implementation of the neural model for the hierarchical NeRF method.""" 4 | 5 | from Methods.NeRF.Model import NeRF, NeRFBlock 6 | 7 | 8 | class HierarchicalNeRF(NeRF): 9 | """Defines a hierarchical NeRF model containing a coarse and a fine version for efficient sampling.""" 10 | 11 | def __init__(self, name: str = None) -> None: 12 | super(HierarchicalNeRF, self).__init__(name) 13 | 14 | def build(self) -> 'HierarchicalNeRF': 15 | """Builds the model.""" 16 | self.coarse = NeRFBlock( 17 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES, 18 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT, 19 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION 20 | ) 21 | self.fine = NeRFBlock( 22 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES, 23 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT, 24 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION 25 | ) 26 | return self 27 | -------------------------------------------------------------------------------- /src/Methods/HierarchicalNeRF/Trainer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """HierarchicalNeRF/Trainer.py: Implementation of the trainer for the hierarchical NeRF method.""" 4 | 5 | import torch 6 | 7 | from Methods.HierarchicalNeRF.Loss import HierarchicalNeRFLoss 8 | from Methods.NeRF.Trainer import NeRFTrainer 9 | 10 | 11 | class HierarchicalNeRFTrainer(NeRFTrainer): 12 | """Defines the trainer for the hierarchical NeRF method.""" 13 | 14 | def __init__(self, **kwargs) -> None: 15 | super(HierarchicalNeRFTrainer, self).__init__(**kwargs) 16 | self.optimizer = torch.optim.Adam( 17 | self.model.parameters(), 18 | lr=self.LEARNINGRATE, 19 | betas=(self.ADAM_BETA_1, self.ADAM_BETA_2) 20 | # eps=Framework.config.GLOBAL.EPS 21 | ) 22 | for param_group in self.optimizer.param_groups: 23 | param_group['capturable'] = True # Hacky fix for PT 1.12 bug 24 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 25 | self.optimizer, 26 | lr_lambda=self.LRDecayPolicy(self.LEARNINGRATE_DECAY_RATE, self.LEARNINGRATE_DECAY_STEPS), 27 | last_epoch=self.model.num_iterations_trained - 1 28 | ) 29 | self.loss = HierarchicalNeRFLoss(self.LAMBDA_COLOR_LOSS, self.LAMBDA_ALPHA_LOSS) 30 | self.renderer.RENDER_COARSE = True 31 | -------------------------------------------------------------------------------- /src/Methods/HierarchicalNeRF/__init__.py: -------------------------------------------------------------------------------- 1 | from Methods.HierarchicalNeRF.Model import HierarchicalNeRF 2 | from Methods.HierarchicalNeRF.Renderer import HierarchicalNeRFRenderer 3 | from Methods.HierarchicalNeRF.Trainer import HierarchicalNeRFTrainer 4 | 5 | MODEL = HierarchicalNeRF 6 | RENDERER = HierarchicalNeRFRenderer 7 | TRAINING_INSTANCE = HierarchicalNeRFTrainer 8 | -------------------------------------------------------------------------------- /src/Methods/HierarchicalNeRF/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | HierarchicalNeRF/utils.py: Contains utility functions used for the implementation of the HierarchicalNeRF method. 5 | """ 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | 11 | def generateSamplesFromPDF(bins: Tensor, values: Tensor, num_samples: int, randomize_samples: bool) -> Tensor: 12 | """Returns samples from probability density function along ray.""" 13 | device: torch.device = bins.device 14 | bins = 0.5 * (bins[..., :-1] + bins[..., 1:]) 15 | values = values[..., 1:-1] + 1e-5 16 | pdf = values / torch.sum(values, -1, keepdim=True) 17 | cdf = torch.cumsum(pdf, -1) 18 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 19 | if randomize_samples: 20 | u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=device) 21 | else: 22 | u = torch.linspace(0., 1., steps=num_samples, device=device) 23 | u = u.expand(list(cdf.shape[:-1]) + [num_samples]) 24 | # Invert CDF 25 | u = u.contiguous() 26 | inds = torch.searchsorted(cdf, u, right=True) 27 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 28 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 29 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 30 | 31 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 32 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 33 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 34 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 35 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 36 | 37 | denom: Tensor = (cdf_g[..., 1] - cdf_g[..., 0]) 38 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 39 | t: Tensor = (u - cdf_g[..., 0]) / denom 40 | samples: Tensor = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 41 | return samples.detach() 42 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/__init__.py: -------------------------------------------------------------------------------- 1 | """Volume Rendering code adapted from pytorch InstantNGP reimplementation, adapted from kwea123 (https://github.com/kwea123/ngp_pl)""" 2 | 3 | from pathlib import Path 4 | 5 | import Framework 6 | 7 | filepath = Path(__file__).resolve() 8 | __extension_name__ = filepath.parent.stem 9 | __install_command__ = [ 10 | 'pip', 'install', 11 | str(Path(__file__).parent), 12 | ] 13 | 14 | try: 15 | from VolumeRenderingV2 import * # noqa 16 | from .custom_functions import * # noqa 17 | except ImportError: 18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 19 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/csrc/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | 9 | std::vector ray_aabb_intersect_cu( 10 | const torch::Tensor rays_o, 11 | const torch::Tensor rays_d, 12 | const torch::Tensor centers, 13 | const torch::Tensor half_sizes, 14 | const int max_hits 15 | ); 16 | 17 | 18 | std::vector ray_sphere_intersect_cu( 19 | const torch::Tensor rays_o, 20 | const torch::Tensor rays_d, 21 | const torch::Tensor centers, 22 | const torch::Tensor radii, 23 | const int max_hits 24 | ); 25 | 26 | 27 | void packbits_cu( 28 | torch::Tensor density_grid, 29 | const float density_threshold, 30 | torch::Tensor density_bitfield 31 | ); 32 | 33 | 34 | torch::Tensor morton3D_cu(const torch::Tensor coords); 35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices); 36 | 37 | 38 | std::vector raymarching_train_cu( 39 | const torch::Tensor rays_o, 40 | const torch::Tensor rays_d, 41 | const torch::Tensor hits_t, 42 | const torch::Tensor density_bitfield, 43 | const int cascades, 44 | const float scale, 45 | const float exp_step_factor, 46 | const torch::Tensor noise, 47 | const int grid_size, 48 | const int max_samples 49 | ); 50 | 51 | 52 | std::vector raymarching_test_cu( 53 | const torch::Tensor rays_o, 54 | const torch::Tensor rays_d, 55 | torch::Tensor hits_t, 56 | const torch::Tensor alive_indices, 57 | const torch::Tensor density_bitfield, 58 | const int cascades, 59 | const float scale, 60 | const float exp_step_factor, 61 | const int grid_size, 62 | const int max_samples, 63 | const int N_samples 64 | ); 65 | 66 | 67 | std::vector composite_train_fw_cu( 68 | const torch::Tensor sigmas, 69 | const torch::Tensor rgbs, 70 | const torch::Tensor deltas, 71 | const torch::Tensor ts, 72 | const torch::Tensor rays_a, 73 | const float T_threshold 74 | ); 75 | 76 | 77 | std::vector composite_train_bw_cu( 78 | const torch::Tensor dL_dopacity, 79 | const torch::Tensor dL_ddepth, 80 | const torch::Tensor dL_drgb, 81 | const torch::Tensor dL_dws, 82 | const torch::Tensor sigmas, 83 | const torch::Tensor rgbs, 84 | const torch::Tensor ws, 85 | const torch::Tensor deltas, 86 | const torch::Tensor ts, 87 | const torch::Tensor rays_a, 88 | const torch::Tensor opacity, 89 | const torch::Tensor depth, 90 | const torch::Tensor rgb, 91 | const float T_threshold 92 | ); 93 | 94 | 95 | void composite_test_fw_cu( 96 | const torch::Tensor sigmas, 97 | const torch::Tensor rgbs, 98 | const torch::Tensor deltas, 99 | const torch::Tensor ts, 100 | const torch::Tensor hits_t, 101 | const torch::Tensor alive_indices, 102 | const float T_threshold, 103 | const torch::Tensor N_eff_samples, 104 | torch::Tensor opacity, 105 | torch::Tensor depth, 106 | torch::Tensor rgb 107 | ); 108 | 109 | 110 | std::vector distortion_loss_fw_cu( 111 | const torch::Tensor ws, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a 115 | ); 116 | 117 | 118 | torch::Tensor distortion_loss_bw_cu( 119 | const torch::Tensor dL_dloss, 120 | const torch::Tensor ws_inclusive_scan, 121 | const torch::Tensor wts_inclusive_scan, 122 | const torch::Tensor ws, 123 | const torch::Tensor deltas, 124 | const torch::Tensor ts, 125 | const torch::Tensor rays_a 126 | ); -------------------------------------------------------------------------------- /src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/csrc/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h 10 | 11 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 12 | 13 | 14 | setup( 15 | name='vren', 16 | version='2.0', 17 | author='kwea123', 18 | author_email='kwea123@gmail.com', 19 | description='cuda volume rendering library', 20 | long_description='cuda volume rendering library', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='vren', 24 | sources=sources, 25 | include_dirs=include_dirs, 26 | extra_compile_args={'cxx': ['-O2'], 27 | 'nvcc': ['-O2']} 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) -------------------------------------------------------------------------------- /src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from pathlib import Path 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = Path(__file__).parent.absolute() 8 | include_dirs = [str(ROOT_DIR / 'csrc' / 'include')] 9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h 10 | 11 | sources = glob.glob(str(ROOT_DIR / 'csrc' / '*.cpp')) + glob.glob(str(ROOT_DIR / 'csrc' / '*.cu')) 12 | 13 | 14 | setup( 15 | name='VolumeRenderingV2', 16 | version='2.0', 17 | author='kwea123', 18 | author_email='kwea123@gmail.com', 19 | description='cuda volume rendering library', 20 | long_description='cuda volume rendering library', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='VolumeRenderingV2', 24 | sources=sources, 25 | include_dirs=include_dirs, 26 | extra_compile_args={'cxx': ['-O2'], 27 | 'nvcc': ['-O2']} 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/Loss.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """InstantNGP/Loss.py: Loss calculation for InstantNGP.""" 4 | 5 | import torch 6 | import torchmetrics 7 | 8 | from Cameras.utils import RayPropertySlice 9 | from Methods.InstantNGP import InstantNGPModel 10 | from Optim.Losses.Base import BaseLoss 11 | 12 | 13 | class InstantNGPLoss(BaseLoss): 14 | def __init__(self, model: 'InstantNGPModel') -> None: 15 | super().__init__() 16 | self.addLossMetric('MSE_Color', torch.nn.functional.mse_loss, 1.0) 17 | self.addLossMetric('Weight_Decay_MLP', model.weight_decay_mlp, 1.0e-6 / 2.0) 18 | self.addQualityMetric('PSNR', torchmetrics.functional.peak_signal_noise_ratio) 19 | 20 | @torch.amp.autocast('cuda', dtype=torch.float32) 21 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor, bg_color: torch.Tensor) -> torch.Tensor: 22 | color_gt = rays[:, RayPropertySlice.rgb] + (1.0 - rays[:, RayPropertySlice.alpha]) * bg_color 23 | return super().forward({ 24 | 'MSE_Color': {'input': outputs['rgb'], 'target': color_gt}, 25 | 'Weight_Decay_MLP': {}, 26 | 'PSNR': {'preds': outputs['rgb'], 'target': color_gt, 'data_range': 1.0} 27 | }) 28 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/Model.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """InstantNGP/Model.py: InstantNGP Scene Model Implementation.""" 4 | 5 | import numpy as np 6 | import torch 7 | from kornia.utils.grid import create_meshgrid3d 8 | 9 | import Framework 10 | from Methods.Base.Model import BaseModel 11 | from Methods.InstantNGP.utils import next_multiple 12 | from Thirdparty import TinyCudaNN as tcnn 13 | 14 | 15 | @Framework.Configurable.configure( 16 | SCALE=0.5, 17 | RESOLUTION=128, 18 | CENTER=[0.0, 0.0, 0.0], 19 | HASHMAP_NUM_LEVELS=16, 20 | HASHMAP_NUM_FEATURES_PER_LEVEL=2, 21 | HASHMAP_LOG2_SIZE=19, 22 | HASHMAP_BASE_RESOLUTION=16, 23 | HASHMAP_TARGET_RESOLUTION=2048, 24 | NUM_DENSITY_OUTPUT_FEATURES=16, 25 | NUM_DENSITY_NEURONS=64, 26 | NUM_DENSITY_LAYERS=1, 27 | DIR_SH_ENCODING_DEGREE=4, 28 | NUM_COLOR_NEURONS=64, 29 | NUM_COLOR_LAYERS=2, 30 | ) 31 | class InstantNGPModel(BaseModel): 32 | """Defines InstantNGP data model""" 33 | 34 | def __init__(self, name: str = None) -> None: 35 | super().__init__(name) 36 | 37 | def __del__(self) -> None: 38 | torch.cuda.empty_cache() 39 | tcnn.free_temporary_memory() 40 | 41 | def weight_decay_mlp(self) -> torch.Tensor: 42 | """Calculates the weight decay for the MLPs.""" 43 | loss = self.encoding_xyz.params[:self.n_params_encoding_mlp].pow(2).sum() 44 | loss += self.color_mlp_with_encoding.params.pow(2).sum() 45 | # loss /= 2 # fused into loss lambda 46 | loss /= self.n_mlp_params 47 | return loss 48 | 49 | def build(self) -> 'InstantNGPModel': 50 | """Builds the model.""" 51 | # createGrid 52 | self.center = torch.Tensor([self.CENTER]) 53 | self.xyz_min = -torch.ones(1, 3) * self.SCALE 54 | self.xyz_max = torch.ones(1, 3) * self.SCALE 55 | self.xyz_size = self.xyz_max - self.xyz_min 56 | self.half_size = (self.xyz_max - self.xyz_min) / 2 57 | self.cascades = max(1 + int(np.ceil(np.log2(2 * self.SCALE))), 1) 58 | self.register_buffer('density_grid', torch.zeros(self.cascades, self.RESOLUTION ** 3)) 59 | self.register_buffer('grid_coords', create_meshgrid3d(self.RESOLUTION, self.RESOLUTION, self.RESOLUTION, False, dtype=torch.int32, device=self.density_grid.device).reshape(-1, 3)) 60 | self.register_buffer('density_bitfield', torch.zeros(self.cascades*self.RESOLUTION ** 3 // 8, dtype=torch.uint8)) 61 | # create encodings and networks 62 | self.encoding_xyz = tcnn.NetworkWithInputEncoding( 63 | n_input_dims=3, 64 | n_output_dims=self.NUM_DENSITY_OUTPUT_FEATURES, 65 | encoding_config={ 66 | 'otype': 'Grid', 67 | 'type': 'Hash', 68 | 'n_levels': self.HASHMAP_NUM_LEVELS, 69 | 'n_features_per_level': self.HASHMAP_NUM_FEATURES_PER_LEVEL, 70 | 'log2_hashmap_size': self.HASHMAP_LOG2_SIZE, 71 | 'base_resolution': self.HASHMAP_BASE_RESOLUTION, 72 | 'per_level_scale': np.exp(np.log(self.HASHMAP_TARGET_RESOLUTION * (2 * self.SCALE) / self.HASHMAP_BASE_RESOLUTION) / (self.HASHMAP_NUM_LEVELS - 1)), 73 | 'interpolation': 'Linear' 74 | }, 75 | network_config={ 76 | 'otype': 'FullyFusedMLP', 77 | 'activation': 'ReLU', 78 | 'output_activation': 'None', 79 | 'n_neurons': self.NUM_DENSITY_NEURONS, 80 | 'n_hidden_layers': self.NUM_DENSITY_LAYERS, 81 | }, 82 | seed=Framework.config.GLOBAL.RANDOM_SEED 83 | ) 84 | # calculate number of parameters for the MLP part of the encoding for weight decay 85 | n_params_mlp = 0 86 | n_inputs = next_multiple(self.HASHMAP_NUM_FEATURES_PER_LEVEL * self.HASHMAP_NUM_LEVELS, 16) 87 | for i in range(self.NUM_DENSITY_LAYERS): 88 | n_params_layer = self.NUM_DENSITY_NEURONS * n_inputs 89 | n_params_layer = next_multiple(n_params_layer, 16) 90 | n_params_mlp += n_params_layer 91 | n_inputs = self.NUM_DENSITY_NEURONS 92 | n_params_mlp += next_multiple(self.encoding_xyz.n_output_dims, 16) * n_inputs 93 | self.n_params_encoding_mlp = n_params_mlp 94 | self.color_mlp_with_encoding = tcnn.NetworkWithInputEncoding( 95 | n_input_dims=3 + self.encoding_xyz.n_output_dims, 96 | n_output_dims=3, 97 | encoding_config={ 98 | 'otype': 'Composite', 99 | 'nested': [ 100 | { 101 | 'n_dims_to_encode': 3, 102 | 'otype': 'SphericalHarmonics', 103 | 'degree': self.DIR_SH_ENCODING_DEGREE, 104 | }, 105 | { 106 | "otype": "Identity" 107 | } 108 | ] 109 | 110 | }, 111 | network_config={ 112 | 'otype': 'FullyFusedMLP', 113 | 'activation': 'ReLU', 114 | 'output_activation': 'Sigmoid', 115 | 'n_neurons': self.NUM_COLOR_NEURONS, 116 | 'n_hidden_layers': self.NUM_COLOR_LAYERS, 117 | }, 118 | seed=Framework.config.GLOBAL.RANDOM_SEED 119 | ) 120 | self.n_mlp_params = len(self.color_mlp_with_encoding.params) + self.n_params_encoding_mlp 121 | return self 122 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/Trainer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """InstantNGP/Trainer.py: Implementation of the trainer for the InstantNGP method.""" 3 | 4 | import torch 5 | 6 | import Framework 7 | from Logging import Logger 8 | from Cameras.NDC import NDCCamera 9 | from Datasets.Base import BaseDataset 10 | from Methods.Base.GuiTrainer import GuiTrainer 11 | from Methods.Base.utils import preTrainingCallback, trainingCallback 12 | from Methods.InstantNGP.Loss import InstantNGPLoss 13 | from Methods.InstantNGP.utils import next_multiple, logOccupancyGrids 14 | from Optim.Samplers.DatasetSamplers import RayPoolSampler 15 | from Optim.Samplers.ImageSamplers import RandomImageSampler 16 | 17 | 18 | @Framework.Configurable.configure( 19 | NUM_ITERATIONS=50000, 20 | TARGET_BATCH_SIZE=262144, 21 | WARMUP_STEPS=256, 22 | DENSITY_GRID_UPDATE_INTERVAL=16, 23 | LEARNING_RATE=1.0e-2, 24 | LEARNING_RATE_DECAY_START=20000, 25 | LEARNING_RATE_DECAY_INTERVAL=10000, 26 | LEARNING_RATE_DECAY_BASE=0.33, 27 | ADAM_EPS=1e-15, 28 | USE_APEX=False, 29 | WANDB=Framework.ConfigParameterList( 30 | RENDER_OCCUPANCY_GRIDS=False, 31 | ) 32 | ) 33 | class InstantNGPTrainer(GuiTrainer): 34 | """Defines the trainer for the InstantNGP method.""" 35 | 36 | def __init__(self, **kwargs) -> None: 37 | super(InstantNGPTrainer, self).__init__(**kwargs) 38 | self.loss = InstantNGPLoss(self.model) 39 | try: 40 | if self.USE_APEX: 41 | from Thirdparty.Apex import FusedAdam # slightly faster than the PyTorch implementation 42 | self.optimizer = FusedAdam(self.model.parameters(), lr=self.LEARNING_RATE, eps=self.ADAM_EPS, betas=(0.9, 0.99), adam_w_mode=False) 43 | else: 44 | raise Exception 45 | except Exception: 46 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.LEARNING_RATE, eps=self.ADAM_EPS, betas=(0.9, 0.99), fused=True) 47 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 48 | optimizer=self.optimizer, 49 | milestones=[iteration for iteration in range( 50 | self.LEARNING_RATE_DECAY_START, 51 | self.NUM_ITERATIONS, 52 | self.LEARNING_RATE_DECAY_INTERVAL 53 | )], 54 | gamma=self.LEARNING_RATE_DECAY_BASE 55 | ) 56 | self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=128.0, growth_interval=self.NUM_ITERATIONS + 1) 57 | self.rays_per_batch = 2 ** 12 58 | self.measured_batch_size = 0 59 | self.sampler_train = None 60 | self.sampler_val = None 61 | 62 | @preTrainingCallback(priority=10000) 63 | @torch.no_grad() 64 | def removeBackgroundColor(self, _, dataset: 'BaseDataset') -> None: 65 | if (dataset.camera.background_color == 0.0).all(): 66 | return 67 | # remove background color from training samples to allow training with random background colors 68 | for cam_properties in Logger.logProgressBar(dataset.data['train'], desc='Removing background color', leave=False): 69 | if cam_properties.alpha is not None: 70 | cam_properties.rgb.sub_((1.0 - cam_properties.alpha) * dataset.camera.background_color[:, None, None]).clamp_(0.0, 1.0) 71 | # recompute rays if necessary 72 | if dataset.PRECOMPUTE_RAYS: 73 | dataset.ray_collection['train'] = None 74 | dataset.precomputeRays(['train']) 75 | 76 | @preTrainingCallback(priority=1000) 77 | @torch.no_grad() 78 | def initSampler(self, _, dataset: 'BaseDataset') -> None: 79 | self.sampler_train = RayPoolSampler(dataset=dataset.train(), img_sampler_cls=RandomImageSampler) 80 | if self.RUN_VALIDATION: 81 | self.sampler_val = RayPoolSampler(dataset=dataset.eval(), img_sampler_cls=RandomImageSampler) 82 | 83 | @preTrainingCallback(priority=100) 84 | @torch.no_grad() 85 | def carveDensityGrid(self, _, dataset: 'BaseDataset') -> None: 86 | if not isinstance(dataset.camera, NDCCamera): 87 | self.renderer.carveDensityGrid(dataset.train(), subtractive=False, use_alpha=False) 88 | 89 | @trainingCallback(priority=1000, iteration_stride='DENSITY_GRID_UPDATE_INTERVAL') 90 | @torch.no_grad() 91 | def updateDensityGrid(self, iteration: int, _) -> None: 92 | self.renderer.updateDensityGrid(warmup=iteration < self.WARMUP_STEPS) 93 | 94 | @trainingCallback(priority=999, start_iteration='DENSITY_GRID_UPDATE_INTERVAL', iteration_stride='DENSITY_GRID_UPDATE_INTERVAL') 95 | @torch.no_grad() 96 | def updateBatchSize(self, *_) -> None: 97 | self.measured_batch_size /= self.DENSITY_GRID_UPDATE_INTERVAL 98 | self.rays_per_batch = min(next_multiple(self.rays_per_batch * self.TARGET_BATCH_SIZE / self.measured_batch_size, 256), self.TARGET_BATCH_SIZE) 99 | self.measured_batch_size = 0 100 | 101 | @trainingCallback(priority=50) 102 | def processTrainingSample(self, _, dataset: 'BaseDataset') -> None: 103 | """Performs a single training step.""" 104 | # prepare training iteration 105 | self.model.train() 106 | self.loss.train() 107 | dataset.train() 108 | # sample ray batch 109 | ray_batch: torch.Tensor = self.sampler_train.get(dataset=dataset, ray_batch_size=self.rays_per_batch)['ray_batch'] 110 | with torch.amp.autocast('cuda', enabled=True): 111 | # render and update 112 | bg_color = torch.rand(3) 113 | output = self.renderer.renderRays( 114 | rays=ray_batch, 115 | camera=dataset.camera, 116 | custom_bg_color=bg_color, 117 | train_mode=True) 118 | loss = self.loss(output, ray_batch, bg_color) 119 | self.optimizer.zero_grad() 120 | self.grad_scaler.scale(loss).backward() 121 | self.grad_scaler.step(self.optimizer) 122 | self.grad_scaler.update() 123 | self.lr_scheduler.step() 124 | self.measured_batch_size += output['rm_samples'].item() 125 | 126 | @trainingCallback(active='RUN_VALIDATION', priority=100) 127 | @torch.no_grad() 128 | def processValidationSample(self, _, dataset: 'BaseDataset') -> None: 129 | """Performs a single validation step.""" 130 | self.model.eval() 131 | self.loss.eval() 132 | dataset.eval() 133 | # sample ray batch 134 | ray_batch: torch.Tensor = self.sampler_val.get(dataset=dataset, ray_batch_size=self.rays_per_batch)['ray_batch'] 135 | with torch.amp.autocast('cuda', enabled=True): 136 | output = self.renderer.renderRays( 137 | rays=ray_batch, 138 | camera=dataset.camera, 139 | train_mode=True) 140 | self.loss(output, ray_batch, None) 141 | 142 | @trainingCallback(active='WANDB.ACTIVATE', priority=500, iteration_stride='WANDB.INTERVAL') 143 | @torch.no_grad() 144 | def logWandB(self, iteration: int, dataset: 'BaseDataset') -> None: 145 | """Logs all losses and visualizes training and validation samples using Weights & Biases.""" 146 | super().logWandB(iteration, dataset) 147 | # visualize scene and occupancy grid as point clouds 148 | if self.WANDB.RENDER_OCCUPANCY_GRIDS: 149 | logOccupancyGrids(self.renderer, iteration, dataset, 'occupancy grid') 150 | # commit current step 151 | Framework.wandb.log(data={}, commit=True) 152 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/__init__.py: -------------------------------------------------------------------------------- 1 | from Methods.InstantNGP.Model import InstantNGPModel 2 | from Methods.InstantNGP.Renderer import InstantNGPRenderer 3 | from Methods.InstantNGP.Trainer import InstantNGPTrainer 4 | 5 | MODEL = InstantNGPModel 6 | RENDERER = InstantNGPRenderer 7 | TRAINING_INSTANCE = InstantNGPTrainer 8 | -------------------------------------------------------------------------------- /src/Methods/InstantNGP/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | InstantNGP/utils.py: Utility functions for InstantNGP. 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import Framework 11 | from Datasets.Base import BaseDataset 12 | 13 | 14 | def next_multiple(value: int | float, multiple: int) -> int: 15 | return int(((value + multiple - 1) // multiple) * multiple) 16 | 17 | 18 | @torch.no_grad() 19 | def logOccupancyGrids(renderer, iteration: int, dataset: 'BaseDataset', log_string: str = 'occupancy grid') -> None: 20 | """visualize occupancy grid as pointcloud in wandb.""" 21 | dataset.train() 22 | cameras: list[np.array] = [] 23 | for i in range(len(dataset)): 24 | dataset.camera.setProperties(dataset[i]) 25 | data = dataset.camera.getPositionAndViewdir().cpu().numpy() 26 | # data[:, 2] *= -1 27 | cameras.append({"start": data[0].tolist(), "end": (data[0] + (0.1 * data[1])).tolist()}) 28 | cameras: np.ndarray = np.array(cameras) 29 | # gather boxes and active cells 30 | center = np.array([renderer.model.CENTER]) 31 | unit_cube_points = np.array([ 32 | [-1, -1, -1], 33 | [-1, +1, -1], 34 | [-1, -1, +1], 35 | [+1, -1, -1], 36 | [+1, +1, -1], 37 | [-1, +1, +1], 38 | [+1, -1, +1], 39 | [+1, +1, +1] 40 | ]) 41 | cells = renderer.getAllCells() 42 | boxes = [] 43 | active_cells = [(unit_cube_points * renderer.model.SCALE * 1.1) + center] 44 | thresh = min(renderer.model.density_grid[renderer.model.density_grid > 0].mean().item(), renderer.density_threshold) 45 | for c in range(renderer.model.cascades): 46 | indices, coords = cells[c] 47 | s = min(2 ** (c - 1), renderer.model.SCALE) 48 | boxes.append( 49 | { 50 | "corners": ((unit_cube_points * s) + center).tolist(), 51 | "label": f"Cascade {c + 1}", 52 | "color": [255, 0, 0], 53 | } 54 | ) 55 | half_grid_size = s / renderer.model.RESOLUTION 56 | points = (coords / (renderer.model.RESOLUTION - 1) * 2 - 1) * (s - half_grid_size) 57 | active_cells.append((points[renderer.model.density_grid[c, indices] > thresh]).cpu().numpy() + center) 58 | active_cells = np.concatenate(active_cells, axis=0) 59 | while active_cells.shape[0] > 200000: 60 | active_cells = active_cells[::2] 61 | scene = Framework.wandb.Object3D({ 62 | "type": "lidar/beta", 63 | "points": active_cells, 64 | "boxes": np.array(boxes), 65 | "vectors": cameras 66 | }) 67 | Framework.wandb.log( 68 | data={log_string: scene}, 69 | step=iteration 70 | ) 71 | -------------------------------------------------------------------------------- /src/Methods/NeRF/Loss.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """NeRF/Loss.py: Loss implementation for the NeRF method.""" 4 | 5 | import torch 6 | import torchmetrics 7 | 8 | from Cameras.utils import RayPropertySlice 9 | from Optim.Losses.Base import BaseLoss 10 | 11 | 12 | class NeRFLoss(BaseLoss): 13 | """Defines a class for all sub-losses of the NeRF method.""" 14 | 15 | def __init__(self, lambda_color: float, lambda_alpha: float) -> None: 16 | super().__init__() 17 | self.addLossMetric('L2_Color', torch.nn.functional.mse_loss, lambda_color) 18 | self.addLossMetric('L2_Alpha', torch.nn.functional.mse_loss, lambda_alpha) 19 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio) 20 | 21 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor) -> torch.Tensor: 22 | """Defines loss calculation.""" 23 | return super().forward({ 24 | 'L2_Color': {'input': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb]}, 25 | 'L2_Alpha': {'input': outputs['alpha'], 'target': rays[:, RayPropertySlice.alpha]}, 26 | 'PSNR': {'preds': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0}, 27 | }) 28 | -------------------------------------------------------------------------------- /src/Methods/NeRF/Model.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """NeRF/Model.py: Implementation of the neural model for the vanilla (i.e. original) NeRF method.""" 4 | 5 | import torch 6 | 7 | import Framework 8 | from Methods.Base.Model import BaseModel 9 | from Methods.NeRF.utils import getActivationFunction, FrequencyEncoding 10 | 11 | 12 | class NeRFBlock(torch.torch.nn.Module): 13 | """Defines a NeRF block (input: position, direction -> output: density, color).""" 14 | 15 | def __init__(self, num_layers: int, num_color_layers: int, num_features: int, 16 | encoding_length_position: int, encoding_length_direction: int, encoding_append_input: bool, 17 | input_skips: list[int], activation_function: str) -> None: 18 | super(NeRFBlock, self).__init__() 19 | # set parameters 20 | self.input_skips = input_skips # layer indices after which input is appended 21 | # get activation function (type, parameters, initial bias for last density layer) 22 | af_class, af_parameters, af_bias = getActivationFunction(activation_function) 23 | # embedding layers 24 | self.embedding_position = FrequencyEncoding(encoding_length_position, encoding_append_input) 25 | self.embedding_direction = FrequencyEncoding(encoding_length_direction, encoding_append_input) 26 | input_size_position = self.embedding_position.getOutputSize(3) 27 | input_size_direction = self.embedding_direction.getOutputSize(3) 28 | # initial linear layers 29 | self.initial_layers = [ 30 | torch.nn.Sequential(torch.nn.Linear(input_size_position, num_features, bias=True), af_class(*af_parameters)) 31 | ] 32 | for layer_index in range(1, num_layers): 33 | self.initial_layers.append(torch.nn.Sequential( 34 | torch.nn.Linear(num_features if layer_index not in input_skips else num_features + input_size_position, 35 | num_features, bias=True), af_class(*af_parameters) 36 | )) 37 | self.initial_layers = torch.nn.ModuleList(self.initial_layers) 38 | # intermediate feature and density layers 39 | self.feature_layer = torch.nn.Linear(num_features, num_features, bias=True) 40 | self.density_layer = torch.nn.Linear(num_features, 1, bias=True) 41 | self.density_activation = af_class(*af_parameters) 42 | # final color layer 43 | self.color_layers = torch.nn.Sequential(*( 44 | [torch.nn.Linear(num_features + input_size_direction, num_features // 2, bias=True), af_class(*af_parameters)] 45 | + [torch.nn.Sequential(torch.nn.Linear(num_features // 2, num_features // 2, bias=True), af_class(*af_parameters)) 46 | for _ in range(num_color_layers - 1)] 47 | + [torch.nn.Linear(num_features // 2, 3, bias=True), torch.nn.Sigmoid()] 48 | )) 49 | # initialize bias for density layer activation function (for better convergence, copied from pytorch3d examples) 50 | if af_bias is not None: 51 | self.density_layer.bias.data[0] = af_bias 52 | 53 | def forward(self, positions: torch.Tensor, directions: torch.Tensor, 54 | return_rgb: bool = False, random_noise_densities: float = 0.0) -> tuple[torch.Tensor, torch.Tensor]: 55 | # transform inputs to higher dimensional space 56 | positions_embedded: torch.Tensor = self.embedding_position(positions) 57 | # run initial layers 58 | x = positions_embedded 59 | for index, layer in enumerate(self.initial_layers): 60 | x = layer(x) 61 | if index + 1 in self.input_skips: 62 | x = torch.cat((x, positions_embedded), dim=-1) 63 | # extract density, add random noise before activation function 64 | density: torch.Tensor = self.density_layer(x) 65 | density = self.density_activation(density + (torch.randn(density.shape) * random_noise_densities)) 66 | # extract features, append view_directions and extract color 67 | color: torch.Tensor | None = None 68 | if return_rgb: 69 | directions_embedded: torch.Tensor = self.embedding_direction(directions) 70 | features: torch.Tensor = self.feature_layer(x) 71 | features = torch.cat((features, directions_embedded), dim=-1) 72 | color = self.color_layers(features) 73 | return density, color 74 | 75 | 76 | @Framework.Configurable.configure( 77 | NUM_LAYERS=8, 78 | NUM_COLOR_LAYERS=1, 79 | NUM_FEATURES=256, 80 | ENCODING_LENGTH_POSITIONS=10, 81 | ENCODING_LENGTH_DIRECTIONS=4, 82 | ENCODING_APPEND_INPUT=True, 83 | INPUT_SKIPS=[5], 84 | ACTIVATION_FUNCTION='relu' 85 | ) 86 | class NeRF(BaseModel): 87 | """Defines a plain NeRF with a single MLP.""" 88 | 89 | def __init__(self, name: str = None) -> None: 90 | super(NeRF, self).__init__(name) 91 | 92 | def build(self) -> 'NeRF': 93 | """Builds the model.""" 94 | self.net = NeRFBlock( 95 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES, 96 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT, 97 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION 98 | ) 99 | return self 100 | -------------------------------------------------------------------------------- /src/Methods/NeRF/Renderer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | NeRF/Renderer.py: Implementation of the renderer for the vanilla (i.e. original) NeRF. 5 | Borrows heavily from the PyTorch NeRF reimplementation of Yenchen Lin 6 | Source: https://github.com/yenchenlin/nerf-pytorch/ 7 | """ 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | import Framework 13 | from Cameras.utils import RayPropertySlice 14 | from Methods.Base.Renderer import BaseRenderingComponent, BaseRenderer 15 | from Methods.NeRF.Model import NeRF, NeRFBlock 16 | from Methods.NeRF.utils import generateSamples, integrateRaySamples 17 | from Methods.Base.Renderer import BaseModel 18 | from Cameras.Base import BaseCamera 19 | 20 | 21 | class NeRFRayRenderingComponent(BaseRenderingComponent): 22 | """Defines a NeRF ray rendering component used to access the NeRF model.""" 23 | 24 | def __init__(self, scene_function: 'NeRFBlock') -> None: 25 | super().__init__() 26 | self.scene_function = scene_function 27 | 28 | def forward(self, rays: Tensor, camera: 'BaseCamera', 29 | ray_batch_size: int, num_samples: int, return_samples: bool, randomize_samples: bool, 30 | random_noise_densities: float) -> dict[str, Tensor | None]: 31 | """Generates samples from the given rays and queries the NeRF model to produce the desired outputs.""" 32 | outputs = {'rgb': [], 'depth': [], 'depth_samples': [], 'alpha_weights': [], 'alpha': []} 33 | # split rays into chunks that fit into VRAM 34 | ray_batches: list[Tensor] = torch.split(rays, ray_batch_size, dim=0) 35 | background_color: Tensor = camera.background_color.to(rays.device) 36 | for ray_batch in ray_batches: 37 | depth_samples = generateSamples( 38 | ray_batch, num_samples, camera.near_plane, camera.far_plane, randomize_samples 39 | ) 40 | positions: Tensor = ray_batch[:, None, RayPropertySlice.origin] + ( 41 | ray_batch[:, None, RayPropertySlice.direction] * depth_samples[:, :, None]) 42 | directions: Tensor = ray_batch[:, None, RayPropertySlice.view_direction].expand(positions.shape) 43 | densities, rgb = self.scene_function( 44 | positions.reshape(-1, 3), directions.reshape(-1, 3), 45 | return_rgb=True, random_noise_densities=random_noise_densities 46 | ) 47 | final_rgb, final_depth, final_alpha, final_alpha_weights = integrateRaySamples( 48 | depth_samples, ray_batch[:, RayPropertySlice.direction], 49 | densities.reshape(-1, num_samples), rgb.reshape(-1, num_samples, 3), background_color 50 | ) 51 | # append outputs 52 | outputs['rgb'].append(final_rgb) 53 | outputs['depth'].append(final_depth) 54 | outputs['alpha'].append(final_alpha) 55 | if return_samples: 56 | outputs['depth_samples'].append(depth_samples) 57 | outputs['alpha_weights'].append(final_alpha_weights) 58 | # concat ray batches 59 | for key in outputs: 60 | outputs[key] = ( 61 | torch.cat(outputs[key], dim=0) if len(ray_batches) > 1 else outputs[key][0] 62 | ) if outputs[key] else None 63 | return outputs 64 | 65 | 66 | @Framework.Configurable.configure( 67 | RAY_BATCH_SIZE=8192, 68 | NUM_SAMPLES=256 69 | ) 70 | class NeRFRenderer(BaseRenderer): 71 | """Defines the renderer for the vanilla (i.e. original) NeRF method.""" 72 | 73 | def __init__(self, model: 'BaseModel') -> None: 74 | super().__init__(model, [NeRF]) 75 | self.ray_rendering_component = NeRFRayRenderingComponent.get(self.model.net) 76 | 77 | def renderRays(self, rays: Tensor, camera: 'BaseCamera', 78 | return_samples: bool = False, randomize_samples: bool = False, 79 | random_noise_densities: float = 0.0) -> dict[str, Tensor | None]: 80 | """Renders the given set of rays using the renderer's rendering component.""" 81 | return self.ray_rendering_component( 82 | rays, camera, 83 | self.RAY_BATCH_SIZE, self.NUM_SAMPLES, 84 | return_samples, randomize_samples, random_noise_densities 85 | ) 86 | 87 | def renderImage(self, camera: 'BaseCamera', to_chw: bool = False, benchmark: bool = False) -> dict[str, Tensor | None]: 88 | """Renders a complete image using the given camera.""" 89 | rays: Tensor = camera.generateRays() 90 | rendered_rays = self.renderRays( 91 | rays, camera, 92 | return_samples=False, randomize_samples=False, random_noise_densities=0.0 93 | ) 94 | # reshape rays to images 95 | for key in rendered_rays: 96 | if rendered_rays[key] is not None: 97 | rendered_rays[key] = rendered_rays[key].reshape(camera.properties.height, camera.properties.width, -1) 98 | if to_chw: 99 | rendered_rays[key] = rendered_rays[key].permute((2, 0, 1)) 100 | return rendered_rays 101 | -------------------------------------------------------------------------------- /src/Methods/NeRF/Trainer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """NeRF/Trainer.py: Implementation of the trainer for the vanilla (i.e. original) NeRF method.""" 4 | 5 | import torch 6 | 7 | import Framework 8 | from Datasets.Base import BaseDataset 9 | from Methods.Base.Trainer import BaseTrainer 10 | from Methods.Base.utils import preTrainingCallback, trainingCallback 11 | from Methods.NeRF.Loss import NeRFLoss 12 | from Optim.Samplers.DatasetSamplers import DatasetSampler, RayPoolSampler 13 | from Optim.Samplers.ImageSamplers import RandomImageSampler 14 | 15 | 16 | @Framework.Configurable.configure( 17 | NUM_ITERATIONS=500000, 18 | BATCH_SIZE=1024, 19 | SAMPLE_SINGLE_IMAGE=True, 20 | DENSITY_RANDOM_NOISE_STD=0.0, 21 | ADAM_BETA_1=0.9, 22 | ADAM_BETA_2=0.999, 23 | LEARNINGRATE=5.0e-04, 24 | LEARNINGRATE_DECAY_RATE=0.1, 25 | LEARNINGRATE_DECAY_STEPS=500000, 26 | LAMBDA_COLOR_LOSS=1.0, 27 | LAMBDA_ALPHA_LOSS=0.0, 28 | ) 29 | class NeRFTrainer(BaseTrainer): 30 | """Defines the trainer for the vanilla (i.e. original) NeRF method.""" 31 | 32 | def __init__(self, **kwargs) -> None: 33 | super(NeRFTrainer, self).__init__(**kwargs) 34 | self.optimizer = torch.optim.Adam( 35 | self.model.parameters(), 36 | lr=self.LEARNINGRATE, betas=(self.ADAM_BETA_1, self.ADAM_BETA_2), 37 | eps=1e-8 38 | ) 39 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 40 | self.optimizer, 41 | lr_lambda=self.LRDecayPolicy(self.LEARNINGRATE_DECAY_RATE, self.LEARNINGRATE_DECAY_STEPS), 42 | last_epoch=self.model.num_iterations_trained - 1 43 | ) 44 | self.loss = NeRFLoss(self.LAMBDA_COLOR_LOSS, self.LAMBDA_ALPHA_LOSS) 45 | 46 | class LRDecayPolicy(object): 47 | """Defines a decay policy for the learning rate.""" 48 | 49 | def __init__(self, ldr: float, lds: float) -> None: 50 | self.ldr: float = ldr 51 | self.lds: float = lds 52 | 53 | def __call__(self, iteration) -> float: 54 | """Calculates learning rate decay.""" 55 | return self.ldr ** (iteration / self.lds) 56 | 57 | @preTrainingCallback(priority=1000) 58 | @torch.no_grad() 59 | def initSampler(self, _, dataset: 'BaseDataset') -> None: 60 | sampler_cls = DatasetSampler if self.SAMPLE_SINGLE_IMAGE else RayPoolSampler 61 | self.sampler_train = sampler_cls(dataset=dataset.train(), random=True, img_sampler_cls=RandomImageSampler) 62 | if self.RUN_VALIDATION: 63 | self.sampler_val = sampler_cls(dataset=dataset.eval(), random=True, img_sampler_cls=RandomImageSampler) 64 | 65 | @trainingCallback(priority=50) 66 | def processTrainingSample(self, iteration: int, dataset: 'BaseDataset') -> None: 67 | """Defines a callback which is executed every iteration to process a training sample.""" 68 | # set modes 69 | self.model.train() 70 | self.loss.train() 71 | dataset.train() 72 | # sample ray batch 73 | ray_batch: torch.Tensor = self.sampler_train.get(dataset=dataset, ray_batch_size=self.BATCH_SIZE)['ray_batch'] 74 | # update model 75 | self.optimizer.zero_grad() 76 | outputs = self.renderer.renderRays( 77 | rays=ray_batch, 78 | camera=dataset.camera, 79 | return_samples=False, 80 | randomize_samples=True, 81 | random_noise_densities=self.DENSITY_RANDOM_NOISE_STD 82 | ) 83 | loss: torch.Tensor = self.loss(outputs, ray_batch) 84 | loss.backward() 85 | self.optimizer.step() 86 | # update learning rate 87 | self.lr_scheduler.step() 88 | 89 | @trainingCallback(priority=100, active='RUN_VALIDATION') 90 | @torch.no_grad() 91 | def processValidationSample(self, iteration: int, dataset: 'BaseDataset') -> None: 92 | """Defines a callback which is executed every iteration to process a validation sample.""" 93 | self.model.eval() 94 | self.loss.eval() 95 | dataset.eval() 96 | ray_batch: torch.Tensor = self.sampler_val.get(dataset=dataset, ray_batch_size=self.BATCH_SIZE)['ray_batch'] 97 | outputs = self.renderer.renderRays( 98 | rays=ray_batch, 99 | camera=dataset.camera, 100 | return_samples=False, 101 | randomize_samples=False, 102 | random_noise_densities=0.0 103 | ) 104 | self.loss(outputs, ray_batch) 105 | -------------------------------------------------------------------------------- /src/Methods/NeRF/__init__.py: -------------------------------------------------------------------------------- 1 | from Methods.NeRF.Model import NeRF 2 | from Methods.NeRF.Renderer import NeRFRenderer 3 | from Methods.NeRF.Trainer import NeRFTrainer 4 | 5 | MODEL = NeRF 6 | RENDERER = NeRFRenderer 7 | TRAINING_INSTANCE = NeRFTrainer 8 | -------------------------------------------------------------------------------- /src/Methods/NeRF/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | NeRF/utils.py: Contains utility functions used for the implementation of the NeRF method. 5 | """ 6 | import torch 7 | from torch import Tensor 8 | 9 | import Framework 10 | from Logging import Logger 11 | 12 | 13 | class FrequencyEncoding(torch.nn.Module): 14 | """Defines a network layer that performs frequency encoding with linear coefficients.""" 15 | 16 | def __init__(self, encoding_length: int, append_input: bool): 17 | super().__init__() 18 | # calculate frequencies 19 | self.register_buffer('frequency_factors', ( 20 | 2 ** torch.linspace(start=0.0, end=encoding_length - 1.0, steps=encoding_length) 21 | )[None, None, :] # * math.pi 22 | ) 23 | self.append_input: bool = append_input 24 | 25 | def getOutputSize(self, num_inputs: int) -> int: 26 | """Returns the number of output nodes""" 27 | num_outputs: int = num_inputs * 2 * self.frequency_factors.numel() 28 | if self.append_input: 29 | num_outputs += num_inputs 30 | return num_outputs 31 | 32 | def forward(self, inputs: Tensor) -> Tensor: 33 | """Returns the input tensor after applying the periodic embedding to it.""" 34 | outputs: list[Tensor] = [] 35 | # append inputs original inputs if requested 36 | if self.append_input: 37 | outputs.append(inputs) 38 | frequencies: Tensor = (inputs[:, :, None] * self.frequency_factors).flatten(start_dim=1) 39 | # apply periodic functions over frequencies 40 | for periodic_function in (torch.cos, torch.sin): 41 | outputs.append(periodic_function(frequencies)) 42 | return torch.cat(outputs, dim=-1) 43 | 44 | 45 | # dictionary variable containing all available activation functions as well as their parameters and initial biases 46 | ACTIVATION_FUNCTION_OPTIONS: dict[str, tuple] = { 47 | 'relu': (torch.nn.ReLU, (True,), None), 48 | 'softplus': (torch.nn.Softplus, (10.0,), -1.5) 49 | } 50 | 51 | 52 | def getActivationFunction(type: str) -> tuple: 53 | """Returns the requested activation function, parameters and initial bias.""" 54 | # log error message and stop execution if requested key is invalid 55 | if type not in ACTIVATION_FUNCTION_OPTIONS: 56 | Logger.logError( 57 | f'requested invalid model activation function: {type} \n' 58 | f'available options are: {list(ACTIVATION_FUNCTION_OPTIONS.keys())}' 59 | ) 60 | raise Framework.ModelError(f'Invalid activation function "{type}"') 61 | # return requested model instance 62 | return ACTIVATION_FUNCTION_OPTIONS[type] 63 | 64 | 65 | def generateSamples(rays: Tensor, num_samples: int, near_plane: float, far_plane: float, 66 | randomize_samples: bool) -> Tensor: 67 | """Returns random samples (positions in space) for the given set of rays.""" 68 | device: torch.device = rays.device 69 | lin_steps: Tensor = torch.linspace(0., 1., steps=num_samples, device=device) 70 | lin_steps: Tensor = (near_plane * (1.0 - lin_steps)) + (far_plane * lin_steps) 71 | depth_samples: Tensor = lin_steps.expand([rays.shape[0], num_samples]) 72 | if randomize_samples: 73 | # use linear samples as interval borders for random samples 74 | mid_points: Tensor = 0.5 * (depth_samples[..., 1:] + depth_samples[..., :-1]) 75 | upper_border: Tensor = torch.cat([mid_points, depth_samples[..., -1:]], -1) 76 | lower_border: Tensor = torch.cat([depth_samples[..., :1], mid_points], -1) 77 | random_offsets: Tensor = torch.rand(depth_samples.shape, device=device) 78 | depth_samples: Tensor = lower_border + ((upper_border - lower_border) * random_offsets) 79 | return depth_samples 80 | 81 | 82 | def integrateRaySamples(depth_samples: Tensor, ray_directions: Tensor, densities: Tensor, rgb: Tensor | None, 83 | background_color: Tensor, final_distance: float=1e10) -> tuple[Tensor | None, Tensor | None, Tensor, Tensor]: 84 | """Estimates final color, depth, and alpha values from samples along ray.""" 85 | distances: Tensor = depth_samples[..., 1:] - depth_samples[..., :-1] 86 | distances: Tensor = torch.cat([distances, Tensor([final_distance]).expand(distances[..., :1].shape)], dim=-1) * torch.norm(ray_directions[..., None, :], dim=-1) 87 | alpha: Tensor = 1.0 - torch.exp(-densities * distances) 88 | alpha_weights: Tensor = alpha * torch.cumprod( 89 | torch.cat([torch.ones((alpha.shape[0], 1)), 1.0 - alpha + 1e-10], -1), -1 90 | )[:, :-1] 91 | alpha_final: Tensor = torch.sum(alpha_weights, dim=-1, keepdim=True) 92 | # render only if color is available 93 | final_rgb: Tensor | None 94 | final_depth: Tensor | None 95 | final_rgb = final_depth = None 96 | if rgb is not None: 97 | final_depth = torch.sum((alpha_weights / (alpha_final + 1e-8)) * depth_samples, -1) 98 | final_rgb = torch.sum(alpha_weights[..., None] * rgb, -2) 99 | if background_color is not None: 100 | final_rgb += ((1.0 - alpha_final) * background_color) 101 | return final_rgb, final_depth, alpha_final, alpha_weights 102 | -------------------------------------------------------------------------------- /src/Optim/AdamUtils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/AdamUtils.py: Provides various utility functions for the Adam optimizer and its variants.""" 4 | 5 | import torch 6 | 7 | 8 | def replace_param_group_data(optimizer: torch.optim.Optimizer, new_values: torch.Tensor, group_name: str, reset_state: bool = True) -> None: 9 | """Replaces the data of a parameter group with the given tensor.""" 10 | for group in optimizer.param_groups: 11 | if group['name'] == group_name: 12 | if len(group['params']) != 1: 13 | raise NotImplementedError('"replace_param_group_data" only implemented for single-parameter groups.') 14 | param = group['params'][0] 15 | param.data = new_values 16 | if reset_state: 17 | state = optimizer.state[param] 18 | if state: 19 | for val in ['exp_avg', 'exp_avg_sq']: 20 | state[val].zero_() 21 | 22 | 23 | def prune_param_groups(optimizer: torch.optim.Optimizer, mask: torch.Tensor, group_names: list[str] | None = None) -> dict[str, torch.Tensor]: 24 | """Removes parameter entries based on the given mask.""" 25 | new_params = {} 26 | for group in optimizer.param_groups: 27 | if group_names is not None and group['name'] not in group_names: 28 | continue 29 | if len(group['params']) != 1: 30 | raise NotImplementedError('"prune_param_groups" only implemented for single-parameter groups.') 31 | old_param = group['params'][0] 32 | state = optimizer.state[old_param] 33 | new_param = torch.nn.Parameter(old_param[mask]) 34 | if state: 35 | for val in ['exp_avg', 'exp_avg_sq']: 36 | state[val] = state[val][mask] 37 | optimizer.state.pop(old_param) 38 | optimizer.state[new_param] = state 39 | group['params'][0] = new_param 40 | new_params[group['name']] = new_param 41 | return new_params 42 | 43 | 44 | def extend_param_groups(optimizer: torch.optim.Optimizer, additional_params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 45 | """Extend existing parameters by concatenating the given tensors.""" 46 | new_params = {} 47 | for group in optimizer.param_groups: 48 | if len(group['params']) != 1: 49 | raise NotImplementedError('"extend_param_groups" only implemented for single-parameter groups.') 50 | extension_tensor = additional_params.get(group['name'], None) 51 | if extension_tensor is None: 52 | continue 53 | old_param = group['params'][0] 54 | state = optimizer.state[old_param] 55 | new_param = torch.nn.Parameter(torch.cat((old_param, extension_tensor), dim=0)) 56 | if state: 57 | for val in ['exp_avg', 'exp_avg_sq']: 58 | state[val] = torch.cat((state[val], torch.zeros_like(extension_tensor)), dim=0) 59 | optimizer.state.pop(old_param) 60 | optimizer.state[new_param] = state 61 | group['params'][0] = new_param 62 | new_params[group['name']] = new_param 63 | return new_params -------------------------------------------------------------------------------- /src/Optim/GradientScaling.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/GradientScaling.py: Gradient scaling routines.""" 4 | 5 | from typing import Any 6 | import torch 7 | 8 | 9 | class _GradientScaler(torch.autograd.Function): 10 | """ 11 | Utility Autograd function scaling gradients by a given factor. 12 | Adapted from NerfStudio: https://docs.nerf.studio/en/latest/_modules/nerfstudio/model_components/losses.html 13 | """ 14 | 15 | @staticmethod 16 | def forward(ctx, value: torch.Tensor, scaling: Any) -> tuple[torch.Tensor, Any]: 17 | ctx.save_for_backward(scaling) 18 | return value, scaling 19 | 20 | @staticmethod 21 | def backward(ctx, output_grad: torch.Tensor, grad_scaling: Any) -> tuple[torch.Tensor, Any]: 22 | (scaling,) = ctx.saved_tensors 23 | return output_grad * scaling, grad_scaling 24 | 25 | 26 | def scaleGradient(*args: torch.Tensor, scaling: Any) -> tuple[torch.Tensor, ...] | torch.Tensor: 27 | """ 28 | Scale the gradient of the given tensors. 29 | """ 30 | output: list[torch.Tensor] = [_GradientScaler.apply(value, scaling)[0] for value in args] 31 | return tuple(output) if len(output) > 1 else output[0] 32 | 33 | 34 | def scaleGradientByDistance(*args: torch.Tensor, distances: torch.Tensor) -> tuple[torch.Tensor, ...]: 35 | """ 36 | Scale the gradient of the given tensor based on the normalized ray distance. 37 | See: Radiance Field Gradient Scaling for Improved Near-Camera Training (https://gradient-scaling.github.io) 38 | """ 39 | return scaleGradient(*args, scaling=torch.square(distances).clamp(0, 1)) 40 | -------------------------------------------------------------------------------- /src/Optim/Losses/BackgroundEntropy.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/LossesBackgroundEntropy.py: Background entropy loss (encouraging alpha channel to be 0 or 1).""" 4 | 5 | from typing import Any 6 | 7 | import torch 8 | import torchmetrics 9 | 10 | 11 | def backgroundEntropy(input: torch.Tensor, symmetrical=False) -> torch.Tensor: 12 | """functional""" 13 | x = input.clamp(min=1e-6, max=1.0 - 1e-6) 14 | return -(x * torch.log(x) + (1 - x) * torch.log(1 - x)).mean() if symmetrical else (-x * torch.log(x)).mean() 15 | 16 | 17 | class BackgroundEntropyLoss(torchmetrics.Metric): 18 | """torchmetrics implementation""" 19 | is_differentiable = True 20 | higher_is_better = False 21 | full_state_update = False 22 | 23 | def __init__(self, symmetrical: bool = False, **kwargs: Any) -> None: 24 | super().__init__(**kwargs) 25 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 27 | self.symmetrical = symmetrical 28 | 29 | def update(self, preds: torch.Tensor) -> None: 30 | """Update state with current alpha predictions.""" 31 | x = preds.clamp(min=1e-6, max=1.0 - 1e-6) 32 | if self.symmetrical: 33 | y = -(x * torch.log(x) + (1 - x) * torch.log(1 - x)).sum() 34 | else: 35 | y = (-x * torch.log(x)).sum() 36 | self.running_sum += y 37 | self.total += preds.numel() 38 | 39 | def compute(self) -> torch.Tensor: 40 | """Computes background entropy over state.""" 41 | return (self.running_sum / self.total) if self.total > 0 else (torch.tensor(0.0)) 42 | -------------------------------------------------------------------------------- /src/Optim/Losses/Base.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/LossesBase.py: Base Loss class for accumulation and logging.""" 4 | 5 | from typing import Any, Callable 6 | 7 | import torch 8 | 9 | import Framework 10 | from Optim.Losses.utils import QualityMetricItem, LossMetricItem 11 | 12 | 13 | class BaseLoss(torch.nn.Module): 14 | """Simple configurable loss container for accumulation and wandb logging""" 15 | 16 | def __init__(self, 17 | loss_metrics: list[LossMetricItem] | None = None, 18 | quality_metrics: list[QualityMetricItem] | None = None) -> None: 19 | super().__init__() 20 | self.loss_metrics: list[LossMetricItem] = loss_metrics or [] 21 | self.quality_metrics: list[QualityMetricItem] = quality_metrics or [] 22 | self.activate_logging: bool = Framework.config.TRAINING.WANDB.ACTIVATE 23 | 24 | def addLossMetric(self, name: str, metric: Callable, weight: float = None) -> None: 25 | self.loss_metrics.append(LossMetricItem( 26 | name=name, 27 | metric_func=metric, 28 | weight=weight 29 | )) 30 | 31 | def addQualityMetric(self, name: str, metric: Callable) -> None: 32 | self.quality_metrics.append(QualityMetricItem( 33 | name=name, 34 | metric_func=metric, 35 | )) 36 | 37 | def reset(self) -> None: 38 | for item in self.loss_metrics + self.quality_metrics: 39 | item.reset() 40 | 41 | def log(self, iteration: int, log_validation: bool) -> None: 42 | if self.activate_logging: 43 | for item in self.loss_metrics + self.quality_metrics: 44 | val_train, val_eval = item.getAverage() 45 | data = {'train': val_train} 46 | if log_validation: 47 | data['eval'] = val_eval 48 | Framework.wandb.log({f'{item.name}': data}, step=iteration) 49 | 50 | def forward(self, configurations: dict[str, dict[str, Any]]) -> torch.Tensor: 51 | try: 52 | if self.activate_logging: 53 | with torch.no_grad(): 54 | for item in self.quality_metrics: 55 | item.apply(train=self.training, accumulate=True, kwargs=configurations[item.name]) 56 | sublosses = [] 57 | for item in self.loss_metrics: 58 | sublosses.append(item.apply(train=self.training, accumulate=True, kwargs=configurations[item.name])) 59 | return torch.stack(sublosses).sum() 60 | except NameError: 61 | raise Framework.LossError(f'missing argument configuration for loss "{item.name}"') 62 | except TypeError: 63 | raise Framework.LossError(f'invalid argument configuration for loss "{item.name}"') 64 | except Exception as e: 65 | raise Framework.LossError(f'unexpected error occurred in loss "{item.name}": {e}') 66 | -------------------------------------------------------------------------------- /src/Optim/Losses/Charbonnier.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/Charbonnier.py: Charbonnier loss as in Mip-NeRF 360.""" 4 | 5 | import torch 6 | 7 | 8 | def charbonnier_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor: 9 | """Computes the Charbonnier loss as in Mip-NeRF 360.""" 10 | return (input - target).pow(2).add(eps).sqrt().mean() 11 | -------------------------------------------------------------------------------- /src/Optim/Losses/DSSIM.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/Losses/DSSIM.py: Loss based on the structural (dis-)similarity index measure (DSSIM = 1 - SSIM).""" 4 | 5 | from typing import Sequence, Literal 6 | 7 | import torch 8 | from torchmetrics.functional.image import structural_similarity_index_measure 9 | from functools import partial 10 | 11 | 12 | class DSSIMLoss: 13 | """SSIM-based perceptual loss called DSSIM (computed as DSSIM = 1 - SSIM).""" 14 | 15 | def __init__( 16 | self, 17 | gaussian_kernel: bool = True, 18 | sigma: float | Sequence[float] = 1.5, 19 | kernel_size: int | Sequence[int] = 11, 20 | reduction: Literal['elementwise_mean', 'sum', 'none', None] = 'elementwise_mean', 21 | data_range: float | tuple[float, float] | None = 1.0, # torchmetrics default is None 22 | k1: float = 0.01, 23 | k2: float = 0.03, 24 | return_full_image: bool = False, 25 | ) -> None: 26 | """Initialize loss function.""" 27 | self.return_full_image = return_full_image 28 | self.loss_function = partial( 29 | structural_similarity_index_measure, 30 | gaussian_kernel=gaussian_kernel, 31 | sigma=sigma, 32 | kernel_size=kernel_size, 33 | reduction=reduction, 34 | data_range=data_range, 35 | k1=k1, 36 | k2=k2, 37 | return_full_image=return_full_image, 38 | return_contrast_sensitivity=False, # not configurable for now 39 | ) 40 | 41 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 42 | """Calculate loss.""" 43 | if self.return_full_image: 44 | return 1.0 - self.loss_function(preds=input[None], target=target[None])[1][0] 45 | return 1.0 - self.loss_function(preds=input[None], target=target[None]) # should be (1.0 - SSIM) / 2.0 46 | -------------------------------------------------------------------------------- /src/Optim/Losses/DepthSmoothness.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/DepthSmoothness.py: smooth depth loss. adapted from https://kornia.readthedocs.io/en/v0.2.1/_modules/kornia/losses/depth_smooth.html.""" 4 | 5 | import torch 6 | 7 | 8 | def _gradient_x(img: torch.Tensor) -> torch.Tensor: 9 | return img[:, :, :, 1:-1] - img[:, :, :, 0:-2] 10 | 11 | 12 | def _gradient_y(img: torch.Tensor) -> torch.Tensor: 13 | return img[:, :, 1:-1, :] - img[:, :, 0:-2, :] 14 | 15 | 16 | def _laplace_x(img: torch.Tensor) -> torch.Tensor: 17 | mi = img[:, :, :, 1:-1] 18 | le = img[:, :, :, :-2] 19 | ri = img[:, :, :, 2:] 20 | return le + ri - (2 * mi) 21 | 22 | 23 | def _laplace_y(img: torch.Tensor) -> torch.Tensor: 24 | mi = img[:, :, 1:-1, :] 25 | le = img[:, :, :-2, :] 26 | ri = img[:, :, 2:, :] 27 | return le + ri - (2 * mi) 28 | 29 | 30 | def depthSmoothnessLoss( 31 | depth: torch.Tensor, 32 | image: torch.Tensor) -> torch.Tensor: 33 | 34 | # compute the gradients 35 | idepth_dx: torch.Tensor = _laplace_x(depth) 36 | idepth_dy: torch.Tensor = _laplace_y(depth) 37 | image_dx: torch.Tensor = _gradient_x(image) 38 | image_dy: torch.Tensor = _gradient_y(image) 39 | 40 | # compute image weights 41 | weights_x: torch.Tensor = torch.exp( 42 | -torch.mean(torch.abs(image_dx), dim=1, keepdim=True)) 43 | weights_y: torch.Tensor = torch.exp( 44 | -torch.mean(torch.abs(image_dy), dim=1, keepdim=True)) 45 | 46 | # apply image weights to depth 47 | smoothness_x: torch.Tensor = torch.abs(idepth_dx * weights_x) 48 | smoothness_y: torch.Tensor = torch.abs(idepth_dy * weights_y) 49 | return torch.mean(smoothness_x) + torch.mean(smoothness_y) 50 | -------------------------------------------------------------------------------- /src/Optim/Losses/Distortion.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/LossesDistortion.py: Distortion loss as introduced in MipNeRF360 / DVGOv2.""" 4 | 5 | from typing import Any 6 | import torch 7 | import torchmetrics 8 | from Methods.InstantNGP.CudaExtensions.VolumeRenderingV2 import DistortionLoss as DistortionLossAutogradFN 9 | 10 | 11 | class DistortionLoss(torch.nn.Module): 12 | 13 | def __init__(self) -> None: 14 | super().__init__() 15 | 16 | def __call__(self, ws: torch.Tensor, deltas: torch.Tensor, ts: torch.Tensor, rays_a: torch.Tensor) -> torch.Tensor: 17 | return DistortionLossAutogradFN.apply(ws, deltas, ts, rays_a).mean() 18 | 19 | 20 | class DistortionLossTorchMetrics(torchmetrics.Metric): 21 | """torchmetrics implementation of the efficient Distortion loss introduced in MipNeRF360 / DVGOv2""" 22 | is_differentiable = True 23 | higher_is_better = False 24 | full_state_update = False 25 | 26 | def __init__(self, **kwargs: Any) -> None: 27 | super().__init__(**kwargs) 28 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 29 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 30 | 31 | def update(self, ws: torch.Tensor, deltas: torch.Tensor, ts: torch.Tensor, rays_a: torch.Tensor) -> None: 32 | """Update state.""" 33 | x = DistortionLossAutogradFN.apply(ws, deltas, ts, rays_a) 34 | y = x.sum() 35 | self.running_sum = self.running_sum + y 36 | self.total += x.numel() 37 | 38 | def compute(self) -> torch.Tensor: 39 | """Computes distortion loss over state.""" 40 | return (self.running_sum / self.total) if self.total > 0 else torch.tensor(0.0) 41 | -------------------------------------------------------------------------------- /src/Optim/Losses/FusedDSSIM.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Optim/Losses/FusedDSSIM.py: Loss based on the structural (dis-)similarity index measure (DSSIM = 1 - SSIM). 5 | """ 6 | 7 | import torch 8 | 9 | from Thirdparty.FusedSSIM import fused_ssim 10 | 11 | 12 | def fused_dssim(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 13 | """Calculate loss.""" 14 | return 1.0 - fused_ssim(input[None], target[None]) # should be (1.0 - SSIM) / 2.0 15 | -------------------------------------------------------------------------------- /src/Optim/Losses/Magnitude.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/LossesMagnitude.py: Mean 1-norm over given dim.""" 4 | 5 | from typing import Any 6 | 7 | import torch 8 | import torchmetrics 9 | 10 | 11 | def magnitudeLoss(input: torch.Tensor, dim: int = 1) -> torch.Tensor: 12 | """functional""" 13 | if input is None: 14 | return torch.tensor(0.0, requires_grad=False) 15 | return torch.norm(input, dim=dim, keepdim=True, p=1).mean() 16 | 17 | 18 | class MagnitudeLoss(torchmetrics.Metric): 19 | """torchmetrics implementation""" 20 | is_differentiable = True 21 | higher_is_better = False 22 | full_state_update = False 23 | 24 | def __init__(self, dim: int, **kwargs: Any) -> None: 25 | super().__init__(**kwargs) 26 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 27 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 28 | self.dim = dim 29 | 30 | def update(self, preds: torch.Tensor) -> None: 31 | """torchmetrics update override""" 32 | x = torch.norm(preds, dim=self.dim, keepdim=True, p=1) 33 | self.running_sum += x.sum() 34 | self.total += x.numel() 35 | 36 | def compute(self) -> torch.Tensor: 37 | """torchmetrics compute override""" 38 | return (self.running_sum / self.total) if self.total > 0 else (torch.tensor(0.0)) 39 | -------------------------------------------------------------------------------- /src/Optim/Losses/Robust.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/Robust.py: Robust loss function described in https://arxiv.org/abs/1701.03077.""" 4 | 5 | import torch 6 | 7 | 8 | class RobustLoss: 9 | """General and Adaptive Robust Loss Function.""" 10 | 11 | def __init__( 12 | self, 13 | alpha: float, 14 | c: float, 15 | min_alpha: float = -1000.0 # resembles -inf 16 | ) -> None: 17 | """Initialize loss function.""" 18 | c_reciprocal = 1.0 / c 19 | if alpha == 2.0: 20 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(0.5).mean() 21 | elif alpha == 0.0: 22 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(0.5).add(1.0).log().mean() 23 | elif alpha < min_alpha: 24 | self.loss_function = lambda x, y: (1.0 - (x - y).mul(c_reciprocal).pow(2).mul(-0.5).exp()).mean() 25 | else: 26 | factor = abs(alpha - 2.0) / alpha 27 | exponent = alpha / 2.0 28 | scale = 1.0 / abs(alpha - 2.0) 29 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(scale).add(1.0).pow(exponent).sub(1.0).mul(factor).mean() 30 | 31 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 32 | """Calculate loss.""" 33 | return self.loss_function(input, target) 34 | -------------------------------------------------------------------------------- /src/Optim/Losses/VGG.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Optim/Losses/VGG.py: Perceptual loss based on VGG features. 5 | See https://arxiv.org/abs/1603.08155 or https://arxiv.org/abs/1609.04802 for details. 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Callable 10 | import torch 11 | from torchvision.models import VGG, vgg19, VGG16_Weights, VGG19_Weights 12 | 13 | 14 | @dataclass(frozen=True) 15 | class VGGLossConfig: 16 | """Configuration for the VGG loss.""" 17 | model_class: 'VGG' = vgg19 18 | used_weights: VGG16_Weights | VGG19_Weights = VGG19_Weights.IMAGENET1K_V1 19 | # used_blocks: tuple[slice] = (slice(0, 4), slice(4, 9), slice(9, 16), slice(16, 23)) # vgg16 20 | used_blocks: tuple[slice] = (slice(0, 4), slice(4, 9), slice(9, 18), slice(18, 27), slice(27, 36)) # vgg19 21 | loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.l1_loss 22 | 23 | 24 | class VGGLoss: 25 | """Perceptual loss based on VGG features.""" 26 | 27 | def __init__(self, config: VGGLossConfig = VGGLossConfig()) -> None: 28 | """Initialize the VGG loss.""" 29 | model = config.model_class(weights=config.used_weights).features.eval() 30 | for parameter in model.parameters(): 31 | parameter.requires_grad = False 32 | self.blocks = torch.nn.ModuleList([model[block] for block in config.used_blocks]) 33 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 34 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 35 | self.loss_function = config.loss_function 36 | 37 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 38 | """Calculate the VGG loss. Both input and target are expected to have shape (C, H, W) with C being RGB.""" 39 | input = (input[None] - self.mean) / self.std 40 | target = (target[None] - self.mean) / self.std 41 | loss = torch.tensor(0.0, requires_grad=False) 42 | for block in self.blocks: 43 | input = block(input) 44 | target = block(target) 45 | loss += self.loss_function(input, target) 46 | return loss 47 | -------------------------------------------------------------------------------- /src/Optim/Losses/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Optim/Lossesutils.py: Utilities for loss implementations.""" 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Any, Callable 7 | 8 | import Framework 9 | import torch 10 | import torchmetrics 11 | 12 | 13 | @dataclass 14 | class QualityMetricItem: 15 | """Used to store quality metrics in BaseLoss, only for training evaluation in wandb""" 16 | name: str 17 | metric_func: Callable 18 | 19 | _running_sum: list[torch.Tensor, torch.Tensor] = field(init=False) 20 | _num_iters: list[int, int] = field(init=False) 21 | 22 | def __post_init__(self): 23 | self.reset() 24 | 25 | def reset(self) -> None: 26 | self._running_sum = [torch.tensor(0.0, requires_grad=False), torch.tensor(0.0, requires_grad=False)] 27 | self._num_iters = [0, 0] 28 | 29 | def getAverage(self): 30 | return [(self._running_sum[i].item() / float(self._num_iters[i])) if self._num_iters[i] > 0 else 0.0 for i in range(2)] 31 | 32 | def _call_metric(self, kwargs: Any) -> torch.Tensor: 33 | return self.metric_func(**kwargs) 34 | 35 | def apply(self, train: bool, accumulate: bool, kwargs: Any) -> torch.Tensor: 36 | loss_val = self._call_metric(kwargs) 37 | if accumulate: 38 | idx = 0 if train else 1 39 | self._running_sum[idx] += loss_val.detach() 40 | self._num_iters[idx] += 1 41 | return loss_val 42 | 43 | 44 | @dataclass 45 | class LossMetricItem(QualityMetricItem): 46 | """Used to store individual loss terms in BaseLoss""" 47 | weight: float = 1.0 48 | 49 | def __post_init__(self): 50 | super().__post_init__() 51 | self.initial_weight = max(0.0, self.weight) if self.weight is not None else 0.0 52 | self.weight = self.initial_weight 53 | if isinstance(self.metric_func, torchmetrics.Metric) and not self.metric_func.is_differentiable: 54 | raise Framework.LossError(f'requested loss metric {self.name} (instance of {self.metric_func.__class__.__name__}) is not differentiable') 55 | 56 | def _call_metric(self, kwargs: Any) -> torch.Tensor: 57 | if self.weight > 0.0: 58 | return self.metric_func(**kwargs) * self.weight 59 | return torch.tensor(0.0) 60 | -------------------------------------------------------------------------------- /src/Optim/Samplers/DatasetSamplers.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Samplers/DatasetSamplers.py: Samplers returning a batch of rays from a dataset.""" 4 | import torch 5 | 6 | import Framework 7 | from Cameras.utils import CameraProperties 8 | from Datasets.Base import BaseDataset 9 | from Optim.Samplers.ImageSamplers import ImageSampler 10 | from Optim.Samplers.utils import IncrementalSequentialSampler, RandomSequentialSampler, SequentialSampler 11 | 12 | 13 | class DatasetSampler: 14 | 15 | def __init__(self, 16 | dataset: BaseDataset, 17 | random: bool = True, 18 | img_sampler_cls: type[ImageSampler] | None = None) -> None: 19 | self.mode = dataset.mode 20 | self.id_sampler = RandomSequentialSampler(num_elements=len(dataset)) if random else SequentialSampler(num_elements=len(dataset)) 21 | self.img_samplers = [img_sampler_cls(num_elements=(i.width * i.height)) for i in dataset] if img_sampler_cls else None 22 | 23 | def get(self, dataset: BaseDataset, ray_batch_size: int | None = None) -> dict[str, int | CameraProperties | None]: 24 | if dataset.mode != self.mode: 25 | raise Framework.SamplerError(f'DatasetSampler initialized for mode "{self.mode}" got dataset with active mode "{dataset.mode}"') 26 | sample_id = self.id_sampler.get(num_samples=1).item() 27 | camera_properties = dataset[sample_id] 28 | image_sampler = ray_ids = ray_batch = None 29 | if self.img_samplers and ray_batch_size is not None: 30 | image_sampler = self.img_samplers[sample_id] 31 | ray_ids = image_sampler.get(ray_batch_size).to(Framework.config.GLOBAL.DEFAULT_DEVICE) 32 | ray_batch = dataset.camera.setProperties(camera_properties).generateRays()[ray_ids] 33 | return { 34 | 'sample_id': sample_id, 35 | 'camera_properties': camera_properties, 36 | 'image_sampler': image_sampler, 37 | 'ray_ids': ray_ids, 38 | 'ray_batch': ray_batch 39 | } 40 | 41 | 42 | class RayPoolSampler: 43 | def __init__(self, 44 | dataset: BaseDataset, 45 | img_sampler_cls: type[ImageSampler]) -> None: 46 | self.mode = dataset.mode 47 | all_rays = dataset.getAllRays() 48 | self.image_sampler = img_sampler_cls(num_elements=all_rays.shape[0]) 49 | 50 | def get(self, dataset: BaseDataset, ray_batch_size: int) -> dict[str, None | ImageSampler | torch.Tensor]: 51 | if dataset.mode != self.mode: 52 | raise Framework.SamplerError(f'RayPoolSampler initialized for mode "{self.mode}" got dataset with active mode "{dataset.mode}"') 53 | sample_id = camera_properties = None 54 | rays_all = dataset.getAllRays() 55 | ray_ids = self.image_sampler.get(ray_batch_size).to(rays_all.device) 56 | ray_batch = rays_all[ray_ids].to(Framework.config.GLOBAL.DEFAULT_DEVICE) 57 | return { 58 | 'sample_id': sample_id, 59 | 'camera_properties': camera_properties, 60 | 'image_sampler': self.image_sampler, 61 | 'ray_ids': ray_ids, 62 | 'ray_batch': ray_batch 63 | } 64 | 65 | 66 | class IncrementalDatasetSampler(DatasetSampler): 67 | 68 | def __init__(self, 69 | dataset: BaseDataset, 70 | img_sampler_cls: type[ImageSampler] | None = None) -> None: 71 | super().__init__(dataset, False, img_sampler_cls) 72 | self.id_sampler = IncrementalSequentialSampler(num_elements=len(dataset)) 73 | -------------------------------------------------------------------------------- /src/Optim/Samplers/ImageSamplers.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Samplers/ImageSamplers.py: Samplers selceting a subset of rays from a given ray batch.""" 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import torch 8 | 9 | from Optim.Samplers.utils import RandomSequentialSampler, SequentialSampler 10 | 11 | 12 | class ImageSampler(ABC): 13 | """Abstract base class for image samplers.""" 14 | 15 | def __init__(self, num_elements: int) -> None: 16 | super().__init__() 17 | self.num_elements: int = num_elements 18 | 19 | @abstractmethod 20 | def get(self, ray_batch_size: int) -> torch.Tensor: 21 | pass 22 | 23 | def update(self, **_) -> None: 24 | pass 25 | 26 | 27 | class SequentialImageSampler(ImageSampler): 28 | 29 | def __init__(self, num_elements: int) -> None: 30 | super().__init__(num_elements) 31 | self.sampler = SequentialSampler(num_elements=self.num_elements) 32 | 33 | def get(self, ray_batch_size: int) -> torch.Tensor: 34 | return self.sampler.get(num_samples=ray_batch_size) 35 | 36 | 37 | class SequentialRandomImageSampler(SequentialImageSampler): 38 | 39 | def __init__(self, num_elements: int) -> None: 40 | super().__init__(num_elements) 41 | self.sampler = RandomSequentialSampler(num_elements=self.num_elements) 42 | 43 | 44 | class RandomImageSampler(ImageSampler): 45 | 46 | def get(self, ray_batch_size: int) -> torch.Tensor: 47 | return torch.randint(low=0, high=self.num_elements, size=(ray_batch_size,)) 48 | 49 | 50 | class MultinomialImageSampler(ImageSampler): 51 | 52 | def __init__(self, num_elements: int) -> None: 53 | super().__init__(num_elements) 54 | self.pdf = torch.ones(size=(self.num_elements,)) 55 | 56 | def get(self, ray_batch_size: int) -> torch.Tensor: 57 | return torch.multinomial(input=self.pdf, num_samples=ray_batch_size) 58 | 59 | @torch.no_grad() 60 | def update(self, ray_ids: torch.Tensor, weights: torch.Tensor, constant_addend: float = False) -> None: 61 | if constant_addend is not None: 62 | self.pdf += constant_addend 63 | self.pdf[ray_ids] = weights 64 | -------------------------------------------------------------------------------- /src/Optim/Samplers/RaySamplers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/src/Optim/Samplers/RaySamplers.py -------------------------------------------------------------------------------- /src/Optim/Samplers/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Samplers/utils.py: Utilities for image and ray sampling routines.""" 4 | 5 | import torch 6 | 7 | import Framework 8 | 9 | 10 | class SequentialSampler: 11 | def __init__(self, num_elements: int) -> None: 12 | self.num_elements: int = num_elements 13 | self.indices: torch.Tensor = torch.arange(self.num_elements) 14 | self.reset() 15 | 16 | def shuffle(self) -> None: 17 | pass 18 | 19 | def reset(self) -> None: 20 | self.current_id: int = 0 21 | self.shuffle() 22 | 23 | def get(self, num_samples: int) -> torch.Tensor: 24 | if num_samples > self.num_elements: 25 | raise Framework.SamplerError(f"cannot draw {num_samples} samples from {self.num_elements} elements") 26 | if self.current_id + num_samples > self.num_elements: 27 | self.reset() 28 | samples = self.indices[self.current_id:self.current_id + num_samples] 29 | self.current_id += num_samples 30 | return samples 31 | 32 | 33 | class RandomSequentialSampler(SequentialSampler): 34 | 35 | def shuffle(self) -> None: 36 | self.indices: torch.Tensor = self.indices[torch.randperm(self.num_elements)] 37 | 38 | 39 | class IncrementalSequentialSampler: 40 | 41 | def __init__(self, num_elements: int) -> None: 42 | self.num_elements: int = num_elements 43 | self.current_size: int = 0 44 | self.indices: torch.Tensor = torch.arange(self.num_elements) 45 | self.reset() 46 | 47 | def reset(self) -> None: 48 | self.current_size = min(self.current_size + 1, self.num_elements) 49 | self.current_indices: torch.Tensor = self.indices[:self.current_size] 50 | self.current_id: int = 0 51 | 52 | def get(self, num_samples: int) -> torch.Tensor: 53 | if num_samples > self.current_size: 54 | raise Framework.SamplerError(f"cannot draw {num_samples} samples from {self.current_size} elements") 55 | if self.current_id + num_samples > self.current_size: 56 | self.reset() 57 | samples = self.current_indices[self.current_id:self.current_id + num_samples] 58 | self.current_id += num_samples 59 | return samples 60 | -------------------------------------------------------------------------------- /src/Thirdparty/Apex.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Thirdparty/Apex.py: NVIDIA Apex (https://github.com/NVIDIA/apex). 5 | """ 6 | 7 | import Framework 8 | 9 | __extension_name__ = 'apex' 10 | __install_command__ = [ 11 | 'pip', 'install', 12 | '-v', '--disable-pip-version-check', '--no-cache-dir', '--no-build-isolation', 13 | '--config-settings', '--build-option=--cpp_ext', 14 | '--config-settings', '--build-option=--cuda_ext', 15 | 'git+https://github.com/NVIDIA/apex', 16 | ] 17 | 18 | try: 19 | from apex.optimizers import FusedAdam # noqa 20 | except ImportError: 21 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 22 | -------------------------------------------------------------------------------- /src/Thirdparty/DiffGaussianRasterization.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | from pathlib import Path 4 | 5 | import Framework 6 | 7 | __extension_name__ = 'DiffGaussianRasterization' 8 | REPO_URL = 'https://github.com/graphdeco-inria/diff-gaussian-rasterization' 9 | COMMIT_HASH = '59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d' 10 | __install_command__ = ['pip', 'install', f'git+{REPO_URL}@{COMMIT_HASH}'] 11 | 12 | try: 13 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer # noqa 14 | except ImportError: 15 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 16 | -------------------------------------------------------------------------------- /src/Thirdparty/FusedSSIM.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Thirdparty/FusedSSIM.py: fast cuda ssim implementation from https://github.com/rahul-goel/fused-ssim. 5 | """ 6 | 7 | import Framework 8 | 9 | __extension_name__ = 'FusedSSIM' 10 | __install_command__ = [ 11 | 'pip', 'install', 12 | 'git+https://github.com/rahul-goel/fused-ssim/', 13 | ] 14 | 15 | try: 16 | from fused_ssim import fused_ssim 17 | except ImportError: 18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 19 | -------------------------------------------------------------------------------- /src/Thirdparty/SimpleKNN.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | from pathlib import Path 4 | 5 | import Framework 6 | 7 | __extension_name__ = 'simple-knn' 8 | REPO_URL = 'https://github.com/camenduru/simple-knn' 9 | COMMIT_HASH = '44f764299fa305faf6ec5ebd99939e0508331503' 10 | __install_command__ = ['pip', 'install', f'git+{REPO_URL}@{COMMIT_HASH}'] 11 | 12 | try: 13 | from simple_knn._C import distCUDA2 # noqa 14 | except ImportError: 15 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 16 | -------------------------------------------------------------------------------- /src/Thirdparty/TinyCudaNN.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Thirdparty/TinyCudaNN.py: Fast fused MLP and input encoding from NVLabs (https://github.com/NVlabs/tiny-cuda-nn). 5 | """ 6 | 7 | import Framework 8 | 9 | __extension_name__ = 'tinycudann' 10 | __install_command__ = [ 11 | 'pip', 'install', 12 | 'git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch', 13 | ] 14 | 15 | try: 16 | from tinycudann import * # noqa 17 | except ImportError: 18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 19 | -------------------------------------------------------------------------------- /src/Thirdparty/TorchScatter.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Thirdparty/TorchScatter.py: Torch Scatter lib. 5 | """ 6 | 7 | import Framework 8 | import torch 9 | 10 | __extension_name__ = 'torch-scatter' 11 | __install_command__ = [ 12 | 'pip', 'install', 13 | 'torch-scatter', 14 | '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html', 15 | ] 16 | 17 | try: 18 | from torch_scatter import segment_csr # noqa 19 | except ImportError: 20 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__) 21 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/BulletTime.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Visual/Trajectories/BulletTime.py: A visualization for dynamic scenes, following a lemniscate trajectory while replaying time. 5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main). 6 | """ 7 | 8 | import torch 9 | 10 | import Framework 11 | from Cameras.Base import BaseCamera 12 | from Cameras.utils import CameraProperties 13 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory 14 | 15 | 16 | class bullet_time(CameraTrajectory): 17 | """A visualization for dynamic scenes, following a lemniscate trajectory while replaying time.""" 18 | 19 | def __init__(self, reference_pose_rel_id: float = 0.5, custom_lookat: torch.Tensor | None = None, 20 | custom_up: torch.Tensor | None = None, num_frames_per_rotation: float = 90, degree: float = 10, num_repeats: int = 2) -> None: 21 | super().__init__() 22 | self.reference_pose_rel_id = reference_pose_rel_id 23 | self.custom_lookat = custom_lookat 24 | self.custom_up = custom_up 25 | self.num_frames_per_rotation = num_frames_per_rotation 26 | self.degree = degree 27 | self.num_repeats = num_repeats 28 | 29 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 30 | """A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame.""" 31 | data: list[CameraProperties] = [] 32 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))] 33 | reference_camera.setProperties(reference_pose) 34 | lookat = self.custom_lookat 35 | if lookat is None: 36 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3] 37 | lookat = pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1]) 38 | up = self.custom_up if self.custom_up is not None else reference_camera.getUpVector()[:3] 39 | lemniscate_trajectory_c2ws = getLemniscateTrajectory( 40 | reference_pose, 41 | lookat=lookat.to(Framework.config.GLOBAL.DEFAULT_DEVICE), 42 | up=up.to(Framework.config.GLOBAL.DEFAULT_DEVICE), 43 | num_frames=self.num_frames_per_rotation, 44 | degree=self.degree, 45 | ) 46 | num_frames = self.num_frames_per_rotation * self.num_repeats 47 | for frame_idx in range(num_frames): 48 | data.append(CameraProperties( 49 | width=reference_pose.width, 50 | height=reference_pose.height, 51 | rgb=None, 52 | alpha=None, 53 | c2w=lemniscate_trajectory_c2ws[frame_idx % self.num_frames_per_rotation].clone(), 54 | focal_x=reference_pose.focal_x, 55 | focal_y=reference_pose.focal_y, 56 | distortion_parameters=reference_pose.distortion_parameters, 57 | principal_offset_x=0.0, 58 | principal_offset_y=0.0, 59 | timestamp=frame_idx / (num_frames - 1) 60 | )) 61 | return data 62 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/FancyZoom.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Visual/Trajectories/FancyZoom.py: A camera trajectory following the training poses, interrupted by zooms and spiral movements.""" 4 | 5 | import math 6 | import torch 7 | 8 | import Framework 9 | from Cameras.Base import BaseCamera 10 | from Cameras.utils import CameraProperties 11 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory 12 | 13 | 14 | class fancy_zoom(CameraTrajectory): 15 | """A camera trajectory following the training poses, interrupted by zooms and spiral movements.""" 16 | 17 | def __init__(self, num_breaks: int = 2, num_zoom_frames: int = 90, zoom_factor: float = 0.2, lemni_frames_per_rot: int = 60, lemni_degree: int = 3) -> None: 18 | super().__init__() 19 | self.num_breaks = num_breaks 20 | self.num_zoom_frames = num_zoom_frames 21 | self.zoom_factor = zoom_factor 22 | self.lemni_frames_per_rot = lemni_frames_per_rot 23 | self.lemni_degree = lemni_degree 24 | 25 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 26 | """Generates the camera trajectory using a list of reference poses.""" 27 | data: list[CameraProperties] = [] 28 | num_reference_frames = len(reference_poses) 29 | break_indices = torch.linspace(0, len(reference_poses), self.num_breaks + 2)[1:-1].int().tolist() 30 | for i in range(break_indices[0]): 31 | data.append(reference_poses[i].toSimple()) 32 | for j in range(len(break_indices)): 33 | reference = reference_poses[break_indices[j]] 34 | for i in range(self.num_zoom_frames): 35 | new = reference.toSimple() 36 | new.focal_x = new.focal_x + (new.focal_x * self.zoom_factor * math.sin((i / (self.num_zoom_frames - 1)) * 2 * math.pi)) 37 | new.focal_y = new.focal_y + (new.focal_y * self.zoom_factor * math.sin((i / (self.num_zoom_frames - 1)) * 2 * math.pi)) 38 | data.append(new) 39 | reference_camera.setProperties(reference) 40 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3] 41 | lookat = (pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1])).to(Framework.config.GLOBAL.DEFAULT_DEVICE) 42 | up = reference_camera.getUpVector().to(Framework.config.GLOBAL.DEFAULT_DEVICE) 43 | lemniscate_trajectory_c2ws = getLemniscateTrajectory(reference, lookat=lookat, up=up, num_frames=self.lemni_frames_per_rot, degree=self.lemni_degree) 44 | for i in lemniscate_trajectory_c2ws: 45 | new = reference.toSimple() 46 | new.c2w = i 47 | data.append(new) 48 | if j < len(break_indices) - 1: 49 | for i in range(break_indices[j], break_indices[j + 1]): 50 | data.append(reference_poses[i].toSimple()) 51 | for i in range(break_indices[-1], num_reference_frames): 52 | data.append(reference_poses[i].toSimple()) 53 | return data 54 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/NovelView.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Visual/Trajectories/NovelView.py: A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame. 5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main). 6 | """ 7 | 8 | import torch 9 | 10 | import Framework 11 | from Cameras.Base import BaseCamera 12 | from Cameras.utils import CameraProperties 13 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory 14 | 15 | 16 | class novel_view(CameraTrajectory): 17 | """A trajectory for dynamic scenes, replaying time for a single, fixed view.""" 18 | 19 | def __init__(self, reference_pose_rel_id: float = 0.5, custom_lookat: torch.Tensor | None = None, 20 | custom_up: torch.Tensor | None = None, num_frames_per_rotation: float = 90, degree: float = 30) -> None: 21 | super().__init__() 22 | self.reference_pose_rel_id = reference_pose_rel_id 23 | self.custom_lookat = custom_lookat 24 | self.custom_up = custom_up 25 | self.num_frames_per_rotation = num_frames_per_rotation 26 | self.degree = degree 27 | 28 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 29 | """A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame.""" 30 | data: list[CameraProperties] = [] 31 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))] 32 | reference_camera.setProperties(reference_pose) 33 | lookat = self.custom_lookat 34 | if lookat is None: 35 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3] 36 | lookat = pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1]) 37 | up = self.custom_up if self.custom_up is not None else reference_camera.getUpVector()[:3] 38 | lemniscate_trajectory_c2ws = getLemniscateTrajectory( 39 | reference_pose, 40 | lookat=lookat.to(Framework.config.GLOBAL.DEFAULT_DEVICE), 41 | up=up.to(Framework.config.GLOBAL.DEFAULT_DEVICE), 42 | num_frames=self.num_frames_per_rotation, 43 | degree=self.degree, 44 | ) 45 | for c2w in lemniscate_trajectory_c2ws: 46 | data.append(CameraProperties( 47 | width=reference_pose.width, 48 | height=reference_pose.height, 49 | rgb=None, 50 | alpha=None, 51 | c2w=c2w.clone(), 52 | focal_x=reference_pose.focal_x, 53 | focal_y=reference_pose.focal_y, 54 | distortion_parameters=reference_pose.distortion_parameters, 55 | principal_offset_x=0.0, 56 | principal_offset_y=0.0, 57 | timestamp=reference_pose.timestamp, 58 | )) 59 | return data 60 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/SpiralPath.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Visual/Trajectories/SpiralPath.py: A spiral camera trajectory for forward facing scenes. 5 | Used by the original NeRF for LLFF visualizations. 6 | """ 7 | 8 | import math 9 | import torch 10 | 11 | import Framework 12 | from Cameras.Base import BaseCamera 13 | from Cameras.utils import CameraProperties, createCameraMatrix, normalizeRays 14 | from Datasets.utils import getAveragePose 15 | from Visual.Trajectories.utils import CameraTrajectory 16 | 17 | 18 | class spiral_path(CameraTrajectory): 19 | """A spiral camera trajectory for forward facing scenes.""" 20 | 21 | def __init__(self, num_views: int = 120, num_rotations: int = 2, ) -> None: 22 | super().__init__() 23 | self.num_views: int = num_views 24 | self.num_rotations: int = num_rotations 25 | 26 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 27 | """Generates the camera trajectory using a list of reference poses.""" 28 | mean_focal_x: float = torch.tensor([c.focal_x for c in reference_poses]).mean().item() 29 | mean_focal_y: float = torch.tensor([c.focal_y for c in reference_poses]).mean().item() 30 | c2ws = torch.stack([c.c2w for c in reference_poses], dim=0).to(Framework.config.GLOBAL.DEFAULT_DEVICE) 31 | return createSpiralPath( 32 | view_matrices=c2ws, 33 | near_plane=reference_camera.near_plane, 34 | far_plane=reference_camera.far_plane, 35 | image_shape=(3, reference_poses[0].height, reference_poses[0].width), 36 | focal_x=mean_focal_x, 37 | focal_y=mean_focal_y, 38 | n_views=self.num_views, 39 | n_rots=self.num_rotations 40 | ) 41 | 42 | 43 | def createSpiralPath(view_matrices: torch.Tensor, near_plane: float, far_plane: float, image_shape: tuple[int, int, int], 44 | focal_x: float, focal_y: float, n_views: int, n_rots: int) -> list[CameraProperties]: 45 | """Creates views on spiral path, adapted from the original NeRF implementation.""" 46 | average_pose: torch.Tensor = getAveragePose(view_matrices) 47 | up: torch.Tensor = -normalizeRays(view_matrices[:, :3, 1].sum(0))[0] 48 | close_depth: float = near_plane * 0.9 49 | inf_depth: float = far_plane * 1.0 50 | dt: float = 0.75 51 | focal: float = 1.0 / ((1.0 - dt) / close_depth + dt / inf_depth) 52 | rads: torch.Tensor = 0.01 * torch.quantile(torch.abs(view_matrices[:, :3, 3]), q=0.9, dim=0) 53 | view_matrices_spiral: list[torch.Tensor] = [] 54 | rads: torch.Tensor = torch.tensor(list(rads) + [1.]) 55 | for theta in torch.linspace(0.0, 2.0 * math.pi * n_rots, n_views + 1)[:-1]: 56 | c: torch.Tensor = torch.mm( 57 | average_pose[:3, :4], 58 | (torch.tensor([torch.cos(theta), torch.sin(theta), torch.sin(theta * 0.5), 1.]) * rads)[:, None] 59 | ).squeeze() 60 | z: torch.Tensor = -normalizeRays(c - torch.mm(average_pose[:3, :4], torch.tensor([[0], [0], [-focal], [1.]])).squeeze())[0] 61 | view_matrices_spiral.append(createCameraMatrix(z, up, c)) 62 | return [ 63 | CameraProperties( 64 | width=image_shape[2], 65 | height=image_shape[1], 66 | rgb=None, 67 | alpha=None, 68 | c2w=c2w.float(), 69 | focal_x=focal_x, 70 | focal_y=focal_y 71 | ) 72 | for c2w in view_matrices_spiral 73 | ] 74 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/StabilizedTrain.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Visual/Trajectories/StabilizedTrain.py: A camera trajectory following the training poses, stabilizing poses over a window.""" 4 | 5 | import torch 6 | 7 | import Framework 8 | from Cameras.Base import BaseCamera 9 | from Cameras.utils import CameraProperties 10 | from Datasets.utils import getAveragePose 11 | from Visual.Trajectories.utils import CameraTrajectory 12 | 13 | 14 | class stabilized_train(CameraTrajectory): 15 | """A camera trajectory following the training poses, stabilizing poses over a window.""" 16 | 17 | def __init__(self, window: int = 5) -> None: 18 | super().__init__() 19 | if window % 2 == 0: 20 | raise Framework.VisualizationError('Window size must be an odd number.') 21 | self.half_window = window // 2 22 | 23 | def _generate(self, _: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 24 | """Generates the camera trajectory using a list of reference poses.""" 25 | data: list[CameraProperties] = [] 26 | poses_all = torch.stack([c.c2w for c in reference_poses], dim=0) 27 | for i, c in enumerate(reference_poses): 28 | c2w = getAveragePose(poses_all[max(0, i - self.half_window):min(len(reference_poses), i + self.half_window)]) 29 | prop = c.toSimple() 30 | prop.c2w = c2w 31 | data.append(prop) 32 | return data 33 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/StabilizedView.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """ 4 | Visual/Trajectories/StabilizedView.py: A trajectory for dynamic scenes, replaying time for a single, fixed view. 5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main). 6 | """ 7 | 8 | from Cameras.Base import BaseCamera 9 | from Cameras.utils import CameraProperties 10 | from Visual.Trajectories.utils import CameraTrajectory 11 | 12 | 13 | class stabilized_view(CameraTrajectory): 14 | """A trajectory for dynamic scenes, replaying time for a single, fixed view.""" 15 | 16 | def __init__(self, reference_pose_rel_id: float = 0.5) -> None: 17 | super().__init__() 18 | self.reference_pose_rel_id = reference_pose_rel_id 19 | 20 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 21 | """Generates the camera trajectory using a list of reference poses.""" 22 | data: list[CameraProperties] = [] 23 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))] 24 | for camera_properties in reference_poses: 25 | data.append(CameraProperties( 26 | width=reference_pose.width, 27 | height=reference_pose.height, 28 | rgb=None, 29 | alpha=None, 30 | c2w=reference_pose.c2w.clone(), 31 | focal_x=reference_pose.focal_x, 32 | focal_y=reference_pose.focal_y, 33 | principal_offset_x=0.0, 34 | principal_offset_y=0.0, 35 | distortion_parameters=reference_pose.distortion_parameters, 36 | timestamp=camera_properties.timestamp 37 | )) 38 | return data 39 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual.Trajectories is a package for adding custom camera trajectories to a dataset, enabling convenient visualization in the gui and inference scripts. 3 | """ 4 | 5 | import importlib 6 | from pathlib import Path 7 | from .utils import CameraTrajectory as CameraTrajectory 8 | 9 | base_path: Path = Path(__file__).resolve().parents[0] 10 | for module in [str(i.name)[:-3] for i in base_path.iterdir() if i.is_file() and i.name not in ['__init__.py', 'utils.py']]: 11 | importlib.import_module(f'Visual.Trajectories.{module}') 12 | del base_path, module 13 | -------------------------------------------------------------------------------- /src/Visual/Trajectories/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Visual/Trajectories/utils.py: Utilities for visualization tasks.""" 4 | 5 | from abc import ABC, abstractmethod 6 | import math 7 | from typing import Type 8 | 9 | import torch 10 | 11 | import Framework 12 | from Logging import Logger 13 | from Cameras.Base import BaseCamera 14 | from Cameras.utils import CameraProperties, createLookAtMatrix 15 | from Datasets.Base import BaseDataset 16 | 17 | 18 | class CameraTrajectory(ABC): 19 | 20 | _options: list[str] = [] 21 | 22 | def __init__(self) -> None: 23 | super().__init__() 24 | self._trajectory: list[CameraProperties] = [] 25 | self.name: str = self.__class__.__name__ 26 | 27 | @classmethod 28 | def listOptions(cls) -> list[str]: 29 | """Lists all available camera trajectories.""" 30 | if not cls._options: 31 | cls._options = [cls.__name__ for cls in CameraTrajectory.__subclasses__()] 32 | return cls._options 33 | 34 | @classmethod 35 | def get(cls, trajectory_name: str) -> Type['CameraTrajectory']: 36 | """Returns a camera trajectory class by its name.""" 37 | options = cls.listOptions() 38 | if trajectory_name not in options: 39 | raise Framework.VisualizationError(f'Unknown camera trajectory type: {trajectory_name}.\nAvailable options are: {options}') 40 | for trajectory in cls.__subclasses__(): 41 | if trajectory.__name__ == trajectory_name: 42 | return trajectory 43 | 44 | @abstractmethod 45 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]: 46 | """Generates the camera trajectory using a list of reference poses.""" 47 | pass 48 | 49 | def generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> None: 50 | """Generates the camera trajectory using a list of reference poses.""" 51 | Logger.logInfo(f'generating {self.name} trajectory...') 52 | self._trajectory = self._generate(reference_camera, reference_poses) 53 | 54 | def addTo(self, dataset: BaseDataset, reference_set: str | None = 'train') -> BaseDataset: 55 | """Adds the camera trajectory to a dataset.""" 56 | if self.name in dataset.subsets: 57 | Logger.logInfo(f'{self.name} trajectory already exists in dataset.') 58 | return dataset 59 | if not self._trajectory: 60 | if reference_set is None: 61 | reference_poses = [*dataset.data['train'], *dataset.data['val'], *dataset.data['test']] 62 | else: 63 | reference_poses = dataset.data[reference_set] 64 | self.generate(reference_camera=dataset.camera, reference_poses=reference_poses) 65 | dataset.subsets.append(self.name) 66 | dataset.data[self.name] = self._trajectory 67 | return dataset 68 | 69 | 70 | def getLemniscateTrajectory( 71 | reference_camera: CameraProperties, 72 | lookat: torch.Tensor, 73 | up: torch.Tensor, 74 | num_frames: int, 75 | degree: float, 76 | ) -> list[torch.Tensor]: 77 | reference_camera = reference_camera.toDefaultDevice() 78 | camera_position = reference_camera.T 79 | a = torch.norm(camera_position - lookat) * math.tan(degree / 360 * math.pi) 80 | # Lemniscate curve in camera space. Starting at the origin. 81 | positions = torch.stack([ 82 | torch.tensor([ 83 | a * math.cos(t) / (1 + math.sin(t) ** 2), 84 | a * math.cos(t) * math.sin(t) / (1 + math.sin(t) ** 2), 85 | 0, 86 | ]) for t in (torch.linspace(0, 2 * math.pi, num_frames) + math.pi / 2) 87 | ], dim=0) 88 | # Transform to world space. 89 | positions = torch.matmul(reference_camera.R.T, positions[..., None])[..., 0] + camera_position 90 | cameras = [createLookAtMatrix(p, lookat, up) for p in positions] 91 | return cameras 92 | -------------------------------------------------------------------------------- /src/Visual/utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | """Visual/utils.py: Utilities for visualization tasks.""" 4 | 5 | from pathlib import Path 6 | import av 7 | from typing import Any 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from Logging import Logger 13 | from Visual.ColorMap import ColorMap 14 | 15 | 16 | def pseudoColorDepth(color_map: str, depth: Tensor, near_far: tuple[float, float] | None = None, alpha: Tensor | None = None, interpolate: bool = False) -> Tensor: 17 | """Produces a pseudo-colorized depth image for the given depth tensor.""" 18 | # correct depth based on alpha mask 19 | if alpha is not None: 20 | depth = depth.clone() 21 | depth[alpha > 1e-5] = depth[alpha > 1e-5] / alpha[alpha > 1e-5] 22 | # normalize depth to [0, 1] 23 | if near_far is None: 24 | # calculate near and far planes from depth map 25 | masked_depth = depth[alpha > 0.99] if alpha is not None else depth 26 | near_plane = masked_depth.min().item() if masked_depth.numel() > 0 else 0.0 27 | far_plane = masked_depth.max().item() if masked_depth.numel() > 0 else 1.0 28 | else: 29 | near_plane, far_plane = near_far 30 | depth: Tensor = torch.clamp((depth - near_plane) / (far_plane - near_plane), min=0.0, max=1.0) 31 | # apply color map 32 | depth = ColorMap.apply(depth, color_map, interpolate) 33 | # mask color with alpha 34 | if alpha is not None: 35 | depth *= alpha 36 | return depth 37 | 38 | 39 | class VideoWriter: 40 | """Wrapper class to facilitate creation of videos using PyAV.""" 41 | 42 | def __init__( 43 | self, 44 | video_paths: Path | list[Path], 45 | width: int, 46 | height: int, 47 | fps: int, 48 | bitrate: int, 49 | video_codec: str = 'libx264rgb', 50 | options: dict[str, Any] | None = None 51 | ) -> None: 52 | self.containers = [] 53 | self.streams = [] 54 | padding_right = width % 2 55 | padding_bottom = height % 2 56 | self.pad = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom)) 57 | if not isinstance(video_paths, list): 58 | video_paths = [video_paths] 59 | for video_path in video_paths: 60 | container = av.open(str(video_path), mode='w') 61 | stream = container.add_stream(video_codec, rate=fps) 62 | stream.codec_context.bit_rate = bitrate * 1e3 # kbps 63 | stream.width = width + padding_right 64 | stream.height = height + padding_bottom 65 | stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24' 66 | stream.options = options or {} 67 | self.containers.append(container) 68 | self.streams.append(stream) 69 | 70 | def addFrame(self, frames: Tensor | list[Tensor]) -> None: 71 | """Adds given frames to the internal video streams.""" 72 | if not isinstance(frames, list): 73 | frames = [frames] 74 | if len(self.streams) != len(frames): 75 | Logger.logWarning( 76 | f'number of frames does not match with number of video streams ({len(frames)} vs. {len(self.streams)})' 77 | ) 78 | for frame, stream, container in zip(frames, self.streams, self.containers): 79 | frame = self.pad((frame * 255)).byte().permute(1, 2, 0).cpu().numpy() 80 | frame = av.VideoFrame.from_ndarray(frame, format="rgb24") 81 | frame.pict_type = 0 82 | for packet in stream.encode(frame): 83 | container.mux(packet) 84 | 85 | def close(self) -> None: 86 | """Flushes all streams and closes all containers.""" 87 | for stream, container in zip(self.streams, self.containers): 88 | for packet in stream.encode(): 89 | container.mux(packet) 90 | container.close() 91 | --------------------------------------------------------------------------------