├── .github └── FUNDING.yml ├── .gitignore ├── GALLERY.md ├── LICENSE ├── README.md ├── assets └── icon.png ├── benchmarking ├── benchmark_blendedmvs.sh ├── benchmark_mipnerf360.sh ├── benchmark_nerfpp.sh ├── benchmark_rtmv.sh ├── benchmark_synthetic_nerf.sh ├── benchmark_synthetic_nsvf.sh └── benchmark_tat.sh ├── datasets ├── __init__.py ├── base.py ├── colmap.py ├── colmap_utils.py ├── color_utils.py ├── depth_utils.py ├── nerf.py ├── nerfpp.py ├── nsvf.py ├── ray_utils.py └── rtmv.py ├── losses.py ├── metrics.py ├── misc └── prepare_rtmv.py ├── models ├── __init__.py ├── csrc │ ├── binding.cpp │ ├── include │ │ ├── helper_math.h │ │ └── utils.h │ ├── intersection.cu │ ├── losses.cu │ ├── raymarching.cu │ ├── setup.py │ └── volumerendering.cu ├── custom_functions.py ├── networks.py └── rendering.py ├── opt.py ├── requirements.txt ├── show_gui.py ├── test.ipynb ├── train.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: kwea123 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode 2 | /.VSCodeCounter 3 | ckpts/ 4 | logs/ 5 | results/ 6 | *.mp4 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /GALLERY.md: -------------------------------------------------------------------------------- 1 | https://user-images.githubusercontent.com/11364490/177025079-cb92a399-2600-4e10-94e0-7cbe09f32a6f.mp4 2 | 3 | https://user-images.githubusercontent.com/11364490/176821462-83078563-28e1-4563-8e7a-5613b505e54a.mp4 4 | 5 | https://user-images.githubusercontent.com/11364490/180640362-9e63da7c-4268-43ce-874a-3219c7bd778c.mp4 6 | 7 | https://user-images.githubusercontent.com/11364490/181415396-7e378d9e-1d74-43c2-82b0-0f86adfb294d.mp4 8 | 9 | https://user-images.githubusercontent.com/11364490/181415347-1db46d8d-5276-404b-8f7c-65cb0b040976.mp4 10 | 11 | https://user-images.githubusercontent.com/11364490/182150753-28c423f6-8ea8-424c-ade5-0ec3dcbe4987.mp4 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AI葵 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ngp_pl 2 | 3 | ### Advertisement: Check out the latest integrated project [nerfstudio](https://github.com/nerfstudio-project/nerfstudio)! There are a lot of recent improvements on nerf related methods, including instant-ngp! 4 | 5 | 10 | 11 | Instant-ngp (only NeRF) in pytorch+cuda trained with pytorch-lightning (**high quality with high speed**). This repo aims at providing a concise pytorch interface to facilitate future research, and am grateful if you can share it (and a citation is highly appreciated)! 12 | 13 | * [Official CUDA implementation](https://github.com/NVlabs/instant-ngp/tree/master) 14 | * [torch-ngp](https://github.com/ashawkey/torch-ngp) another pytorch implementation that I highly referenced. 15 | 16 | # :paintbrush: Gallery 17 | 18 | https://user-images.githubusercontent.com/11364490/181671484-d5e154c8-6cea-4d52-94b5-1e5dd92955f2.mp4 19 | 20 | Other representative videos are in [GALLERY.md](GALLERY.md) 21 | 22 | # :computer: Installation 23 | 24 | This implementation has **strict** requirements due to dependencies on other libraries, if you encounter installation problem due to hardware/software mismatch, I'm afraid there is **no intention** to support different platforms (you are welcomed to contribute). 25 | 26 | ## Hardware 27 | 28 | * OS: Ubuntu 20.04 29 | * NVIDIA GPU with Compute Compatibility >= 75 and memory > 6GB (Tested with RTX 2080 Ti), CUDA 11.3 (might work with older version) 30 | * 32GB RAM (in order to load full size images) 31 | 32 | ## Software 33 | 34 | * Clone this repo by `git clone https://github.com/kwea123/ngp_pl` 35 | * Python>=3.8 (installation via [anaconda](https://www.anaconda.com/distribution/) is recommended, use `conda create -n ngp_pl python=3.8` to create a conda environment and activate it by `conda activate ngp_pl`) 36 | * Python libraries 37 | * Install pytorch by `pip install torch==1.11.0 --extra-index-url https://download.pytorch.org/whl/cu113` 38 | * Install `torch-scatter` following their [instruction](https://github.com/rusty1s/pytorch_scatter#installation) 39 | * Install `tinycudann` following their [instruction](https://github.com/NVlabs/tiny-cuda-nn#pytorch-extension) (pytorch extension) 40 | * Install `apex` following their [instruction](https://github.com/NVIDIA/apex#linux) 41 | * Install core requirements by `pip install -r requirements.txt` 42 | 43 | * Cuda extension: Upgrade `pip` to >= 22.1 and run `pip install models/csrc/` (please run this each time you `pull` the code) 44 | 45 | # :books: Supported Datasets 46 | 47 | 1. NSVF data 48 | 49 | Download preprocessed datasets (`Synthetic_NeRF`, `Synthetic_NSVF`, `BlendedMVS`, `TanksAndTemples`) from [NSVF](https://github.com/facebookresearch/NSVF#dataset). **Do not change the folder names** since there is some hard-coded fix in my dataloader. 50 | 51 | 2. NeRF++ data 52 | 53 | Download data from [here](https://github.com/Kai-46/nerfplusplus#data). 54 | 55 | 3. Colmap data 56 | 57 | For custom data, run `colmap` and get a folder `sparse/0` under which there are `cameras.bin`, `images.bin` and `points3D.bin`. The following data with colmap format are also supported: 58 | 59 | * [nerf_llff_data](https://drive.google.com/file/d/16VnMcF1KJYxN9QId6TClMsZRahHNMW5g/view?usp=sharing) 60 | * [mipnerf360 data](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) 61 | * [HDR-NeRF data](https://drive.google.com/drive/folders/1OTDLLH8ydKX1DcaNpbQ46LlP0dKx6E-I). Additionally, download my colmap pose estimation from [here](https://drive.google.com/file/d/1TXxgf_ZxNB4o67FVD_r0aBUIZVRgZYMX/view?usp=sharing) and extract to the same location. 62 | 63 | 4. RTMV data 64 | 65 | Download data from [here](http://www.cs.umd.edu/~mmeshry/projects/rtmv/). To convert the hdr images into ldr images for training, run `python misc/prepare_rtmv.py `, it will create `images/` folder under each scene folder, and will use these images to train (and test). 66 | 67 | # :key: Training 68 | 69 | Quickstart: `python train.py --root_dir --exp_name Lego` 70 | 71 | It will train the Lego scene for 30k steps (each step with 8192 rays), and perform one testing at the end. The training process should finish within about 5 minutes (saving testing image is slow, add `--no_save_test` to disable). Testing PSNR will be shown at the end. 72 | 73 | More options can be found in [opt.py](opt.py). 74 | 75 | For other public dataset training, please refer to the scripts under `benchmarking`. 76 | 77 | # :mag_right: Testing 78 | 79 | Use `test.ipynb` to generate images. Lego pretrained model is available [here](https://github.com/kwea123/ngp_pl/releases/tag/v1.0) 80 | 81 | GUI usage: run `python show_gui.py` followed by the **same** hyperparameters used in training (`dataset_name`, `root_dir`, etc) and **add the checkpoint path** with `--ckpt_path ` 82 | 83 | # Comparison with torch-ngp and the paper 84 | 85 | I compared the quality (average testing PSNR on `Synthetic-NeRF`) and the inference speed (on `Lego` scene) v.s. the concurrent work torch-ngp (default settings) and the paper, all trained for about 5 minutes: 86 | 87 | | Method | avg PSNR | FPS | GPU | 88 | | :---: | :---: | :---: | :---: | 89 | | torch-ngp | 31.46 | 18.2 | 2080 Ti | 90 | | mine | 32.96 | 36.2 | 2080 Ti | 91 | | instant-ngp paper | **33.18** | **60** | 3090 | 92 | 93 | As for quality, mine is slightly better than torch-ngp, but the result might fluctuate across different runs. 94 | 95 | As for speed, mine is faster than torch-ngp, but is still only half fast as instant-ngp. Speed is dependent on the scene (if most of the scene is empty, speed will be faster). 96 | 97 |

98 | 99 | 100 |
101 | 102 | 103 |
104 | Left: torch-ngp. Right: mine. 105 |

