├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── configs ├── nerf-blender.yaml ├── nerf-colmap.yaml ├── neuralangelo-dtu-wmask.yaml ├── neus-blender.yaml ├── neus-colmap.yaml ├── neus-dtu-wmask.yaml └── neus-dtu.yaml ├── datasets ├── __init__.py ├── blender.py ├── colmap.py ├── colmap_utils.py ├── dtu.py └── utils.py ├── launch.py ├── models ├── __init__.py ├── base.py ├── geometry.py ├── nerf.py ├── network_utils.py ├── neus.py ├── ray_utils.py ├── texture.py └── utils.py ├── requirements.txt ├── scripts └── imgs2poses.py ├── systems ├── __init__.py ├── base.py ├── criterions.py ├── nerf.py ├── neus.py └── utils.py └── utils ├── __init__.py ├── callbacks.py ├── loggers.py ├── misc.py ├── mixins.py └── obj.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | # End of https://www.toptal.com/developers/gitignore/api/python 146 | 147 | .DS_Store 148 | .vscode/ 149 | exp/ 150 | runs/ 151 | load/ 152 | extern/ 153 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | disable=R,C 2 | 3 | [TYPECHECK] 4 | # List of members which are set dynamically and missed by pylint inference 5 | # system, and so shouldn't trigger E1101 when accessed. Python regular 6 | # expressions are accepted. 7 | generated-members=numpy.*,torch.*,cv2.* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yuanchen Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instant Neural Surface Reconstruction 2 | 3 | This repository contains a concise and extensible implementation of NeRF and NeuS for neural surface reconstruction based on Instant-NGP and the Pytorch-Lightning framework. **Training on a NeRF-Synthetic scene takes ~5min for NeRF and ~10min for NeuS on a single RTX3090.** 4 | 5 | ||NeRF in 5min|NeuS in 10 min| 6 | |---|---|---| 7 | |Rendering|![rendering-nerf](https://user-images.githubusercontent.com/19284678/199078178-b719676b-7e60-47f1-813b-c0b533f5480d.png)|![rendering-neus](https://user-images.githubusercontent.com/19284678/199078300-ebcf249d-b05e-431f-b035-da354705d8db.png)| 8 | |Mesh|![mesh-nerf](https://user-images.githubusercontent.com/19284678/199078661-b5cd569a-c22b-4220-9c11-d5fd13a52fb8.png)|![mesh-neus](https://user-images.githubusercontent.com/19284678/199078481-164e36a6-6d55-45cc-aaf3-795a114e4a38.png)| 9 | 10 | 11 | ## Features 12 | **This repository aims to provide a highly efficient while customizable boilerplate for research projects based on NeRF or NeuS.** 13 | 14 | - acceleration techniques from [Instant-NGP](https://github.com/NVlabs/instant-ngp): multiresolution hash encoding and fully fused networks by [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), occupancy grid pruning and rendering by [nerfacc](https://github.com/KAIR-BAIR/nerfacc) 15 | - out-of-the-box multi-GPU and mixed precision training by [PyTorch-Lightning](https://github.com/Lightning-AI/lightning) 16 | - hierarchical project layout that is designed to be easily customized and extended, flexible experiment configuration by [OmegaConf](https://github.com/omry/omegaconf) 17 | 18 | **Please subscribe to [#26](https://github.com/bennyguo/instant-nsr-pl/issues/26) for our latest findings on quality improvements!** 19 | 20 | ## News 21 | 22 | 🔥🔥🔥 Check out my new project on 3D content generation: https://github.com/threestudio-project/threestudio 🔥🔥🔥 23 | 24 | - 06/03/2023: Add an implementation of [Neuralangelo](https://research.nvidia.com/labs/dir/neuralangelo/). See [here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU) for details. 25 | - 03/31/2023: NeuS model now supports background modeling. You could try on the DTU dataset provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa) following [the instruction here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU). 26 | - 02/11/2023: NeRF model now supports unbounded 360 scenes with learned background. You could try on [MipNeRF 360 data](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) following [the COLMAP configuration](https://github.com/bennyguo/instant-nsr-pl#training-on-custom-colmap-data). 27 | 28 | ## Requirements 29 | **Note:** 30 | - To utilize multiresolution hash encoding or fully fused networks provided by tiny-cuda-nn, you should have least an RTX 2080Ti, see [https://github.com/NVlabs/tiny-cuda-nn#requirements](https://github.com/NVlabs/tiny-cuda-nn#requirements) for more details. 31 | - Multi-GPU training is currently not supported on Windows (see [#4](https://github.com/bennyguo/instant-nsr-pl/issues/4)). 32 | ### Environments 33 | - Install PyTorch>=1.10 [here](https://pytorch.org/get-started/locally/) based the package management tool you used and your cuda version (older PyTorch versions may work but have not been tested) 34 | - Install tiny-cuda-nn PyTorch extension: `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch` 35 | - `pip install -r requirements.txt` 36 | 37 | 38 | ## Run 39 | ### Training on NeRF-Synthetic 40 | Download the NeRF-Synthetic data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `load/`. The file structure should be like `load/nerf_synthetic/lego`. 41 | 42 | Run the launch script with `--train`, specifying the config file, the GPU(s) to be used (GPU 0 will be used by default), and the scene name: 43 | ```bash 44 | # train NeRF 45 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=example 46 | 47 | # train NeuS with mask 48 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example 49 | # train NeuS without mask 50 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example system.loss.lambda_mask=0.0 51 | ``` 52 | The code snapshots, checkpoints and experiment outputs are saved to `exp/[name]/[tag]@[timestamp]`, and tensorboard logs can be found at `runs/[name]/[tag]@[timestamp]`. You can change any configuration in the YAML file by specifying arguments without `--`, for example: 53 | ```bash 54 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=iter50k seed=0 trainer.max_steps=50000 55 | ``` 56 | ### Training on DTU 57 | Download preprocessed DTU data provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa). In the provided config files we assume using NeuS DTU data. If you are using IDR DTU data, please set `dataset.cameras_file=cameras.npz`. You may also need to adjust `dataset.root_dir` to point to your downloaded data location. 58 | ```bash 59 | # train NeuS on DTU without mask 60 | python launch.py --config configs/neus-dtu.yaml --gpu 0 --train 61 | # train NeuS on DTU with mask 62 | python launch.py --config configs/neus-dtu-wmask.yaml --gpu 0 --train 63 | # train NeuS on DTU with mask using tricks from Neuralangelo (experimental) 64 | python launch.py --config configs/neuralangelo-dtu-wmask.yaml --gpu 0 --train 65 | ``` 66 | Notes: 67 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. 68 | - The results of Neuralangelo can't reach those in the original paper. Some potential improvements: more iterations; larger `system.geometry.xyz_encoding_config.update_steps`; larger `system.geometry.xyz_encoding_config.n_features_per_level`; larger `system.geometry.xyz_encoding_config.log2_hashmap_size`; adopting curvature loss. 69 | 70 | ### Training on Custom COLMAP Data 71 | To get COLMAP data from custom images, you should have COLMAP installed (see [here](https://colmap.github.io/install.html) for installation instructions). Then put your images in the `images/` folder, and run `scripts/imgs2poses.py` specifying the path containing the `images/` folder. For example: 72 | ```bash 73 | python scripts/imgs2poses.py ./load/bmvs_dog # images are in ./load/bmvs_dog/images 74 | ``` 75 | Existing data following this file structure also works as long as images are store in `images/` and there is a `sparse/` folder for the COLMAP output, for example [the data provided by MipNeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip). An optional `masks/` folder could be provided for object mask supervision. To train on COLMAP data, please refer to the example config files `config/*-colmap.yaml`. Some notes: 76 | - Adapt the `root_dir` and `img_wh` (or `img_downscale`) option in the config file to your data; 77 | - The scene is normalized so that cameras have a minimum distance `1.0` to the center of the scene. Setting `model.radius=1.0` works in most cases. If not, try setting a smaller radius that wraps tightly to your foreground object. 78 | - There are three choices to determine the scene center: `dataset.center_est_method=camera` uses the center of all camera positions as the scene center; `dataset.center_est_method=lookat` assumes the cameras are looking at the same point and calculates an approximate look-at point as the scene center; `dataset.center_est_method=point` uses the center of all points (reconstructed by COLMAP) that are bounded by cameras as the scene center. Please choose an appropriate method according to your capture. 79 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. 80 | 81 | ### Testing 82 | The training procedure are by default followed by testing, which computes metrics on test data, generates animations and exports the geometry as triangular meshes. If you want to do testing alone, just resume the pretrained model and replace `--train` with `--test`, for example: 83 | ```bash 84 | python launch.py --config path/to/your/exp/config/parsed.yaml --resume path/to/your/exp/ckpt/epoch=0-step=20000.ckpt --gpu 0 --test 85 | ``` 86 | 87 | 88 | ## Benchmarks 89 | All experiments are conducted on a single NVIDIA RTX3090. 90 | 91 | |PSNR|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| 92 | |---|---|---|---|---|---|---|---|---|---| 93 | |NeRF Paper|33.00|25.01|30.13|36.18|32.54|29.62|32.91|28.65|31.01| 94 | |NeRF Ours (20k)|34.80|26.04|33.89|37.42|35.33|29.46|35.22|31.17|32.92| 95 | |NeuS Ours (20k, with masks)|34.04|25.26|32.47|35.94|33.78|27.67|33.43|29.50|31.51| 96 | 97 | |Training Time (mm:ss)|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| 98 | |---|---|---|---|---|---|---|---|---|---| 99 | |NeRF Ours (20k)|04:34|04:35|04:18|04:46|04:39|04:35|04:26|05:41|04:42| 100 | |NeuS Ours (20k, with masks)|11:25|10:34|09:51|12:11|11:37|11:46|09:59|16:25|11:44| 101 | 102 | 103 | ## TODO 104 | - [✅] Support more dataset formats, like COLMAP outputs and DTU 105 | - [✅] Support simple background model 106 | - [ ] Support GUI training and interaction 107 | - [ ] More illustrations about the framework 108 | 109 | ## Related Projects 110 | - [ngp_pl](https://github.com/kwea123/ngp_pl): Great Instant-NGP implementation in PyTorch-Lightning! Background model and GUI supported. 111 | - [Instant-NSR](https://github.com/zhaofuq/Instant-NSR): NeuS implementation using multiresolution hash encoding. 112 | 113 | ## Citation 114 | If you find this codebase useful, please consider citing: 115 | ``` 116 | @misc{instant-nsr-pl, 117 | Author = {Yuan-Chen Guo}, 118 | Year = {2022}, 119 | Note = {https://github.com/bennyguo/instant-nsr-pl}, 120 | Title = {Instant Neural Surface Reconstruction} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /configs/nerf-blender.yaml: -------------------------------------------------------------------------------- 1 | name: nerf-blender-${dataset.scene} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: blender 7 | scene: ??? 8 | root_dir: ./load/nerf_synthetic/${dataset.scene} 9 | img_wh: 10 | - 800 11 | - 800 12 | # img_downscale: 1 # specify training image size by either img_wh or img_downscale 13 | near_plane: 2.0 14 | far_plane: 6.0 15 | train_split: "train" 16 | val_split: "val" 17 | test_split: "test" 18 | 19 | model: 20 | name: nerf 21 | radius: 1.5 22 | num_samples_per_ray: 1024 23 | train_num_rays: 256 24 | max_train_num_rays: 8192 25 | grid_prune: true 26 | dynamic_ray_sampling: true 27 | batch_image_sampling: true 28 | randomized: true 29 | ray_chunk: 32768 30 | learned_background: false 31 | background_color: random 32 | geometry: 33 | name: volume-density 34 | radius: ${model.radius} 35 | feature_dim: 16 36 | density_activation: trunc_exp 37 | density_bias: -1 38 | isosurface: 39 | method: mc 40 | resolution: 256 41 | chunk: 2097152 42 | threshold: 5.0 43 | xyz_encoding_config: 44 | otype: HashGrid 45 | n_levels: 16 46 | n_features_per_level: 2 47 | log2_hashmap_size: 19 48 | base_resolution: 16 49 | per_level_scale: 1.447269237440378 50 | mlp_network_config: 51 | otype: FullyFusedMLP 52 | activation: ReLU 53 | output_activation: none 54 | n_neurons: 64 55 | n_hidden_layers: 1 56 | texture: 57 | name: volume-radiance 58 | input_feature_dim: ${model.geometry.feature_dim} 59 | dir_encoding_config: 60 | otype: SphericalHarmonics 61 | degree: 4 62 | mlp_network_config: 63 | otype: FullyFusedMLP 64 | activation: ReLU 65 | output_activation: Sigmoid 66 | n_neurons: 64 67 | n_hidden_layers: 2 68 | 69 | system: 70 | name: nerf-system 71 | loss: 72 | lambda_rgb: 1. 73 | lambda_distortion: 0. 74 | optimizer: 75 | name: AdamW 76 | args: 77 | lr: 0.01 78 | betas: [0.9, 0.99] 79 | eps: 1.e-15 80 | scheduler: 81 | name: MultiStepLR 82 | interval: step 83 | args: 84 | milestones: [10000, 15000, 18000] 85 | gamma: 0.33 86 | 87 | checkpoint: 88 | save_top_k: -1 89 | every_n_train_steps: ${trainer.max_steps} 90 | 91 | export: 92 | chunk_size: 2097152 93 | export_vertex_color: False 94 | 95 | trainer: 96 | max_steps: 20000 97 | log_every_n_steps: 200 98 | num_sanity_val_steps: 0 99 | val_check_interval: 10000 100 | limit_train_batches: 1.0 101 | limit_val_batches: 2 102 | enable_progress_bar: true 103 | precision: 16 104 | -------------------------------------------------------------------------------- /configs/nerf-colmap.yaml: -------------------------------------------------------------------------------- 1 | name: nerf-colmap-${basename:${dataset.root_dir}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: colmap 7 | root_dir: ./load/unbounded360/garden 8 | img_downscale: 4 # specify training image size by either img_wh or img_downscale 9 | up_est_method: ground # if true, use estimated ground plane normal direction as up direction 10 | center_est_method: lookat 11 | n_test_traj_steps: 120 12 | apply_mask: false 13 | load_data_on_gpu: false 14 | 15 | model: 16 | name: nerf 17 | radius: 1.0 18 | num_samples_per_ray: 2048 19 | train_num_rays: 128 20 | max_train_num_rays: 8192 21 | grid_prune: true 22 | dynamic_ray_sampling: true 23 | batch_image_sampling: true 24 | randomized: true 25 | ray_chunk: 16384 26 | learned_background: true 27 | background_color: random 28 | geometry: 29 | name: volume-density 30 | radius: ${model.radius} 31 | feature_dim: 16 32 | density_activation: trunc_exp 33 | density_bias: -1 34 | isosurface: 35 | method: mc 36 | resolution: 256 37 | chunk: 2097152 38 | threshold: 5.0 39 | xyz_encoding_config: 40 | otype: HashGrid 41 | n_levels: 16 42 | n_features_per_level: 2 43 | log2_hashmap_size: 19 44 | base_resolution: 16 45 | per_level_scale: 1.447269237440378 46 | mlp_network_config: 47 | otype: FullyFusedMLP 48 | activation: ReLU 49 | output_activation: none 50 | n_neurons: 64 51 | n_hidden_layers: 1 52 | texture: 53 | name: volume-radiance 54 | input_feature_dim: ${model.geometry.feature_dim} 55 | dir_encoding_config: 56 | otype: SphericalHarmonics 57 | degree: 4 58 | mlp_network_config: 59 | otype: FullyFusedMLP 60 | activation: ReLU 61 | output_activation: Sigmoid 62 | n_neurons: 64 63 | n_hidden_layers: 2 64 | 65 | system: 66 | name: nerf-system 67 | loss: 68 | lambda_rgb: 1. 69 | lambda_distortion: 0.001 70 | optimizer: 71 | name: AdamW 72 | args: 73 | lr: 0.01 74 | betas: [0.9, 0.99] 75 | eps: 1.e-15 76 | scheduler: 77 | name: MultiStepLR 78 | interval: step 79 | args: 80 | milestones: [10000, 15000, 18000] 81 | gamma: 0.33 82 | 83 | checkpoint: 84 | save_top_k: -1 85 | every_n_train_steps: ${trainer.max_steps} 86 | 87 | export: 88 | chunk_size: 2097152 89 | export_vertex_color: False 90 | 91 | trainer: 92 | max_steps: 20000 93 | log_every_n_steps: 200 94 | num_sanity_val_steps: 0 95 | val_check_interval: 5000 96 | limit_train_batches: 1.0 97 | limit_val_batches: 2 98 | enable_progress_bar: true 99 | precision: 16 100 | -------------------------------------------------------------------------------- /configs/neuralangelo-dtu-wmask.yaml: -------------------------------------------------------------------------------- 1 | name: neuralangelo-dtu-wmask-${basename:${dataset.root_dir}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: dtu 7 | root_dir: ./load/DTU-neus/dtu_scan63 8 | cameras_file: cameras_sphere.npz 9 | img_downscale: 2 # specify training image size by either img_wh or img_downscale 10 | n_test_traj_steps: 60 11 | apply_mask: true 12 | 13 | model: 14 | name: neus 15 | radius: 1.0 16 | num_samples_per_ray: 1024 17 | train_num_rays: 256 18 | max_train_num_rays: 8192 19 | grid_prune: true 20 | grid_prune_occ_thre: 0.001 21 | dynamic_ray_sampling: true 22 | batch_image_sampling: true 23 | randomized: true 24 | ray_chunk: 2048 25 | cos_anneal_end: 20000 26 | learned_background: false 27 | background_color: white 28 | variance: 29 | init_val: 0.3 30 | modulate: false 31 | geometry: 32 | name: volume-sdf 33 | radius: ${model.radius} 34 | feature_dim: 13 35 | grad_type: finite_difference 36 | finite_difference_eps: progressive 37 | isosurface: 38 | method: mc 39 | resolution: 512 40 | chunk: 2097152 41 | threshold: 0. 42 | xyz_encoding_config: 43 | otype: ProgressiveBandHashGrid 44 | n_levels: 16 45 | n_features_per_level: 2 46 | log2_hashmap_size: 19 47 | base_resolution: 32 48 | per_level_scale: 1.3195079107728942 49 | include_xyz: true 50 | start_level: 4 51 | start_step: 0 52 | update_steps: 1000 53 | mlp_network_config: 54 | otype: VanillaMLP 55 | activation: ReLU 56 | output_activation: none 57 | n_neurons: 64 58 | n_hidden_layers: 1 59 | sphere_init: true 60 | sphere_init_radius: 0.5 61 | weight_norm: true 62 | texture: 63 | name: volume-radiance 64 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 65 | dir_encoding_config: 66 | otype: SphericalHarmonics 67 | degree: 4 68 | mlp_network_config: 69 | otype: VanillaMLP 70 | activation: ReLU 71 | output_activation: none 72 | n_neurons: 64 73 | n_hidden_layers: 2 74 | color_activation: sigmoid 75 | 76 | system: 77 | name: neus-system 78 | loss: 79 | lambda_rgb_mse: 0. 80 | lambda_rgb_l1: 1. 81 | lambda_mask: 0.1 82 | lambda_eikonal: 0.1 83 | # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup 84 | lambda_curvature: 0. 85 | lambda_sparsity: 0.0 86 | lambda_distortion: 0.0 87 | lambda_distortion_bg: 0.0 88 | lambda_opaque: 0.0 89 | sparsity_scale: 1. 90 | optimizer: 91 | name: AdamW 92 | args: 93 | lr: 0.01 94 | betas: [0.9, 0.99] 95 | eps: 1.e-15 96 | params: 97 | geometry: 98 | lr: 0.01 99 | texture: 100 | lr: 0.01 101 | variance: 102 | lr: 0.001 103 | constant_steps: 5000 104 | scheduler: 105 | name: SequentialLR 106 | interval: step 107 | milestones: 108 | - ${system.constant_steps} 109 | schedulers: 110 | - name: ConstantLR 111 | args: 112 | factor: 1.0 113 | total_iters: ${system.constant_steps} 114 | - name: ExponentialLR 115 | args: 116 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} 117 | 118 | checkpoint: 119 | save_top_k: -1 120 | every_n_train_steps: ${trainer.max_steps} 121 | 122 | export: 123 | chunk_size: 2097152 124 | export_vertex_color: True 125 | 126 | trainer: 127 | max_steps: 20000 128 | log_every_n_steps: 100 129 | num_sanity_val_steps: 0 130 | val_check_interval: 500 131 | limit_train_batches: 1.0 132 | limit_val_batches: 2 133 | enable_progress_bar: true 134 | precision: 16 135 | -------------------------------------------------------------------------------- /configs/neus-blender.yaml: -------------------------------------------------------------------------------- 1 | name: neus-blender-${dataset.scene} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: blender 7 | scene: ??? 8 | root_dir: ./load/nerf_synthetic/${dataset.scene} 9 | img_wh: 10 | - 800 11 | - 800 12 | # img_downscale: 1 # specify training image size by either img_wh or img_downscale 13 | near_plane: 2.0 14 | far_plane: 6.0 15 | train_split: "train" 16 | val_split: "val" 17 | test_split: "test" 18 | 19 | model: 20 | name: neus 21 | radius: 1.5 22 | num_samples_per_ray: 1024 23 | train_num_rays: 256 24 | max_train_num_rays: 8192 25 | grid_prune: true 26 | grid_prune_occ_thre: 0.001 27 | dynamic_ray_sampling: true 28 | batch_image_sampling: true 29 | randomized: true 30 | ray_chunk: 4096 31 | cos_anneal_end: 20000 32 | learned_background: false 33 | background_color: random 34 | variance: 35 | init_val: 0.3 36 | modulate: false 37 | geometry: 38 | name: volume-sdf 39 | radius: ${model.radius} 40 | feature_dim: 13 41 | grad_type: analytic 42 | isosurface: 43 | method: mc 44 | resolution: 512 45 | chunk: 2097152 46 | threshold: 0. 47 | xyz_encoding_config: 48 | otype: HashGrid 49 | n_levels: 16 50 | n_features_per_level: 2 51 | log2_hashmap_size: 19 52 | base_resolution: 32 53 | per_level_scale: 1.3195079107728942 54 | include_xyz: true 55 | mlp_network_config: 56 | otype: VanillaMLP 57 | activation: ReLU 58 | output_activation: none 59 | n_neurons: 64 60 | n_hidden_layers: 1 61 | sphere_init: true 62 | sphere_init_radius: 0.5 63 | weight_norm: true 64 | texture: 65 | name: volume-radiance 66 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 67 | dir_encoding_config: 68 | otype: SphericalHarmonics 69 | degree: 4 70 | mlp_network_config: 71 | otype: FullyFusedMLP 72 | activation: ReLU 73 | output_activation: none 74 | n_neurons: 64 75 | n_hidden_layers: 2 76 | color_activation: sigmoid 77 | 78 | system: 79 | name: neus-system 80 | loss: 81 | lambda_rgb_mse: 10. 82 | lambda_rgb_l1: 0. 83 | lambda_mask: 0.1 84 | lambda_eikonal: 0.1 85 | lambda_curvature: 0. 86 | lambda_sparsity: 0.0 87 | lambda_distortion: 0. 88 | lambda_opaque: 0. 89 | sparsity_scale: 1. 90 | optimizer: 91 | name: AdamW 92 | args: 93 | lr: 0.01 94 | betas: [0.9, 0.99] 95 | eps: 1.e-15 96 | params: 97 | geometry: 98 | lr: 0.01 99 | texture: 100 | lr: 0.01 101 | variance: 102 | lr: 0.001 103 | warmup_steps: 500 104 | scheduler: 105 | name: SequentialLR 106 | interval: step 107 | milestones: 108 | - ${system.warmup_steps} 109 | schedulers: 110 | - name: LinearLR # linear warm-up in the first system.warmup_steps steps 111 | args: 112 | start_factor: 0.01 113 | end_factor: 1.0 114 | total_iters: ${system.warmup_steps} 115 | - name: ExponentialLR 116 | args: 117 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.warmup_steps}}} 118 | 119 | checkpoint: 120 | save_top_k: -1 121 | every_n_train_steps: ${trainer.max_steps} 122 | 123 | export: 124 | chunk_size: 2097152 125 | export_vertex_color: True 126 | 127 | trainer: 128 | max_steps: 20000 129 | log_every_n_steps: 100 130 | num_sanity_val_steps: 0 131 | val_check_interval: 10000 132 | limit_train_batches: 1.0 133 | limit_val_batches: 2 134 | enable_progress_bar: true 135 | precision: 16 136 | -------------------------------------------------------------------------------- /configs/neus-colmap.yaml: -------------------------------------------------------------------------------- 1 | name: neus-colmap-${basename:${dataset.root_dir}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: colmap 7 | root_dir: ./load/unbounded360/garden 8 | img_downscale: 4 # specify training image size by either img_wh or img_downscale 9 | up_est_method: ground # if true, use estimated ground plane normal direction as up direction 10 | center_est_method: lookat 11 | n_test_traj_steps: 120 12 | apply_mask: false 13 | load_data_on_gpu: false 14 | 15 | model: 16 | name: neus 17 | radius: 0.6 18 | num_samples_per_ray: 1024 19 | train_num_rays: 256 20 | max_train_num_rays: 8192 21 | grid_prune: true 22 | grid_prune_occ_thre: 0.001 23 | dynamic_ray_sampling: true 24 | batch_image_sampling: true 25 | randomized: true 26 | ray_chunk: 2048 27 | cos_anneal_end: 20000 28 | learned_background: true 29 | background_color: random 30 | variance: 31 | init_val: 0.3 32 | modulate: false 33 | geometry: 34 | name: volume-sdf 35 | radius: ${model.radius} 36 | feature_dim: 13 37 | grad_type: analytic 38 | isosurface: 39 | method: mc 40 | resolution: 512 41 | chunk: 2097152 42 | threshold: 0. 43 | xyz_encoding_config: 44 | otype: ProgressiveBandHashGrid 45 | n_levels: 16 46 | n_features_per_level: 2 47 | log2_hashmap_size: 19 48 | base_resolution: 32 49 | per_level_scale: 1.3195079107728942 50 | include_xyz: true 51 | start_level: 4 52 | start_step: 0 53 | update_steps: 1000 54 | mlp_network_config: 55 | otype: VanillaMLP 56 | activation: ReLU 57 | output_activation: none 58 | n_neurons: 64 59 | n_hidden_layers: 1 60 | sphere_init: true 61 | sphere_init_radius: 0.5 62 | weight_norm: true 63 | texture: 64 | name: volume-radiance 65 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 66 | dir_encoding_config: 67 | otype: SphericalHarmonics 68 | degree: 4 69 | mlp_network_config: 70 | otype: VanillaMLP 71 | activation: ReLU 72 | output_activation: none 73 | n_neurons: 64 74 | n_hidden_layers: 2 75 | color_activation: sigmoid 76 | # background model configurations 77 | num_samples_per_ray_bg: 256 78 | geometry_bg: 79 | name: volume-density 80 | radius: ${model.radius} 81 | feature_dim: 8 82 | density_activation: trunc_exp 83 | density_bias: -1 84 | isosurface: null 85 | xyz_encoding_config: 86 | otype: HashGrid 87 | n_levels: 16 88 | n_features_per_level: 2 89 | log2_hashmap_size: 19 90 | base_resolution: 32 91 | per_level_scale: 1.3195079107728942 92 | mlp_network_config: 93 | otype: VanillaMLP 94 | activation: ReLU 95 | output_activation: none 96 | n_neurons: 64 97 | n_hidden_layers: 1 98 | texture_bg: 99 | name: volume-radiance 100 | input_feature_dim: ${model.geometry_bg.feature_dim} 101 | dir_encoding_config: 102 | otype: SphericalHarmonics 103 | degree: 4 104 | mlp_network_config: 105 | otype: VanillaMLP 106 | activation: ReLU 107 | output_activation: none 108 | n_neurons: 64 109 | n_hidden_layers: 2 110 | color_activation: sigmoid 111 | 112 | system: 113 | name: neus-system 114 | loss: 115 | lambda_rgb_mse: 10. 116 | lambda_rgb_l1: 0. 117 | lambda_mask: 0.0 118 | lambda_eikonal: 0.1 119 | lambda_curvature: 0. 120 | lambda_sparsity: 0.0 121 | lambda_distortion: 0.0 122 | lambda_distortion_bg: 0.0 123 | lambda_opaque: 0.0 124 | sparsity_scale: 1. 125 | optimizer: 126 | name: AdamW 127 | args: 128 | lr: 0.01 129 | betas: [0.9, 0.99] 130 | eps: 1.e-15 131 | params: 132 | geometry: 133 | lr: 0.01 134 | texture: 135 | lr: 0.01 136 | geometry_bg: 137 | lr: 0.01 138 | texture_bg: 139 | lr: 0.01 140 | variance: 141 | lr: 0.001 142 | warmup_steps: 500 143 | scheduler: 144 | name: SequentialLR 145 | interval: step 146 | milestones: 147 | - ${system.warmup_steps} 148 | schedulers: 149 | - name: LinearLR # linear warm-up in the first system.warmup_steps steps 150 | args: 151 | start_factor: 0.01 152 | end_factor: 1.0 153 | total_iters: ${system.warmup_steps} 154 | - name: ExponentialLR 155 | args: 156 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.warmup_steps}}} 157 | 158 | checkpoint: 159 | save_top_k: -1 160 | every_n_train_steps: ${trainer.max_steps} 161 | 162 | export: 163 | chunk_size: 2097152 164 | export_vertex_color: True 165 | 166 | trainer: 167 | max_steps: 20000 168 | log_every_n_steps: 100 169 | num_sanity_val_steps: 0 170 | val_check_interval: 5000 171 | limit_train_batches: 1.0 172 | limit_val_batches: 2 173 | enable_progress_bar: true 174 | precision: 16 175 | -------------------------------------------------------------------------------- /configs/neus-dtu-wmask.yaml: -------------------------------------------------------------------------------- 1 | name: neus-dtu-wmask-${basename:${dataset.root_dir}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: dtu 7 | root_dir: ./load/DTU-neus/dtu_scan63 8 | cameras_file: cameras_sphere.npz 9 | img_downscale: 2 # specify training image size by either img_wh or img_downscale 10 | n_test_traj_steps: 60 11 | apply_mask: true 12 | 13 | model: 14 | name: neus 15 | radius: 1.0 16 | num_samples_per_ray: 1024 17 | train_num_rays: 256 18 | max_train_num_rays: 8192 19 | grid_prune: true 20 | grid_prune_occ_thre: 0.001 21 | dynamic_ray_sampling: true 22 | batch_image_sampling: true 23 | randomized: true 24 | ray_chunk: 2048 25 | cos_anneal_end: 20000 26 | learned_background: false 27 | background_color: white 28 | variance: 29 | init_val: 0.3 30 | modulate: false 31 | geometry: 32 | name: volume-sdf 33 | radius: ${model.radius} 34 | feature_dim: 13 35 | grad_type: analytic 36 | isosurface: 37 | method: mc 38 | resolution: 512 39 | chunk: 2097152 40 | threshold: 0. 41 | xyz_encoding_config: 42 | otype: HashGrid 43 | n_levels: 16 44 | n_features_per_level: 2 45 | log2_hashmap_size: 19 46 | base_resolution: 32 47 | per_level_scale: 1.3195079107728942 48 | include_xyz: true 49 | mlp_network_config: 50 | otype: VanillaMLP 51 | activation: ReLU 52 | output_activation: none 53 | n_neurons: 64 54 | n_hidden_layers: 1 55 | sphere_init: true 56 | sphere_init_radius: 0.5 57 | weight_norm: true 58 | texture: 59 | name: volume-radiance 60 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 61 | dir_encoding_config: 62 | otype: SphericalHarmonics 63 | degree: 4 64 | mlp_network_config: 65 | otype: VanillaMLP 66 | activation: ReLU 67 | output_activation: none 68 | n_neurons: 64 69 | n_hidden_layers: 2 70 | color_activation: sigmoid 71 | 72 | system: 73 | name: neus-system 74 | loss: 75 | lambda_rgb_mse: 0. 76 | lambda_rgb_l1: 1. 77 | lambda_mask: 0.0 78 | lambda_eikonal: 0.1 79 | lambda_curvature: 0. 80 | lambda_sparsity: 0.0 81 | lambda_distortion: 0.0 82 | lambda_distortion_bg: 0.0 83 | lambda_opaque: 0.0 84 | sparsity_scale: 1. 85 | optimizer: 86 | name: AdamW 87 | args: 88 | lr: 0.01 89 | betas: [0.9, 0.99] 90 | eps: 1.e-15 91 | params: 92 | geometry: 93 | lr: 0.01 94 | texture: 95 | lr: 0.01 96 | variance: 97 | lr: 0.001 98 | constant_steps: 5000 99 | scheduler: 100 | name: SequentialLR 101 | interval: step 102 | milestones: 103 | - ${system.constant_steps} 104 | schedulers: 105 | - name: ConstantLR 106 | args: 107 | factor: 1.0 108 | total_iters: ${system.constant_steps} 109 | - name: ExponentialLR 110 | args: 111 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} 112 | 113 | checkpoint: 114 | save_top_k: -1 115 | every_n_train_steps: ${trainer.max_steps} 116 | 117 | export: 118 | chunk_size: 2097152 119 | export_vertex_color: True 120 | 121 | trainer: 122 | max_steps: 20000 123 | log_every_n_steps: 100 124 | num_sanity_val_steps: 0 125 | val_check_interval: 5000 126 | limit_train_batches: 1.0 127 | limit_val_batches: 2 128 | enable_progress_bar: true 129 | precision: 16 130 | -------------------------------------------------------------------------------- /configs/neus-dtu.yaml: -------------------------------------------------------------------------------- 1 | name: neus-dtu-${basename:${dataset.root_dir}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: dtu 7 | root_dir: ./load/DTU-neus/dtu_scan63 8 | cameras_file: cameras_sphere.npz 9 | img_downscale: 2 # specify training image size by either img_wh or img_downscale 10 | n_test_traj_steps: 60 11 | apply_mask: false 12 | 13 | model: 14 | name: neus 15 | radius: 1.0 16 | num_samples_per_ray: 1024 17 | train_num_rays: 256 18 | max_train_num_rays: 8192 19 | grid_prune: true 20 | grid_prune_occ_thre: 0.001 21 | dynamic_ray_sampling: true 22 | batch_image_sampling: true 23 | randomized: true 24 | ray_chunk: 2048 25 | cos_anneal_end: 20000 26 | learned_background: true 27 | background_color: random 28 | variance: 29 | init_val: 0.3 30 | modulate: false 31 | geometry: 32 | name: volume-sdf 33 | radius: ${model.radius} 34 | feature_dim: 13 35 | grad_type: analytic 36 | isosurface: 37 | method: mc 38 | resolution: 512 39 | chunk: 2097152 40 | threshold: 0. 41 | xyz_encoding_config: 42 | otype: HashGrid 43 | n_levels: 16 44 | n_features_per_level: 2 45 | log2_hashmap_size: 19 46 | base_resolution: 32 47 | per_level_scale: 1.3195079107728942 48 | include_xyz: true 49 | mlp_network_config: 50 | otype: VanillaMLP 51 | activation: ReLU 52 | output_activation: none 53 | n_neurons: 64 54 | n_hidden_layers: 1 55 | sphere_init: true 56 | sphere_init_radius: 0.5 57 | weight_norm: true 58 | texture: 59 | name: volume-radiance 60 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 61 | dir_encoding_config: 62 | otype: SphericalHarmonics 63 | degree: 4 64 | mlp_network_config: 65 | otype: VanillaMLP 66 | activation: ReLU 67 | output_activation: none 68 | n_neurons: 64 69 | n_hidden_layers: 2 70 | color_activation: sigmoid 71 | # background model configurations 72 | num_samples_per_ray_bg: 64 73 | geometry_bg: 74 | name: volume-density 75 | radius: ${model.radius} 76 | feature_dim: 8 77 | density_activation: trunc_exp 78 | density_bias: -1 79 | isosurface: null 80 | xyz_encoding_config: 81 | otype: HashGrid 82 | n_levels: 16 83 | n_features_per_level: 2 84 | log2_hashmap_size: 19 85 | base_resolution: 32 86 | per_level_scale: 1.3195079107728942 87 | mlp_network_config: 88 | otype: VanillaMLP 89 | activation: ReLU 90 | output_activation: none 91 | n_neurons: 64 92 | n_hidden_layers: 1 93 | texture_bg: 94 | name: volume-radiance 95 | input_feature_dim: ${model.geometry_bg.feature_dim} 96 | dir_encoding_config: 97 | otype: SphericalHarmonics 98 | degree: 4 99 | mlp_network_config: 100 | otype: VanillaMLP 101 | activation: ReLU 102 | output_activation: none 103 | n_neurons: 64 104 | n_hidden_layers: 2 105 | color_activation: sigmoid 106 | 107 | system: 108 | name: neus-system 109 | loss: 110 | lambda_rgb_mse: 0. 111 | lambda_rgb_l1: 1. 112 | lambda_mask: 0.0 113 | lambda_eikonal: 0.1 114 | lambda_curvature: 0. 115 | lambda_sparsity: 0.0 116 | lambda_distortion: 0.0 117 | lambda_distortion_bg: 0.0 118 | lambda_opaque: 0.0 119 | sparsity_scale: 1. 120 | optimizer: 121 | name: AdamW 122 | args: 123 | lr: 0.01 124 | betas: [0.9, 0.99] 125 | eps: 1.e-15 126 | params: 127 | geometry: 128 | lr: 0.01 129 | texture: 130 | lr: 0.01 131 | geometry_bg: 132 | lr: 0.01 133 | texture_bg: 134 | lr: 0.01 135 | variance: 136 | lr: 0.001 137 | constant_steps: 5000 138 | scheduler: 139 | name: SequentialLR 140 | interval: step 141 | milestones: 142 | - ${system.constant_steps} 143 | schedulers: 144 | - name: ConstantLR 145 | args: 146 | factor: 1.0 147 | total_iters: ${system.constant_steps} 148 | - name: ExponentialLR 149 | args: 150 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} 151 | 152 | checkpoint: 153 | save_top_k: -1 154 | every_n_train_steps: ${trainer.max_steps} 155 | 156 | export: 157 | chunk_size: 2097152 158 | export_vertex_color: True 159 | 160 | trainer: 161 | max_steps: 20000 162 | log_every_n_steps: 100 163 | num_sanity_val_steps: 0 164 | val_check_interval: 5000 165 | limit_train_batches: 1.0 166 | limit_val_batches: 2 167 | enable_progress_bar: true 168 | precision: 16 169 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | datasets = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | datasets[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config): 12 | dataset = datasets[name](config) 13 | return dataset 14 | 15 | 16 | from . import blender, colmap, dtu 17 | -------------------------------------------------------------------------------- /datasets/blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader, IterableDataset 9 | import torchvision.transforms.functional as TF 10 | 11 | import pytorch_lightning as pl 12 | 13 | import datasets 14 | from models.ray_utils import get_ray_directions 15 | from utils.misc import get_rank 16 | 17 | 18 | class BlenderDatasetBase(): 19 | def setup(self, config, split): 20 | self.config = config 21 | self.split = split 22 | self.rank = get_rank() 23 | 24 | self.has_mask = True 25 | self.apply_mask = True 26 | 27 | with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f: 28 | meta = json.load(f) 29 | 30 | if 'w' in meta and 'h' in meta: 31 | W, H = int(meta['w']), int(meta['h']) 32 | else: 33 | W, H = 800, 800 34 | 35 | if 'img_wh' in self.config: 36 | w, h = self.config.img_wh 37 | assert round(W / w * h) == H 38 | elif 'img_downscale' in self.config: 39 | w, h = W // self.config.img_downscale, H // self.config.img_downscale 40 | else: 41 | raise KeyError("Either img_wh or img_downscale should be specified.") 42 | 43 | self.w, self.h = w, h 44 | self.img_wh = (self.w, self.h) 45 | 46 | self.near, self.far = self.config.near_plane, self.config.far_plane 47 | 48 | self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) # scaled focal length 49 | 50 | # ray directions for all pixels, same for all images (same H, W, focal) 51 | self.directions = \ 52 | get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) # (h, w, 3) 53 | 54 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] 55 | 56 | for i, frame in enumerate(meta['frames']): 57 | c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4]) 58 | self.all_c2w.append(c2w) 59 | 60 | img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") 61 | img = Image.open(img_path) 62 | img = img.resize(self.img_wh, Image.BICUBIC) 63 | img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) 64 | 65 | self.all_fg_masks.append(img[..., -1]) # (h, w) 66 | self.all_images.append(img[...,:3]) 67 | 68 | self.all_c2w, self.all_images, self.all_fg_masks = \ 69 | torch.stack(self.all_c2w, dim=0).float().to(self.rank), \ 70 | torch.stack(self.all_images, dim=0).float().to(self.rank), \ 71 | torch.stack(self.all_fg_masks, dim=0).float().to(self.rank) 72 | 73 | 74 | class BlenderDataset(Dataset, BlenderDatasetBase): 75 | def __init__(self, config, split): 76 | self.setup(config, split) 77 | 78 | def __len__(self): 79 | return len(self.all_images) 80 | 81 | def __getitem__(self, index): 82 | return { 83 | 'index': index 84 | } 85 | 86 | 87 | class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): 88 | def __init__(self, config, split): 89 | self.setup(config, split) 90 | 91 | def __iter__(self): 92 | while True: 93 | yield {} 94 | 95 | 96 | @datasets.register('blender') 97 | class BlenderDataModule(pl.LightningDataModule): 98 | def __init__(self, config): 99 | super().__init__() 100 | self.config = config 101 | 102 | def setup(self, stage=None): 103 | if stage in [None, 'fit']: 104 | self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split) 105 | if stage in [None, 'fit', 'validate']: 106 | self.val_dataset = BlenderDataset(self.config, self.config.val_split) 107 | if stage in [None, 'test']: 108 | self.test_dataset = BlenderDataset(self.config, self.config.test_split) 109 | if stage in [None, 'predict']: 110 | self.predict_dataset = BlenderDataset(self.config, self.config.train_split) 111 | 112 | def prepare_data(self): 113 | pass 114 | 115 | def general_loader(self, dataset, batch_size): 116 | sampler = None 117 | return DataLoader( 118 | dataset, 119 | num_workers=os.cpu_count(), 120 | batch_size=batch_size, 121 | pin_memory=True, 122 | sampler=sampler 123 | ) 124 | 125 | def train_dataloader(self): 126 | return self.general_loader(self.train_dataset, batch_size=1) 127 | 128 | def val_dataloader(self): 129 | return self.general_loader(self.val_dataset, batch_size=1) 130 | 131 | def test_dataloader(self): 132 | return self.general_loader(self.test_dataset, batch_size=1) 133 | 134 | def predict_dataloader(self): 135 | return self.general_loader(self.predict_dataset, batch_size=1) 136 | -------------------------------------------------------------------------------- /datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset, DataLoader, IterableDataset 9 | import torchvision.transforms.functional as TF 10 | 11 | import pytorch_lightning as pl 12 | 13 | import datasets 14 | from datasets.colmap_utils import \ 15 | read_cameras_binary, read_images_binary, read_points3d_binary 16 | from models.ray_utils import get_ray_directions 17 | from utils.misc import get_rank 18 | 19 | 20 | def get_center(pts): 21 | center = pts.mean(0) 22 | dis = (pts - center[None,:]).norm(p=2, dim=-1) 23 | mean, std = dis.mean(), dis.std() 24 | q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) 25 | valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) 26 | center = pts[valid].mean(0) 27 | return center 28 | 29 | def normalize_poses(poses, pts, up_est_method, center_est_method): 30 | if center_est_method == 'camera': 31 | # estimation scene center as the average of all camera positions 32 | center = poses[...,3].mean(0) 33 | elif center_est_method == 'lookat': 34 | # estimation scene center as the average of the intersection of selected pairs of camera rays 35 | cams_ori = poses[...,3] 36 | cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) 37 | cams_dir = F.normalize(cams_dir, dim=-1) 38 | A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) 39 | b = -cams_ori + cams_ori.roll(1,0) 40 | t = torch.linalg.lstsq(A, b).solution 41 | center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) 42 | elif center_est_method == 'point': 43 | # first estimation scene center as the average of all camera positions 44 | # later we'll use the center of all points bounded by the cameras as the final scene center 45 | center = poses[...,3].mean(0) 46 | else: 47 | raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') 48 | 49 | if up_est_method == 'ground': 50 | # estimate up direction as the normal of the estimated ground plane 51 | # use RANSAC to estimate the ground plane in the point cloud 52 | import pyransac3d as pyrsc 53 | ground = pyrsc.Plane() 54 | plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale 55 | plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0 56 | z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction 57 | signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) 58 | if signed_distance.mean() < 0: 59 | z = -z # flip the direction if points lie under the plane 60 | elif up_est_method == 'camera': 61 | # estimate up direction as the average of all camera up directions 62 | z = F.normalize((poses[...,3] - center).mean(0), dim=0) 63 | else: 64 | raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') 65 | 66 | # new axis 67 | y_ = torch.as_tensor([z[1], -z[0], 0.]) 68 | x = F.normalize(y_.cross(z), dim=0) 69 | y = z.cross(x) 70 | 71 | if center_est_method == 'point': 72 | # rotation 73 | Rc = torch.stack([x, y, z], dim=1) 74 | R = Rc.T 75 | poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) 76 | inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) 77 | poses_norm = (inv_trans @ poses_homo)[:,:3] 78 | pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] 79 | 80 | # translation and scaling 81 | poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] 82 | pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] 83 | center = get_center(pts_fg) 84 | tc = center.reshape(3, 1) 85 | t = -tc 86 | poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) 87 | inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) 88 | poses_norm = (inv_trans @ poses_homo)[:,:3] 89 | scale = poses_norm[...,3].norm(p=2, dim=-1).min() 90 | poses_norm[...,3] /= scale 91 | pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] 92 | pts = pts / scale 93 | else: 94 | # rotation and translation 95 | Rc = torch.stack([x, y, z], dim=1) 96 | tc = center.reshape(3, 1) 97 | R, t = Rc.T, -Rc.T @ tc 98 | poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) 99 | inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) 100 | poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4) 101 | 102 | # scaling 103 | scale = poses_norm[...,3].norm(p=2, dim=-1).min() 104 | poses_norm[...,3] /= scale 105 | 106 | # apply the transformation to the point cloud 107 | pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] 108 | pts = pts / scale 109 | 110 | return poses_norm, pts 111 | 112 | def create_spheric_poses(cameras, n_steps=120): 113 | center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) 114 | mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() 115 | mean_h = cameras[:,2].mean() 116 | r = (mean_d**2 - mean_h**2).sqrt() 117 | up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) 118 | 119 | all_c2w = [] 120 | for theta in torch.linspace(0, 2 * math.pi, n_steps): 121 | cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) 122 | l = F.normalize(center - cam_pos, p=2, dim=0) 123 | s = F.normalize(l.cross(up), p=2, dim=0) 124 | u = F.normalize(s.cross(l), p=2, dim=0) 125 | c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) 126 | all_c2w.append(c2w) 127 | 128 | all_c2w = torch.stack(all_c2w, dim=0) 129 | 130 | return all_c2w 131 | 132 | class ColmapDatasetBase(): 133 | # the data only has to be processed once 134 | initialized = False 135 | properties = {} 136 | 137 | def setup(self, config, split): 138 | self.config = config 139 | self.split = split 140 | self.rank = get_rank() 141 | 142 | if not ColmapDatasetBase.initialized: 143 | camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) 144 | 145 | H = int(camdata[1].height) 146 | W = int(camdata[1].width) 147 | 148 | if 'img_wh' in self.config: 149 | w, h = self.config.img_wh 150 | assert round(W / w * h) == H 151 | elif 'img_downscale' in self.config: 152 | w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) 153 | else: 154 | raise KeyError("Either img_wh or img_downscale should be specified.") 155 | 156 | img_wh = (w, h) 157 | factor = w / W 158 | 159 | if camdata[1].model == 'SIMPLE_RADIAL': 160 | fx = fy = camdata[1].params[0] * factor 161 | cx = camdata[1].params[1] * factor 162 | cy = camdata[1].params[2] * factor 163 | elif camdata[1].model in ['PINHOLE', 'OPENCV']: 164 | fx = camdata[1].params[0] * factor 165 | fy = camdata[1].params[1] * factor 166 | cx = camdata[1].params[2] * factor 167 | cy = camdata[1].params[3] * factor 168 | else: 169 | raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") 170 | 171 | directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) 172 | 173 | imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) 174 | 175 | mask_dir = os.path.join(self.config.root_dir, 'masks') 176 | has_mask = os.path.exists(mask_dir) # TODO: support partial masks 177 | apply_mask = has_mask and self.config.apply_mask 178 | 179 | all_c2w, all_images, all_fg_masks = [], [], [] 180 | 181 | for i, d in enumerate(imdata.values()): 182 | R = d.qvec2rotmat() 183 | t = d.tvec.reshape(3, 1) 184 | c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() 185 | c2w[:,1:3] *= -1. # COLMAP => OpenGL 186 | all_c2w.append(c2w) 187 | if self.split in ['train', 'val']: 188 | img_path = os.path.join(self.config.root_dir, 'images', d.name) 189 | img = Image.open(img_path) 190 | img = img.resize(img_wh, Image.BICUBIC) 191 | img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] 192 | img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() 193 | if has_mask: 194 | mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] 195 | mask_paths = list(filter(os.path.exists, mask_paths)) 196 | assert len(mask_paths) == 1 197 | mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1) 198 | mask = mask.resize(img_wh, Image.BICUBIC) 199 | mask = TF.to_tensor(mask)[0] 200 | else: 201 | mask = torch.ones_like(img[...,0], device=img.device) 202 | all_fg_masks.append(mask) # (h, w) 203 | all_images.append(img) 204 | 205 | all_c2w = torch.stack(all_c2w, dim=0) 206 | 207 | pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) 208 | pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() 209 | all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) 210 | 211 | ColmapDatasetBase.properties = { 212 | 'w': w, 213 | 'h': h, 214 | 'img_wh': img_wh, 215 | 'factor': factor, 216 | 'has_mask': has_mask, 217 | 'apply_mask': apply_mask, 218 | 'directions': directions, 219 | 'pts3d': pts3d, 220 | 'all_c2w': all_c2w, 221 | 'all_images': all_images, 222 | 'all_fg_masks': all_fg_masks 223 | } 224 | 225 | ColmapDatasetBase.initialized = True 226 | 227 | for k, v in ColmapDatasetBase.properties.items(): 228 | setattr(self, k, v) 229 | 230 | if self.split == 'test': 231 | self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) 232 | self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) 233 | self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) 234 | else: 235 | self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() 236 | 237 | """ 238 | # for debug use 239 | from models.ray_utils import get_rays 240 | rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) 241 | pts_out = [] 242 | pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) 243 | 244 | t_vals = torch.linspace(0, 1, 8) 245 | z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals 246 | 247 | ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) 248 | pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) 249 | 250 | ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) 251 | pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) 252 | 253 | ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) 254 | pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) 255 | 256 | ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) 257 | pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) 258 | 259 | open('cameras.txt', 'w').write('\n'.join(pts_out)) 260 | open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) 261 | 262 | exit(1) 263 | """ 264 | 265 | self.all_c2w = self.all_c2w.float().to(self.rank) 266 | if self.config.load_data_on_gpu: 267 | self.all_images = self.all_images.to(self.rank) 268 | self.all_fg_masks = self.all_fg_masks.to(self.rank) 269 | 270 | 271 | class ColmapDataset(Dataset, ColmapDatasetBase): 272 | def __init__(self, config, split): 273 | self.setup(config, split) 274 | 275 | def __len__(self): 276 | return len(self.all_images) 277 | 278 | def __getitem__(self, index): 279 | return { 280 | 'index': index 281 | } 282 | 283 | 284 | class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): 285 | def __init__(self, config, split): 286 | self.setup(config, split) 287 | 288 | def __iter__(self): 289 | while True: 290 | yield {} 291 | 292 | 293 | @datasets.register('colmap') 294 | class ColmapDataModule(pl.LightningDataModule): 295 | def __init__(self, config): 296 | super().__init__() 297 | self.config = config 298 | 299 | def setup(self, stage=None): 300 | if stage in [None, 'fit']: 301 | self.train_dataset = ColmapIterableDataset(self.config, 'train') 302 | if stage in [None, 'fit', 'validate']: 303 | self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) 304 | if stage in [None, 'test']: 305 | self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) 306 | if stage in [None, 'predict']: 307 | self.predict_dataset = ColmapDataset(self.config, 'train') 308 | 309 | def prepare_data(self): 310 | pass 311 | 312 | def general_loader(self, dataset, batch_size): 313 | sampler = None 314 | return DataLoader( 315 | dataset, 316 | num_workers=os.cpu_count(), 317 | batch_size=batch_size, 318 | pin_memory=True, 319 | sampler=sampler 320 | ) 321 | 322 | def train_dataloader(self): 323 | return self.general_loader(self.train_dataset, batch_size=1) 324 | 325 | def val_dataloader(self): 326 | return self.general_loader(self.val_dataset, batch_size=1) 327 | 328 | def test_dataloader(self): 329 | return self.general_loader(self.test_dataset, batch_size=1) 330 | 331 | def predict_dataloader(self): 332 | return self.general_loader(self.predict_dataset, batch_size=1) 333 | -------------------------------------------------------------------------------- /datasets/colmap_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import collections 34 | import numpy as np 35 | import struct 36 | 37 | 38 | CameraModel = collections.namedtuple( 39 | "CameraModel", ["model_id", "model_name", "num_params"]) 40 | Camera = collections.namedtuple( 41 | "Camera", ["id", "model", "width", "height", "params"]) 42 | BaseImage = collections.namedtuple( 43 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 44 | Point3D = collections.namedtuple( 45 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 46 | 47 | class Image(BaseImage): 48 | def qvec2rotmat(self): 49 | return qvec2rotmat(self.qvec) 50 | 51 | 52 | CAMERA_MODELS = { 53 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 54 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 55 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 56 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 57 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 58 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 59 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 60 | CameraModel(model_id=7, model_name="FOV", num_params=5), 61 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 62 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 63 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 64 | } 65 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 66 | for camera_model in CAMERA_MODELS]) 67 | 68 | 69 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 70 | """Read and unpack the next bytes from a binary file. 71 | :param fid: 72 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 73 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 74 | :param endian_character: Any of {@, =, <, >, !} 75 | :return: Tuple of read and unpacked values. 76 | """ 77 | data = fid.read(num_bytes) 78 | return struct.unpack(endian_character + format_char_sequence, data) 79 | 80 | 81 | def read_cameras_text(path): 82 | """ 83 | see: src/base/reconstruction.cc 84 | void Reconstruction::WriteCamerasText(const std::string& path) 85 | void Reconstruction::ReadCamerasText(const std::string& path) 86 | """ 87 | cameras = {} 88 | with open(path, "r") as fid: 89 | while True: 90 | line = fid.readline() 91 | if not line: 92 | break 93 | line = line.strip() 94 | if len(line) > 0 and line[0] != "#": 95 | elems = line.split() 96 | camera_id = int(elems[0]) 97 | model = elems[1] 98 | width = int(elems[2]) 99 | height = int(elems[3]) 100 | params = np.array(tuple(map(float, elems[4:]))) 101 | cameras[camera_id] = Camera(id=camera_id, model=model, 102 | width=width, height=height, 103 | params=params) 104 | return cameras 105 | 106 | 107 | def read_cameras_binary(path_to_model_file): 108 | """ 109 | see: src/base/reconstruction.cc 110 | void Reconstruction::WriteCamerasBinary(const std::string& path) 111 | void Reconstruction::ReadCamerasBinary(const std::string& path) 112 | """ 113 | cameras = {} 114 | with open(path_to_model_file, "rb") as fid: 115 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 116 | for camera_line_index in range(num_cameras): 117 | camera_properties = read_next_bytes( 118 | fid, num_bytes=24, format_char_sequence="iiQQ") 119 | camera_id = camera_properties[0] 120 | model_id = camera_properties[1] 121 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 122 | width = camera_properties[2] 123 | height = camera_properties[3] 124 | num_params = CAMERA_MODEL_IDS[model_id].num_params 125 | params = read_next_bytes(fid, num_bytes=8*num_params, 126 | format_char_sequence="d"*num_params) 127 | cameras[camera_id] = Camera(id=camera_id, 128 | model=model_name, 129 | width=width, 130 | height=height, 131 | params=np.array(params)) 132 | assert len(cameras) == num_cameras 133 | return cameras 134 | 135 | 136 | def read_images_text(path): 137 | """ 138 | see: src/base/reconstruction.cc 139 | void Reconstruction::ReadImagesText(const std::string& path) 140 | void Reconstruction::WriteImagesText(const std::string& path) 141 | """ 142 | images = {} 143 | with open(path, "r") as fid: 144 | while True: 145 | line = fid.readline() 146 | if not line: 147 | break 148 | line = line.strip() 149 | if len(line) > 0 and line[0] != "#": 150 | elems = line.split() 151 | image_id = int(elems[0]) 152 | qvec = np.array(tuple(map(float, elems[1:5]))) 153 | tvec = np.array(tuple(map(float, elems[5:8]))) 154 | camera_id = int(elems[8]) 155 | image_name = elems[9] 156 | elems = fid.readline().split() 157 | xys = np.column_stack([tuple(map(float, elems[0::3])), 158 | tuple(map(float, elems[1::3]))]) 159 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 160 | images[image_id] = Image( 161 | id=image_id, qvec=qvec, tvec=tvec, 162 | camera_id=camera_id, name=image_name, 163 | xys=xys, point3D_ids=point3D_ids) 164 | return images 165 | 166 | 167 | def read_images_binary(path_to_model_file): 168 | """ 169 | see: src/base/reconstruction.cc 170 | void Reconstruction::ReadImagesBinary(const std::string& path) 171 | void Reconstruction::WriteImagesBinary(const std::string& path) 172 | """ 173 | images = {} 174 | with open(path_to_model_file, "rb") as fid: 175 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 176 | for image_index in range(num_reg_images): 177 | binary_image_properties = read_next_bytes( 178 | fid, num_bytes=64, format_char_sequence="idddddddi") 179 | image_id = binary_image_properties[0] 180 | qvec = np.array(binary_image_properties[1:5]) 181 | tvec = np.array(binary_image_properties[5:8]) 182 | camera_id = binary_image_properties[8] 183 | image_name = "" 184 | current_char = read_next_bytes(fid, 1, "c")[0] 185 | while current_char != b"\x00": # look for the ASCII 0 entry 186 | image_name += current_char.decode("utf-8") 187 | current_char = read_next_bytes(fid, 1, "c")[0] 188 | num_points2D = read_next_bytes(fid, num_bytes=8, 189 | format_char_sequence="Q")[0] 190 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 191 | format_char_sequence="ddq"*num_points2D) 192 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 193 | tuple(map(float, x_y_id_s[1::3]))]) 194 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 195 | images[image_id] = Image( 196 | id=image_id, qvec=qvec, tvec=tvec, 197 | camera_id=camera_id, name=image_name, 198 | xys=xys, point3D_ids=point3D_ids) 199 | return images 200 | 201 | 202 | def read_points3D_text(path): 203 | """ 204 | see: src/base/reconstruction.cc 205 | void Reconstruction::ReadPoints3DText(const std::string& path) 206 | void Reconstruction::WritePoints3DText(const std::string& path) 207 | """ 208 | points3D = {} 209 | with open(path, "r") as fid: 210 | while True: 211 | line = fid.readline() 212 | if not line: 213 | break 214 | line = line.strip() 215 | if len(line) > 0 and line[0] != "#": 216 | elems = line.split() 217 | point3D_id = int(elems[0]) 218 | xyz = np.array(tuple(map(float, elems[1:4]))) 219 | rgb = np.array(tuple(map(int, elems[4:7]))) 220 | error = float(elems[7]) 221 | image_ids = np.array(tuple(map(int, elems[8::2]))) 222 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 223 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 224 | error=error, image_ids=image_ids, 225 | point2D_idxs=point2D_idxs) 226 | return points3D 227 | 228 | 229 | def read_points3d_binary(path_to_model_file): 230 | """ 231 | see: src/base/reconstruction.cc 232 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 233 | void Reconstruction::WritePoints3DBinary(const std::string& path) 234 | """ 235 | points3D = {} 236 | with open(path_to_model_file, "rb") as fid: 237 | num_points = read_next_bytes(fid, 8, "Q")[0] 238 | for point_line_index in range(num_points): 239 | binary_point_line_properties = read_next_bytes( 240 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 241 | point3D_id = binary_point_line_properties[0] 242 | xyz = np.array(binary_point_line_properties[1:4]) 243 | rgb = np.array(binary_point_line_properties[4:7]) 244 | error = np.array(binary_point_line_properties[7]) 245 | track_length = read_next_bytes( 246 | fid, num_bytes=8, format_char_sequence="Q")[0] 247 | track_elems = read_next_bytes( 248 | fid, num_bytes=8*track_length, 249 | format_char_sequence="ii"*track_length) 250 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 251 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 252 | points3D[point3D_id] = Point3D( 253 | id=point3D_id, xyz=xyz, rgb=rgb, 254 | error=error, image_ids=image_ids, 255 | point2D_idxs=point2D_idxs) 256 | return points3D 257 | 258 | 259 | def read_model(path, ext): 260 | if ext == ".txt": 261 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 262 | images = read_images_text(os.path.join(path, "images" + ext)) 263 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 264 | else: 265 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 266 | images = read_images_binary(os.path.join(path, "images" + ext)) 267 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 268 | return cameras, images, points3D 269 | 270 | 271 | def qvec2rotmat(qvec): 272 | return np.array([ 273 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 274 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 275 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 276 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 277 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 278 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 279 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 280 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 281 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 282 | 283 | 284 | def rotmat2qvec(R): 285 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 286 | K = np.array([ 287 | [Rxx - Ryy - Rzz, 0, 0, 0], 288 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 289 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 290 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 291 | eigvals, eigvecs = np.linalg.eigh(K) 292 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 293 | if qvec[0] < 0: 294 | qvec *= -1 295 | return qvec 296 | -------------------------------------------------------------------------------- /datasets/dtu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader, IterableDataset 11 | import torchvision.transforms.functional as TF 12 | 13 | import pytorch_lightning as pl 14 | 15 | import datasets 16 | from models.ray_utils import get_ray_directions 17 | from utils.misc import get_rank 18 | 19 | 20 | def load_K_Rt_from_P(P=None): 21 | out = cv2.decomposeProjectionMatrix(P) 22 | K = out[0] 23 | R = out[1] 24 | t = out[2] 25 | 26 | K = K / K[2, 2] 27 | intrinsics = np.eye(4) 28 | intrinsics[:3, :3] = K 29 | 30 | pose = np.eye(4, dtype=np.float32) 31 | pose[:3, :3] = R.transpose() 32 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 33 | 34 | return intrinsics, pose 35 | 36 | def create_spheric_poses(cameras, n_steps=120): 37 | center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) 38 | cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2) 39 | eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors 40 | rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1) 41 | up = rot_axis 42 | rot_dir = torch.cross(rot_axis, cam_center) 43 | max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max() 44 | 45 | all_c2w = [] 46 | for theta in torch.linspace(-max_angle, max_angle, n_steps): 47 | cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta) 48 | l = F.normalize(center - cam_pos, p=2, dim=0) 49 | s = F.normalize(l.cross(up), p=2, dim=0) 50 | u = F.normalize(s.cross(l), p=2, dim=0) 51 | c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) 52 | all_c2w.append(c2w) 53 | 54 | all_c2w = torch.stack(all_c2w, dim=0) 55 | 56 | return all_c2w 57 | 58 | class DTUDatasetBase(): 59 | def setup(self, config, split): 60 | self.config = config 61 | self.split = split 62 | self.rank = get_rank() 63 | 64 | cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file)) 65 | 66 | img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png')) 67 | H, W = img_sample.shape[0], img_sample.shape[1] 68 | 69 | if 'img_wh' in self.config: 70 | w, h = self.config.img_wh 71 | assert round(W / w * h) == H 72 | elif 'img_downscale' in self.config: 73 | w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) 74 | else: 75 | raise KeyError("Either img_wh or img_downscale should be specified.") 76 | 77 | self.w, self.h = w, h 78 | self.img_wh = (w, h) 79 | self.factor = w / W 80 | 81 | mask_dir = os.path.join(self.config.root_dir, 'mask') 82 | self.has_mask = True 83 | self.apply_mask = self.config.apply_mask 84 | 85 | self.directions = [] 86 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] 87 | 88 | n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1 89 | 90 | for i in range(n_images): 91 | world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}'] 92 | P = (world_mat @ scale_mat)[:3,:4] 93 | K, c2w = load_K_Rt_from_P(P) 94 | fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor 95 | directions = get_ray_directions(w, h, fx, fy, cx, cy) 96 | self.directions.append(directions) 97 | 98 | c2w = torch.from_numpy(c2w).float() 99 | 100 | # blender follows opengl camera coordinates (right up back) 101 | # NeuS DTU data coordinate system (right down front) is different from blender 102 | # https://github.com/Totoro97/NeuS/issues/9 103 | # for c2w, flip the sign of input camera coordinate yz 104 | c2w_ = c2w.clone() 105 | c2w_[:3,1:3] *= -1. # flip input sign 106 | self.all_c2w.append(c2w_[:3,:4]) 107 | 108 | if self.split in ['train', 'val']: 109 | img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png') 110 | img = Image.open(img_path) 111 | img = img.resize(self.img_wh, Image.BICUBIC) 112 | img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] 113 | 114 | mask_path = os.path.join(mask_dir, f'{i:03d}.png') 115 | mask = Image.open(mask_path).convert('L') # (H, W, 1) 116 | mask = mask.resize(self.img_wh, Image.BICUBIC) 117 | mask = TF.to_tensor(mask)[0] 118 | 119 | self.all_fg_masks.append(mask) # (h, w) 120 | self.all_images.append(img) 121 | 122 | self.all_c2w = torch.stack(self.all_c2w, dim=0) 123 | 124 | if self.split == 'test': 125 | self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) 126 | self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) 127 | self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) 128 | self.directions = self.directions[0] 129 | else: 130 | self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0) 131 | self.directions = torch.stack(self.directions, dim=0) 132 | 133 | self.directions = self.directions.float().to(self.rank) 134 | self.all_c2w, self.all_images, self.all_fg_masks = \ 135 | self.all_c2w.float().to(self.rank), \ 136 | self.all_images.float().to(self.rank), \ 137 | self.all_fg_masks.float().to(self.rank) 138 | 139 | 140 | class DTUDataset(Dataset, DTUDatasetBase): 141 | def __init__(self, config, split): 142 | self.setup(config, split) 143 | 144 | def __len__(self): 145 | return len(self.all_images) 146 | 147 | def __getitem__(self, index): 148 | return { 149 | 'index': index 150 | } 151 | 152 | 153 | class DTUIterableDataset(IterableDataset, DTUDatasetBase): 154 | def __init__(self, config, split): 155 | self.setup(config, split) 156 | 157 | def __iter__(self): 158 | while True: 159 | yield {} 160 | 161 | 162 | @datasets.register('dtu') 163 | class DTUDataModule(pl.LightningDataModule): 164 | def __init__(self, config): 165 | super().__init__() 166 | self.config = config 167 | 168 | def setup(self, stage=None): 169 | if stage in [None, 'fit']: 170 | self.train_dataset = DTUIterableDataset(self.config, 'train') 171 | if stage in [None, 'fit', 'validate']: 172 | self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train')) 173 | if stage in [None, 'test']: 174 | self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test')) 175 | if stage in [None, 'predict']: 176 | self.predict_dataset = DTUDataset(self.config, 'train') 177 | 178 | def prepare_data(self): 179 | pass 180 | 181 | def general_loader(self, dataset, batch_size): 182 | sampler = None 183 | return DataLoader( 184 | dataset, 185 | num_workers=os.cpu_count(), 186 | batch_size=batch_size, 187 | pin_memory=True, 188 | sampler=sampler 189 | ) 190 | 191 | def train_dataloader(self): 192 | return self.general_loader(self.train_dataset, batch_size=1) 193 | 194 | def val_dataloader(self): 195 | return self.general_loader(self.val_dataset, batch_size=1) 196 | 197 | def test_dataloader(self): 198 | return self.general_loader(self.test_dataset, batch_size=1) 199 | 200 | def predict_dataloader(self): 201 | return self.general_loader(self.predict_dataset, batch_size=1) 202 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bennyguo/instant-nsr-pl/e5fe3f246cf2d512494a73c727174c3ca1c3c695/datasets/utils.py -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import time 5 | import logging 6 | from datetime import datetime 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--config', required=True, help='path to config file') 12 | parser.add_argument('--gpu', default='0', help='GPU(s) to be used') 13 | parser.add_argument('--resume', default=None, help='path to the weights to be resumed') 14 | parser.add_argument( 15 | '--resume_weights_only', 16 | action='store_true', 17 | help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only' 18 | ) 19 | 20 | group = parser.add_mutually_exclusive_group(required=True) 21 | group.add_argument('--train', action='store_true') 22 | group.add_argument('--validate', action='store_true') 23 | group.add_argument('--test', action='store_true') 24 | group.add_argument('--predict', action='store_true') 25 | # group.add_argument('--export', action='store_true') # TODO: a separate export action 26 | 27 | parser.add_argument('--exp_dir', default='./exp') 28 | parser.add_argument('--runs_dir', default='./runs') 29 | parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG') 30 | 31 | args, extras = parser.parse_known_args() 32 | 33 | # set CUDA_VISIBLE_DEVICES then import pytorch-lightning 34 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 36 | n_gpus = len(args.gpu.split(',')) 37 | 38 | import datasets 39 | import systems 40 | import pytorch_lightning as pl 41 | from pytorch_lightning import Trainer 42 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 43 | from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger 44 | from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar 45 | from utils.misc import load_config 46 | 47 | # parse YAML config to OmegaConf 48 | config = load_config(args.config, cli_args=extras) 49 | config.cmd_args = vars(args) 50 | 51 | config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S')) 52 | config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name) 53 | config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save') 54 | config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt') 55 | config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code') 56 | config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config') 57 | 58 | logger = logging.getLogger('pytorch_lightning') 59 | if args.verbose: 60 | logger.setLevel(logging.DEBUG) 61 | 62 | if 'seed' not in config: 63 | config.seed = int(time.time() * 1000) % 1000 64 | pl.seed_everything(config.seed) 65 | 66 | dm = datasets.make(config.dataset.name, config.dataset) 67 | system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume) 68 | 69 | callbacks = [] 70 | if args.train: 71 | callbacks += [ 72 | ModelCheckpoint( 73 | dirpath=config.ckpt_dir, 74 | **config.checkpoint 75 | ), 76 | LearningRateMonitor(logging_interval='step'), 77 | CodeSnapshotCallback( 78 | config.code_dir, use_version=False 79 | ), 80 | ConfigSnapshotCallback( 81 | config, config.config_dir, use_version=False 82 | ), 83 | CustomProgressBar(refresh_rate=1), 84 | ] 85 | 86 | loggers = [] 87 | if args.train: 88 | loggers += [ 89 | TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name), 90 | CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs') 91 | ] 92 | 93 | if sys.platform == 'win32': 94 | # does not support multi-gpu on windows 95 | strategy = 'dp' 96 | assert n_gpus == 1 97 | else: 98 | strategy = 'ddp_find_unused_parameters_false' 99 | 100 | trainer = Trainer( 101 | devices=n_gpus, 102 | accelerator='gpu', 103 | callbacks=callbacks, 104 | logger=loggers, 105 | strategy=strategy, 106 | **config.trainer 107 | ) 108 | 109 | if args.train: 110 | if args.resume and not args.resume_weights_only: 111 | # FIXME: different behavior in pytorch-lighting>1.9 ? 112 | trainer.fit(system, datamodule=dm, ckpt_path=args.resume) 113 | else: 114 | trainer.fit(system, datamodule=dm) 115 | trainer.test(system, datamodule=dm) 116 | elif args.validate: 117 | trainer.validate(system, datamodule=dm, ckpt_path=args.resume) 118 | elif args.test: 119 | trainer.test(system, datamodule=dm, ckpt_path=args.resume) 120 | elif args.predict: 121 | trainer.predict(system, datamodule=dm, ckpt_path=args.resume) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | models = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | models[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config): 12 | model = models[name](config) 13 | return model 14 | 15 | 16 | from . import nerf, neus, geometry, texture 17 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.misc import get_rank 5 | 6 | class BaseModel(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.rank = get_rank() 11 | self.setup() 12 | if self.config.get('weights', None): 13 | self.load_state_dict(torch.load(self.config.weights)) 14 | 15 | def setup(self): 16 | raise NotImplementedError 17 | 18 | def update_step(self, epoch, global_step): 19 | pass 20 | 21 | def train(self, mode=True): 22 | return super().train(mode=mode) 23 | 24 | def eval(self): 25 | return super().eval() 26 | 27 | def regularizations(self, out): 28 | return {} 29 | 30 | @torch.no_grad() 31 | def export(self, export_config): 32 | return {} 33 | -------------------------------------------------------------------------------- /models/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 7 | 8 | import models 9 | from models.base import BaseModel 10 | from models.utils import scale_anything, get_activation, cleanup, chunk_batch 11 | from models.network_utils import get_encoding, get_mlp, get_encoding_with_network 12 | from utils.misc import get_rank 13 | from systems.utils import update_module_step 14 | from nerfacc import ContractionType 15 | 16 | 17 | def contract_to_unisphere(x, radius, contraction_type): 18 | if contraction_type == ContractionType.AABB: 19 | x = scale_anything(x, (-radius, radius), (0, 1)) 20 | elif contraction_type == ContractionType.UN_BOUNDED_SPHERE: 21 | x = scale_anything(x, (-radius, radius), (0, 1)) 22 | x = x * 2 - 1 # aabb is at [-1, 1] 23 | mag = x.norm(dim=-1, keepdim=True) 24 | mask = mag.squeeze(-1) > 1 25 | x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) 26 | x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] 27 | else: 28 | raise NotImplementedError 29 | return x 30 | 31 | 32 | class MarchingCubeHelper(nn.Module): 33 | def __init__(self, resolution, use_torch=True): 34 | super().__init__() 35 | self.resolution = resolution 36 | self.use_torch = use_torch 37 | self.points_range = (0, 1) 38 | if self.use_torch: 39 | import torchmcubes 40 | self.mc_func = torchmcubes.marching_cubes 41 | else: 42 | import mcubes 43 | self.mc_func = mcubes.marching_cubes 44 | self.verts = None 45 | 46 | def grid_vertices(self): 47 | if self.verts is None: 48 | x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution) 49 | x, y, z = torch.meshgrid(x, y, z, indexing='ij') 50 | verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3) 51 | self.verts = verts 52 | return self.verts 53 | 54 | def forward(self, level, threshold=0.): 55 | level = level.float().view(self.resolution, self.resolution, self.resolution) 56 | if self.use_torch: 57 | verts, faces = self.mc_func(level.to(get_rank()), threshold) 58 | verts, faces = verts.cpu(), faces.cpu().long() 59 | else: 60 | verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy 61 | verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch 62 | verts = verts / (self.resolution - 1.) 63 | return { 64 | 'v_pos': verts, 65 | 't_pos_idx': faces 66 | } 67 | 68 | 69 | class BaseImplicitGeometry(BaseModel): 70 | def __init__(self, config): 71 | super().__init__(config) 72 | if self.config.isosurface is not None: 73 | assert self.config.isosurface.method in ['mc', 'mc-torch'] 74 | if self.config.isosurface.method == 'mc-torch': 75 | raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.") 76 | self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch') 77 | self.radius = self.config.radius 78 | self.contraction_type = None # assigned in system 79 | 80 | def forward_level(self, points): 81 | raise NotImplementedError 82 | 83 | def isosurface_(self, vmin, vmax): 84 | def batch_func(x): 85 | x = torch.stack([ 86 | scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])), 87 | scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])), 88 | scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])), 89 | ], dim=-1).to(self.rank) 90 | rv = self.forward_level(x).cpu() 91 | cleanup() 92 | return rv 93 | 94 | level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices()) 95 | mesh = self.helper(level, threshold=self.config.isosurface.threshold) 96 | mesh['v_pos'] = torch.stack([ 97 | scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])), 98 | scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])), 99 | scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2])) 100 | ], dim=-1) 101 | return mesh 102 | 103 | @torch.no_grad() 104 | def isosurface(self): 105 | if self.config.isosurface is None: 106 | raise NotImplementedError 107 | mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius)) 108 | vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0) 109 | vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) 110 | vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) 111 | mesh_fine = self.isosurface_(vmin_, vmax_) 112 | return mesh_fine 113 | 114 | 115 | @models.register('volume-density') 116 | class VolumeDensity(BaseImplicitGeometry): 117 | def setup(self): 118 | self.n_input_dims = self.config.get('n_input_dims', 3) 119 | self.n_output_dims = self.config.feature_dim 120 | self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config) 121 | 122 | def forward(self, points): 123 | points = contract_to_unisphere(points, self.radius, self.contraction_type) 124 | out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float() 125 | density, feature = out[...,0], out 126 | if 'density_activation' in self.config: 127 | density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) 128 | if 'feature_activation' in self.config: 129 | feature = get_activation(self.config.feature_activation)(feature) 130 | return density, feature 131 | 132 | def forward_level(self, points): 133 | points = contract_to_unisphere(points, self.radius, self.contraction_type) 134 | density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0] 135 | if 'density_activation' in self.config: 136 | density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) 137 | return -density 138 | 139 | def update_step(self, epoch, global_step): 140 | update_module_step(self.encoding_with_network, epoch, global_step) 141 | 142 | 143 | @models.register('volume-sdf') 144 | class VolumeSDF(BaseImplicitGeometry): 145 | def setup(self): 146 | self.n_output_dims = self.config.feature_dim 147 | encoding = get_encoding(3, self.config.xyz_encoding_config) 148 | network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config) 149 | self.encoding, self.network = encoding, network 150 | self.grad_type = self.config.grad_type 151 | self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3) 152 | # the actual value used in training 153 | # will update at certain steps if finite_difference_eps="progressive" 154 | self._finite_difference_eps = None 155 | if self.grad_type == 'finite_difference': 156 | rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}") 157 | 158 | def forward(self, points, with_grad=True, with_feature=True, with_laplace=False): 159 | with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')): 160 | with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')): 161 | if with_grad and self.grad_type == 'analytic': 162 | if not self.training: 163 | points = points.clone() # points may be in inference mode, get a copy to enable grad 164 | points.requires_grad_(True) 165 | 166 | points_ = points # points in the original scale 167 | points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) 168 | 169 | out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float() 170 | sdf, feature = out[...,0], out 171 | if 'sdf_activation' in self.config: 172 | sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) 173 | if 'feature_activation' in self.config: 174 | feature = get_activation(self.config.feature_activation)(feature) 175 | if with_grad: 176 | if self.grad_type == 'analytic': 177 | grad = torch.autograd.grad( 178 | sdf, points_, grad_outputs=torch.ones_like(sdf), 179 | create_graph=True, retain_graph=True, only_inputs=True 180 | )[0] 181 | elif self.grad_type == 'finite_difference': 182 | eps = self._finite_difference_eps 183 | offsets = torch.as_tensor( 184 | [ 185 | [eps, 0.0, 0.0], 186 | [-eps, 0.0, 0.0], 187 | [0.0, eps, 0.0], 188 | [0.0, -eps, 0.0], 189 | [0.0, 0.0, eps], 190 | [0.0, 0.0, -eps], 191 | ] 192 | ).to(points_) 193 | points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius) 194 | points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1)) 195 | points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float() 196 | grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps 197 | 198 | if with_laplace: 199 | laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2) 200 | 201 | rv = [sdf] 202 | if with_grad: 203 | rv.append(grad) 204 | if with_feature: 205 | rv.append(feature) 206 | if with_laplace: 207 | assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'" 208 | rv.append(laplace) 209 | rv = [v if self.training else v.detach() for v in rv] 210 | return rv[0] if len(rv) == 1 else rv 211 | 212 | def forward_level(self, points): 213 | points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) 214 | sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0] 215 | if 'sdf_activation' in self.config: 216 | sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) 217 | return sdf 218 | 219 | def update_step(self, epoch, global_step): 220 | update_module_step(self.encoding, epoch, global_step) 221 | update_module_step(self.network, epoch, global_step) 222 | if self.grad_type == 'finite_difference': 223 | if isinstance(self.finite_difference_eps, float): 224 | self._finite_difference_eps = self.finite_difference_eps 225 | elif self.finite_difference_eps == 'progressive': 226 | hg_conf = self.config.xyz_encoding_config 227 | assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid" 228 | current_level = min( 229 | hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, 230 | hg_conf.n_levels 231 | ) 232 | grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1) 233 | grid_size = 2 * self.config.radius / grid_res 234 | if grid_size != self._finite_difference_eps: 235 | rank_zero_info(f"Update finite_difference_eps to {grid_size}") 236 | self._finite_difference_eps = grid_size 237 | else: 238 | raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}") 239 | -------------------------------------------------------------------------------- /models/nerf.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import models 8 | from models.base import BaseModel 9 | from models.utils import chunk_batch 10 | from systems.utils import update_module_step 11 | from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays 12 | 13 | 14 | @models.register('nerf') 15 | class NeRFModel(BaseModel): 16 | def setup(self): 17 | self.geometry = models.make(self.config.geometry.name, self.config.geometry) 18 | self.texture = models.make(self.config.texture.name, self.config.texture) 19 | self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) 20 | 21 | if self.config.learned_background: 22 | self.occupancy_grid_res = 256 23 | self.near_plane, self.far_plane = 0.2, 1e4 24 | self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate 25 | self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size) 26 | self.contraction_type = ContractionType.UN_BOUNDED_SPHERE 27 | else: 28 | self.occupancy_grid_res = 128 29 | self.near_plane, self.far_plane = None, None 30 | self.cone_angle = 0.0 31 | self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray 32 | self.contraction_type = ContractionType.AABB 33 | 34 | self.geometry.contraction_type = self.contraction_type 35 | 36 | if self.config.grid_prune: 37 | self.occupancy_grid = OccupancyGrid( 38 | roi_aabb=self.scene_aabb, 39 | resolution=self.occupancy_grid_res, 40 | contraction_type=self.contraction_type 41 | ) 42 | self.randomized = self.config.randomized 43 | self.background_color = None 44 | 45 | def update_step(self, epoch, global_step): 46 | update_module_step(self.geometry, epoch, global_step) 47 | update_module_step(self.texture, epoch, global_step) 48 | 49 | def occ_eval_fn(x): 50 | density, _ = self.geometry(x) 51 | # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series 52 | return density[...,None] * self.render_step_size 53 | 54 | if self.training and self.config.grid_prune: 55 | self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn) 56 | 57 | def isosurface(self): 58 | mesh = self.geometry.isosurface() 59 | return mesh 60 | 61 | def forward_(self, rays): 62 | n_rays = rays.shape[0] 63 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 64 | 65 | def sigma_fn(t_starts, t_ends, ray_indices): 66 | ray_indices = ray_indices.long() 67 | t_origins = rays_o[ray_indices] 68 | t_dirs = rays_d[ray_indices] 69 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2. 70 | density, _ = self.geometry(positions) 71 | return density[...,None] 72 | 73 | def rgb_sigma_fn(t_starts, t_ends, ray_indices): 74 | ray_indices = ray_indices.long() 75 | t_origins = rays_o[ray_indices] 76 | t_dirs = rays_d[ray_indices] 77 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2. 78 | density, feature = self.geometry(positions) 79 | rgb = self.texture(feature, t_dirs) 80 | return rgb, density[...,None] 81 | 82 | with torch.no_grad(): 83 | ray_indices, t_starts, t_ends = ray_marching( 84 | rays_o, rays_d, 85 | scene_aabb=None if self.config.learned_background else self.scene_aabb, 86 | grid=self.occupancy_grid if self.config.grid_prune else None, 87 | sigma_fn=sigma_fn, 88 | near_plane=self.near_plane, far_plane=self.far_plane, 89 | render_step_size=self.render_step_size, 90 | stratified=self.randomized, 91 | cone_angle=self.cone_angle, 92 | alpha_thre=0.0 93 | ) 94 | 95 | ray_indices = ray_indices.long() 96 | t_origins = rays_o[ray_indices] 97 | t_dirs = rays_d[ray_indices] 98 | midpoints = (t_starts + t_ends) / 2. 99 | positions = t_origins + t_dirs * midpoints 100 | intervals = t_ends - t_starts 101 | 102 | density, feature = self.geometry(positions) 103 | rgb = self.texture(feature, t_dirs) 104 | 105 | weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) 106 | opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) 107 | depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) 108 | comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) 109 | comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) 110 | 111 | out = { 112 | 'comp_rgb': comp_rgb, 113 | 'opacity': opacity, 114 | 'depth': depth, 115 | 'rays_valid': opacity > 0, 116 | 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) 117 | } 118 | 119 | if self.training: 120 | out.update({ 121 | 'weights': weights.view(-1), 122 | 'points': midpoints.view(-1), 123 | 'intervals': intervals.view(-1), 124 | 'ray_indices': ray_indices.view(-1) 125 | }) 126 | 127 | return out 128 | 129 | def forward(self, rays): 130 | if self.training: 131 | out = self.forward_(rays) 132 | else: 133 | out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) 134 | return { 135 | **out, 136 | } 137 | 138 | def train(self, mode=True): 139 | self.randomized = mode and self.config.randomized 140 | return super().train(mode=mode) 141 | 142 | def eval(self): 143 | self.randomized = False 144 | return super().eval() 145 | 146 | def regularizations(self, out): 147 | losses = {} 148 | losses.update(self.geometry.regularizations(out)) 149 | losses.update(self.texture.regularizations(out)) 150 | return losses 151 | 152 | @torch.no_grad() 153 | def export(self, export_config): 154 | mesh = self.isosurface() 155 | if export_config.export_vertex_color: 156 | _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank)) 157 | viewdirs = torch.zeros(feature.shape[0], 3).to(feature) 158 | viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down) 159 | rgb = self.texture(feature, viewdirs).clamp(0,1) 160 | mesh['v_rgb'] = rgb.cpu() 161 | return mesh 162 | -------------------------------------------------------------------------------- /models/network_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import tinycudann as tcnn 7 | 8 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info 9 | 10 | from utils.misc import config_to_primitive, get_rank 11 | from models.utils import get_activation 12 | from systems.utils import update_module_step 13 | 14 | class VanillaFrequency(nn.Module): 15 | def __init__(self, in_channels, config): 16 | super().__init__() 17 | self.N_freqs = config['n_frequencies'] 18 | self.in_channels, self.n_input_dims = in_channels, in_channels 19 | self.funcs = [torch.sin, torch.cos] 20 | self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs) 21 | self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) 22 | self.n_masking_step = config.get('n_masking_step', 0) 23 | self.update_step(None, None) # mask should be updated at the beginning each step 24 | 25 | def forward(self, x): 26 | out = [] 27 | for freq, mask in zip(self.freq_bands, self.mask): 28 | for func in self.funcs: 29 | out += [func(freq*x) * mask] 30 | return torch.cat(out, -1) 31 | 32 | def update_step(self, epoch, global_step): 33 | if self.n_masking_step <= 0 or global_step is None: 34 | self.mask = torch.ones(self.N_freqs, dtype=torch.float32) 35 | else: 36 | self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2. 37 | rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}') 38 | 39 | 40 | class ProgressiveBandHashGrid(nn.Module): 41 | def __init__(self, in_channels, config): 42 | super().__init__() 43 | self.n_input_dims = in_channels 44 | encoding_config = config.copy() 45 | encoding_config['otype'] = 'HashGrid' 46 | with torch.cuda.device(get_rank()): 47 | self.encoding = tcnn.Encoding(in_channels, encoding_config) 48 | self.n_output_dims = self.encoding.n_output_dims 49 | self.n_level = config['n_levels'] 50 | self.n_features_per_level = config['n_features_per_level'] 51 | self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps'] 52 | self.current_level = self.start_level 53 | self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank()) 54 | 55 | def forward(self, x): 56 | enc = self.encoding(x) 57 | enc = enc * self.mask 58 | return enc 59 | 60 | def update_step(self, epoch, global_step): 61 | current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level) 62 | if current_level > self.current_level: 63 | rank_zero_info(f'Update grid level to {current_level}') 64 | self.current_level = current_level 65 | self.mask[:self.current_level * self.n_features_per_level] = 1. 66 | 67 | 68 | class CompositeEncoding(nn.Module): 69 | def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.): 70 | super(CompositeEncoding, self).__init__() 71 | self.encoding = encoding 72 | self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset 73 | self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims 74 | 75 | def forward(self, x, *args): 76 | return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1) 77 | 78 | def update_step(self, epoch, global_step): 79 | update_module_step(self.encoding, epoch, global_step) 80 | 81 | 82 | def get_encoding(n_input_dims, config): 83 | # input suppose to be range [0, 1] 84 | if config.otype == 'VanillaFrequency': 85 | encoding = VanillaFrequency(n_input_dims, config_to_primitive(config)) 86 | elif config.otype == 'ProgressiveBandHashGrid': 87 | encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) 88 | else: 89 | with torch.cuda.device(get_rank()): 90 | encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config)) 91 | encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.) 92 | return encoding 93 | 94 | 95 | class VanillaMLP(nn.Module): 96 | def __init__(self, dim_in, dim_out, config): 97 | super().__init__() 98 | self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers'] 99 | self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False) 100 | self.sphere_init_radius = config.get('sphere_init_radius', 0.5) 101 | self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] 102 | for i in range(self.n_hidden_layers - 1): 103 | self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] 104 | self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] 105 | self.layers = nn.Sequential(*self.layers) 106 | self.output_activation = get_activation(config['output_activation']) 107 | 108 | @torch.cuda.amp.autocast(False) 109 | def forward(self, x): 110 | x = self.layers(x.float()) 111 | x = self.output_activation(x) 112 | return x 113 | 114 | def make_linear(self, dim_in, dim_out, is_first, is_last): 115 | layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality 116 | if self.sphere_init: 117 | if is_last: 118 | torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) 119 | torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) 120 | elif is_first: 121 | torch.nn.init.constant_(layer.bias, 0.0) 122 | torch.nn.init.constant_(layer.weight[:, 3:], 0.0) 123 | torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) 124 | else: 125 | torch.nn.init.constant_(layer.bias, 0.0) 126 | torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) 127 | else: 128 | torch.nn.init.constant_(layer.bias, 0.0) 129 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') 130 | 131 | if self.weight_norm: 132 | layer = nn.utils.weight_norm(layer) 133 | return layer 134 | 135 | def make_activation(self): 136 | if self.sphere_init: 137 | return nn.Softplus(beta=100) 138 | else: 139 | return nn.ReLU(inplace=True) 140 | 141 | 142 | def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network): 143 | rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.') 144 | """ 145 | from https://github.com/NVlabs/tiny-cuda-nn/issues/96 146 | It's the weight matrices of each layer laid out in row-major order and then concatenated. 147 | Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP). 148 | The padded input dimensions get a constant value of 1.0, 149 | whereas the padded output dimensions are simply ignored, 150 | so the weights pertaining to those can have any value. 151 | """ 152 | padto = 16 if config.otype == 'FullyFusedMLP' else 8 153 | n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto 154 | n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto 155 | data = list(network.parameters())[0].data 156 | assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2 157 | new_data = [] 158 | # first layer 159 | weight = torch.zeros((config.n_neurons, n_input_dims)).to(data) 160 | torch.nn.init.constant_(weight[:, 3:], 0.0) 161 | torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) 162 | new_data.append(weight.flatten()) 163 | # hidden layers 164 | for i in range(config.n_hidden_layers - 1): 165 | weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data) 166 | torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) 167 | new_data.append(weight.flatten()) 168 | # last layer 169 | weight = torch.zeros((n_output_dims, config.n_neurons)).to(data) 170 | torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001) 171 | new_data.append(weight.flatten()) 172 | new_data = torch.cat(new_data) 173 | data.copy_(new_data) 174 | 175 | 176 | def get_mlp(n_input_dims, n_output_dims, config): 177 | if config.otype == 'VanillaMLP': 178 | network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) 179 | else: 180 | with torch.cuda.device(get_rank()): 181 | network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config)) 182 | if config.get('sphere_init', False): 183 | sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network) 184 | return network 185 | 186 | 187 | class EncodingWithNetwork(nn.Module): 188 | def __init__(self, encoding, network): 189 | super().__init__() 190 | self.encoding, self.network = encoding, network 191 | 192 | def forward(self, x): 193 | return self.network(self.encoding(x)) 194 | 195 | def update_step(self, epoch, global_step): 196 | update_module_step(self.encoding, epoch, global_step) 197 | update_module_step(self.network, epoch, global_step) 198 | 199 | 200 | def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config): 201 | # input suppose to be range [0, 1] 202 | if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \ 203 | or network_config.otype in ['VanillaMLP']: 204 | encoding = get_encoding(n_input_dims, encoding_config) 205 | network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) 206 | encoding_with_network = EncodingWithNetwork(encoding, network) 207 | else: 208 | with torch.cuda.device(get_rank()): 209 | encoding_with_network = tcnn.NetworkWithInputEncoding( 210 | n_input_dims=n_input_dims, 211 | n_output_dims=n_output_dims, 212 | encoding_config=config_to_primitive(encoding_config), 213 | network_config=config_to_primitive(network_config) 214 | ) 215 | return encoding_with_network 216 | -------------------------------------------------------------------------------- /models/neus.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import models 8 | from models.base import BaseModel 9 | from models.utils import chunk_batch 10 | from systems.utils import update_module_step 11 | from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, render_weight_from_alpha, accumulate_along_rays 12 | from nerfacc.intersection import ray_aabb_intersect 13 | 14 | 15 | class VarianceNetwork(nn.Module): 16 | def __init__(self, config): 17 | super(VarianceNetwork, self).__init__() 18 | self.config = config 19 | self.init_val = self.config.init_val 20 | self.register_parameter('variance', nn.Parameter(torch.tensor(self.config.init_val))) 21 | self.modulate = self.config.get('modulate', False) 22 | if self.modulate: 23 | self.mod_start_steps = self.config.mod_start_steps 24 | self.reach_max_steps = self.config.reach_max_steps 25 | self.max_inv_s = self.config.max_inv_s 26 | 27 | @property 28 | def inv_s(self): 29 | val = torch.exp(self.variance * 10.0) 30 | if self.modulate and self.do_mod: 31 | val = val.clamp_max(self.mod_val) 32 | return val 33 | 34 | def forward(self, x): 35 | return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s 36 | 37 | def update_step(self, epoch, global_step): 38 | if self.modulate: 39 | self.do_mod = global_step > self.mod_start_steps 40 | if not self.do_mod: 41 | self.prev_inv_s = self.inv_s.item() 42 | else: 43 | self.mod_val = min((global_step / self.reach_max_steps) * (self.max_inv_s - self.prev_inv_s) + self.prev_inv_s, self.max_inv_s) 44 | 45 | 46 | @models.register('neus') 47 | class NeuSModel(BaseModel): 48 | def setup(self): 49 | self.geometry = models.make(self.config.geometry.name, self.config.geometry) 50 | self.texture = models.make(self.config.texture.name, self.config.texture) 51 | self.geometry.contraction_type = ContractionType.AABB 52 | 53 | if self.config.learned_background: 54 | self.geometry_bg = models.make(self.config.geometry_bg.name, self.config.geometry_bg) 55 | self.texture_bg = models.make(self.config.texture_bg.name, self.config.texture_bg) 56 | self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE 57 | self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 58 | self.cone_angle_bg = 10**(math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) - 1. 59 | self.render_step_size_bg = 0.01 60 | 61 | self.variance = VarianceNetwork(self.config.variance) 62 | self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) 63 | if self.config.grid_prune: 64 | self.occupancy_grid = OccupancyGrid( 65 | roi_aabb=self.scene_aabb, 66 | resolution=128, 67 | contraction_type=ContractionType.AABB 68 | ) 69 | if self.config.learned_background: 70 | self.occupancy_grid_bg = OccupancyGrid( 71 | roi_aabb=self.scene_aabb, 72 | resolution=256, 73 | contraction_type=ContractionType.UN_BOUNDED_SPHERE 74 | ) 75 | self.randomized = self.config.randomized 76 | self.background_color = None 77 | self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray 78 | 79 | def update_step(self, epoch, global_step): 80 | update_module_step(self.geometry, epoch, global_step) 81 | update_module_step(self.texture, epoch, global_step) 82 | if self.config.learned_background: 83 | update_module_step(self.geometry_bg, epoch, global_step) 84 | update_module_step(self.texture_bg, epoch, global_step) 85 | update_module_step(self.variance, epoch, global_step) 86 | 87 | cos_anneal_end = self.config.get('cos_anneal_end', 0) 88 | self.cos_anneal_ratio = 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) 89 | 90 | def occ_eval_fn(x): 91 | sdf = self.geometry(x, with_grad=False, with_feature=False) 92 | inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) 93 | inv_s = inv_s.expand(sdf.shape[0], 1) 94 | estimated_next_sdf = sdf[...,None] - self.render_step_size * 0.5 95 | estimated_prev_sdf = sdf[...,None] + self.render_step_size * 0.5 96 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 97 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 98 | p = prev_cdf - next_cdf 99 | c = prev_cdf 100 | alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) 101 | return alpha 102 | 103 | def occ_eval_fn_bg(x): 104 | density, _ = self.geometry_bg(x) 105 | # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series 106 | return density[...,None] * self.render_step_size_bg 107 | 108 | if self.training and self.config.grid_prune: 109 | self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn, occ_thre=self.config.get('grid_prune_occ_thre', 0.01)) 110 | if self.config.learned_background: 111 | self.occupancy_grid_bg.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn_bg, occ_thre=self.config.get('grid_prune_occ_thre_bg', 0.01)) 112 | 113 | def isosurface(self): 114 | mesh = self.geometry.isosurface() 115 | return mesh 116 | 117 | def get_alpha(self, sdf, normal, dirs, dists): 118 | inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 119 | inv_s = inv_s.expand(sdf.shape[0], 1) 120 | 121 | true_cos = (dirs * normal).sum(-1, keepdim=True) 122 | 123 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 124 | # the cos value "not dead" at the beginning training iterations, for better convergence. 125 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + 126 | F.relu(-true_cos) * self.cos_anneal_ratio) # always non-positive 127 | 128 | # Estimate signed distances at section points 129 | estimated_next_sdf = sdf[...,None] + iter_cos * dists.reshape(-1, 1) * 0.5 130 | estimated_prev_sdf = sdf[...,None] - iter_cos * dists.reshape(-1, 1) * 0.5 131 | 132 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 133 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 134 | 135 | p = prev_cdf - next_cdf 136 | c = prev_cdf 137 | 138 | alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) 139 | return alpha 140 | 141 | def forward_bg_(self, rays): 142 | n_rays = rays.shape[0] 143 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 144 | 145 | def sigma_fn(t_starts, t_ends, ray_indices): 146 | ray_indices = ray_indices.long() 147 | t_origins = rays_o[ray_indices] 148 | t_dirs = rays_d[ray_indices] 149 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2. 150 | density, _ = self.geometry_bg(positions) 151 | return density[...,None] 152 | 153 | _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) 154 | # if the ray intersects with the bounding box, start from the farther intersection point 155 | # otherwise start from self.far_plane_bg 156 | # note that in nerfacc t_max is set to 1e10 if there is no intersection 157 | near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) 158 | with torch.no_grad(): 159 | ray_indices, t_starts, t_ends = ray_marching( 160 | rays_o, rays_d, 161 | scene_aabb=None, 162 | grid=self.occupancy_grid_bg if self.config.grid_prune else None, 163 | sigma_fn=sigma_fn, 164 | near_plane=near_plane, far_plane=self.far_plane_bg, 165 | render_step_size=self.render_step_size_bg, 166 | stratified=self.randomized, 167 | cone_angle=self.cone_angle_bg, 168 | alpha_thre=0.0 169 | ) 170 | 171 | ray_indices = ray_indices.long() 172 | t_origins = rays_o[ray_indices] 173 | t_dirs = rays_d[ray_indices] 174 | midpoints = (t_starts + t_ends) / 2. 175 | positions = t_origins + t_dirs * midpoints 176 | intervals = t_ends - t_starts 177 | 178 | density, feature = self.geometry_bg(positions) 179 | rgb = self.texture_bg(feature, t_dirs) 180 | 181 | weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) 182 | opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) 183 | depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) 184 | comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) 185 | comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) 186 | 187 | out = { 188 | 'comp_rgb': comp_rgb, 189 | 'opacity': opacity, 190 | 'depth': depth, 191 | 'rays_valid': opacity > 0, 192 | 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) 193 | } 194 | 195 | if self.training: 196 | out.update({ 197 | 'weights': weights.view(-1), 198 | 'points': midpoints.view(-1), 199 | 'intervals': intervals.view(-1), 200 | 'ray_indices': ray_indices.view(-1) 201 | }) 202 | 203 | return out 204 | 205 | def forward_(self, rays): 206 | n_rays = rays.shape[0] 207 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 208 | 209 | with torch.no_grad(): 210 | ray_indices, t_starts, t_ends = ray_marching( 211 | rays_o, rays_d, 212 | scene_aabb=self.scene_aabb, 213 | grid=self.occupancy_grid if self.config.grid_prune else None, 214 | alpha_fn=None, 215 | near_plane=None, far_plane=None, 216 | render_step_size=self.render_step_size, 217 | stratified=self.randomized, 218 | cone_angle=0.0, 219 | alpha_thre=0.0 220 | ) 221 | 222 | ray_indices = ray_indices.long() 223 | t_origins = rays_o[ray_indices] 224 | t_dirs = rays_d[ray_indices] 225 | midpoints = (t_starts + t_ends) / 2. 226 | positions = t_origins + t_dirs * midpoints 227 | dists = t_ends - t_starts 228 | 229 | if self.config.geometry.grad_type == 'finite_difference': 230 | sdf, sdf_grad, feature, sdf_laplace = self.geometry(positions, with_grad=True, with_feature=True, with_laplace=True) 231 | else: 232 | sdf, sdf_grad, feature = self.geometry(positions, with_grad=True, with_feature=True) 233 | normal = F.normalize(sdf_grad, p=2, dim=-1) 234 | alpha = self.get_alpha(sdf, normal, t_dirs, dists)[...,None] 235 | rgb = self.texture(feature, t_dirs, normal) 236 | 237 | weights = render_weight_from_alpha(alpha, ray_indices=ray_indices, n_rays=n_rays) 238 | opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) 239 | depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) 240 | comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) 241 | 242 | comp_normal = accumulate_along_rays(weights, ray_indices, values=normal, n_rays=n_rays) 243 | comp_normal = F.normalize(comp_normal, p=2, dim=-1) 244 | 245 | out = { 246 | 'comp_rgb': comp_rgb, 247 | 'comp_normal': comp_normal, 248 | 'opacity': opacity, 249 | 'depth': depth, 250 | 'rays_valid': opacity > 0, 251 | 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) 252 | } 253 | 254 | if self.training: 255 | out.update({ 256 | 'sdf_samples': sdf, 257 | 'sdf_grad_samples': sdf_grad, 258 | 'weights': weights.view(-1), 259 | 'points': midpoints.view(-1), 260 | 'intervals': dists.view(-1), 261 | 'ray_indices': ray_indices.view(-1) 262 | }) 263 | if self.config.geometry.grad_type == 'finite_difference': 264 | out.update({ 265 | 'sdf_laplace_samples': sdf_laplace 266 | }) 267 | 268 | if self.config.learned_background: 269 | out_bg = self.forward_bg_(rays) 270 | else: 271 | out_bg = { 272 | 'comp_rgb': self.background_color[None,:].expand(*comp_rgb.shape), 273 | 'num_samples': torch.zeros_like(out['num_samples']), 274 | 'rays_valid': torch.zeros_like(out['rays_valid']) 275 | } 276 | 277 | out_full = { 278 | 'comp_rgb': out['comp_rgb'] + out_bg['comp_rgb'] * (1.0 - out['opacity']), 279 | 'num_samples': out['num_samples'] + out_bg['num_samples'], 280 | 'rays_valid': out['rays_valid'] | out_bg['rays_valid'] 281 | } 282 | 283 | return { 284 | **out, 285 | **{k + '_bg': v for k, v in out_bg.items()}, 286 | **{k + '_full': v for k, v in out_full.items()} 287 | } 288 | 289 | def forward(self, rays): 290 | if self.training: 291 | out = self.forward_(rays) 292 | else: 293 | out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) 294 | return { 295 | **out, 296 | 'inv_s': self.variance.inv_s 297 | } 298 | 299 | def train(self, mode=True): 300 | self.randomized = mode and self.config.randomized 301 | return super().train(mode=mode) 302 | 303 | def eval(self): 304 | self.randomized = False 305 | return super().eval() 306 | 307 | def regularizations(self, out): 308 | losses = {} 309 | losses.update(self.geometry.regularizations(out)) 310 | losses.update(self.texture.regularizations(out)) 311 | return losses 312 | 313 | @torch.no_grad() 314 | def export(self, export_config): 315 | mesh = self.isosurface() 316 | if export_config.export_vertex_color: 317 | _, sdf_grad, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank), with_grad=True, with_feature=True) 318 | normal = F.normalize(sdf_grad, p=2, dim=-1) 319 | rgb = self.texture(feature, -normal, normal) # set the viewing directions to the normal to get "albedo" 320 | mesh['v_rgb'] = rgb.cpu() 321 | return mesh 322 | -------------------------------------------------------------------------------- /models/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def cast_rays(ori, dir, z_vals): 6 | return ori[..., None, :] + z_vals[..., None] * dir[..., None, :] 7 | 8 | 9 | def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): 10 | pixel_center = 0.5 if use_pixel_centers else 0 11 | i, j = np.meshgrid( 12 | np.arange(W, dtype=np.float32) + pixel_center, 13 | np.arange(H, dtype=np.float32) + pixel_center, 14 | indexing='xy' 15 | ) 16 | i, j = torch.from_numpy(i), torch.from_numpy(j) 17 | 18 | directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) 19 | 20 | return directions 21 | 22 | 23 | def get_rays(directions, c2w, keepdim=False): 24 | # Rotate ray directions from camera coordinate to the world coordinate 25 | # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? 26 | assert directions.shape[-1] == 3 27 | 28 | if directions.ndim == 2: # (N_rays, 3) 29 | assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) 30 | rays_d = (directions[:,None,:] * c2w[:,:3,:3]).sum(-1) # (N_rays, 3) 31 | rays_o = c2w[:,:,3].expand(rays_d.shape) 32 | elif directions.ndim == 3: # (H, W, 3) 33 | if c2w.ndim == 2: # (4, 4) 34 | rays_d = (directions[:,:,None,:] * c2w[None,None,:3,:3]).sum(-1) # (H, W, 3) 35 | rays_o = c2w[None,None,:,3].expand(rays_d.shape) 36 | elif c2w.ndim == 3: # (B, 4, 4) 37 | rays_d = (directions[None,:,:,None,:] * c2w[:,None,None,:3,:3]).sum(-1) # (B, H, W, 3) 38 | rays_o = c2w[:,None,None,:,3].expand(rays_d.shape) 39 | 40 | if not keepdim: 41 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 42 | 43 | return rays_o, rays_d 44 | -------------------------------------------------------------------------------- /models/texture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import models 5 | from models.utils import get_activation 6 | from models.network_utils import get_encoding, get_mlp 7 | from systems.utils import update_module_step 8 | 9 | 10 | @models.register('volume-radiance') 11 | class VolumeRadiance(nn.Module): 12 | def __init__(self, config): 13 | super(VolumeRadiance, self).__init__() 14 | self.config = config 15 | self.n_dir_dims = self.config.get('n_dir_dims', 3) 16 | self.n_output_dims = 3 17 | encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) 18 | self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims 19 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) 20 | self.encoding = encoding 21 | self.network = network 22 | 23 | def forward(self, features, dirs, *args): 24 | dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) 25 | dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) 26 | network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) 27 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() 28 | if 'color_activation' in self.config: 29 | color = get_activation(self.config.color_activation)(color) 30 | return color 31 | 32 | def update_step(self, epoch, global_step): 33 | update_module_step(self.encoding, epoch, global_step) 34 | 35 | def regularizations(self, out): 36 | return {} 37 | 38 | 39 | @models.register('volume-color') 40 | class VolumeColor(nn.Module): 41 | def __init__(self, config): 42 | super(VolumeColor, self).__init__() 43 | self.config = config 44 | self.n_output_dims = 3 45 | self.n_input_dims = self.config.input_feature_dim 46 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) 47 | self.network = network 48 | 49 | def forward(self, features, *args): 50 | network_inp = features.view(-1, features.shape[-1]) 51 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() 52 | if 'color_activation' in self.config: 53 | color = get_activation(self.config.color_activation)(color) 54 | return color 55 | 56 | def regularizations(self, out): 57 | return {} 58 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from collections import defaultdict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | from torch.cuda.amp import custom_bwd, custom_fwd 9 | 10 | import tinycudann as tcnn 11 | 12 | 13 | def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): 14 | B = None 15 | for arg in args: 16 | if isinstance(arg, torch.Tensor): 17 | B = arg.shape[0] 18 | break 19 | out = defaultdict(list) 20 | out_type = None 21 | for i in range(0, B, chunk_size): 22 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) 23 | if out_chunk is None: 24 | continue 25 | out_type = type(out_chunk) 26 | if isinstance(out_chunk, torch.Tensor): 27 | out_chunk = {0: out_chunk} 28 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): 29 | chunk_length = len(out_chunk) 30 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} 31 | elif isinstance(out_chunk, dict): 32 | pass 33 | else: 34 | print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') 35 | exit(1) 36 | for k, v in out_chunk.items(): 37 | v = v if torch.is_grad_enabled() else v.detach() 38 | v = v.cpu() if move_to_cpu else v 39 | out[k].append(v) 40 | 41 | if out_type is None: 42 | return 43 | 44 | out = {k: torch.cat(v, dim=0) for k, v in out.items()} 45 | if out_type is torch.Tensor: 46 | return out[0] 47 | elif out_type in [tuple, list]: 48 | return out_type([out[i] for i in range(chunk_length)]) 49 | elif out_type is dict: 50 | return out 51 | 52 | 53 | class _TruncExp(Function): # pylint: disable=abstract-method 54 | # Implementation from torch-ngp: 55 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 56 | @staticmethod 57 | @custom_fwd(cast_inputs=torch.float32) 58 | def forward(ctx, x): # pylint: disable=arguments-differ 59 | ctx.save_for_backward(x) 60 | return torch.exp(x) 61 | 62 | @staticmethod 63 | @custom_bwd 64 | def backward(ctx, g): # pylint: disable=arguments-differ 65 | x = ctx.saved_tensors[0] 66 | return g * torch.exp(torch.clamp(x, max=15)) 67 | 68 | trunc_exp = _TruncExp.apply 69 | 70 | 71 | def get_activation(name): 72 | if name is None: 73 | return lambda x: x 74 | name = name.lower() 75 | if name == 'none': 76 | return lambda x: x 77 | elif name.startswith('scale'): 78 | scale_factor = float(name[5:]) 79 | return lambda x: x.clamp(0., scale_factor) / scale_factor 80 | elif name.startswith('clamp'): 81 | clamp_max = float(name[5:]) 82 | return lambda x: x.clamp(0., clamp_max) 83 | elif name.startswith('mul'): 84 | mul_factor = float(name[3:]) 85 | return lambda x: x * mul_factor 86 | elif name == 'lin2srgb': 87 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) 88 | elif name == 'trunc_exp': 89 | return trunc_exp 90 | elif name.startswith('+') or name.startswith('-'): 91 | return lambda x: x + float(name) 92 | elif name == 'sigmoid': 93 | return lambda x: torch.sigmoid(x) 94 | elif name == 'tanh': 95 | return lambda x: torch.tanh(x) 96 | else: 97 | return getattr(F, name) 98 | 99 | 100 | def dot(x, y): 101 | return torch.sum(x*y, -1, keepdim=True) 102 | 103 | 104 | def reflect(x, n): 105 | return 2 * dot(x, n) * n - x 106 | 107 | 108 | def scale_anything(dat, inp_scale, tgt_scale): 109 | if inp_scale is None: 110 | inp_scale = [dat.min(), dat.max()] 111 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 112 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 113 | return dat 114 | 115 | 116 | def cleanup(): 117 | gc.collect() 118 | torch.cuda.empty_cache() 119 | tcnn.free_temporary_memory() 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning<2 2 | omegaconf==2.2.3 3 | nerfacc==0.3.3 4 | matplotlib 5 | opencv-python 6 | imageio 7 | imageio-ffmpeg 8 | scipy 9 | PyMCubes 10 | pyransac3d 11 | torch_efficient_distloss 12 | tensorboard 13 | -------------------------------------------------------------------------------- /scripts/imgs2poses.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This file is adapted from https://github.com/Fyusion/LLFF. 4 | """ 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import subprocess 10 | 11 | 12 | def run_colmap(basedir, match_type): 13 | logfile_name = os.path.join(basedir, 'colmap_output.txt') 14 | logfile = open(logfile_name, 'w') 15 | 16 | feature_extractor_args = [ 17 | 'colmap', 'feature_extractor', 18 | '--database_path', os.path.join(basedir, 'database.db'), 19 | '--image_path', os.path.join(basedir, 'images'), 20 | '--ImageReader.single_camera', '1' 21 | ] 22 | feat_output = ( subprocess.check_output(feature_extractor_args, universal_newlines=True) ) 23 | logfile.write(feat_output) 24 | print('Features extracted') 25 | 26 | exhaustive_matcher_args = [ 27 | 'colmap', match_type, 28 | '--database_path', os.path.join(basedir, 'database.db'), 29 | ] 30 | 31 | match_output = ( subprocess.check_output(exhaustive_matcher_args, universal_newlines=True) ) 32 | logfile.write(match_output) 33 | print('Features matched') 34 | 35 | p = os.path.join(basedir, 'sparse') 36 | if not os.path.exists(p): 37 | os.makedirs(p) 38 | 39 | mapper_args = [ 40 | 'colmap', 'mapper', 41 | '--database_path', os.path.join(basedir, 'database.db'), 42 | '--image_path', os.path.join(basedir, 'images'), 43 | '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6 44 | '--Mapper.num_threads', '16', 45 | '--Mapper.init_min_tri_angle', '4', 46 | '--Mapper.multiple_models', '0', 47 | '--Mapper.extract_colors', '0', 48 | ] 49 | 50 | map_output = ( subprocess.check_output(mapper_args, universal_newlines=True) ) 51 | logfile.write(map_output) 52 | logfile.close() 53 | print('Sparse map created') 54 | 55 | print( 'Finished running COLMAP, see {} for logs'.format(logfile_name) ) 56 | 57 | 58 | def gen_poses(basedir, match_type): 59 | files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']] 60 | if os.path.exists(os.path.join(basedir, 'sparse/0')): 61 | files_had = os.listdir(os.path.join(basedir, 'sparse/0')) 62 | else: 63 | files_had = [] 64 | if not all([f in files_had for f in files_needed]): 65 | print( 'Need to run COLMAP' ) 66 | run_colmap(basedir, match_type) 67 | else: 68 | print('Don\'t need to run COLMAP') 69 | 70 | return True 71 | 72 | 73 | if __name__=='__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--match_type', type=str, 76 | default='exhaustive_matcher', help='type of matcher used. Valid options: \ 77 | exhaustive_matcher sequential_matcher. Other matchers not supported at this time') 78 | parser.add_argument('scenedir', type=str, 79 | help='input scene directory') 80 | args = parser.parse_args() 81 | 82 | if args.match_type != 'exhaustive_matcher' and args.match_type != 'sequential_matcher': 83 | print('ERROR: matcher type ' + args.match_type + ' is not valid. Aborting') 84 | sys.exit() 85 | gen_poses(args.scenedir, args.match_type) 86 | -------------------------------------------------------------------------------- /systems/__init__.py: -------------------------------------------------------------------------------- 1 | systems = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | systems[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config, load_from_checkpoint=None): 12 | if load_from_checkpoint is None: 13 | system = systems[name](config) 14 | else: 15 | system = systems[name].load_from_checkpoint(load_from_checkpoint, strict=False, config=config) 16 | return system 17 | 18 | 19 | from . import nerf, neus 20 | -------------------------------------------------------------------------------- /systems/base.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | import models 4 | from systems.utils import parse_optimizer, parse_scheduler, update_module_step 5 | from utils.mixins import SaverMixin 6 | from utils.misc import config_to_primitive, get_rank 7 | 8 | 9 | class BaseSystem(pl.LightningModule, SaverMixin): 10 | """ 11 | Two ways to print to console: 12 | 1. self.print: correctly handle progress bar 13 | 2. rank_zero_info: use the logging module 14 | """ 15 | def __init__(self, config): 16 | super().__init__() 17 | self.config = config 18 | self.rank = get_rank() 19 | self.prepare() 20 | self.model = models.make(self.config.model.name, self.config.model) 21 | 22 | def prepare(self): 23 | pass 24 | 25 | def forward(self, batch): 26 | raise NotImplementedError 27 | 28 | def C(self, value): 29 | if isinstance(value, int) or isinstance(value, float): 30 | pass 31 | else: 32 | value = config_to_primitive(value) 33 | if not isinstance(value, list): 34 | raise TypeError('Scalar specification only supports list, got', type(value)) 35 | if len(value) == 3: 36 | value = [0] + value 37 | assert len(value) == 4 38 | start_step, start_value, end_value, end_step = value 39 | if isinstance(end_step, int): 40 | current_step = self.global_step 41 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) 42 | elif isinstance(end_step, float): 43 | current_step = self.current_epoch 44 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) 45 | return value 46 | 47 | def preprocess_data(self, batch, stage): 48 | pass 49 | 50 | """ 51 | Implementing on_after_batch_transfer of DataModule does the same. 52 | But on_after_batch_transfer does not support DP. 53 | """ 54 | def on_train_batch_start(self, batch, batch_idx, unused=0): 55 | self.dataset = self.trainer.datamodule.train_dataloader().dataset 56 | self.preprocess_data(batch, 'train') 57 | update_module_step(self.model, self.current_epoch, self.global_step) 58 | 59 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): 60 | self.dataset = self.trainer.datamodule.val_dataloader().dataset 61 | self.preprocess_data(batch, 'validation') 62 | update_module_step(self.model, self.current_epoch, self.global_step) 63 | 64 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx): 65 | self.dataset = self.trainer.datamodule.test_dataloader().dataset 66 | self.preprocess_data(batch, 'test') 67 | update_module_step(self.model, self.current_epoch, self.global_step) 68 | 69 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): 70 | self.dataset = self.trainer.datamodule.predict_dataloader().dataset 71 | self.preprocess_data(batch, 'predict') 72 | update_module_step(self.model, self.current_epoch, self.global_step) 73 | 74 | def training_step(self, batch, batch_idx): 75 | raise NotImplementedError 76 | 77 | """ 78 | # aggregate outputs from different devices (DP) 79 | def training_step_end(self, out): 80 | pass 81 | """ 82 | 83 | """ 84 | # aggregate outputs from different iterations 85 | def training_epoch_end(self, out): 86 | pass 87 | """ 88 | 89 | def validation_step(self, batch, batch_idx): 90 | raise NotImplementedError 91 | 92 | """ 93 | # aggregate outputs from different devices when using DP 94 | def validation_step_end(self, out): 95 | pass 96 | """ 97 | 98 | def validation_epoch_end(self, out): 99 | """ 100 | Gather metrics from all devices, compute mean. 101 | Purge repeated results using data index. 102 | """ 103 | raise NotImplementedError 104 | 105 | def test_step(self, batch, batch_idx): 106 | raise NotImplementedError 107 | 108 | def test_epoch_end(self, out): 109 | """ 110 | Gather metrics from all devices, compute mean. 111 | Purge repeated results using data index. 112 | """ 113 | raise NotImplementedError 114 | 115 | def export(self): 116 | raise NotImplementedError 117 | 118 | def configure_optimizers(self): 119 | optim = parse_optimizer(self.config.system.optimizer, self.model) 120 | ret = { 121 | 'optimizer': optim, 122 | } 123 | if 'scheduler' in self.config.system: 124 | ret.update({ 125 | 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), 126 | }) 127 | return ret 128 | 129 | -------------------------------------------------------------------------------- /systems/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class WeightedLoss(nn.Module): 7 | @property 8 | def func(self): 9 | raise NotImplementedError 10 | 11 | def forward(self, inputs, targets, weight=None, reduction='mean'): 12 | assert reduction in ['none', 'sum', 'mean', 'valid_mean'] 13 | loss = self.func(inputs, targets, reduction='none') 14 | if weight is not None: 15 | while weight.ndim < inputs.ndim: 16 | weight = weight[..., None] 17 | loss *= weight.float() 18 | if reduction == 'none': 19 | return loss 20 | elif reduction == 'sum': 21 | return loss.sum() 22 | elif reduction == 'mean': 23 | return loss.mean() 24 | elif reduction == 'valid_mean': 25 | return loss.sum() / weight.float().sum() 26 | 27 | 28 | class MSELoss(WeightedLoss): 29 | @property 30 | def func(self): 31 | return F.mse_loss 32 | 33 | 34 | class L1Loss(WeightedLoss): 35 | @property 36 | def func(self): 37 | return F.l1_loss 38 | 39 | 40 | class PSNR(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, inputs, targets, valid_mask=None, reduction='mean'): 45 | assert reduction in ['mean', 'none'] 46 | value = (inputs - targets)**2 47 | if valid_mask is not None: 48 | value = value[valid_mask] 49 | if reduction == 'mean': 50 | return -10 * torch.log10(torch.mean(value)) 51 | elif reduction == 'none': 52 | return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) 53 | 54 | 55 | class SSIM(): 56 | def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): 57 | self.kernel_size = kernel_size 58 | self.sigma = sigma 59 | self.gaussian = gaussian 60 | 61 | if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): 62 | raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") 63 | if any(y <= 0 for y in self.sigma): 64 | raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") 65 | 66 | data_scale = data_range[1] - data_range[0] 67 | self.c1 = (k1 * data_scale)**2 68 | self.c2 = (k2 * data_scale)**2 69 | self.pad_h = (self.kernel_size[0] - 1) // 2 70 | self.pad_w = (self.kernel_size[1] - 1) // 2 71 | self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) 72 | 73 | def _uniform(self, kernel_size): 74 | max, min = 2.5, -2.5 75 | ksize_half = (kernel_size - 1) * 0.5 76 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 77 | for i, j in enumerate(kernel): 78 | if min <= j <= max: 79 | kernel[i] = 1 / (max - min) 80 | else: 81 | kernel[i] = 0 82 | 83 | return kernel.unsqueeze(dim=0) # (1, kernel_size) 84 | 85 | def _gaussian(self, kernel_size, sigma): 86 | ksize_half = (kernel_size - 1) * 0.5 87 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 88 | gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) 89 | return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) 90 | 91 | def _gaussian_or_uniform_kernel(self, kernel_size, sigma): 92 | if self.gaussian: 93 | kernel_x = self._gaussian(kernel_size[0], sigma[0]) 94 | kernel_y = self._gaussian(kernel_size[1], sigma[1]) 95 | else: 96 | kernel_x = self._uniform(kernel_size[0]) 97 | kernel_y = self._uniform(kernel_size[1]) 98 | 99 | return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) 100 | 101 | def __call__(self, output, target, reduction='mean'): 102 | if output.dtype != target.dtype: 103 | raise TypeError( 104 | f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." 105 | ) 106 | 107 | if output.shape != target.shape: 108 | raise ValueError( 109 | f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." 110 | ) 111 | 112 | if len(output.shape) != 4 or len(target.shape) != 4: 113 | raise ValueError( 114 | f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." 115 | ) 116 | 117 | assert reduction in ['mean', 'sum', 'none'] 118 | 119 | channel = output.size(1) 120 | if len(self._kernel.shape) < 4: 121 | self._kernel = self._kernel.expand(channel, 1, -1, -1) 122 | 123 | output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") 124 | target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") 125 | 126 | input_list = torch.cat([output, target, output * output, target * target, output * target]) 127 | outputs = F.conv2d(input_list, self._kernel, groups=channel) 128 | 129 | output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] 130 | 131 | mu_pred_sq = output_list[0].pow(2) 132 | mu_target_sq = output_list[1].pow(2) 133 | mu_pred_target = output_list[0] * output_list[1] 134 | 135 | sigma_pred_sq = output_list[2] - mu_pred_sq 136 | sigma_target_sq = output_list[3] - mu_target_sq 137 | sigma_pred_target = output_list[4] - mu_pred_target 138 | 139 | a1 = 2 * mu_pred_target + self.c1 140 | a2 = 2 * sigma_pred_target + self.c2 141 | b1 = mu_pred_sq + mu_target_sq + self.c1 142 | b2 = sigma_pred_sq + sigma_target_sq + self.c2 143 | 144 | ssim_idx = (a1 * a2) / (b1 * b2) 145 | _ssim = torch.mean(ssim_idx, (1, 2, 3)) 146 | 147 | if reduction == 'none': 148 | return _ssim 149 | elif reduction == 'sum': 150 | return _ssim.sum() 151 | elif reduction == 'mean': 152 | return _ssim.mean() 153 | 154 | 155 | def binary_cross_entropy(input, target): 156 | """ 157 | F.binary_cross_entropy is not numerically stable in mixed-precision training. 158 | """ 159 | return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean() 160 | -------------------------------------------------------------------------------- /systems/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_efficient_distloss import flatten_eff_distloss 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug 8 | 9 | import models 10 | from models.ray_utils import get_rays 11 | import systems 12 | from systems.base import BaseSystem 13 | from systems.criterions import PSNR 14 | 15 | 16 | @systems.register('nerf-system') 17 | class NeRFSystem(BaseSystem): 18 | """ 19 | Two ways to print to console: 20 | 1. self.print: correctly handle progress bar 21 | 2. rank_zero_info: use the logging module 22 | """ 23 | def prepare(self): 24 | self.criterions = { 25 | 'psnr': PSNR() 26 | } 27 | self.train_num_samples = self.config.model.train_num_rays * self.config.model.num_samples_per_ray 28 | self.train_num_rays = self.config.model.train_num_rays 29 | 30 | def forward(self, batch): 31 | return self.model(batch['rays']) 32 | 33 | def preprocess_data(self, batch, stage): 34 | if 'index' in batch: # validation / testing 35 | index = batch['index'] 36 | else: 37 | if self.config.model.batch_image_sampling: 38 | index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) 39 | else: 40 | index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) 41 | if stage in ['train']: 42 | c2w = self.dataset.all_c2w[index] 43 | x = torch.randint( 44 | 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device 45 | ) 46 | y = torch.randint( 47 | 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device 48 | ) 49 | if self.dataset.directions.ndim == 3: # (H, W, 3) 50 | directions = self.dataset.directions[y, x] 51 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 52 | directions = self.dataset.directions[index, y, x] 53 | rays_o, rays_d = get_rays(directions, c2w) 54 | rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 55 | fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) 56 | else: 57 | c2w = self.dataset.all_c2w[index][0] 58 | if self.dataset.directions.ndim == 3: # (H, W, 3) 59 | directions = self.dataset.directions 60 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 61 | directions = self.dataset.directions[index][0] 62 | rays_o, rays_d = get_rays(directions, c2w) 63 | rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 64 | fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) 65 | 66 | rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) 67 | 68 | if stage in ['train']: 69 | if self.config.model.background_color == 'white': 70 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 71 | elif self.config.model.background_color == 'random': 72 | self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) 73 | else: 74 | raise NotImplementedError 75 | else: 76 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 77 | 78 | if self.dataset.apply_mask: 79 | rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) 80 | 81 | batch.update({ 82 | 'rays': rays, 83 | 'rgb': rgb, 84 | 'fg_mask': fg_mask 85 | }) 86 | 87 | def training_step(self, batch, batch_idx): 88 | out = self(batch) 89 | 90 | loss = 0. 91 | 92 | # update train_num_rays 93 | if self.config.model.dynamic_ray_sampling: 94 | train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples'].sum().item())) 95 | self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) 96 | 97 | loss_rgb = F.smooth_l1_loss(out['comp_rgb'][out['rays_valid'][...,0]], batch['rgb'][out['rays_valid'][...,0]]) 98 | self.log('train/loss_rgb', loss_rgb) 99 | loss += loss_rgb * self.C(self.config.system.loss.lambda_rgb) 100 | 101 | # distortion loss proposed in MipNeRF360 102 | # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss, but still slows down training by ~30% 103 | if self.C(self.config.system.loss.lambda_distortion) > 0: 104 | loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) 105 | self.log('train/loss_distortion', loss_distortion) 106 | loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) 107 | 108 | losses_model_reg = self.model.regularizations(out) 109 | for name, value in losses_model_reg.items(): 110 | self.log(f'train/loss_{name}', value) 111 | loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) 112 | loss += loss_ 113 | 114 | for name, value in self.config.system.loss.items(): 115 | if name.startswith('lambda'): 116 | self.log(f'train_params/{name}', self.C(value)) 117 | 118 | self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) 119 | 120 | return { 121 | 'loss': loss 122 | } 123 | 124 | """ 125 | # aggregate outputs from different devices (DP) 126 | def training_step_end(self, out): 127 | pass 128 | """ 129 | 130 | """ 131 | # aggregate outputs from different iterations 132 | def training_epoch_end(self, out): 133 | pass 134 | """ 135 | 136 | def validation_step(self, batch, batch_idx): 137 | out = self(batch) 138 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) 139 | W, H = self.dataset.img_wh 140 | self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ 141 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 142 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 143 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 144 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} 145 | ]) 146 | return { 147 | 'psnr': psnr, 148 | 'index': batch['index'] 149 | } 150 | 151 | 152 | """ 153 | # aggregate outputs from different devices when using DP 154 | def validation_step_end(self, out): 155 | pass 156 | """ 157 | 158 | def validation_epoch_end(self, out): 159 | out = self.all_gather(out) 160 | if self.trainer.is_global_zero: 161 | out_set = {} 162 | for step_out in out: 163 | # DP 164 | if step_out['index'].ndim == 1: 165 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 166 | # DDP 167 | else: 168 | for oi, index in enumerate(step_out['index']): 169 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 170 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 171 | self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) 172 | 173 | def test_step(self, batch, batch_idx): 174 | out = self(batch) 175 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) 176 | W, H = self.dataset.img_wh 177 | self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ 178 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 179 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 180 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 181 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} 182 | ]) 183 | return { 184 | 'psnr': psnr, 185 | 'index': batch['index'] 186 | } 187 | 188 | def test_epoch_end(self, out): 189 | out = self.all_gather(out) 190 | if self.trainer.is_global_zero: 191 | out_set = {} 192 | for step_out in out: 193 | # DP 194 | if step_out['index'].ndim == 1: 195 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 196 | # DDP 197 | else: 198 | for oi, index in enumerate(step_out['index']): 199 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 200 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 201 | self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) 202 | 203 | self.save_img_sequence( 204 | f"it{self.global_step}-test", 205 | f"it{self.global_step}-test", 206 | '(\d+)\.png', 207 | save_format='mp4', 208 | fps=30 209 | ) 210 | 211 | self.export() 212 | 213 | def export(self): 214 | mesh = self.model.export(self.config.export) 215 | self.save_mesh( 216 | f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", 217 | **mesh 218 | ) 219 | -------------------------------------------------------------------------------- /systems/neus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_efficient_distloss import flatten_eff_distloss 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug 8 | 9 | import models 10 | from models.utils import cleanup 11 | from models.ray_utils import get_rays 12 | import systems 13 | from systems.base import BaseSystem 14 | from systems.criterions import PSNR, binary_cross_entropy 15 | 16 | 17 | @systems.register('neus-system') 18 | class NeuSSystem(BaseSystem): 19 | """ 20 | Two ways to print to console: 21 | 1. self.print: correctly handle progress bar 22 | 2. rank_zero_info: use the logging module 23 | """ 24 | def prepare(self): 25 | self.criterions = { 26 | 'psnr': PSNR() 27 | } 28 | self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) 29 | self.train_num_rays = self.config.model.train_num_rays 30 | 31 | def forward(self, batch): 32 | return self.model(batch['rays']) 33 | 34 | def preprocess_data(self, batch, stage): 35 | if 'index' in batch: # validation / testing 36 | index = batch['index'] 37 | else: 38 | if self.config.model.batch_image_sampling: 39 | index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) 40 | else: 41 | index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) 42 | if stage in ['train']: 43 | c2w = self.dataset.all_c2w[index] 44 | x = torch.randint( 45 | 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device 46 | ) 47 | y = torch.randint( 48 | 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device 49 | ) 50 | if self.dataset.directions.ndim == 3: # (H, W, 3) 51 | directions = self.dataset.directions[y, x] 52 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 53 | directions = self.dataset.directions[index, y, x] 54 | rays_o, rays_d = get_rays(directions, c2w) 55 | rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 56 | fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) 57 | else: 58 | c2w = self.dataset.all_c2w[index][0] 59 | if self.dataset.directions.ndim == 3: # (H, W, 3) 60 | directions = self.dataset.directions 61 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 62 | directions = self.dataset.directions[index][0] 63 | rays_o, rays_d = get_rays(directions, c2w) 64 | rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 65 | fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) 66 | 67 | rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) 68 | 69 | if stage in ['train']: 70 | if self.config.model.background_color == 'white': 71 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 72 | elif self.config.model.background_color == 'random': 73 | self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) 74 | else: 75 | raise NotImplementedError 76 | else: 77 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 78 | 79 | if self.dataset.apply_mask: 80 | rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) 81 | 82 | batch.update({ 83 | 'rays': rays, 84 | 'rgb': rgb, 85 | 'fg_mask': fg_mask 86 | }) 87 | 88 | def training_step(self, batch, batch_idx): 89 | out = self(batch) 90 | 91 | loss = 0. 92 | 93 | # update train_num_rays 94 | if self.config.model.dynamic_ray_sampling: 95 | train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) 96 | self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) 97 | 98 | loss_rgb_mse = F.mse_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) 99 | self.log('train/loss_rgb_mse', loss_rgb_mse) 100 | loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) 101 | 102 | loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) 103 | self.log('train/loss_rgb', loss_rgb_l1) 104 | loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) 105 | 106 | loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() 107 | self.log('train/loss_eikonal', loss_eikonal) 108 | loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) 109 | 110 | opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) 111 | loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float()) 112 | self.log('train/loss_mask', loss_mask) 113 | loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) 114 | 115 | loss_opaque = binary_cross_entropy(opacity, opacity) 116 | self.log('train/loss_opaque', loss_opaque) 117 | loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) 118 | 119 | loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['sdf_samples'].abs()).mean() 120 | self.log('train/loss_sparsity', loss_sparsity) 121 | loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) 122 | 123 | if self.C(self.config.system.loss.lambda_curvature) > 0: 124 | assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" 125 | loss_curvature = out['sdf_laplace_samples'].abs().mean() 126 | self.log('train/loss_curvature', loss_curvature) 127 | loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) 128 | 129 | # distortion loss proposed in MipNeRF360 130 | # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss 131 | if self.C(self.config.system.loss.lambda_distortion) > 0: 132 | loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) 133 | self.log('train/loss_distortion', loss_distortion) 134 | loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) 135 | 136 | if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: 137 | loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) 138 | self.log('train/loss_distortion_bg', loss_distortion_bg) 139 | loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) 140 | 141 | losses_model_reg = self.model.regularizations(out) 142 | for name, value in losses_model_reg.items(): 143 | self.log(f'train/loss_{name}', value) 144 | loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) 145 | loss += loss_ 146 | 147 | self.log('train/inv_s', out['inv_s'], prog_bar=True) 148 | 149 | for name, value in self.config.system.loss.items(): 150 | if name.startswith('lambda'): 151 | self.log(f'train_params/{name}', self.C(value)) 152 | 153 | self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) 154 | 155 | return { 156 | 'loss': loss 157 | } 158 | 159 | """ 160 | # aggregate outputs from different devices (DP) 161 | def training_step_end(self, out): 162 | pass 163 | """ 164 | 165 | """ 166 | # aggregate outputs from different iterations 167 | def training_epoch_end(self, out): 168 | pass 169 | """ 170 | 171 | def validation_step(self, batch, batch_idx): 172 | out = self(batch) 173 | psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) 174 | W, H = self.dataset.img_wh 175 | self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ 176 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 177 | {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} 178 | ] + ([ 179 | {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 180 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 181 | ] if self.config.model.learned_background else []) + [ 182 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 183 | {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} 184 | ]) 185 | return { 186 | 'psnr': psnr, 187 | 'index': batch['index'] 188 | } 189 | 190 | 191 | """ 192 | # aggregate outputs from different devices when using DP 193 | def validation_step_end(self, out): 194 | pass 195 | """ 196 | 197 | def validation_epoch_end(self, out): 198 | out = self.all_gather(out) 199 | if self.trainer.is_global_zero: 200 | out_set = {} 201 | for step_out in out: 202 | # DP 203 | if step_out['index'].ndim == 1: 204 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 205 | # DDP 206 | else: 207 | for oi, index in enumerate(step_out['index']): 208 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 209 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 210 | self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) 211 | 212 | def test_step(self, batch, batch_idx): 213 | out = self(batch) 214 | psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) 215 | W, H = self.dataset.img_wh 216 | self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ 217 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 218 | {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} 219 | ] + ([ 220 | {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 221 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 222 | ] if self.config.model.learned_background else []) + [ 223 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 224 | {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} 225 | ]) 226 | return { 227 | 'psnr': psnr, 228 | 'index': batch['index'] 229 | } 230 | 231 | def test_epoch_end(self, out): 232 | """ 233 | Synchronize devices. 234 | Generate image sequence using test outputs. 235 | """ 236 | out = self.all_gather(out) 237 | if self.trainer.is_global_zero: 238 | out_set = {} 239 | for step_out in out: 240 | # DP 241 | if step_out['index'].ndim == 1: 242 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 243 | # DDP 244 | else: 245 | for oi, index in enumerate(step_out['index']): 246 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 247 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 248 | self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) 249 | 250 | self.save_img_sequence( 251 | f"it{self.global_step}-test", 252 | f"it{self.global_step}-test", 253 | '(\d+)\.png', 254 | save_format='mp4', 255 | fps=30 256 | ) 257 | 258 | self.export() 259 | 260 | def export(self): 261 | mesh = self.model.export(self.config.export) 262 | self.save_mesh( 263 | f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", 264 | **mesh 265 | ) 266 | -------------------------------------------------------------------------------- /systems/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug 10 | 11 | 12 | class ChainedScheduler(lr_scheduler._LRScheduler): 13 | """Chains list of learning rate schedulers. It takes a list of chainable learning 14 | rate schedulers and performs consecutive step() functions belong to them by just 15 | one call. 16 | 17 | Args: 18 | schedulers (list): List of chained schedulers. 19 | 20 | Example: 21 | >>> # Assuming optimizer uses lr = 1. for all groups 22 | >>> # lr = 0.09 if epoch == 0 23 | >>> # lr = 0.081 if epoch == 1 24 | >>> # lr = 0.729 if epoch == 2 25 | >>> # lr = 0.6561 if epoch == 3 26 | >>> # lr = 0.59049 if epoch >= 4 27 | >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) 28 | >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) 29 | >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) 30 | >>> for epoch in range(100): 31 | >>> train(...) 32 | >>> validate(...) 33 | >>> scheduler.step() 34 | """ 35 | 36 | def __init__(self, optimizer, schedulers): 37 | for scheduler_idx in range(1, len(schedulers)): 38 | if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): 39 | raise ValueError( 40 | "ChainedScheduler expects all schedulers to belong to the same optimizer, but " 41 | "got schedulers at index {} and {} to be different".format(0, scheduler_idx) 42 | ) 43 | self._schedulers = list(schedulers) 44 | self.optimizer = optimizer 45 | 46 | def step(self): 47 | for scheduler in self._schedulers: 48 | scheduler.step() 49 | 50 | def state_dict(self): 51 | """Returns the state of the scheduler as a :class:`dict`. 52 | 53 | It contains an entry for every variable in self.__dict__ which 54 | is not the optimizer. 55 | The wrapped scheduler states will also be saved. 56 | """ 57 | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} 58 | state_dict['_schedulers'] = [None] * len(self._schedulers) 59 | 60 | for idx, s in enumerate(self._schedulers): 61 | state_dict['_schedulers'][idx] = s.state_dict() 62 | 63 | return state_dict 64 | 65 | def load_state_dict(self, state_dict): 66 | """Loads the schedulers state. 67 | 68 | Args: 69 | state_dict (dict): scheduler state. Should be an object returned 70 | from a call to :meth:`state_dict`. 71 | """ 72 | _schedulers = state_dict.pop('_schedulers') 73 | self.__dict__.update(state_dict) 74 | # Restore state_dict keys in order to prevent side effects 75 | # https://github.com/pytorch/pytorch/issues/32756 76 | state_dict['_schedulers'] = _schedulers 77 | 78 | for idx, s in enumerate(_schedulers): 79 | self._schedulers[idx].load_state_dict(s) 80 | 81 | 82 | class SequentialLR(lr_scheduler._LRScheduler): 83 | """Receives the list of schedulers that is expected to be called sequentially during 84 | optimization process and milestone points that provides exact intervals to reflect 85 | which scheduler is supposed to be called at a given epoch. 86 | 87 | Args: 88 | schedulers (list): List of chained schedulers. 89 | milestones (list): List of integers that reflects milestone points. 90 | 91 | Example: 92 | >>> # Assuming optimizer uses lr = 1. for all groups 93 | >>> # lr = 0.1 if epoch == 0 94 | >>> # lr = 0.1 if epoch == 1 95 | >>> # lr = 0.9 if epoch == 2 96 | >>> # lr = 0.81 if epoch == 3 97 | >>> # lr = 0.729 if epoch == 4 98 | >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) 99 | >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) 100 | >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) 101 | >>> for epoch in range(100): 102 | >>> train(...) 103 | >>> validate(...) 104 | >>> scheduler.step() 105 | """ 106 | 107 | def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): 108 | for scheduler_idx in range(1, len(schedulers)): 109 | if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): 110 | raise ValueError( 111 | "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " 112 | "got schedulers at index {} and {} to be different".format(0, scheduler_idx) 113 | ) 114 | if (len(milestones) != len(schedulers) - 1): 115 | raise ValueError( 116 | "Sequential Schedulers expects number of schedulers provided to be one more " 117 | "than the number of milestone points, but got number of schedulers {} and the " 118 | "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) 119 | ) 120 | self._schedulers = schedulers 121 | self._milestones = milestones 122 | self.last_epoch = last_epoch + 1 123 | self.optimizer = optimizer 124 | 125 | def step(self): 126 | self.last_epoch += 1 127 | idx = bisect_right(self._milestones, self.last_epoch) 128 | if idx > 0 and self._milestones[idx - 1] == self.last_epoch: 129 | self._schedulers[idx].step(0) 130 | else: 131 | self._schedulers[idx].step() 132 | 133 | def state_dict(self): 134 | """Returns the state of the scheduler as a :class:`dict`. 135 | 136 | It contains an entry for every variable in self.__dict__ which 137 | is not the optimizer. 138 | The wrapped scheduler states will also be saved. 139 | """ 140 | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} 141 | state_dict['_schedulers'] = [None] * len(self._schedulers) 142 | 143 | for idx, s in enumerate(self._schedulers): 144 | state_dict['_schedulers'][idx] = s.state_dict() 145 | 146 | return state_dict 147 | 148 | def load_state_dict(self, state_dict): 149 | """Loads the schedulers state. 150 | 151 | Args: 152 | state_dict (dict): scheduler state. Should be an object returned 153 | from a call to :meth:`state_dict`. 154 | """ 155 | _schedulers = state_dict.pop('_schedulers') 156 | self.__dict__.update(state_dict) 157 | # Restore state_dict keys in order to prevent side effects 158 | # https://github.com/pytorch/pytorch/issues/32756 159 | state_dict['_schedulers'] = _schedulers 160 | 161 | for idx, s in enumerate(_schedulers): 162 | self._schedulers[idx].load_state_dict(s) 163 | 164 | 165 | class ConstantLR(lr_scheduler._LRScheduler): 166 | """Decays the learning rate of each parameter group by a small constant factor until the 167 | number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can 168 | happen simultaneously with other changes to the learning rate from outside this scheduler. 169 | When last_epoch=-1, sets initial lr as lr. 170 | 171 | Args: 172 | optimizer (Optimizer): Wrapped optimizer. 173 | factor (float): The number we multiply learning rate until the milestone. Default: 1./3. 174 | total_iters (int): The number of steps that the scheduler decays the learning rate. 175 | Default: 5. 176 | last_epoch (int): The index of the last epoch. Default: -1. 177 | verbose (bool): If ``True``, prints a message to stdout for 178 | each update. Default: ``False``. 179 | 180 | Example: 181 | >>> # Assuming optimizer uses lr = 0.05 for all groups 182 | >>> # lr = 0.025 if epoch == 0 183 | >>> # lr = 0.025 if epoch == 1 184 | >>> # lr = 0.025 if epoch == 2 185 | >>> # lr = 0.025 if epoch == 3 186 | >>> # lr = 0.05 if epoch >= 4 187 | >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) 188 | >>> for epoch in range(100): 189 | >>> train(...) 190 | >>> validate(...) 191 | >>> scheduler.step() 192 | """ 193 | 194 | def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): 195 | if factor > 1.0 or factor < 0: 196 | raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') 197 | 198 | self.factor = factor 199 | self.total_iters = total_iters 200 | super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) 201 | 202 | def get_lr(self): 203 | if not self._get_lr_called_within_step: 204 | warnings.warn("To get the last learning rate computed by the scheduler, " 205 | "please use `get_last_lr()`.", UserWarning) 206 | 207 | if self.last_epoch == 0: 208 | return [group['lr'] * self.factor for group in self.optimizer.param_groups] 209 | 210 | if (self.last_epoch > self.total_iters or 211 | (self.last_epoch != self.total_iters)): 212 | return [group['lr'] for group in self.optimizer.param_groups] 213 | 214 | if (self.last_epoch == self.total_iters): 215 | return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] 216 | 217 | def _get_closed_form_lr(self): 218 | return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) 219 | for base_lr in self.base_lrs] 220 | 221 | 222 | class LinearLR(lr_scheduler._LRScheduler): 223 | """Decays the learning rate of each parameter group by linearly changing small 224 | multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. 225 | Notice that such decay can happen simultaneously with other changes to the learning rate 226 | from outside this scheduler. When last_epoch=-1, sets initial lr as lr. 227 | 228 | Args: 229 | optimizer (Optimizer): Wrapped optimizer. 230 | start_factor (float): The number we multiply learning rate in the first epoch. 231 | The multiplication factor changes towards end_factor in the following epochs. 232 | Default: 1./3. 233 | end_factor (float): The number we multiply learning rate at the end of linear changing 234 | process. Default: 1.0. 235 | total_iters (int): The number of iterations that multiplicative factor reaches to 1. 236 | Default: 5. 237 | last_epoch (int): The index of the last epoch. Default: -1. 238 | verbose (bool): If ``True``, prints a message to stdout for 239 | each update. Default: ``False``. 240 | 241 | Example: 242 | >>> # Assuming optimizer uses lr = 0.05 for all groups 243 | >>> # lr = 0.025 if epoch == 0 244 | >>> # lr = 0.03125 if epoch == 1 245 | >>> # lr = 0.0375 if epoch == 2 246 | >>> # lr = 0.04375 if epoch == 3 247 | >>> # lr = 0.05 if epoch >= 4 248 | >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) 249 | >>> for epoch in range(100): 250 | >>> train(...) 251 | >>> validate(...) 252 | >>> scheduler.step() 253 | """ 254 | 255 | def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, 256 | verbose=False): 257 | if start_factor > 1.0 or start_factor < 0: 258 | raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') 259 | 260 | if end_factor > 1.0 or end_factor < 0: 261 | raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') 262 | 263 | self.start_factor = start_factor 264 | self.end_factor = end_factor 265 | self.total_iters = total_iters 266 | super(LinearLR, self).__init__(optimizer, last_epoch, verbose) 267 | 268 | def get_lr(self): 269 | if not self._get_lr_called_within_step: 270 | warnings.warn("To get the last learning rate computed by the scheduler, " 271 | "please use `get_last_lr()`.", UserWarning) 272 | 273 | if self.last_epoch == 0: 274 | return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] 275 | 276 | if (self.last_epoch > self.total_iters): 277 | return [group['lr'] for group in self.optimizer.param_groups] 278 | 279 | return [group['lr'] * (1. + (self.end_factor - self.start_factor) / 280 | (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) 281 | for group in self.optimizer.param_groups] 282 | 283 | def _get_closed_form_lr(self): 284 | return [base_lr * (self.start_factor + 285 | (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) 286 | for base_lr in self.base_lrs] 287 | 288 | 289 | custom_schedulers = ['ConstantLR', 'LinearLR'] 290 | def get_scheduler(name): 291 | if hasattr(lr_scheduler, name): 292 | return getattr(lr_scheduler, name) 293 | elif name in custom_schedulers: 294 | return getattr(sys.modules[__name__], name) 295 | else: 296 | raise NotImplementedError 297 | 298 | 299 | def getattr_recursive(m, attr): 300 | for name in attr.split('.'): 301 | m = getattr(m, name) 302 | return m 303 | 304 | 305 | def get_parameters(model, name): 306 | module = getattr_recursive(model, name) 307 | if isinstance(module, nn.Module): 308 | return module.parameters() 309 | elif isinstance(module, nn.Parameter): 310 | return module 311 | return [] 312 | 313 | 314 | def parse_optimizer(config, model): 315 | if hasattr(config, 'params'): 316 | params = [{'params': get_parameters(model, name), 'name': name, **args} for name, args in config.params.items()] 317 | rank_zero_debug('Specify optimizer params:', config.params) 318 | else: 319 | params = model.parameters() 320 | if config.name in ['FusedAdam']: 321 | import apex 322 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 323 | else: 324 | optim = getattr(torch.optim, config.name)(params, **config.args) 325 | return optim 326 | 327 | 328 | def parse_scheduler(config, optimizer): 329 | interval = config.get('interval', 'epoch') 330 | assert interval in ['epoch', 'step'] 331 | if config.name == 'SequentialLR': 332 | scheduler = { 333 | 'scheduler': SequentialLR(optimizer, [parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers], milestones=config.milestones), 334 | 'interval': interval 335 | } 336 | elif config.name == 'Chained': 337 | scheduler = { 338 | 'scheduler': ChainedScheduler([parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers]), 339 | 'interval': interval 340 | } 341 | else: 342 | scheduler = { 343 | 'scheduler': get_scheduler(config.name)(optimizer, **config.args), 344 | 'interval': interval 345 | } 346 | return scheduler 347 | 348 | 349 | def update_module_step(m, epoch, global_step): 350 | if hasattr(m, 'update_step'): 351 | m.update_step(epoch, global_step) 352 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bennyguo/instant-nsr-pl/e5fe3f246cf2d512494a73c727174c3ca1c3c695/utils/__init__.py -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import shutil 4 | from utils.misc import dump_config, parse_version 5 | 6 | 7 | import pytorch_lightning 8 | if parse_version(pytorch_lightning.__version__) > parse_version('1.8'): 9 | from pytorch_lightning.callbacks import Callback 10 | else: 11 | from pytorch_lightning.callbacks.base import Callback 12 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 13 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 14 | 15 | 16 | class VersionedCallback(Callback): 17 | def __init__(self, save_root, version=None, use_version=True): 18 | self.save_root = save_root 19 | self._version = version 20 | self.use_version = use_version 21 | 22 | @property 23 | def version(self) -> int: 24 | """Get the experiment version. 25 | 26 | Returns: 27 | The experiment version if specified else the next version. 28 | """ 29 | if self._version is None: 30 | self._version = self._get_next_version() 31 | return self._version 32 | 33 | def _get_next_version(self): 34 | existing_versions = [] 35 | if os.path.isdir(self.save_root): 36 | for f in os.listdir(self.save_root): 37 | bn = os.path.basename(f) 38 | if bn.startswith("version_"): 39 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 40 | existing_versions.append(int(dir_ver)) 41 | if len(existing_versions) == 0: 42 | return 0 43 | return max(existing_versions) + 1 44 | 45 | @property 46 | def savedir(self): 47 | if not self.use_version: 48 | return self.save_root 49 | return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}") 50 | 51 | 52 | class CodeSnapshotCallback(VersionedCallback): 53 | def __init__(self, save_root, version=None, use_version=True): 54 | super().__init__(save_root, version, use_version) 55 | 56 | def get_file_list(self): 57 | return [ 58 | b.decode() for b in 59 | set(subprocess.check_output('git ls-files', shell=True).splitlines()) | 60 | set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines()) 61 | ] 62 | 63 | @rank_zero_only 64 | def save_code_snapshot(self): 65 | os.makedirs(self.savedir, exist_ok=True) 66 | for f in self.get_file_list(): 67 | if not os.path.exists(f) or os.path.isdir(f): 68 | continue 69 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 70 | shutil.copyfile(f, os.path.join(self.savedir, f)) 71 | 72 | def on_fit_start(self, trainer, pl_module): 73 | try: 74 | self.save_code_snapshot() 75 | except: 76 | rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.") 77 | 78 | 79 | class ConfigSnapshotCallback(VersionedCallback): 80 | def __init__(self, config, save_root, version=None, use_version=True): 81 | super().__init__(save_root, version, use_version) 82 | self.config = config 83 | 84 | @rank_zero_only 85 | def save_config_snapshot(self): 86 | os.makedirs(self.savedir, exist_ok=True) 87 | dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) 88 | shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml')) 89 | 90 | def on_fit_start(self, trainer, pl_module): 91 | self.save_config_snapshot() 92 | 93 | 94 | class CustomProgressBar(TQDMProgressBar): 95 | def get_metrics(self, *args, **kwargs): 96 | # don't show the version number 97 | items = super().get_metrics(*args, **kwargs) 98 | items.pop("v_num", None) 99 | return items 100 | -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pprint 3 | import logging 4 | 5 | from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment 6 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 7 | 8 | 9 | class ConsoleLogger(LightningLoggerBase): 10 | def __init__(self, log_keys=[]): 11 | super().__init__() 12 | self.log_keys = [re.compile(k) for k in log_keys] 13 | self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat 14 | 15 | def match_log_keys(self, s): 16 | return True if not self.log_keys else any(r.search(s) for r in self.log_keys) 17 | 18 | @property 19 | def name(self): 20 | return 'console' 21 | 22 | @property 23 | def version(self): 24 | return '0' 25 | 26 | @property 27 | @rank_zero_experiment 28 | def experiment(self): 29 | return logging.getLogger('pytorch_lightning') 30 | 31 | @rank_zero_only 32 | def log_hyperparams(self, params): 33 | pass 34 | 35 | @rank_zero_only 36 | def log_metrics(self, metrics, step): 37 | metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} 38 | if not metrics_: 39 | return 40 | self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}") 41 | 42 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | from packaging import version 4 | 5 | 6 | # ============ Register OmegaConf Recolvers ============= # 7 | OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) 8 | OmegaConf.register_new_resolver('add', lambda a, b: a + b) 9 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b) 10 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b) 11 | OmegaConf.register_new_resolver('div', lambda a, b: a / b) 12 | OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) 13 | OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) 14 | # ======================================================= # 15 | 16 | 17 | def prompt(question): 18 | inp = input(f"{question} (y/n)").lower().strip() 19 | if inp and inp == 'y': 20 | return True 21 | if inp and inp == 'n': 22 | return False 23 | return prompt(question) 24 | 25 | 26 | def load_config(*yaml_files, cli_args=[]): 27 | yaml_confs = [OmegaConf.load(f) for f in yaml_files] 28 | cli_conf = OmegaConf.from_cli(cli_args) 29 | conf = OmegaConf.merge(*yaml_confs, cli_conf) 30 | OmegaConf.resolve(conf) 31 | return conf 32 | 33 | 34 | def config_to_primitive(config, resolve=True): 35 | return OmegaConf.to_container(config, resolve=resolve) 36 | 37 | 38 | def dump_config(path, config): 39 | with open(path, 'w') as fp: 40 | OmegaConf.save(config=config, f=fp) 41 | 42 | def get_rank(): 43 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 44 | # therefore LOCAL_RANK needs to be checked first 45 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 46 | for key in rank_keys: 47 | rank = os.environ.get(key) 48 | if rank is not None: 49 | return int(rank) 50 | return 0 51 | 52 | 53 | def parse_version(ver): 54 | return version.parse(ver) 55 | -------------------------------------------------------------------------------- /utils/mixins.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import numpy as np 5 | import cv2 6 | import imageio 7 | from matplotlib import cm 8 | from matplotlib.colors import LinearSegmentedColormap 9 | import json 10 | 11 | import torch 12 | 13 | from utils.obj import write_obj 14 | 15 | 16 | class SaverMixin(): 17 | @property 18 | def save_dir(self): 19 | return self.config.save_dir 20 | 21 | def convert_data(self, data): 22 | if isinstance(data, np.ndarray): 23 | return data 24 | elif isinstance(data, torch.Tensor): 25 | return data.cpu().numpy() 26 | elif isinstance(data, list): 27 | return [self.convert_data(d) for d in data] 28 | elif isinstance(data, dict): 29 | return {k: self.convert_data(v) for k, v in data.items()} 30 | else: 31 | raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data)) 32 | 33 | def get_save_path(self, filename): 34 | save_path = os.path.join(self.save_dir, filename) 35 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 36 | return save_path 37 | 38 | DEFAULT_RGB_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1)} 39 | DEFAULT_UV_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1), 'cmap': 'checkerboard'} 40 | DEFAULT_GRAYSCALE_KWARGS = {'data_range': None, 'cmap': 'jet'} 41 | 42 | def get_rgb_image_(self, img, data_format, data_range): 43 | img = self.convert_data(img) 44 | assert data_format in ['CHW', 'HWC'] 45 | if data_format == 'CHW': 46 | img = img.transpose(1, 2, 0) 47 | img = img.clip(min=data_range[0], max=data_range[1]) 48 | img = ((img - data_range[0]) / (data_range[1] - data_range[0]) * 255.).astype(np.uint8) 49 | imgs = [img[...,start:start+3] for start in range(0, img.shape[-1], 3)] 50 | imgs = [img_ if img_.shape[-1] == 3 else np.concatenate([img_, np.zeros((img_.shape[0], img_.shape[1], 3 - img_.shape[2]), dtype=img_.dtype)], axis=-1) for img_ in imgs] 51 | img = np.concatenate(imgs, axis=1) 52 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 53 | return img 54 | 55 | def save_rgb_image(self, filename, img, data_format=DEFAULT_RGB_KWARGS['data_format'], data_range=DEFAULT_RGB_KWARGS['data_range']): 56 | img = self.get_rgb_image_(img, data_format, data_range) 57 | cv2.imwrite(self.get_save_path(filename), img) 58 | 59 | def get_uv_image_(self, img, data_format, data_range, cmap): 60 | img = self.convert_data(img) 61 | assert data_format in ['CHW', 'HWC'] 62 | if data_format == 'CHW': 63 | img = img.transpose(1, 2, 0) 64 | img = img.clip(min=data_range[0], max=data_range[1]) 65 | img = (img - data_range[0]) / (data_range[1] - data_range[0]) 66 | assert cmap in ['checkerboard', 'color'] 67 | if cmap == 'checkerboard': 68 | n_grid = 64 69 | mask = (img * n_grid).astype(int) 70 | mask = (mask[...,0] + mask[...,1]) % 2 == 0 71 | img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 72 | img[mask] = np.array([255, 0, 255], dtype=np.uint8) 73 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 74 | elif cmap == 'color': 75 | img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) 76 | img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) 77 | img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) 78 | img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) 79 | img = img_ 80 | return img 81 | 82 | def save_uv_image(self, filename, img, data_format=DEFAULT_UV_KWARGS['data_format'], data_range=DEFAULT_UV_KWARGS['data_range'], cmap=DEFAULT_UV_KWARGS['cmap']): 83 | img = self.get_uv_image_(img, data_format, data_range, cmap) 84 | cv2.imwrite(self.get_save_path(filename), img) 85 | 86 | def get_grayscale_image_(self, img, data_range, cmap): 87 | img = self.convert_data(img) 88 | img = np.nan_to_num(img) 89 | if data_range is None: 90 | img = (img - img.min()) / (img.max() - img.min()) 91 | else: 92 | img = img.clip(data_range[0], data_range[1]) 93 | img = (img - data_range[0]) / (data_range[1] - data_range[0]) 94 | assert cmap in [None, 'jet', 'magma'] 95 | if cmap == None: 96 | img = (img * 255.).astype(np.uint8) 97 | img = np.repeat(img[...,None], 3, axis=2) 98 | elif cmap == 'jet': 99 | img = (img * 255.).astype(np.uint8) 100 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 101 | elif cmap == 'magma': 102 | img = 1. - img 103 | base = cm.get_cmap('magma') 104 | num_bins = 256 105 | colormap = LinearSegmentedColormap.from_list( 106 | f"{base.name}{num_bins}", 107 | base(np.linspace(0, 1, num_bins)), 108 | num_bins 109 | )(np.linspace(0, 1, num_bins))[:,:3] 110 | a = np.floor(img * 255.) 111 | b = (a + 1).clip(max=255.) 112 | f = img * 255. - a 113 | a = a.astype(np.uint16).clip(0, 255) 114 | b = b.astype(np.uint16).clip(0, 255) 115 | img = colormap[a] + (colormap[b] - colormap[a]) * f[...,None] 116 | img = (img * 255.).astype(np.uint8) 117 | return img 118 | 119 | def save_grayscale_image(self, filename, img, data_range=DEFAULT_GRAYSCALE_KWARGS['data_range'], cmap=DEFAULT_GRAYSCALE_KWARGS['cmap']): 120 | img = self.get_grayscale_image_(img, data_range, cmap) 121 | cv2.imwrite(self.get_save_path(filename), img) 122 | 123 | def get_image_grid_(self, imgs): 124 | if isinstance(imgs[0], list): 125 | return np.concatenate([self.get_image_grid_(row) for row in imgs], axis=0) 126 | cols = [] 127 | for col in imgs: 128 | assert col['type'] in ['rgb', 'uv', 'grayscale'] 129 | if col['type'] == 'rgb': 130 | rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() 131 | rgb_kwargs.update(col['kwargs']) 132 | cols.append(self.get_rgb_image_(col['img'], **rgb_kwargs)) 133 | elif col['type'] == 'uv': 134 | uv_kwargs = self.DEFAULT_UV_KWARGS.copy() 135 | uv_kwargs.update(col['kwargs']) 136 | cols.append(self.get_uv_image_(col['img'], **uv_kwargs)) 137 | elif col['type'] == 'grayscale': 138 | grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() 139 | grayscale_kwargs.update(col['kwargs']) 140 | cols.append(self.get_grayscale_image_(col['img'], **grayscale_kwargs)) 141 | return np.concatenate(cols, axis=1) 142 | 143 | def save_image_grid(self, filename, imgs): 144 | img = self.get_image_grid_(imgs) 145 | cv2.imwrite(self.get_save_path(filename), img) 146 | 147 | def save_image(self, filename, img): 148 | img = self.convert_data(img) 149 | assert img.dtype == np.uint8 150 | if img.shape[-1] == 3: 151 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 152 | elif img.shape[-1] == 4: 153 | img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) 154 | cv2.imwrite(self.get_save_path(filename), img) 155 | 156 | def save_cubemap(self, filename, img, data_range=(0, 1)): 157 | img = self.convert_data(img) 158 | assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] 159 | 160 | imgs_full = [] 161 | for start in range(0, img.shape[-1], 3): 162 | img_ = img[...,start:start+3] 163 | img_ = np.stack([self.get_rgb_image_(img_[i], 'HWC', data_range) for i in range(img_.shape[0])], axis=0) 164 | size = img_.shape[1] 165 | placeholder = np.zeros((size, size, 3), dtype=np.float32) 166 | img_full = np.concatenate([ 167 | np.concatenate([placeholder, img_[2], placeholder, placeholder], axis=1), 168 | np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), 169 | np.concatenate([placeholder, img_[3], placeholder, placeholder], axis=1) 170 | ], axis=0) 171 | img_full = cv2.cvtColor(img_full, cv2.COLOR_RGB2BGR) 172 | imgs_full.append(img_full) 173 | 174 | imgs_full = np.concatenate(imgs_full, axis=1) 175 | cv2.imwrite(self.get_save_path(filename), imgs_full) 176 | 177 | def save_data(self, filename, data): 178 | data = self.convert_data(data) 179 | if isinstance(data, dict): 180 | if not filename.endswith('.npz'): 181 | filename += '.npz' 182 | np.savez(self.get_save_path(filename), **data) 183 | else: 184 | if not filename.endswith('.npy'): 185 | filename += '.npy' 186 | np.save(self.get_save_path(filename), data) 187 | 188 | def save_state_dict(self, filename, data): 189 | torch.save(data, self.get_save_path(filename)) 190 | 191 | def save_img_sequence(self, filename, img_dir, matcher, save_format='gif', fps=30): 192 | assert save_format in ['gif', 'mp4'] 193 | if not filename.endswith(save_format): 194 | filename += f".{save_format}" 195 | matcher = re.compile(matcher) 196 | img_dir = os.path.join(self.save_dir, img_dir) 197 | imgs = [] 198 | for f in os.listdir(img_dir): 199 | if matcher.search(f): 200 | imgs.append(f) 201 | imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) 202 | imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] 203 | 204 | if save_format == 'gif': 205 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 206 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps, palettesize=256) 207 | elif save_format == 'mp4': 208 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 209 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps) 210 | 211 | def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None, v_rgb=None): 212 | v_pos, t_pos_idx = self.convert_data(v_pos), self.convert_data(t_pos_idx) 213 | if v_rgb is not None: 214 | v_rgb = self.convert_data(v_rgb) 215 | 216 | import trimesh 217 | mesh = trimesh.Trimesh( 218 | vertices=v_pos, 219 | faces=t_pos_idx, 220 | vertex_colors=v_rgb 221 | ) 222 | mesh.export(self.get_save_path(filename)) 223 | 224 | def save_file(self, filename, src_path): 225 | shutil.copyfile(src_path, self.get_save_path(filename)) 226 | 227 | def save_json(self, filename, payload): 228 | with open(self.get_save_path(filename), 'w') as f: 229 | f.write(json.dumps(payload)) 230 | -------------------------------------------------------------------------------- /utils/obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def load_obj(filename): 5 | # Read entire file 6 | with open(filename, 'r') as f: 7 | lines = f.readlines() 8 | 9 | # load vertices 10 | vertices, texcoords = [], [] 11 | for line in lines: 12 | if len(line.split()) == 0: 13 | continue 14 | 15 | prefix = line.split()[0].lower() 16 | if prefix == 'v': 17 | vertices.append([float(v) for v in line.split()[1:]]) 18 | elif prefix == 'vt': 19 | val = [float(v) for v in line.split()[1:]] 20 | texcoords.append([val[0], 1.0 - val[1]]) 21 | 22 | uv = len(texcoords) > 0 23 | faces, tfaces = [], [] 24 | for line in lines: 25 | if len(line.split()) == 0: 26 | continue 27 | prefix = line.split()[0].lower() 28 | if prefix == 'usemtl': # Track used materials 29 | pass 30 | elif prefix == 'f': # Parse face 31 | vs = line.split()[1:] 32 | nv = len(vs) 33 | vv = vs[0].split('/') 34 | v0 = int(vv[0]) - 1 35 | if uv: 36 | t0 = int(vv[1]) - 1 if vv[1] != "" else -1 37 | for i in range(nv - 2): # Triangulate polygons 38 | vv1 = vs[i + 1].split('/') 39 | v1 = int(vv1[0]) - 1 40 | vv2 = vs[i + 2].split('/') 41 | v2 = int(vv2[0]) - 1 42 | faces.append([v0, v1, v2]) 43 | if uv: 44 | t1 = int(vv1[1]) - 1 if vv1[1] != "" else -1 45 | t2 = int(vv2[1]) - 1 if vv2[1] != "" else -1 46 | tfaces.append([t0, t1, t2]) 47 | vertices = np.array(vertices, dtype=np.float32) 48 | faces = np.array(faces, dtype=np.int64) 49 | if uv: 50 | assert len(tfaces) == len(faces) 51 | texcoords = np.array(texcoords, dtype=np.float32) 52 | tfaces = np.array(tfaces, dtype=np.int64) 53 | else: 54 | texcoords, tfaces = None, None 55 | 56 | return vertices, faces, texcoords, tfaces 57 | 58 | 59 | def write_obj(filename, v_pos, t_pos_idx, v_tex, t_tex_idx): 60 | with open(filename, "w") as f: 61 | for v in v_pos: 62 | f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) 63 | 64 | if v_tex is not None: 65 | assert(len(t_pos_idx) == len(t_tex_idx)) 66 | for v in v_tex: 67 | f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) 68 | 69 | # Write faces 70 | for i in range(len(t_pos_idx)): 71 | f.write("f ") 72 | for j in range(3): 73 | f.write(' %s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1))) 74 | f.write("\n") 75 | --------------------------------------------------------------------------------