├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── configs
├── .gitignore
├── gs_garden.yaml
├── ingp_lego.yaml
└── nerf_lego.yaml
├── dataset
└── .gitignore
├── output
└── .gitignore
├── resources
├── .gitignore
├── icg_logo.bmp
├── nerficg_banner.png
├── teaser_videos
│ ├── dnpc.gif
│ ├── dnpc.mp4
│ ├── inpc.gif
│ └── inpc.mp4
└── train_demo.gif
├── scripts
├── colmap.py
├── condaEnv.sh
├── cutie.py
├── defaultConfig.py
├── generateTables.py
├── gui.py
├── inference.py
├── install.py
├── monocularDepth.py
├── raft.py
├── sequentialTrain.py
├── train.py
├── utils.py
└── vggsfm.py
└── src
├── Cameras
├── Base.py
├── Equirectangular.py
├── NDC.py
├── ODS.py
├── Perspective.py
├── PerspectiveStereo.py
└── utils.py
├── Datasets
├── Base.py
├── Colmap.py
├── DNeRF.py
├── LLFF.py
├── MipNeRF360.py
├── NeRF.py
├── VGGSfM.py
├── iPhone.py
└── utils.py
├── Framework.py
├── Implementations.py
├── Logging.py
├── Methods
├── Base
│ ├── GuiTrainer.py
│ ├── Model.py
│ ├── Renderer.py
│ ├── Trainer.py
│ └── utils.py
├── GaussianSplatting
│ ├── Loss.py
│ ├── Model.py
│ ├── Renderer.py
│ ├── Trainer.py
│ ├── __init__.py
│ └── utils.py
├── HierarchicalNeRF
│ ├── Loss.py
│ ├── Model.py
│ ├── Renderer.py
│ ├── Trainer.py
│ ├── __init__.py
│ └── utils.py
├── InstantNGP
│ ├── CudaExtensions
│ │ └── VolumeRenderingV2
│ │ │ ├── __init__.py
│ │ │ ├── csrc
│ │ │ ├── binding.cpp
│ │ │ ├── include
│ │ │ │ ├── helper_math.h
│ │ │ │ └── utils.h
│ │ │ ├── intersection.cu
│ │ │ ├── losses.cu
│ │ │ ├── raymarching.cu
│ │ │ ├── setup.py
│ │ │ └── volumerendering.cu
│ │ │ ├── custom_functions.py
│ │ │ └── setup.py
│ ├── Loss.py
│ ├── Model.py
│ ├── Renderer.py
│ ├── Trainer.py
│ ├── __init__.py
│ └── utils.py
└── NeRF
│ ├── Loss.py
│ ├── Model.py
│ ├── Renderer.py
│ ├── Trainer.py
│ ├── __init__.py
│ └── utils.py
├── Optim
├── AdamUtils.py
├── GradientScaling.py
├── Losses
│ ├── BackgroundEntropy.py
│ ├── Base.py
│ ├── Charbonnier.py
│ ├── DSSIM.py
│ ├── DepthSmoothness.py
│ ├── Distortion.py
│ ├── FusedDSSIM.py
│ ├── Magnitude.py
│ ├── Robust.py
│ ├── VGG.py
│ └── utils.py
├── MaskedMetrics.py
└── Samplers
│ ├── DatasetSamplers.py
│ ├── ImageSamplers.py
│ ├── RaySamplers.py
│ └── utils.py
├── Thirdparty
├── Apex.py
├── DiffGaussianRasterization.py
├── FusedSSIM.py
├── SimpleKNN.py
├── TinyCudaNN.py
└── TorchScatter.py
└── Visual
├── ColorMap.py
├── Trajectories
├── BulletTime.py
├── Ellipse.py
├── FancyZoom.py
├── NovelView.py
├── SpiralPath.py
├── StabilizedTrain.py
├── StabilizedView.py
├── __init__.py
└── utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
2 | *.pdf
3 | *.pyc
4 | *.aux
5 | *.yaml
6 | *.mp4
7 | *.txt
8 | *.egg
9 | *.so
10 | *.egg-info
11 | *.ipynb
12 | *.log
13 | !resources/*
14 | .vscode
15 | .DS_Store
16 | .idea/
17 | build/
18 | dist/
19 | backup/
20 | .cache/
21 | .pytest_cache/
22 | imgui.ini
23 | __pycache__/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "src/ICGui"]
2 | path = src/ICGui
3 | url = ../../nerficg-project/icgui
4 | [submodule "src/CudaUtils"]
5 | path = src/CudaUtils
6 | url = ../../nerficg-project/cuda-utils
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Moritz Kappel, Florian Hahlbohm
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
6 |
7 |     [](https://opensource.org/licenses/MIT)
8 |
9 | A flexible Pytorch framework for simple and efficient implementation of neural radiance fields and rasterization-based view synthesis methods, including a GUI for interactive rendering.
10 |
11 |
19 |
20 |
21 | ## Getting Started
22 |
23 | - This repository uses submodules, clone using one of the following options:
24 | ```shell
25 | # HTTPS
26 | git clone https://github.com/nerficg-project/nerficg.git --recursive && cd nerficg
27 | ```
28 | or
29 | ```shell
30 | # SSH
31 | git clone git@github.com:nerficg-project/nerficg.git --recursive && cd nerficg
32 | ```
33 |
34 | - Install the global dependencies listed in `scripts/condaEnv.sh`, or automatically create a new conda environment by executing the script:
35 | ```shell
36 | ./scripts/condaEnv.sh && conda activate nerficg
37 | ```
38 |
39 | - To install all additional dependencies for a specific method, run:
40 | ```shell
41 | ./scripts/install.py -m
42 | ```
43 | or use
44 | ```shell
45 | ./scripts/install.py -e
46 | ```
47 | to only install a specific extension.
48 |
49 | - [optional] To use our training visualizations with [Weights & Biases](https://wandb.ai/site), run the following command and enter your account identifier:
50 | ```shell
51 | wandb login
52 | ```
53 |
54 |
55 | ## Creating a Configuration File
56 |
57 | To create a configuration file for training, run
58 | ```
59 | ./scripts/defaultConfig.py -m -d -o
60 | ```
61 | where `` and `` match one of the items in the `src/Methods` and `src/Datasets` directories, respectively.
62 | The resulting configuration file `.yaml` will be available in the `configs` directory and can be customized as needed.
63 | To create a directory of configuration files for all scenes of a dataset, use the `-a` flag. This requires the full dataset to be available in the `dataset` directory.
64 |
65 |
66 | ## Training a New Model
67 |
68 | To train a new model from a configuration file, run:
69 | ```
70 | ./scripts/train.py -c configs/.yaml
71 | ```
72 | The resulting images and model checkpoints will be saved to the `output` directory.
73 |
74 | To train multiple models from a directory or list of configuration files, use the `scripts/sequentialTrain.py` script with the `-d` or `-c` flag respectively.
75 |
76 |
77 | ## Training on custom image sequences
78 |
79 | If you want to train on a custom image sequence, create a new directory with an `images` subdirectory containing all training images.
80 | Then you can prepare the image sequence using the provided [COLMAP](https://colmap.github.io) script, including various preprocessing options like monocular depth estimation, image segmentation and optical flow.
81 | Run
82 | ```
83 | ./scripts/colmap.py -h
84 | ```
85 | to see all available flags and options. Alternatively, you can try to use the `scripts/vggsfm.py` script if COLMAP’s SfM pipeline fails.
86 |
87 | After calibration, the custom dataset can be loaded by setting `Colmap` or `VGGSfM` as `GLOBAL.DATASET_TYPE` in the config file and entering the correct directory path in the config file under `DATASET.PATH`.
88 |
89 |
90 | ## Inference and evaluation
91 |
92 | We provide multiple scripts for easy model inference and performance evaluation after training.
93 | Use the `scripts/inference.py` script to render output images for individual subsets (train/test/eval) or custom camera trajectories defined in `src/Visual/Trajectories` using the `-s` option.
94 | Additional rendering performance benchmarking and metric calculation is available using the `-b` and `-m` flags respectively .
95 |
96 | The `scripts/generateTables.py` script further enables consistent metric calculation over multiple pre-generated output image directories (e.g. to compare multiple methods against GT), and automatically generates LaTeX code for tables containing the resulting values.
97 | Use `-h` to see the available options for all scripts.
98 |
99 |
100 | ## Graphical User Interface
101 |
102 | To inspect a pretrained model in our GUI, make sure the GUI submodule is initialized, run
103 | ```
104 | ./scripts/gui.py
105 | ```
106 | and select the generated output directory.
107 |
108 | Some methods support live GUI interaction during optimization. To enable live GUI support, activate the `TRAINING.GUI.ACTIVATE` flag in your config file.
109 |
110 |
111 | ## Frequently Asked Questions (FAQ)
112 |
113 | __Q:__ What coordinate system does the framework use internally?
114 |
115 | __A:__ The framework uses a left-handed coordinate system for all internal calculations:
116 | - World space: x-right, y-forward, z-down (left-handed)
117 | - Camera space: x-right, y-down, z-backward (left-handed)
118 |
119 |
126 |
127 |
128 | ## Acknowledgments
129 |
130 | We started working on this project in 2021. Over the years many projects have inspired and helped us to develop this framework.
131 | Apart from any reference you might find in our source code, we would specifically like to thank all authors of the following projects for their great work:
132 | - [NeRF: Neural Radiance Fields](https://github.com/bmild/nerf)
133 | - [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch)
134 | - [Instant Neural Graphics Primitives](https://github.com/NVlabs/instant-ngp)
135 | - [ngp_pl](https://github.com/kwea123/ngp_pl.git)
136 | - [MultiNeRF: A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF](https://github.com/google-research/multinerf)
137 | - [torch_efficient_distloss](https://github.com/sunset1995/torch_efficient_distloss)
138 | - [3D Gaussian Splatting for Real-Time Radiance Field Rendering](https://github.com/graphdeco-inria/gaussian-splatting)
139 | - [CamP Zip-NeRF: A Code Release for CamP and Zip-NeRF](https://github.com/jonbarron/camp_zipnerf/)
140 | - [ADOP: Approximate Differentiable One-Pixel Point Rendering](https://github.com/darglein/ADOP)
141 |
142 |
143 | ## License and Citation
144 |
145 | This framework is licensed under the MIT license (see [LICENSE](LICENSE)).
146 |
147 | If you use it in your research projects, please consider a citation:
148 | ```bibtex
149 | @software{nerficg,
150 | author = {Kappel, Moritz and Hahlbohm, Florian and Scholz, Timon},
151 | license = {MIT},
152 | month = {1},
153 | title = {NeRFICG},
154 | url = {https://github.com/nerficg-project},
155 | version = {1.0},
156 | year = {2025}
157 | }
158 | ```
--------------------------------------------------------------------------------
/configs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !gs_garden.yaml
4 | !nerf_lego.yaml
5 | !ingp_lego.yaml
--------------------------------------------------------------------------------
/configs/gs_garden.yaml:
--------------------------------------------------------------------------------
1 | GLOBAL:
2 | LOG_LEVEL: 2
3 | GPU_INDICES:
4 | - 0
5 | RANDOM_SEED: 1618033989
6 | ANOMALY_DETECTION: false
7 | FILTER_WARNINGS: true
8 | METHOD_TYPE: GaussianSplatting
9 | DATASET_TYPE: MipNeRF360
10 | MODEL:
11 | SH_DEGREE: 3
12 | RENDERER:
13 | USE_FUSED_COVARIANCE_COMPUTATION: true
14 | USE_FUSED_SH_CONVERSION: true
15 | SCALE_MODIFIER: 1.0
16 | DISABLE_SH0: false
17 | DISABLE_SH1: false
18 | DISABLE_SH2: false
19 | DISABLE_SH3: false
20 | USE_BAKED_COVARIANCE: false
21 | TRAINING:
22 | LOAD_CHECKPOINT: null
23 | MODEL_NAME: 3dgs_garden
24 | NUM_ITERATIONS: 30_000
25 | ACTIVATE_TIMING: false
26 | RUN_VALIDATION: false
27 | BACKUP:
28 | FINAL_CHECKPOINT: true
29 | RENDER_TESTSET: true
30 | RENDER_TRAINSET: false
31 | RENDER_VALSET: false
32 | VISUALIZE_ERRORS: false
33 | INTERVAL: -1
34 | TRAINING_STATE: false
35 | WANDB:
36 | ACTIVATE: false
37 | ENTITY: null
38 | PROJECT: nerficg
39 | LOG_IMAGES: true
40 | INDEX_VALIDATION: -1
41 | INDEX_TRAINING: -1
42 | INTERVAL: 1000
43 | SWEEP_MODE:
44 | ACTIVE: false
45 | START_ITERATION: 999
46 | ITERATION_STRIDE: 1000
47 | GUI:
48 | ACTIVATE: true
49 | RENDER_INTERVAL: 5
50 | GUI_STATUS_ENABLED: true
51 | GUI_STATUS_INTERVAL: 20
52 | SKIP_GUI_SETUP: false
53 | FPS_ROLLING_AVERAGE_SIZE: 100
54 | LEARNING_RATE_POSITION_INIT: 0.00016
55 | LEARNING_RATE_POSITION_FINAL: 1.6e-06
56 | LEARNING_RATE_POSITION_DELAY_MULT: 0.01
57 | LEARNING_RATE_POSITION_MAX_STEPS: 30_000
58 | LEARNING_RATE_FEATURE: 0.0025
59 | LEARNING_RATE_OPACITY: 0.05
60 | LEARNING_RATE_SCALING: 0.005
61 | LEARNING_RATE_ROTATION: 0.001
62 | PERCENT_DENSE: 0.01
63 | OPACITY_RESET_INTERVAL: 3_000
64 | DENSIFY_START_ITERATION: 500
65 | DENSIFY_END_ITERATION: 15_000
66 | DENSIFICATION_INTERVAL: 100
67 | DENSIFY_GRAD_THRESHOLD: 0.0002
68 | LOSS:
69 | LAMBDA_L1: 0.8
70 | LAMBDA_DSSIM: 0.2
71 | DATASET:
72 | PATH: dataset/mipnerf360/garden
73 | IMAGE_SCALE_FACTOR: 0.25
74 | NORMALIZE_CUBE: null
75 | NORMALIZE_RECENTER: false
76 | PRECOMPUTE_RAYS: false
77 | TO_DEVICE: true
78 | BACKGROUND_COLOR:
79 | - 0.0
80 | - 0.0
81 | - 0.0
82 | NEAR_PLANE: 0.2
83 | FAR_PLANE: 100.0
84 | TEST_STEP: 8
85 | APPLY_PCA: true
86 | APPLY_PCA_RESCALE: false
87 | USE_PRECOMPUTED_DOWNSCALING: true
88 |
--------------------------------------------------------------------------------
/configs/ingp_lego.yaml:
--------------------------------------------------------------------------------
1 | GLOBAL:
2 | LOG_LEVEL: 2
3 | GPU_INDICES:
4 | - 0
5 | RANDOM_SEED: 1618033989
6 | ANOMALY_DETECTION: false
7 | FILTER_WARNINGS: true
8 | METHOD_TYPE: InstantNGP
9 | DATASET_TYPE: NeRF
10 | MODEL:
11 | SCALE: 0.5
12 | RESOLUTION: 128
13 | CENTER:
14 | - 0.0
15 | - 0.0
16 | - 0.0
17 | HASHMAP_NUM_LEVELS: 16
18 | HASHMAP_NUM_FEATURES_PER_LEVEL: 2
19 | HASHMAP_LOG2_SIZE: 19
20 | HASHMAP_BASE_RESOLUTION: 16
21 | HASHMAP_TARGET_RESOLUTION: 2048
22 | NUM_DENSITY_OUTPUT_FEATURES: 16
23 | NUM_DENSITY_NEURONS: 64
24 | NUM_DENSITY_LAYERS: 1
25 | DIR_SH_ENCODING_DEGREE: 4
26 | NUM_COLOR_NEURONS: 64
27 | NUM_COLOR_LAYERS: 2
28 | RENDERER:
29 | MAX_SAMPLES: 1024
30 | EXPONENTIAL_STEPS: false
31 | DENSITY_THRESHOLD: 0.01
32 | TRAINING:
33 | LOAD_CHECKPOINT: null
34 | MODEL_NAME: ingp_lego
35 | NUM_ITERATIONS: 50_000
36 | ACTIVATE_TIMING: false
37 | RUN_VALIDATION: false
38 | BACKUP:
39 | FINAL_CHECKPOINT: true
40 | RENDER_TESTSET: true
41 | RENDER_TRAINSET: false
42 | RENDER_VALSET: false
43 | VISUALIZE_ERRORS: false
44 | INTERVAL: -1
45 | TRAINING_STATE: false
46 | WANDB:
47 | ACTIVATE: false
48 | ENTITY: null
49 | PROJECT: nerficg
50 | LOG_IMAGES: true
51 | INDEX_VALIDATION: -1
52 | INDEX_TRAINING: -1
53 | INTERVAL: 1000
54 | SWEEP_MODE:
55 | ACTIVE: false
56 | START_ITERATION: 999
57 | ITERATION_STRIDE: 1000
58 | RENDER_OCCUPANCY_GRIDS: false
59 | GUI:
60 | ACTIVATE: true
61 | RENDER_INTERVAL: 30
62 | GUI_STATUS_ENABLED: true
63 | GUI_STATUS_INTERVAL: 20
64 | SKIP_GUI_SETUP: false
65 | FPS_ROLLING_AVERAGE_SIZE: 100
66 | TARGET_BATCH_SIZE: 262_144
67 | WARMUP_STEPS: 256
68 | DENSITY_GRID_UPDATE_INTERVAL: 16
69 | LEARNING_RATE: 0.01
70 | LEARNING_RATE_DECAY_START: 20_000
71 | LEARNING_RATE_DECAY_INTERVAL: 10_000
72 | LEARNING_RATE_DECAY_BASE: 0.33
73 | ADAM_EPS: 1.0e-15
74 | USE_APEX: false
75 | DATASET:
76 | PATH: dataset/nerf_synthetic/lego
77 | IMAGE_SCALE_FACTOR: null
78 | NORMALIZE_CUBE: 2.66666666667 # NeRF synthetic cameras are inside [-4, 4]^3, geometry is only in the [-1.5, 1.5]^3
79 | NORMALIZE_RECENTER: false
80 | PRECOMPUTE_RAYS: true
81 | TO_DEVICE: true
82 | BACKGROUND_COLOR:
83 | - 0.0
84 | - 0.0
85 | - 0.0
86 | LOAD_TESTSET_DEPTHS: false
87 |
--------------------------------------------------------------------------------
/configs/nerf_lego.yaml:
--------------------------------------------------------------------------------
1 | GLOBAL:
2 | LOG_LEVEL: 2
3 | GPU_INDICES:
4 | - 0
5 | RANDOM_SEED: 1618033989
6 | ANOMALY_DETECTION: false
7 | FILTER_WARNINGS: true
8 | METHOD_TYPE: NeRF
9 | DATASET_TYPE: NeRF
10 | MODEL:
11 | NUM_LAYERS: 8
12 | NUM_COLOR_LAYERS: 1
13 | NUM_FEATURES: 256
14 | ENCODING_LENGTH_POSITIONS: 10
15 | ENCODING_LENGTH_DIRECTIONS: 4
16 | ENCODING_APPEND_INPUT: true
17 | INPUT_SKIPS:
18 | - 5
19 | ACTIVATION_FUNCTION: relu
20 | RENDERER:
21 | RAY_BATCH_SIZE: 8_192
22 | NUM_SAMPLES: 256
23 | TRAINING:
24 | LOAD_CHECKPOINT: null
25 | MODEL_NAME: nerf_lego
26 | NUM_ITERATIONS: 500_000
27 | ACTIVATE_TIMING: false
28 | RUN_VALIDATION: false
29 | BACKUP:
30 | FINAL_CHECKPOINT: true
31 | RENDER_TESTSET: true
32 | RENDER_TRAINSET: false
33 | RENDER_VALSET: false
34 | VISUALIZE_ERRORS: false
35 | INTERVAL: -1
36 | TRAINING_STATE: false
37 | WANDB:
38 | ACTIVATE: false
39 | ENTITY: null
40 | PROJECT: nerficg
41 | LOG_IMAGES: true
42 | INDEX_VALIDATION: -1
43 | INDEX_TRAINING: -1
44 | INTERVAL: 10_000
45 | SWEEP_MODE:
46 | ACTIVE: false
47 | START_ITERATION: 999
48 | ITERATION_STRIDE: 1000
49 | BATCH_SIZE: 1024
50 | SAMPLE_SINGLE_IMAGE: true
51 | DENSITY_RANDOM_NOISE_STD: 0.0
52 | ADAM_BETA_1: 0.9
53 | ADAM_BETA_2: 0.999
54 | LEARNINGRATE: 0.0005
55 | LEARNINGRATE_DECAY_RATE: 0.1
56 | LEARNINGRATE_DECAY_STEPS: 500_000
57 | LAMBDA_COLOR_LOSS: 1.0
58 | LAMBDA_ALPHA_LOSS: 0.0
59 | DATASET:
60 | PATH: dataset/nerf_synthetic/lego
61 | IMAGE_SCALE_FACTOR: null
62 | NORMALIZE_CUBE: null
63 | NORMALIZE_RECENTER: true
64 | PRECOMPUTE_RAYS: true
65 | TO_DEVICE: true
66 | BACKGROUND_COLOR:
67 | - 1.0
68 | - 1.0
69 | - 1.0
70 | LOAD_TESTSET_DEPTHS: false
71 |
--------------------------------------------------------------------------------
/dataset/.gitignore:
--------------------------------------------------------------------------------
1 | # ln -s /afs/cg.cs.tu-bs.de/cgdata/nerfdata/* dataset
2 | *
3 | !.gitignore
--------------------------------------------------------------------------------
/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/resources/.gitignore:
--------------------------------------------------------------------------------
1 | !*
--------------------------------------------------------------------------------
/resources/icg_logo.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/icg_logo.bmp
--------------------------------------------------------------------------------
/resources/nerficg_banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/nerficg_banner.png
--------------------------------------------------------------------------------
/resources/teaser_videos/dnpc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/dnpc.gif
--------------------------------------------------------------------------------
/resources/teaser_videos/dnpc.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/dnpc.mp4
--------------------------------------------------------------------------------
/resources/teaser_videos/inpc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/inpc.gif
--------------------------------------------------------------------------------
/resources/teaser_videos/inpc.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/teaser_videos/inpc.mp4
--------------------------------------------------------------------------------
/resources/train_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/resources/train_demo.gif
--------------------------------------------------------------------------------
/scripts/condaEnv.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | ENV_NAME="nerficg"
4 | PYTHONVERSION="3.11"
5 | CUDAVERSION="cu118"
6 | HEADLESS=false
7 |
8 | # parse args
9 | while [[ $# -gt 0 ]]; do
10 | case $1 in
11 | -n|--name)
12 | ENV_NAME="$2"
13 | shift 2
14 | ;;
15 | -p|--python)
16 | PYTHONVERSION="$2"
17 | shift 2
18 | ;;
19 | -c|--cuda)
20 | CUDAVERSION="$2"
21 | shift 2
22 | ;;
23 | -h|--headless)
24 | HEADLESS=true
25 | shift
26 | ;;
27 | *)
28 | echo "Unknown argument: $1"
29 | echo "Usage: $0 [-n|--name ] [-p|--python ] [-c|--cuda ] [-h|--headless]"
30 | exit 1
31 | ;;
32 | esac
33 | done
34 |
35 | echo "Creating conda environment '$ENV_NAME' with Python $PYTHONVERSION and CUDA $CUDAVERSION"
36 | if [ "$HEADLESS" = true ]; then
37 | echo "Installing in headless mode (no GUI dependencies)"
38 | fi
39 |
40 | # create new conda environment
41 | conda create -y --name $ENV_NAME python=$PYTHONVERSION
42 | # install base dependencies
43 | conda install -y -n $ENV_NAME packaging
44 | conda run -n $ENV_NAME pip install torch==2.5.1 torchvision --index-url https://download.pytorch.org/whl/${CUDAVERSION}
45 | conda run -n $ENV_NAME pip install numpy tqdm natsort GitPython av ffmpeg-python pyyaml munch tabulate wandb opencv-python kornia torchmetrics lpips einops setuptools plyfile matplotlib timm plotly pillow jax
46 | conda install -y -n $ENV_NAME -c conda-forge colmap
47 | # install gui dependencies
48 | if [ "$HEADLESS" = false ]; then
49 | conda run -n $ENV_NAME pip install imgui[sdl2] cuda-python platformdirs
50 | fi
51 |
--------------------------------------------------------------------------------
/scripts/defaultConfig.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """defaultConfig.py: Creates a new config file with default values for a given method and dataset."""
5 |
6 | import os
7 | from argparse import ArgumentParser
8 | import warnings
9 | import yaml
10 | from pathlib import Path
11 | import utils
12 |
13 | with utils.discoverSourcePath():
14 | import Framework
15 | from Logging import Logger
16 | from Implementations import Methods as MI
17 | from Implementations import Datasets as DI
18 |
19 |
20 | def main(*, method_name: str, dataset_name: str, all_sequences: bool, output_filename: str) -> None:
21 | Logger.setMode(Logger.MODE_VERBOSE)
22 | # create config with global defaults
23 | Framework.config = Framework.ConfigWrapper(GLOBAL=Framework.getDefaultGlobalConfig())
24 | Framework.config.GLOBAL.METHOD_TYPE = method_name
25 | Framework.config.GLOBAL.DATASET_TYPE = dataset_name
26 | # add renderer, model and training parameters
27 | method = MI.importMethod(method_name)
28 | Framework.config.MODEL = method.MODEL.getDefaultParameters()
29 | Framework.config.RENDERER = method.RENDERER.getDefaultParameters()
30 | Framework.config.TRAINING = method.TRAINING_INSTANCE.getDefaultParameters()
31 | # add dataset parameters
32 | dataset_class = DI.getDatasetClass(dataset_name)
33 | Framework.config.DATASET = dataset_class.getDefaultParameters()
34 | # dump config into file
35 | output_path = Path(__file__).resolve().parents[1] / 'configs'
36 | dataset_path = None
37 | if all_sequences:
38 | output_path = output_path / output_filename
39 | dataset_path = Path(Framework.config.DATASET.PATH).parents[0]
40 | os.makedirs(str(output_path), exist_ok=True)
41 | if not dataset_path.is_dir():
42 | Logger.logError(f'failed to gather sequences from "{dataset_path}": directory not found')
43 | return
44 | config_file_names = [i.name for i in dataset_path.iterdir() if i.is_dir()]
45 | else:
46 | config_file_names = [output_filename]
47 | for config_file_name in config_file_names:
48 | config_file_path = output_path / f'{config_file_name}.yaml'
49 | try:
50 | Framework.config.TRAINING.MODEL_NAME = config_file_name
51 | if dataset_path is not None:
52 | Framework.config.DATASET.PATH = str(dataset_path / config_file_name)
53 | with open(config_file_path, 'w') as f:
54 | yaml.dump(Framework.ConfigParameterList.toDict(Framework.config), f,
55 | default_flow_style=False, indent=4, canonical=False, sort_keys=False)
56 | Logger.logInfo(f'configuration file successfully created: {config_file_path}')
57 | except IOError as e:
58 | Logger.logError(f'failed to create configuration file: "{e}"')
59 |
60 |
61 | if __name__ == '__main__':
62 | warnings.filterwarnings('ignore')
63 | # parse command line args
64 | parser: ArgumentParser = ArgumentParser(prog='defaultConfig')
65 | parser.add_argument(
66 | '-m', '--method', action='store', dest='method_name',
67 | metavar='method directory name', required=True,
68 | help='Name of the method you want to train. Name should match the directory in lib/methods.'
69 | )
70 | parser.add_argument(
71 | '-d', '--dataset', action='store', dest='dataset_name',
72 | metavar='dataset name', required=True,
73 | help='Name of the dataset you want to train on. Name should match the python file in src/Datasets.'
74 | )
75 | parser.add_argument(
76 | '-a', '--all', action='store_true', dest='all_sequences',
77 | help='If set, creates a directory containing a config file for each sequence in the dataset.'
78 | )
79 | parser.add_argument(
80 | '-o', '--output', action='store', dest='output_filename',
81 | metavar='output config filename', required=True,
82 | help='Name of the generated config file.'
83 | )
84 | args = parser.parse_args()
85 | if args.output_filename.endswith('.yaml'):
86 | args.output_filename = args.output_filename[:-5]
87 |
88 | # run main
89 | main(
90 | method_name=args.method_name,
91 | dataset_name=args.dataset_name,
92 | all_sequences=args.all_sequences,
93 | output_filename=args.output_filename
94 | )
95 |
--------------------------------------------------------------------------------
/scripts/gui.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """gui.py: opens graphical user interface."""
5 |
6 | import sys
7 |
8 | import torch.multiprocessing as mp
9 |
10 | import utils
11 | with utils.discoverSourcePath():
12 | from Logging import Logger
13 |
14 | if __name__ == '__main__':
15 | Logger.setMode(Logger.MODE_VERBOSE)
16 | with utils.discoverSourcePath():
17 | try:
18 | from ICGui.launchViewer import main
19 | except ImportError as e:
20 | Logger.logError(f'Failed to open GUI: {e}\n'
21 | f'Make sure the icgui submodule is initialized and '
22 | f'updated before running this script. ')
23 | sys.exit(0)
24 | mp.set_start_method('spawn')
25 | main()
26 |
--------------------------------------------------------------------------------
/scripts/install.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """install.py: installs specific extensions or all extensions required by a given method."""
5 |
6 | import os
7 | import subprocess
8 | import sys
9 | from argparse import ArgumentParser
10 | from pathlib import Path
11 | import importlib
12 | import warnings
13 | from types import ModuleType
14 |
15 | import utils
16 |
17 | with utils.discoverSourcePath():
18 | import Framework
19 | from Implementations import Methods as MI
20 | from Logging import Logger
21 |
22 |
23 | def installExtension(install_name: str, install_command: list[str]) -> bool:
24 | """Installs a single extension."""
25 | Logger.logInfo(f'Installing extension {install_name}...')
26 | result = subprocess.run(install_command, check=False)
27 | if result.returncode != 0:
28 | Logger.logError(f'Failed to install extension "{install_name}" with command: "{install_command if isinstance(install_command, str) else " ".join(install_command)}"')
29 | return result.returncode == 0
30 |
31 |
32 | def importExtension(extension_path: str) -> ModuleType:
33 | """Imports an extension module."""
34 | extension_spec = Path(extension_path).resolve()
35 | if extension_spec.is_dir():
36 | extension_spec = extension_spec / '__init__.py'
37 | extension_spec = importlib.util.spec_from_file_location(str(extension_spec.stem), str(extension_spec))
38 | extension_module = importlib.util.module_from_spec(extension_spec)
39 | extension_spec.loader.exec_module(extension_module)
40 | return extension_module
41 |
42 |
43 | def main(extension_path: str, method_name: str) -> None:
44 | """Installs extensions required by a given method or a specific extension."""
45 | Framework.setup()
46 | essential_modules = set(sys.modules.keys())
47 | if extension_path is not None:
48 | try:
49 | extension_module = importExtension(extension_path)
50 | install_command = extension_module.__install_command__
51 | install_name = extension_module.__extension_name__
52 | except Framework.ExtensionError as e:
53 | install_name = e.__extension_name__
54 | install_command = e.__install_command__
55 | except FileNotFoundError:
56 | Logger.logError(f'Invalid extension path "{extension_path}": Module not found.')
57 | return
58 | except AttributeError as e:
59 | Logger.logError(f'Invalid extension module "{extension_path}": {e}')
60 | return
61 | if not installExtension(install_name, install_command):
62 | return
63 | if method_name is not None:
64 | if method_name not in MI.options:
65 | Logger.logError(f'Invalid method name "{method_name}".\nAvailable methods are: {MI.options}')
66 | return
67 | Logger.logInfo(f'Installing extensions for method "{method_name}"...')
68 | last_installed = None
69 | with utils.discoverSourcePath():
70 | while True:
71 | try:
72 | all_modules = set(sys.modules.keys())
73 | for module in all_modules.symmetric_difference(essential_modules):
74 | del sys.modules[module]
75 | MI._import(method_name)
76 | break
77 | except Framework.ExtensionError as e:
78 | if last_installed == e.__extension_name__:
79 | Logger.logError(f'Failed to install extension "{e.__extension_name__}" with command: "{e.__install_command__ if isinstance(e.__install_command__, str) else " ".join(e.__install_command__)}"')
80 | return
81 | last_installed = e.__extension_name__
82 | if not installExtension(e.__extension_name__, e.__install_command__):
83 | return
84 | except Exception as e:
85 | Logger.logError(f'Unexpected error during method import: {e}')
86 | return
87 | Logger.logInfo('done')
88 |
89 |
90 | if __name__ == '__main__':
91 | # parse arguments
92 | parser: ArgumentParser = ArgumentParser(prog='Install')
93 | parser.add_argument(
94 | '-m', '--method', action='store', dest='method_name', default=None,
95 | metavar='method_name', required=False,
96 | help='Name of the method to install extensions for.'
97 | )
98 | parser.add_argument(
99 | '-e', '--extension', action='store', dest='extension_path', default=None,
100 | metavar='extension_path', required=False,
101 | help='Path to extension to be installed.'
102 | )
103 | args = parser.parse_args()
104 | # run
105 | Logger.setMode(Logger.MODE_VERBOSE)
106 | warnings.filterwarnings('ignore')
107 | os.environ['MKL_THREADING_LAYER'] = 'GNU'
108 | main(args.extension_path, args.method_name)
109 |
--------------------------------------------------------------------------------
/scripts/raft.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """raft.py: Predicts optical flow for the given input image sequence using RAFT."""
5 |
6 | import os
7 | from argparse import ArgumentParser
8 | from pathlib import Path
9 | import torch
10 | from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
11 |
12 | import utils
13 |
14 | with utils.discoverSourcePath():
15 | import Framework
16 | from Logging import Logger
17 | from Datasets.utils import list_sorted_files, list_sorted_directories, \
18 | loadImagesParallel, saveOpticalFlowFile, flowToImage, saveImage
19 |
20 |
21 | def predictAndSave(*, filename: str, output_dir: Path, model: torch.nn.Module, color: bool,
22 | inputs1: torch.Tensor, inputs2: torch.Tensor, width: int, height: int) -> None:
23 | flows = model(inputs1, inputs2)[-1][:, :, :height, :width]
24 | saveOpticalFlowFile(output_dir / f'{filename}.flo', flows[0])
25 | if color:
26 | saveImage(output_dir / f'{filename}.png', flowToImage(flows[0]))
27 |
28 |
29 | @torch.no_grad()
30 | def main(*, base_path: Path, output_path: Path | None, recursive: bool, backward: bool, color: bool,
31 | model: torch.nn.Module = None) -> bool:
32 | device = Framework.config.GLOBAL.DEFAULT_DEVICE
33 | # load model
34 | if model is None:
35 | Logger.logInfo('Loading RAFT model...')
36 | model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device).eval()
37 | # run on subdirectories
38 | if recursive:
39 | subdirs = [base_path / i for i in list_sorted_directories(base_path)]
40 | for subdir in subdirs:
41 | main(
42 | base_path=subdir,
43 | output_path=output_path if output_path is None else output_path / subdir.name,
44 | recursive=recursive,
45 | backward=backward,
46 | color=color,
47 | model=model
48 | )
49 | # load images
50 | filenames = [i for i in list_sorted_files(base_path) if Path(i).suffix.lower() in ['.png', '.jpg', '.jpeg']]
51 | file_paths = [str(base_path / i) for i in filenames]
52 | if filenames:
53 | Logger.logInfo(f'Running RAFT on sequence directory: "{base_path}"...')
54 | Logger.logInfo(f'Found {len(filenames)} images')
55 | rgbs, _ = loadImagesParallel(file_paths, None, 4, 'loading image sequence')
56 | rgbs = (torch.stack(rgbs, dim=0).to(device) * 2.0) - 1.0
57 | # create output directory
58 | if output_path is None:
59 | output_dir = base_path / 'flow'
60 | else:
61 | output_dir = output_path
62 | os.makedirs(str(output_dir), exist_ok=True)
63 | # pad inputs to multiple of 8
64 | *_, h, w = rgbs.shape
65 | delta_h = (8 - (h % 8)) % 8
66 | delta_w = (8 - (w % 8)) % 8
67 | if delta_h != 0 or delta_w != 0:
68 | rgbs = torch.nn.functional.pad(rgbs, (0, delta_w, 0, delta_h, 0, 0, 0, 0), 'constant', 0)
69 | # predict and save flows
70 | for i in Logger.logProgressBar(range(len(rgbs) - 1), desc='image', leave=False):
71 | inputs1 = rgbs[i:i+1]
72 | inputs2 = rgbs[i+1:i+2]
73 | predictAndSave(filename=f'{filenames[i].split(".")[0]}_forward', output_dir=output_dir, model=model,
74 | color=color, inputs1=inputs1, inputs2=inputs2, width=w, height=h)
75 | if backward:
76 | predictAndSave(filename=f'{filenames[i+1].split(".")[0]}_backward', output_dir=output_dir, model=model,
77 | color=color, inputs1=inputs2, inputs2=inputs1, width=w, height=h)
78 | Logger.logInfo('done.')
79 | return True
80 |
81 |
82 | if __name__ == '__main__':
83 | # parse command line args
84 | parser: ArgumentParser = ArgumentParser(
85 | prog='raft.py',
86 | description='Predicts optical flow for the given input image sequence using RAFT.'
87 | )
88 | parser.add_argument(
89 | '-i', '--input', action='store', dest='sequence_path',
90 | required=True, help='Path to the base directory containing the images.'
91 | )
92 | parser.add_argument(
93 | '-r', '--recusive', action='store_true', dest='recursive',
94 | help='Scan for subdirectories and estimate flow.'
95 | )
96 | parser.add_argument(
97 | '-b', '--backward', action='store_true', dest='backward',
98 | help='Generate flow predictions in backward direction.'
99 | )
100 | parser.add_argument(
101 | '-o', '--output_path', action='store', dest='output_path',
102 | required=False, default=None,
103 | help='If set, the flow predictions will be stored in the given directory.'
104 | )
105 | parser.add_argument(
106 | '-v', '--visualize', action='store_true', dest='color',
107 | help='Generate and store color visualizations of estimated flow.'
108 | )
109 | args = parser.parse_args()
110 | # init Framework with defaults
111 | Framework.setup()
112 | # run main
113 | Logger.setMode(Logger.MODE_VERBOSE)
114 | main(
115 | base_path=Path(args.sequence_path),
116 | output_path=Path(args.output_path) if args.output_path is not None else None,
117 | recursive=args.recursive,
118 | backward=args.backward,
119 | color=args.color
120 | )
121 |
--------------------------------------------------------------------------------
/scripts/sequentialTrain.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """sequentialTrain.py: Sequentially runs model trainings for a list or directory of config files."""
5 |
6 | import os
7 | import shutil
8 | from argparse import ArgumentParser
9 | from pathlib import Path
10 | import datetime
11 | from statistics import mean
12 | from tabulate import tabulate
13 | from multiprocessing import Process, Queue
14 |
15 | import utils
16 | with utils.discoverSourcePath():
17 | import Framework
18 | import train
19 | from Logging import Logger
20 | from Datasets.utils import list_sorted_files
21 |
22 |
23 | def gatherConfigs() -> tuple[list[str], Path]:
24 | Logger.log('collecting config files')
25 | # parse arguments to retrieve config file locations
26 | parser: ArgumentParser = ArgumentParser(prog='SequentialTrain')
27 | parser.add_argument(
28 | '-c', '--configs', action='store', dest='config_paths', default=[],
29 | metavar='paths/to/config_files', required=False, nargs='+',
30 | help='Multiple whitespace separated training config files.'
31 | )
32 | parser.add_argument(
33 | '-d', '--dir', action='store', dest='config_dir', default=None,
34 | metavar='path/to/configdir/', required=False,
35 | help='A directory containing training configuration files.'
36 | )
37 | args, _ = parser.parse_known_args()
38 | # add all configs from -c flag
39 | config_paths: list[str] = args.config_paths
40 | # set output common directory
41 | output_directory = f'sequential_train_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'
42 | # add all configs from config dir
43 | if args.config_dir is not None:
44 | config_dir_path = Path(args.config_dir)
45 | if not config_dir_path.is_dir():
46 | raise Framework.TrainingError(f'config dir is not a valid directory: {config_dir_path}')
47 | directory_configs = [str(config_dir_path / i) for i in list_sorted_files(config_dir_path) if '.yaml' in i]
48 | config_paths = config_paths + directory_configs
49 | output_directory = f'{config_dir_path.name}_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'
50 | # output directory and all configs for execution
51 | return output_directory, config_paths
52 |
53 |
54 | def writeOutputMetrics(metric_files: dict[str, Path], output_directory: Path) -> None:
55 | if metric_files:
56 | output_file = output_directory / 'summary.txt'
57 | Logger.log(f'Gathering final quality metrics for {len(metric_files)} runs in: {output_file}')
58 | parsed_values = []
59 | for metric_file in metric_files.values():
60 | with open(metric_file) as f:
61 | for line in f:
62 | pass
63 | parsed_values.append({i[0]: float(i[1]) for i in [j.split(':') for j in line.split(' ')]})
64 | headers = ['Metric'] + list(metric_files.keys()) + ['Mean']
65 | tab = [[metric_name] + [(run[metric_name]) for run in parsed_values] for metric_name in parsed_values[0].keys()]
66 | for row in tab:
67 | row.append(mean(row[1:]))
68 | with open(output_file, 'w') as f:
69 | f.write(tabulate(tabular_data=tab, headers=headers, floatfmt=".3f"))
70 |
71 |
72 | def training_process_helper(output, config):
73 | training_instance = train.main(config_path=config)
74 | output.put((training_instance.output_directory, training_instance.model.model_name, training_instance.NUM_ITERATIONS))
75 |
76 |
77 | def main():
78 | Logger.setMode(Logger.MODE_VERBOSE)
79 | # get list of training configs
80 | output_directory, config_paths = gatherConfigs()
81 | if not config_paths:
82 | raise Framework.TrainingError('no valid config file found')
83 | # create output directory
84 | output_directory = Path(__file__).resolve().parents[1] / 'output' / output_directory
85 | os.makedirs(str(output_directory), exist_ok=False)
86 | # run trainings
87 | Logger.log(f'running sequential training for {len(config_paths)} configurations')
88 | Logger.log(f'outputs will be available at: {output_directory}')
89 | success, failed = 0, 0
90 | metric_files = {}
91 | for config in config_paths:
92 | try:
93 | # run training for single config
94 | output = Queue()
95 | training_process = Process(target=training_process_helper, args=(output, config))
96 | training_process.start()
97 | training_process.join()
98 | output_directory_run, model_name, num_iterations = output.get(block=False)
99 | shutil.move(output_directory_run, output_directory)
100 | metric_file: Path = output_directory / Path(output_directory_run).name / \
101 | f'test_{num_iterations}' / 'metrics_8bit.txt'
102 | # detect output metric file
103 | if metric_file.is_file():
104 | metric_files[model_name] = metric_file
105 | success += 1
106 | except Exception as e:
107 | Logger.logError(f'training failed for config file: "{config}" with error:\n{e}')
108 | failed += 1
109 | Logger.log(f'\nfinished sequential training ({success} successful, {failed} failed)')
110 | # combine training results in single table
111 | writeOutputMetrics(metric_files, output_directory)
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -- coding: utf-8 --
3 |
4 | """train.py: Trains a new model from config file."""
5 |
6 | import utils
7 |
8 | with utils.discoverSourcePath():
9 | import Framework
10 | from Implementations import Methods as MI
11 | from Implementations import Datasets as DI
12 | from Methods.Base.Trainer import BaseTrainer
13 | from Datasets.Base import BaseDataset
14 |
15 |
16 | def main(config_path: str = None):
17 | Framework.setup(config_path=config_path, require_custom_config=True)
18 | training_instance: BaseTrainer = MI.getTrainingInstance(
19 | method=Framework.config.GLOBAL.METHOD_TYPE,
20 | checkpoint=Framework.config.TRAINING.LOAD_CHECKPOINT
21 | )
22 | dataset: BaseDataset = DI.getDataset(
23 | dataset_type=Framework.config.GLOBAL.DATASET_TYPE,
24 | path=Framework.config.DATASET.PATH
25 | )
26 | training_instance.run(dataset)
27 | Framework.teardown()
28 | return training_instance
29 |
30 |
31 | if __name__ == '__main__':
32 | main()
33 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """utils.py: utility code for script execution."""
4 |
5 | import sys
6 | from pathlib import Path
7 |
8 |
9 | class discoverSourcePath:
10 | """A context class adding the source code location to the current python path."""
11 |
12 | def __enter__(self):
13 | sys.path.insert(0, str(Path(__file__).resolve().parents[1] / 'src'))
14 |
15 | def __exit__(self, *_):
16 | sys.path.pop(0)
17 |
18 |
19 | def getCachePath() -> Path:
20 | """Returns the path to the framework .cache directory."""
21 | return Path(__file__).resolve().parents[1] / '.cache'
22 |
--------------------------------------------------------------------------------
/src/Cameras/Equirectangular.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Cameras/Equirectangular.py: Implements a 360-degree panorama camera model."""
4 |
5 | import math
6 | from torch import Tensor
7 | import torch
8 |
9 | from Cameras.Base import BaseCamera
10 |
11 |
12 | class EquirectangularCamera(BaseCamera):
13 | """Defines a 360-degree panorama camera model for ray generation."""
14 |
15 | def __init__(self, near_plane: float, far_plane: float) -> None:
16 | super(EquirectangularCamera, self).__init__(near_plane, far_plane)
17 |
18 | def getRayOrigins(self) -> Tensor:
19 | """Returns a tensor containing the origin of each ray."""
20 | return self.properties.c2w[:3, -1].expand((self.properties.width * self.properties.height, 3))
21 |
22 | def _getLocalRayDirections(self) -> Tensor:
23 | """Returns a tensor containing the direction of each ray."""
24 | azimuth: Tensor = torch.linspace(
25 | start=(-math.pi + (math.pi / self.properties.width)),
26 | end=(math.pi - (math.pi / self.properties.width)),
27 | steps=self.properties.width
28 | )[None, :].expand((self.properties.height, self.properties.width))
29 | inclination: Tensor = torch.linspace(
30 | start=((-math.pi / 2.0) + (0.5 * math.pi / self.properties.height)),
31 | end=((math.pi / 2.0) - (0.5 * math.pi / self.properties.height)),
32 | steps=self.properties.height
33 | )[:, None].expand((self.properties.height, self.properties.width))
34 | x_direction: Tensor = torch.sin(azimuth) * torch.cos(inclination)
35 | y_direction: Tensor = torch.sin(inclination)
36 | z_direction: Tensor = -torch.cos(azimuth) * torch.cos(inclination)
37 | directions_camera_space: Tensor = torch.stack(
38 | (x_direction, y_direction, z_direction), dim=-1
39 | ).reshape(-1, 3)
40 | return directions_camera_space
41 |
42 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]:
43 | raise NotImplementedError('point projection not yet implemented for panorama camera')
44 |
45 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor:
46 | raise NotImplementedError('projection matrix not yet implemented for panorama camera')
47 |
--------------------------------------------------------------------------------
/src/Cameras/NDC.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Cameras/NDC.py: A perspective RGB camera model generating rays in normalized device space."""
4 |
5 | from torch import Tensor
6 | import torch
7 |
8 | from Cameras.Perspective import PerspectiveCamera
9 | from Cameras.utils import RayPropertySlice
10 |
11 |
12 | class NDCCamera(PerspectiveCamera):
13 | """Defines a perspective RGB camera model that generates rays in normalized device space."""
14 |
15 | def __init__(self, cube_scale: float = 1.0, z_plane: float = -1.0, near_plane: float = 0.01, far_plane: float = 1.0) -> None:
16 | super(NDCCamera, self).__init__(near_plane, far_plane * cube_scale)
17 | self.cube_scale: float = cube_scale
18 | self.z_plane: float = z_plane
19 |
20 | def generateRays(self) -> Tensor:
21 | # generate perspective rays
22 | rays: Tensor = super().generateRays()
23 | origins: Tensor = rays[:, RayPropertySlice.origin]
24 | directions: Tensor = rays[:, RayPropertySlice.direction]
25 | # shift rays to near plane
26 | t: Tensor = -(1.0 + origins[..., 2]) / directions[..., 2]
27 | origins: Tensor = origins + t[..., None] * directions
28 | # precompute some intermediate results
29 | w2fx: float = self.properties.width / (2. * self.properties.focal_x)
30 | h2fy: float = self.properties.height / (2. * self.properties.focal_y)
31 | ox_oz: Tensor = origins[..., 0] / origins[..., 2]
32 | oy_oz: Tensor = origins[..., 1] / origins[..., 2]
33 | # projection to normalized device space
34 | o0: Tensor = -1. / w2fx * ox_oz
35 | o1: Tensor = -1. / h2fy * oy_oz
36 | o2: Tensor = 1. + 2. * 1.0 / origins[..., 2]
37 | d0: Tensor = -1. / w2fx * (directions[..., 0] / directions[..., 2] - ox_oz)
38 | d1: Tensor = -1. / h2fy * (directions[..., 1] / directions[..., 2] - oy_oz)
39 | d2: Tensor = 1. - o2
40 | ndc_rays: Tensor = torch.cat([
41 | torch.stack([o0, o1, o2], -1) * self.cube_scale,
42 | torch.stack([d0, d1, d2], -1) * self.cube_scale,
43 | rays[:, RayPropertySlice.all_annotations_ndc],
44 | ], dim=-1)
45 | ndc_rays[:, 2] = self.z_plane
46 | return ndc_rays
47 |
48 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]:
49 | raise NotImplementedError('point projection not yet implemented for NDC camera')
50 |
51 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor:
52 | raise NotImplementedError('projection matrix not yet implemented for NDC camera')
53 |
--------------------------------------------------------------------------------
/src/Cameras/ODS.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Cameras/ODS.py: An omni-directional stereo panorama camera model."""
4 |
5 | import math
6 | from torch import Tensor
7 | import torch
8 |
9 | from Cameras.Equirectangular import EquirectangularCamera
10 |
11 |
12 | class ODSCamera(EquirectangularCamera):
13 | """
14 | Defines an omnidirectional stereo panorama camera model for ray generation.
15 | Expects the stereo images to be vertically concatenated.
16 | """
17 |
18 | def __init__(self, near_plane: float, far_plane: float, baseline: float = 0.065) -> None:
19 | super(ODSCamera, self).__init__(near_plane, far_plane)
20 | self.half_baseline: float = baseline / 2
21 |
22 | def getRayOrigins(self) -> Tensor:
23 | """Returns a tensor containing the origin of each ray."""
24 | rotation_angle: Tensor = torch.linspace(
25 | start=(-math.pi + (math.pi / self.properties.width)),
26 | end=(math.pi - (math.pi / self.properties.width)),
27 | steps=self.properties.width
28 | )[None, :].expand((self.properties.height // 2, self.properties.width))
29 | baseline_vector: Tensor = torch.stack(
30 | (torch.cos(rotation_angle), torch.zeros_like(rotation_angle), torch.sin(rotation_angle)), dim=-1
31 | ).reshape(-1, 3)
32 | origins_camera_space: Tensor = torch.cat([
33 | -self.half_baseline * baseline_vector,
34 | self.half_baseline * baseline_vector
35 | ], dim=0)
36 | origins_world_space: Tensor = torch.matmul(
37 | self.properties.c2w,
38 | torch.cat([origins_camera_space, torch.ones((origins_camera_space.shape[0], 1))], dim=-1)[:, :, None]
39 | ).squeeze()
40 | return origins_world_space[:, :3]
41 |
42 | def _getLocalRayDirections(self) -> Tensor:
43 | """Returns a tensor containing the direction of each ray."""
44 | # adjust height for vertical concatenation
45 | self.properties.height = self.properties.height // 2
46 | directions: Tensor = super().getLocalRayDirections()
47 | self.properties.height = self.properties.height * 2
48 | return torch.cat([directions, directions], dim=0)
49 |
50 | @property
51 | def baseline(self) -> float:
52 | """Returns the baseline of the stereo camera."""
53 | return self.half_baseline * 2.0
54 |
55 | @baseline.setter
56 | def baseline(self, value: float) -> None:
57 | self.half_baseline = value / 2.0
58 |
59 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]:
60 | raise NotImplementedError('point projection not yet implemented for ODS camera')
61 |
62 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor:
63 | raise NotImplementedError('projection matrix not yet implemented for ODS camera')
64 |
--------------------------------------------------------------------------------
/src/Cameras/Perspective.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Cameras/Perspective.py: Implementation of a perspective RGB camera model."""
4 |
5 | from torch import Tensor
6 | import torch
7 |
8 | from Cameras.Base import BaseCamera
9 |
10 |
11 | class PerspectiveCamera(BaseCamera):
12 | """Defines a perspective RGB camera model for ray generation."""
13 |
14 | def __init__(self, near_plane: float, far_plane: float) -> None:
15 | super(PerspectiveCamera, self).__init__(near_plane, far_plane)
16 |
17 | def getRayOrigins(self) -> Tensor:
18 | """Returns a tensor containing the origin of each ray."""
19 | return self.properties.c2w[:3, -1].expand((self.properties.width * self.properties.height, 3))
20 |
21 | def _getLocalRayDirections(self) -> Tensor:
22 | """Returns a tensor containing the direction of each ray."""
23 | # calculate initial directions
24 | x_direction, y_direction = self.getPixelCoordinates()
25 | x_direction: Tensor = ((x_direction + 0.5) - (0.5 * self.properties.width + self.properties.principal_offset_x)) / self.properties.focal_x
26 | y_direction: Tensor = ((y_direction + 0.5) - (0.5 * self.properties.height + self.properties.principal_offset_y)) / self.properties.focal_y
27 | z_direction: Tensor = torch.full((self.properties.height, self.properties.width), fill_value=-1)
28 | directions_camera_space: Tensor = torch.stack(
29 | (x_direction, y_direction, z_direction), dim=-1
30 | ).reshape(-1, 3)
31 | return directions_camera_space
32 |
33 | def projectPoints(self, points: Tensor, eps: float = 1.0e-8) -> tuple[Tensor, Tensor, Tensor]:
34 | """projects points (Nx3) to image plane. returns xy image plane pixel coordinates,
35 | mask of points that hit sensor, and depth values."""
36 | # points = torch.cat((points, torch.ones((points.shape[0], 1), device=points.device, dtype=points.dtype)), dim=1)
37 | # points = points @ self.properties.w2c.T
38 | points = self.pointsToCameraSpace(points)
39 | depths = -points[:, 2]
40 | valid_mask = ((depths > self.near_plane) & (depths < self.far_plane))
41 | focal = torch.tensor((self.properties.focal_x, self.properties.focal_y), device=points.device, dtype=points.dtype)
42 | screen_size = torch.tensor((self.properties.width, self.properties.height), device=points.device, dtype=points.dtype)
43 | offset = screen_size * 0.5 + torch.tensor((self.properties.principal_offset_x, self.properties.principal_offset_y), device=points.device, dtype=points.dtype)
44 | points = self.properties.distortion_parameters.distort(points[:, :2] / (depths[:, None] + eps)) * focal + offset
45 | valid_mask &= ((points >= 0) & (points < screen_size - 1)).all(dim=-1)
46 | return points, valid_mask, depths
47 |
48 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor:
49 | """
50 | Returns the projection matrix for this camera.
51 | After perspective division, x, y, and z will be in [-1, 1] (OpenGL convention).
52 | The z-axis can be inverted for camera coordinate systems where the camera looks along the negative z-axis.
53 |
54 | Args:
55 | invert_z: If True, the z-axis will be inverted.
56 |
57 | Returns:
58 | The projection matrix from camera space to clip space.
59 | """
60 | half_width = self.properties.width * 0.5
61 | half_height = self.properties.height * 0.5
62 | z_sign = -1.0 if invert_z else 1.0
63 | projection_matrix = torch.tensor([
64 | [self.properties.focal_x / half_width, 0.0, z_sign * self.properties.principal_offset_x / half_width, 0.0],
65 | [0.0, self.properties.focal_y / half_height, z_sign * self.properties.principal_offset_y / half_height, 0.0],
66 | [0.0, 0.0, z_sign * (self.far_plane + self.near_plane) / (self.far_plane - self.near_plane), -2.0 * self.far_plane * self.near_plane / (self.far_plane - self.near_plane)],
67 | [0.0, 0.0, z_sign, 0.0]
68 | ], dtype=torch.float32, device=self.properties.c2w.device)
69 | return projection_matrix
70 |
71 | def getViewportTransform(self, pixel_centers_at_integer_coordinates: bool = True) -> Tensor:
72 | """
73 | Returns the transformation matrix from NDC to screen space.
74 | if pixel_centers_at_integer_coordinates:
75 | [-1, 1]^3 -> [-0.5, width - 0.5] x [-0.5, height - 0.5] x [near_plane, far_plane]
76 | else:
77 | [-1, 1]^3 -> [0, width] x [0, height] x [near_plane, far_plane]
78 |
79 | Args:
80 | pixel_centers_at_integer_coordinates: If True, pixel centers will be at integer coordinates.
81 |
82 | Returns:
83 | The transformation matrix from NDC to screen space.
84 | """
85 | offset = 0.5 if pixel_centers_at_integer_coordinates else 0.0
86 | center_x = self.properties.width * 0.5
87 | center_y = self.properties.height * 0.5
88 | viewport_transform = torch.tensor([
89 | [center_x, 0.0, 0.0, center_x - offset],
90 | [0.0, center_y, 0.0, center_y - offset],
91 | [0.0, 0.0, 1.0, 0.0],
92 | [0.0, 0.0, 0.0, 1.0]
93 | ], dtype=torch.float32, device=self.properties.c2w.device)
94 | # the following lines add a nonlinear mapping of z from [-1, 1] to [near_plane, far_plane]
95 | # viewport_transform[2, 2] = (self.far_plane - self.near_plane) * 0.5
96 | # viewport_transform[2, 3] = (self.far_plane + self.near_plane) * 0.5
97 | return viewport_transform
98 |
--------------------------------------------------------------------------------
/src/Cameras/PerspectiveStereo.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Cameras/PerspectiveStereo.py: Implementation of a perspective camera model generating stereo views."""
4 |
5 | from torch import Tensor
6 | import torch
7 |
8 | from Cameras.Perspective import PerspectiveCamera
9 |
10 |
11 | class PerspectiveStereoCamera(PerspectiveCamera):
12 | """Defines a perspective camera model that generates rays of vertically concatenated stereo views."""
13 |
14 | def __init__(self, near_plane: float, far_plane: float, baseline: float = 0.062) -> None:
15 | super(PerspectiveStereoCamera, self).__init__(near_plane, far_plane)
16 | self.half_baseline: float = baseline / 2.0
17 |
18 | def getRayOrigins(self) -> Tensor:
19 | """Returns a tensor containing the origin of each ray."""
20 | x_axis_world_space: Tensor = self.properties.c2w[:3, 0][None]
21 | # adjust height for vertical concatenation
22 | self.properties.height = self.properties.height // 2
23 | origin_world_space: Tensor = super().getRayOrigins()
24 | self.properties.height = self.properties.height * 2
25 | origins_left: Tensor = origin_world_space - (x_axis_world_space * self.half_baseline)
26 | origins_right: Tensor = origin_world_space + (x_axis_world_space * self.half_baseline)
27 | return torch.cat([origins_left, origins_right], dim=0)
28 |
29 | def _getLocalRayDirections(self) -> Tensor:
30 | """Returns a tensor containing the direction of each ray."""
31 | # adjust height for vertical concatenation
32 | self.properties.height = self.properties.height // 2
33 | directions: Tensor = super().getLocalRayDirections()
34 | self.properties.height = self.properties.height * 2
35 | return torch.cat([directions, directions], dim=0)
36 |
37 | @property
38 | def baseline(self) -> float:
39 | """Returns the baseline of the stereo camera."""
40 | return self.half_baseline * 2.0
41 |
42 | @baseline.setter
43 | def baseline(self, value: float) -> None:
44 | self.half_baseline = value / 2.0
45 |
46 | def projectPoints(self, points: Tensor) -> tuple[Tensor, Tensor, Tensor]:
47 | raise NotImplementedError('point projection not yet implemented for PerspectiveStereoCamera')
48 |
49 | def getProjectionMatrix(self, invert_z: bool = False) -> Tensor:
50 | raise NotImplementedError('projection matrix not yet implemented for PerspectiveStereoCamera')
51 |
--------------------------------------------------------------------------------
/src/Datasets/DNeRF.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Datasets/DNerf.py: Provides a dataset class for D-NeRF scenes.
5 | Data available at https://github.com/albertpumarola/D-NeRF (last accessed 2023-05-25).
6 | """
7 |
8 | import json
9 | import math
10 | from pathlib import Path
11 | from typing import Any
12 |
13 | import torch
14 |
15 | import Framework
16 | from Cameras.Perspective import PerspectiveCamera
17 | from Cameras.utils import CameraProperties
18 | from Datasets.Base import BaseDataset
19 | from Datasets.utils import applyBGColor, loadImagesParallel, \
20 | CameraCoordinateSystemsTransformations, WorldCoordinateSystemTransformations
21 |
22 |
23 | @Framework.Configurable.configure(
24 | PATH='dataset/d-nerf/standup',
25 | IMAGE_SCALE_FACTOR=0.5,
26 | NORMALIZE_CUBE=1.5,
27 | )
28 | class CustomDataset(BaseDataset):
29 | """Dataset class for D-NeRF scenes."""
30 |
31 | def __init__(self, path: str) -> None:
32 | super().__init__(
33 | path=path,
34 | camera=PerspectiveCamera(2.0, 6.0),
35 | camera_system=CameraCoordinateSystemsTransformations.RIGHT_HAND,
36 | world_system=WorldCoordinateSystemTransformations.XnYnZ
37 | )
38 |
39 | def load(self) -> dict[str, list[CameraProperties]]:
40 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing."""
41 | # set bb
42 | self._bounding_box = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]).cpu()
43 | # load data
44 | data: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets}
45 | for subset in self.subsets:
46 | metadata_filepath: Path = self.dataset_path / f'transforms_{subset}.json'
47 | try:
48 | with open(metadata_filepath, 'r') as f:
49 | metadata_file: dict[str, Any] = json.load(f)
50 | except IOError:
51 | raise Framework.DatasetError(f'Invalid dataset metadata file path "{metadata_filepath}"')
52 | opening_angle: float = float(metadata_file['camera_angle_x'])
53 | # load images
54 | image_filenames = [str(self.dataset_path / (frame['file_path'] + '.png')) for frame in metadata_file['frames']]
55 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc=subset)
56 | # create split CameraProperties objects
57 | for frame, rgb, alpha in zip(metadata_file['frames'], rgbs, alphas):
58 | # apply background color where alpha < 1
59 | rgb = applyBGColor(rgb, alpha, self.camera.background_color)
60 | # load camera extrinsics
61 | c2w = torch.as_tensor(frame['transform_matrix'], dtype=torch.float32)
62 | # load camera intrinsics
63 | focal_x: float = 0.5 * rgb.shape[2] / math.tan(0.5 * opening_angle)
64 | focal_y: float = 0.5 * rgb.shape[1] / math.tan(0.5 * opening_angle)
65 | # insert loaded values
66 | data[subset].append(CameraProperties(
67 | width=rgb.shape[2],
68 | height=rgb.shape[1],
69 | rgb=rgb,
70 | alpha=alpha,
71 | c2w=c2w,
72 | focal_x=focal_x,
73 | focal_y=focal_y,
74 | timestamp=frame['time']
75 | ))
76 | return data
77 |
--------------------------------------------------------------------------------
/src/Datasets/LLFF.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Datasets/LLFF.py: Provides a dataset class for Local Light Field Fusion (LLFF) scenes.
5 | Data available at https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 (last accessed 2023-05-25).
6 | """
7 |
8 | import numpy as np
9 | import torch
10 |
11 | import Framework
12 | from Cameras.NDC import NDCCamera
13 | from Cameras.Perspective import PerspectiveCamera
14 | from Cameras.utils import CameraProperties
15 | from Datasets.Base import BaseDataset
16 | from Datasets.utils import list_sorted_files, recenterPoses, loadImagesParallel, \
17 | CameraCoordinateSystemsTransformations
18 |
19 |
20 | @Framework.Configurable.configure(
21 | PATH='dataset/nerf_llff_data/fern',
22 | IMAGE_SCALE_FACTOR=0.25,
23 | DISABLE_NDC=False,
24 | TEST_STEP=8,
25 | WORLD_SCALING=0.75
26 | )
27 | class CustomDataset(BaseDataset):
28 | """Dataset class for Local Light Field Fusion (LLFF) scenes."""
29 |
30 | def __init__(self, path: str) -> None:
31 | super().__init__(
32 | path,
33 | NDCCamera(),
34 | CameraCoordinateSystemsTransformations.RIGHT_HAND
35 | )
36 |
37 | def load(self) -> dict[str, list[CameraProperties] | None]:
38 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing."""
39 | # TODO K-planes uses this bounding box for llff scenes!
40 | # torch.tensor([[-3.0, -1.67, -1.2], [3.0, 1.67, 1.2]])
41 |
42 | # load images
43 | images_path = self.dataset_path / 'images'
44 | image_filenames = [str(images_path / file) for file in list_sorted_files(images_path)]
45 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc='images')
46 | # load intrinsics and extrinsics
47 | colmap_poses = torch.as_tensor(np.load(str(self.dataset_path / 'poses_bounds.npy')))
48 | view_matrices = colmap_poses[:, :-2].reshape([-1, 3, 5])
49 | intrinsics = view_matrices[:, :, 4]
50 | focals = intrinsics[:, 2:3]
51 | if self.IMAGE_SCALE_FACTOR is not None:
52 | focals = focals * self.IMAGE_SCALE_FACTOR
53 | view_matrices = view_matrices[:, :, :-1]
54 | view_matrices = torch.cat(
55 | [view_matrices[:, :, 1:2], -view_matrices[:, :, 0:1], view_matrices[:, :, 2:]], dim=2
56 | )
57 | view_matrices = torch.cat(
58 | (view_matrices, torch.broadcast_to(torch.tensor([0, 0, 0, 1]), (view_matrices.shape[0], 1, 4))), dim=1
59 | )
60 | depth_min_max = colmap_poses[:, -2:]
61 | # rescale coordinates
62 | if self.WORLD_SCALING is not None:
63 | scaling = 1.0 / (depth_min_max.min() * self.WORLD_SCALING)
64 | view_matrices[:, :3, 3] *= scaling
65 | depth_min_max *= scaling
66 | # disable normalized device coordinates, if enabled
67 | if self.DISABLE_NDC:
68 | view_matrices = view_matrices.cpu()
69 | self.camera = PerspectiveCamera(
70 | near_plane=depth_min_max.min().item() * 0.9,
71 | far_plane=depth_min_max.max().item()
72 | )
73 | self.camera.setBackgroundColor(*self.BACKGROUND_COLOR)
74 | else:
75 | # recenter coordinates cameras (only makes sense for forward facing scenes)
76 | view_matrices = recenterPoses(view_matrices).cpu()
77 | # insert data into target data structure
78 | data: list[CameraProperties] = [
79 | CameraProperties(
80 | width=rgb.shape[2],
81 | height=rgb.shape[1],
82 | rgb=rgb,
83 | alpha=alpha,
84 | c2w=c2w,
85 | focal_x=focal.item(),
86 | focal_y=focal.item()
87 | )
88 | for rgb, alpha, c2w, focal in zip(rgbs, alphas, view_matrices, focals)
89 | ]
90 | # perform test split
91 | train_data: list[CameraProperties] = []
92 | test_data: list[CameraProperties] = []
93 | if self.TEST_STEP > 0:
94 | for i in range(len(data)):
95 | if i % self.TEST_STEP == 0:
96 | test_data.append(data[i])
97 | else:
98 | train_data.append(data[i])
99 | else:
100 | train_data: list[CameraProperties] = data
101 | # return the dataset
102 | return {
103 | 'train': train_data,
104 | 'test': test_data,
105 | 'val': []
106 | }
107 |
--------------------------------------------------------------------------------
/src/Datasets/MipNeRF360.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Datasets/MipNeRF360.py: Provides a dataset class for scenes from the Mip-NeRF 360 dataset.
5 | Data available at https://storage.googleapis.com/gresearch/refraw360/360_v2.zip (last accessed 2024-02-01) and
6 | https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip (last accessed 2024-02-01).
7 | Will also work for any other scene in the same format as the Mip-NeRF 360 dataset.
8 | """
9 |
10 | import os
11 | import torch
12 |
13 | import Framework
14 | from Logging import Logger
15 | from Cameras.Perspective import PerspectiveCamera
16 | from Cameras.utils import CameraProperties
17 | from Datasets.Base import BaseDataset
18 | from Datasets.utils import CameraCoordinateSystemsTransformations, loadImagesParallel, WorldCoordinateSystemTransformations
19 | from Datasets.Colmap import quaternion_to_R, read_points3D_binary, storePly, fetchPly, read_extrinsics_binary, read_intrinsics_binary, transformPosesPCA
20 |
21 | @Framework.Configurable.configure(
22 | PATH='dataset/mipnerf360/garden',
23 | IMAGE_SCALE_FACTOR=0.25,
24 | BACKGROUND_COLOR=[0.0, 0.0, 0.0],
25 | NEAR_PLANE=0.01,
26 | FAR_PLANE=100.0,
27 | TEST_STEP=8,
28 | APPLY_PCA=True,
29 | APPLY_PCA_RESCALE=True,
30 | USE_PRECOMPUTED_DOWNSCALING=True,
31 | )
32 | class CustomDataset(BaseDataset):
33 | """Dataset class for MipNeRF360 scenes."""
34 |
35 | def __init__(self, path: str) -> None:
36 | super().__init__(
37 | path,
38 | PerspectiveCamera(0.01, 100.0), # mipnerf360 itself uses 0.2, 1e6
39 | CameraCoordinateSystemsTransformations.LEFT_HAND,
40 | WorldCoordinateSystemTransformations.XnZY,
41 | )
42 |
43 | def load(self) -> dict[str, list[CameraProperties] | None]:
44 | """Loads the dataset into a dict containing lists of CameraProperties for training and testing."""
45 | # set near and far plane to values from config
46 | self.camera.near_plane = self.NEAR_PLANE
47 | self.camera.far_plane = self.FAR_PLANE
48 |
49 | # load colmap data
50 | cam_extrinsics = read_extrinsics_binary(self.dataset_path / 'sparse' / '0' / 'images.bin')
51 | cam_intrinsics = read_intrinsics_binary(self.dataset_path / 'sparse' / '0' / 'cameras.bin')
52 |
53 | # create camera properties
54 | data: list[CameraProperties] = []
55 | for cam_idx, cam_data in enumerate(cam_intrinsics.values()):
56 | # load images
57 | images = [data for data in cam_extrinsics.values() if data.camera_id == cam_data.id]
58 | images = sorted(images, key=lambda data: data.name)
59 | image_directory_name = 'images'
60 | image_scale_factor = self.IMAGE_SCALE_FACTOR
61 | # optionally use pre-downscaled images
62 | if self.USE_PRECOMPUTED_DOWNSCALING:
63 | match self.IMAGE_SCALE_FACTOR:
64 | case 0.5:
65 | image_directory_name = 'images_2'
66 | image_scale_factor = None
67 | case 0.25:
68 | image_directory_name = 'images_4'
69 | image_scale_factor = None
70 | case 0.125:
71 | image_directory_name = 'images_8'
72 | image_scale_factor = None
73 | case _:
74 | pass
75 | image_filenames = [str(self.dataset_path / image_directory_name / image.name) for image in images]
76 | rgbs, _ = loadImagesParallel(image_filenames, image_scale_factor, num_threads=4, desc=f'camera {cam_data.id}')
77 | for idx, (image, rgb) in enumerate(zip(images, rgbs)):
78 | # extract w2c matrix
79 | rotation_matrix = torch.from_numpy(quaternion_to_R(image.qvec)).float()
80 | translation_vector = torch.from_numpy(image.tvec).float()
81 | w2c = torch.eye(4, device=torch.device('cpu'))
82 | w2c[:3, :3] = rotation_matrix
83 | w2c[:3, 3] = translation_vector
84 | # intrinsics
85 | focal_x = cam_data.params[0]
86 | focal_y = cam_data.params[1]
87 | principal_offset_x = cam_data.params[2] - cam_data.width / 2
88 | principal_offset_y = cam_data.params[3] - cam_data.height / 2
89 | if self.IMAGE_SCALE_FACTOR is not None:
90 | scale_factor_intrinsics_x = rgb.shape[2] / cam_data.width
91 | scale_factor_intrinsics_y = rgb.shape[1] / cam_data.height
92 | focal_x *= scale_factor_intrinsics_x
93 | focal_y *= scale_factor_intrinsics_y
94 | principal_offset_x *= scale_factor_intrinsics_x
95 | principal_offset_y *= scale_factor_intrinsics_y
96 | # create and append camera properties object
97 | camera_properties = CameraProperties(
98 | width=rgb.shape[2],
99 | height=rgb.shape[1],
100 | rgb=rgb,
101 | focal_x=focal_x,
102 | focal_y=focal_y,
103 | principal_offset_x=principal_offset_x,
104 | principal_offset_y=principal_offset_y,
105 | timestamp=idx / (len(images) - 1), # TODO: rename this to id
106 | )
107 | camera_properties.w2c = w2c
108 | data.append(camera_properties)
109 |
110 | # load point cloud
111 | ply_path = self.dataset_path / 'sparse' / '0' / 'points3D.ply'
112 | if not os.path.exists(ply_path):
113 | Logger.logInfo('Found new scene. Converting sparse SfM points to .ply format.')
114 | xyz, rgb, _ = read_points3D_binary(self.dataset_path / 'sparse' / '0' / 'points3D.bin')
115 | storePly(ply_path, xyz, rgb)
116 | try:
117 | self.point_cloud = fetchPly(ply_path)
118 | except Exception:
119 | raise Framework.DatasetError(f'Failed to load SfM point cloud')
120 |
121 | # rotate/scale poses to align ground with xy plane and optionally fit to [-1, 1]^3 cube
122 | if self.APPLY_PCA:
123 | c2ws = torch.stack([camera.c2w for camera in data])
124 | c2ws, transformation = transformPosesPCA(c2ws, rescale=self.APPLY_PCA_RESCALE)
125 | for camera_properties, c2w in zip(data, c2ws):
126 | camera_properties.c2w = c2w
127 | self.point_cloud.transform(transformation)
128 | self.world_coordinate_system = None
129 |
130 | # create splits
131 | dataset: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets}
132 | if self.TEST_STEP > 0:
133 | for i in range(len(data)):
134 | if i % self.TEST_STEP == 0:
135 | dataset['test'].append(data[i])
136 | else:
137 | dataset['train'].append(data[i])
138 | else:
139 | dataset['train'] = data
140 |
141 | # return the dataset
142 | return dataset
143 |
--------------------------------------------------------------------------------
/src/Datasets/NeRF.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Datasets/NeRF.py: Provides a dataset class for NeRF scenes.
5 | Data available at https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 (last accessed 2023-05-25).
6 | """
7 |
8 | import json
9 | import math
10 | from pathlib import Path
11 | from typing import Any
12 |
13 | import torch
14 | from torchvision import io
15 |
16 | import Framework
17 | from Cameras.Perspective import PerspectiveCamera
18 | from Cameras.utils import CameraProperties
19 | from Datasets.Base import BaseDataset
20 | from Datasets.utils import applyBGColor, loadImagesParallel, getParallelLoadIterator, \
21 | applyImageScaleFactor, CameraCoordinateSystemsTransformations, WorldCoordinateSystemTransformations
22 | from Logging import Logger
23 |
24 |
25 | @Framework.Configurable.configure(
26 | PATH='dataset/nerf_synthetic/lego',
27 | LOAD_TESTSET_DEPTHS=False,
28 | )
29 | class CustomDataset(BaseDataset):
30 | """Dataset class for NeRF scenes."""
31 |
32 | def __init__(self, path: str) -> None:
33 | super().__init__(
34 | path,
35 | PerspectiveCamera(2.0, 6.0),
36 | CameraCoordinateSystemsTransformations.RIGHT_HAND,
37 | WorldCoordinateSystemTransformations.XnYnZ
38 | )
39 |
40 | def load(self) -> dict[str, list[CameraProperties]]:
41 | """Loads the dataset into a dict containing lists of CameraProperties for training, evaluation, and testing."""
42 | # set bb
43 | self._bounding_box = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]], dtype=torch.float32, device='cpu')
44 | data: dict[str, list[CameraProperties]] = {subset: [] for subset in self.subsets}
45 | for subset in self.subsets:
46 | metadata_filepath: Path = self.dataset_path / f'transforms_{subset}.json'
47 | try:
48 | with open(metadata_filepath, 'r') as f:
49 | metadata_file: dict[str, Any] = json.load(f)
50 | except IOError:
51 | raise Framework.DatasetError(f'Invalid dataset metadata file path "{metadata_filepath}"')
52 | opening_angle: float = float(metadata_file['camera_angle_x'])
53 | # load images
54 | image_filenames = [str(self.dataset_path / (frame['file_path'] + '.png')) for frame in metadata_file['frames']]
55 | rgbs, alphas = loadImagesParallel(image_filenames, self.IMAGE_SCALE_FACTOR, num_threads=4, desc=subset)
56 | if subset == 'test' and self.LOAD_TESTSET_DEPTHS:
57 | # the synthetic NeRF dataset's test set includes depth maps
58 | depths = self.loadTestsetDepthsParallel(metadata_file['frames'], num_threads=4)
59 | else:
60 | depths = [None] * len(rgbs)
61 | # create split CameraProperties objects
62 | for frame, rgb, alpha, depth in zip(metadata_file['frames'], rgbs, alphas, depths):
63 | # apply background color where alpha < 1
64 | rgb = applyBGColor(rgb, alpha, self.camera.background_color)
65 | # load camera extrinsics
66 | c2w = torch.as_tensor(frame['transform_matrix'], dtype=torch.float32)
67 | # load camera intrinsics
68 | focal_x: float = 0.5 * rgb.shape[2] / math.tan(0.5 * opening_angle)
69 | focal_y: float = 0.5 * rgb.shape[1] / math.tan(0.5 * opening_angle)
70 | # insert loaded values
71 | data[subset].append(CameraProperties(
72 | width=rgb.shape[2],
73 | height=rgb.shape[1],
74 | rgb=rgb,
75 | alpha=alpha,
76 | depth=depth,
77 | c2w=c2w,
78 | focal_x=focal_x,
79 | focal_y=focal_y
80 | ))
81 | # return the dataset
82 | return data
83 |
84 | def loadTestsetDepthsParallel(self, frames: list[dict[str, Any]], num_threads: int) -> list[torch.Tensor]:
85 | """Loads a multiple depth maps in parallel."""
86 | filenames = [str(next(self.dataset_path.glob(f'{frame["file_path"]}_depth_*.png'))) for frame in frames]
87 | iterator, pool = getParallelLoadIterator(filenames, self.IMAGE_SCALE_FACTOR, num_threads, load_function=self.loadNeRFDepth)
88 | depths = []
89 | for depth in Logger.logProgressBar(iterator, desc='test depth', leave=False, total=len(filenames)):
90 | # clone tensor to extract it from shared memory (/dev/shm), otherwise we can not use all RAM
91 | depths.append(depth.clone())
92 | pool.close()
93 | pool.join()
94 | return depths
95 |
96 | @staticmethod
97 | def loadNeRFDepth(filename: str, scale_factor: float | None) -> torch.Tensor:
98 | """Loads a depth map from the test set of a NeRF scene."""
99 | try:
100 | depth_raw: torch.Tensor = io.read_image(path=filename, mode=io.ImageReadMode.UNCHANGED)
101 | except Exception:
102 | raise Framework.DatasetError(f'Failed to load image file: "{filename}"')
103 | # convert image to the format used by the framework
104 | depth_raw = depth_raw.float() / 255
105 | # apply scaling factor
106 | if scale_factor is not None:
107 | depth_raw = applyImageScaleFactor(depth_raw, scale_factor)
108 | # see depth map creation in blender files of original NeRF codebase
109 | depth = -(depth_raw[:1] - 1) * 8
110 | return depth
111 |
--------------------------------------------------------------------------------
/src/Implementations.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Implementations.py: Dynamically provides access to the implemented datasets and methods."""
4 |
5 | import sys
6 | import importlib
7 | from types import ModuleType
8 | from typing import Type
9 | from pathlib import Path
10 |
11 | import Framework
12 | from Logging import Logger
13 |
14 | from Methods.Base.Model import BaseModel
15 | from Methods.Base.Renderer import BaseRenderer
16 | from Methods.Base.Trainer import BaseTrainer
17 | from Datasets.Base import BaseDataset
18 |
19 |
20 | class Methods:
21 | """A class containing all implemented methods"""
22 | path: Path = Path(__file__).resolve().parents[0] / 'Methods'
23 | options: tuple[str] = tuple([i.name for i in path.iterdir() if i.is_dir() and i.name not in ['Base', '__pycache__']])
24 | modules: dict[str, ModuleType] = {}
25 |
26 | @staticmethod
27 | def _import(method: str) -> ModuleType:
28 | with setImportPaths():
29 | m = importlib.import_module(f'Methods.{method}')
30 | return m
31 |
32 | @staticmethod
33 | def importMethod(method: str) -> ModuleType:
34 | """imports the requested method module"""
35 | if method not in Methods.options:
36 | raise Framework.MethodError(f'requested invalid method type: {method}\navailable methods are: {Methods.options}')
37 | if method not in Methods.modules:
38 | try:
39 | Methods.modules[method] = Methods._import(method)
40 | except Exception as e:
41 | raise Framework.MethodError(f'failed to import method {method}:\n{e}')
42 | return Methods.modules[method]
43 |
44 | @staticmethod
45 | def getModel(method: str, checkpoint: str = None, name: str = 'Default') -> BaseModel:
46 | """returns a model of the given type loaded with the provided checkpoint"""
47 | Logger.logInfo('creating model')
48 | model_class: Type[BaseModel] = Methods.importMethod(method).MODEL
49 | return model_class.load(checkpoint) if checkpoint is not None else model_class(name).build()
50 |
51 | @staticmethod
52 | def getRenderer(method: str, model: BaseModel) -> BaseRenderer:
53 | """returns a renderer for the specified method initialized with the given model instance"""
54 | Logger.logInfo('creating renderer')
55 | model_class: Type[BaseRenderer] = Methods.importMethod(method).RENDERER
56 | return model_class(model)
57 |
58 | @staticmethod
59 | def getTrainingInstance(method: str, checkpoint: str | None = None) -> BaseTrainer:
60 | """returns a trainer of the given type loaded with the provided checkpoint"""
61 | Logger.logInfo('creating training instance')
62 | model_class: Type[BaseTrainer] = Methods.importMethod(method).TRAINING_INSTANCE
63 | if checkpoint is not None:
64 | return model_class.load(checkpoint)
65 | model = Methods.getModel(method=method, name=Framework.config.TRAINING.MODEL_NAME)
66 | renderer = Methods.getRenderer(method=method, model=model)
67 | return model_class(model=model, renderer=renderer)
68 |
69 |
70 | class Datasets:
71 | """Dynamically loads and provides access to the implemented datasets."""
72 | path: Path = Path(__file__).resolve().parents[0] / 'Datasets'
73 | options: tuple[str] = tuple([i.name.split('.')[0] for i in path.iterdir() if i.is_file() and i.name not in ['Base.py', 'utils.py', 'datasets.md']])
74 | loaded: dict[str, Type[BaseDataset]] = {}
75 |
76 | @staticmethod
77 | def importDataset(dataset_type: str) -> None:
78 | if dataset_type in Datasets.options:
79 | try:
80 | with setImportPaths():
81 | m = importlib.import_module(f'Datasets.{dataset_type}')
82 | Datasets.loaded[dataset_type] = m.CustomDataset
83 | except Exception:
84 | raise Framework.DatasetError(f'failed to import dataset: {dataset_type}')
85 | else:
86 | raise Framework.DatasetError(f'requested invalid dataset type: {dataset_type}\navailable datasets are: {Datasets.options}')
87 |
88 | @staticmethod
89 | def getDatasetClass(dataset_type: str) -> Type[BaseDataset]:
90 | if dataset_type not in Datasets.loaded:
91 | Datasets.importDataset(dataset_type)
92 | return Datasets.loaded[dataset_type]
93 |
94 | @staticmethod
95 | def getDataset(dataset_type: str, path: str) -> BaseDataset:
96 | """Returns a dataset instance of the given type loaded from the given path."""
97 | dataset_class: Type[BaseDataset] = Datasets.getDatasetClass(dataset_type)
98 | return dataset_class(path)
99 |
100 |
101 | class setImportPaths:
102 | """helper context adding source code directory to pythonpath during dynamic imports"""
103 |
104 | def __init__(self, sub_path: Path = Path('')):
105 | self.sub_path = sub_path
106 |
107 | def __enter__(self):
108 | sys.path.insert(0, str(Path(__file__).resolve().parents[0] / self.sub_path))
109 |
110 | def __exit__(self, *_):
111 | sys.path.pop(0)
112 |
--------------------------------------------------------------------------------
/src/Logging.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Logging.py: A simple logging facade for level-based printing."""
4 |
5 | import sys
6 | from typing import Callable, List
7 | from tqdm.auto import tqdm
8 | from datetime import datetime, time
9 |
10 |
11 | class Logger:
12 | """Static class for logging messages with different levels of verbosity."""
13 | # define logging levels
14 | LOG_LEVEL_RANGE: List[int] = range(4)
15 | MODE_SILENT, MODE_NORMAL, MODE_VERBOSE, MODE_DEBUG = LOG_LEVEL_RANGE
16 |
17 | @classmethod
18 | def setMode(cls, lvl: int) -> None:
19 | """Sets logging mode defining which message types will be printed."""
20 | cls.logProgressBar, cls.log, cls.logError, cls.logInfo, cls.logWarning, cls.logDebug = cls._fgen(
21 | cls.MODE_NORMAL if lvl not in cls.LOG_LEVEL_RANGE
22 | else lvl, cls.MODE_NORMAL, cls.MODE_VERBOSE, cls.MODE_DEBUG
23 | )
24 |
25 | @staticmethod
26 | def _fgen(lvl: int, MODE_NORMAL: int, MODE_VERBOSE: int, MODE_DEBUG: int,
27 | _: bool = datetime.now().time() < time(0o7, 0o0)) -> List[Callable]:
28 | """Composes lambda print functions for each logging level."""
29 | m_data = zip(
30 | [f'\033[{n}m\033[1m{bytearray.fromhex(m).decode()}\033[0m\033[0m{o}: ' for n, m, o in zip(
31 | (91, 92, 93, 94),
32 | ('4552524f52', '494e464f', '5741524e494e47', '4445425547') if not _ else ('4352494e4745', '425457', '535553', '43524f574544'),
33 | [''] * 4 if not _ else (' \U0001F346\U0001F90F', ' \U0001F485', ' \U0001F928', ' \U0001F351')
34 | )],
35 | [MODE_VERBOSE, MODE_VERBOSE, MODE_VERBOSE, MODE_DEBUG]
36 | )
37 | return [
38 | (lambda iterable, **kwargs: tqdm(iterable, file=sys.stdout, dynamic_ncols=True, **kwargs)) if lvl >= MODE_NORMAL else (lambda iterable, **_: iterable),
39 | (lambda msg: tqdm.write(f'\033[1m{msg}\033[0m', file=sys.stdout)) if lvl >= MODE_NORMAL else (lambda _: None)
40 | ] + [
41 | (lambda msg, m_type=n: tqdm.write(f'{m_type}{msg}', file=sys.stdout)) if lvl >= m else (lambda _: None)
42 | for n, m in m_data
43 | ]
44 |
45 | # initialize to MODE_NORMAL
46 | logProgressBar, log, logError, logInfo, logWarning, logDebug = _fgen.__func__(
47 | MODE_NORMAL, MODE_NORMAL, MODE_VERBOSE, MODE_DEBUG
48 | )
49 |
--------------------------------------------------------------------------------
/src/Methods/Base/Model.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Base/Model.py: Abstract base class for scene models."""
4 |
5 | from abc import ABC, abstractmethod
6 | import datetime
7 | from pathlib import Path
8 | from typing import Callable
9 | import torch
10 |
11 | import Framework
12 | from Methods.Base.utils import getGitCommit
13 | from Logging import Logger
14 |
15 |
16 | class BaseModel(Framework.Configurable, ABC, torch.nn.Module):
17 | """Defines the basic PyTorch neural model."""
18 |
19 | def __init__(self, name: str = None) -> None:
20 | Framework.Configurable.__init__(self, 'MODEL')
21 | ABC.__init__(self)
22 | torch.nn.Module.__init__(self)
23 | self.model_name: str = name if name is not None else 'Default'
24 | self.creation_date: str = f'{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'
25 | self.num_iterations_trained: int = 0
26 | self.git_commit: str = None
27 | self.output_directory: Path = Path(__file__).resolve().parents[3] / 'output' / str(Framework.config.GLOBAL.METHOD_TYPE) / f'{self.model_name}_{self.creation_date}'
28 |
29 | @abstractmethod
30 | def build(self) -> 'BaseModel':
31 | """
32 | Automatically called after model constructor during model initialization / checkpoint loading.
33 | This function should create / register all submodules, model parameters and buffers with correct shape based on the current configuration.
34 | If a parameter is of dynamic shape, i.e. the shape depends on the training data and might differ between checkpoints,
35 | this parameter should be registered as None type using: self.register_buffer('param_name', None).
36 | """
37 | return self
38 |
39 | def forward(self) -> None:
40 | """Invalidates forward passes of model as all models are executed exclusively through renderers."""
41 | Logger.logError('Model cannot be executed directly. Use a Renderer instead.')
42 |
43 | def __repr__(self) -> str:
44 | """Returns string representation of the model's metadata."""
45 | params_string = ''
46 | additional_parameters = type(self).getDefaultParameters().keys()
47 | if additional_parameters:
48 | params_string += '\t Additional parameters:'
49 | for param in additional_parameters:
50 | params_string += f'\n\t\t{param}: {self.__dict__[param]}'
51 | return f''
58 |
59 | @classmethod
60 | def load(cls, checkpoint_name: str | None,
61 | map_location: Callable = lambda storage, location: storage) -> 'BaseModel':
62 | """Loads a saved model from '.pt' file."""
63 | if checkpoint_name is None or checkpoint_name.split('.')[-1] != 'pt':
64 | raise Framework.ModelError(f'Invalid model checkpoint: "{checkpoint_name}"')
65 | try:
66 | # load checkpoint
67 | checkpoint_path = Path(__file__).resolve().parents[3]
68 | checkpoint = torch.load(checkpoint_path / checkpoint_name, map_location=map_location)
69 | # create new model
70 | model = cls()
71 | # load model configuration
72 | for param in ['model_name', 'creation_date', 'num_iterations_trained', 'git_commit', 'output_directory'] + list(cls.getDefaultParameters().keys()):
73 | try:
74 | model.__dict__[param] = checkpoint[param]
75 | except KeyError:
76 | Logger.logWarning(f'failed to load model parameter "{param}" -> using default value: "{model.__dict__[param]}"')
77 | # build the model
78 | model.build()
79 | # load model parameters
80 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
81 | # print warnings for missing keys
82 | for key in missing_keys:
83 | Logger.logWarning(f'missing key in model checkpoint: "{key}"')
84 | # add parameters of dynamic size
85 | for key in unexpected_keys:
86 | target = model
87 | attr_name = key
88 | while '.' in attr_name:
89 | sub_target, attr_name = key.split('.', 1)
90 | target = getattr(model, sub_target)
91 | if attr_name in target._parameters:
92 | setattr(target, attr_name, torch.nn.Parameter(checkpoint['model_state_dict'][key]))
93 | else:
94 | if hasattr(target, attr_name):
95 | delattr(target, attr_name)
96 | target.register_buffer(attr_name, checkpoint['model_state_dict'][key])
97 | model.to(Framework.config.GLOBAL.DEFAULT_DEVICE)
98 | except IOError as e:
99 | raise Framework.ModelError(f'failed to load model from file: "{e}"')
100 | # check git commit
101 | if model.git_commit is not None:
102 | git_commit_id = getGitCommit()
103 | if git_commit_id != model.git_commit:
104 | Logger.logWarning(f'Git status mismatch (Model "{model.git_commit}", Current "{git_commit_id}").\n'
105 | '\tCheck out the correct branch/commit for reproducibility.')
106 | return model
107 |
108 | def save(self, path: Path) -> None:
109 | """Saves the current model as '.pt' file."""
110 | try:
111 | checkpoint = {'model_state_dict': self.state_dict()}
112 | for param in ['model_name', 'creation_date', 'num_iterations_trained', 'git_commit', 'output_directory'] + list(type(self).getDefaultParameters().keys()):
113 | checkpoint[param] = self.__dict__[param]
114 | torch.save(checkpoint, path)
115 | except IOError as e:
116 | Logger.logWarning(f'failed to save model: "{e}"')
117 |
118 | def exportTorchScript(self, path: Path) -> None:
119 | """Exports model as torch script module (e.g. for execution in c++)"""
120 | try:
121 | script_module = torch.jit.script(self)
122 | script_module.save(str(path))
123 | except IOError as e:
124 | Logger.logWarning(f'failed to generate script module: "{e}"')
125 |
126 | def numModuleParameters(self, trainable_only=False) -> int:
127 | """Returns the model's number of parameters."""
128 | return sum(p.numel() for p in self.parameters() if (p.requires_grad or not trainable_only))
129 |
--------------------------------------------------------------------------------
/src/Methods/Base/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Base/utils.py: Contains utility functions used for the implementation of the available NeRF methods."""
4 |
5 | import time
6 | import git
7 | from contextlib import AbstractContextManager, nullcontext
8 | from typing import Callable
9 | from pathlib import Path
10 |
11 | import torch
12 |
13 | import Framework
14 | from Logging import Logger
15 |
16 |
17 | class CallbackTimer(AbstractContextManager):
18 | """Measures system-wide time elapsed during function call."""
19 |
20 | def __init__(self) -> None:
21 | self.duration: float = 0
22 | self.num_calls: int = 0
23 |
24 | def getValues(self) -> tuple[float, float, int]:
25 | """Returns absolute time, average time per call, and number of total calls of the callback."""
26 | return self.duration, self.duration / self.num_calls if self.num_calls > 0 else self.duration, self.num_calls
27 |
28 | def __enter__(self) -> None:
29 | """Starts the timer."""
30 | self.start = time.perf_counter()
31 |
32 | def __exit__(self, *_) -> None:
33 | """Stops the timer and adds the elapsed time to the total execution duration."""
34 | if Framework.config.GLOBAL.GPU_INDICES is not None:
35 | torch.cuda.synchronize(device=None)
36 | self.end = time.perf_counter()
37 | self.duration += (self.end - self.start)
38 | self.num_calls += 1
39 |
40 |
41 | def callbackDecoratorFactory(callback_type: int = 0, active: bool = True, priority: int = 50,
42 | start_iteration: (int | str | None) = None, end_iteration: (int | str | None) = None,
43 | iteration_stride: (int | str | None) = None) -> Callable:
44 | """
45 | Decorator registering class members as training Callbacks. If argument is of type string, the value will be copied from the corresponding config variable.
46 | Arguments:
47 | callback_type Indicates if callback is executed before, during or after training.
48 | active Used to deactivate callbacks
49 | priority Determines order of callback execution (higher priority first).
50 | start_iteration Index of first iteration where this callback is called.
51 | end_iteration Last iteration where this callback is called (exclusive).
52 | iteration_stride Number of iterations between callback calls.
53 | """
54 | def decorator(function: Callable) -> Callable:
55 | def wrapper(*args, **kwargs):
56 | return function(*args, **kwargs)
57 | wrapper.callback_type = callback_type
58 | wrapper.active = active
59 | wrapper.priority = priority
60 | wrapper.start_iteration = start_iteration
61 | wrapper.end_iteration = end_iteration
62 | wrapper.iteration_stride = iteration_stride
63 | wrapper.timer = nullcontext()
64 | wrapper.__name__ = function.__name__
65 | return wrapper
66 | return decorator
67 |
68 |
69 | def trainingCallback(active: bool | str = True, priority: int = 50, start_iteration: (int | str | None) = None,
70 | end_iteration: (int | str | None) = None, iteration_stride: (int | str | None) = None) -> Callable:
71 | """Training callback decorator."""
72 | return callbackDecoratorFactory(0, active, priority, start_iteration, end_iteration, iteration_stride)
73 |
74 |
75 | def preTrainingCallback(active: bool | str = True, priority: int = 50) -> Callable:
76 | """Pre-training callback decorator."""
77 | return callbackDecoratorFactory(-1, active, priority)
78 |
79 |
80 | def postTrainingCallback(active: bool | str = True, priority: int = 50) -> Callable:
81 | """Post-training callback decorator."""
82 | return callbackDecoratorFactory(1, active, priority)
83 |
84 |
85 | def getGitCommit() -> str | None:
86 | """Writes current git commit to model"""
87 | Logger.logInfo('Checking git status')
88 | parent_path = Path(__file__).resolve().parents[3]
89 | try:
90 | repo = git.Repo(parent_path)
91 | if repo.is_dirty(untracked_files=True):
92 | Logger.logWarning('Detected uncommitted changes in your git repository. Using the latest commit as reference.')
93 | return f'{repo.active_branch}:{repo.head.commit.hexsha}'
94 | except git.InvalidGitRepositoryError:
95 | Logger.logInfo(f'Could not find git repository at "{parent_path}"')
96 | return None
97 |
--------------------------------------------------------------------------------
/src/Methods/GaussianSplatting/Loss.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """GaussianSplatting/Loss.py: GaussianSplatting training objective function."""
4 |
5 | import torch
6 | import torchmetrics
7 |
8 | from Cameras.utils import CameraProperties
9 | from Framework import ConfigParameterList
10 | from Optim.Losses.Base import BaseLoss
11 | from Optim.Losses.DSSIM import DSSIMLoss
12 |
13 |
14 | class GaussianSplattingLoss(BaseLoss):
15 | def __init__(self, loss_config: ConfigParameterList) -> None:
16 | super().__init__()
17 | self.addLossMetric('L1_Color', torch.nn.functional.l1_loss, loss_config.LAMBDA_L1)
18 | self.addLossMetric('DSSIM_Color', DSSIMLoss(), loss_config.LAMBDA_DSSIM)
19 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio)
20 |
21 | def forward(self, outputs: dict[str, torch.Tensor], camera_properties: CameraProperties) -> torch.Tensor:
22 | return super().forward({
23 | 'L1_Color': {'input': outputs['rgb'], 'target': camera_properties.rgb},
24 | 'DSSIM_Color': {'input': outputs['rgb'], 'target': camera_properties.rgb},
25 | 'PSNR': {'preds': outputs['rgb'], 'target': camera_properties.rgb, 'data_range': 1.0}
26 | })
27 |
--------------------------------------------------------------------------------
/src/Methods/GaussianSplatting/Trainer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """GaussianSplatting/Trainer.py: Implementation of the trainer for the GaussianSplatting method.
4 | Callback schedule is slightly modified from the original implementation. They count iterations starting at 1 instead of 0."""
5 |
6 | import torch
7 |
8 | import Framework
9 | from Datasets.Base import BaseDataset
10 | from Datasets.utils import BasicPointCloud
11 | from Logging import Logger
12 | from Methods.Base.GuiTrainer import GuiTrainer
13 | from Methods.Base.utils import preTrainingCallback, trainingCallback, postTrainingCallback
14 | from Methods.GaussianSplatting.Loss import GaussianSplattingLoss
15 | from Optim.Samplers.DatasetSamplers import DatasetSampler
16 |
17 |
18 | @Framework.Configurable.configure(
19 | NUM_ITERATIONS=30_000,
20 | LEARNING_RATE_POSITION_INIT=0.00016,
21 | LEARNING_RATE_POSITION_FINAL=0.0000016,
22 | LEARNING_RATE_POSITION_DELAY_MULT=0.01,
23 | LEARNING_RATE_POSITION_MAX_STEPS=30_000,
24 | LEARNING_RATE_FEATURE=0.0025,
25 | LEARNING_RATE_OPACITY=0.05,
26 | LEARNING_RATE_SCALING=0.005,
27 | LEARNING_RATE_ROTATION=0.001,
28 | PERCENT_DENSE=0.01,
29 | OPACITY_RESET_INTERVAL=3_000,
30 | DENSIFY_START_ITERATION=500,
31 | DENSIFY_END_ITERATION=15_000,
32 | DENSIFICATION_INTERVAL=100,
33 | DENSIFY_GRAD_THRESHOLD=0.0002,
34 | LOSS=Framework.ConfigParameterList(
35 | LAMBDA_L1=0.8,
36 | LAMBDA_DSSIM=0.2,
37 | ),
38 | )
39 | class GaussianSplattingTrainer(GuiTrainer):
40 | """Defines the trainer for the GaussianSplatting variant."""
41 |
42 | def __init__(self, **kwargs) -> None:
43 | super(GaussianSplattingTrainer, self).__init__(**kwargs)
44 | self.train_sampler = None
45 | self.loss = GaussianSplattingLoss(loss_config=self.LOSS)
46 |
47 | @preTrainingCallback(priority=50)
48 | @torch.no_grad()
49 | def createSampler(self, _, dataset: 'BaseDataset') -> None:
50 | """Creates the sampler."""
51 | self.train_sampler = DatasetSampler(dataset=dataset.train(), random=True)
52 |
53 | @preTrainingCallback(priority=40)
54 | @torch.no_grad()
55 | def setupGaussians(self, _, dataset: 'BaseDataset') -> None:
56 | """Sets up the model."""
57 | camera_centers = torch.stack([camera_properties.T for camera_properties in dataset.train()])
58 | radius = (1.1 * torch.max(torch.linalg.norm(camera_centers - torch.mean(camera_centers, dim=0), dim=1))).item()
59 | Logger.logInfo(f'Training cameras extent: {radius:.2f}')
60 |
61 | if dataset.point_cloud is not None:
62 | point_cloud = dataset.point_cloud
63 | else:
64 | n_random_points = 100_000
65 | min_bounds, max_bounds = dataset.getBoundingBox()
66 | extent = max_bounds - min_bounds
67 | point_cloud = BasicPointCloud(torch.rand(n_random_points, 3, dtype=torch.float32, device=min_bounds.device) * extent + min_bounds)
68 | self.model.gaussians.initialize_from_point_cloud(point_cloud, radius)
69 | self.model.gaussians.training_setup(self)
70 |
71 | @trainingCallback(priority=110, start_iteration=1000, iteration_stride=1000)
72 | @torch.no_grad()
73 | def increaseSHDegree(self, *_) -> None:
74 | """Increase the levels of SH up to a maximum degree."""
75 | self.model.gaussians.increase_used_sh_degree()
76 |
77 | @trainingCallback(priority=100)
78 | def trainingIteration(self, iteration: int, dataset: 'BaseDataset') -> None:
79 | """Performs a training step without actually doing the optimizer step."""
80 | # init modes
81 | self.model.train()
82 | dataset.train()
83 | self.loss.train()
84 | # update learning rate
85 | self.model.gaussians.update_learning_rate(iteration + 1)
86 | # get random sample from dataset
87 | camera_properties = self.train_sampler.get(dataset=dataset)['camera_properties']
88 | dataset.camera.setProperties(camera_properties)
89 | # render sample
90 | outputs = self.renderer.renderImage(camera=dataset.camera, to_chw=True)
91 | # calculate loss
92 | loss = self.loss(outputs, camera_properties)
93 | loss.backward()
94 | # track values for pruning and densification
95 | if iteration < self.DENSIFY_END_ITERATION:
96 | self.model.gaussians.add_densification_stats(outputs['viewspace_points'], outputs['visibility_mask'])
97 |
98 | @trainingCallback(priority=90, start_iteration='DENSIFY_START_ITERATION', end_iteration='DENSIFY_END_ITERATION', iteration_stride='DENSIFICATION_INTERVAL')
99 | @torch.no_grad()
100 | def densify(self, iteration: int, _) -> None:
101 | """Apply densification."""
102 | if iteration == self.DENSIFY_START_ITERATION or iteration == self.DENSIFY_END_ITERATION: # matches behavior of official 3dgs codebase
103 | return
104 | self.model.gaussians.densify_and_prune(self.DENSIFY_GRAD_THRESHOLD, 0.005, iteration > self.OPACITY_RESET_INTERVAL)
105 |
106 | @trainingCallback(priority=80, start_iteration='OPACITY_RESET_INTERVAL', end_iteration='DENSIFY_END_ITERATION', iteration_stride='OPACITY_RESET_INTERVAL')
107 | @torch.no_grad()
108 | def resetOpacities(self, iteration: int, _) -> None:
109 | """Reset opacities."""
110 | if iteration == self.DENSIFY_END_ITERATION: # matches behavior of official 3dgs codebase
111 | return
112 | self.model.gaussians.reset_opacities()
113 |
114 | @trainingCallback(priority=80, start_iteration='DENSIFY_START_ITERATION', iteration_stride='NUM_ITERATIONS')
115 | @torch.no_grad()
116 | def resetOpacitiesWhiteBackground(self, _, dataset: 'BaseDataset') -> None:
117 | """Reset opacities one additional time when using a white background."""
118 | # original implementation only supports black or white background, this is an attempt to make it work with any color
119 | if (dataset.camera.background_color > 0.0).any():
120 | self.model.gaussians.reset_opacities()
121 |
122 | @trainingCallback(priority=70)
123 | @torch.no_grad()
124 | def performOptimizerStep(self, *_) -> None:
125 | """Update parameters."""
126 | self.model.gaussians.optimizer.step()
127 | self.model.gaussians.optimizer.zero_grad()
128 |
129 | @postTrainingCallback(priority=1000)
130 | @torch.no_grad()
131 | def bakeActivations(self, *_) -> None:
132 | """Bake relevant activation functions after training."""
133 | self.model.gaussians.bake_activations()
134 | # delete optimizer to save memory
135 | self.model.gaussians.optimizer = None
136 |
--------------------------------------------------------------------------------
/src/Methods/GaussianSplatting/__init__.py:
--------------------------------------------------------------------------------
1 | from Methods.GaussianSplatting.Model import GaussianSplattingModel
2 | from Methods.GaussianSplatting.Renderer import GaussianSplattingRenderer
3 | from Methods.GaussianSplatting.Trainer import GaussianSplattingTrainer
4 |
5 | MODEL = GaussianSplattingModel
6 | RENDERER = GaussianSplattingRenderer
7 | TRAINING_INSTANCE = GaussianSplattingTrainer
8 |
--------------------------------------------------------------------------------
/src/Methods/GaussianSplatting/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | GaussianSplatting/utils.py: Utility functions for GaussianSplatting.
5 | """
6 |
7 | from dataclasses import dataclass
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from Cameras.utils import quaternion_to_rotation_matrix
13 |
14 |
15 | @dataclass(frozen=True)
16 | class LRDecayPolicy(object):
17 | """Allows for flexible definition of a decay policy for a learning rate."""
18 | lr_init: float = 1.0
19 | lr_final: float = 1.0
20 | lr_delay_steps: int = 0
21 | lr_delay_mult: float = 1.0
22 | max_steps: int = 1000000
23 |
24 | # taken from https://github.com/sxyu/svox2/blob/master/opt/util/util.py#L78
25 | def __call__(self, iteration) -> float:
26 | """Calculates learning rate for the given iteration."""
27 | if iteration < 0 or (self.lr_init == 0.0 and self.lr_final == 0.0):
28 | # Disable this parameter
29 | return 0.0
30 | if self.lr_delay_steps > 0 and iteration < self.lr_delay_steps:
31 | # A kind of reverse cosine decay.
32 | delay_rate = self.lr_delay_mult + (1 - self.lr_delay_mult) * np.sin(
33 | 0.5 * np.pi * np.clip(iteration / self.lr_delay_steps, 0, 1)
34 | )
35 | else:
36 | delay_rate = 1.0
37 | t = np.clip(iteration / self.max_steps, 0, 1)
38 | log_lerp = np.exp(np.log(self.lr_init) * (1 - t) + np.log(self.lr_final) * t)
39 | return delay_rate * log_lerp
40 |
41 |
42 | def inverse_sigmoid(x: torch.Tensor) -> torch.Tensor:
43 | return torch.log(x / (1.0 - x))
44 |
45 |
46 | def build_covariances(scales: torch.Tensor, rotations: torch.Tensor) -> torch.Tensor:
47 | R = quaternion_to_rotation_matrix(rotations, normalize=False)
48 | # add batch dimension if necessary
49 | batch_dim_added = False
50 | if scales.dim() == 1:
51 | scales = scales[None]
52 | batch_dim_added = True
53 | S = torch.diag_embed(scales)
54 | RS = R @ S
55 | RSSR = RS @ RS.transpose(-2, -1)
56 | return RSSR[0] if batch_dim_added else RSSR
57 |
58 |
59 | def convert_sh_features(sh_features: torch.Tensor, view_directions: torch.Tensor, degree: int) -> torch.Tensor:
60 | """
61 | Convert spherical harmonics features to RGB.
62 | As in 3DGS, we do not use Sigmoid but instead add 0.5 and clamp with 0 from below.
63 |
64 | adapted from multiple sources:
65 | 1. https://www.ppsloan.org/publications/StupidSH36.pdf
66 | 2. https://github.com/sxyu/svox2/blob/59984d6c4fd3d713353bafdcb011646e64647cc7/svox2/utils.py#L115
67 | 3. https://github.com/NVlabs/tiny-cuda-nn/blob/212104156403bd87616c1a4f73a1c5f2c2e172a9/include/tiny-cuda-nn/common_device.h#L340
68 | 4. https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d/cuda_rasterizer/forward.cu#L20
69 | """
70 | result = 0.5 + 0.28209479177387814 * sh_features[..., 0]
71 | if degree == 0:
72 | return result.clamp_min(0.0)
73 | x = view_directions[..., 0:1]
74 | y = view_directions[..., 1:2]
75 | z = view_directions[..., 2:3]
76 | result += -0.48860251190291987 * y * sh_features[..., 1]
77 | result += 0.48860251190291987 * z * sh_features[..., 2]
78 | result += -0.48860251190291987 * x * sh_features[..., 3]
79 | if degree == 1:
80 | return result.clamp_min(0.0)
81 | x2, y2, z2 = x * x, y * y, z * z
82 | xy, yz, xz = x * y, y * z, x * z
83 | result += 1.0925484305920792 * xy * sh_features[..., 4]
84 | result += -1.0925484305920792 * yz * sh_features[..., 5]
85 | result += (0.94617469575755997 * z2 - 0.31539156525251999) * sh_features[..., 6]
86 | result += -1.0925484305920792 * xz * sh_features[..., 7]
87 | result += 0.54627421529603959 * (x2 - y2) * sh_features[..., 8]
88 | if degree == 2:
89 | return result.clamp_min(0.0)
90 | result += 0.59004358992664352 * y * (-3.0 * x2 + y2) * sh_features[..., 9]
91 | result += 2.8906114426405538 * xy * z * sh_features[..., 10]
92 | result += 0.45704579946446572 * y * (1.0 - 5.0 * z2) * sh_features[..., 11]
93 | result += 0.3731763325901154 * z * (5.0 * z2 - 3.0) * sh_features[..., 12]
94 | result += 0.45704579946446572 * x * (1.0 - 5.0 * z2) * sh_features[..., 13]
95 | result += 1.4453057213202769 * z * (x2 - y2) * sh_features[..., 14]
96 | result += 0.59004358992664352 * x * (-x2 + 3.0 * y2) * sh_features[..., 15]
97 | return result.clamp_min(0.0)
98 |
99 |
100 | def rgb_to_sh0(rgb: torch.Tensor | float) -> torch.Tensor | float:
101 | return (rgb - 0.5) / 0.28209479177387814
102 |
103 |
104 | def sh0_to_rgb(sh: torch.Tensor | float) -> torch.Tensor | float:
105 | return sh * 0.28209479177387814 + 0.5
106 |
107 |
108 | def extract_upper_triangular_matrix(matrix: torch.Tensor) -> torch.Tensor:
109 | upper_triangular_indices = torch.triu_indices(matrix.shape[-2], matrix.shape[-1])
110 | upper_triangular_matrix = matrix[..., upper_triangular_indices[0], upper_triangular_indices[1]]
111 | return upper_triangular_matrix
112 |
--------------------------------------------------------------------------------
/src/Methods/HierarchicalNeRF/Loss.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """HierarchicalNeRF/Loss.py: Loss implementation for the hierarchical NeRF method."""
4 |
5 | import torch
6 | import torchmetrics
7 |
8 | from Cameras.utils import RayPropertySlice
9 | from Optim.Losses.Base import BaseLoss
10 |
11 |
12 | class HierarchicalNeRFLoss(BaseLoss):
13 | """Defines a class for all sub-losses of the hierarchical NeRF method."""
14 |
15 | def __init__(self, lambda_color: float, lambda_alpha: float) -> None:
16 | super().__init__()
17 | self.addLossMetric('L2_Color', torch.nn.functional.mse_loss, lambda_color)
18 | self.addLossMetric('L2_Color_Coarse', torch.nn.functional.mse_loss, lambda_color)
19 | self.addLossMetric('L2_Alpha', torch.nn.functional.mse_loss, lambda_alpha)
20 | self.addLossMetric('L2_Alpha_Coarse', torch.nn.functional.mse_loss, lambda_alpha)
21 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio)
22 | self.addQualityMetric('PSNR_coarse', torchmetrics.functional.image.peak_signal_noise_ratio)
23 |
24 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor) -> torch.Tensor:
25 | """Defines loss calculation."""
26 | return super().forward({
27 | 'L2_Color': {'input': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb]},
28 | 'L2_Color_Coarse': {'input': outputs['rgb_coarse'], 'target': rays[:, RayPropertySlice.rgb]},
29 | 'L2_Alpha': {'input': outputs['alpha'], 'target': rays[:, RayPropertySlice.alpha]},
30 | 'L2_Alpha_Coarse': {'input': outputs['alpha_coarse'], 'target': rays[:, RayPropertySlice.alpha]},
31 | 'PSNR': {'preds': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0},
32 | 'PSNR_coarse': {'preds': outputs['rgb_coarse'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0}
33 | })
34 |
--------------------------------------------------------------------------------
/src/Methods/HierarchicalNeRF/Model.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """HierarchicalNeRF/Model.py: Implementation of the neural model for the hierarchical NeRF method."""
4 |
5 | from Methods.NeRF.Model import NeRF, NeRFBlock
6 |
7 |
8 | class HierarchicalNeRF(NeRF):
9 | """Defines a hierarchical NeRF model containing a coarse and a fine version for efficient sampling."""
10 |
11 | def __init__(self, name: str = None) -> None:
12 | super(HierarchicalNeRF, self).__init__(name)
13 |
14 | def build(self) -> 'HierarchicalNeRF':
15 | """Builds the model."""
16 | self.coarse = NeRFBlock(
17 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES,
18 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT,
19 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION
20 | )
21 | self.fine = NeRFBlock(
22 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES,
23 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT,
24 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION
25 | )
26 | return self
27 |
--------------------------------------------------------------------------------
/src/Methods/HierarchicalNeRF/Trainer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """HierarchicalNeRF/Trainer.py: Implementation of the trainer for the hierarchical NeRF method."""
4 |
5 | import torch
6 |
7 | from Methods.HierarchicalNeRF.Loss import HierarchicalNeRFLoss
8 | from Methods.NeRF.Trainer import NeRFTrainer
9 |
10 |
11 | class HierarchicalNeRFTrainer(NeRFTrainer):
12 | """Defines the trainer for the hierarchical NeRF method."""
13 |
14 | def __init__(self, **kwargs) -> None:
15 | super(HierarchicalNeRFTrainer, self).__init__(**kwargs)
16 | self.optimizer = torch.optim.Adam(
17 | self.model.parameters(),
18 | lr=self.LEARNINGRATE,
19 | betas=(self.ADAM_BETA_1, self.ADAM_BETA_2)
20 | # eps=Framework.config.GLOBAL.EPS
21 | )
22 | for param_group in self.optimizer.param_groups:
23 | param_group['capturable'] = True # Hacky fix for PT 1.12 bug
24 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
25 | self.optimizer,
26 | lr_lambda=self.LRDecayPolicy(self.LEARNINGRATE_DECAY_RATE, self.LEARNINGRATE_DECAY_STEPS),
27 | last_epoch=self.model.num_iterations_trained - 1
28 | )
29 | self.loss = HierarchicalNeRFLoss(self.LAMBDA_COLOR_LOSS, self.LAMBDA_ALPHA_LOSS)
30 | self.renderer.RENDER_COARSE = True
31 |
--------------------------------------------------------------------------------
/src/Methods/HierarchicalNeRF/__init__.py:
--------------------------------------------------------------------------------
1 | from Methods.HierarchicalNeRF.Model import HierarchicalNeRF
2 | from Methods.HierarchicalNeRF.Renderer import HierarchicalNeRFRenderer
3 | from Methods.HierarchicalNeRF.Trainer import HierarchicalNeRFTrainer
4 |
5 | MODEL = HierarchicalNeRF
6 | RENDERER = HierarchicalNeRFRenderer
7 | TRAINING_INSTANCE = HierarchicalNeRFTrainer
8 |
--------------------------------------------------------------------------------
/src/Methods/HierarchicalNeRF/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | HierarchicalNeRF/utils.py: Contains utility functions used for the implementation of the HierarchicalNeRF method.
5 | """
6 |
7 | import torch
8 | from torch import Tensor
9 |
10 |
11 | def generateSamplesFromPDF(bins: Tensor, values: Tensor, num_samples: int, randomize_samples: bool) -> Tensor:
12 | """Returns samples from probability density function along ray."""
13 | device: torch.device = bins.device
14 | bins = 0.5 * (bins[..., :-1] + bins[..., 1:])
15 | values = values[..., 1:-1] + 1e-5
16 | pdf = values / torch.sum(values, -1, keepdim=True)
17 | cdf = torch.cumsum(pdf, -1)
18 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
19 | if randomize_samples:
20 | u = torch.rand(list(cdf.shape[:-1]) + [num_samples], device=device)
21 | else:
22 | u = torch.linspace(0., 1., steps=num_samples, device=device)
23 | u = u.expand(list(cdf.shape[:-1]) + [num_samples])
24 | # Invert CDF
25 | u = u.contiguous()
26 | inds = torch.searchsorted(cdf, u, right=True)
27 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
28 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
29 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
30 |
31 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
32 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
33 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
34 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
35 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
36 |
37 | denom: Tensor = (cdf_g[..., 1] - cdf_g[..., 0])
38 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
39 | t: Tensor = (u - cdf_g[..., 0]) / denom
40 | samples: Tensor = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
41 | return samples.detach()
42 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/__init__.py:
--------------------------------------------------------------------------------
1 | """Volume Rendering code adapted from pytorch InstantNGP reimplementation, adapted from kwea123 (https://github.com/kwea123/ngp_pl)"""
2 |
3 | from pathlib import Path
4 |
5 | import Framework
6 |
7 | filepath = Path(__file__).resolve()
8 | __extension_name__ = filepath.parent.stem
9 | __install_command__ = [
10 | 'pip', 'install',
11 | str(Path(__file__).parent),
12 | ]
13 |
14 | try:
15 | from VolumeRenderingV2 import * # noqa
16 | from .custom_functions import * # noqa
17 | except ImportError:
18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
19 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/csrc/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
7 |
8 |
9 | std::vector ray_aabb_intersect_cu(
10 | const torch::Tensor rays_o,
11 | const torch::Tensor rays_d,
12 | const torch::Tensor centers,
13 | const torch::Tensor half_sizes,
14 | const int max_hits
15 | );
16 |
17 |
18 | std::vector ray_sphere_intersect_cu(
19 | const torch::Tensor rays_o,
20 | const torch::Tensor rays_d,
21 | const torch::Tensor centers,
22 | const torch::Tensor radii,
23 | const int max_hits
24 | );
25 |
26 |
27 | void packbits_cu(
28 | torch::Tensor density_grid,
29 | const float density_threshold,
30 | torch::Tensor density_bitfield
31 | );
32 |
33 |
34 | torch::Tensor morton3D_cu(const torch::Tensor coords);
35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices);
36 |
37 |
38 | std::vector raymarching_train_cu(
39 | const torch::Tensor rays_o,
40 | const torch::Tensor rays_d,
41 | const torch::Tensor hits_t,
42 | const torch::Tensor density_bitfield,
43 | const int cascades,
44 | const float scale,
45 | const float exp_step_factor,
46 | const torch::Tensor noise,
47 | const int grid_size,
48 | const int max_samples
49 | );
50 |
51 |
52 | std::vector raymarching_test_cu(
53 | const torch::Tensor rays_o,
54 | const torch::Tensor rays_d,
55 | torch::Tensor hits_t,
56 | const torch::Tensor alive_indices,
57 | const torch::Tensor density_bitfield,
58 | const int cascades,
59 | const float scale,
60 | const float exp_step_factor,
61 | const int grid_size,
62 | const int max_samples,
63 | const int N_samples
64 | );
65 |
66 |
67 | std::vector composite_train_fw_cu(
68 | const torch::Tensor sigmas,
69 | const torch::Tensor rgbs,
70 | const torch::Tensor deltas,
71 | const torch::Tensor ts,
72 | const torch::Tensor rays_a,
73 | const float T_threshold
74 | );
75 |
76 |
77 | std::vector composite_train_bw_cu(
78 | const torch::Tensor dL_dopacity,
79 | const torch::Tensor dL_ddepth,
80 | const torch::Tensor dL_drgb,
81 | const torch::Tensor dL_dws,
82 | const torch::Tensor sigmas,
83 | const torch::Tensor rgbs,
84 | const torch::Tensor ws,
85 | const torch::Tensor deltas,
86 | const torch::Tensor ts,
87 | const torch::Tensor rays_a,
88 | const torch::Tensor opacity,
89 | const torch::Tensor depth,
90 | const torch::Tensor rgb,
91 | const float T_threshold
92 | );
93 |
94 |
95 | void composite_test_fw_cu(
96 | const torch::Tensor sigmas,
97 | const torch::Tensor rgbs,
98 | const torch::Tensor deltas,
99 | const torch::Tensor ts,
100 | const torch::Tensor hits_t,
101 | const torch::Tensor alive_indices,
102 | const float T_threshold,
103 | const torch::Tensor N_eff_samples,
104 | torch::Tensor opacity,
105 | torch::Tensor depth,
106 | torch::Tensor rgb
107 | );
108 |
109 |
110 | std::vector distortion_loss_fw_cu(
111 | const torch::Tensor ws,
112 | const torch::Tensor deltas,
113 | const torch::Tensor ts,
114 | const torch::Tensor rays_a
115 | );
116 |
117 |
118 | torch::Tensor distortion_loss_bw_cu(
119 | const torch::Tensor dL_dloss,
120 | const torch::Tensor ws_inclusive_scan,
121 | const torch::Tensor wts_inclusive_scan,
122 | const torch::Tensor ws,
123 | const torch::Tensor deltas,
124 | const torch::Tensor ts,
125 | const torch::Tensor rays_a
126 | );
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/csrc/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path as osp
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
5 |
6 |
7 | ROOT_DIR = osp.dirname(osp.abspath(__file__))
8 | include_dirs = [osp.join(ROOT_DIR, "include")]
9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h
10 |
11 | sources = glob.glob('*.cpp')+glob.glob('*.cu')
12 |
13 |
14 | setup(
15 | name='vren',
16 | version='2.0',
17 | author='kwea123',
18 | author_email='kwea123@gmail.com',
19 | description='cuda volume rendering library',
20 | long_description='cuda volume rendering library',
21 | ext_modules=[
22 | CUDAExtension(
23 | name='vren',
24 | sources=sources,
25 | include_dirs=include_dirs,
26 | extra_compile_args={'cxx': ['-O2'],
27 | 'nvcc': ['-O2']}
28 | )
29 | ],
30 | cmdclass={
31 | 'build_ext': BuildExtension
32 | }
33 | )
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/CudaExtensions/VolumeRenderingV2/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from pathlib import Path
3 | from setuptools import setup
4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
5 |
6 |
7 | ROOT_DIR = Path(__file__).parent.absolute()
8 | include_dirs = [str(ROOT_DIR / 'csrc' / 'include')]
9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h
10 |
11 | sources = glob.glob(str(ROOT_DIR / 'csrc' / '*.cpp')) + glob.glob(str(ROOT_DIR / 'csrc' / '*.cu'))
12 |
13 |
14 | setup(
15 | name='VolumeRenderingV2',
16 | version='2.0',
17 | author='kwea123',
18 | author_email='kwea123@gmail.com',
19 | description='cuda volume rendering library',
20 | long_description='cuda volume rendering library',
21 | ext_modules=[
22 | CUDAExtension(
23 | name='VolumeRenderingV2',
24 | sources=sources,
25 | include_dirs=include_dirs,
26 | extra_compile_args={'cxx': ['-O2'],
27 | 'nvcc': ['-O2']}
28 | )
29 | ],
30 | cmdclass={
31 | 'build_ext': BuildExtension
32 | }
33 | )
34 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/Loss.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """InstantNGP/Loss.py: Loss calculation for InstantNGP."""
4 |
5 | import torch
6 | import torchmetrics
7 |
8 | from Cameras.utils import RayPropertySlice
9 | from Methods.InstantNGP import InstantNGPModel
10 | from Optim.Losses.Base import BaseLoss
11 |
12 |
13 | class InstantNGPLoss(BaseLoss):
14 | def __init__(self, model: 'InstantNGPModel') -> None:
15 | super().__init__()
16 | self.addLossMetric('MSE_Color', torch.nn.functional.mse_loss, 1.0)
17 | self.addLossMetric('Weight_Decay_MLP', model.weight_decay_mlp, 1.0e-6 / 2.0)
18 | self.addQualityMetric('PSNR', torchmetrics.functional.peak_signal_noise_ratio)
19 |
20 | @torch.amp.autocast('cuda', dtype=torch.float32)
21 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor, bg_color: torch.Tensor) -> torch.Tensor:
22 | color_gt = rays[:, RayPropertySlice.rgb] + (1.0 - rays[:, RayPropertySlice.alpha]) * bg_color
23 | return super().forward({
24 | 'MSE_Color': {'input': outputs['rgb'], 'target': color_gt},
25 | 'Weight_Decay_MLP': {},
26 | 'PSNR': {'preds': outputs['rgb'], 'target': color_gt, 'data_range': 1.0}
27 | })
28 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/Model.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """InstantNGP/Model.py: InstantNGP Scene Model Implementation."""
4 |
5 | import numpy as np
6 | import torch
7 | from kornia.utils.grid import create_meshgrid3d
8 |
9 | import Framework
10 | from Methods.Base.Model import BaseModel
11 | from Methods.InstantNGP.utils import next_multiple
12 | from Thirdparty import TinyCudaNN as tcnn
13 |
14 |
15 | @Framework.Configurable.configure(
16 | SCALE=0.5,
17 | RESOLUTION=128,
18 | CENTER=[0.0, 0.0, 0.0],
19 | HASHMAP_NUM_LEVELS=16,
20 | HASHMAP_NUM_FEATURES_PER_LEVEL=2,
21 | HASHMAP_LOG2_SIZE=19,
22 | HASHMAP_BASE_RESOLUTION=16,
23 | HASHMAP_TARGET_RESOLUTION=2048,
24 | NUM_DENSITY_OUTPUT_FEATURES=16,
25 | NUM_DENSITY_NEURONS=64,
26 | NUM_DENSITY_LAYERS=1,
27 | DIR_SH_ENCODING_DEGREE=4,
28 | NUM_COLOR_NEURONS=64,
29 | NUM_COLOR_LAYERS=2,
30 | )
31 | class InstantNGPModel(BaseModel):
32 | """Defines InstantNGP data model"""
33 |
34 | def __init__(self, name: str = None) -> None:
35 | super().__init__(name)
36 |
37 | def __del__(self) -> None:
38 | torch.cuda.empty_cache()
39 | tcnn.free_temporary_memory()
40 |
41 | def weight_decay_mlp(self) -> torch.Tensor:
42 | """Calculates the weight decay for the MLPs."""
43 | loss = self.encoding_xyz.params[:self.n_params_encoding_mlp].pow(2).sum()
44 | loss += self.color_mlp_with_encoding.params.pow(2).sum()
45 | # loss /= 2 # fused into loss lambda
46 | loss /= self.n_mlp_params
47 | return loss
48 |
49 | def build(self) -> 'InstantNGPModel':
50 | """Builds the model."""
51 | # createGrid
52 | self.center = torch.Tensor([self.CENTER])
53 | self.xyz_min = -torch.ones(1, 3) * self.SCALE
54 | self.xyz_max = torch.ones(1, 3) * self.SCALE
55 | self.xyz_size = self.xyz_max - self.xyz_min
56 | self.half_size = (self.xyz_max - self.xyz_min) / 2
57 | self.cascades = max(1 + int(np.ceil(np.log2(2 * self.SCALE))), 1)
58 | self.register_buffer('density_grid', torch.zeros(self.cascades, self.RESOLUTION ** 3))
59 | self.register_buffer('grid_coords', create_meshgrid3d(self.RESOLUTION, self.RESOLUTION, self.RESOLUTION, False, dtype=torch.int32, device=self.density_grid.device).reshape(-1, 3))
60 | self.register_buffer('density_bitfield', torch.zeros(self.cascades*self.RESOLUTION ** 3 // 8, dtype=torch.uint8))
61 | # create encodings and networks
62 | self.encoding_xyz = tcnn.NetworkWithInputEncoding(
63 | n_input_dims=3,
64 | n_output_dims=self.NUM_DENSITY_OUTPUT_FEATURES,
65 | encoding_config={
66 | 'otype': 'Grid',
67 | 'type': 'Hash',
68 | 'n_levels': self.HASHMAP_NUM_LEVELS,
69 | 'n_features_per_level': self.HASHMAP_NUM_FEATURES_PER_LEVEL,
70 | 'log2_hashmap_size': self.HASHMAP_LOG2_SIZE,
71 | 'base_resolution': self.HASHMAP_BASE_RESOLUTION,
72 | 'per_level_scale': np.exp(np.log(self.HASHMAP_TARGET_RESOLUTION * (2 * self.SCALE) / self.HASHMAP_BASE_RESOLUTION) / (self.HASHMAP_NUM_LEVELS - 1)),
73 | 'interpolation': 'Linear'
74 | },
75 | network_config={
76 | 'otype': 'FullyFusedMLP',
77 | 'activation': 'ReLU',
78 | 'output_activation': 'None',
79 | 'n_neurons': self.NUM_DENSITY_NEURONS,
80 | 'n_hidden_layers': self.NUM_DENSITY_LAYERS,
81 | },
82 | seed=Framework.config.GLOBAL.RANDOM_SEED
83 | )
84 | # calculate number of parameters for the MLP part of the encoding for weight decay
85 | n_params_mlp = 0
86 | n_inputs = next_multiple(self.HASHMAP_NUM_FEATURES_PER_LEVEL * self.HASHMAP_NUM_LEVELS, 16)
87 | for i in range(self.NUM_DENSITY_LAYERS):
88 | n_params_layer = self.NUM_DENSITY_NEURONS * n_inputs
89 | n_params_layer = next_multiple(n_params_layer, 16)
90 | n_params_mlp += n_params_layer
91 | n_inputs = self.NUM_DENSITY_NEURONS
92 | n_params_mlp += next_multiple(self.encoding_xyz.n_output_dims, 16) * n_inputs
93 | self.n_params_encoding_mlp = n_params_mlp
94 | self.color_mlp_with_encoding = tcnn.NetworkWithInputEncoding(
95 | n_input_dims=3 + self.encoding_xyz.n_output_dims,
96 | n_output_dims=3,
97 | encoding_config={
98 | 'otype': 'Composite',
99 | 'nested': [
100 | {
101 | 'n_dims_to_encode': 3,
102 | 'otype': 'SphericalHarmonics',
103 | 'degree': self.DIR_SH_ENCODING_DEGREE,
104 | },
105 | {
106 | "otype": "Identity"
107 | }
108 | ]
109 |
110 | },
111 | network_config={
112 | 'otype': 'FullyFusedMLP',
113 | 'activation': 'ReLU',
114 | 'output_activation': 'Sigmoid',
115 | 'n_neurons': self.NUM_COLOR_NEURONS,
116 | 'n_hidden_layers': self.NUM_COLOR_LAYERS,
117 | },
118 | seed=Framework.config.GLOBAL.RANDOM_SEED
119 | )
120 | self.n_mlp_params = len(self.color_mlp_with_encoding.params) + self.n_params_encoding_mlp
121 | return self
122 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/Trainer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """InstantNGP/Trainer.py: Implementation of the trainer for the InstantNGP method."""
3 |
4 | import torch
5 |
6 | import Framework
7 | from Logging import Logger
8 | from Cameras.NDC import NDCCamera
9 | from Datasets.Base import BaseDataset
10 | from Methods.Base.GuiTrainer import GuiTrainer
11 | from Methods.Base.utils import preTrainingCallback, trainingCallback
12 | from Methods.InstantNGP.Loss import InstantNGPLoss
13 | from Methods.InstantNGP.utils import next_multiple, logOccupancyGrids
14 | from Optim.Samplers.DatasetSamplers import RayPoolSampler
15 | from Optim.Samplers.ImageSamplers import RandomImageSampler
16 |
17 |
18 | @Framework.Configurable.configure(
19 | NUM_ITERATIONS=50000,
20 | TARGET_BATCH_SIZE=262144,
21 | WARMUP_STEPS=256,
22 | DENSITY_GRID_UPDATE_INTERVAL=16,
23 | LEARNING_RATE=1.0e-2,
24 | LEARNING_RATE_DECAY_START=20000,
25 | LEARNING_RATE_DECAY_INTERVAL=10000,
26 | LEARNING_RATE_DECAY_BASE=0.33,
27 | ADAM_EPS=1e-15,
28 | USE_APEX=False,
29 | WANDB=Framework.ConfigParameterList(
30 | RENDER_OCCUPANCY_GRIDS=False,
31 | )
32 | )
33 | class InstantNGPTrainer(GuiTrainer):
34 | """Defines the trainer for the InstantNGP method."""
35 |
36 | def __init__(self, **kwargs) -> None:
37 | super(InstantNGPTrainer, self).__init__(**kwargs)
38 | self.loss = InstantNGPLoss(self.model)
39 | try:
40 | if self.USE_APEX:
41 | from Thirdparty.Apex import FusedAdam # slightly faster than the PyTorch implementation
42 | self.optimizer = FusedAdam(self.model.parameters(), lr=self.LEARNING_RATE, eps=self.ADAM_EPS, betas=(0.9, 0.99), adam_w_mode=False)
43 | else:
44 | raise Exception
45 | except Exception:
46 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.LEARNING_RATE, eps=self.ADAM_EPS, betas=(0.9, 0.99), fused=True)
47 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
48 | optimizer=self.optimizer,
49 | milestones=[iteration for iteration in range(
50 | self.LEARNING_RATE_DECAY_START,
51 | self.NUM_ITERATIONS,
52 | self.LEARNING_RATE_DECAY_INTERVAL
53 | )],
54 | gamma=self.LEARNING_RATE_DECAY_BASE
55 | )
56 | self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=128.0, growth_interval=self.NUM_ITERATIONS + 1)
57 | self.rays_per_batch = 2 ** 12
58 | self.measured_batch_size = 0
59 | self.sampler_train = None
60 | self.sampler_val = None
61 |
62 | @preTrainingCallback(priority=10000)
63 | @torch.no_grad()
64 | def removeBackgroundColor(self, _, dataset: 'BaseDataset') -> None:
65 | if (dataset.camera.background_color == 0.0).all():
66 | return
67 | # remove background color from training samples to allow training with random background colors
68 | for cam_properties in Logger.logProgressBar(dataset.data['train'], desc='Removing background color', leave=False):
69 | if cam_properties.alpha is not None:
70 | cam_properties.rgb.sub_((1.0 - cam_properties.alpha) * dataset.camera.background_color[:, None, None]).clamp_(0.0, 1.0)
71 | # recompute rays if necessary
72 | if dataset.PRECOMPUTE_RAYS:
73 | dataset.ray_collection['train'] = None
74 | dataset.precomputeRays(['train'])
75 |
76 | @preTrainingCallback(priority=1000)
77 | @torch.no_grad()
78 | def initSampler(self, _, dataset: 'BaseDataset') -> None:
79 | self.sampler_train = RayPoolSampler(dataset=dataset.train(), img_sampler_cls=RandomImageSampler)
80 | if self.RUN_VALIDATION:
81 | self.sampler_val = RayPoolSampler(dataset=dataset.eval(), img_sampler_cls=RandomImageSampler)
82 |
83 | @preTrainingCallback(priority=100)
84 | @torch.no_grad()
85 | def carveDensityGrid(self, _, dataset: 'BaseDataset') -> None:
86 | if not isinstance(dataset.camera, NDCCamera):
87 | self.renderer.carveDensityGrid(dataset.train(), subtractive=False, use_alpha=False)
88 |
89 | @trainingCallback(priority=1000, iteration_stride='DENSITY_GRID_UPDATE_INTERVAL')
90 | @torch.no_grad()
91 | def updateDensityGrid(self, iteration: int, _) -> None:
92 | self.renderer.updateDensityGrid(warmup=iteration < self.WARMUP_STEPS)
93 |
94 | @trainingCallback(priority=999, start_iteration='DENSITY_GRID_UPDATE_INTERVAL', iteration_stride='DENSITY_GRID_UPDATE_INTERVAL')
95 | @torch.no_grad()
96 | def updateBatchSize(self, *_) -> None:
97 | self.measured_batch_size /= self.DENSITY_GRID_UPDATE_INTERVAL
98 | self.rays_per_batch = min(next_multiple(self.rays_per_batch * self.TARGET_BATCH_SIZE / self.measured_batch_size, 256), self.TARGET_BATCH_SIZE)
99 | self.measured_batch_size = 0
100 |
101 | @trainingCallback(priority=50)
102 | def processTrainingSample(self, _, dataset: 'BaseDataset') -> None:
103 | """Performs a single training step."""
104 | # prepare training iteration
105 | self.model.train()
106 | self.loss.train()
107 | dataset.train()
108 | # sample ray batch
109 | ray_batch: torch.Tensor = self.sampler_train.get(dataset=dataset, ray_batch_size=self.rays_per_batch)['ray_batch']
110 | with torch.amp.autocast('cuda', enabled=True):
111 | # render and update
112 | bg_color = torch.rand(3)
113 | output = self.renderer.renderRays(
114 | rays=ray_batch,
115 | camera=dataset.camera,
116 | custom_bg_color=bg_color,
117 | train_mode=True)
118 | loss = self.loss(output, ray_batch, bg_color)
119 | self.optimizer.zero_grad()
120 | self.grad_scaler.scale(loss).backward()
121 | self.grad_scaler.step(self.optimizer)
122 | self.grad_scaler.update()
123 | self.lr_scheduler.step()
124 | self.measured_batch_size += output['rm_samples'].item()
125 |
126 | @trainingCallback(active='RUN_VALIDATION', priority=100)
127 | @torch.no_grad()
128 | def processValidationSample(self, _, dataset: 'BaseDataset') -> None:
129 | """Performs a single validation step."""
130 | self.model.eval()
131 | self.loss.eval()
132 | dataset.eval()
133 | # sample ray batch
134 | ray_batch: torch.Tensor = self.sampler_val.get(dataset=dataset, ray_batch_size=self.rays_per_batch)['ray_batch']
135 | with torch.amp.autocast('cuda', enabled=True):
136 | output = self.renderer.renderRays(
137 | rays=ray_batch,
138 | camera=dataset.camera,
139 | train_mode=True)
140 | self.loss(output, ray_batch, None)
141 |
142 | @trainingCallback(active='WANDB.ACTIVATE', priority=500, iteration_stride='WANDB.INTERVAL')
143 | @torch.no_grad()
144 | def logWandB(self, iteration: int, dataset: 'BaseDataset') -> None:
145 | """Logs all losses and visualizes training and validation samples using Weights & Biases."""
146 | super().logWandB(iteration, dataset)
147 | # visualize scene and occupancy grid as point clouds
148 | if self.WANDB.RENDER_OCCUPANCY_GRIDS:
149 | logOccupancyGrids(self.renderer, iteration, dataset, 'occupancy grid')
150 | # commit current step
151 | Framework.wandb.log(data={}, commit=True)
152 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/__init__.py:
--------------------------------------------------------------------------------
1 | from Methods.InstantNGP.Model import InstantNGPModel
2 | from Methods.InstantNGP.Renderer import InstantNGPRenderer
3 | from Methods.InstantNGP.Trainer import InstantNGPTrainer
4 |
5 | MODEL = InstantNGPModel
6 | RENDERER = InstantNGPRenderer
7 | TRAINING_INSTANCE = InstantNGPTrainer
8 |
--------------------------------------------------------------------------------
/src/Methods/InstantNGP/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | InstantNGP/utils.py: Utility functions for InstantNGP.
5 | """
6 |
7 | import numpy as np
8 | import torch
9 |
10 | import Framework
11 | from Datasets.Base import BaseDataset
12 |
13 |
14 | def next_multiple(value: int | float, multiple: int) -> int:
15 | return int(((value + multiple - 1) // multiple) * multiple)
16 |
17 |
18 | @torch.no_grad()
19 | def logOccupancyGrids(renderer, iteration: int, dataset: 'BaseDataset', log_string: str = 'occupancy grid') -> None:
20 | """visualize occupancy grid as pointcloud in wandb."""
21 | dataset.train()
22 | cameras: list[np.array] = []
23 | for i in range(len(dataset)):
24 | dataset.camera.setProperties(dataset[i])
25 | data = dataset.camera.getPositionAndViewdir().cpu().numpy()
26 | # data[:, 2] *= -1
27 | cameras.append({"start": data[0].tolist(), "end": (data[0] + (0.1 * data[1])).tolist()})
28 | cameras: np.ndarray = np.array(cameras)
29 | # gather boxes and active cells
30 | center = np.array([renderer.model.CENTER])
31 | unit_cube_points = np.array([
32 | [-1, -1, -1],
33 | [-1, +1, -1],
34 | [-1, -1, +1],
35 | [+1, -1, -1],
36 | [+1, +1, -1],
37 | [-1, +1, +1],
38 | [+1, -1, +1],
39 | [+1, +1, +1]
40 | ])
41 | cells = renderer.getAllCells()
42 | boxes = []
43 | active_cells = [(unit_cube_points * renderer.model.SCALE * 1.1) + center]
44 | thresh = min(renderer.model.density_grid[renderer.model.density_grid > 0].mean().item(), renderer.density_threshold)
45 | for c in range(renderer.model.cascades):
46 | indices, coords = cells[c]
47 | s = min(2 ** (c - 1), renderer.model.SCALE)
48 | boxes.append(
49 | {
50 | "corners": ((unit_cube_points * s) + center).tolist(),
51 | "label": f"Cascade {c + 1}",
52 | "color": [255, 0, 0],
53 | }
54 | )
55 | half_grid_size = s / renderer.model.RESOLUTION
56 | points = (coords / (renderer.model.RESOLUTION - 1) * 2 - 1) * (s - half_grid_size)
57 | active_cells.append((points[renderer.model.density_grid[c, indices] > thresh]).cpu().numpy() + center)
58 | active_cells = np.concatenate(active_cells, axis=0)
59 | while active_cells.shape[0] > 200000:
60 | active_cells = active_cells[::2]
61 | scene = Framework.wandb.Object3D({
62 | "type": "lidar/beta",
63 | "points": active_cells,
64 | "boxes": np.array(boxes),
65 | "vectors": cameras
66 | })
67 | Framework.wandb.log(
68 | data={log_string: scene},
69 | step=iteration
70 | )
71 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/Loss.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """NeRF/Loss.py: Loss implementation for the NeRF method."""
4 |
5 | import torch
6 | import torchmetrics
7 |
8 | from Cameras.utils import RayPropertySlice
9 | from Optim.Losses.Base import BaseLoss
10 |
11 |
12 | class NeRFLoss(BaseLoss):
13 | """Defines a class for all sub-losses of the NeRF method."""
14 |
15 | def __init__(self, lambda_color: float, lambda_alpha: float) -> None:
16 | super().__init__()
17 | self.addLossMetric('L2_Color', torch.nn.functional.mse_loss, lambda_color)
18 | self.addLossMetric('L2_Alpha', torch.nn.functional.mse_loss, lambda_alpha)
19 | self.addQualityMetric('PSNR', torchmetrics.functional.image.peak_signal_noise_ratio)
20 |
21 | def forward(self, outputs: dict[str, torch.Tensor | None], rays: torch.Tensor) -> torch.Tensor:
22 | """Defines loss calculation."""
23 | return super().forward({
24 | 'L2_Color': {'input': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb]},
25 | 'L2_Alpha': {'input': outputs['alpha'], 'target': rays[:, RayPropertySlice.alpha]},
26 | 'PSNR': {'preds': outputs['rgb'], 'target': rays[:, RayPropertySlice.rgb], 'data_range': 1.0},
27 | })
28 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/Model.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """NeRF/Model.py: Implementation of the neural model for the vanilla (i.e. original) NeRF method."""
4 |
5 | import torch
6 |
7 | import Framework
8 | from Methods.Base.Model import BaseModel
9 | from Methods.NeRF.utils import getActivationFunction, FrequencyEncoding
10 |
11 |
12 | class NeRFBlock(torch.torch.nn.Module):
13 | """Defines a NeRF block (input: position, direction -> output: density, color)."""
14 |
15 | def __init__(self, num_layers: int, num_color_layers: int, num_features: int,
16 | encoding_length_position: int, encoding_length_direction: int, encoding_append_input: bool,
17 | input_skips: list[int], activation_function: str) -> None:
18 | super(NeRFBlock, self).__init__()
19 | # set parameters
20 | self.input_skips = input_skips # layer indices after which input is appended
21 | # get activation function (type, parameters, initial bias for last density layer)
22 | af_class, af_parameters, af_bias = getActivationFunction(activation_function)
23 | # embedding layers
24 | self.embedding_position = FrequencyEncoding(encoding_length_position, encoding_append_input)
25 | self.embedding_direction = FrequencyEncoding(encoding_length_direction, encoding_append_input)
26 | input_size_position = self.embedding_position.getOutputSize(3)
27 | input_size_direction = self.embedding_direction.getOutputSize(3)
28 | # initial linear layers
29 | self.initial_layers = [
30 | torch.nn.Sequential(torch.nn.Linear(input_size_position, num_features, bias=True), af_class(*af_parameters))
31 | ]
32 | for layer_index in range(1, num_layers):
33 | self.initial_layers.append(torch.nn.Sequential(
34 | torch.nn.Linear(num_features if layer_index not in input_skips else num_features + input_size_position,
35 | num_features, bias=True), af_class(*af_parameters)
36 | ))
37 | self.initial_layers = torch.nn.ModuleList(self.initial_layers)
38 | # intermediate feature and density layers
39 | self.feature_layer = torch.nn.Linear(num_features, num_features, bias=True)
40 | self.density_layer = torch.nn.Linear(num_features, 1, bias=True)
41 | self.density_activation = af_class(*af_parameters)
42 | # final color layer
43 | self.color_layers = torch.nn.Sequential(*(
44 | [torch.nn.Linear(num_features + input_size_direction, num_features // 2, bias=True), af_class(*af_parameters)]
45 | + [torch.nn.Sequential(torch.nn.Linear(num_features // 2, num_features // 2, bias=True), af_class(*af_parameters))
46 | for _ in range(num_color_layers - 1)]
47 | + [torch.nn.Linear(num_features // 2, 3, bias=True), torch.nn.Sigmoid()]
48 | ))
49 | # initialize bias for density layer activation function (for better convergence, copied from pytorch3d examples)
50 | if af_bias is not None:
51 | self.density_layer.bias.data[0] = af_bias
52 |
53 | def forward(self, positions: torch.Tensor, directions: torch.Tensor,
54 | return_rgb: bool = False, random_noise_densities: float = 0.0) -> tuple[torch.Tensor, torch.Tensor]:
55 | # transform inputs to higher dimensional space
56 | positions_embedded: torch.Tensor = self.embedding_position(positions)
57 | # run initial layers
58 | x = positions_embedded
59 | for index, layer in enumerate(self.initial_layers):
60 | x = layer(x)
61 | if index + 1 in self.input_skips:
62 | x = torch.cat((x, positions_embedded), dim=-1)
63 | # extract density, add random noise before activation function
64 | density: torch.Tensor = self.density_layer(x)
65 | density = self.density_activation(density + (torch.randn(density.shape) * random_noise_densities))
66 | # extract features, append view_directions and extract color
67 | color: torch.Tensor | None = None
68 | if return_rgb:
69 | directions_embedded: torch.Tensor = self.embedding_direction(directions)
70 | features: torch.Tensor = self.feature_layer(x)
71 | features = torch.cat((features, directions_embedded), dim=-1)
72 | color = self.color_layers(features)
73 | return density, color
74 |
75 |
76 | @Framework.Configurable.configure(
77 | NUM_LAYERS=8,
78 | NUM_COLOR_LAYERS=1,
79 | NUM_FEATURES=256,
80 | ENCODING_LENGTH_POSITIONS=10,
81 | ENCODING_LENGTH_DIRECTIONS=4,
82 | ENCODING_APPEND_INPUT=True,
83 | INPUT_SKIPS=[5],
84 | ACTIVATION_FUNCTION='relu'
85 | )
86 | class NeRF(BaseModel):
87 | """Defines a plain NeRF with a single MLP."""
88 |
89 | def __init__(self, name: str = None) -> None:
90 | super(NeRF, self).__init__(name)
91 |
92 | def build(self) -> 'NeRF':
93 | """Builds the model."""
94 | self.net = NeRFBlock(
95 | self.NUM_LAYERS, self.NUM_COLOR_LAYERS, self.NUM_FEATURES,
96 | self.ENCODING_LENGTH_POSITIONS, self.ENCODING_LENGTH_DIRECTIONS, self.ENCODING_APPEND_INPUT,
97 | self.INPUT_SKIPS, self.ACTIVATION_FUNCTION
98 | )
99 | return self
100 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/Renderer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | NeRF/Renderer.py: Implementation of the renderer for the vanilla (i.e. original) NeRF.
5 | Borrows heavily from the PyTorch NeRF reimplementation of Yenchen Lin
6 | Source: https://github.com/yenchenlin/nerf-pytorch/
7 | """
8 |
9 | import torch
10 | from torch import Tensor
11 |
12 | import Framework
13 | from Cameras.utils import RayPropertySlice
14 | from Methods.Base.Renderer import BaseRenderingComponent, BaseRenderer
15 | from Methods.NeRF.Model import NeRF, NeRFBlock
16 | from Methods.NeRF.utils import generateSamples, integrateRaySamples
17 | from Methods.Base.Renderer import BaseModel
18 | from Cameras.Base import BaseCamera
19 |
20 |
21 | class NeRFRayRenderingComponent(BaseRenderingComponent):
22 | """Defines a NeRF ray rendering component used to access the NeRF model."""
23 |
24 | def __init__(self, scene_function: 'NeRFBlock') -> None:
25 | super().__init__()
26 | self.scene_function = scene_function
27 |
28 | def forward(self, rays: Tensor, camera: 'BaseCamera',
29 | ray_batch_size: int, num_samples: int, return_samples: bool, randomize_samples: bool,
30 | random_noise_densities: float) -> dict[str, Tensor | None]:
31 | """Generates samples from the given rays and queries the NeRF model to produce the desired outputs."""
32 | outputs = {'rgb': [], 'depth': [], 'depth_samples': [], 'alpha_weights': [], 'alpha': []}
33 | # split rays into chunks that fit into VRAM
34 | ray_batches: list[Tensor] = torch.split(rays, ray_batch_size, dim=0)
35 | background_color: Tensor = camera.background_color.to(rays.device)
36 | for ray_batch in ray_batches:
37 | depth_samples = generateSamples(
38 | ray_batch, num_samples, camera.near_plane, camera.far_plane, randomize_samples
39 | )
40 | positions: Tensor = ray_batch[:, None, RayPropertySlice.origin] + (
41 | ray_batch[:, None, RayPropertySlice.direction] * depth_samples[:, :, None])
42 | directions: Tensor = ray_batch[:, None, RayPropertySlice.view_direction].expand(positions.shape)
43 | densities, rgb = self.scene_function(
44 | positions.reshape(-1, 3), directions.reshape(-1, 3),
45 | return_rgb=True, random_noise_densities=random_noise_densities
46 | )
47 | final_rgb, final_depth, final_alpha, final_alpha_weights = integrateRaySamples(
48 | depth_samples, ray_batch[:, RayPropertySlice.direction],
49 | densities.reshape(-1, num_samples), rgb.reshape(-1, num_samples, 3), background_color
50 | )
51 | # append outputs
52 | outputs['rgb'].append(final_rgb)
53 | outputs['depth'].append(final_depth)
54 | outputs['alpha'].append(final_alpha)
55 | if return_samples:
56 | outputs['depth_samples'].append(depth_samples)
57 | outputs['alpha_weights'].append(final_alpha_weights)
58 | # concat ray batches
59 | for key in outputs:
60 | outputs[key] = (
61 | torch.cat(outputs[key], dim=0) if len(ray_batches) > 1 else outputs[key][0]
62 | ) if outputs[key] else None
63 | return outputs
64 |
65 |
66 | @Framework.Configurable.configure(
67 | RAY_BATCH_SIZE=8192,
68 | NUM_SAMPLES=256
69 | )
70 | class NeRFRenderer(BaseRenderer):
71 | """Defines the renderer for the vanilla (i.e. original) NeRF method."""
72 |
73 | def __init__(self, model: 'BaseModel') -> None:
74 | super().__init__(model, [NeRF])
75 | self.ray_rendering_component = NeRFRayRenderingComponent.get(self.model.net)
76 |
77 | def renderRays(self, rays: Tensor, camera: 'BaseCamera',
78 | return_samples: bool = False, randomize_samples: bool = False,
79 | random_noise_densities: float = 0.0) -> dict[str, Tensor | None]:
80 | """Renders the given set of rays using the renderer's rendering component."""
81 | return self.ray_rendering_component(
82 | rays, camera,
83 | self.RAY_BATCH_SIZE, self.NUM_SAMPLES,
84 | return_samples, randomize_samples, random_noise_densities
85 | )
86 |
87 | def renderImage(self, camera: 'BaseCamera', to_chw: bool = False, benchmark: bool = False) -> dict[str, Tensor | None]:
88 | """Renders a complete image using the given camera."""
89 | rays: Tensor = camera.generateRays()
90 | rendered_rays = self.renderRays(
91 | rays, camera,
92 | return_samples=False, randomize_samples=False, random_noise_densities=0.0
93 | )
94 | # reshape rays to images
95 | for key in rendered_rays:
96 | if rendered_rays[key] is not None:
97 | rendered_rays[key] = rendered_rays[key].reshape(camera.properties.height, camera.properties.width, -1)
98 | if to_chw:
99 | rendered_rays[key] = rendered_rays[key].permute((2, 0, 1))
100 | return rendered_rays
101 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/Trainer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """NeRF/Trainer.py: Implementation of the trainer for the vanilla (i.e. original) NeRF method."""
4 |
5 | import torch
6 |
7 | import Framework
8 | from Datasets.Base import BaseDataset
9 | from Methods.Base.Trainer import BaseTrainer
10 | from Methods.Base.utils import preTrainingCallback, trainingCallback
11 | from Methods.NeRF.Loss import NeRFLoss
12 | from Optim.Samplers.DatasetSamplers import DatasetSampler, RayPoolSampler
13 | from Optim.Samplers.ImageSamplers import RandomImageSampler
14 |
15 |
16 | @Framework.Configurable.configure(
17 | NUM_ITERATIONS=500000,
18 | BATCH_SIZE=1024,
19 | SAMPLE_SINGLE_IMAGE=True,
20 | DENSITY_RANDOM_NOISE_STD=0.0,
21 | ADAM_BETA_1=0.9,
22 | ADAM_BETA_2=0.999,
23 | LEARNINGRATE=5.0e-04,
24 | LEARNINGRATE_DECAY_RATE=0.1,
25 | LEARNINGRATE_DECAY_STEPS=500000,
26 | LAMBDA_COLOR_LOSS=1.0,
27 | LAMBDA_ALPHA_LOSS=0.0,
28 | )
29 | class NeRFTrainer(BaseTrainer):
30 | """Defines the trainer for the vanilla (i.e. original) NeRF method."""
31 |
32 | def __init__(self, **kwargs) -> None:
33 | super(NeRFTrainer, self).__init__(**kwargs)
34 | self.optimizer = torch.optim.Adam(
35 | self.model.parameters(),
36 | lr=self.LEARNINGRATE, betas=(self.ADAM_BETA_1, self.ADAM_BETA_2),
37 | eps=1e-8
38 | )
39 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
40 | self.optimizer,
41 | lr_lambda=self.LRDecayPolicy(self.LEARNINGRATE_DECAY_RATE, self.LEARNINGRATE_DECAY_STEPS),
42 | last_epoch=self.model.num_iterations_trained - 1
43 | )
44 | self.loss = NeRFLoss(self.LAMBDA_COLOR_LOSS, self.LAMBDA_ALPHA_LOSS)
45 |
46 | class LRDecayPolicy(object):
47 | """Defines a decay policy for the learning rate."""
48 |
49 | def __init__(self, ldr: float, lds: float) -> None:
50 | self.ldr: float = ldr
51 | self.lds: float = lds
52 |
53 | def __call__(self, iteration) -> float:
54 | """Calculates learning rate decay."""
55 | return self.ldr ** (iteration / self.lds)
56 |
57 | @preTrainingCallback(priority=1000)
58 | @torch.no_grad()
59 | def initSampler(self, _, dataset: 'BaseDataset') -> None:
60 | sampler_cls = DatasetSampler if self.SAMPLE_SINGLE_IMAGE else RayPoolSampler
61 | self.sampler_train = sampler_cls(dataset=dataset.train(), random=True, img_sampler_cls=RandomImageSampler)
62 | if self.RUN_VALIDATION:
63 | self.sampler_val = sampler_cls(dataset=dataset.eval(), random=True, img_sampler_cls=RandomImageSampler)
64 |
65 | @trainingCallback(priority=50)
66 | def processTrainingSample(self, iteration: int, dataset: 'BaseDataset') -> None:
67 | """Defines a callback which is executed every iteration to process a training sample."""
68 | # set modes
69 | self.model.train()
70 | self.loss.train()
71 | dataset.train()
72 | # sample ray batch
73 | ray_batch: torch.Tensor = self.sampler_train.get(dataset=dataset, ray_batch_size=self.BATCH_SIZE)['ray_batch']
74 | # update model
75 | self.optimizer.zero_grad()
76 | outputs = self.renderer.renderRays(
77 | rays=ray_batch,
78 | camera=dataset.camera,
79 | return_samples=False,
80 | randomize_samples=True,
81 | random_noise_densities=self.DENSITY_RANDOM_NOISE_STD
82 | )
83 | loss: torch.Tensor = self.loss(outputs, ray_batch)
84 | loss.backward()
85 | self.optimizer.step()
86 | # update learning rate
87 | self.lr_scheduler.step()
88 |
89 | @trainingCallback(priority=100, active='RUN_VALIDATION')
90 | @torch.no_grad()
91 | def processValidationSample(self, iteration: int, dataset: 'BaseDataset') -> None:
92 | """Defines a callback which is executed every iteration to process a validation sample."""
93 | self.model.eval()
94 | self.loss.eval()
95 | dataset.eval()
96 | ray_batch: torch.Tensor = self.sampler_val.get(dataset=dataset, ray_batch_size=self.BATCH_SIZE)['ray_batch']
97 | outputs = self.renderer.renderRays(
98 | rays=ray_batch,
99 | camera=dataset.camera,
100 | return_samples=False,
101 | randomize_samples=False,
102 | random_noise_densities=0.0
103 | )
104 | self.loss(outputs, ray_batch)
105 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/__init__.py:
--------------------------------------------------------------------------------
1 | from Methods.NeRF.Model import NeRF
2 | from Methods.NeRF.Renderer import NeRFRenderer
3 | from Methods.NeRF.Trainer import NeRFTrainer
4 |
5 | MODEL = NeRF
6 | RENDERER = NeRFRenderer
7 | TRAINING_INSTANCE = NeRFTrainer
8 |
--------------------------------------------------------------------------------
/src/Methods/NeRF/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | NeRF/utils.py: Contains utility functions used for the implementation of the NeRF method.
5 | """
6 | import torch
7 | from torch import Tensor
8 |
9 | import Framework
10 | from Logging import Logger
11 |
12 |
13 | class FrequencyEncoding(torch.nn.Module):
14 | """Defines a network layer that performs frequency encoding with linear coefficients."""
15 |
16 | def __init__(self, encoding_length: int, append_input: bool):
17 | super().__init__()
18 | # calculate frequencies
19 | self.register_buffer('frequency_factors', (
20 | 2 ** torch.linspace(start=0.0, end=encoding_length - 1.0, steps=encoding_length)
21 | )[None, None, :] # * math.pi
22 | )
23 | self.append_input: bool = append_input
24 |
25 | def getOutputSize(self, num_inputs: int) -> int:
26 | """Returns the number of output nodes"""
27 | num_outputs: int = num_inputs * 2 * self.frequency_factors.numel()
28 | if self.append_input:
29 | num_outputs += num_inputs
30 | return num_outputs
31 |
32 | def forward(self, inputs: Tensor) -> Tensor:
33 | """Returns the input tensor after applying the periodic embedding to it."""
34 | outputs: list[Tensor] = []
35 | # append inputs original inputs if requested
36 | if self.append_input:
37 | outputs.append(inputs)
38 | frequencies: Tensor = (inputs[:, :, None] * self.frequency_factors).flatten(start_dim=1)
39 | # apply periodic functions over frequencies
40 | for periodic_function in (torch.cos, torch.sin):
41 | outputs.append(periodic_function(frequencies))
42 | return torch.cat(outputs, dim=-1)
43 |
44 |
45 | # dictionary variable containing all available activation functions as well as their parameters and initial biases
46 | ACTIVATION_FUNCTION_OPTIONS: dict[str, tuple] = {
47 | 'relu': (torch.nn.ReLU, (True,), None),
48 | 'softplus': (torch.nn.Softplus, (10.0,), -1.5)
49 | }
50 |
51 |
52 | def getActivationFunction(type: str) -> tuple:
53 | """Returns the requested activation function, parameters and initial bias."""
54 | # log error message and stop execution if requested key is invalid
55 | if type not in ACTIVATION_FUNCTION_OPTIONS:
56 | Logger.logError(
57 | f'requested invalid model activation function: {type} \n'
58 | f'available options are: {list(ACTIVATION_FUNCTION_OPTIONS.keys())}'
59 | )
60 | raise Framework.ModelError(f'Invalid activation function "{type}"')
61 | # return requested model instance
62 | return ACTIVATION_FUNCTION_OPTIONS[type]
63 |
64 |
65 | def generateSamples(rays: Tensor, num_samples: int, near_plane: float, far_plane: float,
66 | randomize_samples: bool) -> Tensor:
67 | """Returns random samples (positions in space) for the given set of rays."""
68 | device: torch.device = rays.device
69 | lin_steps: Tensor = torch.linspace(0., 1., steps=num_samples, device=device)
70 | lin_steps: Tensor = (near_plane * (1.0 - lin_steps)) + (far_plane * lin_steps)
71 | depth_samples: Tensor = lin_steps.expand([rays.shape[0], num_samples])
72 | if randomize_samples:
73 | # use linear samples as interval borders for random samples
74 | mid_points: Tensor = 0.5 * (depth_samples[..., 1:] + depth_samples[..., :-1])
75 | upper_border: Tensor = torch.cat([mid_points, depth_samples[..., -1:]], -1)
76 | lower_border: Tensor = torch.cat([depth_samples[..., :1], mid_points], -1)
77 | random_offsets: Tensor = torch.rand(depth_samples.shape, device=device)
78 | depth_samples: Tensor = lower_border + ((upper_border - lower_border) * random_offsets)
79 | return depth_samples
80 |
81 |
82 | def integrateRaySamples(depth_samples: Tensor, ray_directions: Tensor, densities: Tensor, rgb: Tensor | None,
83 | background_color: Tensor, final_distance: float=1e10) -> tuple[Tensor | None, Tensor | None, Tensor, Tensor]:
84 | """Estimates final color, depth, and alpha values from samples along ray."""
85 | distances: Tensor = depth_samples[..., 1:] - depth_samples[..., :-1]
86 | distances: Tensor = torch.cat([distances, Tensor([final_distance]).expand(distances[..., :1].shape)], dim=-1) * torch.norm(ray_directions[..., None, :], dim=-1)
87 | alpha: Tensor = 1.0 - torch.exp(-densities * distances)
88 | alpha_weights: Tensor = alpha * torch.cumprod(
89 | torch.cat([torch.ones((alpha.shape[0], 1)), 1.0 - alpha + 1e-10], -1), -1
90 | )[:, :-1]
91 | alpha_final: Tensor = torch.sum(alpha_weights, dim=-1, keepdim=True)
92 | # render only if color is available
93 | final_rgb: Tensor | None
94 | final_depth: Tensor | None
95 | final_rgb = final_depth = None
96 | if rgb is not None:
97 | final_depth = torch.sum((alpha_weights / (alpha_final + 1e-8)) * depth_samples, -1)
98 | final_rgb = torch.sum(alpha_weights[..., None] * rgb, -2)
99 | if background_color is not None:
100 | final_rgb += ((1.0 - alpha_final) * background_color)
101 | return final_rgb, final_depth, alpha_final, alpha_weights
102 |
--------------------------------------------------------------------------------
/src/Optim/AdamUtils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/AdamUtils.py: Provides various utility functions for the Adam optimizer and its variants."""
4 |
5 | import torch
6 |
7 |
8 | def replace_param_group_data(optimizer: torch.optim.Optimizer, new_values: torch.Tensor, group_name: str, reset_state: bool = True) -> None:
9 | """Replaces the data of a parameter group with the given tensor."""
10 | for group in optimizer.param_groups:
11 | if group['name'] == group_name:
12 | if len(group['params']) != 1:
13 | raise NotImplementedError('"replace_param_group_data" only implemented for single-parameter groups.')
14 | param = group['params'][0]
15 | param.data = new_values
16 | if reset_state:
17 | state = optimizer.state[param]
18 | if state:
19 | for val in ['exp_avg', 'exp_avg_sq']:
20 | state[val].zero_()
21 |
22 |
23 | def prune_param_groups(optimizer: torch.optim.Optimizer, mask: torch.Tensor, group_names: list[str] | None = None) -> dict[str, torch.Tensor]:
24 | """Removes parameter entries based on the given mask."""
25 | new_params = {}
26 | for group in optimizer.param_groups:
27 | if group_names is not None and group['name'] not in group_names:
28 | continue
29 | if len(group['params']) != 1:
30 | raise NotImplementedError('"prune_param_groups" only implemented for single-parameter groups.')
31 | old_param = group['params'][0]
32 | state = optimizer.state[old_param]
33 | new_param = torch.nn.Parameter(old_param[mask])
34 | if state:
35 | for val in ['exp_avg', 'exp_avg_sq']:
36 | state[val] = state[val][mask]
37 | optimizer.state.pop(old_param)
38 | optimizer.state[new_param] = state
39 | group['params'][0] = new_param
40 | new_params[group['name']] = new_param
41 | return new_params
42 |
43 |
44 | def extend_param_groups(optimizer: torch.optim.Optimizer, additional_params: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
45 | """Extend existing parameters by concatenating the given tensors."""
46 | new_params = {}
47 | for group in optimizer.param_groups:
48 | if len(group['params']) != 1:
49 | raise NotImplementedError('"extend_param_groups" only implemented for single-parameter groups.')
50 | extension_tensor = additional_params.get(group['name'], None)
51 | if extension_tensor is None:
52 | continue
53 | old_param = group['params'][0]
54 | state = optimizer.state[old_param]
55 | new_param = torch.nn.Parameter(torch.cat((old_param, extension_tensor), dim=0))
56 | if state:
57 | for val in ['exp_avg', 'exp_avg_sq']:
58 | state[val] = torch.cat((state[val], torch.zeros_like(extension_tensor)), dim=0)
59 | optimizer.state.pop(old_param)
60 | optimizer.state[new_param] = state
61 | group['params'][0] = new_param
62 | new_params[group['name']] = new_param
63 | return new_params
--------------------------------------------------------------------------------
/src/Optim/GradientScaling.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/GradientScaling.py: Gradient scaling routines."""
4 |
5 | from typing import Any
6 | import torch
7 |
8 |
9 | class _GradientScaler(torch.autograd.Function):
10 | """
11 | Utility Autograd function scaling gradients by a given factor.
12 | Adapted from NerfStudio: https://docs.nerf.studio/en/latest/_modules/nerfstudio/model_components/losses.html
13 | """
14 |
15 | @staticmethod
16 | def forward(ctx, value: torch.Tensor, scaling: Any) -> tuple[torch.Tensor, Any]:
17 | ctx.save_for_backward(scaling)
18 | return value, scaling
19 |
20 | @staticmethod
21 | def backward(ctx, output_grad: torch.Tensor, grad_scaling: Any) -> tuple[torch.Tensor, Any]:
22 | (scaling,) = ctx.saved_tensors
23 | return output_grad * scaling, grad_scaling
24 |
25 |
26 | def scaleGradient(*args: torch.Tensor, scaling: Any) -> tuple[torch.Tensor, ...] | torch.Tensor:
27 | """
28 | Scale the gradient of the given tensors.
29 | """
30 | output: list[torch.Tensor] = [_GradientScaler.apply(value, scaling)[0] for value in args]
31 | return tuple(output) if len(output) > 1 else output[0]
32 |
33 |
34 | def scaleGradientByDistance(*args: torch.Tensor, distances: torch.Tensor) -> tuple[torch.Tensor, ...]:
35 | """
36 | Scale the gradient of the given tensor based on the normalized ray distance.
37 | See: Radiance Field Gradient Scaling for Improved Near-Camera Training (https://gradient-scaling.github.io)
38 | """
39 | return scaleGradient(*args, scaling=torch.square(distances).clamp(0, 1))
40 |
--------------------------------------------------------------------------------
/src/Optim/Losses/BackgroundEntropy.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/LossesBackgroundEntropy.py: Background entropy loss (encouraging alpha channel to be 0 or 1)."""
4 |
5 | from typing import Any
6 |
7 | import torch
8 | import torchmetrics
9 |
10 |
11 | def backgroundEntropy(input: torch.Tensor, symmetrical=False) -> torch.Tensor:
12 | """functional"""
13 | x = input.clamp(min=1e-6, max=1.0 - 1e-6)
14 | return -(x * torch.log(x) + (1 - x) * torch.log(1 - x)).mean() if symmetrical else (-x * torch.log(x)).mean()
15 |
16 |
17 | class BackgroundEntropyLoss(torchmetrics.Metric):
18 | """torchmetrics implementation"""
19 | is_differentiable = True
20 | higher_is_better = False
21 | full_state_update = False
22 |
23 | def __init__(self, symmetrical: bool = False, **kwargs: Any) -> None:
24 | super().__init__(**kwargs)
25 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
26 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
27 | self.symmetrical = symmetrical
28 |
29 | def update(self, preds: torch.Tensor) -> None:
30 | """Update state with current alpha predictions."""
31 | x = preds.clamp(min=1e-6, max=1.0 - 1e-6)
32 | if self.symmetrical:
33 | y = -(x * torch.log(x) + (1 - x) * torch.log(1 - x)).sum()
34 | else:
35 | y = (-x * torch.log(x)).sum()
36 | self.running_sum += y
37 | self.total += preds.numel()
38 |
39 | def compute(self) -> torch.Tensor:
40 | """Computes background entropy over state."""
41 | return (self.running_sum / self.total) if self.total > 0 else (torch.tensor(0.0))
42 |
--------------------------------------------------------------------------------
/src/Optim/Losses/Base.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/LossesBase.py: Base Loss class for accumulation and logging."""
4 |
5 | from typing import Any, Callable
6 |
7 | import torch
8 |
9 | import Framework
10 | from Optim.Losses.utils import QualityMetricItem, LossMetricItem
11 |
12 |
13 | class BaseLoss(torch.nn.Module):
14 | """Simple configurable loss container for accumulation and wandb logging"""
15 |
16 | def __init__(self,
17 | loss_metrics: list[LossMetricItem] | None = None,
18 | quality_metrics: list[QualityMetricItem] | None = None) -> None:
19 | super().__init__()
20 | self.loss_metrics: list[LossMetricItem] = loss_metrics or []
21 | self.quality_metrics: list[QualityMetricItem] = quality_metrics or []
22 | self.activate_logging: bool = Framework.config.TRAINING.WANDB.ACTIVATE
23 |
24 | def addLossMetric(self, name: str, metric: Callable, weight: float = None) -> None:
25 | self.loss_metrics.append(LossMetricItem(
26 | name=name,
27 | metric_func=metric,
28 | weight=weight
29 | ))
30 |
31 | def addQualityMetric(self, name: str, metric: Callable) -> None:
32 | self.quality_metrics.append(QualityMetricItem(
33 | name=name,
34 | metric_func=metric,
35 | ))
36 |
37 | def reset(self) -> None:
38 | for item in self.loss_metrics + self.quality_metrics:
39 | item.reset()
40 |
41 | def log(self, iteration: int, log_validation: bool) -> None:
42 | if self.activate_logging:
43 | for item in self.loss_metrics + self.quality_metrics:
44 | val_train, val_eval = item.getAverage()
45 | data = {'train': val_train}
46 | if log_validation:
47 | data['eval'] = val_eval
48 | Framework.wandb.log({f'{item.name}': data}, step=iteration)
49 |
50 | def forward(self, configurations: dict[str, dict[str, Any]]) -> torch.Tensor:
51 | try:
52 | if self.activate_logging:
53 | with torch.no_grad():
54 | for item in self.quality_metrics:
55 | item.apply(train=self.training, accumulate=True, kwargs=configurations[item.name])
56 | sublosses = []
57 | for item in self.loss_metrics:
58 | sublosses.append(item.apply(train=self.training, accumulate=True, kwargs=configurations[item.name]))
59 | return torch.stack(sublosses).sum()
60 | except NameError:
61 | raise Framework.LossError(f'missing argument configuration for loss "{item.name}"')
62 | except TypeError:
63 | raise Framework.LossError(f'invalid argument configuration for loss "{item.name}"')
64 | except Exception as e:
65 | raise Framework.LossError(f'unexpected error occurred in loss "{item.name}": {e}')
66 |
--------------------------------------------------------------------------------
/src/Optim/Losses/Charbonnier.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/Charbonnier.py: Charbonnier loss as in Mip-NeRF 360."""
4 |
5 | import torch
6 |
7 |
8 | def charbonnier_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor:
9 | """Computes the Charbonnier loss as in Mip-NeRF 360."""
10 | return (input - target).pow(2).add(eps).sqrt().mean()
11 |
--------------------------------------------------------------------------------
/src/Optim/Losses/DSSIM.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/Losses/DSSIM.py: Loss based on the structural (dis-)similarity index measure (DSSIM = 1 - SSIM)."""
4 |
5 | from typing import Sequence, Literal
6 |
7 | import torch
8 | from torchmetrics.functional.image import structural_similarity_index_measure
9 | from functools import partial
10 |
11 |
12 | class DSSIMLoss:
13 | """SSIM-based perceptual loss called DSSIM (computed as DSSIM = 1 - SSIM)."""
14 |
15 | def __init__(
16 | self,
17 | gaussian_kernel: bool = True,
18 | sigma: float | Sequence[float] = 1.5,
19 | kernel_size: int | Sequence[int] = 11,
20 | reduction: Literal['elementwise_mean', 'sum', 'none', None] = 'elementwise_mean',
21 | data_range: float | tuple[float, float] | None = 1.0, # torchmetrics default is None
22 | k1: float = 0.01,
23 | k2: float = 0.03,
24 | return_full_image: bool = False,
25 | ) -> None:
26 | """Initialize loss function."""
27 | self.return_full_image = return_full_image
28 | self.loss_function = partial(
29 | structural_similarity_index_measure,
30 | gaussian_kernel=gaussian_kernel,
31 | sigma=sigma,
32 | kernel_size=kernel_size,
33 | reduction=reduction,
34 | data_range=data_range,
35 | k1=k1,
36 | k2=k2,
37 | return_full_image=return_full_image,
38 | return_contrast_sensitivity=False, # not configurable for now
39 | )
40 |
41 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
42 | """Calculate loss."""
43 | if self.return_full_image:
44 | return 1.0 - self.loss_function(preds=input[None], target=target[None])[1][0]
45 | return 1.0 - self.loss_function(preds=input[None], target=target[None]) # should be (1.0 - SSIM) / 2.0
46 |
--------------------------------------------------------------------------------
/src/Optim/Losses/DepthSmoothness.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/DepthSmoothness.py: smooth depth loss. adapted from https://kornia.readthedocs.io/en/v0.2.1/_modules/kornia/losses/depth_smooth.html."""
4 |
5 | import torch
6 |
7 |
8 | def _gradient_x(img: torch.Tensor) -> torch.Tensor:
9 | return img[:, :, :, 1:-1] - img[:, :, :, 0:-2]
10 |
11 |
12 | def _gradient_y(img: torch.Tensor) -> torch.Tensor:
13 | return img[:, :, 1:-1, :] - img[:, :, 0:-2, :]
14 |
15 |
16 | def _laplace_x(img: torch.Tensor) -> torch.Tensor:
17 | mi = img[:, :, :, 1:-1]
18 | le = img[:, :, :, :-2]
19 | ri = img[:, :, :, 2:]
20 | return le + ri - (2 * mi)
21 |
22 |
23 | def _laplace_y(img: torch.Tensor) -> torch.Tensor:
24 | mi = img[:, :, 1:-1, :]
25 | le = img[:, :, :-2, :]
26 | ri = img[:, :, 2:, :]
27 | return le + ri - (2 * mi)
28 |
29 |
30 | def depthSmoothnessLoss(
31 | depth: torch.Tensor,
32 | image: torch.Tensor) -> torch.Tensor:
33 |
34 | # compute the gradients
35 | idepth_dx: torch.Tensor = _laplace_x(depth)
36 | idepth_dy: torch.Tensor = _laplace_y(depth)
37 | image_dx: torch.Tensor = _gradient_x(image)
38 | image_dy: torch.Tensor = _gradient_y(image)
39 |
40 | # compute image weights
41 | weights_x: torch.Tensor = torch.exp(
42 | -torch.mean(torch.abs(image_dx), dim=1, keepdim=True))
43 | weights_y: torch.Tensor = torch.exp(
44 | -torch.mean(torch.abs(image_dy), dim=1, keepdim=True))
45 |
46 | # apply image weights to depth
47 | smoothness_x: torch.Tensor = torch.abs(idepth_dx * weights_x)
48 | smoothness_y: torch.Tensor = torch.abs(idepth_dy * weights_y)
49 | return torch.mean(smoothness_x) + torch.mean(smoothness_y)
50 |
--------------------------------------------------------------------------------
/src/Optim/Losses/Distortion.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/LossesDistortion.py: Distortion loss as introduced in MipNeRF360 / DVGOv2."""
4 |
5 | from typing import Any
6 | import torch
7 | import torchmetrics
8 | from Methods.InstantNGP.CudaExtensions.VolumeRenderingV2 import DistortionLoss as DistortionLossAutogradFN
9 |
10 |
11 | class DistortionLoss(torch.nn.Module):
12 |
13 | def __init__(self) -> None:
14 | super().__init__()
15 |
16 | def __call__(self, ws: torch.Tensor, deltas: torch.Tensor, ts: torch.Tensor, rays_a: torch.Tensor) -> torch.Tensor:
17 | return DistortionLossAutogradFN.apply(ws, deltas, ts, rays_a).mean()
18 |
19 |
20 | class DistortionLossTorchMetrics(torchmetrics.Metric):
21 | """torchmetrics implementation of the efficient Distortion loss introduced in MipNeRF360 / DVGOv2"""
22 | is_differentiable = True
23 | higher_is_better = False
24 | full_state_update = False
25 |
26 | def __init__(self, **kwargs: Any) -> None:
27 | super().__init__(**kwargs)
28 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
29 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
30 |
31 | def update(self, ws: torch.Tensor, deltas: torch.Tensor, ts: torch.Tensor, rays_a: torch.Tensor) -> None:
32 | """Update state."""
33 | x = DistortionLossAutogradFN.apply(ws, deltas, ts, rays_a)
34 | y = x.sum()
35 | self.running_sum = self.running_sum + y
36 | self.total += x.numel()
37 |
38 | def compute(self) -> torch.Tensor:
39 | """Computes distortion loss over state."""
40 | return (self.running_sum / self.total) if self.total > 0 else torch.tensor(0.0)
41 |
--------------------------------------------------------------------------------
/src/Optim/Losses/FusedDSSIM.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Optim/Losses/FusedDSSIM.py: Loss based on the structural (dis-)similarity index measure (DSSIM = 1 - SSIM).
5 | """
6 |
7 | import torch
8 |
9 | from Thirdparty.FusedSSIM import fused_ssim
10 |
11 |
12 | def fused_dssim(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
13 | """Calculate loss."""
14 | return 1.0 - fused_ssim(input[None], target[None]) # should be (1.0 - SSIM) / 2.0
15 |
--------------------------------------------------------------------------------
/src/Optim/Losses/Magnitude.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/LossesMagnitude.py: Mean 1-norm over given dim."""
4 |
5 | from typing import Any
6 |
7 | import torch
8 | import torchmetrics
9 |
10 |
11 | def magnitudeLoss(input: torch.Tensor, dim: int = 1) -> torch.Tensor:
12 | """functional"""
13 | if input is None:
14 | return torch.tensor(0.0, requires_grad=False)
15 | return torch.norm(input, dim=dim, keepdim=True, p=1).mean()
16 |
17 |
18 | class MagnitudeLoss(torchmetrics.Metric):
19 | """torchmetrics implementation"""
20 | is_differentiable = True
21 | higher_is_better = False
22 | full_state_update = False
23 |
24 | def __init__(self, dim: int, **kwargs: Any) -> None:
25 | super().__init__(**kwargs)
26 | self.add_state("running_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
27 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
28 | self.dim = dim
29 |
30 | def update(self, preds: torch.Tensor) -> None:
31 | """torchmetrics update override"""
32 | x = torch.norm(preds, dim=self.dim, keepdim=True, p=1)
33 | self.running_sum += x.sum()
34 | self.total += x.numel()
35 |
36 | def compute(self) -> torch.Tensor:
37 | """torchmetrics compute override"""
38 | return (self.running_sum / self.total) if self.total > 0 else (torch.tensor(0.0))
39 |
--------------------------------------------------------------------------------
/src/Optim/Losses/Robust.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/Robust.py: Robust loss function described in https://arxiv.org/abs/1701.03077."""
4 |
5 | import torch
6 |
7 |
8 | class RobustLoss:
9 | """General and Adaptive Robust Loss Function."""
10 |
11 | def __init__(
12 | self,
13 | alpha: float,
14 | c: float,
15 | min_alpha: float = -1000.0 # resembles -inf
16 | ) -> None:
17 | """Initialize loss function."""
18 | c_reciprocal = 1.0 / c
19 | if alpha == 2.0:
20 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(0.5).mean()
21 | elif alpha == 0.0:
22 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(0.5).add(1.0).log().mean()
23 | elif alpha < min_alpha:
24 | self.loss_function = lambda x, y: (1.0 - (x - y).mul(c_reciprocal).pow(2).mul(-0.5).exp()).mean()
25 | else:
26 | factor = abs(alpha - 2.0) / alpha
27 | exponent = alpha / 2.0
28 | scale = 1.0 / abs(alpha - 2.0)
29 | self.loss_function = lambda x, y: (x - y).mul(c_reciprocal).pow(2).mul(scale).add(1.0).pow(exponent).sub(1.0).mul(factor).mean()
30 |
31 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
32 | """Calculate loss."""
33 | return self.loss_function(input, target)
34 |
--------------------------------------------------------------------------------
/src/Optim/Losses/VGG.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Optim/Losses/VGG.py: Perceptual loss based on VGG features.
5 | See https://arxiv.org/abs/1603.08155 or https://arxiv.org/abs/1609.04802 for details.
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Callable
10 | import torch
11 | from torchvision.models import VGG, vgg19, VGG16_Weights, VGG19_Weights
12 |
13 |
14 | @dataclass(frozen=True)
15 | class VGGLossConfig:
16 | """Configuration for the VGG loss."""
17 | model_class: 'VGG' = vgg19
18 | used_weights: VGG16_Weights | VGG19_Weights = VGG19_Weights.IMAGENET1K_V1
19 | # used_blocks: tuple[slice] = (slice(0, 4), slice(4, 9), slice(9, 16), slice(16, 23)) # vgg16
20 | used_blocks: tuple[slice] = (slice(0, 4), slice(4, 9), slice(9, 18), slice(18, 27), slice(27, 36)) # vgg19
21 | loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = torch.nn.functional.l1_loss
22 |
23 |
24 | class VGGLoss:
25 | """Perceptual loss based on VGG features."""
26 |
27 | def __init__(self, config: VGGLossConfig = VGGLossConfig()) -> None:
28 | """Initialize the VGG loss."""
29 | model = config.model_class(weights=config.used_weights).features.eval()
30 | for parameter in model.parameters():
31 | parameter.requires_grad = False
32 | self.blocks = torch.nn.ModuleList([model[block] for block in config.used_blocks])
33 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
34 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
35 | self.loss_function = config.loss_function
36 |
37 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
38 | """Calculate the VGG loss. Both input and target are expected to have shape (C, H, W) with C being RGB."""
39 | input = (input[None] - self.mean) / self.std
40 | target = (target[None] - self.mean) / self.std
41 | loss = torch.tensor(0.0, requires_grad=False)
42 | for block in self.blocks:
43 | input = block(input)
44 | target = block(target)
45 | loss += self.loss_function(input, target)
46 | return loss
47 |
--------------------------------------------------------------------------------
/src/Optim/Losses/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Optim/Lossesutils.py: Utilities for loss implementations."""
4 |
5 | from dataclasses import dataclass, field
6 | from typing import Any, Callable
7 |
8 | import Framework
9 | import torch
10 | import torchmetrics
11 |
12 |
13 | @dataclass
14 | class QualityMetricItem:
15 | """Used to store quality metrics in BaseLoss, only for training evaluation in wandb"""
16 | name: str
17 | metric_func: Callable
18 |
19 | _running_sum: list[torch.Tensor, torch.Tensor] = field(init=False)
20 | _num_iters: list[int, int] = field(init=False)
21 |
22 | def __post_init__(self):
23 | self.reset()
24 |
25 | def reset(self) -> None:
26 | self._running_sum = [torch.tensor(0.0, requires_grad=False), torch.tensor(0.0, requires_grad=False)]
27 | self._num_iters = [0, 0]
28 |
29 | def getAverage(self):
30 | return [(self._running_sum[i].item() / float(self._num_iters[i])) if self._num_iters[i] > 0 else 0.0 for i in range(2)]
31 |
32 | def _call_metric(self, kwargs: Any) -> torch.Tensor:
33 | return self.metric_func(**kwargs)
34 |
35 | def apply(self, train: bool, accumulate: bool, kwargs: Any) -> torch.Tensor:
36 | loss_val = self._call_metric(kwargs)
37 | if accumulate:
38 | idx = 0 if train else 1
39 | self._running_sum[idx] += loss_val.detach()
40 | self._num_iters[idx] += 1
41 | return loss_val
42 |
43 |
44 | @dataclass
45 | class LossMetricItem(QualityMetricItem):
46 | """Used to store individual loss terms in BaseLoss"""
47 | weight: float = 1.0
48 |
49 | def __post_init__(self):
50 | super().__post_init__()
51 | self.initial_weight = max(0.0, self.weight) if self.weight is not None else 0.0
52 | self.weight = self.initial_weight
53 | if isinstance(self.metric_func, torchmetrics.Metric) and not self.metric_func.is_differentiable:
54 | raise Framework.LossError(f'requested loss metric {self.name} (instance of {self.metric_func.__class__.__name__}) is not differentiable')
55 |
56 | def _call_metric(self, kwargs: Any) -> torch.Tensor:
57 | if self.weight > 0.0:
58 | return self.metric_func(**kwargs) * self.weight
59 | return torch.tensor(0.0)
60 |
--------------------------------------------------------------------------------
/src/Optim/Samplers/DatasetSamplers.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Samplers/DatasetSamplers.py: Samplers returning a batch of rays from a dataset."""
4 | import torch
5 |
6 | import Framework
7 | from Cameras.utils import CameraProperties
8 | from Datasets.Base import BaseDataset
9 | from Optim.Samplers.ImageSamplers import ImageSampler
10 | from Optim.Samplers.utils import IncrementalSequentialSampler, RandomSequentialSampler, SequentialSampler
11 |
12 |
13 | class DatasetSampler:
14 |
15 | def __init__(self,
16 | dataset: BaseDataset,
17 | random: bool = True,
18 | img_sampler_cls: type[ImageSampler] | None = None) -> None:
19 | self.mode = dataset.mode
20 | self.id_sampler = RandomSequentialSampler(num_elements=len(dataset)) if random else SequentialSampler(num_elements=len(dataset))
21 | self.img_samplers = [img_sampler_cls(num_elements=(i.width * i.height)) for i in dataset] if img_sampler_cls else None
22 |
23 | def get(self, dataset: BaseDataset, ray_batch_size: int | None = None) -> dict[str, int | CameraProperties | None]:
24 | if dataset.mode != self.mode:
25 | raise Framework.SamplerError(f'DatasetSampler initialized for mode "{self.mode}" got dataset with active mode "{dataset.mode}"')
26 | sample_id = self.id_sampler.get(num_samples=1).item()
27 | camera_properties = dataset[sample_id]
28 | image_sampler = ray_ids = ray_batch = None
29 | if self.img_samplers and ray_batch_size is not None:
30 | image_sampler = self.img_samplers[sample_id]
31 | ray_ids = image_sampler.get(ray_batch_size).to(Framework.config.GLOBAL.DEFAULT_DEVICE)
32 | ray_batch = dataset.camera.setProperties(camera_properties).generateRays()[ray_ids]
33 | return {
34 | 'sample_id': sample_id,
35 | 'camera_properties': camera_properties,
36 | 'image_sampler': image_sampler,
37 | 'ray_ids': ray_ids,
38 | 'ray_batch': ray_batch
39 | }
40 |
41 |
42 | class RayPoolSampler:
43 | def __init__(self,
44 | dataset: BaseDataset,
45 | img_sampler_cls: type[ImageSampler]) -> None:
46 | self.mode = dataset.mode
47 | all_rays = dataset.getAllRays()
48 | self.image_sampler = img_sampler_cls(num_elements=all_rays.shape[0])
49 |
50 | def get(self, dataset: BaseDataset, ray_batch_size: int) -> dict[str, None | ImageSampler | torch.Tensor]:
51 | if dataset.mode != self.mode:
52 | raise Framework.SamplerError(f'RayPoolSampler initialized for mode "{self.mode}" got dataset with active mode "{dataset.mode}"')
53 | sample_id = camera_properties = None
54 | rays_all = dataset.getAllRays()
55 | ray_ids = self.image_sampler.get(ray_batch_size).to(rays_all.device)
56 | ray_batch = rays_all[ray_ids].to(Framework.config.GLOBAL.DEFAULT_DEVICE)
57 | return {
58 | 'sample_id': sample_id,
59 | 'camera_properties': camera_properties,
60 | 'image_sampler': self.image_sampler,
61 | 'ray_ids': ray_ids,
62 | 'ray_batch': ray_batch
63 | }
64 |
65 |
66 | class IncrementalDatasetSampler(DatasetSampler):
67 |
68 | def __init__(self,
69 | dataset: BaseDataset,
70 | img_sampler_cls: type[ImageSampler] | None = None) -> None:
71 | super().__init__(dataset, False, img_sampler_cls)
72 | self.id_sampler = IncrementalSequentialSampler(num_elements=len(dataset))
73 |
--------------------------------------------------------------------------------
/src/Optim/Samplers/ImageSamplers.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Samplers/ImageSamplers.py: Samplers selceting a subset of rays from a given ray batch."""
4 |
5 | from abc import ABC, abstractmethod
6 |
7 | import torch
8 |
9 | from Optim.Samplers.utils import RandomSequentialSampler, SequentialSampler
10 |
11 |
12 | class ImageSampler(ABC):
13 | """Abstract base class for image samplers."""
14 |
15 | def __init__(self, num_elements: int) -> None:
16 | super().__init__()
17 | self.num_elements: int = num_elements
18 |
19 | @abstractmethod
20 | def get(self, ray_batch_size: int) -> torch.Tensor:
21 | pass
22 |
23 | def update(self, **_) -> None:
24 | pass
25 |
26 |
27 | class SequentialImageSampler(ImageSampler):
28 |
29 | def __init__(self, num_elements: int) -> None:
30 | super().__init__(num_elements)
31 | self.sampler = SequentialSampler(num_elements=self.num_elements)
32 |
33 | def get(self, ray_batch_size: int) -> torch.Tensor:
34 | return self.sampler.get(num_samples=ray_batch_size)
35 |
36 |
37 | class SequentialRandomImageSampler(SequentialImageSampler):
38 |
39 | def __init__(self, num_elements: int) -> None:
40 | super().__init__(num_elements)
41 | self.sampler = RandomSequentialSampler(num_elements=self.num_elements)
42 |
43 |
44 | class RandomImageSampler(ImageSampler):
45 |
46 | def get(self, ray_batch_size: int) -> torch.Tensor:
47 | return torch.randint(low=0, high=self.num_elements, size=(ray_batch_size,))
48 |
49 |
50 | class MultinomialImageSampler(ImageSampler):
51 |
52 | def __init__(self, num_elements: int) -> None:
53 | super().__init__(num_elements)
54 | self.pdf = torch.ones(size=(self.num_elements,))
55 |
56 | def get(self, ray_batch_size: int) -> torch.Tensor:
57 | return torch.multinomial(input=self.pdf, num_samples=ray_batch_size)
58 |
59 | @torch.no_grad()
60 | def update(self, ray_ids: torch.Tensor, weights: torch.Tensor, constant_addend: float = False) -> None:
61 | if constant_addend is not None:
62 | self.pdf += constant_addend
63 | self.pdf[ray_ids] = weights
64 |
--------------------------------------------------------------------------------
/src/Optim/Samplers/RaySamplers.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerficg-project/nerficg/4e625350cdb7558da9f4d9e0983b40ddf6a98add/src/Optim/Samplers/RaySamplers.py
--------------------------------------------------------------------------------
/src/Optim/Samplers/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Samplers/utils.py: Utilities for image and ray sampling routines."""
4 |
5 | import torch
6 |
7 | import Framework
8 |
9 |
10 | class SequentialSampler:
11 | def __init__(self, num_elements: int) -> None:
12 | self.num_elements: int = num_elements
13 | self.indices: torch.Tensor = torch.arange(self.num_elements)
14 | self.reset()
15 |
16 | def shuffle(self) -> None:
17 | pass
18 |
19 | def reset(self) -> None:
20 | self.current_id: int = 0
21 | self.shuffle()
22 |
23 | def get(self, num_samples: int) -> torch.Tensor:
24 | if num_samples > self.num_elements:
25 | raise Framework.SamplerError(f"cannot draw {num_samples} samples from {self.num_elements} elements")
26 | if self.current_id + num_samples > self.num_elements:
27 | self.reset()
28 | samples = self.indices[self.current_id:self.current_id + num_samples]
29 | self.current_id += num_samples
30 | return samples
31 |
32 |
33 | class RandomSequentialSampler(SequentialSampler):
34 |
35 | def shuffle(self) -> None:
36 | self.indices: torch.Tensor = self.indices[torch.randperm(self.num_elements)]
37 |
38 |
39 | class IncrementalSequentialSampler:
40 |
41 | def __init__(self, num_elements: int) -> None:
42 | self.num_elements: int = num_elements
43 | self.current_size: int = 0
44 | self.indices: torch.Tensor = torch.arange(self.num_elements)
45 | self.reset()
46 |
47 | def reset(self) -> None:
48 | self.current_size = min(self.current_size + 1, self.num_elements)
49 | self.current_indices: torch.Tensor = self.indices[:self.current_size]
50 | self.current_id: int = 0
51 |
52 | def get(self, num_samples: int) -> torch.Tensor:
53 | if num_samples > self.current_size:
54 | raise Framework.SamplerError(f"cannot draw {num_samples} samples from {self.current_size} elements")
55 | if self.current_id + num_samples > self.current_size:
56 | self.reset()
57 | samples = self.current_indices[self.current_id:self.current_id + num_samples]
58 | self.current_id += num_samples
59 | return samples
60 |
--------------------------------------------------------------------------------
/src/Thirdparty/Apex.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Thirdparty/Apex.py: NVIDIA Apex (https://github.com/NVIDIA/apex).
5 | """
6 |
7 | import Framework
8 |
9 | __extension_name__ = 'apex'
10 | __install_command__ = [
11 | 'pip', 'install',
12 | '-v', '--disable-pip-version-check', '--no-cache-dir', '--no-build-isolation',
13 | '--config-settings', '--build-option=--cpp_ext',
14 | '--config-settings', '--build-option=--cuda_ext',
15 | 'git+https://github.com/NVIDIA/apex',
16 | ]
17 |
18 | try:
19 | from apex.optimizers import FusedAdam # noqa
20 | except ImportError:
21 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
22 |
--------------------------------------------------------------------------------
/src/Thirdparty/DiffGaussianRasterization.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | from pathlib import Path
4 |
5 | import Framework
6 |
7 | __extension_name__ = 'DiffGaussianRasterization'
8 | REPO_URL = 'https://github.com/graphdeco-inria/diff-gaussian-rasterization'
9 | COMMIT_HASH = '59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d'
10 | __install_command__ = ['pip', 'install', f'git+{REPO_URL}@{COMMIT_HASH}']
11 |
12 | try:
13 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer # noqa
14 | except ImportError:
15 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
16 |
--------------------------------------------------------------------------------
/src/Thirdparty/FusedSSIM.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Thirdparty/FusedSSIM.py: fast cuda ssim implementation from https://github.com/rahul-goel/fused-ssim.
5 | """
6 |
7 | import Framework
8 |
9 | __extension_name__ = 'FusedSSIM'
10 | __install_command__ = [
11 | 'pip', 'install',
12 | 'git+https://github.com/rahul-goel/fused-ssim/',
13 | ]
14 |
15 | try:
16 | from fused_ssim import fused_ssim
17 | except ImportError:
18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
19 |
--------------------------------------------------------------------------------
/src/Thirdparty/SimpleKNN.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | from pathlib import Path
4 |
5 | import Framework
6 |
7 | __extension_name__ = 'simple-knn'
8 | REPO_URL = 'https://github.com/camenduru/simple-knn'
9 | COMMIT_HASH = '44f764299fa305faf6ec5ebd99939e0508331503'
10 | __install_command__ = ['pip', 'install', f'git+{REPO_URL}@{COMMIT_HASH}']
11 |
12 | try:
13 | from simple_knn._C import distCUDA2 # noqa
14 | except ImportError:
15 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
16 |
--------------------------------------------------------------------------------
/src/Thirdparty/TinyCudaNN.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Thirdparty/TinyCudaNN.py: Fast fused MLP and input encoding from NVLabs (https://github.com/NVlabs/tiny-cuda-nn).
5 | """
6 |
7 | import Framework
8 |
9 | __extension_name__ = 'tinycudann'
10 | __install_command__ = [
11 | 'pip', 'install',
12 | 'git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch',
13 | ]
14 |
15 | try:
16 | from tinycudann import * # noqa
17 | except ImportError:
18 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
19 |
--------------------------------------------------------------------------------
/src/Thirdparty/TorchScatter.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Thirdparty/TorchScatter.py: Torch Scatter lib.
5 | """
6 |
7 | import Framework
8 | import torch
9 |
10 | __extension_name__ = 'torch-scatter'
11 | __install_command__ = [
12 | 'pip', 'install',
13 | 'torch-scatter',
14 | '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html',
15 | ]
16 |
17 | try:
18 | from torch_scatter import segment_csr # noqa
19 | except ImportError:
20 | raise Framework.ExtensionError(name=__extension_name__, install_command=__install_command__)
21 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/BulletTime.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Visual/Trajectories/BulletTime.py: A visualization for dynamic scenes, following a lemniscate trajectory while replaying time.
5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main).
6 | """
7 |
8 | import torch
9 |
10 | import Framework
11 | from Cameras.Base import BaseCamera
12 | from Cameras.utils import CameraProperties
13 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory
14 |
15 |
16 | class bullet_time(CameraTrajectory):
17 | """A visualization for dynamic scenes, following a lemniscate trajectory while replaying time."""
18 |
19 | def __init__(self, reference_pose_rel_id: float = 0.5, custom_lookat: torch.Tensor | None = None,
20 | custom_up: torch.Tensor | None = None, num_frames_per_rotation: float = 90, degree: float = 10, num_repeats: int = 2) -> None:
21 | super().__init__()
22 | self.reference_pose_rel_id = reference_pose_rel_id
23 | self.custom_lookat = custom_lookat
24 | self.custom_up = custom_up
25 | self.num_frames_per_rotation = num_frames_per_rotation
26 | self.degree = degree
27 | self.num_repeats = num_repeats
28 |
29 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
30 | """A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame."""
31 | data: list[CameraProperties] = []
32 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))]
33 | reference_camera.setProperties(reference_pose)
34 | lookat = self.custom_lookat
35 | if lookat is None:
36 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3]
37 | lookat = pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1])
38 | up = self.custom_up if self.custom_up is not None else reference_camera.getUpVector()[:3]
39 | lemniscate_trajectory_c2ws = getLemniscateTrajectory(
40 | reference_pose,
41 | lookat=lookat.to(Framework.config.GLOBAL.DEFAULT_DEVICE),
42 | up=up.to(Framework.config.GLOBAL.DEFAULT_DEVICE),
43 | num_frames=self.num_frames_per_rotation,
44 | degree=self.degree,
45 | )
46 | num_frames = self.num_frames_per_rotation * self.num_repeats
47 | for frame_idx in range(num_frames):
48 | data.append(CameraProperties(
49 | width=reference_pose.width,
50 | height=reference_pose.height,
51 | rgb=None,
52 | alpha=None,
53 | c2w=lemniscate_trajectory_c2ws[frame_idx % self.num_frames_per_rotation].clone(),
54 | focal_x=reference_pose.focal_x,
55 | focal_y=reference_pose.focal_y,
56 | distortion_parameters=reference_pose.distortion_parameters,
57 | principal_offset_x=0.0,
58 | principal_offset_y=0.0,
59 | timestamp=frame_idx / (num_frames - 1)
60 | ))
61 | return data
62 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/FancyZoom.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Visual/Trajectories/FancyZoom.py: A camera trajectory following the training poses, interrupted by zooms and spiral movements."""
4 |
5 | import math
6 | import torch
7 |
8 | import Framework
9 | from Cameras.Base import BaseCamera
10 | from Cameras.utils import CameraProperties
11 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory
12 |
13 |
14 | class fancy_zoom(CameraTrajectory):
15 | """A camera trajectory following the training poses, interrupted by zooms and spiral movements."""
16 |
17 | def __init__(self, num_breaks: int = 2, num_zoom_frames: int = 90, zoom_factor: float = 0.2, lemni_frames_per_rot: int = 60, lemni_degree: int = 3) -> None:
18 | super().__init__()
19 | self.num_breaks = num_breaks
20 | self.num_zoom_frames = num_zoom_frames
21 | self.zoom_factor = zoom_factor
22 | self.lemni_frames_per_rot = lemni_frames_per_rot
23 | self.lemni_degree = lemni_degree
24 |
25 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
26 | """Generates the camera trajectory using a list of reference poses."""
27 | data: list[CameraProperties] = []
28 | num_reference_frames = len(reference_poses)
29 | break_indices = torch.linspace(0, len(reference_poses), self.num_breaks + 2)[1:-1].int().tolist()
30 | for i in range(break_indices[0]):
31 | data.append(reference_poses[i].toSimple())
32 | for j in range(len(break_indices)):
33 | reference = reference_poses[break_indices[j]]
34 | for i in range(self.num_zoom_frames):
35 | new = reference.toSimple()
36 | new.focal_x = new.focal_x + (new.focal_x * self.zoom_factor * math.sin((i / (self.num_zoom_frames - 1)) * 2 * math.pi))
37 | new.focal_y = new.focal_y + (new.focal_y * self.zoom_factor * math.sin((i / (self.num_zoom_frames - 1)) * 2 * math.pi))
38 | data.append(new)
39 | reference_camera.setProperties(reference)
40 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3]
41 | lookat = (pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1])).to(Framework.config.GLOBAL.DEFAULT_DEVICE)
42 | up = reference_camera.getUpVector().to(Framework.config.GLOBAL.DEFAULT_DEVICE)
43 | lemniscate_trajectory_c2ws = getLemniscateTrajectory(reference, lookat=lookat, up=up, num_frames=self.lemni_frames_per_rot, degree=self.lemni_degree)
44 | for i in lemniscate_trajectory_c2ws:
45 | new = reference.toSimple()
46 | new.c2w = i
47 | data.append(new)
48 | if j < len(break_indices) - 1:
49 | for i in range(break_indices[j], break_indices[j + 1]):
50 | data.append(reference_poses[i].toSimple())
51 | for i in range(break_indices[-1], num_reference_frames):
52 | data.append(reference_poses[i].toSimple())
53 | return data
54 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/NovelView.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Visual/Trajectories/NovelView.py: A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame.
5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main).
6 | """
7 |
8 | import torch
9 |
10 | import Framework
11 | from Cameras.Base import BaseCamera
12 | from Cameras.utils import CameraProperties
13 | from Visual.Trajectories.utils import getLemniscateTrajectory, CameraTrajectory
14 |
15 |
16 | class novel_view(CameraTrajectory):
17 | """A trajectory for dynamic scenes, replaying time for a single, fixed view."""
18 |
19 | def __init__(self, reference_pose_rel_id: float = 0.5, custom_lookat: torch.Tensor | None = None,
20 | custom_up: torch.Tensor | None = None, num_frames_per_rotation: float = 90, degree: float = 30) -> None:
21 | super().__init__()
22 | self.reference_pose_rel_id = reference_pose_rel_id
23 | self.custom_lookat = custom_lookat
24 | self.custom_up = custom_up
25 | self.num_frames_per_rotation = num_frames_per_rotation
26 | self.degree = degree
27 |
28 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
29 | """A visualization for dynamic scenes, following a lemniscate trajectory while freezing time at a reference frame."""
30 | data: list[CameraProperties] = []
31 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))]
32 | reference_camera.setProperties(reference_pose)
33 | lookat = self.custom_lookat
34 | if lookat is None:
35 | pos_viewdir = reference_camera.getPositionAndViewdir()[..., :3]
36 | lookat = pos_viewdir[0] + (((reference_camera.near_plane + reference_camera.far_plane) / 2) * pos_viewdir[1])
37 | up = self.custom_up if self.custom_up is not None else reference_camera.getUpVector()[:3]
38 | lemniscate_trajectory_c2ws = getLemniscateTrajectory(
39 | reference_pose,
40 | lookat=lookat.to(Framework.config.GLOBAL.DEFAULT_DEVICE),
41 | up=up.to(Framework.config.GLOBAL.DEFAULT_DEVICE),
42 | num_frames=self.num_frames_per_rotation,
43 | degree=self.degree,
44 | )
45 | for c2w in lemniscate_trajectory_c2ws:
46 | data.append(CameraProperties(
47 | width=reference_pose.width,
48 | height=reference_pose.height,
49 | rgb=None,
50 | alpha=None,
51 | c2w=c2w.clone(),
52 | focal_x=reference_pose.focal_x,
53 | focal_y=reference_pose.focal_y,
54 | distortion_parameters=reference_pose.distortion_parameters,
55 | principal_offset_x=0.0,
56 | principal_offset_y=0.0,
57 | timestamp=reference_pose.timestamp,
58 | ))
59 | return data
60 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/SpiralPath.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Visual/Trajectories/SpiralPath.py: A spiral camera trajectory for forward facing scenes.
5 | Used by the original NeRF for LLFF visualizations.
6 | """
7 |
8 | import math
9 | import torch
10 |
11 | import Framework
12 | from Cameras.Base import BaseCamera
13 | from Cameras.utils import CameraProperties, createCameraMatrix, normalizeRays
14 | from Datasets.utils import getAveragePose
15 | from Visual.Trajectories.utils import CameraTrajectory
16 |
17 |
18 | class spiral_path(CameraTrajectory):
19 | """A spiral camera trajectory for forward facing scenes."""
20 |
21 | def __init__(self, num_views: int = 120, num_rotations: int = 2, ) -> None:
22 | super().__init__()
23 | self.num_views: int = num_views
24 | self.num_rotations: int = num_rotations
25 |
26 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
27 | """Generates the camera trajectory using a list of reference poses."""
28 | mean_focal_x: float = torch.tensor([c.focal_x for c in reference_poses]).mean().item()
29 | mean_focal_y: float = torch.tensor([c.focal_y for c in reference_poses]).mean().item()
30 | c2ws = torch.stack([c.c2w for c in reference_poses], dim=0).to(Framework.config.GLOBAL.DEFAULT_DEVICE)
31 | return createSpiralPath(
32 | view_matrices=c2ws,
33 | near_plane=reference_camera.near_plane,
34 | far_plane=reference_camera.far_plane,
35 | image_shape=(3, reference_poses[0].height, reference_poses[0].width),
36 | focal_x=mean_focal_x,
37 | focal_y=mean_focal_y,
38 | n_views=self.num_views,
39 | n_rots=self.num_rotations
40 | )
41 |
42 |
43 | def createSpiralPath(view_matrices: torch.Tensor, near_plane: float, far_plane: float, image_shape: tuple[int, int, int],
44 | focal_x: float, focal_y: float, n_views: int, n_rots: int) -> list[CameraProperties]:
45 | """Creates views on spiral path, adapted from the original NeRF implementation."""
46 | average_pose: torch.Tensor = getAveragePose(view_matrices)
47 | up: torch.Tensor = -normalizeRays(view_matrices[:, :3, 1].sum(0))[0]
48 | close_depth: float = near_plane * 0.9
49 | inf_depth: float = far_plane * 1.0
50 | dt: float = 0.75
51 | focal: float = 1.0 / ((1.0 - dt) / close_depth + dt / inf_depth)
52 | rads: torch.Tensor = 0.01 * torch.quantile(torch.abs(view_matrices[:, :3, 3]), q=0.9, dim=0)
53 | view_matrices_spiral: list[torch.Tensor] = []
54 | rads: torch.Tensor = torch.tensor(list(rads) + [1.])
55 | for theta in torch.linspace(0.0, 2.0 * math.pi * n_rots, n_views + 1)[:-1]:
56 | c: torch.Tensor = torch.mm(
57 | average_pose[:3, :4],
58 | (torch.tensor([torch.cos(theta), torch.sin(theta), torch.sin(theta * 0.5), 1.]) * rads)[:, None]
59 | ).squeeze()
60 | z: torch.Tensor = -normalizeRays(c - torch.mm(average_pose[:3, :4], torch.tensor([[0], [0], [-focal], [1.]])).squeeze())[0]
61 | view_matrices_spiral.append(createCameraMatrix(z, up, c))
62 | return [
63 | CameraProperties(
64 | width=image_shape[2],
65 | height=image_shape[1],
66 | rgb=None,
67 | alpha=None,
68 | c2w=c2w.float(),
69 | focal_x=focal_x,
70 | focal_y=focal_y
71 | )
72 | for c2w in view_matrices_spiral
73 | ]
74 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/StabilizedTrain.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Visual/Trajectories/StabilizedTrain.py: A camera trajectory following the training poses, stabilizing poses over a window."""
4 |
5 | import torch
6 |
7 | import Framework
8 | from Cameras.Base import BaseCamera
9 | from Cameras.utils import CameraProperties
10 | from Datasets.utils import getAveragePose
11 | from Visual.Trajectories.utils import CameraTrajectory
12 |
13 |
14 | class stabilized_train(CameraTrajectory):
15 | """A camera trajectory following the training poses, stabilizing poses over a window."""
16 |
17 | def __init__(self, window: int = 5) -> None:
18 | super().__init__()
19 | if window % 2 == 0:
20 | raise Framework.VisualizationError('Window size must be an odd number.')
21 | self.half_window = window // 2
22 |
23 | def _generate(self, _: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
24 | """Generates the camera trajectory using a list of reference poses."""
25 | data: list[CameraProperties] = []
26 | poses_all = torch.stack([c.c2w for c in reference_poses], dim=0)
27 | for i, c in enumerate(reference_poses):
28 | c2w = getAveragePose(poses_all[max(0, i - self.half_window):min(len(reference_poses), i + self.half_window)])
29 | prop = c.toSimple()
30 | prop.c2w = c2w
31 | data.append(prop)
32 | return data
33 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/StabilizedView.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """
4 | Visual/Trajectories/StabilizedView.py: A trajectory for dynamic scenes, replaying time for a single, fixed view.
5 | Adapted from DyCheck (iPhone dataset) by Gao et al. 2022 (https://github.com/KAIR-BAIR/dycheck/blob/main).
6 | """
7 |
8 | from Cameras.Base import BaseCamera
9 | from Cameras.utils import CameraProperties
10 | from Visual.Trajectories.utils import CameraTrajectory
11 |
12 |
13 | class stabilized_view(CameraTrajectory):
14 | """A trajectory for dynamic scenes, replaying time for a single, fixed view."""
15 |
16 | def __init__(self, reference_pose_rel_id: float = 0.5) -> None:
17 | super().__init__()
18 | self.reference_pose_rel_id = reference_pose_rel_id
19 |
20 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
21 | """Generates the camera trajectory using a list of reference poses."""
22 | data: list[CameraProperties] = []
23 | reference_pose = reference_poses[int(min(1.0, max(0.0, self.reference_pose_rel_id)) * len(reference_poses))]
24 | for camera_properties in reference_poses:
25 | data.append(CameraProperties(
26 | width=reference_pose.width,
27 | height=reference_pose.height,
28 | rgb=None,
29 | alpha=None,
30 | c2w=reference_pose.c2w.clone(),
31 | focal_x=reference_pose.focal_x,
32 | focal_y=reference_pose.focal_y,
33 | principal_offset_x=0.0,
34 | principal_offset_y=0.0,
35 | distortion_parameters=reference_pose.distortion_parameters,
36 | timestamp=camera_properties.timestamp
37 | ))
38 | return data
39 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Visual.Trajectories is a package for adding custom camera trajectories to a dataset, enabling convenient visualization in the gui and inference scripts.
3 | """
4 |
5 | import importlib
6 | from pathlib import Path
7 | from .utils import CameraTrajectory as CameraTrajectory
8 |
9 | base_path: Path = Path(__file__).resolve().parents[0]
10 | for module in [str(i.name)[:-3] for i in base_path.iterdir() if i.is_file() and i.name not in ['__init__.py', 'utils.py']]:
11 | importlib.import_module(f'Visual.Trajectories.{module}')
12 | del base_path, module
13 |
--------------------------------------------------------------------------------
/src/Visual/Trajectories/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Visual/Trajectories/utils.py: Utilities for visualization tasks."""
4 |
5 | from abc import ABC, abstractmethod
6 | import math
7 | from typing import Type
8 |
9 | import torch
10 |
11 | import Framework
12 | from Logging import Logger
13 | from Cameras.Base import BaseCamera
14 | from Cameras.utils import CameraProperties, createLookAtMatrix
15 | from Datasets.Base import BaseDataset
16 |
17 |
18 | class CameraTrajectory(ABC):
19 |
20 | _options: list[str] = []
21 |
22 | def __init__(self) -> None:
23 | super().__init__()
24 | self._trajectory: list[CameraProperties] = []
25 | self.name: str = self.__class__.__name__
26 |
27 | @classmethod
28 | def listOptions(cls) -> list[str]:
29 | """Lists all available camera trajectories."""
30 | if not cls._options:
31 | cls._options = [cls.__name__ for cls in CameraTrajectory.__subclasses__()]
32 | return cls._options
33 |
34 | @classmethod
35 | def get(cls, trajectory_name: str) -> Type['CameraTrajectory']:
36 | """Returns a camera trajectory class by its name."""
37 | options = cls.listOptions()
38 | if trajectory_name not in options:
39 | raise Framework.VisualizationError(f'Unknown camera trajectory type: {trajectory_name}.\nAvailable options are: {options}')
40 | for trajectory in cls.__subclasses__():
41 | if trajectory.__name__ == trajectory_name:
42 | return trajectory
43 |
44 | @abstractmethod
45 | def _generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> list[CameraProperties]:
46 | """Generates the camera trajectory using a list of reference poses."""
47 | pass
48 |
49 | def generate(self, reference_camera: BaseCamera, reference_poses: list[CameraProperties]) -> None:
50 | """Generates the camera trajectory using a list of reference poses."""
51 | Logger.logInfo(f'generating {self.name} trajectory...')
52 | self._trajectory = self._generate(reference_camera, reference_poses)
53 |
54 | def addTo(self, dataset: BaseDataset, reference_set: str | None = 'train') -> BaseDataset:
55 | """Adds the camera trajectory to a dataset."""
56 | if self.name in dataset.subsets:
57 | Logger.logInfo(f'{self.name} trajectory already exists in dataset.')
58 | return dataset
59 | if not self._trajectory:
60 | if reference_set is None:
61 | reference_poses = [*dataset.data['train'], *dataset.data['val'], *dataset.data['test']]
62 | else:
63 | reference_poses = dataset.data[reference_set]
64 | self.generate(reference_camera=dataset.camera, reference_poses=reference_poses)
65 | dataset.subsets.append(self.name)
66 | dataset.data[self.name] = self._trajectory
67 | return dataset
68 |
69 |
70 | def getLemniscateTrajectory(
71 | reference_camera: CameraProperties,
72 | lookat: torch.Tensor,
73 | up: torch.Tensor,
74 | num_frames: int,
75 | degree: float,
76 | ) -> list[torch.Tensor]:
77 | reference_camera = reference_camera.toDefaultDevice()
78 | camera_position = reference_camera.T
79 | a = torch.norm(camera_position - lookat) * math.tan(degree / 360 * math.pi)
80 | # Lemniscate curve in camera space. Starting at the origin.
81 | positions = torch.stack([
82 | torch.tensor([
83 | a * math.cos(t) / (1 + math.sin(t) ** 2),
84 | a * math.cos(t) * math.sin(t) / (1 + math.sin(t) ** 2),
85 | 0,
86 | ]) for t in (torch.linspace(0, 2 * math.pi, num_frames) + math.pi / 2)
87 | ], dim=0)
88 | # Transform to world space.
89 | positions = torch.matmul(reference_camera.R.T, positions[..., None])[..., 0] + camera_position
90 | cameras = [createLookAtMatrix(p, lookat, up) for p in positions]
91 | return cameras
92 |
--------------------------------------------------------------------------------
/src/Visual/utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 |
3 | """Visual/utils.py: Utilities for visualization tasks."""
4 |
5 | from pathlib import Path
6 | import av
7 | from typing import Any
8 |
9 | import torch
10 | from torch import Tensor
11 |
12 | from Logging import Logger
13 | from Visual.ColorMap import ColorMap
14 |
15 |
16 | def pseudoColorDepth(color_map: str, depth: Tensor, near_far: tuple[float, float] | None = None, alpha: Tensor | None = None, interpolate: bool = False) -> Tensor:
17 | """Produces a pseudo-colorized depth image for the given depth tensor."""
18 | # correct depth based on alpha mask
19 | if alpha is not None:
20 | depth = depth.clone()
21 | depth[alpha > 1e-5] = depth[alpha > 1e-5] / alpha[alpha > 1e-5]
22 | # normalize depth to [0, 1]
23 | if near_far is None:
24 | # calculate near and far planes from depth map
25 | masked_depth = depth[alpha > 0.99] if alpha is not None else depth
26 | near_plane = masked_depth.min().item() if masked_depth.numel() > 0 else 0.0
27 | far_plane = masked_depth.max().item() if masked_depth.numel() > 0 else 1.0
28 | else:
29 | near_plane, far_plane = near_far
30 | depth: Tensor = torch.clamp((depth - near_plane) / (far_plane - near_plane), min=0.0, max=1.0)
31 | # apply color map
32 | depth = ColorMap.apply(depth, color_map, interpolate)
33 | # mask color with alpha
34 | if alpha is not None:
35 | depth *= alpha
36 | return depth
37 |
38 |
39 | class VideoWriter:
40 | """Wrapper class to facilitate creation of videos using PyAV."""
41 |
42 | def __init__(
43 | self,
44 | video_paths: Path | list[Path],
45 | width: int,
46 | height: int,
47 | fps: int,
48 | bitrate: int,
49 | video_codec: str = 'libx264rgb',
50 | options: dict[str, Any] | None = None
51 | ) -> None:
52 | self.containers = []
53 | self.streams = []
54 | padding_right = width % 2
55 | padding_bottom = height % 2
56 | self.pad = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom))
57 | if not isinstance(video_paths, list):
58 | video_paths = [video_paths]
59 | for video_path in video_paths:
60 | container = av.open(str(video_path), mode='w')
61 | stream = container.add_stream(video_codec, rate=fps)
62 | stream.codec_context.bit_rate = bitrate * 1e3 # kbps
63 | stream.width = width + padding_right
64 | stream.height = height + padding_bottom
65 | stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
66 | stream.options = options or {}
67 | self.containers.append(container)
68 | self.streams.append(stream)
69 |
70 | def addFrame(self, frames: Tensor | list[Tensor]) -> None:
71 | """Adds given frames to the internal video streams."""
72 | if not isinstance(frames, list):
73 | frames = [frames]
74 | if len(self.streams) != len(frames):
75 | Logger.logWarning(
76 | f'number of frames does not match with number of video streams ({len(frames)} vs. {len(self.streams)})'
77 | )
78 | for frame, stream, container in zip(frames, self.streams, self.containers):
79 | frame = self.pad((frame * 255)).byte().permute(1, 2, 0).cpu().numpy()
80 | frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
81 | frame.pict_type = 0
82 | for packet in stream.encode(frame):
83 | container.mux(packet)
84 |
85 | def close(self) -> None:
86 | """Flushes all streams and closes all containers."""
87 | for stream, container in zip(self.streams, self.containers):
88 | for packet in stream.encode():
89 | container.mux(packet)
90 | container.close()
91 |
--------------------------------------------------------------------------------