106 | 107 | # :chart: Benchmarks 108 | 109 | To run benchmarks, use the scripts under `benchmarking`. 110 | 111 | Followings are my results trained using 1 RTX 2080 Ti (qualitative results [here](https://github.com/kwea123/ngp_pl/issues/7)): 112 | 113 |
114 | Synthetic-NeRF 115 | 116 | | | Mic | Ficus | Chair | Hotdog | Materials | Drums | Ship | Lego | AVG | 117 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 118 | | PSNR | 35.59 | 34.13 | 35.28 | 37.35 | 29.46 | 25.81 | 30.32 | 35.76 | 32.96 | 119 | | SSIM | 0.988 | 0.982 | 0.984 | 0.980 | 0.944 | 0.933 | 0.890 | 0.979 | 0.960 | 120 | | LPIPS | 0.017 | 0.024 | 0.025 | 0.038 | 0.070 | 0.076 | 0.133 | 0.022 | 0.051 | 121 | | FPS | 40.81 | 34.02 | 49.80 | 25.06 | 20.08 | 37.77 | 15.77 | 36.20 | 32.44 | 122 | | Training time | 3m9s | 3m12s | 4m17s | 5m53s | 4m55s | 4m7s | 9m20s | 5m5s | 5m00s | 123 | 124 |
125 | 126 |
127 | Synthetic-NSVF 128 | 129 | | | Wineholder | Steamtrain | Toad | Robot | Bike | Palace | Spaceship | Lifestyle | AVG | 130 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 131 | | PSNR | 31.64 | 36.47 | 35.57 | 37.10 | 37.87 | 37.41 | 35.58 | 34.76 | 35.80 | 132 | | SSIM | 0.962 | 0.987 | 0.980 | 0.994 | 0.990 | 0.977 | 0.980 | 0.967 | 0.980 | 133 | | LPIPS | 0.047 | 0.023 | 0.024 | 0.010 | 0.015 | 0.021 | 0.029 | 0.044 | 0.027 | 134 | | FPS | 47.07 | 75.17 | 50.42 | 64.87 | 66.88 | 28.62 | 35.55 | 22.84 | 48.93 | 135 | | Training time | 3m58s | 3m44s | 7m22s | 3m25s | 3m11s | 6m45s | 3m25s | 4m56s | 4m36s | 136 | 137 |
138 | 139 |
140 | Tanks and Temples 141 | 142 | | | Ignatius | Truck | Barn | Caterpillar | Family | AVG | 143 | |:---: | :---: | :---: | :---: | :---: | :---: | :---: | 144 | | PSNR | 28.30 | 27.67 | 28.00 | 26.16 | 34.27 | 28.78 | 145 | | *FPS | 10.04 | 7.99 | 16.14 | 10.91 | 6.16 | 10.25 | 146 | 147 | *Evaluated on `test-traj` 148 | 149 |
150 | 151 |
152 | BlendedMVS 153 | 154 | | | *Jade | *Fountain | Character | Statues | AVG | 155 | |:---: | :---: | :---: | :---: | :---: | :---: | 156 | | PSNR | 25.43 | 26.82 | 30.43 | 26.79 | 27.38 | 157 | | **FPS | 26.02 | 21.24 | 35.99 | 19.22 | 25.61 | 158 | | Training time | 6m31s | 7m15s | 4m50s | 5m57s | 6m48s | 159 | 160 | *I manually switch the background from black to white, so the number isn't directly comparable to that in the papers. 161 | 162 | **Evaluated on `test-traj` 163 | 164 |
165 | 166 | # TODO 167 | 168 | - [ ] use super resolution in GUI to improve FPS 169 | - [ ] multi-sphere images as background 170 | -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwea123/ngp_pl/1b49af1856a276b236e0f17539814134ed329860/assets/icon.png -------------------------------------------------------------------------------- /benchmarking/benchmark_blendedmvs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/BlendedMVS 4 | 5 | python train.py \ 6 | --root_dir $ROOT_DIR/Jade \ 7 | --exp_name Jade --no_save_test \ 8 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 9 | 10 | python train.py \ 11 | --root_dir $ROOT_DIR/Fountain \ 12 | --exp_name Fountain --no_save_test \ 13 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 14 | 15 | python train.py \ 16 | --root_dir $ROOT_DIR/Character \ 17 | --exp_name Character --no_save_test \ 18 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 19 | 20 | python train.py \ 21 | --root_dir $ROOT_DIR/Statues \ 22 | --exp_name Statues --no_save_test \ 23 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 24 | -------------------------------------------------------------------------------- /benchmarking/benchmark_mipnerf360.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/360_v2 4 | export DOWNSAMPLE=0.25 # to avoid OOM 5 | 6 | python train.py \ 7 | --root_dir $ROOT_DIR/bicycle --dataset_name colmap \ 8 | --exp_name bicycle --downsample $DOWNSAMPLE --no_save_test \ 9 | --num_epochs 20 --batch_size 4096 --scale 16.0 --eval_lpips 10 | 11 | python train.py \ 12 | --root_dir $ROOT_DIR/bonsai --dataset_name colmap \ 13 | --exp_name bonsai --downsample $DOWNSAMPLE --no_save_test \ 14 | --num_epochs 20 --batch_size 4096 --scale 16.0 --eval_lpips 15 | 16 | python train.py \ 17 | --root_dir $ROOT_DIR/counter --dataset_name colmap \ 18 | --exp_name counter --downsample $DOWNSAMPLE --no_save_test \ 19 | --num_epochs 20 --scale 16.0 --eval_lpips 20 | 21 | python train.py \ 22 | --root_dir $ROOT_DIR/garden --dataset_name colmap \ 23 | --exp_name garden --downsample $DOWNSAMPLE --no_save_test \ 24 | --num_epochs 20 --scale 16.0 --eval_lpips 25 | 26 | python train.py \ 27 | --root_dir $ROOT_DIR/kitchen --dataset_name colmap \ 28 | --exp_name kitchen --downsample $DOWNSAMPLE --no_save_test \ 29 | --num_epochs 20 --scale 4.0 --eval_lpips 30 | 31 | python train.py \ 32 | --root_dir $ROOT_DIR/room --dataset_name colmap \ 33 | --exp_name room --downsample $DOWNSAMPLE --no_save_test \ 34 | --num_epochs 20 --scale 4.0 --eval_lpips 35 | 36 | python train.py \ 37 | --root_dir $ROOT_DIR/stump --dataset_name colmap \ 38 | --exp_name stump --downsample $DOWNSAMPLE --no_save_test \ 39 | --num_epochs 20 --batch_size 4096 --scale 64.0 --eval_lpips 40 | 41 | python train.py \ 42 | --root_dir $ROOT_DIR/flowers --dataset_name colmap \ 43 | --exp_name bicycle --downsample $DOWNSAMPLE --no_save_test \ 44 | --num_epochs 20 --batch_size 4096 --scale 16.0 --eval_lpips 45 | 46 | python train.py \ 47 | --root_dir $ROOT_DIR/treehill --dataset_name colmap \ 48 | --exp_name bicycle --downsample $DOWNSAMPLE --no_save_test \ 49 | --num_epochs 20 --batch_size 4096 --scale 64.0 --eval_lpips -------------------------------------------------------------------------------- /benchmarking/benchmark_nerfpp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/tanks_and_temples 4 | 5 | python train.py \ 6 | --root_dir $ROOT_DIR/tat_intermediate_M60 --dataset_name nerfpp \ 7 | --exp_name tat_intermediate_M60 --no_save_test \ 8 | --num_epochs 20 --scale 4.0 9 | 10 | python train.py \ 11 | --root_dir $ROOT_DIR/tat_intermediate_Playground --dataset_name nerfpp \ 12 | --exp_name tat_intermediate_Playground --no_save_test \ 13 | --num_epochs 20 --scale 4.0 14 | 15 | python train.py \ 16 | --root_dir $ROOT_DIR/tat_intermediate_Train --dataset_name nerfpp \ 17 | --exp_name tat_intermediate_Train --no_save_test \ 18 | --num_epochs 20 --scale 16.0 --batch_size 4096 19 | 20 | python train.py \ 21 | --root_dir $ROOT_DIR/tat_training_Truck --dataset_name nerfpp \ 22 | --exp_name tat_training_Truck --no_save_test \ 23 | --num_epochs 20 --scale 16.0 --batch_size 4096 24 | 25 | export ROOT_DIR=/home/ubuntu/data/nerf_data/lf_data 26 | 27 | python train.py \ 28 | --root_dir $ROOT_DIR/africa --dataset_name nerfpp \ 29 | --exp_name africa --no_save_test \ 30 | --num_epochs 20 --scale 16.0 --eval_lpips 31 | 32 | # basket fails for some unknown reason (black stripes appear in test image) 33 | # python train.py \ 34 | # --root_dir $ROOT_DIR/basket --dataset_name nerfpp \ 35 | # --exp_name basket --no_save_test \ 36 | # --num_epochs 20 --scale 16.0 --eval_lpips 37 | 38 | python train.py \ 39 | --root_dir $ROOT_DIR/ship --dataset_name nerfpp \ 40 | --exp_name ship --no_save_test \ 41 | --num_epochs 20 --scale 8.0 --eval_lpips 42 | 43 | python train.py \ 44 | --root_dir $ROOT_DIR/statue --dataset_name nerfpp \ 45 | --exp_name statue --no_save_test \ 46 | --num_epochs 20 --scale 16.0 --eval_lpips 47 | 48 | python train.py \ 49 | --root_dir $ROOT_DIR/torch --dataset_name nerfpp \ 50 | --exp_name torch --no_save_test \ 51 | --num_epochs 20 --scale 32.0 --eval_lpips -------------------------------------------------------------------------------- /benchmarking/benchmark_rtmv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # for other environments, change the paths accordingly 4 | # for amazon_berkely, set scale=1.0 5 | export ROOT_DIR=/home/ubuntu/hdd/data/RTMV/bricks 6 | 7 | python train.py \ 8 | --root_dir $ROOT_DIR/4_Privet_Drive \ 9 | --exp_name 4_Privet_Drive --no_save_test \ 10 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 11 | 12 | python train.py \ 13 | --root_dir $ROOT_DIR/Action_Comics_#1_Superman \ 14 | --exp_name Action_Comics_#1_Superman --no_save_test \ 15 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 16 | 17 | python train.py \ 18 | --root_dir $ROOT_DIR/Buried_Treasure! \ 19 | --exp_name Buried_Treasure! --no_save_test \ 20 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 21 | 22 | python train.py \ 23 | --root_dir $ROOT_DIR/Fire_temple \ 24 | --exp_name Fire_temple --no_save_test \ 25 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 26 | 27 | python train.py \ 28 | --root_dir $ROOT_DIR/First_Order_Star_Destroyer \ 29 | --exp_name First_Order_Star_Destroyer --no_save_test \ 30 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 31 | 32 | python train.py \ 33 | --root_dir $ROOT_DIR/Four_Weapons_Blacksmith \ 34 | --exp_name Four_Weapons_Blacksmith --no_save_test \ 35 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 36 | 37 | python train.py \ 38 | --root_dir $ROOT_DIR/NASA_apollo_lunar_excursion_module \ 39 | --exp_name NASA_apollo_lunar_excursion_module --no_save_test \ 40 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 41 | 42 | python train.py \ 43 | --root_dir $ROOT_DIR/Night_Fury_Dragon_-_Lego_Elves_Style \ 44 | --exp_name Night_Fury_Dragon_-_Lego_Elves_Style --no_save_test \ 45 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 46 | 47 | python train.py \ 48 | --root_dir $ROOT_DIR/Oak_Tree \ 49 | --exp_name Oak_Tree --no_save_test \ 50 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 51 | 52 | python train.py \ 53 | --root_dir $ROOT_DIR/V8 \ 54 | --exp_name V8 --no_save_test \ 55 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips -------------------------------------------------------------------------------- /benchmarking/benchmark_synthetic_nerf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/Synthetic_NeRF 4 | 5 | python train.py \ 6 | --root_dir $ROOT_DIR/Chair \ 7 | --exp_name Chair --no_save_test \ 8 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 9 | 10 | python train.py \ 11 | --root_dir $ROOT_DIR/Drums \ 12 | --exp_name Drums --no_save_test \ 13 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 14 | 15 | python train.py \ 16 | --root_dir $ROOT_DIR/Ficus \ 17 | --exp_name Ficus --no_save_test \ 18 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 19 | 20 | python train.py \ 21 | --root_dir $ROOT_DIR/Hotdog \ 22 | --exp_name Hotdog --no_save_test \ 23 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 24 | 25 | python train.py \ 26 | --root_dir $ROOT_DIR/Lego \ 27 | --exp_name Lego --no_save_test \ 28 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 29 | 30 | python train.py \ 31 | --root_dir $ROOT_DIR/Materials \ 32 | --exp_name Materials --no_save_test \ 33 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 34 | 35 | python train.py \ 36 | --root_dir $ROOT_DIR/Mic \ 37 | --exp_name Mic --no_save_test \ 38 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 39 | 40 | python train.py \ 41 | --root_dir $ROOT_DIR/Ship \ 42 | --exp_name Ship --no_save_test \ 43 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips -------------------------------------------------------------------------------- /benchmarking/benchmark_synthetic_nsvf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/Synthetic_NSVF 4 | 5 | python train.py \ 6 | --root_dir $ROOT_DIR/Wineholder \ 7 | --exp_name Wineholder --no_save_test \ 8 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 9 | 10 | python train.py \ 11 | --root_dir $ROOT_DIR/Steamtrain \ 12 | --exp_name Steamtrain --no_save_test \ 13 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 14 | 15 | python train.py \ 16 | --root_dir $ROOT_DIR/Toad \ 17 | --exp_name Toad --no_save_test \ 18 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 19 | 20 | python train.py \ 21 | --root_dir $ROOT_DIR/Robot \ 22 | --exp_name Robot --no_save_test \ 23 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 24 | 25 | python train.py \ 26 | --root_dir $ROOT_DIR/Bike \ 27 | --exp_name Bike --no_save_test \ 28 | --num_epochs 20 --batch_size 16384 --lr 1e-2 --eval_lpips 29 | 30 | python train.py \ 31 | --root_dir $ROOT_DIR/Palace \ 32 | --exp_name Palace --no_save_test \ 33 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 34 | 35 | python train.py \ 36 | --root_dir $ROOT_DIR/Spaceship \ 37 | --exp_name Spaceship --no_save_test \ 38 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 39 | 40 | python train.py \ 41 | --root_dir $ROOT_DIR/Lifestyle \ 42 | --exp_name Lifestyle --no_save_test \ 43 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips -------------------------------------------------------------------------------- /benchmarking/benchmark_tat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_DIR=/home/ubuntu/data/nerf_data/TanksAndTemple 4 | export DOWNSAMPLE=0.5 # to avoid OOM 5 | 6 | python train.py \ 7 | --root_dir $ROOT_DIR/Ignatius \ 8 | --exp_name Ignatius --downsample $DOWNSAMPLE --no_save_test \ 9 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 10 | 11 | python train.py \ 12 | --root_dir $ROOT_DIR/Truck \ 13 | --exp_name Truck --downsample $DOWNSAMPLE --no_save_test \ 14 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 15 | 16 | python train.py \ 17 | --root_dir $ROOT_DIR/Barn \ 18 | --exp_name Barn --downsample $DOWNSAMPLE --no_save_test \ 19 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 20 | 21 | python train.py \ 22 | --root_dir $ROOT_DIR/Caterpillar \ 23 | --exp_name Caterpillar --downsample $DOWNSAMPLE --no_save_test \ 24 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 25 | 26 | python train.py \ 27 | --root_dir $ROOT_DIR/Family \ 28 | --exp_name Family --downsample $DOWNSAMPLE --no_save_test \ 29 | --num_epochs 20 --batch_size 16384 --lr 2e-2 --eval_lpips 30 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nerf import NeRFDataset 2 | from .nsvf import NSVFDataset 3 | from .colmap import ColmapDataset 4 | from .nerfpp import NeRFPPDataset 5 | from .rtmv import RTMVDataset 6 | 7 | 8 | dataset_dict = {'nerf': NeRFDataset, 9 | 'nsvf': NSVFDataset, 10 | 'colmap': ColmapDataset, 11 | 'nerfpp': NeRFPPDataset, 12 | 'rtmv': RTMVDataset} -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(Dataset): 6 | """ 7 | Define length and sampling method 8 | """ 9 | def __init__(self, root_dir, split='train', downsample=1.0): 10 | self.root_dir = root_dir 11 | self.split = split 12 | self.downsample = downsample 13 | 14 | def read_intrinsics(self): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | if self.split.startswith('train'): 19 | return 1000 20 | return len(self.poses) 21 | 22 | def __getitem__(self, idx): 23 | if self.split.startswith('train'): 24 | # training pose is retrieved in train.py 25 | if self.ray_sampling_strategy == 'all_images': # randomly select images 26 | img_idxs = np.random.choice(len(self.poses), self.batch_size) 27 | elif self.ray_sampling_strategy == 'same_image': # randomly select ONE image 28 | img_idxs = np.random.choice(len(self.poses), 1)[0] 29 | # randomly select pixels 30 | pix_idxs = np.random.choice(self.img_wh[0]*self.img_wh[1], self.batch_size) 31 | rays = self.rays[img_idxs, pix_idxs] 32 | sample = {'img_idxs': img_idxs, 'pix_idxs': pix_idxs, 33 | 'rgb': rays[:, :3]} 34 | if self.rays.shape[-1] == 4: # HDR-NeRF data 35 | sample['exposure'] = rays[:, 3:] 36 | else: 37 | sample = {'pose': self.poses[idx], 'img_idxs': idx} 38 | if len(self.rays)>0: # if ground truth available 39 | rays = self.rays[idx] 40 | sample['rgb'] = rays[:, :3] 41 | if rays.shape[1] == 4: # HDR-NeRF data 42 | sample['exposure'] = rays[0, 3] # same exposure for all rays 43 | 44 | return sample -------------------------------------------------------------------------------- /datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import * 8 | from .color_utils import read_image 9 | from .colmap_utils import \ 10 | read_cameras_binary, read_images_binary, read_points3d_binary 11 | 12 | from .base import BaseDataset 13 | 14 | 15 | class ColmapDataset(BaseDataset): 16 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 17 | super().__init__(root_dir, split, downsample) 18 | 19 | self.read_intrinsics() 20 | 21 | if kwargs.get('read_meta', True): 22 | self.read_meta(split, **kwargs) 23 | 24 | def read_intrinsics(self): 25 | # Step 1: read and scale intrinsics (same for all images) 26 | camdata = read_cameras_binary(os.path.join(self.root_dir, 'sparse/0/cameras.bin')) 27 | h = int(camdata[1].height*self.downsample) 28 | w = int(camdata[1].width*self.downsample) 29 | self.img_wh = (w, h) 30 | 31 | if camdata[1].model == 'SIMPLE_RADIAL': 32 | fx = fy = camdata[1].params[0]*self.downsample 33 | cx = camdata[1].params[1]*self.downsample 34 | cy = camdata[1].params[2]*self.downsample 35 | elif camdata[1].model in ['PINHOLE', 'OPENCV']: 36 | fx = camdata[1].params[0]*self.downsample 37 | fy = camdata[1].params[1]*self.downsample 38 | cx = camdata[1].params[2]*self.downsample 39 | cy = camdata[1].params[3]*self.downsample 40 | else: 41 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") 42 | self.K = torch.FloatTensor([[fx, 0, cx], 43 | [0, fy, cy], 44 | [0, 0, 1]]) 45 | self.directions = get_ray_directions(h, w, self.K) 46 | 47 | def read_meta(self, split, **kwargs): 48 | # Step 2: correct poses 49 | # read extrinsics (of successfully reconstructed images) 50 | imdata = read_images_binary(os.path.join(self.root_dir, 'sparse/0/images.bin')) 51 | img_names = [imdata[k].name for k in imdata] 52 | perm = np.argsort(img_names) 53 | if '360_v2' in self.root_dir and self.downsample<1: # mipnerf360 data 54 | folder = f'images_{int(1/self.downsample)}' 55 | else: 56 | folder = 'images' 57 | # read successfully reconstructed images and ignore others 58 | img_paths = [os.path.join(self.root_dir, folder, name) 59 | for name in sorted(img_names)] 60 | w2c_mats = [] 61 | bottom = np.array([[0, 0, 0, 1.]]) 62 | for k in imdata: 63 | im = imdata[k] 64 | R = im.qvec2rotmat(); t = im.tvec.reshape(3, 1) 65 | w2c_mats += [np.concatenate([np.concatenate([R, t], 1), bottom], 0)] 66 | w2c_mats = np.stack(w2c_mats, 0) 67 | poses = np.linalg.inv(w2c_mats)[perm, :3] # (N_images, 3, 4) cam2world matrices 68 | 69 | pts3d = read_points3d_binary(os.path.join(self.root_dir, 'sparse/0/points3D.bin')) 70 | pts3d = np.array([pts3d[k].xyz for k in pts3d]) # (N, 3) 71 | 72 | self.poses, self.pts3d = center_poses(poses, pts3d) 73 | 74 | scale = np.linalg.norm(self.poses[..., 3], axis=-1).min() 75 | self.poses[..., 3] /= scale 76 | self.pts3d /= scale 77 | 78 | self.rays = [] 79 | if split == 'test_traj': # use precomputed test poses 80 | self.poses = create_spheric_poses(1.2, self.poses[:, 1, 3].mean()) 81 | self.poses = torch.FloatTensor(self.poses) 82 | return 83 | 84 | if 'HDR-NeRF' in self.root_dir: # HDR-NeRF data 85 | if 'syndata' in self.root_dir: # synthetic 86 | # first 17 are test, last 18 are train 87 | self.unit_exposure_rgb = 0.73 88 | if split=='train': 89 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 90 | f'train/*[024].png'))) 91 | self.poses = np.repeat(self.poses[-18:], 3, 0) 92 | elif split=='test': 93 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 94 | f'test/*[13].png'))) 95 | self.poses = np.repeat(self.poses[:17], 2, 0) 96 | else: 97 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 98 | else: # real 99 | self.unit_exposure_rgb = 0.5 100 | # even numbers are train, odd numbers are test 101 | if split=='train': 102 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 103 | f'input_images/*0.jpg')))[::2] 104 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 105 | f'input_images/*2.jpg')))[::2] 106 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 107 | f'input_images/*4.jpg')))[::2] 108 | self.poses = np.tile(self.poses[::2], (3, 1, 1)) 109 | elif split=='test': 110 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 111 | f'input_images/*1.jpg')))[1::2] 112 | img_paths+= sorted(glob.glob(os.path.join(self.root_dir, 113 | f'input_images/*3.jpg')))[1::2] 114 | self.poses = np.tile(self.poses[1::2], (2, 1, 1)) 115 | else: 116 | raise ValueError(f"split {split} is invalid for HDR-NeRF!") 117 | else: 118 | # use every 8th image as test set 119 | if split=='train': 120 | img_paths = [x for i, x in enumerate(img_paths) if i%8!=0] 121 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8!=0]) 122 | elif split=='test': 123 | img_paths = [x for i, x in enumerate(img_paths) if i%8==0] 124 | self.poses = np.array([x for i, x in enumerate(self.poses) if i%8==0]) 125 | 126 | print(f'Loading {len(img_paths)} {split} images ...') 127 | for img_path in tqdm(img_paths): 128 | buf = [] # buffer for ray attributes: rgb, etc 129 | 130 | img = read_image(img_path, self.img_wh, blend_a=False) 131 | img = torch.FloatTensor(img) 132 | buf += [img] 133 | 134 | if 'HDR-NeRF' in self.root_dir: # get exposure 135 | folder = self.root_dir.split('/') 136 | scene = folder[-1] if folder[-1] != '' else folder[-2] 137 | if scene in ['bathroom', 'bear', 'chair', 'desk']: 138 | e_dict = {e: 1/8*4**e for e in range(5)} 139 | elif scene in ['diningroom', 'dog']: 140 | e_dict = {e: 1/16*4**e for e in range(5)} 141 | elif scene in ['sofa']: 142 | e_dict = {0:0.25, 1:1, 2:2, 3:4, 4:16} 143 | elif scene in ['sponza']: 144 | e_dict = {0:0.5, 1:2, 2:4, 3:8, 4:32} 145 | elif scene in ['box']: 146 | e_dict = {0:2/3, 1:1/3, 2:1/6, 3:0.1, 4:0.05} 147 | elif scene in ['computer']: 148 | e_dict = {0:1/3, 1:1/8, 2:1/15, 3:1/30, 4:1/60} 149 | elif scene in ['flower']: 150 | e_dict = {0:1/3, 1:1/6, 2:0.1, 3:0.05, 4:1/45} 151 | elif scene in ['luckycat']: 152 | e_dict = {0:2, 1:1, 2:0.5, 3:0.25, 4:0.125} 153 | e = int(img_path.split('.')[0][-1]) 154 | buf += [e_dict[e]*torch.ones_like(img[:, :1])] 155 | 156 | self.rays += [torch.cat(buf, 1)] 157 | 158 | self.rays = torch.stack(self.rays) # (N_images, hw, ?) 159 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) -------------------------------------------------------------------------------- /datasets/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec -------------------------------------------------------------------------------- /datasets/color_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from einops import rearrange 3 | import imageio 4 | import numpy as np 5 | 6 | 7 | def srgb_to_linear(img): 8 | limit = 0.04045 9 | return np.where(img>limit, ((img+0.055)/1.055)**2.4, img/12.92) 10 | 11 | 12 | def linear_to_srgb(img): 13 | limit = 0.0031308 14 | img = np.where(img>limit, 1.055*img**(1/2.4)-0.055, 12.92*img) 15 | img[img>1] = 1 # "clamp" tonemapper 16 | return img 17 | 18 | 19 | def read_image(img_path, img_wh, blend_a=True): 20 | img = imageio.imread(img_path).astype(np.float32)/255.0 21 | # img[..., :3] = srgb_to_linear(img[..., :3]) 22 | if img.shape[2] == 4: # blend A to RGB 23 | if blend_a: 24 | img = img[..., :3]*img[..., -1:]+(1-img[..., -1:]) 25 | else: 26 | img = img[..., :3]*img[..., -1:] 27 | 28 | img = cv2.resize(img, img_wh) 29 | img = rearrange(img, 'h w c -> (h w) c') 30 | 31 | return img -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | 5 | def read_pfm(path): 6 | """Read pfm file. 7 | 8 | Args: 9 | path (str): path to file 10 | 11 | Returns: 12 | tuple: (data, scale) 13 | """ 14 | with open(path, "rb") as file: 15 | 16 | color = None 17 | width = None 18 | height = None 19 | scale = None 20 | endian = None 21 | 22 | header = file.readline().rstrip() 23 | if header.decode("ascii") == "PF": 24 | color = True 25 | elif header.decode("ascii") == "Pf": 26 | color = False 27 | else: 28 | raise Exception("Not a PFM file: " + path) 29 | 30 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 31 | if dim_match: 32 | width, height = list(map(int, dim_match.groups())) 33 | else: 34 | raise Exception("Malformed PFM header.") 35 | 36 | scale = float(file.readline().decode("ascii").rstrip()) 37 | if scale < 0: 38 | # little-endian 39 | endian = "<" 40 | scale = -scale 41 | else: 42 | # big-endian 43 | endian = ">" 44 | 45 | data = np.fromfile(file, endian + "f") 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | 51 | return data, scale -------------------------------------------------------------------------------- /datasets/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NeRFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | self.read_meta(split) 21 | 22 | def read_intrinsics(self): 23 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 24 | meta = json.load(f) 25 | 26 | w = h = int(800*self.downsample) 27 | fx = fy = 0.5*800/np.tan(0.5*meta['camera_angle_x'])*self.downsample 28 | 29 | K = np.float32([[fx, 0, w/2], 30 | [0, fy, h/2], 31 | [0, 0, 1]]) 32 | 33 | self.K = torch.FloatTensor(K) 34 | self.directions = get_ray_directions(h, w, self.K) 35 | self.img_wh = (w, h) 36 | 37 | def read_meta(self, split): 38 | self.rays = [] 39 | self.poses = [] 40 | 41 | if split == 'trainval': 42 | with open(os.path.join(self.root_dir, "transforms_train.json"), 'r') as f: 43 | frames = json.load(f)["frames"] 44 | with open(os.path.join(self.root_dir, "transforms_val.json"), 'r') as f: 45 | frames+= json.load(f)["frames"] 46 | else: 47 | with open(os.path.join(self.root_dir, f"transforms_{split}.json"), 'r') as f: 48 | frames = json.load(f)["frames"] 49 | 50 | print(f'Loading {len(frames)} {split} images ...') 51 | for frame in tqdm(frames): 52 | c2w = np.array(frame['transform_matrix'])[:3, :4] 53 | 54 | # determine scale 55 | if 'Jrender_Dataset' in self.root_dir: 56 | c2w[:, :2] *= -1 # [left up front] to [right down front] 57 | folder = self.root_dir.split('/') 58 | scene = folder[-1] if folder[-1] != '' else folder[-2] 59 | if scene=='Easyship': 60 | pose_radius_scale = 1.2 61 | elif scene=='Scar': 62 | pose_radius_scale = 1.8 63 | elif scene=='Coffee': 64 | pose_radius_scale = 2.5 65 | elif scene=='Car': 66 | pose_radius_scale = 0.8 67 | else: 68 | pose_radius_scale = 1.5 69 | else: 70 | c2w[:, 1:3] *= -1 # [right up back] to [right down front] 71 | pose_radius_scale = 1.5 72 | c2w[:, 3] /= np.linalg.norm(c2w[:, 3])/pose_radius_scale 73 | 74 | # add shift 75 | if 'Jrender_Dataset' in self.root_dir: 76 | if scene=='Coffee': 77 | c2w[1, 3] -= 0.4465 78 | elif scene=='Car': 79 | c2w[0, 3] -= 0.7 80 | self.poses += [c2w] 81 | 82 | try: 83 | img_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 84 | img = read_image(img_path, self.img_wh) 85 | self.rays += [img] 86 | except: pass 87 | 88 | if len(self.rays)>0: 89 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 90 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 91 | -------------------------------------------------------------------------------- /datasets/nerfpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from .ray_utils import get_ray_directions 9 | from .color_utils import read_image 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class NeRFPPDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | K = np.loadtxt(glob.glob(os.path.join(self.root_dir, 'train/intrinsics/*.txt'))[0], 25 | dtype=np.float32).reshape(4, 4)[:3, :3] 26 | K[:2] *= self.downsample 27 | w, h = Image.open(glob.glob(os.path.join(self.root_dir, 'train/rgb/*'))[0]).size 28 | w, h = int(w*self.downsample), int(h*self.downsample) 29 | self.K = torch.FloatTensor(K) 30 | self.directions = get_ray_directions(h, w, self.K) 31 | self.img_wh = (w, h) 32 | 33 | def read_meta(self, split): 34 | self.rays = [] 35 | self.poses = [] 36 | 37 | if split == 'test_traj': 38 | poses_path = \ 39 | sorted(glob.glob(os.path.join(self.root_dir, 'camera_path/pose/*.txt'))) 40 | self.poses = [np.loadtxt(p).reshape(4, 4)[:3] for p in poses_path] 41 | else: 42 | if split=='trainval': 43 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'train/rgb/*')))+\ 44 | sorted(glob.glob(os.path.join(self.root_dir, 'val/rgb/*'))) 45 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'train/pose/*.txt')))+\ 46 | sorted(glob.glob(os.path.join(self.root_dir, 'val/pose/*.txt'))) 47 | else: 48 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, split, 'rgb/*'))) 49 | poses = sorted(glob.glob(os.path.join(self.root_dir, split, 'pose/*.txt'))) 50 | 51 | print(f'Loading {len(img_paths)} {split} images ...') 52 | for img_path, pose in tqdm(zip(img_paths, poses)): 53 | self.poses += [np.loadtxt(pose).reshape(4, 4)[:3]] 54 | 55 | img = read_image(img_path, self.img_wh) 56 | self.rays += [img] 57 | 58 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 59 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 60 | -------------------------------------------------------------------------------- /datasets/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import numpy as np 4 | import os 5 | from tqdm import tqdm 6 | 7 | from .ray_utils import get_ray_directions 8 | from .color_utils import read_image 9 | 10 | from .base import BaseDataset 11 | 12 | 13 | class NSVFDataset(BaseDataset): 14 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 15 | super().__init__(root_dir, split, downsample) 16 | 17 | self.read_intrinsics() 18 | 19 | if kwargs.get('read_meta', True): 20 | xyz_min, xyz_max = \ 21 | np.loadtxt(os.path.join(root_dir, 'bbox.txt'))[:6].reshape(2, 3) 22 | self.shift = (xyz_max+xyz_min)/2 23 | self.scale = (xyz_max-xyz_min).max()/2 * 1.05 # enlarge a little 24 | 25 | # hard-code fix the bound error for some scenes... 26 | if 'Mic' in self.root_dir: self.scale *= 1.2 27 | elif 'Lego' in self.root_dir: self.scale *= 1.1 28 | 29 | self.read_meta(split) 30 | 31 | def read_intrinsics(self): 32 | if 'Synthetic' in self.root_dir or 'Ignatius' in self.root_dir: 33 | with open(os.path.join(self.root_dir, 'intrinsics.txt')) as f: 34 | fx = fy = float(f.readline().split()[0]) * self.downsample 35 | if 'Synthetic' in self.root_dir: 36 | w = h = int(800*self.downsample) 37 | else: 38 | w, h = int(1920*self.downsample), int(1080*self.downsample) 39 | 40 | K = np.float32([[fx, 0, w/2], 41 | [0, fy, h/2], 42 | [0, 0, 1]]) 43 | else: 44 | K = np.loadtxt(os.path.join(self.root_dir, 'intrinsics.txt'), 45 | dtype=np.float32)[:3, :3] 46 | if 'BlendedMVS' in self.root_dir: 47 | w, h = int(768*self.downsample), int(576*self.downsample) 48 | elif 'Tanks' in self.root_dir: 49 | w, h = int(1920*self.downsample), int(1080*self.downsample) 50 | K[:2] *= self.downsample 51 | 52 | self.K = torch.FloatTensor(K) 53 | self.directions = get_ray_directions(h, w, self.K) 54 | self.img_wh = (w, h) 55 | 56 | def read_meta(self, split): 57 | self.rays = [] 58 | self.poses = [] 59 | 60 | if split == 'test_traj': # BlendedMVS and TanksAndTemple 61 | if 'Ignatius' in self.root_dir: 62 | poses_path = \ 63 | sorted(glob.glob(os.path.join(self.root_dir, 'test_pose/*.txt'))) 64 | poses = [np.loadtxt(p) for p in poses_path] 65 | else: 66 | poses = np.loadtxt(os.path.join(self.root_dir, 'test_traj.txt')) 67 | poses = poses.reshape(-1, 4, 4) 68 | for pose in poses: 69 | c2w = pose[:3] 70 | c2w[:, 0] *= -1 # [left down front] to [right down front] 71 | c2w[:, 3] -= self.shift 72 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 73 | self.poses += [c2w] 74 | else: 75 | if split == 'train': prefix = '0_' 76 | elif split == 'trainval': prefix = '[0-1]_' 77 | elif split == 'trainvaltest': prefix = '[0-2]_' 78 | elif split == 'val': prefix = '1_' 79 | elif 'Synthetic' in self.root_dir: prefix = '2_' # test set for synthetic scenes 80 | elif split == 'test': prefix = '1_' # test set for real scenes 81 | else: raise ValueError(f'{split} split not recognized!') 82 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'rgb', prefix+'*.png'))) 83 | poses = sorted(glob.glob(os.path.join(self.root_dir, 'pose', prefix+'*.txt'))) 84 | 85 | print(f'Loading {len(img_paths)} {split} images ...') 86 | for img_path, pose in tqdm(zip(img_paths, poses)): 87 | c2w = np.loadtxt(pose)[:3] 88 | c2w[:, 3] -= self.shift 89 | c2w[:, 3] /= 2*self.scale # to bound the scene inside [-0.5, 0.5] 90 | self.poses += [c2w] 91 | 92 | img = read_image(img_path, self.img_wh) 93 | if 'Jade' in self.root_dir or 'Fountain' in self.root_dir: 94 | # these scenes have black background, changing to white 95 | img[torch.all(img<=0.1, dim=-1)] = 1.0 96 | 97 | self.rays += [img] 98 | 99 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 100 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 101 | -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from kornia import create_meshgrid 4 | from einops import rearrange 5 | 6 | 7 | @torch.cuda.amp.autocast(dtype=torch.float32) 8 | def get_ray_directions(H, W, K, device='cpu', random=False, return_uv=False, flatten=True): 9 | """ 10 | Get ray directions for all pixels in camera coordinate [right down front]. 11 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 12 | ray-tracing-generating-camera-rays/standard-coordinate-systems 13 | 14 | Inputs: 15 | H, W: image height and width 16 | K: (3, 3) camera intrinsics 17 | random: whether the ray passes randomly inside the pixel 18 | return_uv: whether to return uv image coordinates 19 | 20 | Outputs: (shape depends on @flatten) 21 | directions: (H, W, 3) or (H*W, 3), the direction of the rays in camera coordinate 22 | uv: (H, W, 2) or (H*W, 2) image coordinates 23 | """ 24 | grid = create_meshgrid(H, W, False, device=device)[0] # (H, W, 2) 25 | u, v = grid.unbind(-1) 26 | 27 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 28 | if random: 29 | directions = \ 30 | torch.stack([(u-cx+torch.rand_like(u))/fx, 31 | (v-cy+torch.rand_like(v))/fy, 32 | torch.ones_like(u)], -1) 33 | else: # pass by the center 34 | directions = \ 35 | torch.stack([(u-cx+0.5)/fx, (v-cy+0.5)/fy, torch.ones_like(u)], -1) 36 | if flatten: 37 | directions = directions.reshape(-1, 3) 38 | grid = grid.reshape(-1, 2) 39 | 40 | if return_uv: 41 | return directions, grid 42 | return directions 43 | 44 | 45 | @torch.cuda.amp.autocast(dtype=torch.float32) 46 | def get_rays(directions, c2w): 47 | """ 48 | Get ray origin and directions in world coordinate for all pixels in one image. 49 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 50 | ray-tracing-generating-camera-rays/standard-coordinate-systems 51 | 52 | Inputs: 53 | directions: (N, 3) ray directions in camera coordinate 54 | c2w: (3, 4) or (N, 3, 4) transformation matrix from camera coordinate to world coordinate 55 | 56 | Outputs: 57 | rays_o: (N, 3), the origin of the rays in world coordinate 58 | rays_d: (N, 3), the direction of the rays in world coordinate 59 | """ 60 | if c2w.ndim==2: 61 | # Rotate ray directions from camera coordinate to the world coordinate 62 | rays_d = directions @ c2w[:, :3].T 63 | else: 64 | rays_d = rearrange(directions, 'n c -> n 1 c') @ \ 65 | rearrange(c2w[..., :3], 'n a b -> n b a') 66 | rays_d = rearrange(rays_d, 'n 1 c -> n c') 67 | # The origin of all rays is the camera origin in world coordinate 68 | rays_o = c2w[..., 3].expand_as(rays_d) 69 | 70 | return rays_o, rays_d 71 | 72 | 73 | @torch.cuda.amp.autocast(dtype=torch.float32) 74 | def axisangle_to_R(v): 75 | """ 76 | Convert an axis-angle vector to rotation matrix 77 | from https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py#L47 78 | 79 | Inputs: 80 | v: (3) or (B, 3) 81 | 82 | Outputs: 83 | R: (3, 3) or (B, 3, 3) 84 | """ 85 | v_ndim = v.ndim 86 | if v_ndim==1: 87 | v = rearrange(v, 'c -> 1 c') 88 | zero = torch.zeros_like(v[:, :1]) # (B, 1) 89 | skew_v0 = torch.cat([zero, -v[:, 2:3], v[:, 1:2]], 1) # (B, 3) 90 | skew_v1 = torch.cat([v[:, 2:3], zero, -v[:, 0:1]], 1) 91 | skew_v2 = torch.cat([-v[:, 1:2], v[:, 0:1], zero], 1) 92 | skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=1) # (B, 3, 3) 93 | 94 | norm_v = rearrange(torch.norm(v, dim=1)+1e-7, 'b -> b 1 1') 95 | eye = torch.eye(3, device=v.device) 96 | R = eye + (torch.sin(norm_v)/norm_v)*skew_v + \ 97 | ((1-torch.cos(norm_v))/norm_v**2)*(skew_v@skew_v) 98 | if v_ndim==1: 99 | R = rearrange(R, '1 c d -> c d') 100 | return R 101 | 102 | 103 | def normalize(v): 104 | """Normalize a vector.""" 105 | return v/np.linalg.norm(v) 106 | 107 | 108 | def average_poses(poses, pts3d=None): 109 | """ 110 | Calculate the average pose, which is then used to center all poses 111 | using @center_poses. Its computation is as follows: 112 | 1. Compute the center: the average of 3d point cloud (if None, center of cameras). 113 | 2. Compute the z axis: the normalized average z axis. 114 | 3. Compute axis y': the average y axis. 115 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 116 | 5. Compute the y axis: z cross product x. 117 | 118 | Note that at step 3, we cannot directly use y' as y axis since it's 119 | not necessarily orthogonal to z axis. We need to pass from x to y. 120 | Inputs: 121 | poses: (N_images, 3, 4) 122 | pts3d: (N, 3) 123 | 124 | Outputs: 125 | pose_avg: (3, 4) the average pose 126 | """ 127 | # 1. Compute the center 128 | if pts3d is not None: 129 | center = pts3d.mean(0) 130 | else: 131 | center = poses[..., 3].mean(0) 132 | 133 | # 2. Compute the z axis 134 | z = normalize(poses[..., 2].mean(0)) # (3) 135 | 136 | # 3. Compute axis y' (no need to normalize as it's not the final output) 137 | y_ = poses[..., 1].mean(0) # (3) 138 | 139 | # 4. Compute the x axis 140 | x = normalize(np.cross(y_, z)) # (3) 141 | 142 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 143 | y = np.cross(z, x) # (3) 144 | 145 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 146 | 147 | return pose_avg 148 | 149 | 150 | def center_poses(poses, pts3d=None): 151 | """ 152 | See https://github.com/bmild/nerf/issues/34 153 | Inputs: 154 | poses: (N_images, 3, 4) 155 | pts3d: (N, 3) reconstructed point cloud 156 | 157 | Outputs: 158 | poses_centered: (N_images, 3, 4) the centered poses 159 | pts3d_centered: (N, 3) centered point cloud 160 | """ 161 | 162 | pose_avg = average_poses(poses, pts3d) # (3, 4) 163 | pose_avg_homo = np.eye(4) 164 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 165 | # by simply adding 0, 0, 0, 1 as the last row 166 | pose_avg_inv = np.linalg.inv(pose_avg_homo) 167 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 168 | poses_homo = \ 169 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 170 | 171 | poses_centered = pose_avg_inv @ poses_homo # (N_images, 4, 4) 172 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 173 | 174 | if pts3d is not None: 175 | pts3d_centered = pts3d @ pose_avg_inv[:, :3].T + pose_avg_inv[:, 3:].T 176 | return poses_centered, pts3d_centered 177 | 178 | return poses_centered 179 | 180 | def create_spheric_poses(radius, mean_h, n_poses=120): 181 | """ 182 | Create circular poses around z axis. 183 | Inputs: 184 | radius: the (negative) height and the radius of the circle. 185 | mean_h: mean camera height 186 | Outputs: 187 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 188 | """ 189 | def spheric_pose(theta, phi, radius): 190 | trans_t = lambda t : np.array([ 191 | [1,0,0,0], 192 | [0,1,0,2*mean_h], 193 | [0,0,1,-t] 194 | ]) 195 | 196 | rot_phi = lambda phi : np.array([ 197 | [1,0,0], 198 | [0,np.cos(phi),-np.sin(phi)], 199 | [0,np.sin(phi), np.cos(phi)] 200 | ]) 201 | 202 | rot_theta = lambda th : np.array([ 203 | [np.cos(th),0,-np.sin(th)], 204 | [0,1,0], 205 | [np.sin(th),0, np.cos(th)] 206 | ]) 207 | 208 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 209 | c2w = np.array([[-1,0,0],[0,0,1],[0,1,0]]) @ c2w 210 | return c2w 211 | 212 | spheric_poses = [] 213 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: 214 | spheric_poses += [spheric_pose(th, -np.pi/12, radius)] 215 | return np.stack(spheric_poses, 0) -------------------------------------------------------------------------------- /datasets/rtmv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import json 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | 8 | from .ray_utils import get_ray_directions 9 | from .color_utils import read_image 10 | 11 | from .base import BaseDataset 12 | 13 | 14 | class RTMVDataset(BaseDataset): 15 | def __init__(self, root_dir, split='train', downsample=1.0, **kwargs): 16 | super().__init__(root_dir, split, downsample) 17 | 18 | self.read_intrinsics() 19 | 20 | if kwargs.get('read_meta', True): 21 | self.read_meta(split) 22 | 23 | def read_intrinsics(self): 24 | with open(os.path.join(self.root_dir, '00000.json'), 'r') as f: 25 | meta = json.load(f)['camera_data'] 26 | 27 | self.shift = np.array(meta['scene_center_3d_box']) 28 | self.scale = (np.array(meta['scene_max_3d_box'])- 29 | np.array(meta['scene_min_3d_box'])).max()/2 * 1.05 # enlarge a little 30 | 31 | fx = meta['intrinsics']['fx'] * self.downsample 32 | fy = meta['intrinsics']['fy'] * self.downsample 33 | cx = meta['intrinsics']['cx'] * self.downsample 34 | cy = meta['intrinsics']['cy'] * self.downsample 35 | w = int(meta['width']*self.downsample) 36 | h = int(meta['height']*self.downsample) 37 | K = np.float32([[fx, 0, cx], 38 | [0, fy, cy], 39 | [0, 0, 1]]) 40 | self.K = torch.FloatTensor(K) 41 | self.directions = get_ray_directions(h, w, self.K) 42 | self.img_wh = (w, h) 43 | 44 | def read_meta(self, split): 45 | self.rays = [] 46 | self.poses = [] 47 | 48 | if split == 'train': start_idx, end_idx = 0, 100 49 | elif split == 'trainval': start_idx, end_idx = 0, 105 50 | elif split == 'test': start_idx, end_idx = 105, 150 51 | else: start_idx, end_idx = 0, 150 52 | img_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*')))[start_idx:end_idx] 53 | poses = sorted(glob.glob(os.path.join(self.root_dir, '*.json')))[start_idx:end_idx] 54 | 55 | print(f'Loading {len(img_paths)} {split} images ...') 56 | for img_path, pose in tqdm(zip(img_paths, poses)): 57 | with open(pose, 'r') as f: 58 | p = json.load(f)['camera_data'] 59 | c2w = np.array(p['cam2world']).T[:3] 60 | c2w[:, 1:3] *= -1 61 | if 'bricks' in self.root_dir: 62 | c2w[:, 3] -= self.shift 63 | c2w[:, 3] /= 2*self.scale # bound in [-0.5, 0.5] 64 | self.poses += [c2w] 65 | 66 | img = read_image(img_path, self.img_wh) 67 | self.rays += [img] 68 | 69 | self.rays = torch.FloatTensor(np.stack(self.rays)) # (N_images, hw, ?) 70 | self.poses = torch.FloatTensor(self.poses) # (N_images, 3, 4) 71 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import vren 4 | 5 | 6 | class DistortionLoss(torch.autograd.Function): 7 | """ 8 | Distortion loss proposed in Mip-NeRF 360 (https://arxiv.org/pdf/2111.12077.pdf) 9 | Implementation is based on DVGO-v2 (https://arxiv.org/pdf/2206.05085.pdf) 10 | 11 | Inputs: 12 | ws: (N) sample point weights 13 | deltas: (N) considered as intervals 14 | ts: (N) considered as midpoints 15 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 16 | meaning each entry corresponds to the @ray_idx th ray, 17 | whose samples are [start_idx:start_idx+N_samples] 18 | 19 | Outputs: 20 | loss: (N_rays) 21 | """ 22 | @staticmethod 23 | def forward(ctx, ws, deltas, ts, rays_a): 24 | loss, ws_inclusive_scan, wts_inclusive_scan = \ 25 | vren.distortion_loss_fw(ws, deltas, ts, rays_a) 26 | ctx.save_for_backward(ws_inclusive_scan, wts_inclusive_scan, 27 | ws, deltas, ts, rays_a) 28 | return loss 29 | 30 | @staticmethod 31 | def backward(ctx, dL_dloss): 32 | (ws_inclusive_scan, wts_inclusive_scan, 33 | ws, deltas, ts, rays_a) = ctx.saved_tensors 34 | dL_dws = vren.distortion_loss_bw(dL_dloss, ws_inclusive_scan, 35 | wts_inclusive_scan, 36 | ws, deltas, ts, rays_a) 37 | return dL_dws, None, None, None 38 | 39 | 40 | class NeRFLoss(nn.Module): 41 | def __init__(self, lambda_opacity=1e-3, lambda_distortion=1e-3): 42 | super().__init__() 43 | 44 | self.lambda_opacity = lambda_opacity 45 | self.lambda_distortion = lambda_distortion 46 | 47 | def forward(self, results, target, **kwargs): 48 | d = {} 49 | d['rgb'] = (results['rgb']-target['rgb'])**2 50 | 51 | o = results['opacity']+1e-10 52 | # encourage opacity to be either 0 or 1 to avoid floater 53 | d['opacity'] = self.lambda_opacity*(-o*torch.log(o)) 54 | 55 | if self.lambda_distortion > 0: 56 | d['distortion'] = self.lambda_distortion * \ 57 | DistortionLoss.apply(results['ws'], results['deltas'], 58 | results['ts'], results['rays_a']) 59 | 60 | return d 61 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | 13 | @torch.no_grad() 14 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 15 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 16 | -------------------------------------------------------------------------------- /misc/prepare_rtmv.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import glob 3 | import sys 4 | from tqdm import tqdm 5 | import os 6 | import numpy as np 7 | sys.path.append('datasets') 8 | from color_utils import linear_to_srgb 9 | 10 | import warnings; warnings.filterwarnings("ignore") 11 | 12 | 13 | if __name__ == '__main__': 14 | # convert hdr images to ldr by applying linear_to_srgb and clamping tone-mapping 15 | # and save into images/ folder to accelerate reading 16 | root_dir = sys.argv[1] 17 | envs = sorted(os.listdir(root_dir)) 18 | print('Generating ldr images from hdr images ...') 19 | for env in tqdm(envs): 20 | for scene in tqdm(sorted(os.listdir(os.path.join(root_dir, env)))): 21 | os.makedirs(os.path.join(root_dir, env, scene, 'images'), exist_ok=True) 22 | for i, img_p in enumerate(tqdm(sorted(glob.glob(os.path.join(root_dir, env, scene, '*[0-9].exr'))))): 23 | img = imageio.imread(img_p) # hdr 24 | img[..., :3] = linear_to_srgb(img[..., :3]) 25 | img = (255*img).astype(np.uint8) 26 | imageio.imsave(os.path.join(root_dir, env, scene, f'images/{i:05d}.png'), img) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwea123/ngp_pl/1b49af1856a276b236e0f17539814134ed329860/models/__init__.py -------------------------------------------------------------------------------- /models/csrc/binding.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | std::vector ray_aabb_intersect( 5 | const torch::Tensor rays_o, 6 | const torch::Tensor rays_d, 7 | const torch::Tensor centers, 8 | const torch::Tensor half_sizes, 9 | const int max_hits 10 | ){ 11 | CHECK_INPUT(rays_o); 12 | CHECK_INPUT(rays_d); 13 | CHECK_INPUT(centers); 14 | CHECK_INPUT(half_sizes); 15 | return ray_aabb_intersect_cu(rays_o, rays_d, centers, half_sizes, max_hits); 16 | } 17 | 18 | 19 | std::vector ray_sphere_intersect( 20 | const torch::Tensor rays_o, 21 | const torch::Tensor rays_d, 22 | const torch::Tensor centers, 23 | const torch::Tensor radii, 24 | const int max_hits 25 | ){ 26 | CHECK_INPUT(rays_o); 27 | CHECK_INPUT(rays_d); 28 | CHECK_INPUT(centers); 29 | CHECK_INPUT(radii); 30 | return ray_sphere_intersect_cu(rays_o, rays_d, centers, radii, max_hits); 31 | } 32 | 33 | 34 | void packbits( 35 | torch::Tensor density_grid, 36 | const float density_threshold, 37 | torch::Tensor density_bitfield 38 | ){ 39 | CHECK_INPUT(density_grid); 40 | CHECK_INPUT(density_bitfield); 41 | 42 | return packbits_cu(density_grid, density_threshold, density_bitfield); 43 | } 44 | 45 | 46 | torch::Tensor morton3D(const torch::Tensor coords){ 47 | CHECK_INPUT(coords); 48 | 49 | return morton3D_cu(coords); 50 | } 51 | 52 | 53 | torch::Tensor morton3D_invert(const torch::Tensor indices){ 54 | CHECK_INPUT(indices); 55 | 56 | return morton3D_invert_cu(indices); 57 | } 58 | 59 | 60 | std::vector raymarching_train( 61 | const torch::Tensor rays_o, 62 | const torch::Tensor rays_d, 63 | const torch::Tensor hits_t, 64 | const torch::Tensor density_bitfield, 65 | const int cascades, 66 | const float scale, 67 | const float exp_step_factor, 68 | const torch::Tensor noise, 69 | const int grid_size, 70 | const int max_samples 71 | ){ 72 | CHECK_INPUT(rays_o); 73 | CHECK_INPUT(rays_d); 74 | CHECK_INPUT(hits_t); 75 | CHECK_INPUT(density_bitfield); 76 | CHECK_INPUT(noise); 77 | 78 | return raymarching_train_cu( 79 | rays_o, rays_d, hits_t, density_bitfield, cascades, 80 | scale, exp_step_factor, noise, grid_size, max_samples); 81 | } 82 | 83 | 84 | std::vector raymarching_test( 85 | const torch::Tensor rays_o, 86 | const torch::Tensor rays_d, 87 | torch::Tensor hits_t, 88 | const torch::Tensor alive_indices, 89 | const torch::Tensor density_bitfield, 90 | const int cascades, 91 | const float scale, 92 | const float exp_step_factor, 93 | const int grid_size, 94 | const int max_samples, 95 | const int N_samples 96 | ){ 97 | CHECK_INPUT(rays_o); 98 | CHECK_INPUT(rays_d); 99 | CHECK_INPUT(hits_t); 100 | CHECK_INPUT(alive_indices); 101 | CHECK_INPUT(density_bitfield); 102 | 103 | return raymarching_test_cu( 104 | rays_o, rays_d, hits_t, alive_indices, density_bitfield, cascades, 105 | scale, exp_step_factor, grid_size, max_samples, N_samples); 106 | } 107 | 108 | 109 | std::vector composite_train_fw( 110 | const torch::Tensor sigmas, 111 | const torch::Tensor rgbs, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a, 115 | const float opacity_threshold 116 | ){ 117 | CHECK_INPUT(sigmas); 118 | CHECK_INPUT(rgbs); 119 | CHECK_INPUT(deltas); 120 | CHECK_INPUT(ts); 121 | CHECK_INPUT(rays_a); 122 | 123 | return composite_train_fw_cu( 124 | sigmas, rgbs, deltas, ts, 125 | rays_a, opacity_threshold); 126 | } 127 | 128 | 129 | std::vector composite_train_bw( 130 | const torch::Tensor dL_dopacity, 131 | const torch::Tensor dL_ddepth, 132 | const torch::Tensor dL_drgb, 133 | const torch::Tensor dL_dws, 134 | const torch::Tensor sigmas, 135 | const torch::Tensor rgbs, 136 | const torch::Tensor ws, 137 | const torch::Tensor deltas, 138 | const torch::Tensor ts, 139 | const torch::Tensor rays_a, 140 | const torch::Tensor opacity, 141 | const torch::Tensor depth, 142 | const torch::Tensor rgb, 143 | const float opacity_threshold 144 | ){ 145 | CHECK_INPUT(dL_dopacity); 146 | CHECK_INPUT(dL_ddepth); 147 | CHECK_INPUT(dL_drgb); 148 | CHECK_INPUT(dL_dws); 149 | CHECK_INPUT(sigmas); 150 | CHECK_INPUT(rgbs); 151 | CHECK_INPUT(ws); 152 | CHECK_INPUT(deltas); 153 | CHECK_INPUT(ts); 154 | CHECK_INPUT(rays_a); 155 | CHECK_INPUT(opacity); 156 | CHECK_INPUT(depth); 157 | CHECK_INPUT(rgb); 158 | 159 | return composite_train_bw_cu( 160 | dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 161 | sigmas, rgbs, ws, deltas, ts, rays_a, 162 | opacity, depth, rgb, opacity_threshold); 163 | } 164 | 165 | 166 | void composite_test_fw( 167 | const torch::Tensor sigmas, 168 | const torch::Tensor rgbs, 169 | const torch::Tensor deltas, 170 | const torch::Tensor ts, 171 | const torch::Tensor hits_t, 172 | const torch::Tensor alive_indices, 173 | const float T_threshold, 174 | const torch::Tensor N_eff_samples, 175 | torch::Tensor opacity, 176 | torch::Tensor depth, 177 | torch::Tensor rgb 178 | ){ 179 | CHECK_INPUT(sigmas); 180 | CHECK_INPUT(rgbs); 181 | CHECK_INPUT(deltas); 182 | CHECK_INPUT(ts); 183 | CHECK_INPUT(hits_t); 184 | CHECK_INPUT(alive_indices); 185 | CHECK_INPUT(N_eff_samples); 186 | CHECK_INPUT(opacity); 187 | CHECK_INPUT(depth); 188 | CHECK_INPUT(rgb); 189 | 190 | composite_test_fw_cu( 191 | sigmas, rgbs, deltas, ts, hits_t, alive_indices, 192 | T_threshold, N_eff_samples, 193 | opacity, depth, rgb); 194 | } 195 | 196 | 197 | std::vector distortion_loss_fw( 198 | const torch::Tensor ws, 199 | const torch::Tensor deltas, 200 | const torch::Tensor ts, 201 | const torch::Tensor rays_a 202 | ){ 203 | CHECK_INPUT(ws); 204 | CHECK_INPUT(deltas); 205 | CHECK_INPUT(ts); 206 | CHECK_INPUT(rays_a); 207 | 208 | return distortion_loss_fw_cu(ws, deltas, ts, rays_a); 209 | } 210 | 211 | 212 | torch::Tensor distortion_loss_bw( 213 | const torch::Tensor dL_dloss, 214 | const torch::Tensor ws_inclusive_scan, 215 | const torch::Tensor wts_inclusive_scan, 216 | const torch::Tensor ws, 217 | const torch::Tensor deltas, 218 | const torch::Tensor ts, 219 | const torch::Tensor rays_a 220 | ){ 221 | CHECK_INPUT(dL_dloss); 222 | CHECK_INPUT(ws_inclusive_scan); 223 | CHECK_INPUT(wts_inclusive_scan); 224 | CHECK_INPUT(ws); 225 | CHECK_INPUT(deltas); 226 | CHECK_INPUT(ts); 227 | CHECK_INPUT(rays_a); 228 | 229 | return distortion_loss_bw_cu(dL_dloss, ws_inclusive_scan, wts_inclusive_scan, 230 | ws, deltas, ts, rays_a); 231 | } 232 | 233 | 234 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 235 | m.def("ray_aabb_intersect", &ray_aabb_intersect); 236 | m.def("ray_sphere_intersect", &ray_sphere_intersect); 237 | 238 | m.def("morton3D", &morton3D); 239 | m.def("morton3D_invert", &morton3D_invert); 240 | m.def("packbits", &packbits); 241 | 242 | m.def("raymarching_train", &raymarching_train); 243 | m.def("raymarching_test", &raymarching_test); 244 | m.def("composite_train_fw", &composite_train_fw); 245 | m.def("composite_train_bw", &composite_train_bw); 246 | m.def("composite_test_fw", &composite_test_fw); 247 | 248 | m.def("distortion_loss_fw", &distortion_loss_fw); 249 | m.def("distortion_loss_bw", &distortion_loss_bw); 250 | 251 | } -------------------------------------------------------------------------------- /models/csrc/include/helper_math.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | * 3 | * Redistribution and use in source and binary forms, with or without 4 | * modification, are permitted provided that the following conditions 5 | * are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of NVIDIA CORPORATION nor the names of its 12 | * contributors may be used to endorse or promote products derived 13 | * from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | /* 29 | * This file implements common mathematical operations on vector types 30 | * (float3, float4 etc.) since these are not provided as standard by CUDA. 31 | * 32 | * The syntax is modeled on the Cg standard library. 33 | * 34 | * This is part of the Helper library includes 35 | * 36 | * Thanks to Linh Hah for additions and fixes. 37 | */ 38 | 39 | #ifndef HELPER_MATH_H 40 | #define HELPER_MATH_H 41 | 42 | #include "cuda_runtime.h" 43 | 44 | typedef unsigned int uint; 45 | typedef unsigned short ushort; 46 | 47 | #ifndef EXIT_WAIVED 48 | #define EXIT_WAIVED 2 49 | #endif 50 | 51 | #ifndef __CUDACC__ 52 | #include 53 | 54 | //////////////////////////////////////////////////////////////////////////////// 55 | // host implementations of CUDA functions 56 | //////////////////////////////////////////////////////////////////////////////// 57 | 58 | inline float fminf(float a, float b) 59 | { 60 | return a < b ? a : b; 61 | } 62 | 63 | inline float fmaxf(float a, float b) 64 | { 65 | return a > b ? a : b; 66 | } 67 | 68 | inline int max(int a, int b) 69 | { 70 | return a > b ? a : b; 71 | } 72 | 73 | inline int min(int a, int b) 74 | { 75 | return a < b ? a : b; 76 | } 77 | 78 | inline float rsqrtf(float x) 79 | { 80 | return 1.0f / sqrtf(x); 81 | } 82 | #endif 83 | 84 | //////////////////////////////////////////////////////////////////////////////// 85 | // constructors 86 | //////////////////////////////////////////////////////////////////////////////// 87 | 88 | inline __host__ __device__ float2 make_float2(float s) 89 | { 90 | return make_float2(s, s); 91 | } 92 | inline __host__ __device__ float2 make_float2(float3 a) 93 | { 94 | return make_float2(a.x, a.y); 95 | } 96 | inline __host__ __device__ float3 make_float3(float s) 97 | { 98 | return make_float3(s, s, s); 99 | } 100 | inline __host__ __device__ float3 make_float3(float2 a) 101 | { 102 | return make_float3(a.x, a.y, 0.0f); 103 | } 104 | inline __host__ __device__ float3 make_float3(float2 a, float s) 105 | { 106 | return make_float3(a.x, a.y, s); 107 | } 108 | 109 | //////////////////////////////////////////////////////////////////////////////// 110 | // negate 111 | //////////////////////////////////////////////////////////////////////////////// 112 | 113 | inline __host__ __device__ float3 operator-(float3 &a) 114 | { 115 | return make_float3(-a.x, -a.y, -a.z); 116 | } 117 | 118 | //////////////////////////////////////////////////////////////////////////////// 119 | // addition 120 | //////////////////////////////////////////////////////////////////////////////// 121 | 122 | inline __host__ __device__ float3 operator+(float3 a, float3 b) 123 | { 124 | return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); 125 | } 126 | inline __host__ __device__ void operator+=(float3 &a, float3 b) 127 | { 128 | a.x += b.x; 129 | a.y += b.y; 130 | a.z += b.z; 131 | } 132 | inline __host__ __device__ float3 operator+(float3 a, float b) 133 | { 134 | return make_float3(a.x + b, a.y + b, a.z + b); 135 | } 136 | inline __host__ __device__ void operator+=(float3 &a, float b) 137 | { 138 | a.x += b; 139 | a.y += b; 140 | a.z += b; 141 | } 142 | inline __host__ __device__ float3 operator+(float b, float3 a) 143 | { 144 | return make_float3(a.x + b, a.y + b, a.z + b); 145 | } 146 | 147 | //////////////////////////////////////////////////////////////////////////////// 148 | // subtract 149 | //////////////////////////////////////////////////////////////////////////////// 150 | 151 | inline __host__ __device__ float3 operator-(float3 a, float3 b) 152 | { 153 | return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); 154 | } 155 | inline __host__ __device__ void operator-=(float3 &a, float3 b) 156 | { 157 | a.x -= b.x; 158 | a.y -= b.y; 159 | a.z -= b.z; 160 | } 161 | inline __host__ __device__ float3 operator-(float3 a, float b) 162 | { 163 | return make_float3(a.x - b, a.y - b, a.z - b); 164 | } 165 | inline __host__ __device__ float3 operator-(float b, float3 a) 166 | { 167 | return make_float3(b - a.x, b - a.y, b - a.z); 168 | } 169 | inline __host__ __device__ void operator-=(float3 &a, float b) 170 | { 171 | a.x -= b; 172 | a.y -= b; 173 | a.z -= b; 174 | } 175 | 176 | //////////////////////////////////////////////////////////////////////////////// 177 | // multiply 178 | //////////////////////////////////////////////////////////////////////////////// 179 | 180 | inline __host__ __device__ float3 operator*(float3 a, float3 b) 181 | { 182 | return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); 183 | } 184 | inline __host__ __device__ void operator*=(float3 &a, float3 b) 185 | { 186 | a.x *= b.x; 187 | a.y *= b.y; 188 | a.z *= b.z; 189 | } 190 | inline __host__ __device__ float3 operator*(float3 a, float b) 191 | { 192 | return make_float3(a.x * b, a.y * b, a.z * b); 193 | } 194 | inline __host__ __device__ float3 operator*(float b, float3 a) 195 | { 196 | return make_float3(b * a.x, b * a.y, b * a.z); 197 | } 198 | inline __host__ __device__ void operator*=(float3 &a, float b) 199 | { 200 | a.x *= b; 201 | a.y *= b; 202 | a.z *= b; 203 | } 204 | 205 | //////////////////////////////////////////////////////////////////////////////// 206 | // divide 207 | //////////////////////////////////////////////////////////////////////////////// 208 | 209 | inline __host__ __device__ float2 operator/(float2 a, float2 b) 210 | { 211 | return make_float2(a.x / b.x, a.y / b.y); 212 | } 213 | inline __host__ __device__ void operator/=(float2 &a, float2 b) 214 | { 215 | a.x /= b.x; 216 | a.y /= b.y; 217 | } 218 | inline __host__ __device__ float2 operator/(float2 a, float b) 219 | { 220 | return make_float2(a.x / b, a.y / b); 221 | } 222 | inline __host__ __device__ void operator/=(float2 &a, float b) 223 | { 224 | a.x /= b; 225 | a.y /= b; 226 | } 227 | inline __host__ __device__ float2 operator/(float b, float2 a) 228 | { 229 | return make_float2(b / a.x, b / a.y); 230 | } 231 | 232 | inline __host__ __device__ float3 operator/(float3 a, float3 b) 233 | { 234 | return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); 235 | } 236 | inline __host__ __device__ void operator/=(float3 &a, float3 b) 237 | { 238 | a.x /= b.x; 239 | a.y /= b.y; 240 | a.z /= b.z; 241 | } 242 | inline __host__ __device__ float3 operator/(float3 a, float b) 243 | { 244 | return make_float3(a.x / b, a.y / b, a.z / b); 245 | } 246 | inline __host__ __device__ void operator/=(float3 &a, float b) 247 | { 248 | a.x /= b; 249 | a.y /= b; 250 | a.z /= b; 251 | } 252 | inline __host__ __device__ float3 operator/(float b, float3 a) 253 | { 254 | return make_float3(b / a.x, b / a.y, b / a.z); 255 | } 256 | 257 | //////////////////////////////////////////////////////////////////////////////// 258 | // min 259 | //////////////////////////////////////////////////////////////////////////////// 260 | 261 | inline __host__ __device__ float3 fminf(float3 a, float3 b) 262 | { 263 | return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); 264 | } 265 | 266 | //////////////////////////////////////////////////////////////////////////////// 267 | // max 268 | //////////////////////////////////////////////////////////////////////////////// 269 | 270 | inline __host__ __device__ float3 fmaxf(float3 a, float3 b) 271 | { 272 | return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); 273 | } 274 | 275 | //////////////////////////////////////////////////////////////////////////////// 276 | // clamp 277 | // - clamp the value v to be in the range [a, b] 278 | //////////////////////////////////////////////////////////////////////////////// 279 | 280 | inline __device__ __host__ float clamp(float f, float a, float b) 281 | { 282 | return fmaxf(a, fminf(f, b)); 283 | } 284 | inline __device__ __host__ int clamp(int f, int a, int b) 285 | { 286 | return max(a, min(f, b)); 287 | } 288 | inline __device__ __host__ uint clamp(uint f, uint a, uint b) 289 | { 290 | return max(a, min(f, b)); 291 | } 292 | 293 | inline __device__ __host__ float3 clamp(float3 v, float a, float b) 294 | { 295 | return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); 296 | } 297 | inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) 298 | { 299 | return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); 300 | } 301 | 302 | //////////////////////////////////////////////////////////////////////////////// 303 | // dot product 304 | //////////////////////////////////////////////////////////////////////////////// 305 | 306 | inline __host__ __device__ float dot(float3 a, float3 b) 307 | { 308 | return a.x * b.x + a.y * b.y + a.z * b.z; 309 | } 310 | 311 | //////////////////////////////////////////////////////////////////////////////// 312 | // length 313 | //////////////////////////////////////////////////////////////////////////////// 314 | 315 | inline __host__ __device__ float length(float3 v) 316 | { 317 | return sqrtf(dot(v, v)); 318 | } 319 | 320 | //////////////////////////////////////////////////////////////////////////////// 321 | // normalize 322 | //////////////////////////////////////////////////////////////////////////////// 323 | 324 | inline __host__ __device__ float3 normalize(float3 v) 325 | { 326 | float invLen = rsqrtf(dot(v, v)); 327 | return v * invLen; 328 | } 329 | 330 | //////////////////////////////////////////////////////////////////////////////// 331 | // reflect 332 | // - returns reflection of incident ray I around surface normal N 333 | // - N should be normalized, reflected vector's length is equal to length of I 334 | //////////////////////////////////////////////////////////////////////////////// 335 | 336 | inline __host__ __device__ float3 reflect(float3 i, float3 n) 337 | { 338 | return i - 2.0f * n * dot(n,i); 339 | } 340 | 341 | //////////////////////////////////////////////////////////////////////////////// 342 | // cross product 343 | //////////////////////////////////////////////////////////////////////////////// 344 | 345 | inline __host__ __device__ float3 cross(float3 a, float3 b) 346 | { 347 | return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); 348 | } 349 | 350 | //////////////////////////////////////////////////////////////////////////////// 351 | // smoothstep 352 | // - returns 0 if x < a 353 | // - returns 1 if x > b 354 | // - otherwise returns smooth interpolation between 0 and 1 based on x 355 | //////////////////////////////////////////////////////////////////////////////// 356 | 357 | inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) 358 | { 359 | float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); 360 | return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y))); 361 | } 362 | 363 | #endif 364 | -------------------------------------------------------------------------------- /models/csrc/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 7 | 8 | 9 | std::vector ray_aabb_intersect_cu( 10 | const torch::Tensor rays_o, 11 | const torch::Tensor rays_d, 12 | const torch::Tensor centers, 13 | const torch::Tensor half_sizes, 14 | const int max_hits 15 | ); 16 | 17 | 18 | std::vector ray_sphere_intersect_cu( 19 | const torch::Tensor rays_o, 20 | const torch::Tensor rays_d, 21 | const torch::Tensor centers, 22 | const torch::Tensor radii, 23 | const int max_hits 24 | ); 25 | 26 | 27 | void packbits_cu( 28 | torch::Tensor density_grid, 29 | const float density_threshold, 30 | torch::Tensor density_bitfield 31 | ); 32 | 33 | 34 | torch::Tensor morton3D_cu(const torch::Tensor coords); 35 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices); 36 | 37 | 38 | std::vector raymarching_train_cu( 39 | const torch::Tensor rays_o, 40 | const torch::Tensor rays_d, 41 | const torch::Tensor hits_t, 42 | const torch::Tensor density_bitfield, 43 | const int cascades, 44 | const float scale, 45 | const float exp_step_factor, 46 | const torch::Tensor noise, 47 | const int grid_size, 48 | const int max_samples 49 | ); 50 | 51 | 52 | std::vector raymarching_test_cu( 53 | const torch::Tensor rays_o, 54 | const torch::Tensor rays_d, 55 | torch::Tensor hits_t, 56 | const torch::Tensor alive_indices, 57 | const torch::Tensor density_bitfield, 58 | const int cascades, 59 | const float scale, 60 | const float exp_step_factor, 61 | const int grid_size, 62 | const int max_samples, 63 | const int N_samples 64 | ); 65 | 66 | 67 | std::vector composite_train_fw_cu( 68 | const torch::Tensor sigmas, 69 | const torch::Tensor rgbs, 70 | const torch::Tensor deltas, 71 | const torch::Tensor ts, 72 | const torch::Tensor rays_a, 73 | const float T_threshold 74 | ); 75 | 76 | 77 | std::vector composite_train_bw_cu( 78 | const torch::Tensor dL_dopacity, 79 | const torch::Tensor dL_ddepth, 80 | const torch::Tensor dL_drgb, 81 | const torch::Tensor dL_dws, 82 | const torch::Tensor sigmas, 83 | const torch::Tensor rgbs, 84 | const torch::Tensor ws, 85 | const torch::Tensor deltas, 86 | const torch::Tensor ts, 87 | const torch::Tensor rays_a, 88 | const torch::Tensor opacity, 89 | const torch::Tensor depth, 90 | const torch::Tensor rgb, 91 | const float T_threshold 92 | ); 93 | 94 | 95 | void composite_test_fw_cu( 96 | const torch::Tensor sigmas, 97 | const torch::Tensor rgbs, 98 | const torch::Tensor deltas, 99 | const torch::Tensor ts, 100 | const torch::Tensor hits_t, 101 | const torch::Tensor alive_indices, 102 | const float T_threshold, 103 | const torch::Tensor N_eff_samples, 104 | torch::Tensor opacity, 105 | torch::Tensor depth, 106 | torch::Tensor rgb 107 | ); 108 | 109 | 110 | std::vector distortion_loss_fw_cu( 111 | const torch::Tensor ws, 112 | const torch::Tensor deltas, 113 | const torch::Tensor ts, 114 | const torch::Tensor rays_a 115 | ); 116 | 117 | 118 | torch::Tensor distortion_loss_bw_cu( 119 | const torch::Tensor dL_dloss, 120 | const torch::Tensor ws_inclusive_scan, 121 | const torch::Tensor wts_inclusive_scan, 122 | const torch::Tensor ws, 123 | const torch::Tensor deltas, 124 | const torch::Tensor ts, 125 | const torch::Tensor rays_a 126 | ); -------------------------------------------------------------------------------- /models/csrc/intersection.cu: -------------------------------------------------------------------------------- 1 | #include "helper_math.h" 2 | #include "utils.h" 3 | 4 | 5 | __device__ __forceinline__ float2 _ray_aabb_intersect( 6 | const float3 ray_o, 7 | const float3 inv_d, 8 | const float3 center, 9 | const float3 half_size 10 | ){ 11 | 12 | const float3 t_min = (center-half_size-ray_o)*inv_d; 13 | const float3 t_max = (center+half_size-ray_o)*inv_d; 14 | 15 | const float3 _t1 = fminf(t_min, t_max); 16 | const float3 _t2 = fmaxf(t_min, t_max); 17 | const float t1 = fmaxf(fmaxf(_t1.x, _t1.y), _t1.z); 18 | const float t2 = fminf(fminf(_t2.x, _t2.y), _t2.z); 19 | 20 | if (t1 > t2) return make_float2(-1.0f); // no intersection 21 | return make_float2(t1, t2); 22 | } 23 | 24 | 25 | __global__ void ray_aabb_intersect_kernel( 26 | const torch::PackedTensorAccessor32 rays_o, 27 | const torch::PackedTensorAccessor32 rays_d, 28 | const torch::PackedTensorAccessor32 centers, 29 | const torch::PackedTensorAccessor32 half_sizes, 30 | const int max_hits, 31 | int* __restrict__ hit_cnt, 32 | torch::PackedTensorAccessor32 hits_t, 33 | torch::PackedTensorAccessor64 hits_voxel_idx 34 | ){ 35 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 36 | const int v = blockIdx.y * blockDim.y + threadIdx.y; 37 | 38 | if (v>=centers.size(0) || r>=rays_o.size(0)) return; 39 | 40 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 41 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 42 | const float3 inv_d = 1.0f/ray_d; 43 | 44 | const float3 center = make_float3(centers[v][0], centers[v][1], centers[v][2]); 45 | const float3 half_size = make_float3(half_sizes[v][0], half_sizes[v][1], half_sizes[v][2]); 46 | const float2 t1t2 = _ray_aabb_intersect(ray_o, inv_d, center, half_size); 47 | 48 | if (t1t2.y > 0){ // if ray hits the voxel 49 | const int cnt = atomicAdd(&hit_cnt[r], 1); 50 | if (cnt < max_hits){ 51 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 52 | hits_t[r][cnt][1] = t1t2.y; 53 | hits_voxel_idx[r][cnt] = v; 54 | } 55 | } 56 | } 57 | 58 | 59 | std::vector ray_aabb_intersect_cu( 60 | const torch::Tensor rays_o, 61 | const torch::Tensor rays_d, 62 | const torch::Tensor centers, 63 | const torch::Tensor half_sizes, 64 | const int max_hits 65 | ){ 66 | 67 | const int N_rays = rays_o.size(0), N_voxels = centers.size(0); 68 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 69 | auto hits_voxel_idx = 70 | torch::zeros({N_rays, max_hits}, 71 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 72 | auto hit_cnt = 73 | torch::zeros({N_rays}, 74 | torch::dtype(torch::kInt32).device(rays_o.device())); 75 | 76 | const dim3 threads(256, 1); 77 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 78 | (N_voxels+threads.y-1)/threads.y); 79 | 80 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_aabb_intersect_cu", 81 | ([&] { 82 | ray_aabb_intersect_kernel<<>>( 83 | rays_o.packed_accessor32(), 84 | rays_d.packed_accessor32(), 85 | centers.packed_accessor32(), 86 | half_sizes.packed_accessor32(), 87 | max_hits, 88 | hit_cnt.data_ptr(), 89 | hits_t.packed_accessor32(), 90 | hits_voxel_idx.packed_accessor64() 91 | ); 92 | })); 93 | 94 | // sort intersections from near to far based on t1 95 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 96 | hits_voxel_idx = torch::gather(hits_voxel_idx, 1, hits_order); 97 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 98 | 99 | return {hit_cnt, hits_t, hits_voxel_idx}; 100 | } 101 | 102 | 103 | __device__ __forceinline__ float2 _ray_sphere_intersect( 104 | const float3 ray_o, 105 | const float3 ray_d, 106 | const float3 center, 107 | const float radius 108 | ){ 109 | const float3 co = ray_o-center; 110 | 111 | const float a = dot(ray_d, ray_d); 112 | const float half_b = dot(ray_d, co); 113 | const float c = dot(co, co)-radius*radius; 114 | 115 | const float discriminant = half_b*half_b-a*c; 116 | 117 | if (discriminant < 0) return make_float2(-1.0f); // no intersection 118 | 119 | const float disc_sqrt = sqrtf(discriminant); 120 | return make_float2(-half_b-disc_sqrt, -half_b+disc_sqrt)/a; 121 | } 122 | 123 | 124 | __global__ void ray_sphere_intersect_kernel( 125 | const torch::PackedTensorAccessor32 rays_o, 126 | const torch::PackedTensorAccessor32 rays_d, 127 | const torch::PackedTensorAccessor32 centers, 128 | const torch::PackedTensorAccessor32 radii, 129 | const int max_hits, 130 | int* __restrict__ hit_cnt, 131 | torch::PackedTensorAccessor32 hits_t, 132 | torch::PackedTensorAccessor64 hits_sphere_idx 133 | ){ 134 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 135 | const int s = blockIdx.y * blockDim.y + threadIdx.y; 136 | 137 | if (s>=centers.size(0) || r>=rays_o.size(0)) return; 138 | 139 | const float3 ray_o = make_float3(rays_o[r][0], rays_o[r][1], rays_o[r][2]); 140 | const float3 ray_d = make_float3(rays_d[r][0], rays_d[r][1], rays_d[r][2]); 141 | const float3 center = make_float3(centers[s][0], centers[s][1], centers[s][2]); 142 | 143 | const float2 t1t2 = _ray_sphere_intersect(ray_o, ray_d, center, radii[s]); 144 | 145 | if (t1t2.y > 0){ // if ray hits the sphere 146 | const int cnt = atomicAdd(&hit_cnt[r], 1); 147 | if (cnt < max_hits){ 148 | hits_t[r][cnt][0] = fmaxf(t1t2.x, 0.0f); 149 | hits_t[r][cnt][1] = t1t2.y; 150 | hits_sphere_idx[r][cnt] = s; 151 | } 152 | } 153 | } 154 | 155 | 156 | std::vector ray_sphere_intersect_cu( 157 | const torch::Tensor rays_o, 158 | const torch::Tensor rays_d, 159 | const torch::Tensor centers, 160 | const torch::Tensor radii, 161 | const int max_hits 162 | ){ 163 | 164 | const int N_rays = rays_o.size(0), N_spheres = centers.size(0); 165 | auto hits_t = torch::zeros({N_rays, max_hits, 2}, rays_o.options())-1; 166 | auto hits_sphere_idx = 167 | torch::zeros({N_rays, max_hits}, 168 | torch::dtype(torch::kLong).device(rays_o.device()))-1; 169 | auto hit_cnt = 170 | torch::zeros({N_rays}, 171 | torch::dtype(torch::kInt32).device(rays_o.device())); 172 | 173 | const dim3 threads(256, 1); 174 | const dim3 blocks((N_rays+threads.x-1)/threads.x, 175 | (N_spheres+threads.y-1)/threads.y); 176 | 177 | AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "ray_sphere_intersect_cu", 178 | ([&] { 179 | ray_sphere_intersect_kernel<<>>( 180 | rays_o.packed_accessor32(), 181 | rays_d.packed_accessor32(), 182 | centers.packed_accessor32(), 183 | radii.packed_accessor32(), 184 | max_hits, 185 | hit_cnt.data_ptr(), 186 | hits_t.packed_accessor32(), 187 | hits_sphere_idx.packed_accessor64() 188 | ); 189 | })); 190 | 191 | // sort intersections from near to far based on t1 192 | auto hits_order = std::get<1>(torch::sort(hits_t.index({"...", 0}))); 193 | hits_sphere_idx = torch::gather(hits_sphere_idx, 1, hits_order); 194 | hits_t = torch::gather(hits_t, 1, hits_order.unsqueeze(-1).tile({1, 1, 2})); 195 | 196 | return {hit_cnt, hits_t, hits_sphere_idx}; 197 | } -------------------------------------------------------------------------------- /models/csrc/losses.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | // for details of the formulae, please see https://arxiv.org/pdf/2206.05085.pdf 8 | 9 | template 10 | __global__ void prefix_sums_kernel( 11 | const scalar_t* __restrict__ ws, 12 | const scalar_t* __restrict__ wts, 13 | const torch::PackedTensorAccessor64 rays_a, 14 | scalar_t* __restrict__ ws_inclusive_scan, 15 | scalar_t* __restrict__ ws_exclusive_scan, 16 | scalar_t* __restrict__ wts_inclusive_scan, 17 | scalar_t* __restrict__ wts_exclusive_scan 18 | ){ 19 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 20 | if (n >= rays_a.size(0)) return; 21 | 22 | const int start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 23 | 24 | // compute prefix sum of ws and ws*ts 25 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 26 | thrust::inclusive_scan(thrust::device, 27 | ws+start_idx, 28 | ws+start_idx+N_samples, 29 | ws_inclusive_scan+start_idx); 30 | thrust::inclusive_scan(thrust::device, 31 | wts+start_idx, 32 | wts+start_idx+N_samples, 33 | wts_inclusive_scan+start_idx); 34 | // [a0, a1, a2, a3, ...] -> [0, a0, a0+a1, a0+a1+a2, ...] 35 | thrust::exclusive_scan(thrust::device, 36 | ws+start_idx, 37 | ws+start_idx+N_samples, 38 | ws_exclusive_scan+start_idx); 39 | thrust::exclusive_scan(thrust::device, 40 | wts+start_idx, 41 | wts+start_idx+N_samples, 42 | wts_exclusive_scan+start_idx); 43 | } 44 | 45 | 46 | template 47 | __global__ void distortion_loss_fw_kernel( 48 | const scalar_t* __restrict__ _loss, 49 | const torch::PackedTensorAccessor64 rays_a, 50 | torch::PackedTensorAccessor loss 51 | ){ 52 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 53 | if (n >= rays_a.size(0)) return; 54 | 55 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 56 | 57 | loss[ray_idx] = thrust::reduce(thrust::device, 58 | _loss+start_idx, 59 | _loss+start_idx+N_samples, 60 | (scalar_t)0); 61 | } 62 | 63 | 64 | std::vector distortion_loss_fw_cu( 65 | const torch::Tensor ws, 66 | const torch::Tensor deltas, 67 | const torch::Tensor ts, 68 | const torch::Tensor rays_a 69 | ){ 70 | const int N_rays = rays_a.size(0), N = ws.size(0); 71 | 72 | auto wts = ws * ts; 73 | 74 | auto ws_inclusive_scan = torch::zeros({N}, ws.options()); 75 | auto ws_exclusive_scan = torch::zeros({N}, ws.options()); 76 | auto wts_inclusive_scan = torch::zeros({N}, ws.options()); 77 | auto wts_exclusive_scan = torch::zeros({N}, ws.options()); 78 | 79 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 80 | 81 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu_prefix_sums", 82 | ([&] { 83 | prefix_sums_kernel<<>>( 84 | ws.data_ptr(), 85 | wts.data_ptr(), 86 | rays_a.packed_accessor64(), 87 | ws_inclusive_scan.data_ptr(), 88 | ws_exclusive_scan.data_ptr(), 89 | wts_inclusive_scan.data_ptr(), 90 | wts_exclusive_scan.data_ptr() 91 | ); 92 | })); 93 | 94 | auto _loss = 2*(wts_inclusive_scan*ws_exclusive_scan- 95 | ws_inclusive_scan*wts_exclusive_scan) + 1.0f/3*ws*ws*deltas; 96 | 97 | auto loss = torch::zeros({N_rays}, ws.options()); 98 | 99 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_fw_cu", 100 | ([&] { 101 | distortion_loss_fw_kernel<<>>( 102 | _loss.data_ptr(), 103 | rays_a.packed_accessor64(), 104 | loss.packed_accessor() 105 | ); 106 | })); 107 | 108 | return {loss, ws_inclusive_scan, wts_inclusive_scan}; 109 | } 110 | 111 | 112 | template 113 | __global__ void distortion_loss_bw_kernel( 114 | const torch::PackedTensorAccessor dL_dloss, 115 | const torch::PackedTensorAccessor ws_inclusive_scan, 116 | const torch::PackedTensorAccessor wts_inclusive_scan, 117 | const torch::PackedTensorAccessor ws, 118 | const torch::PackedTensorAccessor deltas, 119 | const torch::PackedTensorAccessor ts, 120 | const torch::PackedTensorAccessor64 rays_a, 121 | torch::PackedTensorAccessor dL_dws 122 | ){ 123 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 124 | if (n >= rays_a.size(0)) return; 125 | 126 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 127 | const int end_idx = start_idx+N_samples-1; 128 | 129 | const scalar_t ws_sum = ws_inclusive_scan[end_idx]; 130 | const scalar_t wts_sum = wts_inclusive_scan[end_idx]; 131 | // fill in dL_dws from start_idx to end_idx 132 | for (int s=start_idx; s<=end_idx; s++){ 133 | dL_dws[s] = dL_dloss[ray_idx] * 2 * ( 134 | (s==start_idx? 135 | (scalar_t)0: 136 | (ts[s]*ws_inclusive_scan[s-1]-wts_inclusive_scan[s-1]) 137 | ) + 138 | (wts_sum-wts_inclusive_scan[s]-ts[s]*(ws_sum-ws_inclusive_scan[s])) 139 | ); 140 | dL_dws[s] += dL_dloss[ray_idx] * (scalar_t)2/3*ws[s]*deltas[s]; 141 | } 142 | } 143 | 144 | 145 | torch::Tensor distortion_loss_bw_cu( 146 | const torch::Tensor dL_dloss, 147 | const torch::Tensor ws_inclusive_scan, 148 | const torch::Tensor wts_inclusive_scan, 149 | const torch::Tensor ws, 150 | const torch::Tensor deltas, 151 | const torch::Tensor ts, 152 | const torch::Tensor rays_a 153 | ){ 154 | const int N_rays = rays_a.size(0), N = ws.size(0); 155 | 156 | auto dL_dws = torch::zeros({N}, dL_dloss.options()); 157 | 158 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 159 | 160 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(ws.type(), "distortion_loss_bw_cu", 161 | ([&] { 162 | distortion_loss_bw_kernel<<>>( 163 | dL_dloss.packed_accessor(), 164 | ws_inclusive_scan.packed_accessor(), 165 | wts_inclusive_scan.packed_accessor(), 166 | ws.packed_accessor(), 167 | deltas.packed_accessor(), 168 | ts.packed_accessor(), 169 | rays_a.packed_accessor64(), 170 | dL_dws.packed_accessor() 171 | ); 172 | })); 173 | 174 | return dL_dws; 175 | } -------------------------------------------------------------------------------- /models/csrc/raymarching.cu: -------------------------------------------------------------------------------- 1 | #include "helper_math.h" 2 | #include "utils.h" 3 | 4 | #define SQRT3 1.73205080757f 5 | 6 | 7 | inline __host__ __device__ float signf(const float x) { return copysignf(1.0f, x); } 8 | 9 | // exponentially step t if exp_step_factor>0 (larger step size when sample moves away from the camera) 10 | // default exp_step_factor is 0 for synthetic scene, 1/256 for real scene 11 | inline __host__ __device__ float calc_dt(float t, float exp_step_factor, int max_samples, int grid_size, float scale){ 12 | return clamp(t*exp_step_factor, SQRT3/max_samples, SQRT3*2*scale/grid_size); 13 | } 14 | 15 | // Example input range of |xyz| and return value of this function 16 | // [0, 0.5) -> 0 17 | // [0.5, 1) -> 1 18 | // [1, 2) -> 2 19 | inline __device__ int mip_from_pos(const float x, const float y, const float z, const int cascades) { 20 | const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z))); 21 | int exponent; frexpf(mx, &exponent); 22 | return min(cascades-1, max(0, exponent+1)); 23 | } 24 | 25 | // Example input range of dt and return value of this function 26 | // [0, 1/grid_size) -> 0 27 | // [1/grid_size, 2/grid_size) -> 1 28 | // [2/grid_size, 4/grid_size) -> 2 29 | inline __device__ int mip_from_dt(float dt, int grid_size, int cascades) { 30 | int exponent; frexpf(dt*grid_size, &exponent); 31 | return min(cascades-1, max(0, exponent)); 32 | } 33 | 34 | // morton utils 35 | inline __host__ __device__ uint32_t __expand_bits(uint32_t v) 36 | { 37 | v = (v * 0x00010001u) & 0xFF0000FFu; 38 | v = (v * 0x00000101u) & 0x0F00F00Fu; 39 | v = (v * 0x00000011u) & 0xC30C30C3u; 40 | v = (v * 0x00000005u) & 0x49249249u; 41 | return v; 42 | } 43 | 44 | inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) 45 | { 46 | uint32_t xx = __expand_bits(x); 47 | uint32_t yy = __expand_bits(y); 48 | uint32_t zz = __expand_bits(z); 49 | return xx | (yy << 1) | (zz << 2); 50 | } 51 | 52 | inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) 53 | { 54 | x = x & 0x49249249; 55 | x = (x | (x >> 2)) & 0xc30c30c3; 56 | x = (x | (x >> 4)) & 0x0f00f00f; 57 | x = (x | (x >> 8)) & 0xff0000ff; 58 | x = (x | (x >> 16)) & 0x0000ffff; 59 | return x; 60 | } 61 | 62 | __global__ void morton3D_kernel( 63 | const torch::PackedTensorAccessor32 coords, 64 | torch::PackedTensorAccessor32 indices 65 | ){ 66 | const int n = threadIdx.x + blockIdx.x * blockDim.x; 67 | if (n >= coords.size(0)) return; 68 | 69 | indices[n] = __morton3D(coords[n][0], coords[n][1], coords[n][2]); 70 | } 71 | 72 | torch::Tensor morton3D_cu(const torch::Tensor coords){ 73 | int N = coords.size(0); 74 | 75 | auto indices = torch::zeros({N}, coords.options()); 76 | 77 | const int threads = 256, blocks = (N+threads-1)/threads; 78 | 79 | AT_DISPATCH_INTEGRAL_TYPES(coords.type(), "morton3D_cu", 80 | ([&] { 81 | morton3D_kernel<<>>( 82 | coords.packed_accessor32(), 83 | indices.packed_accessor32() 84 | ); 85 | })); 86 | 87 | return indices; 88 | } 89 | 90 | __global__ void morton3D_invert_kernel( 91 | const torch::PackedTensorAccessor32 indices, 92 | torch::PackedTensorAccessor32 coords 93 | ){ 94 | const int n = threadIdx.x + blockIdx.x * blockDim.x; 95 | if (n >= coords.size(0)) return; 96 | 97 | const int ind = indices[n]; 98 | coords[n][0] = __morton3D_invert(ind >> 0); 99 | coords[n][1] = __morton3D_invert(ind >> 1); 100 | coords[n][2] = __morton3D_invert(ind >> 2); 101 | } 102 | 103 | torch::Tensor morton3D_invert_cu(const torch::Tensor indices){ 104 | int N = indices.size(0); 105 | 106 | auto coords = torch::zeros({N, 3}, indices.options()); 107 | 108 | const int threads = 256, blocks = (N+threads-1)/threads; 109 | 110 | AT_DISPATCH_INTEGRAL_TYPES(indices.type(), "morton3D_invert_cu", 111 | ([&] { 112 | morton3D_invert_kernel<<>>( 113 | indices.packed_accessor32(), 114 | coords.packed_accessor32() 115 | ); 116 | })); 117 | 118 | return coords; 119 | } 120 | 121 | // packbits utils 122 | template 123 | __global__ void packbits_kernel( 124 | const scalar_t* __restrict__ density_grid, 125 | const int N, 126 | const float density_threshold, 127 | uint8_t* __restrict__ density_bitfield 128 | ){ 129 | // parallel per byte 130 | const int n = threadIdx.x + blockIdx.x * blockDim.x; 131 | if (n >= N) return; 132 | 133 | uint8_t bits = 0; 134 | 135 | #pragma unroll 8 136 | for (uint8_t i = 0; i < 8; i++) { 137 | bits |= (density_grid[8*n+i]>density_threshold) ? ((uint8_t)1<<<>>( 155 | density_grid.data_ptr(), 156 | N, 157 | density_threshold, 158 | density_bitfield.data_ptr() 159 | ); 160 | })); 161 | } 162 | 163 | 164 | // ray marching utils 165 | // below code is based on https://github.com/ashawkey/torch-ngp/blob/main/raymarching/src/raymarching.cu 166 | __global__ void raymarching_train_kernel( 167 | const torch::PackedTensorAccessor32 rays_o, 168 | const torch::PackedTensorAccessor32 rays_d, 169 | const torch::PackedTensorAccessor32 hits_t, 170 | const uint8_t* __restrict__ density_bitfield, 171 | const int cascades, 172 | const int grid_size, 173 | const float scale, 174 | const float exp_step_factor, 175 | const torch::PackedTensorAccessor32 noise, 176 | const int max_samples, 177 | int* __restrict__ counter, 178 | torch::PackedTensorAccessor64 rays_a, 179 | torch::PackedTensorAccessor32 xyzs, 180 | torch::PackedTensorAccessor32 dirs, 181 | torch::PackedTensorAccessor32 deltas, 182 | torch::PackedTensorAccessor32 ts 183 | ){ 184 | const int r = blockIdx.x * blockDim.x + threadIdx.x; 185 | if (r >= rays_o.size(0)) return; 186 | 187 | const uint32_t grid_size3 = grid_size*grid_size*grid_size; 188 | const float grid_size_inv = 1.0f/grid_size; 189 | 190 | const float ox = rays_o[r][0], oy = rays_o[r][1], oz = rays_o[r][2]; 191 | const float dx = rays_d[r][0], dy = rays_d[r][1], dz = rays_d[r][2]; 192 | const float dx_inv = 1.0f/dx, dy_inv = 1.0f/dy, dz_inv = 1.0f/dz; 193 | float t1 = hits_t[r][0], t2 = hits_t[r][1]; 194 | 195 | if (t1>=0) { // only perturb the starting t 196 | const float dt = calc_dt(t1, exp_step_factor, max_samples, grid_size, scale); 197 | t1 += dt*noise[r]; 198 | } 199 | 200 | // first pass: compute the number of samples on the ray 201 | float t = t1; int N_samples = 0; 202 | 203 | // if t1 < 0 (no hit) this loop will be skipped (N_samples will be 0) 204 | while (0<=t && t raymarching_train_cu( 284 | const torch::Tensor rays_o, 285 | const torch::Tensor rays_d, 286 | const torch::Tensor hits_t, 287 | const torch::Tensor density_bitfield, 288 | const int cascades, 289 | const float scale, 290 | const float exp_step_factor, 291 | const torch::Tensor noise, 292 | const int grid_size, 293 | const int max_samples 294 | ){ 295 | const int N_rays = rays_o.size(0); 296 | 297 | // count the number of samples and the number of rays processed 298 | auto counter = torch::zeros({2}, torch::dtype(torch::kInt32).device(rays_o.device())); 299 | // ray attributes: ray_idx, start_idx, N_samples 300 | auto rays_a = torch::zeros({N_rays, 3}, 301 | torch::dtype(torch::kLong).device(rays_o.device())); 302 | auto xyzs = torch::zeros({N_rays*max_samples, 3}, rays_o.options()); 303 | auto dirs = torch::zeros({N_rays*max_samples, 3}, rays_o.options()); 304 | auto deltas = torch::zeros({N_rays*max_samples}, rays_o.options()); 305 | auto ts = torch::zeros({N_rays*max_samples}, rays_o.options()); 306 | 307 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 308 | 309 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rays_o.type(), "raymarching_train_cu", 310 | ([&] { 311 | raymarching_train_kernel<<>>( 312 | rays_o.packed_accessor32(), 313 | rays_d.packed_accessor32(), 314 | hits_t.packed_accessor32(), 315 | density_bitfield.data_ptr(), 316 | cascades, 317 | grid_size, 318 | scale, 319 | exp_step_factor, 320 | noise.packed_accessor32(), 321 | max_samples, 322 | counter.data_ptr(), 323 | rays_a.packed_accessor64(), 324 | xyzs.packed_accessor32(), 325 | dirs.packed_accessor32(), 326 | deltas.packed_accessor32(), 327 | ts.packed_accessor32() 328 | ); 329 | })); 330 | 331 | return {rays_a, xyzs, dirs, deltas, ts, counter}; 332 | } 333 | 334 | 335 | __global__ void raymarching_test_kernel( 336 | const torch::PackedTensorAccessor32 rays_o, 337 | const torch::PackedTensorAccessor32 rays_d, 338 | torch::PackedTensorAccessor32 hits_t, 339 | const torch::PackedTensorAccessor64 alive_indices, 340 | const uint8_t* __restrict__ density_bitfield, 341 | const int cascades, 342 | const int grid_size, 343 | const float scale, 344 | const float exp_step_factor, 345 | const int N_samples, 346 | const int max_samples, 347 | torch::PackedTensorAccessor32 xyzs, 348 | torch::PackedTensorAccessor32 dirs, 349 | torch::PackedTensorAccessor32 deltas, 350 | torch::PackedTensorAccessor32 ts, 351 | torch::PackedTensorAccessor32 N_eff_samples 352 | ){ 353 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 354 | if (n >= alive_indices.size(0)) return; 355 | 356 | const size_t r = alive_indices[n]; // ray index 357 | const uint32_t grid_size3 = grid_size*grid_size*grid_size; 358 | const float grid_size_inv = 1.0f/grid_size; 359 | 360 | const float ox = rays_o[r][0], oy = rays_o[r][1], oz = rays_o[r][2]; 361 | const float dx = rays_d[r][0], dy = rays_d[r][1], dz = rays_d[r][2]; 362 | const float dx_inv = 1.0f/dx, dy_inv = 1.0f/dy, dz_inv = 1.0f/dz; 363 | 364 | float t = hits_t[r][0], t2 = hits_t[r][1]; 365 | int s = 0; 366 | 367 | while (t raymarching_test_cu( 408 | const torch::Tensor rays_o, 409 | const torch::Tensor rays_d, 410 | torch::Tensor hits_t, 411 | const torch::Tensor alive_indices, 412 | const torch::Tensor density_bitfield, 413 | const int cascades, 414 | const float scale, 415 | const float exp_step_factor, 416 | const int grid_size, 417 | const int max_samples, 418 | const int N_samples 419 | ){ 420 | const int N_rays = alive_indices.size(0); 421 | 422 | auto xyzs = torch::zeros({N_rays, N_samples, 3}, rays_o.options()); 423 | auto dirs = torch::zeros({N_rays, N_samples, 3}, rays_o.options()); 424 | auto deltas = torch::zeros({N_rays, N_samples}, rays_o.options()); 425 | auto ts = torch::zeros({N_rays, N_samples}, rays_o.options()); 426 | auto N_eff_samples = torch::zeros({N_rays}, 427 | torch::dtype(torch::kInt32).device(rays_o.device())); 428 | 429 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 430 | 431 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rays_o.type(), "raymarching_test_cu", 432 | ([&] { 433 | raymarching_test_kernel<<>>( 434 | rays_o.packed_accessor32(), 435 | rays_d.packed_accessor32(), 436 | hits_t.packed_accessor32(), 437 | alive_indices.packed_accessor64(), 438 | density_bitfield.data_ptr(), 439 | cascades, 440 | grid_size, 441 | scale, 442 | exp_step_factor, 443 | N_samples, 444 | max_samples, 445 | xyzs.packed_accessor32(), 446 | dirs.packed_accessor32(), 447 | deltas.packed_accessor32(), 448 | ts.packed_accessor32(), 449 | N_eff_samples.packed_accessor32() 450 | ); 451 | })); 452 | 453 | return {xyzs, dirs, deltas, ts, N_eff_samples}; 454 | } 455 | -------------------------------------------------------------------------------- /models/csrc/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | # "helper_math.h" is copied from https://github.com/NVIDIA/cuda-samples/blob/master/Common/helper_math.h 10 | 11 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 12 | 13 | 14 | setup( 15 | name='vren', 16 | version='2.0', 17 | author='kwea123', 18 | author_email='kwea123@gmail.com', 19 | description='cuda volume rendering library', 20 | long_description='cuda volume rendering library', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='vren', 24 | sources=sources, 25 | include_dirs=include_dirs, 26 | extra_compile_args={'cxx': ['-O2'], 27 | 'nvcc': ['-O2']} 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) -------------------------------------------------------------------------------- /models/csrc/volumerendering.cu: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | #include 4 | 5 | 6 | template 7 | __global__ void composite_train_fw_kernel( 8 | const torch::PackedTensorAccessor sigmas, 9 | const torch::PackedTensorAccessor rgbs, 10 | const torch::PackedTensorAccessor deltas, 11 | const torch::PackedTensorAccessor ts, 12 | const torch::PackedTensorAccessor64 rays_a, 13 | const scalar_t T_threshold, 14 | torch::PackedTensorAccessor64 total_samples, 15 | torch::PackedTensorAccessor opacity, 16 | torch::PackedTensorAccessor depth, 17 | torch::PackedTensorAccessor rgb, 18 | torch::PackedTensorAccessor ws 19 | ){ 20 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 21 | if (n >= opacity.size(0)) return; 22 | 23 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 24 | 25 | // front to back compositing 26 | int samples = 0; scalar_t T = 1.0f; 27 | 28 | while (samples < N_samples) { 29 | const int s = start_idx + samples; 30 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 31 | const scalar_t w = a * T; // weight of the sample point 32 | 33 | rgb[ray_idx][0] += w*rgbs[s][0]; 34 | rgb[ray_idx][1] += w*rgbs[s][1]; 35 | rgb[ray_idx][2] += w*rgbs[s][2]; 36 | depth[ray_idx] += w*ts[s]; 37 | opacity[ray_idx] += w; 38 | ws[s] = w; 39 | T *= 1.0f-a; 40 | 41 | if (T <= T_threshold) break; // ray has enough opacity 42 | samples++; 43 | } 44 | total_samples[ray_idx] = samples; 45 | } 46 | 47 | 48 | std::vector composite_train_fw_cu( 49 | const torch::Tensor sigmas, 50 | const torch::Tensor rgbs, 51 | const torch::Tensor deltas, 52 | const torch::Tensor ts, 53 | const torch::Tensor rays_a, 54 | const float T_threshold 55 | ){ 56 | const int N_rays = rays_a.size(0), N = sigmas.size(0); 57 | 58 | auto opacity = torch::zeros({N_rays}, sigmas.options()); 59 | auto depth = torch::zeros({N_rays}, sigmas.options()); 60 | auto rgb = torch::zeros({N_rays, 3}, sigmas.options()); 61 | auto ws = torch::zeros({N}, sigmas.options()); 62 | auto total_samples = torch::zeros({N_rays}, torch::dtype(torch::kLong).device(sigmas.device())); 63 | 64 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 65 | 66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_fw_cu", 67 | ([&] { 68 | composite_train_fw_kernel<<>>( 69 | sigmas.packed_accessor(), 70 | rgbs.packed_accessor(), 71 | deltas.packed_accessor(), 72 | ts.packed_accessor(), 73 | rays_a.packed_accessor64(), 74 | T_threshold, 75 | total_samples.packed_accessor64(), 76 | opacity.packed_accessor(), 77 | depth.packed_accessor(), 78 | rgb.packed_accessor(), 79 | ws.packed_accessor() 80 | ); 81 | })); 82 | 83 | return {total_samples, opacity, depth, rgb, ws}; 84 | } 85 | 86 | 87 | template 88 | __global__ void composite_train_bw_kernel( 89 | const torch::PackedTensorAccessor dL_dopacity, 90 | const torch::PackedTensorAccessor dL_ddepth, 91 | const torch::PackedTensorAccessor dL_drgb, 92 | const torch::PackedTensorAccessor dL_dws, 93 | scalar_t* __restrict__ dL_dws_times_ws, 94 | const torch::PackedTensorAccessor sigmas, 95 | const torch::PackedTensorAccessor rgbs, 96 | const torch::PackedTensorAccessor deltas, 97 | const torch::PackedTensorAccessor ts, 98 | const torch::PackedTensorAccessor64 rays_a, 99 | const torch::PackedTensorAccessor opacity, 100 | const torch::PackedTensorAccessor depth, 101 | const torch::PackedTensorAccessor rgb, 102 | const scalar_t T_threshold, 103 | torch::PackedTensorAccessor dL_dsigmas, 104 | torch::PackedTensorAccessor dL_drgbs 105 | ){ 106 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 107 | if (n >= opacity.size(0)) return; 108 | 109 | const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2]; 110 | 111 | // front to back compositing 112 | int samples = 0; 113 | scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2]; 114 | scalar_t O = opacity[ray_idx], D = depth[ray_idx]; 115 | scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f; 116 | 117 | // compute prefix sum of dL_dws * ws 118 | // [a0, a1, a2, a3, ...] -> [a0, a0+a1, a0+a1+a2, a0+a1+a2+a3, ...] 119 | thrust::inclusive_scan(thrust::device, 120 | dL_dws_times_ws+start_idx, 121 | dL_dws_times_ws+start_idx+N_samples, 122 | dL_dws_times_ws+start_idx); 123 | scalar_t dL_dws_times_ws_sum = dL_dws_times_ws[start_idx+N_samples-1]; 124 | 125 | while (samples < N_samples) { 126 | const int s = start_idx + samples; 127 | const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]); 128 | const scalar_t w = a * T; 129 | 130 | r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2]; 131 | d += w*ts[s]; 132 | T *= 1.0f-a; 133 | 134 | // compute gradients by math... 135 | dL_drgbs[s][0] = dL_drgb[ray_idx][0]*w; 136 | dL_drgbs[s][1] = dL_drgb[ray_idx][1]*w; 137 | dL_drgbs[s][2] = dL_drgb[ray_idx][2]*w; 138 | 139 | dL_dsigmas[s] = deltas[s] * ( 140 | dL_drgb[ray_idx][0]*(rgbs[s][0]*T-(R-r)) + 141 | dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) + 142 | dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) + // gradients from rgb 143 | dL_dopacity[ray_idx]*(1-O) + // gradient from opacity 144 | dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + // gradient from depth 145 | T*dL_dws[s]-(dL_dws_times_ws_sum-dL_dws_times_ws[s]) // gradient from ws 146 | ); 147 | 148 | if (T <= T_threshold) break; // ray has enough opacity 149 | samples++; 150 | } 151 | } 152 | 153 | 154 | std::vector composite_train_bw_cu( 155 | const torch::Tensor dL_dopacity, 156 | const torch::Tensor dL_ddepth, 157 | const torch::Tensor dL_drgb, 158 | const torch::Tensor dL_dws, 159 | const torch::Tensor sigmas, 160 | const torch::Tensor rgbs, 161 | const torch::Tensor ws, 162 | const torch::Tensor deltas, 163 | const torch::Tensor ts, 164 | const torch::Tensor rays_a, 165 | const torch::Tensor opacity, 166 | const torch::Tensor depth, 167 | const torch::Tensor rgb, 168 | const float T_threshold 169 | ){ 170 | const int N = sigmas.size(0), N_rays = rays_a.size(0); 171 | 172 | auto dL_dsigmas = torch::zeros({N}, sigmas.options()); 173 | auto dL_drgbs = torch::zeros({N, 3}, sigmas.options()); 174 | 175 | auto dL_dws_times_ws = dL_dws * ws; // auxiliary input 176 | 177 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 178 | 179 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_train_bw_cu", 180 | ([&] { 181 | composite_train_bw_kernel<<>>( 182 | dL_dopacity.packed_accessor(), 183 | dL_ddepth.packed_accessor(), 184 | dL_drgb.packed_accessor(), 185 | dL_dws.packed_accessor(), 186 | dL_dws_times_ws.data_ptr(), 187 | sigmas.packed_accessor(), 188 | rgbs.packed_accessor(), 189 | deltas.packed_accessor(), 190 | ts.packed_accessor(), 191 | rays_a.packed_accessor64(), 192 | opacity.packed_accessor(), 193 | depth.packed_accessor(), 194 | rgb.packed_accessor(), 195 | T_threshold, 196 | dL_dsigmas.packed_accessor(), 197 | dL_drgbs.packed_accessor() 198 | ); 199 | })); 200 | 201 | return {dL_dsigmas, dL_drgbs}; 202 | } 203 | 204 | 205 | template 206 | __global__ void composite_test_fw_kernel( 207 | const torch::PackedTensorAccessor sigmas, 208 | const torch::PackedTensorAccessor rgbs, 209 | const torch::PackedTensorAccessor deltas, 210 | const torch::PackedTensorAccessor ts, 211 | const torch::PackedTensorAccessor hits_t, 212 | torch::PackedTensorAccessor64 alive_indices, 213 | const scalar_t T_threshold, 214 | const torch::PackedTensorAccessor32 N_eff_samples, 215 | torch::PackedTensorAccessor opacity, 216 | torch::PackedTensorAccessor depth, 217 | torch::PackedTensorAccessor rgb 218 | ){ 219 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 220 | if (n >= alive_indices.size(0)) return; 221 | 222 | if (N_eff_samples[n]==0){ // no hit 223 | alive_indices[n] = -1; 224 | return; 225 | } 226 | 227 | const size_t r = alive_indices[n]; // ray index 228 | 229 | // front to back compositing 230 | int s = 0; scalar_t T = 1-opacity[r]; 231 | 232 | while (s < N_eff_samples[n]) { 233 | const scalar_t a = 1.0f - __expf(-sigmas[n][s]*deltas[n][s]); 234 | const scalar_t w = a * T; 235 | 236 | rgb[r][0] += w*rgbs[n][s][0]; 237 | rgb[r][1] += w*rgbs[n][s][1]; 238 | rgb[r][2] += w*rgbs[n][s][2]; 239 | depth[r] += w*ts[n][s]; 240 | opacity[r] += w; 241 | T *= 1.0f-a; 242 | 243 | if (T <= T_threshold){ // ray has enough opacity 244 | alive_indices[n] = -1; 245 | break; 246 | } 247 | s++; 248 | } 249 | } 250 | 251 | 252 | void composite_test_fw_cu( 253 | const torch::Tensor sigmas, 254 | const torch::Tensor rgbs, 255 | const torch::Tensor deltas, 256 | const torch::Tensor ts, 257 | const torch::Tensor hits_t, 258 | torch::Tensor alive_indices, 259 | const float T_threshold, 260 | const torch::Tensor N_eff_samples, 261 | torch::Tensor opacity, 262 | torch::Tensor depth, 263 | torch::Tensor rgb 264 | ){ 265 | const int N_rays = alive_indices.size(0); 266 | 267 | const int threads = 256, blocks = (N_rays+threads-1)/threads; 268 | 269 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(sigmas.type(), "composite_test_fw_cu", 270 | ([&] { 271 | composite_test_fw_kernel<<>>( 272 | sigmas.packed_accessor(), 273 | rgbs.packed_accessor(), 274 | deltas.packed_accessor(), 275 | ts.packed_accessor(), 276 | hits_t.packed_accessor(), 277 | alive_indices.packed_accessor64(), 278 | T_threshold, 279 | N_eff_samples.packed_accessor32(), 280 | opacity.packed_accessor(), 281 | depth.packed_accessor(), 282 | rgb.packed_accessor() 283 | ); 284 | })); 285 | } -------------------------------------------------------------------------------- /models/custom_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vren 3 | from torch.cuda.amp import custom_fwd, custom_bwd 4 | from torch_scatter import segment_csr 5 | from einops import rearrange 6 | 7 | 8 | class RayAABBIntersector(torch.autograd.Function): 9 | """ 10 | Computes the intersections of rays and axis-aligned voxels. 11 | 12 | Inputs: 13 | rays_o: (N_rays, 3) ray origins 14 | rays_d: (N_rays, 3) ray directions 15 | centers: (N_voxels, 3) voxel centers 16 | half_sizes: (N_voxels, 3) voxel half sizes 17 | max_hits: maximum number of intersected voxels to keep for one ray 18 | (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) 19 | 20 | Outputs: 21 | hits_cnt: (N_rays) number of hits for each ray 22 | (followings are from near to far) 23 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 24 | hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) 25 | """ 26 | @staticmethod 27 | @custom_fwd(cast_inputs=torch.float32) 28 | def forward(ctx, rays_o, rays_d, center, half_size, max_hits): 29 | return vren.ray_aabb_intersect(rays_o, rays_d, center, half_size, max_hits) 30 | 31 | 32 | class RaySphereIntersector(torch.autograd.Function): 33 | """ 34 | Computes the intersections of rays and spheres. 35 | 36 | Inputs: 37 | rays_o: (N_rays, 3) ray origins 38 | rays_d: (N_rays, 3) ray directions 39 | centers: (N_spheres, 3) sphere centers 40 | radii: (N_spheres, 3) radii 41 | max_hits: maximum number of intersected spheres to keep for one ray 42 | 43 | Outputs: 44 | hits_cnt: (N_rays) number of hits for each ray 45 | (followings are from near to far) 46 | hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) 47 | hits_sphere_idx: (N_rays, max_hits) hit sphere indices (-1 if no hit) 48 | """ 49 | @staticmethod 50 | @custom_fwd(cast_inputs=torch.float32) 51 | def forward(ctx, rays_o, rays_d, center, radii, max_hits): 52 | return vren.ray_sphere_intersect(rays_o, rays_d, center, radii, max_hits) 53 | 54 | 55 | class RayMarcher(torch.autograd.Function): 56 | """ 57 | March the rays to get sample point positions and directions. 58 | 59 | Inputs: 60 | rays_o: (N_rays, 3) ray origins 61 | rays_d: (N_rays, 3) normalized ray directions 62 | hits_t: (N_rays, 2) near and far bounds from aabb intersection 63 | density_bitfield: (C*G**3//8) 64 | cascades: int 65 | scale: float 66 | exp_step_factor: the exponential factor to scale the steps 67 | grid_size: int 68 | max_samples: int 69 | 70 | Outputs: 71 | rays_a: (N_rays) ray_idx, start_idx, N_samples 72 | xyzs: (N, 3) sample positions 73 | dirs: (N, 3) sample view directions 74 | deltas: (N) dt for integration 75 | ts: (N) sample ts 76 | """ 77 | @staticmethod 78 | @custom_fwd(cast_inputs=torch.float32) 79 | def forward(ctx, rays_o, rays_d, hits_t, 80 | density_bitfield, cascades, scale, exp_step_factor, 81 | grid_size, max_samples): 82 | # noise to perturb the first sample of each ray 83 | noise = torch.rand_like(rays_o[:, 0]) 84 | 85 | rays_a, xyzs, dirs, deltas, ts, counter = \ 86 | vren.raymarching_train( 87 | rays_o, rays_d, hits_t, 88 | density_bitfield, cascades, scale, 89 | exp_step_factor, noise, grid_size, max_samples) 90 | 91 | total_samples = counter[0] # total samples for all rays 92 | # remove redundant output 93 | xyzs = xyzs[:total_samples] 94 | dirs = dirs[:total_samples] 95 | deltas = deltas[:total_samples] 96 | ts = ts[:total_samples] 97 | 98 | ctx.save_for_backward(rays_a, ts) 99 | 100 | return rays_a, xyzs, dirs, deltas, ts, total_samples 101 | 102 | @staticmethod 103 | @custom_bwd 104 | def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, 105 | dL_ddeltas, dL_dts, dL_dtotal_samples): 106 | rays_a, ts = ctx.saved_tensors 107 | segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]]) 108 | dL_drays_o = segment_csr(dL_dxyzs, segments) 109 | dL_drays_d = \ 110 | segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments) 111 | 112 | return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None 113 | 114 | 115 | class VolumeRenderer(torch.autograd.Function): 116 | """ 117 | Volume rendering with different number of samples per ray 118 | Used in training only 119 | 120 | Inputs: 121 | sigmas: (N) 122 | rgbs: (N, 3) 123 | deltas: (N) 124 | ts: (N) 125 | rays_a: (N_rays, 3) ray_idx, start_idx, N_samples 126 | meaning each entry corresponds to the @ray_idx th ray, 127 | whose samples are [start_idx:start_idx+N_samples] 128 | T_threshold: float, stop the ray if the transmittance is below it 129 | 130 | Outputs: 131 | total_samples: int, total effective samples 132 | opacity: (N_rays) 133 | depth: (N_rays) 134 | rgb: (N_rays, 3) 135 | ws: (N) sample point weights 136 | """ 137 | @staticmethod 138 | @custom_fwd(cast_inputs=torch.float32) 139 | def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold): 140 | total_samples, opacity, depth, rgb, ws = \ 141 | vren.composite_train_fw(sigmas, rgbs, deltas, ts, 142 | rays_a, T_threshold) 143 | ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a, 144 | opacity, depth, rgb, ws) 145 | ctx.T_threshold = T_threshold 146 | return total_samples.sum(), opacity, depth, rgb, ws 147 | 148 | @staticmethod 149 | @custom_bwd 150 | def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws): 151 | sigmas, rgbs, deltas, ts, rays_a, \ 152 | opacity, depth, rgb, ws = ctx.saved_tensors 153 | dL_dsigmas, dL_drgbs = \ 154 | vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws, 155 | sigmas, rgbs, ws, deltas, ts, 156 | rays_a, 157 | opacity, depth, rgb, 158 | ctx.T_threshold) 159 | return dL_dsigmas, dL_drgbs, None, None, None, None 160 | 161 | 162 | class TruncExp(torch.autograd.Function): 163 | @staticmethod 164 | @custom_fwd(cast_inputs=torch.float32) 165 | def forward(ctx, x): 166 | ctx.save_for_backward(x) 167 | return torch.exp(x) 168 | 169 | @staticmethod 170 | @custom_bwd 171 | def backward(ctx, dL_dout): 172 | x = ctx.saved_tensors[0] 173 | return dL_dout * torch.exp(x.clamp(-15, 15)) 174 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import tinycudann as tcnn 4 | import vren 5 | from einops import rearrange 6 | from .custom_functions import TruncExp 7 | import numpy as np 8 | 9 | from .rendering import NEAR_DISTANCE 10 | 11 | 12 | class NGP(nn.Module): 13 | def __init__(self, scale, rgb_act='Sigmoid'): 14 | super().__init__() 15 | 16 | self.rgb_act = rgb_act 17 | 18 | # scene bounding box 19 | self.scale = scale 20 | self.register_buffer('center', torch.zeros(1, 3)) 21 | self.register_buffer('xyz_min', -torch.ones(1, 3)*scale) 22 | self.register_buffer('xyz_max', torch.ones(1, 3)*scale) 23 | self.register_buffer('half_size', (self.xyz_max-self.xyz_min)/2) 24 | 25 | # each density grid covers [-2^(k-1), 2^(k-1)]^3 for k in [0, C-1] 26 | self.cascades = max(1+int(np.ceil(np.log2(2*scale))), 1) 27 | self.grid_size = 128 28 | self.register_buffer('density_bitfield', 29 | torch.zeros(self.cascades*self.grid_size**3//8, dtype=torch.uint8)) 30 | 31 | # constants 32 | L = 16; F = 2; log2_T = 19; N_min = 16 33 | b = np.exp(np.log(2048*scale/N_min)/(L-1)) 34 | print(f'GridEncoding: Nmin={N_min} b={b:.5f} F={F} T=2^{log2_T} L={L}') 35 | 36 | self.xyz_encoder = \ 37 | tcnn.NetworkWithInputEncoding( 38 | n_input_dims=3, n_output_dims=16, 39 | encoding_config={ 40 | "otype": "Grid", 41 | "type": "Hash", 42 | "n_levels": L, 43 | "n_features_per_level": F, 44 | "log2_hashmap_size": log2_T, 45 | "base_resolution": N_min, 46 | "per_level_scale": b, 47 | "interpolation": "Linear" 48 | }, 49 | network_config={ 50 | "otype": "FullyFusedMLP", 51 | "activation": "ReLU", 52 | "output_activation": "None", 53 | "n_neurons": 64, 54 | "n_hidden_layers": 1, 55 | } 56 | ) 57 | 58 | self.dir_encoder = \ 59 | tcnn.Encoding( 60 | n_input_dims=3, 61 | encoding_config={ 62 | "otype": "SphericalHarmonics", 63 | "degree": 4, 64 | }, 65 | ) 66 | 67 | self.rgb_net = \ 68 | tcnn.Network( 69 | n_input_dims=32, n_output_dims=3, 70 | network_config={ 71 | "otype": "FullyFusedMLP", 72 | "activation": "ReLU", 73 | "output_activation": self.rgb_act, 74 | "n_neurons": 64, 75 | "n_hidden_layers": 2, 76 | } 77 | ) 78 | 79 | if self.rgb_act == 'None': # rgb_net output is log-radiance 80 | for i in range(3): # independent tonemappers for r,g,b 81 | tonemapper_net = \ 82 | tcnn.Network( 83 | n_input_dims=1, n_output_dims=1, 84 | network_config={ 85 | "otype": "FullyFusedMLP", 86 | "activation": "ReLU", 87 | "output_activation": "Sigmoid", 88 | "n_neurons": 64, 89 | "n_hidden_layers": 1, 90 | } 91 | ) 92 | setattr(self, f'tonemapper_net_{i}', tonemapper_net) 93 | 94 | def density(self, x, return_feat=False): 95 | """ 96 | Inputs: 97 | x: (N, 3) xyz in [-scale, scale] 98 | return_feat: whether to return intermediate feature 99 | 100 | Outputs: 101 | sigmas: (N) 102 | """ 103 | x = (x-self.xyz_min)/(self.xyz_max-self.xyz_min) 104 | h = self.xyz_encoder(x) 105 | sigmas = TruncExp.apply(h[:, 0]) 106 | if return_feat: return sigmas, h 107 | return sigmas 108 | 109 | def log_radiance_to_rgb(self, log_radiances, **kwargs): 110 | """ 111 | Convert log-radiance to rgb as the setting in HDR-NeRF. 112 | Called only when self.rgb_act == 'None' (with exposure) 113 | 114 | Inputs: 115 | log_radiances: (N, 3) 116 | 117 | Outputs: 118 | rgbs: (N, 3) 119 | """ 120 | if 'exposure' in kwargs: 121 | log_exposure = torch.log(kwargs['exposure']) 122 | else: # unit exposure by default 123 | log_exposure = 0 124 | 125 | out = [] 126 | for i in range(3): 127 | inp = log_radiances[:, i:i+1]+log_exposure 128 | out += [getattr(self, f'tonemapper_net_{i}')(inp)] 129 | rgbs = torch.cat(out, 1) 130 | return rgbs 131 | 132 | def forward(self, x, d, **kwargs): 133 | """ 134 | Inputs: 135 | x: (N, 3) xyz in [-scale, scale] 136 | d: (N, 3) directions 137 | 138 | Outputs: 139 | sigmas: (N) 140 | rgbs: (N, 3) 141 | """ 142 | sigmas, h = self.density(x, return_feat=True) 143 | d = d/torch.norm(d, dim=1, keepdim=True) 144 | d = self.dir_encoder((d+1)/2) 145 | rgbs = self.rgb_net(torch.cat([d, h], 1)) 146 | 147 | if self.rgb_act == 'None': # rgbs is log-radiance 148 | if kwargs.get('output_radiance', False): # output HDR map 149 | rgbs = TruncExp.apply(rgbs) 150 | else: # convert to LDR using tonemapper networks 151 | rgbs = self.log_radiance_to_rgb(rgbs, **kwargs) 152 | 153 | return sigmas, rgbs 154 | 155 | @torch.no_grad() 156 | def get_all_cells(self): 157 | """ 158 | Get all cells from the density grid. 159 | 160 | Outputs: 161 | cells: list (of length self.cascades) of indices and coords 162 | selected at each cascade 163 | """ 164 | indices = vren.morton3D(self.grid_coords).long() 165 | cells = [(indices, self.grid_coords)] * self.cascades 166 | 167 | return cells 168 | 169 | @torch.no_grad() 170 | def sample_uniform_and_occupied_cells(self, M, density_threshold): 171 | """ 172 | Sample both M uniform and occupied cells (per cascade) 173 | occupied cells are sample from cells with density > @density_threshold 174 | 175 | Outputs: 176 | cells: list (of length self.cascades) of indices and coords 177 | selected at each cascade 178 | """ 179 | cells = [] 180 | for c in range(self.cascades): 181 | # uniform cells 182 | coords1 = torch.randint(self.grid_size, (M, 3), dtype=torch.int32, 183 | device=self.density_grid.device) 184 | indices1 = vren.morton3D(coords1).long() 185 | # occupied cells 186 | indices2 = torch.nonzero(self.density_grid[c]>density_threshold)[:, 0] 187 | if len(indices2)>0: 188 | rand_idx = torch.randint(len(indices2), (M,), 189 | device=self.density_grid.device) 190 | indices2 = indices2[rand_idx] 191 | coords2 = vren.morton3D_invert(indices2.int()) 192 | # concatenate 193 | cells += [(torch.cat([indices1, indices2]), torch.cat([coords1, coords2]))] 194 | 195 | return cells 196 | 197 | @torch.no_grad() 198 | def mark_invisible_cells(self, K, poses, img_wh, chunk=64**3): 199 | """ 200 | mark the cells that aren't covered by the cameras with density -1 201 | only executed once before training starts 202 | 203 | Inputs: 204 | K: (3, 3) camera intrinsics 205 | poses: (N, 3, 4) camera to world poses 206 | img_wh: image width and height 207 | chunk: the chunk size to split the cells (to avoid OOM) 208 | """ 209 | N_cams = poses.shape[0] 210 | self.count_grid = torch.zeros_like(self.density_grid) 211 | w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3) 212 | w2c_T = -w2c_R@poses[:, :3, 3:] # (N_cams, 3, 1) 213 | cells = self.get_all_cells() 214 | for c in range(self.cascades): 215 | indices, coords = cells[c] 216 | for i in range(0, len(indices), chunk): 217 | xyzs = coords[i:i+chunk]/(self.grid_size-1)*2-1 218 | s = min(2**(c-1), self.scale) 219 | half_grid_size = s/self.grid_size 220 | xyzs_w = (xyzs*(s-half_grid_size)).T # (3, chunk) 221 | xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk) 222 | uvd = K @ xyzs_c # (N_cams, 3, chunk) 223 | uv = uvd[:, :2]/uvd[:, 2:] # (N_cams, 2, chunk) 224 | in_image = (uvd[:, 2]>=0)& \ 225 | (uv[:, 0]>=0)&(uv[:, 0]=0)&(uv[:, 1]=NEAR_DISTANCE)&in_image # (N_cams, chunk) 228 | # if the cell is visible by at least one camera 229 | self.count_grid[c, indices[i:i+chunk]] = \ 230 | count = covered_by_cam.sum(0)/N_cams 231 | 232 | too_near_to_cam = (uvd[:, 2]0)&(~too_near_to_any_cam) 237 | self.density_grid[c, indices[i:i+chunk]] = \ 238 | torch.where(valid_mask, 0., -1.) 239 | 240 | @torch.no_grad() 241 | def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False): 242 | density_grid_tmp = torch.zeros_like(self.density_grid) 243 | if warmup: # during the first steps 244 | cells = self.get_all_cells() 245 | else: 246 | cells = self.sample_uniform_and_occupied_cells(self.grid_size**3//4, 247 | density_threshold) 248 | # infer sigmas 249 | for c in range(self.cascades): 250 | indices, coords = cells[c] 251 | s = min(2**(c-1), self.scale) 252 | half_grid_size = s/self.grid_size 253 | xyzs_w = (coords/(self.grid_size-1)*2-1)*(s-half_grid_size) 254 | # pick random position in the cell by adding noise in [-hgs, hgs] 255 | xyzs_w += (torch.rand_like(xyzs_w)*2-1) * half_grid_size 256 | density_grid_tmp[c, indices] = self.density(xyzs_w) 257 | 258 | if erode: 259 | # My own logic. decay more the cells that are visible to few cameras 260 | decay = torch.clamp(decay**(1/self.count_grid), 0.1, 0.95) 261 | self.density_grid = \ 262 | torch.where(self.density_grid<0, 263 | self.density_grid, 264 | torch.maximum(self.density_grid*decay, density_grid_tmp)) 265 | 266 | mean_density = self.density_grid[self.density_grid>0].mean().item() 267 | 268 | vren.packbits(self.density_grid, min(mean_density, density_threshold), 269 | self.density_bitfield) 270 | -------------------------------------------------------------------------------- /models/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .custom_functions import \ 3 | RayAABBIntersector, RayMarcher, VolumeRenderer 4 | from einops import rearrange 5 | import vren 6 | 7 | MAX_SAMPLES = 1024 8 | NEAR_DISTANCE = 0.01 9 | 10 | 11 | @torch.cuda.amp.autocast() 12 | def render(model, rays_o, rays_d, **kwargs): 13 | """ 14 | Render rays by 15 | 1. Compute the intersection of the rays with the scene bounding box 16 | 2. Follow the process in @render_func (different for train/test) 17 | 18 | Inputs: 19 | model: NGP 20 | rays_o: (N_rays, 3) ray origins 21 | rays_d: (N_rays, 3) ray directions 22 | 23 | Outputs: 24 | result: dictionary containing final rgb and depth 25 | """ 26 | rays_o = rays_o.contiguous(); rays_d = rays_d.contiguous() 27 | _, hits_t, _ = \ 28 | RayAABBIntersector.apply(rays_o, rays_d, model.center, model.half_size, 1) 29 | hits_t[(hits_t[:, 0, 0]>=0)&(hits_t[:, 0, 0] (n1 n2) c') 90 | dirs = rearrange(dirs, 'n1 n2 c -> (n1 n2) c') 91 | valid_mask = ~torch.all(dirs==0, dim=1) 92 | if valid_mask.sum()==0: break 93 | 94 | sigmas = torch.zeros(len(xyzs), device=device) 95 | rgbs = torch.zeros(len(xyzs), 3, device=device) 96 | sigmas[valid_mask], _rgbs = model(xyzs[valid_mask], dirs[valid_mask], **kwargs) 97 | rgbs[valid_mask] = _rgbs.float() 98 | sigmas = rearrange(sigmas, '(n1 n2) -> n1 n2', n2=N_samples) 99 | rgbs = rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=N_samples) 100 | 101 | vren.composite_test_fw( 102 | sigmas, rgbs, deltas, ts, 103 | hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4), 104 | N_eff_samples, opacity, depth, rgb) 105 | alive_indices = alive_indices[alive_indices>=0] # remove converged rays 106 | 107 | results['opacity'] = opacity 108 | results['depth'] = depth 109 | results['rgb'] = rgb 110 | results['total_samples'] = total_samples # total samples for all rays 111 | 112 | if exp_step_factor==0: # synthetic 113 | rgb_bg = torch.ones(3, device=device) 114 | else: # real 115 | rgb_bg = torch.zeros(3, device=device) 116 | results['rgb'] += rgb_bg*rearrange(1-opacity, 'n -> n 1') 117 | 118 | return results 119 | 120 | 121 | def __render_rays_train(model, rays_o, rays_d, hits_t, **kwargs): 122 | """ 123 | Render rays by 124 | 1. March the rays along their directions, querying @density_bitfield 125 | to skip empty space, and get the effective sample points (where 126 | there is object) 127 | 2. Infer the NN at these positions and view directions to get properties 128 | (currently sigmas and rgbs) 129 | 3. Use volume rendering to combine the result (front to back compositing 130 | and early stop the ray if its transmittance is below a threshold) 131 | """ 132 | exp_step_factor = kwargs.get('exp_step_factor', 0.) 133 | results = {} 134 | 135 | (rays_a, xyzs, dirs, 136 | results['deltas'], results['ts'], results['rm_samples']) = \ 137 | RayMarcher.apply( 138 | rays_o, rays_d, hits_t[:, 0], model.density_bitfield, 139 | model.cascades, model.scale, 140 | exp_step_factor, model.grid_size, MAX_SAMPLES) 141 | 142 | for k, v in kwargs.items(): # supply additional inputs, repeated per ray 143 | if isinstance(v, torch.Tensor): 144 | kwargs[k] = torch.repeat_interleave(v[rays_a[:, 0]], rays_a[:, 2], 0) 145 | sigmas, rgbs = model(xyzs, dirs, **kwargs) 146 | 147 | (results['vr_samples'], results['opacity'], 148 | results['depth'], results['rgb'], results['ws']) = \ 149 | VolumeRenderer.apply(sigmas, rgbs.contiguous(), results['deltas'], results['ts'], 150 | rays_a, kwargs.get('T_threshold', 1e-4)) 151 | results['rays_a'] = rays_a 152 | 153 | if exp_step_factor==0: # synthetic 154 | rgb_bg = torch.ones(3, device=rays_o.device) 155 | else: # real 156 | if kwargs.get('random_bg', False): 157 | rgb_bg = torch.rand(3, device=rays_o.device) 158 | else: 159 | rgb_bg = torch.zeros(3, device=rays_o.device) 160 | results['rgb'] = results['rgb'] + \ 161 | rgb_bg*rearrange(1-results['opacity'], 'n -> n 1') 162 | 163 | return results 164 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # dataset parameters 7 | parser.add_argument('--root_dir', type=str, required=True, 8 | help='root directory of dataset') 9 | parser.add_argument('--dataset_name', type=str, default='nsvf', 10 | choices=['nerf', 'nsvf', 'colmap', 'nerfpp', 'rtmv'], 11 | help='which dataset to train/test') 12 | parser.add_argument('--split', type=str, default='train', 13 | choices=['train', 'trainval', 'trainvaltest'], 14 | help='use which split to train') 15 | parser.add_argument('--downsample', type=float, default=1.0, 16 | help='downsample factor (<=1.0) for the images') 17 | 18 | # model parameters 19 | parser.add_argument('--scale', type=float, default=0.5, 20 | help='scene scale (whole scene must lie in [-scale, scale]^3') 21 | parser.add_argument('--use_exposure', action='store_true', default=False, 22 | help='whether to train in HDR-NeRF setting') 23 | 24 | # loss parameters 25 | parser.add_argument('--distortion_loss_w', type=float, default=0, 26 | help='''weight of distortion loss (see losses.py), 27 | 0 to disable (default), to enable, 28 | a good value is 1e-3 for real scene and 1e-2 for synthetic scene 29 | ''') 30 | 31 | # training options 32 | parser.add_argument('--batch_size', type=int, default=8192, 33 | help='number of rays in a batch') 34 | parser.add_argument('--ray_sampling_strategy', type=str, default='all_images', 35 | choices=['all_images', 'same_image'], 36 | help=''' 37 | all_images: uniformly from all pixels of ALL images 38 | same_image: uniformly from all pixels of a SAME image 39 | ''') 40 | parser.add_argument('--num_epochs', type=int, default=30, 41 | help='number of training epochs') 42 | parser.add_argument('--num_gpus', type=int, default=1, 43 | help='number of gpus') 44 | parser.add_argument('--lr', type=float, default=1e-2, 45 | help='learning rate') 46 | # experimental training options 47 | parser.add_argument('--optimize_ext', action='store_true', default=False, 48 | help='whether to optimize extrinsics') 49 | parser.add_argument('--random_bg', action='store_true', default=False, 50 | help='''whether to train with random bg color (real scene only) 51 | to avoid objects with black color to be predicted as transparent 52 | ''') 53 | 54 | # validation options 55 | parser.add_argument('--eval_lpips', action='store_true', default=False, 56 | help='evaluate lpips metric (consumes more VRAM)') 57 | parser.add_argument('--val_only', action='store_true', default=False, 58 | help='run only validation (need to provide ckpt_path)') 59 | parser.add_argument('--no_save_test', action='store_true', default=False, 60 | help='whether to save test image and video') 61 | 62 | # misc 63 | parser.add_argument('--exp_name', type=str, default='exp', 64 | help='experiment name') 65 | parser.add_argument('--ckpt_path', type=str, default=None, 66 | help='pretrained checkpoint to load (including optimizers, etc)') 67 | parser.add_argument('--weight_path', type=str, default=None, 68 | help='pretrained checkpoint to load (excluding optimizers, etc)') 69 | 70 | return parser.parse_args() 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | kornia==0.6.5 3 | pytorch-lightning==1.7.7 4 | matplotlib==3.5.2 5 | opencv-python==4.6.0.66 6 | lpips 7 | imageio 8 | imageio-ffmpeg 9 | jupyter 10 | scipy 11 | pymcubes 12 | trimesh 13 | dearpygui -------------------------------------------------------------------------------- /show_gui.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from opt import get_opts 3 | import numpy as np 4 | from einops import rearrange 5 | import dearpygui.dearpygui as dpg 6 | from scipy.spatial.transform import Rotation as R 7 | import time 8 | 9 | from datasets import dataset_dict 10 | from datasets.ray_utils import get_ray_directions, get_rays 11 | from models.networks import NGP 12 | from models.rendering import render 13 | from train import depth2img 14 | from utils import load_ckpt 15 | 16 | import warnings; warnings.filterwarnings("ignore") 17 | 18 | 19 | class OrbitCamera: 20 | def __init__(self, K, img_wh, r): 21 | self.K = K 22 | self.W, self.H = img_wh 23 | self.radius = r 24 | self.center = np.zeros(3) 25 | self.rot = np.eye(3) 26 | 27 | @property 28 | def pose(self): 29 | # first move camera to radius 30 | res = np.eye(4) 31 | res[2, 3] -= self.radius 32 | # rotate 33 | rot = np.eye(4) 34 | rot[:3, :3] = self.rot 35 | res = rot @ res 36 | # translate 37 | res[:3, 3] -= self.center 38 | return res 39 | 40 | def orbit(self, dx, dy): 41 | rotvec_x = self.rot[:, 1] * np.radians(0.05 * dx) 42 | rotvec_y = self.rot[:, 0] * np.radians(-0.05 * dy) 43 | self.rot = R.from_rotvec(rotvec_y).as_matrix() @ \ 44 | R.from_rotvec(rotvec_x).as_matrix() @ \ 45 | self.rot 46 | 47 | def scale(self, delta): 48 | self.radius *= 1.1 ** (-delta) 49 | 50 | def pan(self, dx, dy, dz=0): 51 | self.center += 1e-4 * self.rot @ np.array([dx, dy, dz]) 52 | 53 | 54 | class NGPGUI: 55 | def __init__(self, hparams, K, img_wh, radius=2.5): 56 | self.hparams = hparams 57 | rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid' 58 | self.model = NGP(scale=hparams.scale, rgb_act=rgb_act).cuda() 59 | load_ckpt(self.model, hparams.ckpt_path) 60 | 61 | self.cam = OrbitCamera(K, img_wh, r=radius) 62 | self.W, self.H = img_wh 63 | self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32) 64 | 65 | # placeholders 66 | self.dt = 0 67 | self.mean_samples = 0 68 | self.img_mode = 0 69 | 70 | self.register_dpg() 71 | 72 | def render_cam(self, cam): 73 | t = time.time() 74 | directions = get_ray_directions(cam.H, cam.W, cam.K, device='cuda') 75 | rays_o, rays_d = get_rays(directions, torch.cuda.FloatTensor(cam.pose)) 76 | 77 | # TODO: set these attributes by gui 78 | if self.hparams.dataset_name in ['colmap', 'nerfpp']: 79 | exp_step_factor = 1/256 80 | else: exp_step_factor = 0 81 | 82 | results = render(self.model, rays_o, rays_d, 83 | **{'test_time': True, 84 | 'to_cpu': True, 'to_numpy': True, 85 | 'T_threshold': 1e-2, 86 | 'exposure': torch.cuda.FloatTensor([dpg.get_value('_exposure')]), 87 | 'max_samples': 100, 88 | 'exp_step_factor': exp_step_factor}) 89 | 90 | rgb = rearrange(results["rgb"], "(h w) c -> h w c", h=self.H) 91 | depth = rearrange(results["depth"], "(h w) -> h w", h=self.H) 92 | torch.cuda.synchronize() 93 | self.dt = time.time()-t 94 | self.mean_samples = results['total_samples']/len(rays_o) 95 | 96 | if self.img_mode == 0: 97 | return rgb 98 | elif self.img_mode == 1: 99 | return depth2img(depth).astype(np.float32)/255.0 100 | 101 | def register_dpg(self): 102 | dpg.create_context() 103 | dpg.create_viewport(title="ngp_pl", width=self.W, height=self.H, resizable=False) 104 | 105 | ## register texture ## 106 | with dpg.texture_registry(show=False): 107 | dpg.add_raw_texture( 108 | self.W, 109 | self.H, 110 | self.render_buffer, 111 | format=dpg.mvFormat_Float_rgb, 112 | tag="_texture") 113 | 114 | ## register window ## 115 | with dpg.window(tag="_primary_window", width=self.W, height=self.H): 116 | dpg.add_image("_texture") 117 | dpg.set_primary_window("_primary_window", True) 118 | 119 | def callback_depth(sender, app_data): 120 | self.img_mode = 1-self.img_mode 121 | 122 | ## control window ## 123 | with dpg.window(label="Control", tag="_control_window", width=200, height=150): 124 | dpg.add_slider_float(label="exposure", default_value=0.2, 125 | min_value=1/60, max_value=32, tag="_exposure") 126 | dpg.add_button(label="show depth", tag="_button_depth", 127 | callback=callback_depth) 128 | dpg.add_separator() 129 | dpg.add_text('no data', tag="_log_time") 130 | dpg.add_text('no data', tag="_samples_per_ray") 131 | 132 | ## register camera handler ## 133 | def callback_camera_drag_rotate(sender, app_data): 134 | if not dpg.is_item_focused("_primary_window"): 135 | return 136 | self.cam.orbit(app_data[1], app_data[2]) 137 | 138 | def callback_camera_wheel_scale(sender, app_data): 139 | if not dpg.is_item_focused("_primary_window"): 140 | return 141 | self.cam.scale(app_data) 142 | 143 | def callback_camera_drag_pan(sender, app_data): 144 | if not dpg.is_item_focused("_primary_window"): 145 | return 146 | self.cam.pan(app_data[1], app_data[2]) 147 | 148 | with dpg.handler_registry(): 149 | dpg.add_mouse_drag_handler( 150 | button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate 151 | ) 152 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 153 | dpg.add_mouse_drag_handler( 154 | button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan 155 | ) 156 | 157 | ## Avoid scroll bar in the window ## 158 | with dpg.theme() as theme_no_padding: 159 | with dpg.theme_component(dpg.mvAll): 160 | dpg.add_theme_style( 161 | dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core 162 | ) 163 | dpg.add_theme_style( 164 | dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core 165 | ) 166 | dpg.add_theme_style( 167 | dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core 168 | ) 169 | dpg.bind_item_theme("_primary_window", theme_no_padding) 170 | 171 | ## Launch the gui ## 172 | dpg.setup_dearpygui() 173 | dpg.set_viewport_small_icon("assets/icon.png") 174 | dpg.set_viewport_large_icon("assets/icon.png") 175 | dpg.show_viewport() 176 | 177 | def render(self): 178 | while dpg.is_dearpygui_running(): 179 | dpg.set_value("_texture", self.render_cam(self.cam)) 180 | dpg.set_value("_log_time", f'Render time: {1000*self.dt:.2f} ms') 181 | dpg.set_value("_samples_per_ray", f'Samples/ray: {self.mean_samples:.2f}') 182 | dpg.render_dearpygui_frame() 183 | 184 | 185 | if __name__ == "__main__": 186 | hparams = get_opts() 187 | kwargs = {'root_dir': hparams.root_dir, 188 | 'downsample': hparams.downsample, 189 | 'read_meta': False} 190 | dataset = dataset_dict[hparams.dataset_name](**kwargs) 191 | 192 | NGPGUI(hparams, dataset.K, dataset.img_wh).render() 193 | dpg.destroy_context() 194 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from opt import get_opts 4 | import os 5 | import glob 6 | import imageio 7 | import numpy as np 8 | import cv2 9 | from einops import rearrange 10 | 11 | # data 12 | from torch.utils.data import DataLoader 13 | from datasets import dataset_dict 14 | from datasets.ray_utils import axisangle_to_R, get_rays 15 | 16 | # models 17 | from kornia.utils.grid import create_meshgrid3d 18 | from models.networks import NGP 19 | from models.rendering import render, MAX_SAMPLES 20 | 21 | # optimizer, losses 22 | from apex.optimizers import FusedAdam 23 | from torch.optim.lr_scheduler import CosineAnnealingLR 24 | from losses import NeRFLoss 25 | 26 | # metrics 27 | from torchmetrics import ( 28 | PeakSignalNoiseRatio, 29 | StructuralSimilarityIndexMeasure 30 | ) 31 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 32 | 33 | # pytorch-lightning 34 | from pytorch_lightning.plugins import DDPPlugin 35 | from pytorch_lightning import LightningModule, Trainer 36 | from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint 37 | from pytorch_lightning.loggers import TensorBoardLogger 38 | from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available 39 | 40 | from utils import slim_ckpt, load_ckpt 41 | 42 | import warnings; warnings.filterwarnings("ignore") 43 | 44 | 45 | def depth2img(depth): 46 | depth = (depth-depth.min())/(depth.max()-depth.min()) 47 | depth_img = cv2.applyColorMap((depth*255).astype(np.uint8), 48 | cv2.COLORMAP_TURBO) 49 | 50 | return depth_img 51 | 52 | 53 | class NeRFSystem(LightningModule): 54 | def __init__(self, hparams): 55 | super().__init__() 56 | self.save_hyperparameters(hparams) 57 | 58 | self.warmup_steps = 256 59 | self.update_interval = 16 60 | 61 | self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w) 62 | self.train_psnr = PeakSignalNoiseRatio(data_range=1) 63 | self.val_psnr = PeakSignalNoiseRatio(data_range=1) 64 | self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1) 65 | if self.hparams.eval_lpips: 66 | self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg') 67 | for p in self.val_lpips.net.parameters(): 68 | p.requires_grad = False 69 | 70 | rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid' 71 | self.model = NGP(scale=self.hparams.scale, rgb_act=rgb_act) 72 | G = self.model.grid_size 73 | self.model.register_buffer('density_grid', 74 | torch.zeros(self.model.cascades, G**3)) 75 | self.model.register_buffer('grid_coords', 76 | create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3)) 77 | 78 | def forward(self, batch, split): 79 | if split=='train': 80 | poses = self.poses[batch['img_idxs']] 81 | directions = self.directions[batch['pix_idxs']] 82 | else: 83 | poses = batch['pose'] 84 | directions = self.directions 85 | 86 | if self.hparams.optimize_ext: 87 | dR = axisangle_to_R(self.dR[batch['img_idxs']]) 88 | poses[..., :3] = dR @ poses[..., :3] 89 | poses[..., 3] += self.dT[batch['img_idxs']] 90 | 91 | rays_o, rays_d = get_rays(directions, poses) 92 | 93 | kwargs = {'test_time': split!='train', 94 | 'random_bg': self.hparams.random_bg} 95 | if self.hparams.scale > 0.5: 96 | kwargs['exp_step_factor'] = 1/256 97 | if self.hparams.use_exposure: 98 | kwargs['exposure'] = batch['exposure'] 99 | 100 | return render(self.model, rays_o, rays_d, **kwargs) 101 | 102 | def setup(self, stage): 103 | dataset = dataset_dict[self.hparams.dataset_name] 104 | kwargs = {'root_dir': self.hparams.root_dir, 105 | 'downsample': self.hparams.downsample} 106 | self.train_dataset = dataset(split=self.hparams.split, **kwargs) 107 | self.train_dataset.batch_size = self.hparams.batch_size 108 | self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy 109 | 110 | self.test_dataset = dataset(split='test', **kwargs) 111 | 112 | def configure_optimizers(self): 113 | # define additional parameters 114 | self.register_buffer('directions', self.train_dataset.directions.to(self.device)) 115 | self.register_buffer('poses', self.train_dataset.poses.to(self.device)) 116 | 117 | if self.hparams.optimize_ext: 118 | N = len(self.train_dataset.poses) 119 | self.register_parameter('dR', 120 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 121 | self.register_parameter('dT', 122 | nn.Parameter(torch.zeros(N, 3, device=self.device))) 123 | 124 | load_ckpt(self.model, self.hparams.weight_path) 125 | 126 | net_params = [] 127 | for n, p in self.named_parameters(): 128 | if n not in ['dR', 'dT']: net_params += [p] 129 | 130 | opts = [] 131 | self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15) 132 | opts += [self.net_opt] 133 | if self.hparams.optimize_ext: 134 | opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded 135 | net_sch = CosineAnnealingLR(self.net_opt, 136 | self.hparams.num_epochs, 137 | self.hparams.lr/30) 138 | 139 | return opts, [net_sch] 140 | 141 | def train_dataloader(self): 142 | return DataLoader(self.train_dataset, 143 | num_workers=16, 144 | persistent_workers=True, 145 | batch_size=None, 146 | pin_memory=True) 147 | 148 | def val_dataloader(self): 149 | return DataLoader(self.test_dataset, 150 | num_workers=8, 151 | batch_size=None, 152 | pin_memory=True) 153 | 154 | def on_train_start(self): 155 | self.model.mark_invisible_cells(self.train_dataset.K.to(self.device), 156 | self.poses, 157 | self.train_dataset.img_wh) 158 | 159 | def training_step(self, batch, batch_nb, *args): 160 | if self.global_step%self.update_interval == 0: 161 | self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5, 162 | warmup=self.global_step 1 c h w', h=h) 205 | rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h) 206 | self.val_ssim(rgb_pred, rgb_gt) 207 | logs['ssim'] = self.val_ssim.compute() 208 | self.val_ssim.reset() 209 | if self.hparams.eval_lpips: 210 | self.val_lpips(torch.clip(rgb_pred*2-1, -1, 1), 211 | torch.clip(rgb_gt*2-1, -1, 1)) 212 | logs['lpips'] = self.val_lpips.compute() 213 | self.val_lpips.reset() 214 | 215 | if not self.hparams.no_save_test: # save test image to disk 216 | idx = batch['img_idxs'] 217 | rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h) 218 | rgb_pred = (rgb_pred*255).astype(np.uint8) 219 | depth = depth2img(rearrange(results['depth'].cpu().numpy(), '(h w) -> h w', h=h)) 220 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred) 221 | imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}_d.png'), depth) 222 | 223 | return logs 224 | 225 | def validation_epoch_end(self, outputs): 226 | psnrs = torch.stack([x['psnr'] for x in outputs]) 227 | mean_psnr = all_gather_ddp_if_available(psnrs).mean() 228 | self.log('test/psnr', mean_psnr, True) 229 | 230 | ssims = torch.stack([x['ssim'] for x in outputs]) 231 | mean_ssim = all_gather_ddp_if_available(ssims).mean() 232 | self.log('test/ssim', mean_ssim) 233 | 234 | if self.hparams.eval_lpips: 235 | lpipss = torch.stack([x['lpips'] for x in outputs]) 236 | mean_lpips = all_gather_ddp_if_available(lpipss).mean() 237 | self.log('test/lpips_vgg', mean_lpips) 238 | 239 | def get_progress_bar_dict(self): 240 | # don't show the version number 241 | items = super().get_progress_bar_dict() 242 | items.pop("v_num", None) 243 | return items 244 | 245 | 246 | if __name__ == '__main__': 247 | hparams = get_opts() 248 | if hparams.val_only and (not hparams.ckpt_path): 249 | raise ValueError('You need to provide a @ckpt_path for validation!') 250 | system = NeRFSystem(hparams) 251 | 252 | ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.dataset_name}/{hparams.exp_name}', 253 | filename='{epoch:d}', 254 | save_weights_only=True, 255 | every_n_epochs=hparams.num_epochs, 256 | save_on_train_epoch_end=True, 257 | save_top_k=-1) 258 | callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)] 259 | 260 | logger = TensorBoardLogger(save_dir=f"logs/{hparams.dataset_name}", 261 | name=hparams.exp_name, 262 | default_hp_metric=False) 263 | 264 | trainer = Trainer(max_epochs=hparams.num_epochs, 265 | check_val_every_n_epoch=hparams.num_epochs, 266 | callbacks=callbacks, 267 | logger=logger, 268 | enable_model_summary=False, 269 | accelerator='gpu', 270 | devices=hparams.num_gpus, 271 | strategy=DDPPlugin(find_unused_parameters=False) 272 | if hparams.num_gpus>1 else None, 273 | num_sanity_val_steps=-1 if hparams.val_only else 0, 274 | precision=16) 275 | 276 | trainer.fit(system, ckpt_path=hparams.ckpt_path) 277 | 278 | if not hparams.val_only: # save slimmed ckpt for the last epoch 279 | ckpt_ = \ 280 | slim_ckpt(f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}.ckpt', 281 | save_poses=hparams.optimize_ext) 282 | torch.save(ckpt_, f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs-1}_slim.ckpt') 283 | 284 | if (not hparams.no_save_test) and \ 285 | hparams.dataset_name=='nsvf' and \ 286 | 'Synthetic' in hparams.root_dir: # save video 287 | imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png'))) 288 | imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'), 289 | [imageio.imread(img) for img in imgs[::2]], 290 | fps=30, macro_block_size=1) 291 | imageio.mimsave(os.path.join(system.val_dir, 'depth.mp4'), 292 | [imageio.imread(img) for img in imgs[1::2]], 293 | fps=30, macro_block_size=1) 294 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 5 | checkpoint = torch.load(ckpt_path, map_location='cpu') 6 | checkpoint_ = {} 7 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 8 | checkpoint = checkpoint['state_dict'] 9 | for k, v in checkpoint.items(): 10 | if not k.startswith(model_name): 11 | continue 12 | k = k[len(model_name)+1:] 13 | for prefix in prefixes_to_ignore: 14 | if k.startswith(prefix): 15 | break 16 | else: 17 | checkpoint_[k] = v 18 | return checkpoint_ 19 | 20 | 21 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 22 | if not ckpt_path: return 23 | model_dict = model.state_dict() 24 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 25 | model_dict.update(checkpoint_) 26 | model.load_state_dict(model_dict) 27 | 28 | 29 | def slim_ckpt(ckpt_path, save_poses=False): 30 | ckpt = torch.load(ckpt_path, map_location='cpu') 31 | # pop unused parameters 32 | keys_to_pop = ['directions', 'model.density_grid', 'model.grid_coords'] 33 | if not save_poses: keys_to_pop += ['poses'] 34 | for k in ckpt['state_dict']: 35 | if k.startswith('val_lpips'): 36 | keys_to_pop += [k] 37 | for k in keys_to_pop: 38 | ckpt['state_dict'].pop(k, None) 39 | return ckpt['state_dict'] 40 | --------------------------------------------------------------------------------