├── .github
├── FUNDING.yml
└── ISSUE_TEMPLATE
│ └── bug_report.md
├── .gitignore
├── .gitmodules
├── LICENSE
├── LICENSE_bmild
├── README.md
├── README_Unity.md
├── README_mesh.md
├── datasets
├── __init__.py
├── blender.py
├── depth_utils.py
├── llff.py
└── ray_utils.py
├── docs
├── .gitignore
├── Gemfile
├── Gemfile.lock
├── _config.yml
├── index.md
└── style.css
├── eval.py
├── extract_color_mesh.py
├── extract_mesh.ipynb
├── losses.py
├── metrics.py
├── models
├── __init__.py
├── nerf.py
└── rendering.py
├── opt.py
├── requirements.txt
├── test.ipynb
├── train.py
└── utils
├── __init__.py
├── optimizers.py
├── save_weights_only.py
├── visualization.py
└── warmup_scheduler.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: kwea123
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **Which branch you use**
14 | Only issues of the *dev* and *nerfw* branches will be considered currently!
15 |
16 | **To Reproduce**
17 | Steps to reproduce the behavior:
18 | 1. Go to '...'
19 | 2. Click on '....'
20 | 3. Scroll down to '....'
21 | 4. See error
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | logs/
3 | ckpts/
4 | results/
5 | *.ply
6 | *.vol
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "torchsearchsorted"]
2 | path = torchsearchsorted
3 | url = https://github.com/aliutkus/torchsearchsorted.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Quei-An Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE_bmild:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 bmild
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # nerf_pl
2 |
3 | ### Update: NVIDIA open-sourced a lightning-fast version of NeRF: [NGP](https://github.com/NVlabs/instant-ngp). I re-implemented in pytorch [here](https://github.com/kwea123/ngp_pl). This version is ~100x faster than this repo with also better quality!
4 |
5 | ### Update: an improved [NSFF](https://www.cs.cornell.edu/~zl548/NSFF/) implementation to handle dynamic scene is [open](https://github.com/kwea123/nsff_pl)!
6 |
7 | ### Update: [NeRF-W](https://nerf-w.github.io/) (NeRF in the Wild) implementation is added to [nerfw](https://github.com/kwea123/nerf_pl/tree/nerfw) branch!
8 |
9 | ### Update: The lastest code (using the latest libraries) will be updated to [dev](https://github.com/kwea123/nerf_pl/tree/dev) branch. The master branch remains to support the colab files. If you don't use colab, it is recommended to switch to dev branch. Only issues of the dev and nerfw branch will be considered currently.
10 |
11 | ### :gem: [**Project page**](https://kwea123.github.io/nerf_pl/) (live demo!)
12 |
13 | Unofficial implementation of [NeRF](https://arxiv.org/pdf/2003.08934.pdf) (Neural Radiance Fields) using pytorch ([pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning)). This repo doesn't aim at reproducibility, but aim at providing a simpler and faster training procedure (also simpler code with detailed comments to help to understand the work). Moreover, I try to extend much more opportunities by integrating this algorithm into game engine like Unity.
14 |
15 | Official implementation: [nerf](https://github.com/bmild/nerf) .. Reference pytorch implementation: [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch)
16 |
17 | ### Recommend to read: A detailed NeRF extension list: [awesome-NeRF](https://github.com/yenchenlin/awesome-NeRF)
18 |
19 | ## :milky_way: Features
20 |
21 | * Multi-gpu training: Training on 8 GPUs finishes within 1 hour for the synthetic dataset!
22 | * [Colab](#mortar_board-colab) notebooks to allow easy usage!
23 | * [Reconstruct](#ribbon-mesh) **colored** mesh!
24 | * [Mixed Reality](https://youtu.be/S5phWFTs2iM) in Unity!
25 | * [REAL TIME volume rendering](https://youtu.be/w9qTbVzCdWk) in Unity!
26 | * [Portable Scenes](#portable-scenes) to let you play with other people's scenes!
27 |
28 | ### You can find the Unity project including mesh, mixed reality and volume rendering [here](https://github.com/kwea123/nerf_Unity)! See [README_Unity](README_Unity.md) for generating your own data for Unity rendering!
29 |
30 | ## :beginner: Tutorial
31 |
32 | ### What can NeRF do?
33 |
34 |
35 | ### Tutorial videos
36 |
37 |
38 |
39 |
40 | # :computer: Installation
41 |
42 | ## Hardware
43 |
44 | * OS: Ubuntu 18.04
45 | * NVIDIA GPU with **CUDA>=10.1** (tested with 1 RTX2080Ti)
46 |
47 | ## Software
48 |
49 | * Clone this repo by `git clone --recursive https://github.com/kwea123/nerf_pl`
50 | * Python>=3.6 (installation via [anaconda](https://www.anaconda.com/distribution/) is recommended, use `conda create -n nerf_pl python=3.6` to create a conda environment and activate it by `conda activate nerf_pl`)
51 | * Python libraries
52 | * Install core requirements by `pip install -r requirements.txt`
53 | * Install `torchsearchsorted` by `cd torchsearchsorted` then `pip install .`
54 |
55 | # :key: Training
56 |
57 | Please see each subsection for training on different datasets. Available training datasets:
58 |
59 | * [Blender](#blender) (Realistic Synthetic 360)
60 | * [LLFF](#llff) (Real Forward-Facing)
61 | * [Your own data](#your-own-data) (Forward-Facing/360 inward-facing)
62 |
63 | ## Blender
64 |
65 | Steps
66 |
67 | ### Data download
68 |
69 | Download `nerf_synthetic.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
70 |
71 | ### Training model
72 |
73 | Run (example)
74 | ```
75 | python train.py \
76 | --dataset_name blender \
77 | --root_dir $BLENDER_DIR \
78 | --N_importance 64 --img_wh 400 400 --noise_std 0 \
79 | --num_epochs 16 --batch_size 1024 \
80 | --optimizer adam --lr 5e-4 \
81 | --lr_scheduler steplr --decay_step 2 4 8 --decay_gamma 0.5 \
82 | --exp_name exp
83 | ```
84 |
85 | These parameters are chosen to best mimic the training settings in the original repo. See [opt.py](opt.py) for all configurations.
86 |
87 | NOTE: the above configuration doesn't work for some scenes like `drums`, `ship`. In that case, consider increasing the `batch_size` or change the `optimizer` to `radam`. I managed to train on all scenes with these modifications.
88 |
89 | You can monitor the training process by `tensorboard --logdir logs/` and go to `localhost:6006` in your browser.
90 |
91 |
92 | ## LLFF
93 |
94 | Steps
95 |
96 | ### Data download
97 |
98 | Download `nerf_llff_data.zip` from [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
99 |
100 | ### Training model
101 |
102 | Run (example)
103 | ```
104 | python train.py \
105 | --dataset_name llff \
106 | --root_dir $LLFF_DIR \
107 | --N_importance 64 --img_wh 504 378 \
108 | --num_epochs 30 --batch_size 1024 \
109 | --optimizer adam --lr 5e-4 \
110 | --lr_scheduler steplr --decay_step 10 20 --decay_gamma 0.5 \
111 | --exp_name exp
112 | ```
113 |
114 | These parameters are chosen to best mimic the training settings in the original repo. See [opt.py](opt.py) for all configurations.
115 |
116 | You can monitor the training process by `tensorboard --logdir logs/` and go to `localhost:6006` in your browser.
117 |
118 |
119 | ## Your own data
120 |
121 | Steps
122 |
123 | 1. Install [COLMAP](https://github.com/colmap/colmap) following [installation guide](https://colmap.github.io/install.html)
124 | 2. Prepare your images in a folder (around 20 to 30 for forward facing, and 40 to 50 for 360 inward-facing)
125 | 3. Clone [LLFF](https://github.com/Fyusion/LLFF) and run `python img2poses.py $your-images-folder`
126 | 4. Train the model using the same command as in [LLFF](#llff). If the scene is captured in a 360 inward-facing manner, add `--spheric` argument.
127 |
128 | For more details of training a good model, please see the video [here](#colab).
129 |
130 |
131 | ## Pretrained models and logs
132 | Download the pretrained models and training logs in [release](https://github.com/kwea123/nerf_pl/releases).
133 |
134 | ## Comparison with other repos
135 |
136 | | | training GPU memory in GB | Speed (1 step) |
137 | | :---: | :---: | :---: |
138 | | [Original](https://github.com/bmild/nerf) | 8.5 | 0.177s |
139 | | [Ref pytorch](https://github.com/yenchenlin/nerf-pytorch) | 6.0 | 0.147s |
140 | | This repo | 3.2 | 0.12s |
141 |
142 | The speed is measured on 1 RTX2080Ti. Detailed profile can be found in [release](https://github.com/kwea123/nerf_pl/releases).
143 | Training memory is largely reduced, since the original repo loads the whole data to GPU at the beginning, while we only pass batches to GPU every step.
144 |
145 | # :mag_right: Testing
146 |
147 | See [test.ipynb](test.ipynb) for a simple view synthesis and depth prediction on 1 image.
148 |
149 | Use [eval.py](eval.py) to create the whole sequence of moving views.
150 | E.g.
151 | ```
152 | python eval.py \
153 | --root_dir $BLENDER \
154 | --dataset_name blender --scene_name lego \
155 | --img_wh 400 400 --N_importance 64 --ckpt_path $CKPT_PATH
156 | ```
157 | **IMPORTANT** : Don't forget to add `--spheric_poses` if the model is trained under `--spheric` setting!
158 |
159 | It will create folder `results/{dataset_name}/{scene_name}` and run inference on all test data, finally create a gif out of them.
160 |
161 | Example of lego scene using pretrained model and the reconstructed **colored** mesh: (PSNR=31.39, paper=32.54)
162 |
163 |
164 |
165 |
166 |
167 |
168 | Example of fern scene using pretrained model:
169 |
170 | 
171 |
172 | Example of own scene ([Silica GGO figure](https://www.youtube.com/watch?v=hVQIvEq_Av0)) and the reconstructed **colored** mesh. Click to link to youtube video.
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | ## Portable scenes
182 | The concept of NeRF is that the whole scene is compressed into a NeRF model, then we can render from any pose we want. To render from plausible poses, we can leverage the training poses; therefore, you can generate video with **only** the trained model and the poses (hence the name of portable scenes). I provided my silica model in [release](https://github.com/kwea123/nerf_pl/releases), feel free to play around with it!
183 |
184 | If you trained some interesting scenes, you are also welcomed to share the model (and the `poses_bounds.npy`) by sending me an email, or post in issues! After all, a model is just around **5MB**! Please run `python utils/save_weights_only.py --ckpt_path $YOUR_MODEL_PATH` to extract the final model.
185 |
186 | # :ribbon: Mesh
187 |
188 | See [README_mesh](README_mesh.md) for reconstruction of **colored** mesh. Only supported for blender dataset and 360 inward-facing data!
189 |
190 | # :warning: Notes on differences with the original repo
191 |
192 | * The learning rate decay in the original repo is **by step**, which means it decreases every step, here I use learning rate decay **by epoch**, which means it changes only at the end of 1 epoch.
193 | * The validation image for LLFF dataset is chosen as the most centered image here, whereas the original repo chooses every 8th image.
194 | * The rendering spiral path is slightly different from the original repo (I use approximate values to simplify the code).
195 |
196 | # :mortar_board: COLAB
197 |
198 | I also prepared colab notebooks that allow you to run the algorithm on any machine without GPU requirement.
199 |
200 | * [colmap](https://gist.github.com/kwea123/f0e8f38ff2aa94495dbfe7ae9219f75c) to prepare camera poses for your own training data
201 | * [nerf](https://gist.github.com/kwea123/a3c541a325e895ef79ecbc0d2e6d7221) to train on your data
202 | * [extract_mesh](https://gist.github.com/kwea123/77ed1640f9bc9550136dc13a6a419e88) to extract colored mesh
203 |
204 | Please see [this playlist](https://www.youtube.com/playlist?list=PLDV2CyUo4q-K02pNEyDr7DYpTQuka3mbV) for the detailed tutorials.
205 |
206 | # :jack_o_lantern: SHOWOFF
207 |
208 | We can incorporate *ray tracing* techniques into the volume rendering pipeline, and realize realistic scene editing (following is the `materials` scene with an object removed, and a mesh is inserted and rendered with ray tracing). The code **will not** be released.
209 |
210 | 
211 | 
212 |
213 | With my integration in Unity, I can realize realistic mixed reality photos (note my character casts shadow on the scene, **zero** post- image editing required):
214 | 
215 | 
216 | BTW, I would like to visit the museum one day...
217 |
218 | # :book: Citation
219 | If you use (part of) my code or find my work helpful, please consider citing
220 | ```
221 | @misc{queianchen_nerf,
222 | author={Quei-An, Chen},
223 | title={Nerf_pl: a pytorch-lightning implementation of NeRF},
224 | url={https://github.com/kwea123/nerf_pl/},
225 | year={2020},
226 | }
227 | ```
228 |
--------------------------------------------------------------------------------
/README_Unity.md:
--------------------------------------------------------------------------------
1 | # Generating files for Unity rendering
2 |
3 | This readme contains guidances for generating files required for Unity rendering in my [Unity project](https://github.com/kwea123/nerf_Unity)
4 |
5 | ## MeshRender
6 |
7 | See [README_mesh](README_mesh.md) for generating mesh.
8 | You then need [this plugin](https://github.com/kwea123/Pcx) to import `.ply` files into Unity.
9 |
10 | ## MixedReality
11 |
12 | Use `eval.py` with `--save_depth --depth_format bytes`to create the whole sequence of moving views. E.g.
13 | ```
14 | python eval.py \
15 | --root_dir $BLENDER \
16 | --dataset_name blender --scene_name lego \
17 | --img_wh 400 400 --N_importance 64 --ckpt_path $CKPT_PATH \
18 | --save_depth --depth_format bytes
19 | ```
20 | You will get `*.png` files and corresponding `depth_*` files. Now import the image you want to show and its corresponding depth file into Unity, and replace the files in my Unity project.
21 |
22 | ## VolumeRender
23 |
24 | Use `extract_mesh.ipynb` (not `extract_color_mesh.py`!) to find the tight bounds for the object as for mesh generation (See [this video](https://www.youtube.com/watch?v=t06qu-gXrxA&t=1355)), but this time stop before the cell "Extract colored mesh". Remember to set `N=512` in the cell "Search for tight bounds of the object" and comment out the lines for visualization. Now run the cell "Generate .vol file for volume rendering in Unity", after that, you should obtain a `.vol` file, which you can import to my Unity project and render.
25 |
26 | **NOTE:** If you use colab as in the video, copy the cell "Generate .vol file for volume rendering in Unity" into colab notebook and execute it.
27 |
28 | If you want to render in your own project, you need the script [LoadVolume.cs](https://github.com/kwea123/nerf_Unity/blob/master/Assets/Editor/LoadVolume.cs) which reads this own-defined `.vol` into a `Texture3D`.
29 |
--------------------------------------------------------------------------------
/README_mesh.md:
--------------------------------------------------------------------------------
1 | # Reconstruct mesh
2 |
3 | Use `extract_mesh.ipynb` (the notebook, **not** the py script) to extract **colored** mesh. The guideline for choosing good parameters is commented in the notebook.
4 | Here, I'll give detailed explanation of how it works. There is also a [video](https://youtu.be/t06qu-gXrxA) that explains the same thing.
5 |
6 | ## Step 1. Predict occupancy
7 |
8 | As the [original repo](https://github.com/bmild/nerf/blob/master/extract_mesh.ipynb), we need to first infer which locations are occupied by the object. This is done by first create a grid volume in the form of a cuboid covering the whole object, then use the nerf model to predict whether a cell is occupied or not. This is the main reason why mesh construction is only available for 360 inward-facing scenes as forward facing scenes would require a **huge** volume to cover the whole space! It is computationally impossible to predict the occupancy for all cells.
9 |
10 | ## Step 2. Perform marching cube algorithm
11 |
12 | After we know which cells are occupied, we can use [marching cube algorithm](https://en.wikipedia.org/wiki/Marching_cubes) to extract mesh. This mesh will only contain vertices and faces, if you don't require color, you can stop here and export the mesh. Until here, the code is the same as the original repo.
13 |
14 | ## Step 3. Remove noise
15 |
16 | The mesh might contain some noise, which could be due to wrongly predicted occupancy in step 1, or you might consider the floor as noise. To remove these noises, we use a simple method: only keep the largest cluster. We cluster the triangles into groups (two triangles are in the same group if they are connected), and only keep the biggest one. After removing the noise, we then compute the color for each vertex.
17 |
18 | ## Step 4. Compute color for each vertex
19 |
20 | We adopt the concept of assigning colors to vertices instead of faces (they are actually somehow equivalent, as you can think of the color of vertices as the average color of neighboring faces and vice versa). To compute the color of a vertex, we leverage the **training images**: we project this vertex onto the training images to get its rgb values, then average these values as its final color. Notice that the projected pixel coordinates are floating numbers, and we use *bilinear interpolation* as its rgb value.
21 |
22 | This process might seem correct at first sight, however, this is what we'll get:
23 |
24 |
25 |
26 | by projecting the vertices onto this input image:
27 |
28 |
29 |
30 | You'll notice the face appears on the mantle. Why is that? It is because of **occlusion**.
31 |
32 | From the input image view, that spurious part of the mantle is actually occluded (blocked) by the face, so in reality we **shouldn't** assign color to it, but the above process assigns it the same color as the face because those vertices are projected onto the face (in pixel coordinate) as well!
33 |
34 | So the problem becomes: How do we correctly infer occlusion information, to know which vertices shouldn't be assigned colors? I tried two methods, where the first turns out to not work well:
35 |
36 | 1. Use depth information
37 |
38 | The first intuitive way is to leverage vertices' depths (which is obtained when projecting vertices onto image plane): if two (or more) vertices are projected onto the **same** pixel coordinates, then only the nearest vertex will be assigned color, the rest remains untouched. However, this method won't work since no any two pixels will be projected onto the exact same location! As we mentioned earlier, the pixel coordinates are floating numbers, so it is impossible for they to be exactly the same. If we round the numbers to integers (which I tried as well), then this method works, but with still a lot of misclassified (occluded/non occluded) vertices in my experiments.
39 |
40 | 2. Leverage NeRF model
41 |
42 | What I find a intelligent way to infer occlusion is by using NeRF model. Recall that nerf model can estimate the opacity (or density) along a ray path (the following figure c):
43 | 
44 | We can leverage that information to tell if a vertex is occluded or not. More concretely, we form rays originating from the camera origin, destinating (ending) at the vertices, and compute the total opacity along these rays. If a vertex is not occluded, the opacity will be small; otherwise, the value will be large, meaning that something lies between the vertex and the camera.
45 |
46 | After applying this method, this is what we get (by projecting the vertices onto the input view as above):
47 |
48 |
49 | The spurious face on the mantle disappears, and the colored pixels are almost exactly the ones we can observe from the image. By default we set the vertices to be all black, so a black vertex means it's occluded in this view, but will be assigned color when we change to other views.
50 |
51 |
52 | # Finally...
53 |
54 | This is the final result:
55 |
56 |
57 |
58 | We can then export this `.ply` file to any other format, and embed in programs like I did in Unity:
59 | (I use [this plugin](https://github.com/kwea123/Pcx) to import `.ply` file to unity, and made some modifications so that it can also read mesh triangles, not only the points)
60 |
61 | 
62 |
63 | The meshes can be attached a meshcollider so that they can interact with other objects. You can see [this video](https://youtu.be/I2M0xhnrBos) for a demo.
64 |
65 | ## Further reading
66 | The author suggested [another way](https://github.com/bmild/nerf/issues/44#issuecomment-622961303) to extract color, in my experiments it doesn't turn out to be good, but the idea is reasonable and interesting. You can also test this by setting `--use_vertex_normal`.
67 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .blender import BlenderDataset
2 | from .llff import LLFFDataset
3 |
4 | dataset_dict = {'blender': BlenderDataset,
5 | 'llff': LLFFDataset}
--------------------------------------------------------------------------------
/datasets/blender.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import json
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 | class BlenderDataset(Dataset):
12 | def __init__(self, root_dir, split='train', img_wh=(800, 800)):
13 | self.root_dir = root_dir
14 | self.split = split
15 | assert img_wh[0] == img_wh[1], 'image width must equal image height!'
16 | self.img_wh = img_wh
17 | self.define_transforms()
18 |
19 | self.read_meta()
20 | self.white_back = True
21 |
22 | def read_meta(self):
23 | with open(os.path.join(self.root_dir,
24 | f"transforms_{self.split}.json"), 'r') as f:
25 | self.meta = json.load(f)
26 |
27 | w, h = self.img_wh
28 | self.focal = 0.5*800/np.tan(0.5*self.meta['camera_angle_x']) # original focal length
29 | # when W=800
30 |
31 | self.focal *= self.img_wh[0]/800 # modify focal length to match size self.img_wh
32 |
33 | # bounds, common for all scenes
34 | self.near = 2.0
35 | self.far = 6.0
36 | self.bounds = np.array([self.near, self.far])
37 |
38 | # ray directions for all pixels, same for all images (same H, W, focal)
39 | self.directions = \
40 | get_ray_directions(h, w, self.focal) # (h, w, 3)
41 |
42 | if self.split == 'train': # create buffer of all rays and rgb data
43 | self.image_paths = []
44 | self.poses = []
45 | self.all_rays = []
46 | self.all_rgbs = []
47 | for frame in self.meta['frames']:
48 | pose = np.array(frame['transform_matrix'])[:3, :4]
49 | self.poses += [pose]
50 | c2w = torch.FloatTensor(pose)
51 |
52 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
53 | self.image_paths += [image_path]
54 | img = Image.open(image_path)
55 | img = img.resize(self.img_wh, Image.LANCZOS)
56 | img = self.transform(img) # (4, h, w)
57 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
58 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB
59 | self.all_rgbs += [img]
60 |
61 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
62 |
63 | self.all_rays += [torch.cat([rays_o, rays_d,
64 | self.near*torch.ones_like(rays_o[:, :1]),
65 | self.far*torch.ones_like(rays_o[:, :1])],
66 | 1)] # (h*w, 8)
67 |
68 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
69 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
70 |
71 | def define_transforms(self):
72 | self.transform = T.ToTensor()
73 |
74 | def __len__(self):
75 | if self.split == 'train':
76 | return len(self.all_rays)
77 | if self.split == 'val':
78 | return 8 # only validate 8 images (to support <=8 gpus)
79 | return len(self.meta['frames'])
80 |
81 | def __getitem__(self, idx):
82 | if self.split == 'train': # use data in the buffers
83 | sample = {'rays': self.all_rays[idx],
84 | 'rgbs': self.all_rgbs[idx]}
85 |
86 | else: # create data for each image separately
87 | frame = self.meta['frames'][idx]
88 | c2w = torch.FloatTensor(frame['transform_matrix'])[:3, :4]
89 |
90 | img = Image.open(os.path.join(self.root_dir, f"{frame['file_path']}.png"))
91 | img = img.resize(self.img_wh, Image.LANCZOS)
92 | img = self.transform(img) # (4, H, W)
93 | valid_mask = (img[-1]>0).flatten() # (H*W) valid color area
94 | img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA
95 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB
96 |
97 | rays_o, rays_d = get_rays(self.directions, c2w)
98 |
99 | rays = torch.cat([rays_o, rays_d,
100 | self.near*torch.ones_like(rays_o[:, :1]),
101 | self.far*torch.ones_like(rays_o[:, :1])],
102 | 1) # (H*W, 8)
103 |
104 | sample = {'rays': rays,
105 | 'rgbs': img,
106 | 'c2w': c2w,
107 | 'valid_mask': valid_mask}
108 |
109 | return sample
--------------------------------------------------------------------------------
/datasets/depth_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 | import sys
4 |
5 | def read_pfm(filename):
6 | file = open(filename, 'rb')
7 | color = None
8 | width = None
9 | height = None
10 | scale = None
11 | endian = None
12 |
13 | header = file.readline().decode('utf-8').rstrip()
14 | if header == 'PF':
15 | color = True
16 | elif header == 'Pf':
17 | color = False
18 | else:
19 | raise Exception('Not a PFM file.')
20 |
21 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
22 | if dim_match:
23 | width, height = map(int, dim_match.groups())
24 | else:
25 | raise Exception('Malformed PFM header.')
26 |
27 | scale = float(file.readline().rstrip())
28 | if scale < 0: # little-endian
29 | endian = '<'
30 | scale = -scale
31 | else:
32 | endian = '>' # big-endian
33 |
34 | data = np.fromfile(file, endian + 'f')
35 | shape = (height, width, 3) if color else (height, width)
36 |
37 | data = np.reshape(data, shape)
38 | data = np.flipud(data)
39 | file.close()
40 | return data, scale
41 |
42 |
43 | def save_pfm(filename, image, scale=1):
44 | file = open(filename, "wb")
45 | color = None
46 |
47 | image = np.flipud(image)
48 |
49 | if image.dtype.name != 'float32':
50 | raise Exception('Image dtype must be float32.')
51 |
52 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
53 | color = True
54 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
55 | color = False
56 | else:
57 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
58 |
59 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
60 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
61 |
62 | endian = image.dtype.byteorder
63 |
64 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
65 | scale = -scale
66 |
67 | file.write(('%f\n' % scale).encode('utf-8'))
68 |
69 | image.tofile(file)
70 | file.close()
--------------------------------------------------------------------------------
/datasets/llff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import glob
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | def normalize(v):
13 | """Normalize a vector."""
14 | return v/np.linalg.norm(v)
15 |
16 |
17 | def average_poses(poses):
18 | """
19 | Calculate the average pose, which is then used to center all poses
20 | using @center_poses. Its computation is as follows:
21 | 1. Compute the center: the average of pose centers.
22 | 2. Compute the z axis: the normalized average z axis.
23 | 3. Compute axis y': the average y axis.
24 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
25 | 5. Compute the y axis: z cross product x.
26 |
27 | Note that at step 3, we cannot directly use y' as y axis since it's
28 | not necessarily orthogonal to z axis. We need to pass from x to y.
29 |
30 | Inputs:
31 | poses: (N_images, 3, 4)
32 |
33 | Outputs:
34 | pose_avg: (3, 4) the average pose
35 | """
36 | # 1. Compute the center
37 | center = poses[..., 3].mean(0) # (3)
38 |
39 | # 2. Compute the z axis
40 | z = normalize(poses[..., 2].mean(0)) # (3)
41 |
42 | # 3. Compute axis y' (no need to normalize as it's not the final output)
43 | y_ = poses[..., 1].mean(0) # (3)
44 |
45 | # 4. Compute the x axis
46 | x = normalize(np.cross(y_, z)) # (3)
47 |
48 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
49 | y = np.cross(z, x) # (3)
50 |
51 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
52 |
53 | return pose_avg
54 |
55 |
56 | def center_poses(poses):
57 | """
58 | Center the poses so that we can use NDC.
59 | See https://github.com/bmild/nerf/issues/34
60 |
61 | Inputs:
62 | poses: (N_images, 3, 4)
63 |
64 | Outputs:
65 | poses_centered: (N_images, 3, 4) the centered poses
66 | pose_avg: (3, 4) the average pose
67 | """
68 |
69 | pose_avg = average_poses(poses) # (3, 4)
70 | pose_avg_homo = np.eye(4)
71 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
72 | # by simply adding 0, 0, 0, 1 as the last row
73 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
74 | poses_homo = \
75 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
76 |
77 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
78 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
79 |
80 | return poses_centered, np.linalg.inv(pose_avg_homo)
81 |
82 |
83 | def create_spiral_poses(radii, focus_depth, n_poses=120):
84 | """
85 | Computes poses that follow a spiral path for rendering purpose.
86 | See https://github.com/Fyusion/LLFF/issues/19
87 | In particular, the path looks like:
88 | https://tinyurl.com/ybgtfns3
89 |
90 | Inputs:
91 | radii: (3) radii of the spiral for each axis
92 | focus_depth: float, the depth that the spiral poses look at
93 | n_poses: int, number of poses to create along the path
94 |
95 | Outputs:
96 | poses_spiral: (n_poses, 3, 4) the poses in the spiral path
97 | """
98 |
99 | poses_spiral = []
100 | for t in np.linspace(0, 4*np.pi, n_poses+1)[:-1]: # rotate 4pi (2 rounds)
101 | # the parametric function of the spiral (see the interactive web)
102 | center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5*t)]) * radii
103 |
104 | # the viewing z axis is the vector pointing from the @focus_depth plane
105 | # to @center
106 | z = normalize(center - np.array([0, 0, -focus_depth]))
107 |
108 | # compute other axes as in @average_poses
109 | y_ = np.array([0, 1, 0]) # (3)
110 | x = normalize(np.cross(y_, z)) # (3)
111 | y = np.cross(z, x) # (3)
112 |
113 | poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4)
114 |
115 | return np.stack(poses_spiral, 0) # (n_poses, 3, 4)
116 |
117 |
118 | def create_spheric_poses(radius, n_poses=120):
119 | """
120 | Create circular poses around z axis.
121 | Inputs:
122 | radius: the (negative) height and the radius of the circle.
123 |
124 | Outputs:
125 | spheric_poses: (n_poses, 3, 4) the poses in the circular path
126 | """
127 | def spheric_pose(theta, phi, radius):
128 | trans_t = lambda t : np.array([
129 | [1,0,0,0],
130 | [0,1,0,-0.9*t],
131 | [0,0,1,t],
132 | [0,0,0,1],
133 | ])
134 |
135 | rot_phi = lambda phi : np.array([
136 | [1,0,0,0],
137 | [0,np.cos(phi),-np.sin(phi),0],
138 | [0,np.sin(phi), np.cos(phi),0],
139 | [0,0,0,1],
140 | ])
141 |
142 | rot_theta = lambda th : np.array([
143 | [np.cos(th),0,-np.sin(th),0],
144 | [0,1,0,0],
145 | [np.sin(th),0, np.cos(th),0],
146 | [0,0,0,1],
147 | ])
148 |
149 | c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius)
150 | c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w
151 | return c2w[:3]
152 |
153 | spheric_poses = []
154 | for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]:
155 | spheric_poses += [spheric_pose(th, -np.pi/5, radius)] # 36 degree view downwards
156 | return np.stack(spheric_poses, 0)
157 |
158 |
159 | class LLFFDataset(Dataset):
160 | def __init__(self, root_dir, split='train', img_wh=(504, 378), spheric_poses=False, val_num=1):
161 | """
162 | spheric_poses: whether the images are taken in a spheric inward-facing manner
163 | default: False (forward-facing)
164 | val_num: number of val images (used for multigpu training, validate same image for all gpus)
165 | """
166 | self.root_dir = root_dir
167 | self.split = split
168 | self.img_wh = img_wh
169 | self.spheric_poses = spheric_poses
170 | self.val_num = max(1, val_num) # at least 1
171 | self.define_transforms()
172 |
173 | self.read_meta()
174 | self.white_back = False
175 |
176 | def read_meta(self):
177 | poses_bounds = np.load(os.path.join(self.root_dir,
178 | 'poses_bounds.npy')) # (N_images, 17)
179 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*')))
180 | # load full resolution image then resize
181 | if self.split in ['train', 'val']:
182 | assert len(poses_bounds) == len(self.image_paths), \
183 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
184 |
185 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
186 | self.bounds = poses_bounds[:, -2:] # (N_images, 2)
187 |
188 | # Step 1: rescale focal length according to training resolution
189 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
190 | assert H*self.img_wh[0] == W*self.img_wh[1], \
191 | f'You must set @img_wh to have the same aspect ratio as ({W}, {H}) !'
192 |
193 | self.focal *= self.img_wh[0]/W
194 |
195 | # Step 2: correct poses
196 | # Original poses has rotation in form "down right back", change to "right up back"
197 | # See https://github.com/bmild/nerf/issues/34
198 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
199 | # (N_images, 3, 4) exclude H, W, focal
200 | self.poses, self.pose_avg = center_poses(poses)
201 | distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
202 | val_idx = np.argmin(distances_from_center) # choose val image as the closest to
203 | # center image
204 |
205 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
206 | # See https://github.com/bmild/nerf/issues/34
207 | near_original = self.bounds.min()
208 | scale_factor = near_original*0.75 # 0.75 is the default parameter
209 | # the nearest depth is at 1/0.75=1.33
210 | self.bounds /= scale_factor
211 | self.poses[..., 3] /= scale_factor
212 |
213 | # ray directions for all pixels, same for all images (same H, W, focal)
214 | self.directions = \
215 | get_ray_directions(self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3)
216 |
217 | if self.split == 'train': # create buffer of all rays and rgb data
218 | # use first N_images-1 to train, the LAST is val
219 | self.all_rays = []
220 | self.all_rgbs = []
221 | for i, image_path in enumerate(self.image_paths):
222 | if i == val_idx: # exclude the val image
223 | continue
224 | c2w = torch.FloatTensor(self.poses[i])
225 |
226 | img = Image.open(image_path).convert('RGB')
227 | assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \
228 | f'''{image_path} has different aspect ratio than img_wh,
229 | please check your data!'''
230 | img = img.resize(self.img_wh, Image.LANCZOS)
231 | img = self.transform(img) # (3, h, w)
232 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
233 | self.all_rgbs += [img]
234 |
235 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
236 | if not self.spheric_poses:
237 | near, far = 0, 1
238 | rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0],
239 | self.focal, 1.0, rays_o, rays_d)
240 | # near plane is always at 1.0
241 | # near and far in NDC are always 0 and 1
242 | # See https://github.com/bmild/nerf/issues/34
243 | else:
244 | near = self.bounds.min()
245 | far = min(8 * near, self.bounds.max()) # focus on central object only
246 |
247 | self.all_rays += [torch.cat([rays_o, rays_d,
248 | near*torch.ones_like(rays_o[:, :1]),
249 | far*torch.ones_like(rays_o[:, :1])],
250 | 1)] # (h*w, 8)
251 |
252 | self.all_rays = torch.cat(self.all_rays, 0) # ((N_images-1)*h*w, 8)
253 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3)
254 |
255 | elif self.split == 'val':
256 | print('val image is', self.image_paths[val_idx])
257 | self.c2w_val = self.poses[val_idx]
258 | self.image_path_val = self.image_paths[val_idx]
259 |
260 | else: # for testing, create a parametric rendering path
261 | if self.split.endswith('train'): # test on training set
262 | self.poses_test = self.poses
263 | elif not self.spheric_poses:
264 | focus_depth = 3.5 # hardcoded, this is numerically close to the formula
265 | # given in the original repo. Mathematically if near=1
266 | # and far=infinity, then this number will converge to 4
267 | radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0)
268 | self.poses_test = create_spiral_poses(radii, focus_depth)
269 | else:
270 | radius = 1.1 * self.bounds.min()
271 | self.poses_test = create_spheric_poses(radius)
272 |
273 | def define_transforms(self):
274 | self.transform = T.ToTensor()
275 |
276 | def __len__(self):
277 | if self.split == 'train':
278 | return len(self.all_rays)
279 | if self.split == 'val':
280 | return self.val_num
281 | return len(self.poses_test)
282 |
283 | def __getitem__(self, idx):
284 | if self.split == 'train': # use data in the buffers
285 | sample = {'rays': self.all_rays[idx],
286 | 'rgbs': self.all_rgbs[idx]}
287 |
288 | else:
289 | if self.split == 'val':
290 | c2w = torch.FloatTensor(self.c2w_val)
291 | else:
292 | c2w = torch.FloatTensor(self.poses_test[idx])
293 |
294 | rays_o, rays_d = get_rays(self.directions, c2w)
295 | if not self.spheric_poses:
296 | near, far = 0, 1
297 | rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0],
298 | self.focal, 1.0, rays_o, rays_d)
299 | else:
300 | near = self.bounds.min()
301 | far = min(8 * near, self.bounds.max())
302 |
303 | rays = torch.cat([rays_o, rays_d,
304 | near*torch.ones_like(rays_o[:, :1]),
305 | far*torch.ones_like(rays_o[:, :1])],
306 | 1) # (h*w, 8)
307 |
308 | sample = {'rays': rays,
309 | 'c2w': c2w}
310 |
311 | if self.split == 'val':
312 | img = Image.open(self.image_path_val).convert('RGB')
313 | img = img.resize(self.img_wh, Image.LANCZOS)
314 | img = self.transform(img) # (3, h, w)
315 | img = img.view(3, -1).permute(1, 0) # (h*w, 3)
316 | sample['rgbs'] = img
317 |
318 | return sample
319 |
--------------------------------------------------------------------------------
/datasets/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia import create_meshgrid
3 |
4 |
5 | def get_ray_directions(H, W, focal):
6 | """
7 | Get ray directions for all pixels in camera coordinate.
8 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
9 | ray-tracing-generating-camera-rays/standard-coordinate-systems
10 |
11 | Inputs:
12 | H, W, focal: image height, width and focal length
13 |
14 | Outputs:
15 | directions: (H, W, 3), the direction of the rays in camera coordinate
16 | """
17 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
18 | i, j = grid.unbind(-1)
19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
20 | # see https://github.com/bmild/nerf/issues/24
21 | directions = \
22 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3)
23 |
24 | return directions
25 |
26 |
27 | def get_rays(directions, c2w):
28 | """
29 | Get ray origin and normalized directions in world coordinate for all pixels in one image.
30 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
31 | ray-tracing-generating-camera-rays/standard-coordinate-systems
32 |
33 | Inputs:
34 | directions: (H, W, 3) precomputed ray directions in camera coordinate
35 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
36 |
37 | Outputs:
38 | rays_o: (H*W, 3), the origin of the rays in world coordinate
39 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
40 | """
41 | # Rotate ray directions from camera coordinate to the world coordinate
42 | rays_d = directions @ c2w[:, :3].T # (H, W, 3)
43 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
44 | # The origin of all rays is the camera origin in world coordinate
45 | rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)
46 |
47 | rays_d = rays_d.view(-1, 3)
48 | rays_o = rays_o.view(-1, 3)
49 |
50 | return rays_o, rays_d
51 |
52 |
53 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d):
54 | """
55 | Transform rays from world coordinate to NDC.
56 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis.
57 | For detailed derivation, please see:
58 | http://www.songho.ca/opengl/gl_projectionmatrix.html
59 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf
60 |
61 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth).
62 | See https://github.com/bmild/nerf/issues/18
63 |
64 | Inputs:
65 | H, W, focal: image height, width and focal length
66 | near: (N_rays) or float, the depths of the near plane
67 | rays_o: (N_rays, 3), the origin of the rays in world coordinate
68 | rays_d: (N_rays, 3), the direction of the rays in world coordinate
69 |
70 | Outputs:
71 | rays_o: (N_rays, 3), the origin of the rays in NDC
72 | rays_d: (N_rays, 3), the direction of the rays in NDC
73 | """
74 | # Shift ray origins to near plane
75 | t = -(near + rays_o[...,2]) / rays_d[...,2]
76 | rays_o = rays_o + t[...,None] * rays_d
77 |
78 | # Store some intermediate homogeneous results
79 | ox_oz = rays_o[...,0] / rays_o[...,2]
80 | oy_oz = rays_o[...,1] / rays_o[...,2]
81 |
82 | # Projection
83 | o0 = -1./(W/(2.*focal)) * ox_oz
84 | o1 = -1./(H/(2.*focal)) * oy_oz
85 | o2 = 1. + 2. * near / rays_o[...,2]
86 |
87 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz)
88 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz)
89 | d2 = 1 - o2
90 |
91 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3)
92 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3)
93 |
94 | return rays_o, rays_d
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _site
2 | .sass-cache
3 | .jekyll-cache
4 | .jekyll-metadata
5 | vendor
6 |
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source 'https://rubygems.org'
2 | gem 'github-pages', group: :jekyll_plugins
3 |
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | activesupport (>= 6.0.3.1)
5 | concurrent-ruby (~> 1.0, >= 1.0.2)
6 | i18n (>= 0.7, < 2)
7 | minitest (~> 5.1)
8 | tzinfo (~> 1.1)
9 | zeitwerk (~> 2.2, >= 2.2.2)
10 | addressable (2.7.0)
11 | public_suffix (>= 2.0.2, < 5.0)
12 | coffee-script (2.4.1)
13 | coffee-script-source
14 | execjs
15 | coffee-script-source (1.11.1)
16 | colorator (1.1.0)
17 | commonmarker (0.17.13)
18 | ruby-enum (~> 0.5)
19 | concurrent-ruby (1.1.6)
20 | dnsruby (1.61.3)
21 | addressable (~> 2.5)
22 | em-websocket (0.5.1)
23 | eventmachine (>= 0.12.9)
24 | http_parser.rb (~> 0.6.0)
25 | ethon (0.12.0)
26 | ffi (>= 1.3.0)
27 | eventmachine (1.2.7)
28 | execjs (2.7.0)
29 | faraday (1.0.1)
30 | multipart-post (>= 1.2, < 3)
31 | ffi (1.12.2)
32 | forwardable-extended (2.6.0)
33 | gemoji (3.0.1)
34 | github-pages (204)
35 | github-pages-health-check (= 1.16.1)
36 | jekyll (= 3.8.5)
37 | jekyll-avatar (= 0.7.0)
38 | jekyll-coffeescript (= 1.1.1)
39 | jekyll-commonmark-ghpages (= 0.1.6)
40 | jekyll-default-layout (= 0.1.4)
41 | jekyll-feed (= 0.13.0)
42 | jekyll-gist (= 1.5.0)
43 | jekyll-github-metadata (= 2.13.0)
44 | jekyll-mentions (= 1.5.1)
45 | jekyll-optional-front-matter (= 0.3.2)
46 | jekyll-paginate (= 1.1.0)
47 | jekyll-readme-index (= 0.3.0)
48 | jekyll-redirect-from (= 0.15.0)
49 | jekyll-relative-links (= 0.6.1)
50 | jekyll-remote-theme (= 0.4.1)
51 | jekyll-sass-converter (= 1.5.2)
52 | jekyll-seo-tag (= 2.6.1)
53 | jekyll-sitemap (= 1.4.0)
54 | jekyll-swiss (= 1.0.0)
55 | jekyll-theme-architect (= 0.1.1)
56 | jekyll-theme-cayman (= 0.1.1)
57 | jekyll-theme-dinky (= 0.1.1)
58 | jekyll-theme-hacker (= 0.1.1)
59 | jekyll-theme-leap-day (= 0.1.1)
60 | jekyll-theme-merlot (= 0.1.1)
61 | jekyll-theme-midnight (= 0.1.1)
62 | jekyll-theme-minimal (= 0.1.1)
63 | jekyll-theme-modernist (= 0.1.1)
64 | jekyll-theme-primer (= 0.5.4)
65 | jekyll-theme-slate (= 0.1.1)
66 | jekyll-theme-tactile (= 0.1.1)
67 | jekyll-theme-time-machine (= 0.1.1)
68 | jekyll-titles-from-headings (= 0.5.3)
69 | jemoji (= 0.11.1)
70 | kramdown (>= 2.3.0)
71 | liquid (= 4.0.3)
72 | mercenary (~> 0.3)
73 | minima (= 2.5.1)
74 | nokogiri (>= 1.11.0.rc4, < 2.0)
75 | rouge (= 3.13.0)
76 | terminal-table (~> 1.4)
77 | github-pages-health-check (1.16.1)
78 | addressable (~> 2.3)
79 | dnsruby (~> 1.60)
80 | octokit (~> 4.0)
81 | public_suffix (~> 3.0)
82 | typhoeus (~> 1.3)
83 | html-pipeline (2.12.3)
84 | activesupport (>= 2)
85 | nokogiri (>= 1.4)
86 | http_parser.rb (0.6.0)
87 | i18n (0.9.5)
88 | concurrent-ruby (~> 1.0)
89 | jekyll (3.8.5)
90 | addressable (~> 2.4)
91 | colorator (~> 1.0)
92 | em-websocket (~> 0.5)
93 | i18n (~> 0.7)
94 | jekyll-sass-converter (~> 1.0)
95 | jekyll-watch (~> 2.0)
96 | kramdown (~> 1.14)
97 | liquid (~> 4.0)
98 | mercenary (~> 0.3.3)
99 | pathutil (~> 0.9)
100 | rouge (>= 1.7, < 4)
101 | safe_yaml (~> 1.0)
102 | jekyll-avatar (0.7.0)
103 | jekyll (>= 3.0, < 5.0)
104 | jekyll-coffeescript (1.1.1)
105 | coffee-script (~> 2.2)
106 | coffee-script-source (~> 1.11.1)
107 | jekyll-commonmark (1.3.1)
108 | commonmarker (~> 0.14)
109 | jekyll (>= 3.7, < 5.0)
110 | jekyll-commonmark-ghpages (0.1.6)
111 | commonmarker (~> 0.17.6)
112 | jekyll-commonmark (~> 1.2)
113 | rouge (>= 2.0, < 4.0)
114 | jekyll-default-layout (0.1.4)
115 | jekyll (~> 3.0)
116 | jekyll-feed (0.13.0)
117 | jekyll (>= 3.7, < 5.0)
118 | jekyll-gist (1.5.0)
119 | octokit (~> 4.2)
120 | jekyll-github-metadata (2.13.0)
121 | jekyll (>= 3.4, < 5.0)
122 | octokit (~> 4.0, != 4.4.0)
123 | jekyll-mentions (1.5.1)
124 | html-pipeline (~> 2.3)
125 | jekyll (>= 3.7, < 5.0)
126 | jekyll-optional-front-matter (0.3.2)
127 | jekyll (>= 3.0, < 5.0)
128 | jekyll-paginate (1.1.0)
129 | jekyll-readme-index (0.3.0)
130 | jekyll (>= 3.0, < 5.0)
131 | jekyll-redirect-from (0.15.0)
132 | jekyll (>= 3.3, < 5.0)
133 | jekyll-relative-links (0.6.1)
134 | jekyll (>= 3.3, < 5.0)
135 | jekyll-remote-theme (0.4.1)
136 | addressable (~> 2.0)
137 | jekyll (>= 3.5, < 5.0)
138 | rubyzip (>= 1.3.0)
139 | jekyll-sass-converter (1.5.2)
140 | sass (~> 3.4)
141 | jekyll-seo-tag (2.6.1)
142 | jekyll (>= 3.3, < 5.0)
143 | jekyll-sitemap (1.4.0)
144 | jekyll (>= 3.7, < 5.0)
145 | jekyll-swiss (1.0.0)
146 | jekyll-theme-architect (0.1.1)
147 | jekyll (~> 3.5)
148 | jekyll-seo-tag (~> 2.0)
149 | jekyll-theme-cayman (0.1.1)
150 | jekyll (~> 3.5)
151 | jekyll-seo-tag (~> 2.0)
152 | jekyll-theme-dinky (0.1.1)
153 | jekyll (~> 3.5)
154 | jekyll-seo-tag (~> 2.0)
155 | jekyll-theme-hacker (0.1.1)
156 | jekyll (~> 3.5)
157 | jekyll-seo-tag (~> 2.0)
158 | jekyll-theme-leap-day (0.1.1)
159 | jekyll (~> 3.5)
160 | jekyll-seo-tag (~> 2.0)
161 | jekyll-theme-merlot (0.1.1)
162 | jekyll (~> 3.5)
163 | jekyll-seo-tag (~> 2.0)
164 | jekyll-theme-midnight (0.1.1)
165 | jekyll (~> 3.5)
166 | jekyll-seo-tag (~> 2.0)
167 | jekyll-theme-minimal (0.1.1)
168 | jekyll (~> 3.5)
169 | jekyll-seo-tag (~> 2.0)
170 | jekyll-theme-modernist (0.1.1)
171 | jekyll (~> 3.5)
172 | jekyll-seo-tag (~> 2.0)
173 | jekyll-theme-primer (0.5.4)
174 | jekyll (> 3.5, < 5.0)
175 | jekyll-github-metadata (~> 2.9)
176 | jekyll-seo-tag (~> 2.0)
177 | jekyll-theme-slate (0.1.1)
178 | jekyll (~> 3.5)
179 | jekyll-seo-tag (~> 2.0)
180 | jekyll-theme-tactile (0.1.1)
181 | jekyll (~> 3.5)
182 | jekyll-seo-tag (~> 2.0)
183 | jekyll-theme-time-machine (0.1.1)
184 | jekyll (~> 3.5)
185 | jekyll-seo-tag (~> 2.0)
186 | jekyll-titles-from-headings (0.5.3)
187 | jekyll (>= 3.3, < 5.0)
188 | jekyll-watch (2.2.1)
189 | listen (~> 3.0)
190 | jemoji (0.11.1)
191 | gemoji (~> 3.0)
192 | html-pipeline (~> 2.2)
193 | jekyll (>= 3.0, < 5.0)
194 | kramdown (>= 2.3.0)
195 | liquid (4.0.3)
196 | listen (3.2.1)
197 | rb-fsevent (~> 0.10, >= 0.10.3)
198 | rb-inotify (~> 0.9, >= 0.9.10)
199 | mercenary (0.3.6)
200 | mini_portile2 (2.4.0)
201 | minima (2.5.1)
202 | jekyll (>= 3.5, < 5.0)
203 | jekyll-feed (~> 0.9)
204 | jekyll-seo-tag (~> 2.1)
205 | minitest (5.14.1)
206 | multipart-post (2.1.1)
207 | nokogiri (1.10.9)
208 | mini_portile2 (~> 2.4.0)
209 | octokit (4.18.0)
210 | faraday (>= 0.9)
211 | sawyer (~> 0.8.0, >= 0.5.3)
212 | pathutil (0.16.2)
213 | forwardable-extended (~> 2.6)
214 | public_suffix (3.1.1)
215 | rb-fsevent (0.10.4)
216 | rb-inotify (0.10.1)
217 | ffi (~> 1.0)
218 | rouge (3.13.0)
219 | ruby-enum (0.8.0)
220 | i18n
221 | rubyzip (2.3.0)
222 | safe_yaml (1.0.5)
223 | sass (3.7.4)
224 | sass-listen (~> 4.0.0)
225 | sass-listen (4.0.0)
226 | rb-fsevent (~> 0.9, >= 0.9.4)
227 | rb-inotify (~> 0.9, >= 0.9.7)
228 | sawyer (0.8.2)
229 | addressable (>= 2.3.5)
230 | faraday (> 0.8, < 2.0)
231 | terminal-table (1.8.0)
232 | unicode-display_width (~> 1.1, >= 1.1.1)
233 | thread_safe (0.3.6)
234 | typhoeus (1.4.0)
235 | ethon (>= 0.9.0)
236 | tzinfo (1.2.7)
237 | thread_safe (~> 0.1)
238 | unicode-display_width (1.7.0)
239 | zeitwerk (2.3.0)
240 |
241 | PLATFORMS
242 | ruby
243 |
244 | DEPENDENCIES
245 | github-pages
246 |
247 | BUNDLED WITH
248 | 2.1.4
249 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
2 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | # Author of this implementation
11 | If you find my work helpful, please consider star this project!
12 | Star
13 |
14 |
15 | Quei-An Chen ([kwea123](https://github.com/kwea123)). Original author and photo credits: Ben Mildenhall ([bmild](https://github.com/bmild))
16 |
17 | # What can NeRF do?
18 |
19 |
20 |
21 |
22 | # 360 degree view synthesis
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
53 |
54 |
55 |
56 | # Colored 3D mesh reconstruction (photogrammetry)
57 | We can generate real colored mesh that allows the object to interact with other physical objects.
58 |
59 |
60 |
61 |
62 | # Real time volume rendering in Unity
63 | [Volume rendering](https://en.wikipedia.org/wiki/Volume_rendering) is a technique that doesn't require "real object". The model you see here is composed of rays, so we can cut off parts to see internal structures, also perform deforming effect in real time.
64 |
65 |
66 |
67 |
68 | # Mixed reality in Unity (doesn't work on FireFox, please use Chrome)
69 | Accurate depth allows us to embed virtual object inside real scenes with correct z-order.
70 |
71 |
72 |
73 |
74 | # Tutorial
75 |
76 | I also have tutorials on how to achieve above results using google colab:
77 |
78 |
79 |
80 |
81 |
82 | # Call for contribution
83 | If you are expert in Unity and know how to make more visual appealing effects for the models shown above, feel free to contact me! I can share my code and data with you, and put your name on my github page!
84 |
--------------------------------------------------------------------------------
/docs/style.css:
--------------------------------------------------------------------------------
1 | #main_content {
2 | max-width: 960px;
3 | }
4 |
5 | .slick-slide {
6 | margin: 0 10px;
7 | }
8 |
9 | .slick-prev{
10 | left: -40px;
11 | }
12 |
13 | .slick-prev:before {
14 | font-size: 40px;
15 | color: blueviolet;
16 | }
17 | .slick-next:before {
18 | font-size: 40px;
19 | color: blueviolet;
20 | }
21 |
22 | .slick-dots li button:before{
23 | font-size: 20px;
24 | line-height: 20px;
25 | }
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from collections import defaultdict
5 | from tqdm import tqdm
6 | import imageio
7 | from argparse import ArgumentParser
8 |
9 | from models.rendering import render_rays
10 | from models.nerf import *
11 |
12 | from utils import load_ckpt
13 | import metrics
14 |
15 | from datasets import dataset_dict
16 | from datasets.depth_utils import *
17 |
18 | torch.backends.cudnn.benchmark = True
19 |
20 | def get_opts():
21 | parser = ArgumentParser()
22 | parser.add_argument('--root_dir', type=str,
23 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego',
24 | help='root directory of dataset')
25 | parser.add_argument('--dataset_name', type=str, default='blender',
26 | choices=['blender', 'llff'],
27 | help='which dataset to validate')
28 | parser.add_argument('--scene_name', type=str, default='test',
29 | help='scene name, used as output folder name')
30 | parser.add_argument('--split', type=str, default='test',
31 | help='test or test_train')
32 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800],
33 | help='resolution (img_w, img_h) of the image')
34 | parser.add_argument('--spheric_poses', default=False, action="store_true",
35 | help='whether images are taken in spheric poses (for llff)')
36 |
37 | parser.add_argument('--N_samples', type=int, default=64,
38 | help='number of coarse samples')
39 | parser.add_argument('--N_importance', type=int, default=128,
40 | help='number of additional fine samples')
41 | parser.add_argument('--use_disp', default=False, action="store_true",
42 | help='use disparity depth sampling')
43 | parser.add_argument('--chunk', type=int, default=32*1024*4,
44 | help='chunk size to split the input to avoid OOM')
45 |
46 | parser.add_argument('--ckpt_path', type=str, required=True,
47 | help='pretrained checkpoint path to load')
48 |
49 | parser.add_argument('--save_depth', default=False, action="store_true",
50 | help='whether to save depth prediction')
51 | parser.add_argument('--depth_format', type=str, default='pfm',
52 | choices=['pfm', 'bytes'],
53 | help='which format to save')
54 |
55 | return parser.parse_args()
56 |
57 |
58 | @torch.no_grad()
59 | def batched_inference(models, embeddings,
60 | rays, N_samples, N_importance, use_disp,
61 | chunk,
62 | white_back):
63 | """Do batched inference on rays using chunk."""
64 | B = rays.shape[0]
65 | chunk = 1024*32
66 | results = defaultdict(list)
67 | for i in range(0, B, chunk):
68 | rendered_ray_chunks = \
69 | render_rays(models,
70 | embeddings,
71 | rays[i:i+chunk],
72 | N_samples,
73 | use_disp,
74 | 0,
75 | 0,
76 | N_importance,
77 | chunk,
78 | dataset.white_back,
79 | test_time=True)
80 |
81 | for k, v in rendered_ray_chunks.items():
82 | results[k] += [v]
83 |
84 | for k, v in results.items():
85 | results[k] = torch.cat(v, 0)
86 | return results
87 |
88 |
89 | if __name__ == "__main__":
90 | args = get_opts()
91 | w, h = args.img_wh
92 |
93 | kwargs = {'root_dir': args.root_dir,
94 | 'split': args.split,
95 | 'img_wh': tuple(args.img_wh)}
96 | if args.dataset_name == 'llff':
97 | kwargs['spheric_poses'] = args.spheric_poses
98 | dataset = dataset_dict[args.dataset_name](**kwargs)
99 |
100 | embedding_xyz = Embedding(3, 10)
101 | embedding_dir = Embedding(3, 4)
102 | nerf_coarse = NeRF()
103 | nerf_fine = NeRF()
104 | load_ckpt(nerf_coarse, args.ckpt_path, model_name='nerf_coarse')
105 | load_ckpt(nerf_fine, args.ckpt_path, model_name='nerf_fine')
106 | nerf_coarse.cuda().eval()
107 | nerf_fine.cuda().eval()
108 |
109 | models = [nerf_coarse, nerf_fine]
110 | embeddings = [embedding_xyz, embedding_dir]
111 |
112 | imgs = []
113 | psnrs = []
114 | dir_name = f'results/{args.dataset_name}/{args.scene_name}'
115 | os.makedirs(dir_name, exist_ok=True)
116 |
117 | for i in tqdm(range(len(dataset))):
118 | sample = dataset[i]
119 | rays = sample['rays'].cuda()
120 | results = batched_inference(models, embeddings, rays,
121 | args.N_samples, args.N_importance, args.use_disp,
122 | args.chunk,
123 | dataset.white_back)
124 |
125 | img_pred = results['rgb_fine'].view(h, w, 3).cpu().numpy()
126 |
127 | if args.save_depth:
128 | depth_pred = results['depth_fine'].view(h, w).cpu().numpy()
129 | depth_pred = np.nan_to_num(depth_pred)
130 | if args.depth_format == 'pfm':
131 | save_pfm(os.path.join(dir_name, f'depth_{i:03d}.pfm'), depth_pred)
132 | else:
133 | with open(f'depth_{i:03d}', 'wb') as f:
134 | f.write(depth_pred.tobytes())
135 |
136 | img_pred_ = (img_pred*255).astype(np.uint8)
137 | imgs += [img_pred_]
138 | imageio.imwrite(os.path.join(dir_name, f'{i:03d}.png'), img_pred_)
139 |
140 | if 'rgbs' in sample:
141 | rgbs = sample['rgbs']
142 | img_gt = rgbs.view(h, w, 3)
143 | psnrs += [metrics.psnr(img_gt, img_pred).item()]
144 |
145 | imageio.mimsave(os.path.join(dir_name, f'{args.scene_name}.gif'), imgs, fps=30)
146 |
147 | if psnrs:
148 | mean_psnr = np.mean(psnrs)
149 | print(f'Mean PSNR : {mean_psnr:.2f}')
--------------------------------------------------------------------------------
/extract_color_mesh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | import cv2
5 | from PIL import Image
6 | from collections import defaultdict
7 | from tqdm import tqdm
8 | import mcubes
9 | import open3d as o3d
10 | from plyfile import PlyData, PlyElement
11 | from argparse import ArgumentParser
12 |
13 | from models.rendering import *
14 | from models.nerf import *
15 |
16 | from utils import load_ckpt
17 |
18 | from datasets import dataset_dict
19 |
20 | torch.backends.cudnn.benchmark = True
21 |
22 | def get_opts():
23 | parser = ArgumentParser()
24 | parser.add_argument('--root_dir', type=str,
25 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego',
26 | help='root directory of dataset')
27 | parser.add_argument('--dataset_name', type=str, default='blender',
28 | choices=['blender', 'llff'],
29 | help='which dataset to validate')
30 | parser.add_argument('--scene_name', type=str, default='test',
31 | help='scene name, used as output ply filename')
32 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800],
33 | help='resolution (img_w, img_h) of the image')
34 |
35 | parser.add_argument('--N_samples', type=int, default=64,
36 | help='number of samples to infer the acculmulated opacity')
37 | parser.add_argument('--chunk', type=int, default=32*1024,
38 | help='chunk size to split the input to avoid OOM')
39 | parser.add_argument('--ckpt_path', type=str, required=True,
40 | help='pretrained checkpoint path to load')
41 |
42 | parser.add_argument('--N_grid', type=int, default=256,
43 | help='size of the grid on 1 side, larger=higher resolution')
44 | parser.add_argument('--x_range', nargs="+", type=float, default=[-1.0, 1.0],
45 | help='x range of the object')
46 | parser.add_argument('--y_range', nargs="+", type=float, default=[-1.0, 1.0],
47 | help='x range of the object')
48 | parser.add_argument('--z_range', nargs="+", type=float, default=[-1.0, 1.0],
49 | help='x range of the object')
50 | parser.add_argument('--sigma_threshold', type=float, default=20.0,
51 | help='threshold to consider a location is occupied')
52 | parser.add_argument('--occ_threshold', type=float, default=0.2,
53 | help='''threshold to consider a vertex is occluded.
54 | larger=fewer occluded pixels''')
55 |
56 | #### method using vertex normals ####
57 | parser.add_argument('--use_vertex_normal', action="store_true",
58 | help='use vertex normals to compute color')
59 | parser.add_argument('--N_importance', type=int, default=64,
60 | help='number of fine samples to infer the acculmulated opacity')
61 | parser.add_argument('--near_t', type=float, default=1.0,
62 | help='the near bound factor to start the ray')
63 |
64 | return parser.parse_args()
65 |
66 |
67 | @torch.no_grad()
68 | def f(models, embeddings, rays, N_samples, N_importance, chunk, white_back):
69 | """Do batched inference on rays using chunk."""
70 | B = rays.shape[0]
71 | results = defaultdict(list)
72 | for i in range(0, B, chunk):
73 | rendered_ray_chunks = \
74 | render_rays(models,
75 | embeddings,
76 | rays[i:i+chunk],
77 | N_samples,
78 | False,
79 | 0,
80 | 0,
81 | N_importance,
82 | chunk,
83 | white_back,
84 | test_time=True)
85 |
86 | for k, v in rendered_ray_chunks.items():
87 | results[k] += [v]
88 |
89 | for k, v in results.items():
90 | results[k] = torch.cat(v, 0)
91 | return results
92 |
93 |
94 | if __name__ == "__main__":
95 | args = get_opts()
96 |
97 | kwargs = {'root_dir': args.root_dir,
98 | 'img_wh': tuple(args.img_wh)}
99 | if args.dataset_name == 'llff':
100 | kwargs['spheric_poses'] = True
101 | kwargs['split'] = 'test'
102 | else:
103 | kwargs['split'] = 'train'
104 | dataset = dataset_dict[args.dataset_name](**kwargs)
105 |
106 | embedding_xyz = Embedding(3, 10)
107 | embedding_dir = Embedding(3, 4)
108 | embeddings = [embedding_xyz, embedding_dir]
109 | nerf_fine = NeRF()
110 | load_ckpt(nerf_fine, args.ckpt_path, model_name='nerf_fine')
111 | nerf_fine.cuda().eval()
112 |
113 | # define the dense grid for query
114 | N = args.N_grid
115 | xmin, xmax = args.x_range
116 | ymin, ymax = args.y_range
117 | zmin, zmax = args.z_range
118 | # assert xmax-xmin == ymax-ymin == zmax-zmin, 'the ranges must have the same length!'
119 | x = np.linspace(xmin, xmax, N)
120 | y = np.linspace(ymin, ymax, N)
121 | z = np.linspace(zmin, zmax, N)
122 |
123 | xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3)).cuda()
124 | dir_ = torch.zeros_like(xyz_).cuda()
125 | # sigma is independent of direction, so any value here will produce the same result
126 |
127 | # predict sigma (occupancy) for each grid location
128 | print('Predicting occupancy ...')
129 | with torch.no_grad():
130 | B = xyz_.shape[0]
131 | out_chunks = []
132 | for i in tqdm(range(0, B, args.chunk)):
133 | xyz_embedded = embedding_xyz(xyz_[i:i+args.chunk]) # (N, embed_xyz_channels)
134 | dir_embedded = embedding_dir(dir_[i:i+args.chunk]) # (N, embed_dir_channels)
135 | xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1)
136 | out_chunks += [nerf_fine(xyzdir_embedded)]
137 | rgbsigma = torch.cat(out_chunks, 0)
138 |
139 | sigma = rgbsigma[:, -1].cpu().numpy()
140 | sigma = np.maximum(sigma, 0).reshape(N, N, N)
141 |
142 | # perform marching cube algorithm to retrieve vertices and triangle mesh
143 | print('Extracting mesh ...')
144 | vertices, triangles = mcubes.marching_cubes(sigma, args.sigma_threshold)
145 |
146 | ##### Until mesh extraction here, it is the same as the original repo. ######
147 |
148 | vertices_ = (vertices/N).astype(np.float32)
149 | ## invert x and y coordinates (WHY? maybe because of the marching cubes algo)
150 | x_ = (ymax-ymin) * vertices_[:, 1] + ymin
151 | y_ = (xmax-xmin) * vertices_[:, 0] + xmin
152 | vertices_[:, 0] = x_
153 | vertices_[:, 1] = y_
154 | vertices_[:, 2] = (zmax-zmin) * vertices_[:, 2] + zmin
155 | vertices_.dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
156 |
157 | face = np.empty(len(triangles), dtype=[('vertex_indices', 'i4', (3,))])
158 | face['vertex_indices'] = triangles
159 |
160 | PlyData([PlyElement.describe(vertices_[:, 0], 'vertex'),
161 | PlyElement.describe(face, 'face')]).write(f'{args.scene_name}.ply')
162 |
163 | # remove noise in the mesh by keeping only the biggest cluster
164 | print('Removing noise ...')
165 | mesh = o3d.io.read_triangle_mesh(f"{args.scene_name}.ply")
166 | idxs, count, _ = mesh.cluster_connected_triangles()
167 | max_cluster_idx = np.argmax(count)
168 | triangles_to_remove = [i for i in range(len(face)) if idxs[i] != max_cluster_idx]
169 | mesh.remove_triangles_by_index(triangles_to_remove)
170 | mesh.remove_unreferenced_vertices()
171 | print(f'Mesh has {len(mesh.vertices)/1e6:.2f} M vertices and {len(mesh.triangles)/1e6:.2f} M faces.')
172 |
173 | vertices_ = np.asarray(mesh.vertices).astype(np.float32)
174 | triangles = np.asarray(mesh.triangles)
175 |
176 | # perform color prediction
177 | # Step 0. define constants (image width, height and intrinsics)
178 | W, H = args.img_wh
179 | K = np.array([[dataset.focal, 0, W/2],
180 | [0, dataset.focal, H/2],
181 | [0, 0, 1]]).astype(np.float32)
182 |
183 | # Step 1. transform vertices into world coordinate
184 | N_vertices = len(vertices_)
185 | vertices_homo = np.concatenate([vertices_, np.ones((N_vertices, 1))], 1) # (N, 4)
186 |
187 | if args.use_vertex_normal: ## use normal vector method as suggested by the author.
188 | ## see https://github.com/bmild/nerf/issues/44
189 | mesh.compute_vertex_normals()
190 | rays_d = torch.FloatTensor(np.asarray(mesh.vertex_normals))
191 | near = dataset.bounds.min() * torch.ones_like(rays_d[:, :1])
192 | far = dataset.bounds.max() * torch.ones_like(rays_d[:, :1])
193 | rays_o = torch.FloatTensor(vertices_) - rays_d * near * args.near_t
194 |
195 | nerf_coarse = NeRF()
196 | load_ckpt(nerf_coarse, args.ckpt_path, model_name='nerf_coarse')
197 | nerf_coarse.cuda().eval()
198 |
199 | results = f([nerf_coarse, nerf_fine], embeddings,
200 | torch.cat([rays_o, rays_d, near, far], 1).cuda(),
201 | args.N_samples,
202 | args.N_importance,
203 | args.chunk,
204 | dataset.white_back)
205 |
206 | else: ## use my color average method. see README_mesh.md
207 | ## buffers to store the final averaged color
208 | non_occluded_sum = np.zeros((N_vertices, 1))
209 | v_color_sum = np.zeros((N_vertices, 3))
210 |
211 | # Step 2. project the vertices onto each training image to infer the color
212 | print('Fusing colors ...')
213 | for idx in tqdm(range(len(dataset.image_paths))):
214 | ## read image of this pose
215 | image = Image.open(dataset.image_paths[idx]).convert('RGB')
216 | image = image.resize(tuple(args.img_wh), Image.LANCZOS)
217 | image = np.array(image)
218 |
219 | ## read the camera to world relative pose
220 | P_c2w = np.concatenate([dataset.poses[idx], np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
221 | P_w2c = np.linalg.inv(P_c2w)[:3] # (3, 4)
222 | ## project vertices from world coordinate to camera coordinate
223 | vertices_cam = (P_w2c @ vertices_homo.T) # (3, N) in "right up back"
224 | vertices_cam[1:] *= -1 # (3, N) in "right down forward"
225 | ## project vertices from camera coordinate to pixel coordinate
226 | vertices_image = (K @ vertices_cam).T # (N, 3)
227 | depth = vertices_image[:, -1:]+1e-5 # the depth of the vertices, used as far plane
228 | vertices_image = vertices_image[:, :2]/depth
229 | vertices_image = vertices_image.astype(np.float32)
230 | vertices_image[:, 0] = np.clip(vertices_image[:, 0], 0, W-1)
231 | vertices_image[:, 1] = np.clip(vertices_image[:, 1], 0, H-1)
232 |
233 | ## compute the color on these projected pixel coordinates
234 | ## using bilinear interpolation.
235 | ## NOTE: opencv's implementation has a size limit of 32768 pixels per side,
236 | ## so we split the input into chunks.
237 | colors = []
238 | remap_chunk = int(3e4)
239 | for i in range(0, N_vertices, remap_chunk):
240 | colors += [cv2.remap(image,
241 | vertices_image[i:i+remap_chunk, 0],
242 | vertices_image[i:i+remap_chunk, 1],
243 | interpolation=cv2.INTER_LINEAR)[:, 0]]
244 | colors = np.vstack(colors) # (N_vertices, 3)
245 |
246 | ## predict occlusion of each vertex
247 | ## we leverage the concept of NeRF by constructing rays coming out from the camera
248 | ## and hitting each vertex; by computing the accumulated opacity along this path,
249 | ## we can know if the vertex is occluded or not.
250 | ## for vertices that appear to be occluded from every input view, we make the
251 | ## assumption that its color is the same as its neighbors that are facing our side.
252 | ## (think of a surface with one side facing us: we assume the other side has the same color)
253 |
254 | ## ray's origin is camera origin
255 | rays_o = torch.FloatTensor(dataset.poses[idx][:, -1]).expand(N_vertices, 3)
256 | ## ray's direction is the vector pointing from camera origin to the vertices
257 | rays_d = torch.FloatTensor(vertices_) - rays_o # (N_vertices, 3)
258 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
259 | near = dataset.bounds.min() * torch.ones_like(rays_o[:, :1])
260 | ## the far plane is the depth of the vertices, since what we want is the accumulated
261 | ## opacity along the path from camera origin to the vertices
262 | far = torch.FloatTensor(depth) * torch.ones_like(rays_o[:, :1])
263 | results = f([nerf_fine], embeddings,
264 | torch.cat([rays_o, rays_d, near, far], 1).cuda(),
265 | args.N_samples,
266 | 0,
267 | args.chunk,
268 | dataset.white_back)
269 | opacity = results['opacity_coarse'].cpu().numpy()[:, np.newaxis] # (N_vertices, 1)
270 | opacity = np.nan_to_num(opacity, 1)
271 |
272 | non_occluded = np.ones_like(non_occluded_sum) * 0.1/depth # weight by inverse depth
273 | # near=more confident in color
274 | non_occluded += opacity < args.occ_threshold
275 |
276 | v_color_sum += colors * non_occluded
277 | non_occluded_sum += non_occluded
278 |
279 | # Step 3. combine the output and write to file
280 | if args.use_vertex_normal:
281 | v_colors = results['rgb_fine'].cpu().numpy() * 255.0
282 | else: ## the combined color is the average color among all views
283 | v_colors = v_color_sum/non_occluded_sum
284 | v_colors = v_colors.astype(np.uint8)
285 | v_colors.dtype = [('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
286 | vertices_.dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
287 | vertex_all = np.empty(N_vertices, vertices_.dtype.descr+v_colors.dtype.descr)
288 | for prop in vertices_.dtype.names:
289 | vertex_all[prop] = vertices_[prop][:, 0]
290 | for prop in v_colors.dtype.names:
291 | vertex_all[prop] = v_colors[prop][:, 0]
292 |
293 | face = np.empty(len(triangles), dtype=[('vertex_indices', 'i4', (3,))])
294 | face['vertex_indices'] = triangles
295 |
296 | PlyData([PlyElement.describe(vertex_all, 'vertex'),
297 | PlyElement.describe(face, 'face')]).write(f'{args.scene_name}.ply')
298 |
299 | print('Done!')
300 |
--------------------------------------------------------------------------------
/extract_mesh.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "from collections import defaultdict\n",
11 | "import numpy as np\n",
12 | "import mcubes\n",
13 | "import trimesh\n",
14 | "\n",
15 | "from models.rendering import *\n",
16 | "from models.nerf import *\n",
17 | "\n",
18 | "from datasets import dataset_dict\n",
19 | "\n",
20 | "from utils import load_ckpt"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "# Load model and data"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "# Change here #\n",
37 | "img_wh = (800, 800) # full resolution of the input images\n",
38 | "dataset_name = 'blender' # blender or llff (own data)\n",
39 | "scene_name = 'lego' # whatever you want\n",
40 | "root_dir = '/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego/' # the folder containing data\n",
41 | "ckpt_path = 'ckpts/exp2/epoch=05.ckpt' # the model path\n",
42 | "###############\n",
43 | "\n",
44 | "kwargs = {'root_dir': root_dir,\n",
45 | " 'img_wh': img_wh}\n",
46 | "if dataset_name == 'llff':\n",
47 | " kwargs['spheric_poses'] = True\n",
48 | " kwargs['split'] = 'test'\n",
49 | "else:\n",
50 | " kwargs['split'] = 'train'\n",
51 | " \n",
52 | "chunk = 1024*32\n",
53 | "dataset = dataset_dict[dataset_name](**kwargs)\n",
54 | "\n",
55 | "embedding_xyz = Embedding(3, 10)\n",
56 | "embedding_dir = Embedding(3, 4)\n",
57 | "\n",
58 | "nerf_fine = NeRF()\n",
59 | "load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')\n",
60 | "nerf_fine.cuda().eval();"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "# Search for tight bounds of the object (trial and error!)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "### Tune these parameters until the whole object lies tightly in range with little noise ###\n",
77 | "N = 128 # controls the resolution, set this number small here because we're only finding\n",
78 | " # good ranges here, not yet for mesh reconstruction; we can set this number high\n",
79 | " # when it comes to final reconstruction.\n",
80 | "xmin, xmax = -1.2, 1.2 # left/right range\n",
81 | "ymin, ymax = -1.2, 1.2 # forward/backward range\n",
82 | "zmin, zmax = -1.2, 1.2 # up/down range\n",
83 | "## Attention! the ranges MUST have the same length!\n",
84 | "sigma_threshold = 50. # controls the noise (lower=maybe more noise; higher=some mesh might be missing)\n",
85 | "############################################################################################\n",
86 | "\n",
87 | "x = np.linspace(xmin, xmax, N)\n",
88 | "y = np.linspace(ymin, ymax, N)\n",
89 | "z = np.linspace(zmin, zmax, N)\n",
90 | "\n",
91 | "xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3)).cuda()\n",
92 | "dir_ = torch.zeros_like(xyz_).cuda()\n",
93 | "\n",
94 | "with torch.no_grad():\n",
95 | " B = xyz_.shape[0]\n",
96 | " out_chunks = []\n",
97 | " for i in range(0, B, chunk):\n",
98 | " xyz_embedded = embedding_xyz(xyz_[i:i+chunk]) # (N, embed_xyz_channels)\n",
99 | " dir_embedded = embedding_dir(dir_[i:i+chunk]) # (N, embed_dir_channels)\n",
100 | " xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1)\n",
101 | " out_chunks += [nerf_fine(xyzdir_embedded)]\n",
102 | " rgbsigma = torch.cat(out_chunks, 0)\n",
103 | " \n",
104 | "sigma = rgbsigma[:, -1].cpu().numpy()\n",
105 | "sigma = np.maximum(sigma, 0)\n",
106 | "sigma = sigma.reshape(N, N, N)\n",
107 | "\n",
108 | "# The below lines are for visualization, COMMENT OUT once you find the best range and increase N!\n",
109 | "vertices, triangles = mcubes.marching_cubes(sigma, sigma_threshold)\n",
110 | "mesh = trimesh.Trimesh(vertices/N, triangles)\n",
111 | "mesh.show()"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "# # You can already export \"colorless\" mesh if you don't need color\n",
121 | "# mcubes.export_mesh(vertices, triangles, f\"{scene_name}.dae\")"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "# Generate .vol file for volume rendering in Unity"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "assert N==512, \\\n",
138 | " 'Please set N to 512 in the two above cell! Remember to comment out the visualization code (last 3 lines)!'\n",
139 | "\n",
140 | "a = 1-np.exp(-(xmax-xmin)/N*sigma)\n",
141 | "a = a.flatten()\n",
142 | "rgb = (rgbsigma[:, :3].numpy()*255).astype(np.uint32)\n",
143 | "i = np.where(a>0)[0] # valid indices (alpha>0)\n",
144 | "\n",
145 | "rgb = rgb[i]\n",
146 | "a = a[i]\n",
147 | "s = rgb.dot(np.array([1<<24, 1<<16, 1<<8])) + (a*255).astype(np.uint32)\n",
148 | "res = np.stack([i, s], -1).astype(np.uint32).flatten()\n",
149 | "with open(f'{scene_name}.vol', 'wb') as f:\n",
150 | " f.write(res.tobytes())"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "# Extract colored mesh\n",
158 | "\n",
159 | "Once you find the best range, now **RESTART** the notebook, and copy the configs to the following cell\n",
160 | "and execute it."
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "# Copy the variables you have above here! ####\n",
170 | "img_wh = (800, 800) # full resolution of the input images\n",
171 | "dataset_name = 'blender' # blender or llff (own data)\n",
172 | "scene_name = 'lego' # whatever you want\n",
173 | "root_dir = '/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego/' # the folder containing data\n",
174 | "ckpt_path = 'ckpts/exp2/epoch=05.ckpt' # the model path\n",
175 | "\n",
176 | "N = 128 # controls the resolution, set this number small here because we're only finding\n",
177 | " # good ranges here, not yet for mesh reconstruction; we can set this number high\n",
178 | " # when it comes to final reconstruction.\n",
179 | "xmin, xmax = -1.2, 1.2 # left/right range\n",
180 | "ymin, ymax = -1.2, 1.2 # forward/backward range\n",
181 | "zmin, zmax = -1.2, 1.2 # up/down range\n",
182 | "sigma_threshold = 50. # controls the noise (lower=maybe more noise; higher=some mesh might be missing)\n",
183 | "###############################################\n",
184 | "\n",
185 | "import os\n",
186 | "os.environ['ROOT_DIR'] = root_dir\n",
187 | "os.environ['DATASET_NAME'] = dataset_name\n",
188 | "os.environ['SCENE_NAME'] = scene_name\n",
189 | "os.environ['IMG_SIZE'] = f\"{img_wh[0]} {img_wh[1]}\"\n",
190 | "os.environ['CKPT_PATH'] = ckpt_path\n",
191 | "os.environ['N_GRID'] = \"256\" # final resolution. You can set this number high to preserve more details\n",
192 | "os.environ['X_RANGE'] = f\"{xmin} {xmax}\"\n",
193 | "os.environ['Y_RANGE'] = f\"{ymin} {ymax}\"\n",
194 | "os.environ['Z_RANGE'] = f\"{zmin} {zmax}\"\n",
195 | "os.environ['SIGMA_THRESHOLD'] = str(sigma_threshold)\n",
196 | "os.environ['OCC_THRESHOLD'] = \"0.2\" # probably doesn't require tuning. If you find the color is not close\n",
197 | " # to real, try to set this number smaller (the effect of this number\n",
198 | " # is explained in my youtube video)\n",
199 | "\n",
200 | "!python extract_color_mesh.py \\\n",
201 | " --root_dir $ROOT_DIR \\\n",
202 | " --dataset_name $DATASET_NAME \\\n",
203 | " --scene_name $SCENE_NAME \\\n",
204 | " --img_wh $IMG_SIZE \\\n",
205 | " --ckpt_path $CKPT_PATH \\\n",
206 | " --N_grid $N_GRID \\\n",
207 | " --x_range $X_RANGE \\\n",
208 | " --y_range $Y_RANGE \\\n",
209 | " --z_range $Z_RANGE \\\n",
210 | " --sigma_threshold $SIGMA_THRESHOLD \\\n",
211 | " --occ_threshold $OCC_THRESHOLD"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": []
220 | }
221 | ],
222 | "metadata": {
223 | "kernelspec": {
224 | "display_name": "nerf_pl",
225 | "language": "python",
226 | "name": "nerf_pl"
227 | },
228 | "language_info": {
229 | "codemirror_mode": {
230 | "name": "ipython",
231 | "version": 3
232 | },
233 | "file_extension": ".py",
234 | "mimetype": "text/x-python",
235 | "name": "python",
236 | "nbconvert_exporter": "python",
237 | "pygments_lexer": "ipython3",
238 | "version": "3.6.10"
239 | }
240 | },
241 | "nbformat": 4,
242 | "nbformat_minor": 4
243 | }
244 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class MSELoss(nn.Module):
5 | def __init__(self):
6 | super(MSELoss, self).__init__()
7 | self.loss = nn.MSELoss(reduction='mean')
8 |
9 | def forward(self, inputs, targets):
10 | loss = self.loss(inputs['rgb_coarse'], targets)
11 | if 'rgb_fine' in inputs:
12 | loss += self.loss(inputs['rgb_fine'], targets)
13 |
14 | return loss
15 |
16 |
17 | loss_dict = {'mse': MSELoss}
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim as dssim
3 |
4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
5 | value = (image_pred-image_gt)**2
6 | if valid_mask is not None:
7 | value = value[valid_mask]
8 | if reduction == 'mean':
9 | return torch.mean(value)
10 | return value
11 |
12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'):
13 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction))
14 |
15 | def ssim(image_pred, image_gt, reduction='mean'):
16 | """
17 | image_pred and image_gt: (1, 3, H, W)
18 | """
19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1]
20 | return 1-2*dssim_ # in [-1, 1]
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwea123/nerf_pl/52aeb387da64a9ad9a0f914ea9b049ffc598b20c/models/__init__.py
--------------------------------------------------------------------------------
/models/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class Embedding(nn.Module):
5 | def __init__(self, in_channels, N_freqs, logscale=True):
6 | """
7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
8 | in_channels: number of input channels (3 for both xyz and direction)
9 | """
10 | super(Embedding, self).__init__()
11 | self.N_freqs = N_freqs
12 | self.in_channels = in_channels
13 | self.funcs = [torch.sin, torch.cos]
14 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
15 |
16 | if logscale:
17 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
18 | else:
19 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
20 |
21 | def forward(self, x):
22 | """
23 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
24 | Different from the paper, "x" is also in the output
25 | See https://github.com/bmild/nerf/issues/12
26 |
27 | Inputs:
28 | x: (B, self.in_channels)
29 |
30 | Outputs:
31 | out: (B, self.out_channels)
32 | """
33 | out = [x]
34 | for freq in self.freq_bands:
35 | for func in self.funcs:
36 | out += [func(freq*x)]
37 |
38 | return torch.cat(out, -1)
39 |
40 |
41 | class NeRF(nn.Module):
42 | def __init__(self,
43 | D=8, W=256,
44 | in_channels_xyz=63, in_channels_dir=27,
45 | skips=[4]):
46 | """
47 | D: number of layers for density (sigma) encoder
48 | W: number of hidden units in each layer
49 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default)
50 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default)
51 | skips: add skip connection in the Dth layer
52 | """
53 | super(NeRF, self).__init__()
54 | self.D = D
55 | self.W = W
56 | self.in_channels_xyz = in_channels_xyz
57 | self.in_channels_dir = in_channels_dir
58 | self.skips = skips
59 |
60 | # xyz encoding layers
61 | for i in range(D):
62 | if i == 0:
63 | layer = nn.Linear(in_channels_xyz, W)
64 | elif i in skips:
65 | layer = nn.Linear(W+in_channels_xyz, W)
66 | else:
67 | layer = nn.Linear(W, W)
68 | layer = nn.Sequential(layer, nn.ReLU(True))
69 | setattr(self, f"xyz_encoding_{i+1}", layer)
70 | self.xyz_encoding_final = nn.Linear(W, W)
71 |
72 | # direction encoding layers
73 | self.dir_encoding = nn.Sequential(
74 | nn.Linear(W+in_channels_dir, W//2),
75 | nn.ReLU(True))
76 |
77 | # output layers
78 | self.sigma = nn.Linear(W, 1)
79 | self.rgb = nn.Sequential(
80 | nn.Linear(W//2, 3),
81 | nn.Sigmoid())
82 |
83 | def forward(self, x, sigma_only=False):
84 | """
85 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet).
86 | For rendering this ray, please see rendering.py
87 |
88 | Inputs:
89 | x: (B, self.in_channels_xyz(+self.in_channels_dir))
90 | the embedded vector of position and direction
91 | sigma_only: whether to infer sigma only. If True,
92 | x is of shape (B, self.in_channels_xyz)
93 |
94 | Outputs:
95 | if sigma_ony:
96 | sigma: (B, 1) sigma
97 | else:
98 | out: (B, 4), rgb and sigma
99 | """
100 | if not sigma_only:
101 | input_xyz, input_dir = \
102 | torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
103 | else:
104 | input_xyz = x
105 |
106 | xyz_ = input_xyz
107 | for i in range(self.D):
108 | if i in self.skips:
109 | xyz_ = torch.cat([input_xyz, xyz_], -1)
110 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
111 |
112 | sigma = self.sigma(xyz_)
113 | if sigma_only:
114 | return sigma
115 |
116 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
117 |
118 | dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1)
119 | dir_encoding = self.dir_encoding(dir_encoding_input)
120 | rgb = self.rgb(dir_encoding)
121 |
122 | out = torch.cat([rgb, sigma], -1)
123 |
124 | return out
--------------------------------------------------------------------------------
/models/rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchsearchsorted import searchsorted
3 |
4 | __all__ = ['render_rays']
5 |
6 | """
7 | Function dependencies: (-> means function calls)
8 |
9 | @render_rays -> @inference
10 |
11 | @render_rays -> @sample_pdf if there is fine model
12 | """
13 |
14 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
15 | """
16 | Sample @N_importance samples from @bins with distribution defined by @weights.
17 |
18 | Inputs:
19 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
20 | weights: (N_rays, N_samples_)
21 | N_importance: the number of samples to draw from the distribution
22 | det: deterministic or not
23 | eps: a small number to prevent division by zero
24 |
25 | Outputs:
26 | samples: the sampled samples
27 | """
28 | N_rays, N_samples_ = weights.shape
29 | weights = weights + eps # prevent division by zero (don't do inplace op!)
30 | pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
31 | cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
32 | cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
33 | # padded to 0~1 inclusive
34 |
35 | if det:
36 | u = torch.linspace(0, 1, N_importance, device=bins.device)
37 | u = u.expand(N_rays, N_importance)
38 | else:
39 | u = torch.rand(N_rays, N_importance, device=bins.device)
40 | u = u.contiguous()
41 |
42 | inds = searchsorted(cdf, u, side='right')
43 | below = torch.clamp_min(inds-1, 0)
44 | above = torch.clamp_max(inds, N_samples_)
45 |
46 | inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
47 | cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
48 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
49 |
50 | denom = cdf_g[...,1]-cdf_g[...,0]
51 | denom[denom 0: # perturb sampling depths (z_vals)
198 | z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points
199 | # get intervals between samples
200 | upper = torch.cat([z_vals_mid, z_vals[: ,-1:]], -1)
201 | lower = torch.cat([z_vals[: ,:1], z_vals_mid], -1)
202 |
203 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
204 | z_vals = lower + (upper - lower) * perturb_rand
205 |
206 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \
207 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
208 |
209 | if test_time:
210 | weights_coarse = \
211 | inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,
212 | dir_embedded, z_vals, weights_only=True)
213 | result = {'opacity_coarse': weights_coarse.sum(1)}
214 | else:
215 | rgb_coarse, depth_coarse, weights_coarse = \
216 | inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d,
217 | dir_embedded, z_vals, weights_only=False)
218 | result = {'rgb_coarse': rgb_coarse,
219 | 'depth_coarse': depth_coarse,
220 | 'opacity_coarse': weights_coarse.sum(1)
221 | }
222 |
223 | if N_importance > 0: # sample points for fine model
224 | z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) # (N_rays, N_samples-1) interval mid points
225 | z_vals_ = sample_pdf(z_vals_mid, weights_coarse[:, 1:-1],
226 | N_importance, det=(perturb==0)).detach()
227 | # detach so that grad doesn't propogate to weights_coarse from here
228 |
229 | z_vals, _ = torch.sort(torch.cat([z_vals, z_vals_], -1), -1)
230 |
231 | xyz_fine_sampled = rays_o.unsqueeze(1) + \
232 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2)
233 | # (N_rays, N_samples+N_importance, 3)
234 |
235 | model_fine = models[1]
236 | rgb_fine, depth_fine, weights_fine = \
237 | inference(model_fine, embedding_xyz, xyz_fine_sampled, rays_d,
238 | dir_embedded, z_vals, weights_only=False)
239 |
240 | result['rgb_fine'] = rgb_fine
241 | result['depth_fine'] = depth_fine
242 | result['opacity_fine'] = weights_fine.sum(1)
243 |
244 | return result
245 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_opts():
4 | parser = argparse.ArgumentParser()
5 |
6 | parser.add_argument('--root_dir', type=str,
7 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego',
8 | help='root directory of dataset')
9 | parser.add_argument('--dataset_name', type=str, default='blender',
10 | choices=['blender', 'llff'],
11 | help='which dataset to train/val')
12 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800],
13 | help='resolution (img_w, img_h) of the image')
14 | parser.add_argument('--spheric_poses', default=False, action="store_true",
15 | help='whether images are taken in spheric poses (for llff)')
16 |
17 | parser.add_argument('--N_samples', type=int, default=64,
18 | help='number of coarse samples')
19 | parser.add_argument('--N_importance', type=int, default=128,
20 | help='number of additional fine samples')
21 | parser.add_argument('--use_disp', default=False, action="store_true",
22 | help='use disparity depth sampling')
23 | parser.add_argument('--perturb', type=float, default=1.0,
24 | help='factor to perturb depth sampling points')
25 | parser.add_argument('--noise_std', type=float, default=1.0,
26 | help='std dev of noise added to regularize sigma')
27 |
28 | parser.add_argument('--loss_type', type=str, default='mse',
29 | choices=['mse'],
30 | help='loss to use')
31 |
32 | parser.add_argument('--batch_size', type=int, default=1024,
33 | help='batch size')
34 | parser.add_argument('--chunk', type=int, default=32*1024,
35 | help='chunk size to split the input to avoid OOM')
36 | parser.add_argument('--num_epochs', type=int, default=16,
37 | help='number of training epochs')
38 | parser.add_argument('--num_gpus', type=int, default=1,
39 | help='number of gpus')
40 |
41 | parser.add_argument('--ckpt_path', type=str, default=None,
42 | help='pretrained checkpoint path to load')
43 | parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'],
44 | help='the prefixes to ignore in the checkpoint state dict')
45 |
46 | parser.add_argument('--optimizer', type=str, default='adam',
47 | help='optimizer type',
48 | choices=['sgd', 'adam', 'radam', 'ranger'])
49 | parser.add_argument('--lr', type=float, default=5e-4,
50 | help='learning rate')
51 | parser.add_argument('--momentum', type=float, default=0.9,
52 | help='learning rate momentum')
53 | parser.add_argument('--weight_decay', type=float, default=0,
54 | help='weight decay')
55 | parser.add_argument('--lr_scheduler', type=str, default='steplr',
56 | help='scheduler type',
57 | choices=['steplr', 'cosine', 'poly'])
58 | #### params for warmup, only applied when optimizer == 'sgd' or 'adam'
59 | parser.add_argument('--warmup_multiplier', type=float, default=1.0,
60 | help='lr is multiplied by this factor after --warmup_epochs')
61 | parser.add_argument('--warmup_epochs', type=int, default=0,
62 | help='Gradually warm-up(increasing) learning rate in optimizer')
63 | ###########################
64 | #### params for steplr ####
65 | parser.add_argument('--decay_step', nargs='+', type=int, default=[20],
66 | help='scheduler decay step')
67 | parser.add_argument('--decay_gamma', type=float, default=0.1,
68 | help='learning rate decay amount')
69 | ###########################
70 | #### params for poly ####
71 | parser.add_argument('--poly_exp', type=float, default=0.9,
72 | help='exponent for polynomial learning rate decay')
73 | ###########################
74 |
75 | parser.add_argument('--exp_name', type=str, default='exp',
76 | help='experiment name')
77 |
78 | return parser.parse_args()
79 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.4.0
2 | torchvision==0.5.0
3 | pytorch-lightning==0.7.5
4 | test-tube
5 | kornia==0.2.0
6 | opencv-python==4.2.0.34
7 | matplotlib
8 | jupyter
9 |
10 | # for mesh
11 | PyMCubes
12 | pycollada
13 | trimesh
14 | pyglet
15 |
16 | # for point cloud
17 | plyfile
18 | open3d
19 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from opt import get_opts
3 | import torch
4 | from collections import defaultdict
5 |
6 | from torch.utils.data import DataLoader
7 | from datasets import dataset_dict
8 |
9 | # models
10 | from models.nerf import Embedding, NeRF
11 | from models.rendering import render_rays
12 |
13 | # optimizer, scheduler, visualization
14 | from utils import *
15 |
16 | # losses
17 | from losses import loss_dict
18 |
19 | # metrics
20 | from metrics import *
21 |
22 | # pytorch-lightning
23 | from pytorch_lightning.callbacks import ModelCheckpoint
24 | from pytorch_lightning import LightningModule, Trainer
25 | from pytorch_lightning.logging import TestTubeLogger
26 |
27 | class NeRFSystem(LightningModule):
28 | def __init__(self, hparams):
29 | super(NeRFSystem, self).__init__()
30 | self.hparams = hparams
31 |
32 | self.loss = loss_dict[hparams.loss_type]()
33 |
34 | self.embedding_xyz = Embedding(3, 10) # 10 is the default number
35 | self.embedding_dir = Embedding(3, 4) # 4 is the default number
36 | self.embeddings = [self.embedding_xyz, self.embedding_dir]
37 |
38 | self.nerf_coarse = NeRF()
39 | self.models = [self.nerf_coarse]
40 | if hparams.N_importance > 0:
41 | self.nerf_fine = NeRF()
42 | self.models += [self.nerf_fine]
43 |
44 | def decode_batch(self, batch):
45 | rays = batch['rays'] # (B, 8)
46 | rgbs = batch['rgbs'] # (B, 3)
47 | return rays, rgbs
48 |
49 | def forward(self, rays):
50 | """Do batched inference on rays using chunk."""
51 | B = rays.shape[0]
52 | results = defaultdict(list)
53 | for i in range(0, B, self.hparams.chunk):
54 | rendered_ray_chunks = \
55 | render_rays(self.models,
56 | self.embeddings,
57 | rays[i:i+self.hparams.chunk],
58 | self.hparams.N_samples,
59 | self.hparams.use_disp,
60 | self.hparams.perturb,
61 | self.hparams.noise_std,
62 | self.hparams.N_importance,
63 | self.hparams.chunk, # chunk size is effective in val mode
64 | self.train_dataset.white_back)
65 |
66 | for k, v in rendered_ray_chunks.items():
67 | results[k] += [v]
68 |
69 | for k, v in results.items():
70 | results[k] = torch.cat(v, 0)
71 | return results
72 |
73 | def prepare_data(self):
74 | dataset = dataset_dict[self.hparams.dataset_name]
75 | kwargs = {'root_dir': self.hparams.root_dir,
76 | 'img_wh': tuple(self.hparams.img_wh)}
77 | if self.hparams.dataset_name == 'llff':
78 | kwargs['spheric_poses'] = self.hparams.spheric_poses
79 | kwargs['val_num'] = self.hparams.num_gpus
80 | self.train_dataset = dataset(split='train', **kwargs)
81 | self.val_dataset = dataset(split='val', **kwargs)
82 |
83 | def configure_optimizers(self):
84 | self.optimizer = get_optimizer(self.hparams, self.models)
85 | scheduler = get_scheduler(self.hparams, self.optimizer)
86 |
87 | return [self.optimizer], [scheduler]
88 |
89 | def train_dataloader(self):
90 | return DataLoader(self.train_dataset,
91 | shuffle=True,
92 | num_workers=4,
93 | batch_size=self.hparams.batch_size,
94 | pin_memory=True)
95 |
96 | def val_dataloader(self):
97 | return DataLoader(self.val_dataset,
98 | shuffle=False,
99 | num_workers=4,
100 | batch_size=1, # validate one image (H*W rays) at a time
101 | pin_memory=True)
102 |
103 | def training_step(self, batch, batch_nb):
104 | log = {'lr': get_learning_rate(self.optimizer)}
105 | rays, rgbs = self.decode_batch(batch)
106 | results = self(rays)
107 | log['train/loss'] = loss = self.loss(results, rgbs)
108 | typ = 'fine' if 'rgb_fine' in results else 'coarse'
109 |
110 | with torch.no_grad():
111 | psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
112 | log['train/psnr'] = psnr_
113 |
114 | return {'loss': loss,
115 | 'progress_bar': {'train_psnr': psnr_},
116 | 'log': log
117 | }
118 |
119 | def validation_step(self, batch, batch_nb):
120 | rays, rgbs = self.decode_batch(batch)
121 | rays = rays.squeeze() # (H*W, 3)
122 | rgbs = rgbs.squeeze() # (H*W, 3)
123 | results = self(rays)
124 | log = {'val_loss': self.loss(results, rgbs)}
125 | typ = 'fine' if 'rgb_fine' in results else 'coarse'
126 |
127 | if batch_nb == 0:
128 | W, H = self.hparams.img_wh
129 | img = results[f'rgb_{typ}'].view(H, W, 3).cpu()
130 | img = img.permute(2, 0, 1) # (3, H, W)
131 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
132 | depth = visualize_depth(results[f'depth_{typ}'].view(H, W)) # (3, H, W)
133 | stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W)
134 | self.logger.experiment.add_images('val/GT_pred_depth',
135 | stack, self.global_step)
136 |
137 | log['val_psnr'] = psnr(results[f'rgb_{typ}'], rgbs)
138 | return log
139 |
140 | def validation_epoch_end(self, outputs):
141 | mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
142 | mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean()
143 |
144 | return {'progress_bar': {'val_loss': mean_loss,
145 | 'val_psnr': mean_psnr},
146 | 'log': {'val/loss': mean_loss,
147 | 'val/psnr': mean_psnr}
148 | }
149 |
150 |
151 | if __name__ == '__main__':
152 | hparams = get_opts()
153 | system = NeRFSystem(hparams)
154 | checkpoint_callback = ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}',
155 | '{epoch:d}'),
156 | monitor='val/loss',
157 | mode='min',
158 | save_top_k=5,)
159 |
160 | logger = TestTubeLogger(
161 | save_dir="logs",
162 | name=hparams.exp_name,
163 | debug=False,
164 | create_git_tag=False
165 | )
166 |
167 | trainer = Trainer(max_epochs=hparams.num_epochs,
168 | checkpoint_callback=checkpoint_callback,
169 | resume_from_checkpoint=hparams.ckpt_path,
170 | logger=logger,
171 | early_stop_callback=None,
172 | weights_summary=None,
173 | progress_bar_refresh_rate=1,
174 | gpus=hparams.num_gpus,
175 | distributed_backend='ddp' if hparams.num_gpus>1 else None,
176 | num_sanity_val_steps=1,
177 | benchmark=True,
178 | profiler=hparams.num_gpus==1)
179 |
180 | trainer.fit(system)
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | from torch.optim import SGD, Adam
3 | from .optimizers import *
4 | # scheduler
5 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
6 | from .warmup_scheduler import GradualWarmupScheduler
7 |
8 | from .visualization import *
9 |
10 | def get_optimizer(hparams, models):
11 | eps = 1e-8
12 | parameters = []
13 | for model in models:
14 | parameters += list(model.parameters())
15 | if hparams.optimizer == 'sgd':
16 | optimizer = SGD(parameters, lr=hparams.lr,
17 | momentum=hparams.momentum, weight_decay=hparams.weight_decay)
18 | elif hparams.optimizer == 'adam':
19 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps,
20 | weight_decay=hparams.weight_decay)
21 | elif hparams.optimizer == 'radam':
22 | optimizer = RAdam(parameters, lr=hparams.lr, eps=eps,
23 | weight_decay=hparams.weight_decay)
24 | elif hparams.optimizer == 'ranger':
25 | optimizer = Ranger(parameters, lr=hparams.lr, eps=eps,
26 | weight_decay=hparams.weight_decay)
27 | else:
28 | raise ValueError('optimizer not recognized!')
29 |
30 | return optimizer
31 |
32 | def get_scheduler(hparams, optimizer):
33 | eps = 1e-8
34 | if hparams.lr_scheduler == 'steplr':
35 | scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
36 | gamma=hparams.decay_gamma)
37 | elif hparams.lr_scheduler == 'cosine':
38 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
39 | elif hparams.lr_scheduler == 'poly':
40 | scheduler = LambdaLR(optimizer,
41 | lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp)
42 | else:
43 | raise ValueError('scheduler not recognized!')
44 |
45 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']:
46 | scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier,
47 | total_epoch=hparams.warmup_epochs, after_scheduler=scheduler)
48 |
49 | return scheduler
50 |
51 | def get_learning_rate(optimizer):
52 | for param_group in optimizer.param_groups:
53 | return param_group['lr']
54 |
55 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]):
56 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
57 | checkpoint_ = {}
58 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
59 | checkpoint = checkpoint['state_dict']
60 | for k, v in checkpoint.items():
61 | if not k.startswith(model_name):
62 | continue
63 | k = k[len(model_name)+1:]
64 | for prefix in prefixes_to_ignore:
65 | if k.startswith(prefix):
66 | print('ignore', k)
67 | break
68 | else:
69 | checkpoint_[k] = v
70 | return checkpoint_
71 |
72 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
73 | model_dict = model.state_dict()
74 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
75 | model_dict.update(checkpoint_)
76 | model.load_state_dict(model_dict)
77 |
--------------------------------------------------------------------------------
/utils/optimizers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 | import itertools as it
5 |
6 | class RAdam(Optimizer):
7 |
8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
9 | if not 0.0 <= lr:
10 | raise ValueError("Invalid learning rate: {}".format(lr))
11 | if not 0.0 <= eps:
12 | raise ValueError("Invalid epsilon value: {}".format(eps))
13 | if not 0.0 <= betas[0] < 1.0:
14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
15 | if not 0.0 <= betas[1] < 1.0:
16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
17 |
18 | self.degenerated_to_sgd = degenerated_to_sgd
19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
20 | for param in params:
21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
22 | param['buffer'] = [[None, None, None] for _ in range(10)]
23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
24 | super(RAdam, self).__init__(params, defaults)
25 |
26 | def __setstate__(self, state):
27 | super(RAdam, self).__setstate__(state)
28 |
29 | def step(self, closure=None):
30 |
31 | loss = None
32 | if closure is not None:
33 | loss = closure()
34 |
35 | for group in self.param_groups:
36 |
37 | for p in group['params']:
38 | if p.grad is None:
39 | continue
40 | grad = p.grad.data.float()
41 | if grad.is_sparse:
42 | raise RuntimeError('RAdam does not support sparse gradients')
43 |
44 | p_data_fp32 = p.data.float()
45 |
46 | state = self.state[p]
47 |
48 | if len(state) == 0:
49 | state['step'] = 0
50 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
51 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
52 | else:
53 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
54 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
55 |
56 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
57 | beta1, beta2 = group['betas']
58 |
59 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
60 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
61 |
62 | state['step'] += 1
63 | buffered = group['buffer'][int(state['step'] % 10)]
64 | if state['step'] == buffered[0]:
65 | N_sma, step_size = buffered[1], buffered[2]
66 | else:
67 | buffered[0] = state['step']
68 | beta2_t = beta2 ** state['step']
69 | N_sma_max = 2 / (1 - beta2) - 1
70 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
71 | buffered[1] = N_sma
72 |
73 | # more conservative since it's an approximated value
74 | if N_sma >= 5:
75 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
76 | elif self.degenerated_to_sgd:
77 | step_size = 1.0 / (1 - beta1 ** state['step'])
78 | else:
79 | step_size = -1
80 | buffered[2] = step_size
81 |
82 | # more conservative since it's an approximated value
83 | if N_sma >= 5:
84 | if group['weight_decay'] != 0:
85 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
86 | denom = exp_avg_sq.sqrt().add_(group['eps'])
87 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
88 | p.data.copy_(p_data_fp32)
89 | elif step_size > 0:
90 | if group['weight_decay'] != 0:
91 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
92 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
93 | p.data.copy_(p_data_fp32)
94 |
95 | return loss
96 |
97 | class PlainRAdam(Optimizer):
98 |
99 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
100 | if not 0.0 <= lr:
101 | raise ValueError("Invalid learning rate: {}".format(lr))
102 | if not 0.0 <= eps:
103 | raise ValueError("Invalid epsilon value: {}".format(eps))
104 | if not 0.0 <= betas[0] < 1.0:
105 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
106 | if not 0.0 <= betas[1] < 1.0:
107 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
108 |
109 | self.degenerated_to_sgd = degenerated_to_sgd
110 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
111 |
112 | super(PlainRAdam, self).__init__(params, defaults)
113 |
114 | def __setstate__(self, state):
115 | super(PlainRAdam, self).__setstate__(state)
116 |
117 | def step(self, closure=None):
118 |
119 | loss = None
120 | if closure is not None:
121 | loss = closure()
122 |
123 | for group in self.param_groups:
124 |
125 | for p in group['params']:
126 | if p.grad is None:
127 | continue
128 | grad = p.grad.data.float()
129 | if grad.is_sparse:
130 | raise RuntimeError('RAdam does not support sparse gradients')
131 |
132 | p_data_fp32 = p.data.float()
133 |
134 | state = self.state[p]
135 |
136 | if len(state) == 0:
137 | state['step'] = 0
138 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
139 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
140 | else:
141 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
142 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
143 |
144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
145 | beta1, beta2 = group['betas']
146 |
147 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
148 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
149 |
150 | state['step'] += 1
151 | beta2_t = beta2 ** state['step']
152 | N_sma_max = 2 / (1 - beta2) - 1
153 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
154 |
155 |
156 | # more conservative since it's an approximated value
157 | if N_sma >= 5:
158 | if group['weight_decay'] != 0:
159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
160 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
161 | denom = exp_avg_sq.sqrt().add_(group['eps'])
162 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
163 | p.data.copy_(p_data_fp32)
164 | elif self.degenerated_to_sgd:
165 | if group['weight_decay'] != 0:
166 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
167 | step_size = group['lr'] / (1 - beta1 ** state['step'])
168 | p_data_fp32.add_(-step_size, exp_avg)
169 | p.data.copy_(p_data_fp32)
170 |
171 | return loss
172 |
173 | class AdamW(Optimizer):
174 |
175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
176 | if not 0.0 <= lr:
177 | raise ValueError("Invalid learning rate: {}".format(lr))
178 | if not 0.0 <= eps:
179 | raise ValueError("Invalid epsilon value: {}".format(eps))
180 | if not 0.0 <= betas[0] < 1.0:
181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
182 | if not 0.0 <= betas[1] < 1.0:
183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
184 |
185 | defaults = dict(lr=lr, betas=betas, eps=eps,
186 | weight_decay=weight_decay, warmup = warmup)
187 | super(AdamW, self).__init__(params, defaults)
188 |
189 | def __setstate__(self, state):
190 | super(AdamW, self).__setstate__(state)
191 |
192 | def step(self, closure=None):
193 | loss = None
194 | if closure is not None:
195 | loss = closure()
196 |
197 | for group in self.param_groups:
198 |
199 | for p in group['params']:
200 | if p.grad is None:
201 | continue
202 | grad = p.grad.data.float()
203 | if grad.is_sparse:
204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
205 |
206 | p_data_fp32 = p.data.float()
207 |
208 | state = self.state[p]
209 |
210 | if len(state) == 0:
211 | state['step'] = 0
212 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
214 | else:
215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
217 |
218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
219 | beta1, beta2 = group['betas']
220 |
221 | state['step'] += 1
222 |
223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
224 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
225 |
226 | denom = exp_avg_sq.sqrt().add_(group['eps'])
227 | bias_correction1 = 1 - beta1 ** state['step']
228 | bias_correction2 = 1 - beta2 ** state['step']
229 |
230 | if group['warmup'] > state['step']:
231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
232 | else:
233 | scheduled_lr = group['lr']
234 |
235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
236 |
237 | if group['weight_decay'] != 0:
238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
239 |
240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
241 |
242 | p.data.copy_(p_data_fp32)
243 |
244 | return loss
245 |
246 |
247 | #Ranger deep learning optimizer - RAdam + Lookahead combined.
248 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
249 |
250 | #Ranger has now been used to capture 12 records on the FastAI leaderboard.
251 |
252 | #This version = 9.3.19
253 |
254 | #Credits:
255 | #RAdam --> https://github.com/LiyuanLucasLiu/RAdam
256 | #Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
257 | #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
258 |
259 | #summary of changes:
260 | #full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
261 | #supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
262 | #changes 8/31/19 - fix references to *self*.N_sma_threshold;
263 | #changed eps to 1e-5 as better default than 1e-8.
264 |
265 |
266 | class Ranger(Optimizer):
267 |
268 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0):
269 | #parameter checks
270 | if not 0.0 <= alpha <= 1.0:
271 | raise ValueError(f'Invalid slow update rate: {alpha}')
272 | if not 1 <= k:
273 | raise ValueError(f'Invalid lookahead steps: {k}')
274 | if not lr > 0:
275 | raise ValueError(f'Invalid Learning Rate: {lr}')
276 | if not eps > 0:
277 | raise ValueError(f'Invalid eps: {eps}')
278 |
279 | #parameter comments:
280 | # beta1 (momentum) of .95 seems to work better than .90...
281 | #N_sma_threshold of 5 seems better in testing than 4.
282 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
283 |
284 | #prep defaults and init torch.optim base
285 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
286 | super().__init__(params,defaults)
287 |
288 | #adjustable threshold
289 | self.N_sma_threshhold = N_sma_threshhold
290 |
291 | #now we can get to work...
292 | #removed as we now use step from RAdam...no need for duplicate step counting
293 | #for group in self.param_groups:
294 | # group["step_counter"] = 0
295 | #print("group step counter init")
296 |
297 | #look ahead params
298 | self.alpha = alpha
299 | self.k = k
300 |
301 | #radam buffer for state
302 | self.radam_buffer = [[None,None,None] for ind in range(10)]
303 |
304 | #self.first_run_check=0
305 |
306 | #lookahead weights
307 | #9/2/19 - lookahead param tensors have been moved to state storage.
308 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
309 |
310 | #self.slow_weights = [[p.clone().detach() for p in group['params']]
311 | # for group in self.param_groups]
312 |
313 | #don't use grad for lookahead weights
314 | #for w in it.chain(*self.slow_weights):
315 | # w.requires_grad = False
316 |
317 | def __setstate__(self, state):
318 | print("set state called")
319 | super(Ranger, self).__setstate__(state)
320 |
321 |
322 | def step(self, closure=None):
323 | loss = None
324 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
325 | #Uncomment if you need to use the actual closure...
326 |
327 | #if closure is not None:
328 | #loss = closure()
329 |
330 | #Evaluate averages and grad, update param tensors
331 | for group in self.param_groups:
332 |
333 | for p in group['params']:
334 | if p.grad is None:
335 | continue
336 | grad = p.grad.data.float()
337 | if grad.is_sparse:
338 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
339 |
340 | p_data_fp32 = p.data.float()
341 |
342 | state = self.state[p] #get state dict for this param
343 |
344 | if len(state) == 0: #if first time to run...init dictionary with our desired entries
345 | #if self.first_run_check==0:
346 | #self.first_run_check=1
347 | #print("Initializing slow buffer...should not see this at load from saved model!")
348 | state['step'] = 0
349 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
350 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
351 |
352 | #look ahead weight storage now in state dict
353 | state['slow_buffer'] = torch.empty_like(p.data)
354 | state['slow_buffer'].copy_(p.data)
355 |
356 | else:
357 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
358 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
359 |
360 | #begin computations
361 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
362 | beta1, beta2 = group['betas']
363 |
364 | #compute variance mov avg
365 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
366 | #compute mean moving avg
367 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
368 |
369 | state['step'] += 1
370 |
371 |
372 | buffered = self.radam_buffer[int(state['step'] % 10)]
373 | if state['step'] == buffered[0]:
374 | N_sma, step_size = buffered[1], buffered[2]
375 | else:
376 | buffered[0] = state['step']
377 | beta2_t = beta2 ** state['step']
378 | N_sma_max = 2 / (1 - beta2) - 1
379 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
380 | buffered[1] = N_sma
381 | if N_sma > self.N_sma_threshhold:
382 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
383 | else:
384 | step_size = 1.0 / (1 - beta1 ** state['step'])
385 | buffered[2] = step_size
386 |
387 | if group['weight_decay'] != 0:
388 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
389 |
390 | if N_sma > self.N_sma_threshhold:
391 | denom = exp_avg_sq.sqrt().add_(group['eps'])
392 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
393 | else:
394 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
395 |
396 | p.data.copy_(p_data_fp32)
397 |
398 | #integrated look ahead...
399 | #we do it at the param level instead of group level
400 | if state['step'] % group['k'] == 0:
401 | slow_p = state['slow_buffer'] #get access to slow param tensor
402 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
403 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
404 |
405 | return loss
--------------------------------------------------------------------------------
/utils/save_weights_only.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 |
4 | def get_opts():
5 | parser = argparse.ArgumentParser()
6 |
7 | parser.add_argument('--ckpt_path', type=str, required=True,
8 | help='checkpoint path')
9 |
10 | return parser.parse_args()
11 |
12 | if __name__ == "__main__":
13 | args = get_opts()
14 | checkpoint = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
15 | torch.save(checkpoint['state_dict'], args.ckpt_path.split('/')[-2]+'.ckpt')
16 | print('Done!')
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | import numpy as np
3 | import cv2
4 | from PIL import Image
5 |
6 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET):
7 | """
8 | depth: (H, W)
9 | """
10 | x = depth.cpu().numpy()
11 | x = np.nan_to_num(x) # change nan to 0
12 | mi = np.min(x) # get minimum depth
13 | ma = np.max(x)
14 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
15 | x = (255*x).astype(np.uint8)
16 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
17 | x_ = T.ToTensor()(x_) # (3, H, W)
18 | return x_
--------------------------------------------------------------------------------
/utils/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 | class GradualWarmupScheduler(_LRScheduler):
5 | """ Gradually warm-up(increasing) learning rate in optimizer.
6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7 | Args:
8 | optimizer (Optimizer): Wrapped optimizer.
9 | multiplier: target learning rate = base lr * multiplier
10 | total_epoch: target learning rate is reached at total_epoch, gradually
11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
12 | """
13 |
14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
15 | self.multiplier = multiplier
16 | if self.multiplier < 1.:
17 | raise ValueError('multiplier should be greater thant or equal to 1.')
18 | self.total_epoch = total_epoch
19 | self.after_scheduler = after_scheduler
20 | self.finished = False
21 | super().__init__(optimizer)
22 |
23 | def get_lr(self):
24 | if self.last_epoch > self.total_epoch:
25 | if self.after_scheduler:
26 | if not self.finished:
27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
28 | self.finished = True
29 | return self.after_scheduler.get_lr()
30 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
31 |
32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
33 |
34 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
35 | if epoch is None:
36 | epoch = self.last_epoch + 1
37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
38 | if self.last_epoch <= self.total_epoch:
39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
41 | param_group['lr'] = lr
42 | else:
43 | if epoch is None:
44 | self.after_scheduler.step(metrics, None)
45 | else:
46 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
47 |
48 | def step(self, epoch=None, metrics=None):
49 | if type(self.after_scheduler) != ReduceLROnPlateau:
50 | if self.finished and self.after_scheduler:
51 | if epoch is None:
52 | self.after_scheduler.step(None)
53 | else:
54 | self.after_scheduler.step(epoch - self.total_epoch)
55 | else:
56 | return super(GradualWarmupScheduler, self).step(epoch)
57 | else:
58 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------