├── .github ├── FUNDING.yml └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── .gitmodules ├── LICENSE ├── LICENSE_bmild ├── README.md ├── README_Unity.md ├── README_mesh.md ├── datasets ├── __init__.py ├── blender.py ├── depth_utils.py ├── llff.py └── ray_utils.py ├── docs ├── .gitignore ├── Gemfile ├── Gemfile.lock ├── _config.yml ├── index.md └── style.css ├── eval.py ├── extract_color_mesh.py ├── extract_mesh.ipynb ├── losses.py ├── metrics.py ├── models ├── __init__.py ├── nerf.py └── rendering.py ├── opt.py ├── requirements.txt ├── test.ipynb ├── train.py └── utils ├── __init__.py ├── optimizers.py ├── save_weights_only.py ├── visualization.py └── warmup_scheduler.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: kwea123 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Which branch you use** 14 | Only issues of the *dev* and *nerfw* branches will be considered currently! 15 | 16 | **To Reproduce** 17 | Steps to reproduce the behavior: 18 | 1. Go to '...' 19 | 2. Click on '....' 20 | 3. Scroll down to '....' 21 | 4. See error 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | logs/ 3 | ckpts/ 4 | results/ 5 | *.ply 6 | *.vol 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "torchsearchsorted"] 2 | path = torchsearchsorted 3 | url = https://github.com/aliutkus/torchsearchsorted.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Quei-An Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE_bmild: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bmild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nerf_pl 2 | 3 | ### Update: NVIDIA open-sourced a lightning-fast version of NeRF: [NGP](https://github.com/NVlabs/instant-ngp). I re-implemented in pytorch [here](https://github.com/kwea123/ngp_pl). This version is ~100x faster than this repo with also better quality! 4 | 5 | ### Update: an improved [NSFF](https://www.cs.cornell.edu/~zl548/NSFF/) implementation to handle dynamic scene is [open](https://github.com/kwea123/nsff_pl)! 6 | 7 | ### Update: [NeRF-W](https://nerf-w.github.io/) (NeRF in the Wild) implementation is added to [nerfw](https://github.com/kwea123/nerf_pl/tree/nerfw) branch! 8 | 9 | ### Update: The lastest code (using the latest libraries) will be updated to [dev](https://github.com/kwea123/nerf_pl/tree/dev) branch. The master branch remains to support the colab files. If you don't use colab, it is recommended to switch to dev branch. Only issues of the dev and nerfw branch will be considered currently. 10 | 11 | ### :gem: [**Project page**](https://kwea123.github.io/nerf_pl/) (live demo!) 12 | 13 | Unofficial implementation of [NeRF](https://arxiv.org/pdf/2003.08934.pdf) (Neural Radiance Fields) using pytorch ([pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning)). This repo doesn't aim at reproducibility, but aim at providing a simpler and faster training procedure (also simpler code with detailed comments to help to understand the work). Moreover, I try to extend much more opportunities by integrating this algorithm into game engine like Unity. 14 | 15 | Official implementation: [nerf](https://github.com/bmild/nerf) .. Reference pytorch implementation: [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) 16 | 17 | ### Recommend to read: A detailed NeRF extension list: [awesome-NeRF](https://github.com/yenchenlin/awesome-NeRF) 18 | 19 | ## :milky_way: Features 20 | 21 | * Multi-gpu training: Training on 8 GPUs finishes within 1 hour for the synthetic dataset! 22 | * [Colab](#mortar_board-colab) notebooks to allow easy usage! 23 | * [Reconstruct](#ribbon-mesh) **colored** mesh! 24 | * [Mixed Reality](https://youtu.be/S5phWFTs2iM) in Unity! 25 | * [REAL TIME volume rendering](https://youtu.be/w9qTbVzCdWk) in Unity! 26 | * [Portable Scenes](#portable-scenes) to let you play with other people's scenes! 27 | 28 | ### You can find the Unity project including mesh, mixed reality and volume rendering [here](https://github.com/kwea123/nerf_Unity)! See [README_Unity](README_Unity.md) for generating your own data for Unity rendering! 29 | 30 | ## :beginner: Tutorial 31 | 32 | ### What can NeRF do? 33 | 34 | 35 | ### Tutorial videos 36 | 37 | 38 | 39 | 40 | # :computer: Installation 41 | 42 | ## Hardware 43 | 44 | * OS: Ubuntu 18.04 45 | * NVIDIA GPU with **CUDA>=10.1** (tested with 1 RTX2080Ti) 46 | 47 | ## Software 48 | 49 | * Clone this repo by `git clone --recursive https://github.com/kwea123/nerf_pl` 50 | * Python>=3.6 (installation via [anaconda](https://www.anaconda.com/distribution/) is recommended, use `conda create -n nerf_pl python=3.6` to create a conda environment and activate it by `conda activate nerf_pl`) 51 | * Python libraries 52 | * Install core requirements by `pip install -r requirements.txt` 53 | * Install `torchsearchsorted` by `cd torchsearchsorted` then `pip install .` 54 | 55 | # :key: Training 56 | 57 | Please see each subsection for training on different datasets. Available training datasets: 58 | 59 | * [Blender](#blender) (Realistic Synthetic 360) 60 | * [LLFF](#llff) (Real Forward-Facing) 61 | * [Your own data](#your-own-data) (Forward-Facing/360 inward-facing) 62 | 63 | ## Blender 64 |
65 | Steps 66 | 67 | ### Data download 68 | 69 | Download `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 70 | 71 | ### Training model 72 | 73 | Run (example) 74 | ``` 75 | python train.py \ 76 | --dataset_name blender \ 77 | --root_dir $BLENDER_DIR \ 78 | --N_importance 64 --img_wh 400 400 --noise_std 0 \ 79 | --num_epochs 16 --batch_size 1024 \ 80 | --optimizer adam --lr 5e-4 \ 81 | --lr_scheduler steplr --decay_step 2 4 8 --decay_gamma 0.5 \ 82 | --exp_name exp 83 | ``` 84 | 85 | These parameters are chosen to best mimic the training settings in the original repo. See [opt.py](opt.py) for all configurations. 86 | 87 | NOTE: the above configuration doesn't work for some scenes like `drums`, `ship`. In that case, consider increasing the `batch_size` or change the `optimizer` to `radam`. I managed to train on all scenes with these modifications. 88 | 89 | You can monitor the training process by `tensorboard --logdir logs/` and go to `localhost:6006` in your browser. 90 |
91 | 92 | ## LLFF 93 |
94 | Steps 95 | 96 | ### Data download 97 | 98 | Download `nerf_llff_data.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 99 | 100 | ### Training model 101 | 102 | Run (example) 103 | ``` 104 | python train.py \ 105 | --dataset_name llff \ 106 | --root_dir $LLFF_DIR \ 107 | --N_importance 64 --img_wh 504 378 \ 108 | --num_epochs 30 --batch_size 1024 \ 109 | --optimizer adam --lr 5e-4 \ 110 | --lr_scheduler steplr --decay_step 10 20 --decay_gamma 0.5 \ 111 | --exp_name exp 112 | ``` 113 | 114 | These parameters are chosen to best mimic the training settings in the original repo. See [opt.py](opt.py) for all configurations. 115 | 116 | You can monitor the training process by `tensorboard --logdir logs/` and go to `localhost:6006` in your browser. 117 |
118 | 119 | ## Your own data 120 |
121 | Steps 122 | 123 | 1. Install [COLMAP](https://github.com/colmap/colmap) following [installation guide](https://colmap.github.io/install.html) 124 | 2. Prepare your images in a folder (around 20 to 30 for forward facing, and 40 to 50 for 360 inward-facing) 125 | 3. Clone [LLFF](https://github.com/Fyusion/LLFF) and run `python img2poses.py $your-images-folder` 126 | 4. Train the model using the same command as in [LLFF](#llff). If the scene is captured in a 360 inward-facing manner, add `--spheric` argument. 127 | 128 | For more details of training a good model, please see the video [here](#colab). 129 |
130 | 131 | ## Pretrained models and logs 132 | Download the pretrained models and training logs in [release](https://github.com/kwea123/nerf_pl/releases). 133 | 134 | ## Comparison with other repos 135 | 136 | | | training GPU memory in GB | Speed (1 step) | 137 | | :---: | :---: | :---: | 138 | | [Original](https://github.com/bmild/nerf) | 8.5 | 0.177s | 139 | | [Ref pytorch](https://github.com/yenchenlin/nerf-pytorch) | 6.0 | 0.147s | 140 | | This repo | 3.2 | 0.12s | 141 | 142 | The speed is measured on 1 RTX2080Ti. Detailed profile can be found in [release](https://github.com/kwea123/nerf_pl/releases). 143 | Training memory is largely reduced, since the original repo loads the whole data to GPU at the beginning, while we only pass batches to GPU every step. 144 | 145 | # :mag_right: Testing 146 | 147 | See [test.ipynb](test.ipynb) for a simple view synthesis and depth prediction on 1 image. 148 | 149 | Use [eval.py](eval.py) to create the whole sequence of moving views. 150 | E.g. 151 | ``` 152 | python eval.py \ 153 | --root_dir $BLENDER \ 154 | --dataset_name blender --scene_name lego \ 155 | --img_wh 400 400 --N_importance 64 --ckpt_path $CKPT_PATH 156 | ``` 157 | **IMPORTANT** : Don't forget to add `--spheric_poses` if the model is trained under `--spheric` setting! 158 | 159 | It will create folder `results/{dataset_name}/{scene_name}` and run inference on all test data, finally create a gif out of them. 160 | 161 | Example of lego scene using pretrained model and the reconstructed **colored** mesh: (PSNR=31.39, paper=32.54) 162 | 163 |

164 | 165 | 166 |

167 | 168 | Example of fern scene using pretrained model: 169 | 170 | ![fern](https://user-images.githubusercontent.com/11364490/79932650-f9d31380-8488-11ea-8dad-b70a6a3daa6e.gif) 171 | 172 | Example of own scene ([Silica GGO figure](https://www.youtube.com/watch?v=hVQIvEq_Av0)) and the reconstructed **colored** mesh. Click to link to youtube video. 173 | 174 |

175 | 176 | 177 | 178 | 179 |

180 | 181 | ## Portable scenes 182 | The concept of NeRF is that the whole scene is compressed into a NeRF model, then we can render from any pose we want. To render from plausible poses, we can leverage the training poses; therefore, you can generate video with **only** the trained model and the poses (hence the name of portable scenes). I provided my silica model in [release](https://github.com/kwea123/nerf_pl/releases), feel free to play around with it! 183 | 184 | If you trained some interesting scenes, you are also welcomed to share the model (and the `poses_bounds.npy`) by sending me an email, or post in issues! After all, a model is just around **5MB**! Please run `python utils/save_weights_only.py --ckpt_path $YOUR_MODEL_PATH` to extract the final model. 185 | 186 | # :ribbon: Mesh 187 | 188 | See [README_mesh](README_mesh.md) for reconstruction of **colored** mesh. Only supported for blender dataset and 360 inward-facing data! 189 | 190 | # :warning: Notes on differences with the original repo 191 | 192 | * The learning rate decay in the original repo is **by step**, which means it decreases every step, here I use learning rate decay **by epoch**, which means it changes only at the end of 1 epoch. 193 | * The validation image for LLFF dataset is chosen as the most centered image here, whereas the original repo chooses every 8th image. 194 | * The rendering spiral path is slightly different from the original repo (I use approximate values to simplify the code). 195 | 196 | # :mortar_board: COLAB 197 | 198 | I also prepared colab notebooks that allow you to run the algorithm on any machine without GPU requirement. 199 | 200 | * [colmap](https://gist.github.com/kwea123/f0e8f38ff2aa94495dbfe7ae9219f75c) to prepare camera poses for your own training data 201 | * [nerf](https://gist.github.com/kwea123/a3c541a325e895ef79ecbc0d2e6d7221) to train on your data 202 | * [extract_mesh](https://gist.github.com/kwea123/77ed1640f9bc9550136dc13a6a419e88) to extract colored mesh 203 | 204 | Please see [this playlist](https://www.youtube.com/playlist?list=PLDV2CyUo4q-K02pNEyDr7DYpTQuka3mbV) for the detailed tutorials. 205 | 206 | # :jack_o_lantern: SHOWOFF 207 | 208 | We can incorporate *ray tracing* techniques into the volume rendering pipeline, and realize realistic scene editing (following is the `materials` scene with an object removed, and a mesh is inserted and rendered with ray tracing). The code **will not** be released. 209 | 210 | ![add](https://user-images.githubusercontent.com/11364490/90312710-92face00-df41-11ea-9eea-10f24849b407.gif) 211 | ![add2](https://user-images.githubusercontent.com/11364490/90360796-92744b80-e097-11ea-859d-159aa2519375.gif) 212 | 213 | With my integration in Unity, I can realize realistic mixed reality photos (note my character casts shadow on the scene, **zero** post- image editing required): 214 | ![defer](https://user-images.githubusercontent.com/11364490/140264589-295acebe-8ace-4d61-b871-26eb8ae10ab0.png) 215 | ![defer2](https://user-images.githubusercontent.com/11364490/140264596-59daebe5-b88d-48e7-82bd-5ccaaff2283f.png) 216 | BTW, I would like to visit the museum one day... 217 | 218 | # :book: Citation 219 | If you use (part of) my code or find my work helpful, please consider citing 220 | ``` 221 | @misc{queianchen_nerf, 222 | author={Quei-An, Chen}, 223 | title={Nerf_pl: a pytorch-lightning implementation of NeRF}, 224 | url={https://github.com/kwea123/nerf_pl/}, 225 | year={2020}, 226 | } 227 | ``` 228 | -------------------------------------------------------------------------------- /README_Unity.md: -------------------------------------------------------------------------------- 1 | # Generating files for Unity rendering 2 | 3 | This readme contains guidances for generating files required for Unity rendering in my [Unity project](https://github.com/kwea123/nerf_Unity) 4 | 5 | ## MeshRender 6 | 7 | See [README_mesh](README_mesh.md) for generating mesh. 8 | You then need [this plugin](https://github.com/kwea123/Pcx) to import `.ply` files into Unity. 9 | 10 | ## MixedReality 11 | 12 | Use `eval.py` with `--save_depth --depth_format bytes`to create the whole sequence of moving views. E.g. 13 | ``` 14 | python eval.py \ 15 | --root_dir $BLENDER \ 16 | --dataset_name blender --scene_name lego \ 17 | --img_wh 400 400 --N_importance 64 --ckpt_path $CKPT_PATH \ 18 | --save_depth --depth_format bytes 19 | ``` 20 | You will get `*.png` files and corresponding `depth_*` files. Now import the image you want to show and its corresponding depth file into Unity, and replace the files in my Unity project. 21 | 22 | ## VolumeRender 23 | 24 | Use `extract_mesh.ipynb` (not `extract_color_mesh.py`!) to find the tight bounds for the object as for mesh generation (See [this video](https://www.youtube.com/watch?v=t06qu-gXrxA&t=1355)), but this time stop before the cell "Extract colored mesh". Remember to set `N=512` in the cell "Search for tight bounds of the object" and comment out the lines for visualization. Now run the cell "Generate .vol file for volume rendering in Unity", after that, you should obtain a `.vol` file, which you can import to my Unity project and render. 25 | 26 | **NOTE:** If you use colab as in the video, copy the cell "Generate .vol file for volume rendering in Unity" into colab notebook and execute it. 27 | 28 | If you want to render in your own project, you need the script [LoadVolume.cs](https://github.com/kwea123/nerf_Unity/blob/master/Assets/Editor/LoadVolume.cs) which reads this own-defined `.vol` into a `Texture3D`. 29 | -------------------------------------------------------------------------------- /README_mesh.md: -------------------------------------------------------------------------------- 1 | # Reconstruct mesh 2 | 3 | Use `extract_mesh.ipynb` (the notebook, **not** the py script) to extract **colored** mesh. The guideline for choosing good parameters is commented in the notebook. 4 | Here, I'll give detailed explanation of how it works. There is also a [video](https://youtu.be/t06qu-gXrxA) that explains the same thing. 5 | 6 | ## Step 1. Predict occupancy 7 | 8 | As the [original repo](https://github.com/bmild/nerf/blob/master/extract_mesh.ipynb), we need to first infer which locations are occupied by the object. This is done by first create a grid volume in the form of a cuboid covering the whole object, then use the nerf model to predict whether a cell is occupied or not. This is the main reason why mesh construction is only available for 360 inward-facing scenes as forward facing scenes would require a **huge** volume to cover the whole space! It is computationally impossible to predict the occupancy for all cells. 9 | 10 | ## Step 2. Perform marching cube algorithm 11 | 12 | After we know which cells are occupied, we can use [marching cube algorithm](https://en.wikipedia.org/wiki/Marching_cubes) to extract mesh. This mesh will only contain vertices and faces, if you don't require color, you can stop here and export the mesh. Until here, the code is the same as the original repo. 13 | 14 | ## Step 3. Remove noise 15 | 16 | The mesh might contain some noise, which could be due to wrongly predicted occupancy in step 1, or you might consider the floor as noise. To remove these noises, we use a simple method: only keep the largest cluster. We cluster the triangles into groups (two triangles are in the same group if they are connected), and only keep the biggest one. After removing the noise, we then compute the color for each vertex. 17 | 18 | ## Step 4. Compute color for each vertex 19 | 20 | We adopt the concept of assigning colors to vertices instead of faces (they are actually somehow equivalent, as you can think of the color of vertices as the average color of neighboring faces and vice versa). To compute the color of a vertex, we leverage the **training images**: we project this vertex onto the training images to get its rgb values, then average these values as its final color. Notice that the projected pixel coordinates are floating numbers, and we use *bilinear interpolation* as its rgb value. 21 | 22 | This process might seem correct at first sight, however, this is what we'll get: 23 | 24 | 25 | 26 | by projecting the vertices onto this input image: 27 | 28 | 29 | 30 | You'll notice the face appears on the mantle. Why is that? It is because of **occlusion**. 31 | 32 | From the input image view, that spurious part of the mantle is actually occluded (blocked) by the face, so in reality we **shouldn't** assign color to it, but the above process assigns it the same color as the face because those vertices are projected onto the face (in pixel coordinate) as well! 33 | 34 | So the problem becomes: How do we correctly infer occlusion information, to know which vertices shouldn't be assigned colors? I tried two methods, where the first turns out to not work well: 35 | 36 | 1. Use depth information 37 | 38 | The first intuitive way is to leverage vertices' depths (which is obtained when projecting vertices onto image plane): if two (or more) vertices are projected onto the **same** pixel coordinates, then only the nearest vertex will be assigned color, the rest remains untouched. However, this method won't work since no any two pixels will be projected onto the exact same location! As we mentioned earlier, the pixel coordinates are floating numbers, so it is impossible for they to be exactly the same. If we round the numbers to integers (which I tried as well), then this method works, but with still a lot of misclassified (occluded/non occluded) vertices in my experiments. 39 | 40 | 2. Leverage NeRF model 41 | 42 | What I find a intelligent way to infer occlusion is by using NeRF model. Recall that nerf model can estimate the opacity (or density) along a ray path (the following figure c): 43 | ![nerf](https://github.com/bmild/nerf/blob/master/imgs/pipeline.jpg) 44 | We can leverage that information to tell if a vertex is occluded or not. More concretely, we form rays originating from the camera origin, destinating (ending) at the vertices, and compute the total opacity along these rays. If a vertex is not occluded, the opacity will be small; otherwise, the value will be large, meaning that something lies between the vertex and the camera. 45 | 46 | After applying this method, this is what we get (by projecting the vertices onto the input view as above): 47 | 48 | 49 | The spurious face on the mantle disappears, and the colored pixels are almost exactly the ones we can observe from the image. By default we set the vertices to be all black, so a black vertex means it's occluded in this view, but will be assigned color when we change to other views. 50 | 51 | 52 | # Finally... 53 | 54 | This is the final result: 55 | 56 | 57 | 58 | We can then export this `.ply` file to any other format, and embed in programs like I did in Unity: 59 | (I use [this plugin](https://github.com/kwea123/Pcx) to import `.ply` file to unity, and made some modifications so that it can also read mesh triangles, not only the points) 60 | 61 | ![image](https://user-images.githubusercontent.com/11364490/80859833-9e7dfe00-8c9e-11ea-9fa1-ec48237e3873.png) 62 | 63 | The meshes can be attached a meshcollider so that they can interact with other objects. You can see [this video](https://youtu.be/I2M0xhnrBos) for a demo. 64 | 65 | ## Further reading 66 | The author suggested [another way](https://github.com/bmild/nerf/issues/44#issuecomment-622961303) to extract color, in my experiments it doesn't turn out to be good, but the idea is reasonable and interesting. You can also test this by setting `--use_vertex_normal`. 67 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .blender import BlenderDataset 2 | from .llff import LLFFDataset 3 | 4 | dataset_dict = {'blender': BlenderDataset, 5 | 'llff': LLFFDataset} -------------------------------------------------------------------------------- /datasets/blender.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | class BlenderDataset(Dataset): 12 | def __init__(self, root_dir, split='train', img_wh=(800, 800)): 13 | self.root_dir = root_dir 14 | self.split = split 15 | assert img_wh[0] == img_wh[1], 'image width must equal image height!' 16 | self.img_wh = img_wh 17 | self.define_transforms() 18 | 19 | self.read_meta() 20 | self.white_back = True 21 | 22 | def read_meta(self): 23 | with open(os.path.join(self.root_dir, 24 | f"transforms_{self.split}.json"), 'r') as f: 25 | self.meta = json.load(f) 26 | 27 | w, h = self.img_wh 28 | self.focal = 0.5*800/np.tan(0.5*self.meta['camera_angle_x']) # original focal length 29 | # when W=800 30 | 31 | self.focal *= self.img_wh[0]/800 # modify focal length to match size self.img_wh 32 | 33 | # bounds, common for all scenes 34 | self.near = 2.0 35 | self.far = 6.0 36 | self.bounds = np.array([self.near, self.far]) 37 | 38 | # ray directions for all pixels, same for all images (same H, W, focal) 39 | self.directions = \ 40 | get_ray_directions(h, w, self.focal) # (h, w, 3) 41 | 42 | if self.split == 'train': # create buffer of all rays and rgb data 43 | self.image_paths = [] 44 | self.poses = [] 45 | self.all_rays = [] 46 | self.all_rgbs = [] 47 | for frame in self.meta['frames']: 48 | pose = np.array(frame['transform_matrix'])[:3, :4] 49 | self.poses += [pose] 50 | c2w = torch.FloatTensor(pose) 51 | 52 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 53 | self.image_paths += [image_path] 54 | img = Image.open(image_path) 55 | img = img.resize(self.img_wh, Image.LANCZOS) 56 | img = self.transform(img) # (4, h, w) 57 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 58 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB 59 | self.all_rgbs += [img] 60 | 61 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 62 | 63 | self.all_rays += [torch.cat([rays_o, rays_d, 64 | self.near*torch.ones_like(rays_o[:, :1]), 65 | self.far*torch.ones_like(rays_o[:, :1])], 66 | 1)] # (h*w, 8) 67 | 68 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 69 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 70 | 71 | def define_transforms(self): 72 | self.transform = T.ToTensor() 73 | 74 | def __len__(self): 75 | if self.split == 'train': 76 | return len(self.all_rays) 77 | if self.split == 'val': 78 | return 8 # only validate 8 images (to support <=8 gpus) 79 | return len(self.meta['frames']) 80 | 81 | def __getitem__(self, idx): 82 | if self.split == 'train': # use data in the buffers 83 | sample = {'rays': self.all_rays[idx], 84 | 'rgbs': self.all_rgbs[idx]} 85 | 86 | else: # create data for each image separately 87 | frame = self.meta['frames'][idx] 88 | c2w = torch.FloatTensor(frame['transform_matrix'])[:3, :4] 89 | 90 | img = Image.open(os.path.join(self.root_dir, f"{frame['file_path']}.png")) 91 | img = img.resize(self.img_wh, Image.LANCZOS) 92 | img = self.transform(img) # (4, H, W) 93 | valid_mask = (img[-1]>0).flatten() # (H*W) valid color area 94 | img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA 95 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB 96 | 97 | rays_o, rays_d = get_rays(self.directions, c2w) 98 | 99 | rays = torch.cat([rays_o, rays_d, 100 | self.near*torch.ones_like(rays_o[:, :1]), 101 | self.far*torch.ones_like(rays_o[:, :1])], 102 | 1) # (H*W, 8) 103 | 104 | sample = {'rays': rays, 105 | 'rgbs': img, 106 | 'c2w': c2w, 107 | 'valid_mask': valid_mask} 108 | 109 | return sample -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | 5 | def read_pfm(filename): 6 | file = open(filename, 'rb') 7 | color = None 8 | width = None 9 | height = None 10 | scale = None 11 | endian = None 12 | 13 | header = file.readline().decode('utf-8').rstrip() 14 | if header == 'PF': 15 | color = True 16 | elif header == 'Pf': 17 | color = False 18 | else: 19 | raise Exception('Not a PFM file.') 20 | 21 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 22 | if dim_match: 23 | width, height = map(int, dim_match.groups()) 24 | else: 25 | raise Exception('Malformed PFM header.') 26 | 27 | scale = float(file.readline().rstrip()) 28 | if scale < 0: # little-endian 29 | endian = '<' 30 | scale = -scale 31 | else: 32 | endian = '>' # big-endian 33 | 34 | data = np.fromfile(file, endian + 'f') 35 | shape = (height, width, 3) if color else (height, width) 36 | 37 | data = np.reshape(data, shape) 38 | data = np.flipud(data) 39 | file.close() 40 | return data, scale 41 | 42 | 43 | def save_pfm(filename, image, scale=1): 44 | file = open(filename, "wb") 45 | color = None 46 | 47 | image = np.flipud(image) 48 | 49 | if image.dtype.name != 'float32': 50 | raise Exception('Image dtype must be float32.') 51 | 52 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 53 | color = True 54 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 55 | color = False 56 | else: 57 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 58 | 59 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 60 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 61 | 62 | endian = image.dtype.byteorder 63 | 64 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 65 | scale = -scale 66 | 67 | file.write(('%f\n' % scale).encode('utf-8')) 68 | 69 | image.tofile(file) 70 | file.close() -------------------------------------------------------------------------------- /datasets/llff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import glob 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | def normalize(v): 13 | """Normalize a vector.""" 14 | return v/np.linalg.norm(v) 15 | 16 | 17 | def average_poses(poses): 18 | """ 19 | Calculate the average pose, which is then used to center all poses 20 | using @center_poses. Its computation is as follows: 21 | 1. Compute the center: the average of pose centers. 22 | 2. Compute the z axis: the normalized average z axis. 23 | 3. Compute axis y': the average y axis. 24 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 25 | 5. Compute the y axis: z cross product x. 26 | 27 | Note that at step 3, we cannot directly use y' as y axis since it's 28 | not necessarily orthogonal to z axis. We need to pass from x to y. 29 | 30 | Inputs: 31 | poses: (N_images, 3, 4) 32 | 33 | Outputs: 34 | pose_avg: (3, 4) the average pose 35 | """ 36 | # 1. Compute the center 37 | center = poses[..., 3].mean(0) # (3) 38 | 39 | # 2. Compute the z axis 40 | z = normalize(poses[..., 2].mean(0)) # (3) 41 | 42 | # 3. Compute axis y' (no need to normalize as it's not the final output) 43 | y_ = poses[..., 1].mean(0) # (3) 44 | 45 | # 4. Compute the x axis 46 | x = normalize(np.cross(y_, z)) # (3) 47 | 48 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 49 | y = np.cross(z, x) # (3) 50 | 51 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 52 | 53 | return pose_avg 54 | 55 | 56 | def center_poses(poses): 57 | """ 58 | Center the poses so that we can use NDC. 59 | See https://github.com/bmild/nerf/issues/34 60 | 61 | Inputs: 62 | poses: (N_images, 3, 4) 63 | 64 | Outputs: 65 | poses_centered: (N_images, 3, 4) the centered poses 66 | pose_avg: (3, 4) the average pose 67 | """ 68 | 69 | pose_avg = average_poses(poses) # (3, 4) 70 | pose_avg_homo = np.eye(4) 71 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 72 | # by simply adding 0, 0, 0, 1 as the last row 73 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 74 | poses_homo = \ 75 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 76 | 77 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 78 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 79 | 80 | return poses_centered, np.linalg.inv(pose_avg_homo) 81 | 82 | 83 | def create_spiral_poses(radii, focus_depth, n_poses=120): 84 | """ 85 | Computes poses that follow a spiral path for rendering purpose. 86 | See https://github.com/Fyusion/LLFF/issues/19 87 | In particular, the path looks like: 88 | https://tinyurl.com/ybgtfns3 89 | 90 | Inputs: 91 | radii: (3) radii of the spiral for each axis 92 | focus_depth: float, the depth that the spiral poses look at 93 | n_poses: int, number of poses to create along the path 94 | 95 | Outputs: 96 | poses_spiral: (n_poses, 3, 4) the poses in the spiral path 97 | """ 98 | 99 | poses_spiral = [] 100 | for t in np.linspace(0, 4*np.pi, n_poses+1)[:-1]: # rotate 4pi (2 rounds) 101 | # the parametric function of the spiral (see the interactive web) 102 | center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5*t)]) * radii 103 | 104 | # the viewing z axis is the vector pointing from the @focus_depth plane 105 | # to @center 106 | z = normalize(center - np.array([0, 0, -focus_depth])) 107 | 108 | # compute other axes as in @average_poses 109 | y_ = np.array([0, 1, 0]) # (3) 110 | x = normalize(np.cross(y_, z)) # (3) 111 | y = np.cross(z, x) # (3) 112 | 113 | poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4) 114 | 115 | return np.stack(poses_spiral, 0) # (n_poses, 3, 4) 116 | 117 | 118 | def create_spheric_poses(radius, n_poses=120): 119 | """ 120 | Create circular poses around z axis. 121 | Inputs: 122 | radius: the (negative) height and the radius of the circle. 123 | 124 | Outputs: 125 | spheric_poses: (n_poses, 3, 4) the poses in the circular path 126 | """ 127 | def spheric_pose(theta, phi, radius): 128 | trans_t = lambda t : np.array([ 129 | [1,0,0,0], 130 | [0,1,0,-0.9*t], 131 | [0,0,1,t], 132 | [0,0,0,1], 133 | ]) 134 | 135 | rot_phi = lambda phi : np.array([ 136 | [1,0,0,0], 137 | [0,np.cos(phi),-np.sin(phi),0], 138 | [0,np.sin(phi), np.cos(phi),0], 139 | [0,0,0,1], 140 | ]) 141 | 142 | rot_theta = lambda th : np.array([ 143 | [np.cos(th),0,-np.sin(th),0], 144 | [0,1,0,0], 145 | [np.sin(th),0, np.cos(th),0], 146 | [0,0,0,1], 147 | ]) 148 | 149 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) 150 | c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w 151 | return c2w[:3] 152 | 153 | spheric_poses = [] 154 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: 155 | spheric_poses += [spheric_pose(th, -np.pi/5, radius)] # 36 degree view downwards 156 | return np.stack(spheric_poses, 0) 157 | 158 | 159 | class LLFFDataset(Dataset): 160 | def __init__(self, root_dir, split='train', img_wh=(504, 378), spheric_poses=False, val_num=1): 161 | """ 162 | spheric_poses: whether the images are taken in a spheric inward-facing manner 163 | default: False (forward-facing) 164 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 165 | """ 166 | self.root_dir = root_dir 167 | self.split = split 168 | self.img_wh = img_wh 169 | self.spheric_poses = spheric_poses 170 | self.val_num = max(1, val_num) # at least 1 171 | self.define_transforms() 172 | 173 | self.read_meta() 174 | self.white_back = False 175 | 176 | def read_meta(self): 177 | poses_bounds = np.load(os.path.join(self.root_dir, 178 | 'poses_bounds.npy')) # (N_images, 17) 179 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*'))) 180 | # load full resolution image then resize 181 | if self.split in ['train', 'val']: 182 | assert len(poses_bounds) == len(self.image_paths), \ 183 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!' 184 | 185 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 186 | self.bounds = poses_bounds[:, -2:] # (N_images, 2) 187 | 188 | # Step 1: rescale focal length according to training resolution 189 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images 190 | assert H*self.img_wh[0] == W*self.img_wh[1], \ 191 | f'You must set @img_wh to have the same aspect ratio as ({W}, {H}) !' 192 | 193 | self.focal *= self.img_wh[0]/W 194 | 195 | # Step 2: correct poses 196 | # Original poses has rotation in form "down right back", change to "right up back" 197 | # See https://github.com/bmild/nerf/issues/34 198 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 199 | # (N_images, 3, 4) exclude H, W, focal 200 | self.poses, self.pose_avg = center_poses(poses) 201 | distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 202 | val_idx = np.argmin(distances_from_center) # choose val image as the closest to 203 | # center image 204 | 205 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 206 | # See https://github.com/bmild/nerf/issues/34 207 | near_original = self.bounds.min() 208 | scale_factor = near_original*0.75 # 0.75 is the default parameter 209 | # the nearest depth is at 1/0.75=1.33 210 | self.bounds /= scale_factor 211 | self.poses[..., 3] /= scale_factor 212 | 213 | # ray directions for all pixels, same for all images (same H, W, focal) 214 | self.directions = \ 215 | get_ray_directions(self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3) 216 | 217 | if self.split == 'train': # create buffer of all rays and rgb data 218 | # use first N_images-1 to train, the LAST is val 219 | self.all_rays = [] 220 | self.all_rgbs = [] 221 | for i, image_path in enumerate(self.image_paths): 222 | if i == val_idx: # exclude the val image 223 | continue 224 | c2w = torch.FloatTensor(self.poses[i]) 225 | 226 | img = Image.open(image_path).convert('RGB') 227 | assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \ 228 | f'''{image_path} has different aspect ratio than img_wh, 229 | please check your data!''' 230 | img = img.resize(self.img_wh, Image.LANCZOS) 231 | img = self.transform(img) # (3, h, w) 232 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 233 | self.all_rgbs += [img] 234 | 235 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 236 | if not self.spheric_poses: 237 | near, far = 0, 1 238 | rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0], 239 | self.focal, 1.0, rays_o, rays_d) 240 | # near plane is always at 1.0 241 | # near and far in NDC are always 0 and 1 242 | # See https://github.com/bmild/nerf/issues/34 243 | else: 244 | near = self.bounds.min() 245 | far = min(8 * near, self.bounds.max()) # focus on central object only 246 | 247 | self.all_rays += [torch.cat([rays_o, rays_d, 248 | near*torch.ones_like(rays_o[:, :1]), 249 | far*torch.ones_like(rays_o[:, :1])], 250 | 1)] # (h*w, 8) 251 | 252 | self.all_rays = torch.cat(self.all_rays, 0) # ((N_images-1)*h*w, 8) 253 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3) 254 | 255 | elif self.split == 'val': 256 | print('val image is', self.image_paths[val_idx]) 257 | self.c2w_val = self.poses[val_idx] 258 | self.image_path_val = self.image_paths[val_idx] 259 | 260 | else: # for testing, create a parametric rendering path 261 | if self.split.endswith('train'): # test on training set 262 | self.poses_test = self.poses 263 | elif not self.spheric_poses: 264 | focus_depth = 3.5 # hardcoded, this is numerically close to the formula 265 | # given in the original repo. Mathematically if near=1 266 | # and far=infinity, then this number will converge to 4 267 | radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0) 268 | self.poses_test = create_spiral_poses(radii, focus_depth) 269 | else: 270 | radius = 1.1 * self.bounds.min() 271 | self.poses_test = create_spheric_poses(radius) 272 | 273 | def define_transforms(self): 274 | self.transform = T.ToTensor() 275 | 276 | def __len__(self): 277 | if self.split == 'train': 278 | return len(self.all_rays) 279 | if self.split == 'val': 280 | return self.val_num 281 | return len(self.poses_test) 282 | 283 | def __getitem__(self, idx): 284 | if self.split == 'train': # use data in the buffers 285 | sample = {'rays': self.all_rays[idx], 286 | 'rgbs': self.all_rgbs[idx]} 287 | 288 | else: 289 | if self.split == 'val': 290 | c2w = torch.FloatTensor(self.c2w_val) 291 | else: 292 | c2w = torch.FloatTensor(self.poses_test[idx]) 293 | 294 | rays_o, rays_d = get_rays(self.directions, c2w) 295 | if not self.spheric_poses: 296 | near, far = 0, 1 297 | rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0], 298 | self.focal, 1.0, rays_o, rays_d) 299 | else: 300 | near = self.bounds.min() 301 | far = min(8 * near, self.bounds.max()) 302 | 303 | rays = torch.cat([rays_o, rays_d, 304 | near*torch.ones_like(rays_o[:, :1]), 305 | far*torch.ones_like(rays_o[:, :1])], 306 | 1) # (h*w, 8) 307 | 308 | sample = {'rays': rays, 309 | 'c2w': c2w} 310 | 311 | if self.split == 'val': 312 | img = Image.open(self.image_path_val).convert('RGB') 313 | img = img.resize(self.img_wh, Image.LANCZOS) 314 | img = self.transform(img) # (3, h, w) 315 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) 316 | sample['rgbs'] = img 317 | 318 | return sample 319 | -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | 4 | 5 | def get_ray_directions(H, W, focal): 6 | """ 7 | Get ray directions for all pixels in camera coordinate. 8 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 9 | ray-tracing-generating-camera-rays/standard-coordinate-systems 10 | 11 | Inputs: 12 | H, W, focal: image height, width and focal length 13 | 14 | Outputs: 15 | directions: (H, W, 3), the direction of the rays in camera coordinate 16 | """ 17 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 18 | i, j = grid.unbind(-1) 19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 20 | # see https://github.com/bmild/nerf/issues/24 21 | directions = \ 22 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3) 23 | 24 | return directions 25 | 26 | 27 | def get_rays(directions, c2w): 28 | """ 29 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 30 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 31 | ray-tracing-generating-camera-rays/standard-coordinate-systems 32 | 33 | Inputs: 34 | directions: (H, W, 3) precomputed ray directions in camera coordinate 35 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 36 | 37 | Outputs: 38 | rays_o: (H*W, 3), the origin of the rays in world coordinate 39 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 40 | """ 41 | # Rotate ray directions from camera coordinate to the world coordinate 42 | rays_d = directions @ c2w[:, :3].T # (H, W, 3) 43 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 44 | # The origin of all rays is the camera origin in world coordinate 45 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3) 46 | 47 | rays_d = rays_d.view(-1, 3) 48 | rays_o = rays_o.view(-1, 3) 49 | 50 | return rays_o, rays_d 51 | 52 | 53 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 54 | """ 55 | Transform rays from world coordinate to NDC. 56 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. 57 | For detailed derivation, please see: 58 | http://www.songho.ca/opengl/gl_projectionmatrix.html 59 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf 60 | 61 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth). 62 | See https://github.com/bmild/nerf/issues/18 63 | 64 | Inputs: 65 | H, W, focal: image height, width and focal length 66 | near: (N_rays) or float, the depths of the near plane 67 | rays_o: (N_rays, 3), the origin of the rays in world coordinate 68 | rays_d: (N_rays, 3), the direction of the rays in world coordinate 69 | 70 | Outputs: 71 | rays_o: (N_rays, 3), the origin of the rays in NDC 72 | rays_d: (N_rays, 3), the direction of the rays in NDC 73 | """ 74 | # Shift ray origins to near plane 75 | t = -(near + rays_o[...,2]) / rays_d[...,2] 76 | rays_o = rays_o + t[...,None] * rays_d 77 | 78 | # Store some intermediate homogeneous results 79 | ox_oz = rays_o[...,0] / rays_o[...,2] 80 | oy_oz = rays_o[...,1] / rays_o[...,2] 81 | 82 | # Projection 83 | o0 = -1./(W/(2.*focal)) * ox_oz 84 | o1 = -1./(H/(2.*focal)) * oy_oz 85 | o2 = 1. + 2. * near / rays_o[...,2] 86 | 87 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz) 88 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz) 89 | d2 = 1 - o2 90 | 91 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) 92 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) 93 | 94 | return rays_o, rays_d -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _site 2 | .sass-cache 3 | .jekyll-cache 4 | .jekyll-metadata 5 | vendor 6 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | gem 'github-pages', group: :jekyll_plugins 3 | -------------------------------------------------------------------------------- /docs/Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | activesupport (>= 6.0.3.1) 5 | concurrent-ruby (~> 1.0, >= 1.0.2) 6 | i18n (>= 0.7, < 2) 7 | minitest (~> 5.1) 8 | tzinfo (~> 1.1) 9 | zeitwerk (~> 2.2, >= 2.2.2) 10 | addressable (2.7.0) 11 | public_suffix (>= 2.0.2, < 5.0) 12 | coffee-script (2.4.1) 13 | coffee-script-source 14 | execjs 15 | coffee-script-source (1.11.1) 16 | colorator (1.1.0) 17 | commonmarker (0.17.13) 18 | ruby-enum (~> 0.5) 19 | concurrent-ruby (1.1.6) 20 | dnsruby (1.61.3) 21 | addressable (~> 2.5) 22 | em-websocket (0.5.1) 23 | eventmachine (>= 0.12.9) 24 | http_parser.rb (~> 0.6.0) 25 | ethon (0.12.0) 26 | ffi (>= 1.3.0) 27 | eventmachine (1.2.7) 28 | execjs (2.7.0) 29 | faraday (1.0.1) 30 | multipart-post (>= 1.2, < 3) 31 | ffi (1.12.2) 32 | forwardable-extended (2.6.0) 33 | gemoji (3.0.1) 34 | github-pages (204) 35 | github-pages-health-check (= 1.16.1) 36 | jekyll (= 3.8.5) 37 | jekyll-avatar (= 0.7.0) 38 | jekyll-coffeescript (= 1.1.1) 39 | jekyll-commonmark-ghpages (= 0.1.6) 40 | jekyll-default-layout (= 0.1.4) 41 | jekyll-feed (= 0.13.0) 42 | jekyll-gist (= 1.5.0) 43 | jekyll-github-metadata (= 2.13.0) 44 | jekyll-mentions (= 1.5.1) 45 | jekyll-optional-front-matter (= 0.3.2) 46 | jekyll-paginate (= 1.1.0) 47 | jekyll-readme-index (= 0.3.0) 48 | jekyll-redirect-from (= 0.15.0) 49 | jekyll-relative-links (= 0.6.1) 50 | jekyll-remote-theme (= 0.4.1) 51 | jekyll-sass-converter (= 1.5.2) 52 | jekyll-seo-tag (= 2.6.1) 53 | jekyll-sitemap (= 1.4.0) 54 | jekyll-swiss (= 1.0.0) 55 | jekyll-theme-architect (= 0.1.1) 56 | jekyll-theme-cayman (= 0.1.1) 57 | jekyll-theme-dinky (= 0.1.1) 58 | jekyll-theme-hacker (= 0.1.1) 59 | jekyll-theme-leap-day (= 0.1.1) 60 | jekyll-theme-merlot (= 0.1.1) 61 | jekyll-theme-midnight (= 0.1.1) 62 | jekyll-theme-minimal (= 0.1.1) 63 | jekyll-theme-modernist (= 0.1.1) 64 | jekyll-theme-primer (= 0.5.4) 65 | jekyll-theme-slate (= 0.1.1) 66 | jekyll-theme-tactile (= 0.1.1) 67 | jekyll-theme-time-machine (= 0.1.1) 68 | jekyll-titles-from-headings (= 0.5.3) 69 | jemoji (= 0.11.1) 70 | kramdown (>= 2.3.0) 71 | liquid (= 4.0.3) 72 | mercenary (~> 0.3) 73 | minima (= 2.5.1) 74 | nokogiri (>= 1.11.0.rc4, < 2.0) 75 | rouge (= 3.13.0) 76 | terminal-table (~> 1.4) 77 | github-pages-health-check (1.16.1) 78 | addressable (~> 2.3) 79 | dnsruby (~> 1.60) 80 | octokit (~> 4.0) 81 | public_suffix (~> 3.0) 82 | typhoeus (~> 1.3) 83 | html-pipeline (2.12.3) 84 | activesupport (>= 2) 85 | nokogiri (>= 1.4) 86 | http_parser.rb (0.6.0) 87 | i18n (0.9.5) 88 | concurrent-ruby (~> 1.0) 89 | jekyll (3.8.5) 90 | addressable (~> 2.4) 91 | colorator (~> 1.0) 92 | em-websocket (~> 0.5) 93 | i18n (~> 0.7) 94 | jekyll-sass-converter (~> 1.0) 95 | jekyll-watch (~> 2.0) 96 | kramdown (~> 1.14) 97 | liquid (~> 4.0) 98 | mercenary (~> 0.3.3) 99 | pathutil (~> 0.9) 100 | rouge (>= 1.7, < 4) 101 | safe_yaml (~> 1.0) 102 | jekyll-avatar (0.7.0) 103 | jekyll (>= 3.0, < 5.0) 104 | jekyll-coffeescript (1.1.1) 105 | coffee-script (~> 2.2) 106 | coffee-script-source (~> 1.11.1) 107 | jekyll-commonmark (1.3.1) 108 | commonmarker (~> 0.14) 109 | jekyll (>= 3.7, < 5.0) 110 | jekyll-commonmark-ghpages (0.1.6) 111 | commonmarker (~> 0.17.6) 112 | jekyll-commonmark (~> 1.2) 113 | rouge (>= 2.0, < 4.0) 114 | jekyll-default-layout (0.1.4) 115 | jekyll (~> 3.0) 116 | jekyll-feed (0.13.0) 117 | jekyll (>= 3.7, < 5.0) 118 | jekyll-gist (1.5.0) 119 | octokit (~> 4.2) 120 | jekyll-github-metadata (2.13.0) 121 | jekyll (>= 3.4, < 5.0) 122 | octokit (~> 4.0, != 4.4.0) 123 | jekyll-mentions (1.5.1) 124 | html-pipeline (~> 2.3) 125 | jekyll (>= 3.7, < 5.0) 126 | jekyll-optional-front-matter (0.3.2) 127 | jekyll (>= 3.0, < 5.0) 128 | jekyll-paginate (1.1.0) 129 | jekyll-readme-index (0.3.0) 130 | jekyll (>= 3.0, < 5.0) 131 | jekyll-redirect-from (0.15.0) 132 | jekyll (>= 3.3, < 5.0) 133 | jekyll-relative-links (0.6.1) 134 | jekyll (>= 3.3, < 5.0) 135 | jekyll-remote-theme (0.4.1) 136 | addressable (~> 2.0) 137 | jekyll (>= 3.5, < 5.0) 138 | rubyzip (>= 1.3.0) 139 | jekyll-sass-converter (1.5.2) 140 | sass (~> 3.4) 141 | jekyll-seo-tag (2.6.1) 142 | jekyll (>= 3.3, < 5.0) 143 | jekyll-sitemap (1.4.0) 144 | jekyll (>= 3.7, < 5.0) 145 | jekyll-swiss (1.0.0) 146 | jekyll-theme-architect (0.1.1) 147 | jekyll (~> 3.5) 148 | jekyll-seo-tag (~> 2.0) 149 | jekyll-theme-cayman (0.1.1) 150 | jekyll (~> 3.5) 151 | jekyll-seo-tag (~> 2.0) 152 | jekyll-theme-dinky (0.1.1) 153 | jekyll (~> 3.5) 154 | jekyll-seo-tag (~> 2.0) 155 | jekyll-theme-hacker (0.1.1) 156 | jekyll (~> 3.5) 157 | jekyll-seo-tag (~> 2.0) 158 | jekyll-theme-leap-day (0.1.1) 159 | jekyll (~> 3.5) 160 | jekyll-seo-tag (~> 2.0) 161 | jekyll-theme-merlot (0.1.1) 162 | jekyll (~> 3.5) 163 | jekyll-seo-tag (~> 2.0) 164 | jekyll-theme-midnight (0.1.1) 165 | jekyll (~> 3.5) 166 | jekyll-seo-tag (~> 2.0) 167 | jekyll-theme-minimal (0.1.1) 168 | jekyll (~> 3.5) 169 | jekyll-seo-tag (~> 2.0) 170 | jekyll-theme-modernist (0.1.1) 171 | jekyll (~> 3.5) 172 | jekyll-seo-tag (~> 2.0) 173 | jekyll-theme-primer (0.5.4) 174 | jekyll (> 3.5, < 5.0) 175 | jekyll-github-metadata (~> 2.9) 176 | jekyll-seo-tag (~> 2.0) 177 | jekyll-theme-slate (0.1.1) 178 | jekyll (~> 3.5) 179 | jekyll-seo-tag (~> 2.0) 180 | jekyll-theme-tactile (0.1.1) 181 | jekyll (~> 3.5) 182 | jekyll-seo-tag (~> 2.0) 183 | jekyll-theme-time-machine (0.1.1) 184 | jekyll (~> 3.5) 185 | jekyll-seo-tag (~> 2.0) 186 | jekyll-titles-from-headings (0.5.3) 187 | jekyll (>= 3.3, < 5.0) 188 | jekyll-watch (2.2.1) 189 | listen (~> 3.0) 190 | jemoji (0.11.1) 191 | gemoji (~> 3.0) 192 | html-pipeline (~> 2.2) 193 | jekyll (>= 3.0, < 5.0) 194 | kramdown (>= 2.3.0) 195 | liquid (4.0.3) 196 | listen (3.2.1) 197 | rb-fsevent (~> 0.10, >= 0.10.3) 198 | rb-inotify (~> 0.9, >= 0.9.10) 199 | mercenary (0.3.6) 200 | mini_portile2 (2.4.0) 201 | minima (2.5.1) 202 | jekyll (>= 3.5, < 5.0) 203 | jekyll-feed (~> 0.9) 204 | jekyll-seo-tag (~> 2.1) 205 | minitest (5.14.1) 206 | multipart-post (2.1.1) 207 | nokogiri (1.10.9) 208 | mini_portile2 (~> 2.4.0) 209 | octokit (4.18.0) 210 | faraday (>= 0.9) 211 | sawyer (~> 0.8.0, >= 0.5.3) 212 | pathutil (0.16.2) 213 | forwardable-extended (~> 2.6) 214 | public_suffix (3.1.1) 215 | rb-fsevent (0.10.4) 216 | rb-inotify (0.10.1) 217 | ffi (~> 1.0) 218 | rouge (3.13.0) 219 | ruby-enum (0.8.0) 220 | i18n 221 | rubyzip (2.3.0) 222 | safe_yaml (1.0.5) 223 | sass (3.7.4) 224 | sass-listen (~> 4.0.0) 225 | sass-listen (4.0.0) 226 | rb-fsevent (~> 0.9, >= 0.9.4) 227 | rb-inotify (~> 0.9, >= 0.9.7) 228 | sawyer (0.8.2) 229 | addressable (>= 2.3.5) 230 | faraday (> 0.8, < 2.0) 231 | terminal-table (1.8.0) 232 | unicode-display_width (~> 1.1, >= 1.1.1) 233 | thread_safe (0.3.6) 234 | typhoeus (1.4.0) 235 | ethon (>= 0.9.0) 236 | tzinfo (1.2.7) 237 | thread_safe (~> 0.1) 238 | unicode-display_width (1.7.0) 239 | zeitwerk (2.3.0) 240 | 241 | PLATFORMS 242 | ruby 243 | 244 | DEPENDENCIES 245 | github-pages 246 | 247 | BUNDLED WITH 248 | 2.1.4 249 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate 2 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | # Author of this implementation 11 |

If you find my work helpful, please consider star this project! 12 | Star 13 |

14 | 15 | Quei-An Chen ([kwea123](https://github.com/kwea123)). Original author and photo credits: Ben Mildenhall ([bmild](https://github.com/bmild)) 16 | 17 | # What can NeRF do? 18 | 19 | 20 |
21 | 22 | # 360 degree view synthesis 23 |
24 | 27 | 30 | 33 | 36 | 39 |
40 | 41 | 53 | 54 |
55 | 56 | # Colored 3D mesh reconstruction (photogrammetry) 57 | We can generate real colored mesh that allows the object to interact with other physical objects. 58 | 59 | 60 |
61 | 62 | # Real time volume rendering in Unity 63 | [Volume rendering](https://en.wikipedia.org/wiki/Volume_rendering) is a technique that doesn't require "real object". The model you see here is composed of rays, so we can cut off parts to see internal structures, also perform deforming effect in real time. 64 | 65 | 66 |
67 | 68 | # Mixed reality in Unity (doesn't work on FireFox, please use Chrome) 69 | Accurate depth allows us to embed virtual object inside real scenes with correct z-order. 70 | 71 | 72 |
73 | 74 | # Tutorial 75 | 76 | I also have tutorials on how to achieve above results using google colab: 77 | 78 | 79 | 80 | 81 | 82 | # Call for contribution 83 | If you are expert in Unity and know how to make more visual appealing effects for the models shown above, feel free to contact me! I can share my code and data with you, and put your name on my github page! 84 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- 1 | #main_content { 2 | max-width: 960px; 3 | } 4 | 5 | .slick-slide { 6 | margin: 0 10px; 7 | } 8 | 9 | .slick-prev{ 10 | left: -40px; 11 | } 12 | 13 | .slick-prev:before { 14 | font-size: 40px; 15 | color: blueviolet; 16 | } 17 | .slick-next:before { 18 | font-size: 40px; 19 | color: blueviolet; 20 | } 21 | 22 | .slick-dots li button:before{ 23 | font-size: 20px; 24 | line-height: 20px; 25 | } -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | import imageio 7 | from argparse import ArgumentParser 8 | 9 | from models.rendering import render_rays 10 | from models.nerf import * 11 | 12 | from utils import load_ckpt 13 | import metrics 14 | 15 | from datasets import dataset_dict 16 | from datasets.depth_utils import * 17 | 18 | torch.backends.cudnn.benchmark = True 19 | 20 | def get_opts(): 21 | parser = ArgumentParser() 22 | parser.add_argument('--root_dir', type=str, 23 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego', 24 | help='root directory of dataset') 25 | parser.add_argument('--dataset_name', type=str, default='blender', 26 | choices=['blender', 'llff'], 27 | help='which dataset to validate') 28 | parser.add_argument('--scene_name', type=str, default='test', 29 | help='scene name, used as output folder name') 30 | parser.add_argument('--split', type=str, default='test', 31 | help='test or test_train') 32 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800], 33 | help='resolution (img_w, img_h) of the image') 34 | parser.add_argument('--spheric_poses', default=False, action="store_true", 35 | help='whether images are taken in spheric poses (for llff)') 36 | 37 | parser.add_argument('--N_samples', type=int, default=64, 38 | help='number of coarse samples') 39 | parser.add_argument('--N_importance', type=int, default=128, 40 | help='number of additional fine samples') 41 | parser.add_argument('--use_disp', default=False, action="store_true", 42 | help='use disparity depth sampling') 43 | parser.add_argument('--chunk', type=int, default=32*1024*4, 44 | help='chunk size to split the input to avoid OOM') 45 | 46 | parser.add_argument('--ckpt_path', type=str, required=True, 47 | help='pretrained checkpoint path to load') 48 | 49 | parser.add_argument('--save_depth', default=False, action="store_true", 50 | help='whether to save depth prediction') 51 | parser.add_argument('--depth_format', type=str, default='pfm', 52 | choices=['pfm', 'bytes'], 53 | help='which format to save') 54 | 55 | return parser.parse_args() 56 | 57 | 58 | @torch.no_grad() 59 | def batched_inference(models, embeddings, 60 | rays, N_samples, N_importance, use_disp, 61 | chunk, 62 | white_back): 63 | """Do batched inference on rays using chunk.""" 64 | B = rays.shape[0] 65 | chunk = 1024*32 66 | results = defaultdict(list) 67 | for i in range(0, B, chunk): 68 | rendered_ray_chunks = \ 69 | render_rays(models, 70 | embeddings, 71 | rays[i:i+chunk], 72 | N_samples, 73 | use_disp, 74 | 0, 75 | 0, 76 | N_importance, 77 | chunk, 78 | dataset.white_back, 79 | test_time=True) 80 | 81 | for k, v in rendered_ray_chunks.items(): 82 | results[k] += [v] 83 | 84 | for k, v in results.items(): 85 | results[k] = torch.cat(v, 0) 86 | return results 87 | 88 | 89 | if __name__ == "__main__": 90 | args = get_opts() 91 | w, h = args.img_wh 92 | 93 | kwargs = {'root_dir': args.root_dir, 94 | 'split': args.split, 95 | 'img_wh': tuple(args.img_wh)} 96 | if args.dataset_name == 'llff': 97 | kwargs['spheric_poses'] = args.spheric_poses 98 | dataset = dataset_dict[args.dataset_name](**kwargs) 99 | 100 | embedding_xyz = Embedding(3, 10) 101 | embedding_dir = Embedding(3, 4) 102 | nerf_coarse = NeRF() 103 | nerf_fine = NeRF() 104 | load_ckpt(nerf_coarse, args.ckpt_path, model_name='nerf_coarse') 105 | load_ckpt(nerf_fine, args.ckpt_path, model_name='nerf_fine') 106 | nerf_coarse.cuda().eval() 107 | nerf_fine.cuda().eval() 108 | 109 | models = [nerf_coarse, nerf_fine] 110 | embeddings = [embedding_xyz, embedding_dir] 111 | 112 | imgs = [] 113 | psnrs = [] 114 | dir_name = f'results/{args.dataset_name}/{args.scene_name}' 115 | os.makedirs(dir_name, exist_ok=True) 116 | 117 | for i in tqdm(range(len(dataset))): 118 | sample = dataset[i] 119 | rays = sample['rays'].cuda() 120 | results = batched_inference(models, embeddings, rays, 121 | args.N_samples, args.N_importance, args.use_disp, 122 | args.chunk, 123 | dataset.white_back) 124 | 125 | img_pred = results['rgb_fine'].view(h, w, 3).cpu().numpy() 126 | 127 | if args.save_depth: 128 | depth_pred = results['depth_fine'].view(h, w).cpu().numpy() 129 | depth_pred = np.nan_to_num(depth_pred) 130 | if args.depth_format == 'pfm': 131 | save_pfm(os.path.join(dir_name, f'depth_{i:03d}.pfm'), depth_pred) 132 | else: 133 | with open(f'depth_{i:03d}', 'wb') as f: 134 | f.write(depth_pred.tobytes()) 135 | 136 | img_pred_ = (img_pred*255).astype(np.uint8) 137 | imgs += [img_pred_] 138 | imageio.imwrite(os.path.join(dir_name, f'{i:03d}.png'), img_pred_) 139 | 140 | if 'rgbs' in sample: 141 | rgbs = sample['rgbs'] 142 | img_gt = rgbs.view(h, w, 3) 143 | psnrs += [metrics.psnr(img_gt, img_pred).item()] 144 | 145 | imageio.mimsave(os.path.join(dir_name, f'{args.scene_name}.gif'), imgs, fps=30) 146 | 147 | if psnrs: 148 | mean_psnr = np.mean(psnrs) 149 | print(f'Mean PSNR : {mean_psnr:.2f}') -------------------------------------------------------------------------------- /extract_color_mesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import mcubes 9 | import open3d as o3d 10 | from plyfile import PlyData, PlyElement 11 | from argparse import ArgumentParser 12 | 13 | from models.rendering import * 14 | from models.nerf import * 15 | 16 | from utils import load_ckpt 17 | 18 | from datasets import dataset_dict 19 | 20 | torch.backends.cudnn.benchmark = True 21 | 22 | def get_opts(): 23 | parser = ArgumentParser() 24 | parser.add_argument('--root_dir', type=str, 25 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego', 26 | help='root directory of dataset') 27 | parser.add_argument('--dataset_name', type=str, default='blender', 28 | choices=['blender', 'llff'], 29 | help='which dataset to validate') 30 | parser.add_argument('--scene_name', type=str, default='test', 31 | help='scene name, used as output ply filename') 32 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800], 33 | help='resolution (img_w, img_h) of the image') 34 | 35 | parser.add_argument('--N_samples', type=int, default=64, 36 | help='number of samples to infer the acculmulated opacity') 37 | parser.add_argument('--chunk', type=int, default=32*1024, 38 | help='chunk size to split the input to avoid OOM') 39 | parser.add_argument('--ckpt_path', type=str, required=True, 40 | help='pretrained checkpoint path to load') 41 | 42 | parser.add_argument('--N_grid', type=int, default=256, 43 | help='size of the grid on 1 side, larger=higher resolution') 44 | parser.add_argument('--x_range', nargs="+", type=float, default=[-1.0, 1.0], 45 | help='x range of the object') 46 | parser.add_argument('--y_range', nargs="+", type=float, default=[-1.0, 1.0], 47 | help='x range of the object') 48 | parser.add_argument('--z_range', nargs="+", type=float, default=[-1.0, 1.0], 49 | help='x range of the object') 50 | parser.add_argument('--sigma_threshold', type=float, default=20.0, 51 | help='threshold to consider a location is occupied') 52 | parser.add_argument('--occ_threshold', type=float, default=0.2, 53 | help='''threshold to consider a vertex is occluded. 54 | larger=fewer occluded pixels''') 55 | 56 | #### method using vertex normals #### 57 | parser.add_argument('--use_vertex_normal', action="store_true", 58 | help='use vertex normals to compute color') 59 | parser.add_argument('--N_importance', type=int, default=64, 60 | help='number of fine samples to infer the acculmulated opacity') 61 | parser.add_argument('--near_t', type=float, default=1.0, 62 | help='the near bound factor to start the ray') 63 | 64 | return parser.parse_args() 65 | 66 | 67 | @torch.no_grad() 68 | def f(models, embeddings, rays, N_samples, N_importance, chunk, white_back): 69 | """Do batched inference on rays using chunk.""" 70 | B = rays.shape[0] 71 | results = defaultdict(list) 72 | for i in range(0, B, chunk): 73 | rendered_ray_chunks = \ 74 | render_rays(models, 75 | embeddings, 76 | rays[i:i+chunk], 77 | N_samples, 78 | False, 79 | 0, 80 | 0, 81 | N_importance, 82 | chunk, 83 | white_back, 84 | test_time=True) 85 | 86 | for k, v in rendered_ray_chunks.items(): 87 | results[k] += [v] 88 | 89 | for k, v in results.items(): 90 | results[k] = torch.cat(v, 0) 91 | return results 92 | 93 | 94 | if __name__ == "__main__": 95 | args = get_opts() 96 | 97 | kwargs = {'root_dir': args.root_dir, 98 | 'img_wh': tuple(args.img_wh)} 99 | if args.dataset_name == 'llff': 100 | kwargs['spheric_poses'] = True 101 | kwargs['split'] = 'test' 102 | else: 103 | kwargs['split'] = 'train' 104 | dataset = dataset_dict[args.dataset_name](**kwargs) 105 | 106 | embedding_xyz = Embedding(3, 10) 107 | embedding_dir = Embedding(3, 4) 108 | embeddings = [embedding_xyz, embedding_dir] 109 | nerf_fine = NeRF() 110 | load_ckpt(nerf_fine, args.ckpt_path, model_name='nerf_fine') 111 | nerf_fine.cuda().eval() 112 | 113 | # define the dense grid for query 114 | N = args.N_grid 115 | xmin, xmax = args.x_range 116 | ymin, ymax = args.y_range 117 | zmin, zmax = args.z_range 118 | # assert xmax-xmin == ymax-ymin == zmax-zmin, 'the ranges must have the same length!' 119 | x = np.linspace(xmin, xmax, N) 120 | y = np.linspace(ymin, ymax, N) 121 | z = np.linspace(zmin, zmax, N) 122 | 123 | xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3)).cuda() 124 | dir_ = torch.zeros_like(xyz_).cuda() 125 | # sigma is independent of direction, so any value here will produce the same result 126 | 127 | # predict sigma (occupancy) for each grid location 128 | print('Predicting occupancy ...') 129 | with torch.no_grad(): 130 | B = xyz_.shape[0] 131 | out_chunks = [] 132 | for i in tqdm(range(0, B, args.chunk)): 133 | xyz_embedded = embedding_xyz(xyz_[i:i+args.chunk]) # (N, embed_xyz_channels) 134 | dir_embedded = embedding_dir(dir_[i:i+args.chunk]) # (N, embed_dir_channels) 135 | xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1) 136 | out_chunks += [nerf_fine(xyzdir_embedded)] 137 | rgbsigma = torch.cat(out_chunks, 0) 138 | 139 | sigma = rgbsigma[:, -1].cpu().numpy() 140 | sigma = np.maximum(sigma, 0).reshape(N, N, N) 141 | 142 | # perform marching cube algorithm to retrieve vertices and triangle mesh 143 | print('Extracting mesh ...') 144 | vertices, triangles = mcubes.marching_cubes(sigma, args.sigma_threshold) 145 | 146 | ##### Until mesh extraction here, it is the same as the original repo. ###### 147 | 148 | vertices_ = (vertices/N).astype(np.float32) 149 | ## invert x and y coordinates (WHY? maybe because of the marching cubes algo) 150 | x_ = (ymax-ymin) * vertices_[:, 1] + ymin 151 | y_ = (xmax-xmin) * vertices_[:, 0] + xmin 152 | vertices_[:, 0] = x_ 153 | vertices_[:, 1] = y_ 154 | vertices_[:, 2] = (zmax-zmin) * vertices_[:, 2] + zmin 155 | vertices_.dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 156 | 157 | face = np.empty(len(triangles), dtype=[('vertex_indices', 'i4', (3,))]) 158 | face['vertex_indices'] = triangles 159 | 160 | PlyData([PlyElement.describe(vertices_[:, 0], 'vertex'), 161 | PlyElement.describe(face, 'face')]).write(f'{args.scene_name}.ply') 162 | 163 | # remove noise in the mesh by keeping only the biggest cluster 164 | print('Removing noise ...') 165 | mesh = o3d.io.read_triangle_mesh(f"{args.scene_name}.ply") 166 | idxs, count, _ = mesh.cluster_connected_triangles() 167 | max_cluster_idx = np.argmax(count) 168 | triangles_to_remove = [i for i in range(len(face)) if idxs[i] != max_cluster_idx] 169 | mesh.remove_triangles_by_index(triangles_to_remove) 170 | mesh.remove_unreferenced_vertices() 171 | print(f'Mesh has {len(mesh.vertices)/1e6:.2f} M vertices and {len(mesh.triangles)/1e6:.2f} M faces.') 172 | 173 | vertices_ = np.asarray(mesh.vertices).astype(np.float32) 174 | triangles = np.asarray(mesh.triangles) 175 | 176 | # perform color prediction 177 | # Step 0. define constants (image width, height and intrinsics) 178 | W, H = args.img_wh 179 | K = np.array([[dataset.focal, 0, W/2], 180 | [0, dataset.focal, H/2], 181 | [0, 0, 1]]).astype(np.float32) 182 | 183 | # Step 1. transform vertices into world coordinate 184 | N_vertices = len(vertices_) 185 | vertices_homo = np.concatenate([vertices_, np.ones((N_vertices, 1))], 1) # (N, 4) 186 | 187 | if args.use_vertex_normal: ## use normal vector method as suggested by the author. 188 | ## see https://github.com/bmild/nerf/issues/44 189 | mesh.compute_vertex_normals() 190 | rays_d = torch.FloatTensor(np.asarray(mesh.vertex_normals)) 191 | near = dataset.bounds.min() * torch.ones_like(rays_d[:, :1]) 192 | far = dataset.bounds.max() * torch.ones_like(rays_d[:, :1]) 193 | rays_o = torch.FloatTensor(vertices_) - rays_d * near * args.near_t 194 | 195 | nerf_coarse = NeRF() 196 | load_ckpt(nerf_coarse, args.ckpt_path, model_name='nerf_coarse') 197 | nerf_coarse.cuda().eval() 198 | 199 | results = f([nerf_coarse, nerf_fine], embeddings, 200 | torch.cat([rays_o, rays_d, near, far], 1).cuda(), 201 | args.N_samples, 202 | args.N_importance, 203 | args.chunk, 204 | dataset.white_back) 205 | 206 | else: ## use my color average method. see README_mesh.md 207 | ## buffers to store the final averaged color 208 | non_occluded_sum = np.zeros((N_vertices, 1)) 209 | v_color_sum = np.zeros((N_vertices, 3)) 210 | 211 | # Step 2. project the vertices onto each training image to infer the color 212 | print('Fusing colors ...') 213 | for idx in tqdm(range(len(dataset.image_paths))): 214 | ## read image of this pose 215 | image = Image.open(dataset.image_paths[idx]).convert('RGB') 216 | image = image.resize(tuple(args.img_wh), Image.LANCZOS) 217 | image = np.array(image) 218 | 219 | ## read the camera to world relative pose 220 | P_c2w = np.concatenate([dataset.poses[idx], np.array([0, 0, 0, 1]).reshape(1, 4)], 0) 221 | P_w2c = np.linalg.inv(P_c2w)[:3] # (3, 4) 222 | ## project vertices from world coordinate to camera coordinate 223 | vertices_cam = (P_w2c @ vertices_homo.T) # (3, N) in "right up back" 224 | vertices_cam[1:] *= -1 # (3, N) in "right down forward" 225 | ## project vertices from camera coordinate to pixel coordinate 226 | vertices_image = (K @ vertices_cam).T # (N, 3) 227 | depth = vertices_image[:, -1:]+1e-5 # the depth of the vertices, used as far plane 228 | vertices_image = vertices_image[:, :2]/depth 229 | vertices_image = vertices_image.astype(np.float32) 230 | vertices_image[:, 0] = np.clip(vertices_image[:, 0], 0, W-1) 231 | vertices_image[:, 1] = np.clip(vertices_image[:, 1], 0, H-1) 232 | 233 | ## compute the color on these projected pixel coordinates 234 | ## using bilinear interpolation. 235 | ## NOTE: opencv's implementation has a size limit of 32768 pixels per side, 236 | ## so we split the input into chunks. 237 | colors = [] 238 | remap_chunk = int(3e4) 239 | for i in range(0, N_vertices, remap_chunk): 240 | colors += [cv2.remap(image, 241 | vertices_image[i:i+remap_chunk, 0], 242 | vertices_image[i:i+remap_chunk, 1], 243 | interpolation=cv2.INTER_LINEAR)[:, 0]] 244 | colors = np.vstack(colors) # (N_vertices, 3) 245 | 246 | ## predict occlusion of each vertex 247 | ## we leverage the concept of NeRF by constructing rays coming out from the camera 248 | ## and hitting each vertex; by computing the accumulated opacity along this path, 249 | ## we can know if the vertex is occluded or not. 250 | ## for vertices that appear to be occluded from every input view, we make the 251 | ## assumption that its color is the same as its neighbors that are facing our side. 252 | ## (think of a surface with one side facing us: we assume the other side has the same color) 253 | 254 | ## ray's origin is camera origin 255 | rays_o = torch.FloatTensor(dataset.poses[idx][:, -1]).expand(N_vertices, 3) 256 | ## ray's direction is the vector pointing from camera origin to the vertices 257 | rays_d = torch.FloatTensor(vertices_) - rays_o # (N_vertices, 3) 258 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 259 | near = dataset.bounds.min() * torch.ones_like(rays_o[:, :1]) 260 | ## the far plane is the depth of the vertices, since what we want is the accumulated 261 | ## opacity along the path from camera origin to the vertices 262 | far = torch.FloatTensor(depth) * torch.ones_like(rays_o[:, :1]) 263 | results = f([nerf_fine], embeddings, 264 | torch.cat([rays_o, rays_d, near, far], 1).cuda(), 265 | args.N_samples, 266 | 0, 267 | args.chunk, 268 | dataset.white_back) 269 | opacity = results['opacity_coarse'].cpu().numpy()[:, np.newaxis] # (N_vertices, 1) 270 | opacity = np.nan_to_num(opacity, 1) 271 | 272 | non_occluded = np.ones_like(non_occluded_sum) * 0.1/depth # weight by inverse depth 273 | # near=more confident in color 274 | non_occluded += opacity < args.occ_threshold 275 | 276 | v_color_sum += colors * non_occluded 277 | non_occluded_sum += non_occluded 278 | 279 | # Step 3. combine the output and write to file 280 | if args.use_vertex_normal: 281 | v_colors = results['rgb_fine'].cpu().numpy() * 255.0 282 | else: ## the combined color is the average color among all views 283 | v_colors = v_color_sum/non_occluded_sum 284 | v_colors = v_colors.astype(np.uint8) 285 | v_colors.dtype = [('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 286 | vertices_.dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 287 | vertex_all = np.empty(N_vertices, vertices_.dtype.descr+v_colors.dtype.descr) 288 | for prop in vertices_.dtype.names: 289 | vertex_all[prop] = vertices_[prop][:, 0] 290 | for prop in v_colors.dtype.names: 291 | vertex_all[prop] = v_colors[prop][:, 0] 292 | 293 | face = np.empty(len(triangles), dtype=[('vertex_indices', 'i4', (3,))]) 294 | face['vertex_indices'] = triangles 295 | 296 | PlyData([PlyElement.describe(vertex_all, 'vertex'), 297 | PlyElement.describe(face, 'face')]).write(f'{args.scene_name}.ply') 298 | 299 | print('Done!') 300 | -------------------------------------------------------------------------------- /extract_mesh.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from collections import defaultdict\n", 11 | "import numpy as np\n", 12 | "import mcubes\n", 13 | "import trimesh\n", 14 | "\n", 15 | "from models.rendering import *\n", 16 | "from models.nerf import *\n", 17 | "\n", 18 | "from datasets import dataset_dict\n", 19 | "\n", 20 | "from utils import load_ckpt" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Load model and data" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# Change here #\n", 37 | "img_wh = (800, 800) # full resolution of the input images\n", 38 | "dataset_name = 'blender' # blender or llff (own data)\n", 39 | "scene_name = 'lego' # whatever you want\n", 40 | "root_dir = '/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego/' # the folder containing data\n", 41 | "ckpt_path = 'ckpts/exp2/epoch=05.ckpt' # the model path\n", 42 | "###############\n", 43 | "\n", 44 | "kwargs = {'root_dir': root_dir,\n", 45 | " 'img_wh': img_wh}\n", 46 | "if dataset_name == 'llff':\n", 47 | " kwargs['spheric_poses'] = True\n", 48 | " kwargs['split'] = 'test'\n", 49 | "else:\n", 50 | " kwargs['split'] = 'train'\n", 51 | " \n", 52 | "chunk = 1024*32\n", 53 | "dataset = dataset_dict[dataset_name](**kwargs)\n", 54 | "\n", 55 | "embedding_xyz = Embedding(3, 10)\n", 56 | "embedding_dir = Embedding(3, 4)\n", 57 | "\n", 58 | "nerf_fine = NeRF()\n", 59 | "load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')\n", 60 | "nerf_fine.cuda().eval();" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "# Search for tight bounds of the object (trial and error!)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "### Tune these parameters until the whole object lies tightly in range with little noise ###\n", 77 | "N = 128 # controls the resolution, set this number small here because we're only finding\n", 78 | " # good ranges here, not yet for mesh reconstruction; we can set this number high\n", 79 | " # when it comes to final reconstruction.\n", 80 | "xmin, xmax = -1.2, 1.2 # left/right range\n", 81 | "ymin, ymax = -1.2, 1.2 # forward/backward range\n", 82 | "zmin, zmax = -1.2, 1.2 # up/down range\n", 83 | "## Attention! the ranges MUST have the same length!\n", 84 | "sigma_threshold = 50. # controls the noise (lower=maybe more noise; higher=some mesh might be missing)\n", 85 | "############################################################################################\n", 86 | "\n", 87 | "x = np.linspace(xmin, xmax, N)\n", 88 | "y = np.linspace(ymin, ymax, N)\n", 89 | "z = np.linspace(zmin, zmax, N)\n", 90 | "\n", 91 | "xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3)).cuda()\n", 92 | "dir_ = torch.zeros_like(xyz_).cuda()\n", 93 | "\n", 94 | "with torch.no_grad():\n", 95 | " B = xyz_.shape[0]\n", 96 | " out_chunks = []\n", 97 | " for i in range(0, B, chunk):\n", 98 | " xyz_embedded = embedding_xyz(xyz_[i:i+chunk]) # (N, embed_xyz_channels)\n", 99 | " dir_embedded = embedding_dir(dir_[i:i+chunk]) # (N, embed_dir_channels)\n", 100 | " xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1)\n", 101 | " out_chunks += [nerf_fine(xyzdir_embedded)]\n", 102 | " rgbsigma = torch.cat(out_chunks, 0)\n", 103 | " \n", 104 | "sigma = rgbsigma[:, -1].cpu().numpy()\n", 105 | "sigma = np.maximum(sigma, 0)\n", 106 | "sigma = sigma.reshape(N, N, N)\n", 107 | "\n", 108 | "# The below lines are for visualization, COMMENT OUT once you find the best range and increase N!\n", 109 | "vertices, triangles = mcubes.marching_cubes(sigma, sigma_threshold)\n", 110 | "mesh = trimesh.Trimesh(vertices/N, triangles)\n", 111 | "mesh.show()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# # You can already export \"colorless\" mesh if you don't need color\n", 121 | "# mcubes.export_mesh(vertices, triangles, f\"{scene_name}.dae\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "# Generate .vol file for volume rendering in Unity" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "assert N==512, \\\n", 138 | " 'Please set N to 512 in the two above cell! Remember to comment out the visualization code (last 3 lines)!'\n", 139 | "\n", 140 | "a = 1-np.exp(-(xmax-xmin)/N*sigma)\n", 141 | "a = a.flatten()\n", 142 | "rgb = (rgbsigma[:, :3].numpy()*255).astype(np.uint32)\n", 143 | "i = np.where(a>0)[0] # valid indices (alpha>0)\n", 144 | "\n", 145 | "rgb = rgb[i]\n", 146 | "a = a[i]\n", 147 | "s = rgb.dot(np.array([1<<24, 1<<16, 1<<8])) + (a*255).astype(np.uint32)\n", 148 | "res = np.stack([i, s], -1).astype(np.uint32).flatten()\n", 149 | "with open(f'{scene_name}.vol', 'wb') as f:\n", 150 | " f.write(res.tobytes())" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "# Extract colored mesh\n", 158 | "\n", 159 | "Once you find the best range, now **RESTART** the notebook, and copy the configs to the following cell\n", 160 | "and execute it." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# Copy the variables you have above here! ####\n", 170 | "img_wh = (800, 800) # full resolution of the input images\n", 171 | "dataset_name = 'blender' # blender or llff (own data)\n", 172 | "scene_name = 'lego' # whatever you want\n", 173 | "root_dir = '/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego/' # the folder containing data\n", 174 | "ckpt_path = 'ckpts/exp2/epoch=05.ckpt' # the model path\n", 175 | "\n", 176 | "N = 128 # controls the resolution, set this number small here because we're only finding\n", 177 | " # good ranges here, not yet for mesh reconstruction; we can set this number high\n", 178 | " # when it comes to final reconstruction.\n", 179 | "xmin, xmax = -1.2, 1.2 # left/right range\n", 180 | "ymin, ymax = -1.2, 1.2 # forward/backward range\n", 181 | "zmin, zmax = -1.2, 1.2 # up/down range\n", 182 | "sigma_threshold = 50. # controls the noise (lower=maybe more noise; higher=some mesh might be missing)\n", 183 | "###############################################\n", 184 | "\n", 185 | "import os\n", 186 | "os.environ['ROOT_DIR'] = root_dir\n", 187 | "os.environ['DATASET_NAME'] = dataset_name\n", 188 | "os.environ['SCENE_NAME'] = scene_name\n", 189 | "os.environ['IMG_SIZE'] = f\"{img_wh[0]} {img_wh[1]}\"\n", 190 | "os.environ['CKPT_PATH'] = ckpt_path\n", 191 | "os.environ['N_GRID'] = \"256\" # final resolution. You can set this number high to preserve more details\n", 192 | "os.environ['X_RANGE'] = f\"{xmin} {xmax}\"\n", 193 | "os.environ['Y_RANGE'] = f\"{ymin} {ymax}\"\n", 194 | "os.environ['Z_RANGE'] = f\"{zmin} {zmax}\"\n", 195 | "os.environ['SIGMA_THRESHOLD'] = str(sigma_threshold)\n", 196 | "os.environ['OCC_THRESHOLD'] = \"0.2\" # probably doesn't require tuning. If you find the color is not close\n", 197 | " # to real, try to set this number smaller (the effect of this number\n", 198 | " # is explained in my youtube video)\n", 199 | "\n", 200 | "!python extract_color_mesh.py \\\n", 201 | " --root_dir $ROOT_DIR \\\n", 202 | " --dataset_name $DATASET_NAME \\\n", 203 | " --scene_name $SCENE_NAME \\\n", 204 | " --img_wh $IMG_SIZE \\\n", 205 | " --ckpt_path $CKPT_PATH \\\n", 206 | " --N_grid $N_GRID \\\n", 207 | " --x_range $X_RANGE \\\n", 208 | " --y_range $Y_RANGE \\\n", 209 | " --z_range $Z_RANGE \\\n", 210 | " --sigma_threshold $SIGMA_THRESHOLD \\\n", 211 | " --occ_threshold $OCC_THRESHOLD" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "nerf_pl", 225 | "language": "python", 226 | "name": "nerf_pl" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.6.10" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 4 243 | } 244 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class MSELoss(nn.Module): 5 | def __init__(self): 6 | super(MSELoss, self).__init__() 7 | self.loss = nn.MSELoss(reduction='mean') 8 | 9 | def forward(self, inputs, targets): 10 | loss = self.loss(inputs['rgb_coarse'], targets) 11 | if 'rgb_fine' in inputs: 12 | loss += self.loss(inputs['rgb_fine'], targets) 13 | 14 | return loss 15 | 16 | 17 | loss_dict = {'mse': MSELoss} -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 13 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 14 | 15 | def ssim(image_pred, image_gt, reduction='mean'): 16 | """ 17 | image_pred and image_gt: (1, 3, H, W) 18 | """ 19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1] 20 | return 1-2*dssim_ # in [-1, 1] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwea123/nerf_pl/52aeb387da64a9ad9a0f914ea9b049ffc598b20c/models/__init__.py -------------------------------------------------------------------------------- /models/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Embedding(nn.Module): 5 | def __init__(self, in_channels, N_freqs, logscale=True): 6 | """ 7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 8 | in_channels: number of input channels (3 for both xyz and direction) 9 | """ 10 | super(Embedding, self).__init__() 11 | self.N_freqs = N_freqs 12 | self.in_channels = in_channels 13 | self.funcs = [torch.sin, torch.cos] 14 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 15 | 16 | if logscale: 17 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 18 | else: 19 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 20 | 21 | def forward(self, x): 22 | """ 23 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 24 | Different from the paper, "x" is also in the output 25 | See https://github.com/bmild/nerf/issues/12 26 | 27 | Inputs: 28 | x: (B, self.in_channels) 29 | 30 | Outputs: 31 | out: (B, self.out_channels) 32 | """ 33 | out = [x] 34 | for freq in self.freq_bands: 35 | for func in self.funcs: 36 | out += [func(freq*x)] 37 | 38 | return torch.cat(out, -1) 39 | 40 | 41 | class NeRF(nn.Module): 42 | def __init__(self, 43 | D=8, W=256, 44 | in_channels_xyz=63, in_channels_dir=27, 45 | skips=[4]): 46 | """ 47 | D: number of layers for density (sigma) encoder 48 | W: number of hidden units in each layer 49 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default) 50 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) 51 | skips: add skip connection in the Dth layer 52 | """ 53 | super(NeRF, self).__init__() 54 | self.D = D 55 | self.W = W 56 | self.in_channels_xyz = in_channels_xyz 57 | self.in_channels_dir = in_channels_dir 58 | self.skips = skips 59 | 60 | # xyz encoding layers 61 | for i in range(D): 62 | if i == 0: 63 | layer = nn.Linear(in_channels_xyz, W) 64 | elif i in skips: 65 | layer = nn.Linear(W+in_channels_xyz, W) 66 | else: 67 | layer = nn.Linear(W, W) 68 | layer = nn.Sequential(layer, nn.ReLU(True)) 69 | setattr(self, f"xyz_encoding_{i+1}", layer) 70 | self.xyz_encoding_final = nn.Linear(W, W) 71 | 72 | # direction encoding layers 73 | self.dir_encoding = nn.Sequential( 74 | nn.Linear(W+in_channels_dir, W//2), 75 | nn.ReLU(True)) 76 | 77 | # output layers 78 | self.sigma = nn.Linear(W, 1) 79 | self.rgb = nn.Sequential( 80 | nn.Linear(W//2, 3), 81 | nn.Sigmoid()) 82 | 83 | def forward(self, x, sigma_only=False): 84 | """ 85 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 86 | For rendering this ray, please see rendering.py 87 | 88 | Inputs: 89 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 90 | the embedded vector of position and direction 91 | sigma_only: whether to infer sigma only. If True, 92 | x is of shape (B, self.in_channels_xyz) 93 | 94 | Outputs: 95 | if sigma_ony: 96 | sigma: (B, 1) sigma 97 | else: 98 | out: (B, 4), rgb and sigma 99 | """ 100 | if not sigma_only: 101 | input_xyz, input_dir = \ 102 | torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1) 103 | else: 104 | input_xyz = x 105 | 106 | xyz_ = input_xyz 107 | for i in range(self.D): 108 | if i in self.skips: 109 | xyz_ = torch.cat([input_xyz, xyz_], -1) 110 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 111 | 112 | sigma = self.sigma(xyz_) 113 | if sigma_only: 114 | return sigma 115 | 116 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 117 | 118 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1) 119 | dir_encoding = self.dir_encoding(dir_encoding_input) 120 | rgb = self.rgb(dir_encoding) 121 | 122 | out = torch.cat([rgb, sigma], -1) 123 | 124 | return out -------------------------------------------------------------------------------- /models/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsearchsorted import searchsorted 3 | 4 | __all__ = ['render_rays'] 5 | 6 | """ 7 | Function dependencies: (-> means function calls) 8 | 9 | @render_rays -> @inference 10 | 11 | @render_rays -> @sample_pdf if there is fine model 12 | """ 13 | 14 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 15 | """ 16 | Sample @N_importance samples from @bins with distribution defined by @weights. 17 | 18 | Inputs: 19 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" 20 | weights: (N_rays, N_samples_) 21 | N_importance: the number of samples to draw from the distribution 22 | det: deterministic or not 23 | eps: a small number to prevent division by zero 24 | 25 | Outputs: 26 | samples: the sampled samples 27 | """ 28 | N_rays, N_samples_ = weights.shape 29 | weights = weights + eps # prevent division by zero (don't do inplace op!) 30 | pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) 31 | cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function 32 | cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) 33 | # padded to 0~1 inclusive 34 | 35 | if det: 36 | u = torch.linspace(0, 1, N_importance, device=bins.device) 37 | u = u.expand(N_rays, N_importance) 38 | else: 39 | u = torch.rand(N_rays, N_importance, device=bins.device) 40 | u = u.contiguous() 41 | 42 | inds = searchsorted(cdf, u, side='right') 43 | below = torch.clamp_min(inds-1, 0) 44 | above = torch.clamp_max(inds, N_samples_) 45 | 46 | inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) 47 | cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) 48 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) 49 | 50 | denom = cdf_g[...,1]-cdf_g[...,0] 51 | denom[denom 0: # perturb sampling depths (z_vals) 198 | z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points 199 | # get intervals between samples 200 | upper = torch.cat([z_vals_mid, z_vals[: ,-1:]], -1) 201 | lower = torch.cat([z_vals[: ,:1], z_vals_mid], -1) 202 | 203 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 204 | z_vals = lower + (upper - lower) * perturb_rand 205 | 206 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 207 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 208 | 209 | if test_time: 210 | weights_coarse = \ 211 | inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d, 212 | dir_embedded, z_vals, weights_only=True) 213 | result = {'opacity_coarse': weights_coarse.sum(1)} 214 | else: 215 | rgb_coarse, depth_coarse, weights_coarse = \ 216 | inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d, 217 | dir_embedded, z_vals, weights_only=False) 218 | result = {'rgb_coarse': rgb_coarse, 219 | 'depth_coarse': depth_coarse, 220 | 'opacity_coarse': weights_coarse.sum(1) 221 | } 222 | 223 | if N_importance > 0: # sample points for fine model 224 | z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points 225 | z_vals_ = sample_pdf(z_vals_mid, weights_coarse[:, 1:-1], 226 | N_importance, det=(perturb==0)).detach() 227 | # detach so that grad doesn't propogate to weights_coarse from here 228 | 229 | z_vals, _ = torch.sort(torch.cat([z_vals, z_vals_], -1), -1) 230 | 231 | xyz_fine_sampled = rays_o.unsqueeze(1) + \ 232 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) 233 | # (N_rays, N_samples+N_importance, 3) 234 | 235 | model_fine = models[1] 236 | rgb_fine, depth_fine, weights_fine = \ 237 | inference(model_fine, embedding_xyz, xyz_fine_sampled, rays_d, 238 | dir_embedded, z_vals, weights_only=False) 239 | 240 | result['rgb_fine'] = rgb_fine 241 | result['depth_fine'] = depth_fine 242 | result['opacity_fine'] = weights_fine.sum(1) 243 | 244 | return result 245 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--root_dir', type=str, 7 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego', 8 | help='root directory of dataset') 9 | parser.add_argument('--dataset_name', type=str, default='blender', 10 | choices=['blender', 'llff'], 11 | help='which dataset to train/val') 12 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800], 13 | help='resolution (img_w, img_h) of the image') 14 | parser.add_argument('--spheric_poses', default=False, action="store_true", 15 | help='whether images are taken in spheric poses (for llff)') 16 | 17 | parser.add_argument('--N_samples', type=int, default=64, 18 | help='number of coarse samples') 19 | parser.add_argument('--N_importance', type=int, default=128, 20 | help='number of additional fine samples') 21 | parser.add_argument('--use_disp', default=False, action="store_true", 22 | help='use disparity depth sampling') 23 | parser.add_argument('--perturb', type=float, default=1.0, 24 | help='factor to perturb depth sampling points') 25 | parser.add_argument('--noise_std', type=float, default=1.0, 26 | help='std dev of noise added to regularize sigma') 27 | 28 | parser.add_argument('--loss_type', type=str, default='mse', 29 | choices=['mse'], 30 | help='loss to use') 31 | 32 | parser.add_argument('--batch_size', type=int, default=1024, 33 | help='batch size') 34 | parser.add_argument('--chunk', type=int, default=32*1024, 35 | help='chunk size to split the input to avoid OOM') 36 | parser.add_argument('--num_epochs', type=int, default=16, 37 | help='number of training epochs') 38 | parser.add_argument('--num_gpus', type=int, default=1, 39 | help='number of gpus') 40 | 41 | parser.add_argument('--ckpt_path', type=str, default=None, 42 | help='pretrained checkpoint path to load') 43 | parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'], 44 | help='the prefixes to ignore in the checkpoint state dict') 45 | 46 | parser.add_argument('--optimizer', type=str, default='adam', 47 | help='optimizer type', 48 | choices=['sgd', 'adam', 'radam', 'ranger']) 49 | parser.add_argument('--lr', type=float, default=5e-4, 50 | help='learning rate') 51 | parser.add_argument('--momentum', type=float, default=0.9, 52 | help='learning rate momentum') 53 | parser.add_argument('--weight_decay', type=float, default=0, 54 | help='weight decay') 55 | parser.add_argument('--lr_scheduler', type=str, default='steplr', 56 | help='scheduler type', 57 | choices=['steplr', 'cosine', 'poly']) 58 | #### params for warmup, only applied when optimizer == 'sgd' or 'adam' 59 | parser.add_argument('--warmup_multiplier', type=float, default=1.0, 60 | help='lr is multiplied by this factor after --warmup_epochs') 61 | parser.add_argument('--warmup_epochs', type=int, default=0, 62 | help='Gradually warm-up(increasing) learning rate in optimizer') 63 | ########################### 64 | #### params for steplr #### 65 | parser.add_argument('--decay_step', nargs='+', type=int, default=[20], 66 | help='scheduler decay step') 67 | parser.add_argument('--decay_gamma', type=float, default=0.1, 68 | help='learning rate decay amount') 69 | ########################### 70 | #### params for poly #### 71 | parser.add_argument('--poly_exp', type=float, default=0.9, 72 | help='exponent for polynomial learning rate decay') 73 | ########################### 74 | 75 | parser.add_argument('--exp_name', type=str, default='exp', 76 | help='experiment name') 77 | 78 | return parser.parse_args() 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | pytorch-lightning==0.7.5 4 | test-tube 5 | kornia==0.2.0 6 | opencv-python==4.2.0.34 7 | matplotlib 8 | jupyter 9 | 10 | # for mesh 11 | PyMCubes 12 | pycollada 13 | trimesh 14 | pyglet 15 | 16 | # for point cloud 17 | plyfile 18 | open3d 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from opt import get_opts 3 | import torch 4 | from collections import defaultdict 5 | 6 | from torch.utils.data import DataLoader 7 | from datasets import dataset_dict 8 | 9 | # models 10 | from models.nerf import Embedding, NeRF 11 | from models.rendering import render_rays 12 | 13 | # optimizer, scheduler, visualization 14 | from utils import * 15 | 16 | # losses 17 | from losses import loss_dict 18 | 19 | # metrics 20 | from metrics import * 21 | 22 | # pytorch-lightning 23 | from pytorch_lightning.callbacks import ModelCheckpoint 24 | from pytorch_lightning import LightningModule, Trainer 25 | from pytorch_lightning.logging import TestTubeLogger 26 | 27 | class NeRFSystem(LightningModule): 28 | def __init__(self, hparams): 29 | super(NeRFSystem, self).__init__() 30 | self.hparams = hparams 31 | 32 | self.loss = loss_dict[hparams.loss_type]() 33 | 34 | self.embedding_xyz = Embedding(3, 10) # 10 is the default number 35 | self.embedding_dir = Embedding(3, 4) # 4 is the default number 36 | self.embeddings = [self.embedding_xyz, self.embedding_dir] 37 | 38 | self.nerf_coarse = NeRF() 39 | self.models = [self.nerf_coarse] 40 | if hparams.N_importance > 0: 41 | self.nerf_fine = NeRF() 42 | self.models += [self.nerf_fine] 43 | 44 | def decode_batch(self, batch): 45 | rays = batch['rays'] # (B, 8) 46 | rgbs = batch['rgbs'] # (B, 3) 47 | return rays, rgbs 48 | 49 | def forward(self, rays): 50 | """Do batched inference on rays using chunk.""" 51 | B = rays.shape[0] 52 | results = defaultdict(list) 53 | for i in range(0, B, self.hparams.chunk): 54 | rendered_ray_chunks = \ 55 | render_rays(self.models, 56 | self.embeddings, 57 | rays[i:i+self.hparams.chunk], 58 | self.hparams.N_samples, 59 | self.hparams.use_disp, 60 | self.hparams.perturb, 61 | self.hparams.noise_std, 62 | self.hparams.N_importance, 63 | self.hparams.chunk, # chunk size is effective in val mode 64 | self.train_dataset.white_back) 65 | 66 | for k, v in rendered_ray_chunks.items(): 67 | results[k] += [v] 68 | 69 | for k, v in results.items(): 70 | results[k] = torch.cat(v, 0) 71 | return results 72 | 73 | def prepare_data(self): 74 | dataset = dataset_dict[self.hparams.dataset_name] 75 | kwargs = {'root_dir': self.hparams.root_dir, 76 | 'img_wh': tuple(self.hparams.img_wh)} 77 | if self.hparams.dataset_name == 'llff': 78 | kwargs['spheric_poses'] = self.hparams.spheric_poses 79 | kwargs['val_num'] = self.hparams.num_gpus 80 | self.train_dataset = dataset(split='train', **kwargs) 81 | self.val_dataset = dataset(split='val', **kwargs) 82 | 83 | def configure_optimizers(self): 84 | self.optimizer = get_optimizer(self.hparams, self.models) 85 | scheduler = get_scheduler(self.hparams, self.optimizer) 86 | 87 | return [self.optimizer], [scheduler] 88 | 89 | def train_dataloader(self): 90 | return DataLoader(self.train_dataset, 91 | shuffle=True, 92 | num_workers=4, 93 | batch_size=self.hparams.batch_size, 94 | pin_memory=True) 95 | 96 | def val_dataloader(self): 97 | return DataLoader(self.val_dataset, 98 | shuffle=False, 99 | num_workers=4, 100 | batch_size=1, # validate one image (H*W rays) at a time 101 | pin_memory=True) 102 | 103 | def training_step(self, batch, batch_nb): 104 | log = {'lr': get_learning_rate(self.optimizer)} 105 | rays, rgbs = self.decode_batch(batch) 106 | results = self(rays) 107 | log['train/loss'] = loss = self.loss(results, rgbs) 108 | typ = 'fine' if 'rgb_fine' in results else 'coarse' 109 | 110 | with torch.no_grad(): 111 | psnr_ = psnr(results[f'rgb_{typ}'], rgbs) 112 | log['train/psnr'] = psnr_ 113 | 114 | return {'loss': loss, 115 | 'progress_bar': {'train_psnr': psnr_}, 116 | 'log': log 117 | } 118 | 119 | def validation_step(self, batch, batch_nb): 120 | rays, rgbs = self.decode_batch(batch) 121 | rays = rays.squeeze() # (H*W, 3) 122 | rgbs = rgbs.squeeze() # (H*W, 3) 123 | results = self(rays) 124 | log = {'val_loss': self.loss(results, rgbs)} 125 | typ = 'fine' if 'rgb_fine' in results else 'coarse' 126 | 127 | if batch_nb == 0: 128 | W, H = self.hparams.img_wh 129 | img = results[f'rgb_{typ}'].view(H, W, 3).cpu() 130 | img = img.permute(2, 0, 1) # (3, H, W) 131 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) 132 | depth = visualize_depth(results[f'depth_{typ}'].view(H, W)) # (3, H, W) 133 | stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W) 134 | self.logger.experiment.add_images('val/GT_pred_depth', 135 | stack, self.global_step) 136 | 137 | log['val_psnr'] = psnr(results[f'rgb_{typ}'], rgbs) 138 | return log 139 | 140 | def validation_epoch_end(self, outputs): 141 | mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 142 | mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean() 143 | 144 | return {'progress_bar': {'val_loss': mean_loss, 145 | 'val_psnr': mean_psnr}, 146 | 'log': {'val/loss': mean_loss, 147 | 'val/psnr': mean_psnr} 148 | } 149 | 150 | 151 | if __name__ == '__main__': 152 | hparams = get_opts() 153 | system = NeRFSystem(hparams) 154 | checkpoint_callback = ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}', 155 | '{epoch:d}'), 156 | monitor='val/loss', 157 | mode='min', 158 | save_top_k=5,) 159 | 160 | logger = TestTubeLogger( 161 | save_dir="logs", 162 | name=hparams.exp_name, 163 | debug=False, 164 | create_git_tag=False 165 | ) 166 | 167 | trainer = Trainer(max_epochs=hparams.num_epochs, 168 | checkpoint_callback=checkpoint_callback, 169 | resume_from_checkpoint=hparams.ckpt_path, 170 | logger=logger, 171 | early_stop_callback=None, 172 | weights_summary=None, 173 | progress_bar_refresh_rate=1, 174 | gpus=hparams.num_gpus, 175 | distributed_backend='ddp' if hparams.num_gpus>1 else None, 176 | num_sanity_val_steps=1, 177 | benchmark=True, 178 | profiler=hparams.num_gpus==1) 179 | 180 | trainer.fit(system) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | from torch.optim import SGD, Adam 3 | from .optimizers import * 4 | # scheduler 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 6 | from .warmup_scheduler import GradualWarmupScheduler 7 | 8 | from .visualization import * 9 | 10 | def get_optimizer(hparams, models): 11 | eps = 1e-8 12 | parameters = [] 13 | for model in models: 14 | parameters += list(model.parameters()) 15 | if hparams.optimizer == 'sgd': 16 | optimizer = SGD(parameters, lr=hparams.lr, 17 | momentum=hparams.momentum, weight_decay=hparams.weight_decay) 18 | elif hparams.optimizer == 'adam': 19 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps, 20 | weight_decay=hparams.weight_decay) 21 | elif hparams.optimizer == 'radam': 22 | optimizer = RAdam(parameters, lr=hparams.lr, eps=eps, 23 | weight_decay=hparams.weight_decay) 24 | elif hparams.optimizer == 'ranger': 25 | optimizer = Ranger(parameters, lr=hparams.lr, eps=eps, 26 | weight_decay=hparams.weight_decay) 27 | else: 28 | raise ValueError('optimizer not recognized!') 29 | 30 | return optimizer 31 | 32 | def get_scheduler(hparams, optimizer): 33 | eps = 1e-8 34 | if hparams.lr_scheduler == 'steplr': 35 | scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, 36 | gamma=hparams.decay_gamma) 37 | elif hparams.lr_scheduler == 'cosine': 38 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) 39 | elif hparams.lr_scheduler == 'poly': 40 | scheduler = LambdaLR(optimizer, 41 | lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp) 42 | else: 43 | raise ValueError('scheduler not recognized!') 44 | 45 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: 46 | scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, 47 | total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) 48 | 49 | return scheduler 50 | 51 | def get_learning_rate(optimizer): 52 | for param_group in optimizer.param_groups: 53 | return param_group['lr'] 54 | 55 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 56 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 57 | checkpoint_ = {} 58 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 59 | checkpoint = checkpoint['state_dict'] 60 | for k, v in checkpoint.items(): 61 | if not k.startswith(model_name): 62 | continue 63 | k = k[len(model_name)+1:] 64 | for prefix in prefixes_to_ignore: 65 | if k.startswith(prefix): 66 | print('ignore', k) 67 | break 68 | else: 69 | checkpoint_[k] = v 70 | return checkpoint_ 71 | 72 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 73 | model_dict = model.state_dict() 74 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 75 | model_dict.update(checkpoint_) 76 | model.load_state_dict(model_dict) 77 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | import itertools as it 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 <= betas[0] < 1.0: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= betas[1] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 17 | 18 | self.degenerated_to_sgd = degenerated_to_sgd 19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 20 | for param in params: 21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 22 | param['buffer'] = [[None, None, None] for _ in range(10)] 23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 24 | super(RAdam, self).__init__(params, defaults) 25 | 26 | def __setstate__(self, state): 27 | super(RAdam, self).__setstate__(state) 28 | 29 | def step(self, closure=None): 30 | 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | 35 | for group in self.param_groups: 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | grad = p.grad.data.float() 41 | if grad.is_sparse: 42 | raise RuntimeError('RAdam does not support sparse gradients') 43 | 44 | p_data_fp32 = p.data.float() 45 | 46 | state = self.state[p] 47 | 48 | if len(state) == 0: 49 | state['step'] = 0 50 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 51 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 52 | else: 53 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 54 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 55 | 56 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 57 | beta1, beta2 = group['betas'] 58 | 59 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 60 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 61 | 62 | state['step'] += 1 63 | buffered = group['buffer'][int(state['step'] % 10)] 64 | if state['step'] == buffered[0]: 65 | N_sma, step_size = buffered[1], buffered[2] 66 | else: 67 | buffered[0] = state['step'] 68 | beta2_t = beta2 ** state['step'] 69 | N_sma_max = 2 / (1 - beta2) - 1 70 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 71 | buffered[1] = N_sma 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 76 | elif self.degenerated_to_sgd: 77 | step_size = 1.0 / (1 - beta1 ** state['step']) 78 | else: 79 | step_size = -1 80 | buffered[2] = step_size 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | if group['weight_decay'] != 0: 85 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 86 | denom = exp_avg_sq.sqrt().add_(group['eps']) 87 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 88 | p.data.copy_(p_data_fp32) 89 | elif step_size > 0: 90 | if group['weight_decay'] != 0: 91 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 92 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 93 | p.data.copy_(p_data_fp32) 94 | 95 | return loss 96 | 97 | class PlainRAdam(Optimizer): 98 | 99 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 100 | if not 0.0 <= lr: 101 | raise ValueError("Invalid learning rate: {}".format(lr)) 102 | if not 0.0 <= eps: 103 | raise ValueError("Invalid epsilon value: {}".format(eps)) 104 | if not 0.0 <= betas[0] < 1.0: 105 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 106 | if not 0.0 <= betas[1] < 1.0: 107 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 108 | 109 | self.degenerated_to_sgd = degenerated_to_sgd 110 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 111 | 112 | super(PlainRAdam, self).__init__(params, defaults) 113 | 114 | def __setstate__(self, state): 115 | super(PlainRAdam, self).__setstate__(state) 116 | 117 | def step(self, closure=None): 118 | 119 | loss = None 120 | if closure is not None: 121 | loss = closure() 122 | 123 | for group in self.param_groups: 124 | 125 | for p in group['params']: 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data.float() 129 | if grad.is_sparse: 130 | raise RuntimeError('RAdam does not support sparse gradients') 131 | 132 | p_data_fp32 = p.data.float() 133 | 134 | state = self.state[p] 135 | 136 | if len(state) == 0: 137 | state['step'] = 0 138 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 139 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 140 | else: 141 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 142 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 143 | 144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 145 | beta1, beta2 = group['betas'] 146 | 147 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 148 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 149 | 150 | state['step'] += 1 151 | beta2_t = beta2 ** state['step'] 152 | N_sma_max = 2 / (1 - beta2) - 1 153 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 154 | 155 | 156 | # more conservative since it's an approximated value 157 | if N_sma >= 5: 158 | if group['weight_decay'] != 0: 159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 160 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 161 | denom = exp_avg_sq.sqrt().add_(group['eps']) 162 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 163 | p.data.copy_(p_data_fp32) 164 | elif self.degenerated_to_sgd: 165 | if group['weight_decay'] != 0: 166 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 167 | step_size = group['lr'] / (1 - beta1 ** state['step']) 168 | p_data_fp32.add_(-step_size, exp_avg) 169 | p.data.copy_(p_data_fp32) 170 | 171 | return loss 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss 245 | 246 | 247 | #Ranger deep learning optimizer - RAdam + Lookahead combined. 248 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 249 | 250 | #Ranger has now been used to capture 12 records on the FastAI leaderboard. 251 | 252 | #This version = 9.3.19 253 | 254 | #Credits: 255 | #RAdam --> https://github.com/LiyuanLucasLiu/RAdam 256 | #Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 257 | #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 258 | 259 | #summary of changes: 260 | #full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 261 | #supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 262 | #changes 8/31/19 - fix references to *self*.N_sma_threshold; 263 | #changed eps to 1e-5 as better default than 1e-8. 264 | 265 | 266 | class Ranger(Optimizer): 267 | 268 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0): 269 | #parameter checks 270 | if not 0.0 <= alpha <= 1.0: 271 | raise ValueError(f'Invalid slow update rate: {alpha}') 272 | if not 1 <= k: 273 | raise ValueError(f'Invalid lookahead steps: {k}') 274 | if not lr > 0: 275 | raise ValueError(f'Invalid Learning Rate: {lr}') 276 | if not eps > 0: 277 | raise ValueError(f'Invalid eps: {eps}') 278 | 279 | #parameter comments: 280 | # beta1 (momentum) of .95 seems to work better than .90... 281 | #N_sma_threshold of 5 seems better in testing than 4. 282 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 283 | 284 | #prep defaults and init torch.optim base 285 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 286 | super().__init__(params,defaults) 287 | 288 | #adjustable threshold 289 | self.N_sma_threshhold = N_sma_threshhold 290 | 291 | #now we can get to work... 292 | #removed as we now use step from RAdam...no need for duplicate step counting 293 | #for group in self.param_groups: 294 | # group["step_counter"] = 0 295 | #print("group step counter init") 296 | 297 | #look ahead params 298 | self.alpha = alpha 299 | self.k = k 300 | 301 | #radam buffer for state 302 | self.radam_buffer = [[None,None,None] for ind in range(10)] 303 | 304 | #self.first_run_check=0 305 | 306 | #lookahead weights 307 | #9/2/19 - lookahead param tensors have been moved to state storage. 308 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. 309 | 310 | #self.slow_weights = [[p.clone().detach() for p in group['params']] 311 | # for group in self.param_groups] 312 | 313 | #don't use grad for lookahead weights 314 | #for w in it.chain(*self.slow_weights): 315 | # w.requires_grad = False 316 | 317 | def __setstate__(self, state): 318 | print("set state called") 319 | super(Ranger, self).__setstate__(state) 320 | 321 | 322 | def step(self, closure=None): 323 | loss = None 324 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 325 | #Uncomment if you need to use the actual closure... 326 | 327 | #if closure is not None: 328 | #loss = closure() 329 | 330 | #Evaluate averages and grad, update param tensors 331 | for group in self.param_groups: 332 | 333 | for p in group['params']: 334 | if p.grad is None: 335 | continue 336 | grad = p.grad.data.float() 337 | if grad.is_sparse: 338 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 339 | 340 | p_data_fp32 = p.data.float() 341 | 342 | state = self.state[p] #get state dict for this param 343 | 344 | if len(state) == 0: #if first time to run...init dictionary with our desired entries 345 | #if self.first_run_check==0: 346 | #self.first_run_check=1 347 | #print("Initializing slow buffer...should not see this at load from saved model!") 348 | state['step'] = 0 349 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 350 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 351 | 352 | #look ahead weight storage now in state dict 353 | state['slow_buffer'] = torch.empty_like(p.data) 354 | state['slow_buffer'].copy_(p.data) 355 | 356 | else: 357 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 358 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 359 | 360 | #begin computations 361 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 362 | beta1, beta2 = group['betas'] 363 | 364 | #compute variance mov avg 365 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 366 | #compute mean moving avg 367 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 368 | 369 | state['step'] += 1 370 | 371 | 372 | buffered = self.radam_buffer[int(state['step'] % 10)] 373 | if state['step'] == buffered[0]: 374 | N_sma, step_size = buffered[1], buffered[2] 375 | else: 376 | buffered[0] = state['step'] 377 | beta2_t = beta2 ** state['step'] 378 | N_sma_max = 2 / (1 - beta2) - 1 379 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 380 | buffered[1] = N_sma 381 | if N_sma > self.N_sma_threshhold: 382 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 383 | else: 384 | step_size = 1.0 / (1 - beta1 ** state['step']) 385 | buffered[2] = step_size 386 | 387 | if group['weight_decay'] != 0: 388 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 389 | 390 | if N_sma > self.N_sma_threshhold: 391 | denom = exp_avg_sq.sqrt().add_(group['eps']) 392 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 393 | else: 394 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 395 | 396 | p.data.copy_(p_data_fp32) 397 | 398 | #integrated look ahead... 399 | #we do it at the param level instead of group level 400 | if state['step'] % group['k'] == 0: 401 | slow_p = state['slow_buffer'] #get access to slow param tensor 402 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha 403 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor 404 | 405 | return loss -------------------------------------------------------------------------------- /utils/save_weights_only.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | def get_opts(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument('--ckpt_path', type=str, required=True, 8 | help='checkpoint path') 9 | 10 | return parser.parse_args() 11 | 12 | if __name__ == "__main__": 13 | args = get_opts() 14 | checkpoint = torch.load(args.ckpt_path, map_location=torch.device('cpu')) 15 | torch.save(checkpoint['state_dict'], args.ckpt_path.split('/')[-2]+'.ckpt') 16 | print('Done!') -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | 6 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 7 | """ 8 | depth: (H, W) 9 | """ 10 | x = depth.cpu().numpy() 11 | x = np.nan_to_num(x) # change nan to 0 12 | mi = np.min(x) # get minimum depth 13 | ma = np.max(x) 14 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 15 | x = (255*x).astype(np.uint8) 16 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 17 | x_ = T.ToTensor()(x_) # (3, H, W) 18 | return x_ -------------------------------------------------------------------------------- /utils/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | multiplier: target learning rate = base lr * multiplier 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 15 | self.multiplier = multiplier 16 | if self.multiplier < 1.: 17 | raise ValueError('multiplier should be greater thant or equal to 1.') 18 | self.total_epoch = total_epoch 19 | self.after_scheduler = after_scheduler 20 | self.finished = False 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.last_epoch > self.total_epoch: 25 | if self.after_scheduler: 26 | if not self.finished: 27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 28 | self.finished = True 29 | return self.after_scheduler.get_lr() 30 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | 32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 33 | 34 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 35 | if epoch is None: 36 | epoch = self.last_epoch + 1 37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 38 | if self.last_epoch <= self.total_epoch: 39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 41 | param_group['lr'] = lr 42 | else: 43 | if epoch is None: 44 | self.after_scheduler.step(metrics, None) 45 | else: 46 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 47 | 48 | def step(self, epoch=None, metrics=None): 49 | if type(self.after_scheduler) != ReduceLROnPlateau: 50 | if self.finished and self.after_scheduler: 51 | if epoch is None: 52 | self.after_scheduler.step(None) 53 | else: 54 | self.after_scheduler.step(epoch - self.total_epoch) 55 | else: 56 | return super(GradualWarmupScheduler, self).step(epoch) 57 | else: 58 | self.step_ReduceLROnPlateau(metrics, epoch) --------------------------------------------------------------------------------