├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── ScanNet.md ├── camera_pose_visualizer.py ├── configs ├── chair.txt ├── drums.txt ├── fern.txt ├── ficus.txt ├── flower.txt ├── fortress.txt ├── horns.txt ├── hotdog.txt ├── leaves.txt ├── lego.txt ├── materials.txt ├── mic.txt ├── orchids.txt ├── room.txt ├── scannet_scene0000.txt ├── ship.txt └── trex.txt ├── hash_encoding.py ├── load_LINEMOD.py ├── load_blender.py ├── load_deepvoxels.py ├── load_llff.py ├── load_scannet.py ├── loss.py ├── optimizer.py ├── radam.py ├── ray_utils.py ├── run_nerf.py ├── run_nerf_helpers.py ├── scripts ├── arial.ttf ├── make_gif.py ├── plot_losses.py └── run_all_checkpoints.sh ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Datasets 132 | data/ 133 | logs/ 134 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: HashNeRF-pytorch 6 | message: >- 7 | If you use this software, please cite it using the 8 | metadata from this file. 9 | type: software 10 | authors: 11 | - given-names: Yash 12 | family-names: Bhalgat 13 | email: yashbhalgat95@gmail.com 14 | affiliation: University of Oxford 15 | orcid: 'https://orcid.org/0000-0001-7775-6250' 16 | url: 'https://github.com/yashbhalgat/HashNeRF-pytorch' 17 | abstract: >- 18 | HashNeRF-pytorch is a pure PyTorch Implementation of the 19 | NVIDIA paper on Instant Training of Neural Graphics 20 | primitives (Instant-NGP). This codebase was built with the 21 | purpose of enabling AI Researchers to play around and 22 | innovate further upon this method. 23 | keywords: 24 | - machine learning 25 | - artificial intelligence 26 | - computer vision 27 | - computer graphics 28 | - nerf 29 | - 3D reconstruction 30 | - neural rendering 31 | license: MIT 32 | version: '1.0' 33 | date-released: '2022-06-01' 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yash Sanjay Bhalgat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HashNeRF-pytorch 2 | 3 | ### 🌟 Update 🌟 4 | Get answers to any questions about this repository using this [HuggingFace Chatbot](https://hf.co/chat/assistant/66b33a28bb36e2de9d8a2a93). 5 | 6 | --- 7 | 8 | [Instant-NGP](https://github.com/NVlabs/instant-ngp) recently introduced a Multi-resolution Hash Encoding for neural graphics primitives like [NeRFs](https://www.matthewtancik.com/nerf). The original NVIDIA implementation mainly in C++/CUDA, based on [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), can train NeRFs upto 100x faster! 9 | 10 | This project is a **pure PyTorch** implementation of [Instant-NGP](https://github.com/NVlabs/instant-ngp), built with the purpose of enabling AI Researchers to play around and innovate further upon this method. 11 | 12 | This project is built on top of the super-useful [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch) implementation. 13 | 14 | ## Convergence speed w.r.t. Vanilla NeRF 15 | **HashNeRF-pytorch** (left) vs [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch) (right): 16 | 17 | https://user-images.githubusercontent.com/8559512/154065666-f2eb156c-333c-4de4-99aa-8aa15a9254de.mp4 18 | 19 | After training for just 5k iterations (~10 minutes on a single 1050Ti), you start seeing a _crisp_ chair rendering. :) 20 | 21 | # Instructions 22 | Download the nerf-synthetic dataset from here: [Google Drive](https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi). 23 | 24 | To train a `chair` HashNeRF model: 25 | ``` 26 | python run_nerf.py --config configs/chair.txt --finest_res 512 --log2_hashmap_size 19 --lrate 0.01 --lrate_decay 10 27 | ``` 28 | 29 | To train for other objects like `ficus`/`hotdog`, replace `configs/chair.txt` with `configs/{object}.txt`: 30 | 31 | ![hotdog_ficus](https://user-images.githubusercontent.com/8559512/154066554-d3656d4a-1738-427c-982d-3ef4e4071969.gif) 32 | 33 | ## Extras 34 | The code-base has additional support for: 35 | * Total Variation Loss for smoother embeddings (use `--tv-loss-weight` to enable) 36 | * Sparsity-inducing loss on the ray weights (use `--sparse-loss-weight` to enable) 37 | 38 | ## ScanNet dataset support 39 | The repo now supports training a NeRF model on a scene from the ScanNet dataset. I personally found setting up the ScanNet dataset to be a bit tricky. Please find some instructions/notes in [ScanNet.md](ScanNet.md). 40 | 41 | 42 | ## TODO: 43 | * Voxel pruning during training and/or inference 44 | * Accelerated ray tracing, early ray termination 45 | 46 | 47 | # Citation 48 | Kudos to [Thomas Müller](https://tom94.net/) and the NVIDIA team for this amazing work, that will greatly help accelerate Neural Graphics research: 49 | ``` 50 | @article{mueller2022instant, 51 | title = {Instant Neural Graphics Primitives with a Multiresolution Hash Encoding}, 52 | author = {Thomas M\"uller and Alex Evans and Christoph Schied and Alexander Keller}, 53 | journal = {arXiv:2201.05989}, 54 | year = {2022}, 55 | month = jan 56 | } 57 | ``` 58 | 59 | Also, thanks to [Yen-Chen Lin](https://yenchenlin.me/) for the super-useful [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch): 60 | ``` 61 | @misc{lin2020nerfpytorch, 62 | title={NeRF-pytorch}, 63 | author={Yen-Chen, Lin}, 64 | publisher = {GitHub}, 65 | journal = {GitHub repository}, 66 | howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}}, 67 | year={2020} 68 | } 69 | ``` 70 | 71 | If you find this project useful, please consider to cite: 72 | ``` 73 | @misc{bhalgat2022hashnerfpytorch, 74 | title={HashNeRF-pytorch}, 75 | author={Yash Bhalgat}, 76 | publisher = {GitHub}, 77 | journal = {GitHub repository}, 78 | howpublished={\url{https://github.com/yashbhalgat/HashNeRF-pytorch/}}, 79 | year={2022} 80 | } 81 | ``` 82 | 83 | ## Star History 84 | 85 | [![Star History Chart](https://api.star-history.com/svg?repos=yashbhalgat/HashNeRF-pytorch&type=Date)](https://star-history.com/#yashbhalgat/HashNeRF-pytorch&Date) 86 | -------------------------------------------------------------------------------- /ScanNet.md: -------------------------------------------------------------------------------- 1 | # ScanNet Instructions 2 | 3 | I personally found it a bit tricky to setup the ScanNet dataset the first time I tried it. So, I am compiling some notes/instructions on how to do it in case someone finds it useful. 4 | 5 | ### 1. Dataset download 6 | 7 | To download ScanNet data and its labels, follow the instructions [here](https://github.com/ScanNet/ScanNet). Basically, fill out the ScanNet Terms of Use agreement and email it to [scannet@googlegroups.com](mailto:scannet@googlegroups.com). You will receive a download link to the dataset. Download the dataset and unzip it. 8 | 9 | ### 2. Use [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) to extract RGB-D and camera data 10 | Use the `reader.py` script as follows for each scene you want to work with: 11 | ``` 12 | python reader.py --filename [.sens file to export data from] --output_path [output directory to export data to] 13 | Options: 14 | --export_depth_images: export all depth frames as 16-bit pngs (depth shift 1000) 15 | --export_color_images: export all color frames as 8-bit rgb jpgs 16 | --export_poses: export all camera poses (4x4 matrix, camera to world) 17 | --export_intrinsics: export camera intrinsics (4x4 matrix) 18 | ``` 19 | 20 | ### 3. Then, use this [script](https://github.com/zju3dv/object_nerf/blob/main/data_preparation/scannet_sens_reader/convert_to_nerf_style_data.py) to convert the data to NeRF-style format. For instructions, see Step 1 [here](https://github.com/zju3dv/object_nerf/tree/main/data_preparation). 21 | 1. The transformation matrices (`c2w`) in the generated transforms_xxx.json will be in SLAM / OpenCV format (xyz -> right down forward). You need to change to NDC format (xyz -> right up back) in the dataloader for training with NeRF convention. 22 | 2. For example, see the conversion done [here](https://github.com/cvg/nice-slam/blob/7af15cc33729aa5a8ca052908d96f495e34ab34c/src/utils/datasets.py#L205). 23 | -------------------------------------------------------------------------------- /camera_pose_visualizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | from matplotlib.patches import Patch 7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 8 | import pdb 9 | 10 | class CameraPoseVisualizer: 11 | def __init__(self, xlim, ylim, zlim): 12 | self.fig = plt.figure(figsize=(18, 7)) 13 | self.ax = self.fig.gca(projection='3d') 14 | self.ax.set_aspect("auto") 15 | self.ax.set_xlim(xlim) 16 | self.ax.set_ylim(ylim) 17 | self.ax.set_zlim(zlim) 18 | self.ax.set_xlabel('x') 19 | self.ax.set_ylabel('y') 20 | self.ax.set_zlabel('z') 21 | print('initialize camera pose visualizer') 22 | 23 | def extrinsic2pyramid(self, extrinsic, color='r', focal_len_scaled=5, aspect_ratio=0.3): 24 | focal_len_scaled = -1*focal_len_scaled 25 | vertex_std = np.array([[0, 0, 0, 1], 26 | [focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 27 | [focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 28 | [-focal_len_scaled * aspect_ratio, focal_len_scaled * aspect_ratio, focal_len_scaled, 1], 29 | [-focal_len_scaled * aspect_ratio, -focal_len_scaled * aspect_ratio, focal_len_scaled, 1]]) 30 | vertex_transformed = vertex_std @ extrinsic.T 31 | meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], 32 | [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], 33 | [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], 34 | [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], 35 | [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] 36 | self.ax.add_collection3d( 37 | Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) 38 | 39 | def customize_legend(self, list_label): 40 | list_handle = [] 41 | for idx, label in enumerate(list_label): 42 | color = plt.cm.rainbow(idx / len(list_label)) 43 | patch = Patch(color=color, label=label) 44 | list_handle.append(patch) 45 | plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) 46 | 47 | def colorbar(self, max_frame_length): 48 | cmap = mpl.cm.rainbow 49 | norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) 50 | self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='Frame Number') 51 | 52 | def show(self): 53 | plt.title('Extrinsic Parameters') 54 | plt.show() 55 | 56 | if __name__ == '__main__': 57 | poses = [] 58 | with open(os.path.join('data/nerf_synthetic/chair/', 'transforms_train.json'), 'r') as fp: 59 | meta = json.load(fp) 60 | for frame in meta['frames']: 61 | poses.append(np.array(frame['transform_matrix'])) 62 | t_arr = np.array([pose[:3,-1] for pose in poses]) 63 | maxes = t_arr.max(axis=0) 64 | mins = t_arr.min(axis=0) 65 | 66 | # argument : the minimum/maximum value of x, y, z 67 | visualizer = CameraPoseVisualizer([mins[0]-1, maxes[0]+1], [mins[1]-1, maxes[1]+1], [mins[2]-1, maxes[2]+1]) 68 | 69 | # argument : extrinsic matrix, color, scaled focal length(z-axis length of frame body of camera 70 | for pose in poses: 71 | visualizer.extrinsic2pyramid(pose, 'c', 1) 72 | 73 | visualizer.show() 74 | -------------------------------------------------------------------------------- /configs/chair.txt: -------------------------------------------------------------------------------- 1 | expname = blender_chair 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/chair 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/drums.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_drums 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/drums 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/fern.txt: -------------------------------------------------------------------------------- 1 | expname = fern_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fern 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/ficus.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ficus 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ficus 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/flower.txt: -------------------------------------------------------------------------------- 1 | expname = flower_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/flower 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/fortress.txt: -------------------------------------------------------------------------------- 1 | expname = fortress_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fortress 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/horns.txt: -------------------------------------------------------------------------------- 1 | expname = horns_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/horns 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/hotdog.txt: -------------------------------------------------------------------------------- 1 | expname = blender_hotdog 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/hotdog 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | 21 | -------------------------------------------------------------------------------- /configs/leaves.txt: -------------------------------------------------------------------------------- 1 | expname = leaves_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/leaves 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_lego 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/lego 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/materials.txt: -------------------------------------------------------------------------------- 1 | expname = blender_materials 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/materials 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/mic.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_mic 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/mic 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/orchids.txt: -------------------------------------------------------------------------------- 1 | expname = orchids_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/orchids 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/room.txt: -------------------------------------------------------------------------------- 1 | expname = room_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/room 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/scannet_scene0000.txt: -------------------------------------------------------------------------------- 1 | expname = scannet_scene0000_00 2 | basedir = ./logs 3 | datadir = /work/yashsb/datasets/ScanNet/ 4 | dataset_type = scannet 5 | 6 | no_batching = False 7 | 8 | use_viewdirs = True 9 | white_bkgd = False 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | -------------------------------------------------------------------------------- /configs/ship.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ship 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ship 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/trex.txt: -------------------------------------------------------------------------------- 1 | expname = trex_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/trex 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /hash_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import pdb 7 | 8 | from utils import get_voxel_vertices 9 | 10 | class HashEmbedder(nn.Module): 11 | def __init__(self, bounding_box, n_levels=16, n_features_per_level=2,\ 12 | log2_hashmap_size=19, base_resolution=16, finest_resolution=512): 13 | super(HashEmbedder, self).__init__() 14 | self.bounding_box = bounding_box 15 | self.n_levels = n_levels 16 | self.n_features_per_level = n_features_per_level 17 | self.log2_hashmap_size = log2_hashmap_size 18 | self.base_resolution = torch.tensor(base_resolution) 19 | self.finest_resolution = torch.tensor(finest_resolution) 20 | self.out_dim = self.n_levels * self.n_features_per_level 21 | 22 | self.b = torch.exp((torch.log(self.finest_resolution)-torch.log(self.base_resolution))/(n_levels-1)) 23 | 24 | self.embeddings = nn.ModuleList([nn.Embedding(2**self.log2_hashmap_size, \ 25 | self.n_features_per_level) for i in range(n_levels)]) 26 | # custom uniform initialization 27 | for i in range(n_levels): 28 | nn.init.uniform_(self.embeddings[i].weight, a=-0.0001, b=0.0001) 29 | # self.embeddings[i].weight.data.zero_() 30 | 31 | 32 | def trilinear_interp(self, x, voxel_min_vertex, voxel_max_vertex, voxel_embedds): 33 | ''' 34 | x: B x 3 35 | voxel_min_vertex: B x 3 36 | voxel_max_vertex: B x 3 37 | voxel_embedds: B x 8 x 2 38 | ''' 39 | # source: https://en.wikipedia.org/wiki/Trilinear_interpolation 40 | weights = (x - voxel_min_vertex)/(voxel_max_vertex-voxel_min_vertex) # B x 3 41 | 42 | # step 1 43 | # 0->000, 1->001, 2->010, 3->011, 4->100, 5->101, 6->110, 7->111 44 | c00 = voxel_embedds[:,0]*(1-weights[:,0][:,None]) + voxel_embedds[:,4]*weights[:,0][:,None] 45 | c01 = voxel_embedds[:,1]*(1-weights[:,0][:,None]) + voxel_embedds[:,5]*weights[:,0][:,None] 46 | c10 = voxel_embedds[:,2]*(1-weights[:,0][:,None]) + voxel_embedds[:,6]*weights[:,0][:,None] 47 | c11 = voxel_embedds[:,3]*(1-weights[:,0][:,None]) + voxel_embedds[:,7]*weights[:,0][:,None] 48 | 49 | # step 2 50 | c0 = c00*(1-weights[:,1][:,None]) + c10*weights[:,1][:,None] 51 | c1 = c01*(1-weights[:,1][:,None]) + c11*weights[:,1][:,None] 52 | 53 | # step 3 54 | c = c0*(1-weights[:,2][:,None]) + c1*weights[:,2][:,None] 55 | 56 | return c 57 | 58 | def forward(self, x): 59 | # x is 3D point position: B x 3 60 | x_embedded_all = [] 61 | for i in range(self.n_levels): 62 | resolution = torch.floor(self.base_resolution * self.b**i) 63 | voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask = get_voxel_vertices(\ 64 | x, self.bounding_box, \ 65 | resolution, self.log2_hashmap_size) 66 | 67 | voxel_embedds = self.embeddings[i](hashed_voxel_indices) 68 | 69 | x_embedded = self.trilinear_interp(x, voxel_min_vertex, voxel_max_vertex, voxel_embedds) 70 | x_embedded_all.append(x_embedded) 71 | 72 | keep_mask = keep_mask.sum(dim=-1)==keep_mask.shape[-1] 73 | return torch.cat(x_embedded_all, dim=-1), keep_mask 74 | 75 | 76 | class SHEncoder(nn.Module): 77 | def __init__(self, input_dim=3, degree=4): 78 | 79 | super().__init__() 80 | 81 | self.input_dim = input_dim 82 | self.degree = degree 83 | 84 | assert self.input_dim == 3 85 | assert self.degree >= 1 and self.degree <= 5 86 | 87 | self.out_dim = degree ** 2 88 | 89 | self.C0 = 0.28209479177387814 90 | self.C1 = 0.4886025119029199 91 | self.C2 = [ 92 | 1.0925484305920792, 93 | -1.0925484305920792, 94 | 0.31539156525252005, 95 | -1.0925484305920792, 96 | 0.5462742152960396 97 | ] 98 | self.C3 = [ 99 | -0.5900435899266435, 100 | 2.890611442640554, 101 | -0.4570457994644658, 102 | 0.3731763325901154, 103 | -0.4570457994644658, 104 | 1.445305721320277, 105 | -0.5900435899266435 106 | ] 107 | self.C4 = [ 108 | 2.5033429417967046, 109 | -1.7701307697799304, 110 | 0.9461746957575601, 111 | -0.6690465435572892, 112 | 0.10578554691520431, 113 | -0.6690465435572892, 114 | 0.47308734787878004, 115 | -1.7701307697799304, 116 | 0.6258357354491761 117 | ] 118 | 119 | def forward(self, input, **kwargs): 120 | 121 | result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device) 122 | x, y, z = input.unbind(-1) 123 | 124 | result[..., 0] = self.C0 125 | if self.degree > 1: 126 | result[..., 1] = -self.C1 * y 127 | result[..., 2] = self.C1 * z 128 | result[..., 3] = -self.C1 * x 129 | if self.degree > 2: 130 | xx, yy, zz = x * x, y * y, z * z 131 | xy, yz, xz = x * y, y * z, x * z 132 | result[..., 4] = self.C2[0] * xy 133 | result[..., 5] = self.C2[1] * yz 134 | result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy) 135 | #result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting... 136 | result[..., 7] = self.C2[3] * xz 137 | result[..., 8] = self.C2[4] * (xx - yy) 138 | if self.degree > 3: 139 | result[..., 9] = self.C3[0] * y * (3 * xx - yy) 140 | result[..., 10] = self.C3[1] * xy * z 141 | result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy) 142 | result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy) 143 | result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy) 144 | result[..., 14] = self.C3[5] * z * (xx - yy) 145 | result[..., 15] = self.C3[6] * x * (xx - 3 * yy) 146 | if self.degree > 4: 147 | result[..., 16] = self.C4[0] * xy * (xx - yy) 148 | result[..., 17] = self.C4[1] * yz * (3 * xx - yy) 149 | result[..., 18] = self.C4[2] * xy * (7 * zz - 1) 150 | result[..., 19] = self.C4[3] * yz * (7 * zz - 3) 151 | result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3) 152 | result[..., 21] = self.C4[5] * xz * (7 * zz - 3) 153 | result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1) 154 | result[..., 23] = self.C4[7] * xz * (xx - 3 * yy) 155 | result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 156 | 157 | return result 158 | -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | from utils import get_bbox3d_for_blenderobj 10 | 11 | trans_t = lambda t : torch.Tensor([ 12 | [1,0,0,0], 13 | [0,1,0,0], 14 | [0,0,1,t], 15 | [0,0,0,1]]).float() 16 | 17 | rot_phi = lambda phi : torch.Tensor([ 18 | [1,0,0,0], 19 | [0,np.cos(phi),-np.sin(phi),0], 20 | [0,np.sin(phi), np.cos(phi),0], 21 | [0,0,0,1]]).float() 22 | 23 | rot_theta = lambda th : torch.Tensor([ 24 | [np.cos(th),0,-np.sin(th),0], 25 | [0,1,0,0], 26 | [np.sin(th),0, np.cos(th),0], 27 | [0,0,0,1]]).float() 28 | 29 | 30 | def pose_spherical(theta, phi, radius): 31 | c2w = trans_t(radius) 32 | c2w = rot_phi(phi/180.*np.pi) @ c2w 33 | c2w = rot_theta(theta/180.*np.pi) @ c2w 34 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 35 | return c2w 36 | 37 | 38 | def load_blender_data(basedir, half_res=False, testskip=1): 39 | splits = ['train', 'val', 'test'] 40 | metas = {} 41 | for s in splits: 42 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 43 | metas[s] = json.load(fp) 44 | 45 | all_imgs = [] 46 | all_poses = [] 47 | counts = [0] 48 | for s in splits: 49 | meta = metas[s] 50 | imgs = [] 51 | poses = [] 52 | if s=='train' or testskip==0: 53 | skip = 1 54 | else: 55 | skip = testskip 56 | 57 | for frame in meta['frames'][::skip]: 58 | fname = os.path.join(basedir, frame['file_path'] + '.png') 59 | imgs.append(imageio.imread(fname)) 60 | poses.append(np.array(frame['transform_matrix'])) 61 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 62 | poses = np.array(poses).astype(np.float32) 63 | counts.append(counts[-1] + imgs.shape[0]) 64 | all_imgs.append(imgs) 65 | all_poses.append(poses) 66 | 67 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 68 | 69 | imgs = np.concatenate(all_imgs, 0) 70 | poses = np.concatenate(all_poses, 0) 71 | 72 | H, W = imgs[0].shape[:2] 73 | camera_angle_x = float(meta['camera_angle_x']) 74 | focal = .5 * W / np.tan(.5 * camera_angle_x) 75 | 76 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 77 | 78 | if half_res: 79 | H = H//2 80 | W = W//2 81 | focal = focal/2. 82 | 83 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 84 | for i, img in enumerate(imgs): 85 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 86 | imgs = imgs_half_res 87 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 88 | 89 | bounding_box = get_bbox3d_for_blenderobj(metas["train"], H, W, near=2.0, far=6.0) 90 | 91 | return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | from utils import get_bbox3d_for_llff 5 | 6 | ########## Slightly modified version of LLFF data loading code 7 | ########## see https://github.com/Fyusion/LLFF for original 8 | 9 | def _minify(basedir, factors=[], resolutions=[]): 10 | needtoload = False 11 | for r in factors: 12 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 13 | if not os.path.exists(imgdir): 14 | needtoload = True 15 | for r in resolutions: 16 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 17 | if not os.path.exists(imgdir): 18 | needtoload = True 19 | if not needtoload: 20 | return 21 | 22 | from shutil import copy 23 | from subprocess import check_output 24 | 25 | imgdir = os.path.join(basedir, 'images') 26 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 27 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 28 | imgdir_orig = imgdir 29 | 30 | wd = os.getcwd() 31 | 32 | for r in factors + resolutions: 33 | if isinstance(r, int): 34 | name = 'images_{}'.format(r) 35 | resizearg = '{}%'.format(100./r) 36 | else: 37 | name = 'images_{}x{}'.format(r[1], r[0]) 38 | resizearg = '{}x{}'.format(r[1], r[0]) 39 | imgdir = os.path.join(basedir, name) 40 | if os.path.exists(imgdir): 41 | continue 42 | 43 | print('Minifying', r, basedir) 44 | 45 | os.makedirs(imgdir) 46 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 47 | 48 | ext = imgs[0].split('.')[-1] 49 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 50 | print(args) 51 | os.chdir(imgdir) 52 | check_output(args, shell=True) 53 | os.chdir(wd) 54 | 55 | if ext != 'png': 56 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 57 | print('Removed duplicates') 58 | print('Done') 59 | 60 | 61 | 62 | 63 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 64 | 65 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 66 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 67 | bds = poses_arr[:, -2:].transpose([1,0]) 68 | 69 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 70 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 71 | sh = imageio.imread(img0).shape 72 | 73 | sfx = '' 74 | 75 | if factor is not None: 76 | sfx = '_{}'.format(factor) 77 | _minify(basedir, factors=[factor]) 78 | factor = factor 79 | elif height is not None: 80 | factor = sh[0] / float(height) 81 | width = int(sh[1] / factor) 82 | _minify(basedir, resolutions=[[height, width]]) 83 | sfx = '_{}x{}'.format(width, height) 84 | elif width is not None: 85 | factor = sh[1] / float(width) 86 | height = int(sh[0] / factor) 87 | _minify(basedir, resolutions=[[height, width]]) 88 | sfx = '_{}x{}'.format(width, height) 89 | else: 90 | factor = 1 91 | 92 | imgdir = os.path.join(basedir, 'images' + sfx) 93 | if not os.path.exists(imgdir): 94 | print( imgdir, 'does not exist, returning' ) 95 | return 96 | 97 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 98 | if poses.shape[-1] != len(imgfiles): 99 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 100 | return 101 | 102 | sh = imageio.imread(imgfiles[0]).shape 103 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 104 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 105 | 106 | if not load_imgs: 107 | return poses, bds 108 | 109 | def imread(f): 110 | if f.endswith('png'): 111 | return imageio.imread(f, ignoregamma=True) 112 | else: 113 | return imageio.imread(f) 114 | 115 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 116 | imgs = np.stack(imgs, -1) 117 | 118 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 119 | return poses, bds, imgs 120 | 121 | 122 | 123 | 124 | 125 | 126 | def normalize(x): 127 | return x / np.linalg.norm(x) 128 | 129 | def viewmatrix(z, up, pos): 130 | vec2 = normalize(z) 131 | vec1_avg = up 132 | vec0 = normalize(np.cross(vec1_avg, vec2)) 133 | vec1 = normalize(np.cross(vec2, vec0)) 134 | m = np.stack([vec0, vec1, vec2, pos], 1) 135 | return m 136 | 137 | def ptstocam(pts, c2w): 138 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 139 | return tt 140 | 141 | def poses_avg(poses): 142 | 143 | hwf = poses[0, :3, -1:] 144 | 145 | center = poses[:, :3, 3].mean(0) 146 | vec2 = normalize(poses[:, :3, 2].sum(0)) 147 | up = poses[:, :3, 1].sum(0) 148 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 149 | 150 | return c2w 151 | 152 | 153 | 154 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 155 | render_poses = [] 156 | rads = np.array(list(rads) + [1.]) 157 | hwf = c2w[:,4:5] 158 | 159 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 160 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 161 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 162 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 163 | return render_poses 164 | 165 | 166 | 167 | def recenter_poses(poses): 168 | 169 | poses_ = poses+0 170 | bottom = np.reshape([0,0,0,1.], [1,4]) 171 | c2w = poses_avg(poses) 172 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 173 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 174 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 175 | 176 | poses = np.linalg.inv(c2w) @ poses 177 | poses_[:,:3,:4] = poses[:,:3,:4] 178 | poses = poses_ 179 | return poses 180 | 181 | 182 | ##################### 183 | 184 | 185 | def spherify_poses(poses, bds): 186 | 187 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 188 | 189 | rays_d = poses[:,:3,2:3] 190 | rays_o = poses[:,:3,3:4] 191 | 192 | def min_line_dist(rays_o, rays_d): 193 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 194 | b_i = -A_i @ rays_o 195 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 196 | return pt_mindist 197 | 198 | pt_mindist = min_line_dist(rays_o, rays_d) 199 | 200 | center = pt_mindist 201 | up = (poses[:,:3,3] - center).mean(0) 202 | 203 | vec0 = normalize(up) 204 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 205 | vec2 = normalize(np.cross(vec0, vec1)) 206 | pos = center 207 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 208 | 209 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 210 | 211 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 212 | 213 | sc = 1./rad 214 | poses_reset[:,:3,3] *= sc 215 | bds *= sc 216 | rad *= sc 217 | 218 | centroid = np.mean(poses_reset[:,:3,3], 0) 219 | zh = centroid[2] 220 | radcircle = np.sqrt(rad**2-zh**2) 221 | new_poses = [] 222 | 223 | for th in np.linspace(0.,2.*np.pi, 120): 224 | 225 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 226 | up = np.array([0,0,-1.]) 227 | 228 | vec2 = normalize(camorigin) 229 | vec0 = normalize(np.cross(vec2, up)) 230 | vec1 = normalize(np.cross(vec2, vec0)) 231 | pos = camorigin 232 | p = np.stack([vec0, vec1, vec2, pos], 1) 233 | 234 | new_poses.append(p) 235 | 236 | new_poses = np.stack(new_poses, 0) 237 | 238 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 239 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 240 | 241 | return poses_reset, new_poses, bds 242 | 243 | 244 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 245 | 246 | 247 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 248 | print('Loaded', basedir, bds.min(), bds.max()) 249 | 250 | # Correct rotation matrix ordering and move variable dim to axis 0 251 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 252 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 253 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 254 | images = imgs 255 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 256 | 257 | # Rescale if bd_factor is provided 258 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 259 | poses[:,:3,3] *= sc 260 | bds *= sc 261 | 262 | if recenter: 263 | poses = recenter_poses(poses) 264 | 265 | if spherify: 266 | poses, render_poses, bds = spherify_poses(poses, bds) 267 | 268 | else: 269 | 270 | c2w = poses_avg(poses) 271 | print('recentered', c2w.shape) 272 | print(c2w[:3,:4]) 273 | 274 | ## Get spiral 275 | # Get average pose 276 | up = normalize(poses[:, :3, 1].sum(0)) 277 | 278 | # Find a reasonable "focus depth" for this dataset 279 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 280 | dt = .75 281 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 282 | focal = mean_dz 283 | 284 | # Get radii for spiral path 285 | shrink_factor = .8 286 | zdelta = close_depth * .2 287 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 288 | rads = np.percentile(np.abs(tt), 90, 0) 289 | c2w_path = c2w 290 | N_views = 120 291 | N_rots = 2 292 | if path_zflat: 293 | # zloc = np.percentile(tt, 10, 0)[2] 294 | zloc = -close_depth * .1 295 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 296 | rads[2] = 0. 297 | N_rots = 1 298 | N_views/=2 299 | 300 | # Generate poses for spiral path 301 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 302 | 303 | 304 | render_poses = np.array(render_poses).astype(np.float32) 305 | 306 | c2w = poses_avg(poses) 307 | print('Data:') 308 | print(poses.shape, images.shape, bds.shape) 309 | 310 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 311 | i_test = np.argmin(dists) 312 | print('HOLDOUT view is', i_test) 313 | 314 | images = images.astype(np.float32) 315 | poses = poses.astype(np.float32) 316 | 317 | bounding_box = get_bbox3d_for_llff(poses[:,:3,:4], poses[0,:3,-1], near=0.0, far=1.0) 318 | 319 | return images, poses, bds, render_poses, i_test, bounding_box 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /load_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | import pyvista as pv 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_scannet_data(basedir, sceneID, half_res=False, trainskip=10, testskip=1): 38 | ''' 39 | basedir is something like: "/work/yashsb/datasets/ScanNet/" 40 | ''' 41 | scansdir = os.path.join(basedir, "scans") 42 | basedir = os.path.join(basedir, "nerfstyle_"+sceneID) 43 | 44 | splits = ['train', 'val', 'test'] 45 | metas = {} 46 | for s in splits: 47 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 48 | metas[s] = json.load(fp) 49 | 50 | all_imgs = [] 51 | all_poses = [] 52 | counts = [0] 53 | for s in splits: 54 | meta = metas[s] 55 | imgs = [] 56 | poses = [] 57 | if s=='train': 58 | skip = trainskip 59 | else: 60 | skip = testskip 61 | 62 | for frame in meta['frames'][::skip]: 63 | fname = os.path.join(basedir, frame['file_path'] + '.png') 64 | imgs.append(imageio.imread(fname)) 65 | pose = np.array(frame['transform_matrix']) 66 | 67 | ### NEED to do this because ScanNet uses OpenCV convention 68 | pose[:3, 1] *= -1 69 | pose[:3, 2] *= -1 70 | 71 | poses.append(pose) 72 | 73 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 74 | poses = np.array(poses).astype(np.float32) 75 | counts.append(counts[-1] + imgs.shape[0]) 76 | all_imgs.append(imgs) 77 | all_poses.append(poses) 78 | 79 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 80 | 81 | imgs = np.concatenate(all_imgs, 0) 82 | poses = np.concatenate(all_poses, 0) 83 | 84 | H, W = imgs[0].shape[:2] 85 | camera_angle_x = float(meta['camera_angle_x']) 86 | focal = .5 * W / np.tan(.5 * camera_angle_x) 87 | 88 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 89 | 90 | if half_res: 91 | H = H//2 92 | W = W//2 93 | focal = focal/2. 94 | 95 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 96 | for i, img in enumerate(imgs): 97 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 98 | imgs = imgs_half_res 99 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 100 | 101 | ## getting an approximate bounding box for the scene 102 | # load scene mesh 103 | mesh = pv.read(os.path.join(scansdir, sceneID, f"{sceneID}_vh_clean.ply")) 104 | # get the bounding box 105 | bounding_box = torch.tensor(mesh.bounds[::2]) - 1, torch.tensor(mesh.bounds[1::2]) + 1 106 | 107 | return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # Author: Yash Bhalgat 2 | 3 | from math import exp, log, floor 4 | import torch 5 | import torch.nn.functional as F 6 | import pdb 7 | 8 | from utils import hash 9 | 10 | 11 | def total_variation_loss(embeddings, min_resolution, max_resolution, level, log2_hashmap_size, n_levels=16): 12 | # Get resolution 13 | b = exp((log(max_resolution)-log(min_resolution))/(n_levels-1)) 14 | resolution = torch.tensor(floor(min_resolution * b**level)) 15 | 16 | # Cube size to apply TV loss 17 | min_cube_size = min_resolution - 1 18 | max_cube_size = 50 # can be tuned 19 | if min_cube_size > max_cube_size: 20 | print("ALERT! min cuboid size greater than max!") 21 | pdb.set_trace() 22 | cube_size = torch.floor(torch.clip(resolution/10.0, min_cube_size, max_cube_size)).int() 23 | 24 | # Sample cuboid 25 | min_vertex = torch.randint(0, resolution-cube_size, (3,)) 26 | idx = min_vertex + torch.stack([torch.arange(cube_size+1) for _ in range(3)], dim=-1) 27 | cube_indices = torch.stack(torch.meshgrid(idx[:,0], idx[:,1], idx[:,2]), dim=-1) 28 | 29 | hashed_indices = hash(cube_indices, log2_hashmap_size) 30 | cube_embeddings = embeddings(hashed_indices) 31 | #hashed_idx_offset_x = hash(idx+torch.tensor([1,0,0]), log2_hashmap_size) 32 | #hashed_idx_offset_y = hash(idx+torch.tensor([0,1,0]), log2_hashmap_size) 33 | #hashed_idx_offset_z = hash(idx+torch.tensor([0,0,1]), log2_hashmap_size) 34 | 35 | # Compute loss 36 | #tv_x = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_x), 2).sum() 37 | #tv_y = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_y), 2).sum() 38 | #tv_z = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_z), 2).sum() 39 | tv_x = torch.pow(cube_embeddings[1:,:,:,:]-cube_embeddings[:-1,:,:,:], 2).sum() 40 | tv_y = torch.pow(cube_embeddings[:,1:,:,:]-cube_embeddings[:,:-1,:,:], 2).sum() 41 | tv_z = torch.pow(cube_embeddings[:,:,1:,:]-cube_embeddings[:,:,:-1,:], 2).sum() 42 | 43 | return (tv_x + tv_y + tv_z)/cube_size 44 | 45 | def sigma_sparsity_loss(sigmas): 46 | # Using Cauchy Sparsity loss on sigma values 47 | return torch.log(1.0 + 2*sigmas**2).sum(dim=-1) 48 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os, sys 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | from functools import reduce 9 | from torch.optim import AdamW 10 | 11 | class MultiOptimizer: 12 | def __init__(self, optimizers={}): 13 | self.optimizers = optimizers 14 | self.keys = list(optimizers.keys()) 15 | self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()]) 16 | 17 | def state_dict(self): 18 | state_dicts = [(key, self.optimizers[key].state_dict())\ 19 | for key in self.keys] 20 | return state_dicts 21 | 22 | def load_state_dict(self, state_dict): 23 | for key, val in state_dict: 24 | try: 25 | self.optimizers[key].load_state_dict(val) 26 | except: 27 | print("Unloaded %s" % key) 28 | 29 | def step(self, key=None, scaler=None): 30 | keys = [key] if key is not None else self.keys 31 | _ = [self._step(key, scaler) for key in keys] 32 | 33 | def _step(self, key, scaler=None): 34 | if scaler is not None: 35 | scaler.step(self.optimizers[key]) 36 | scaler.update() 37 | else: 38 | self.optimizers[key].step() 39 | 40 | def zero_grad(self, key=None): 41 | if key is not None: 42 | self.optimizers[key].zero_grad() 43 | else: 44 | _ = [self.optimizers[key].zero_grad() for key in self.keys] 45 | -------------------------------------------------------------------------------- /radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss -------------------------------------------------------------------------------- /ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | 4 | 5 | def get_ray_directions(H, W, focal): 6 | """ 7 | Get ray directions for all pixels in camera coordinate. 8 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 9 | ray-tracing-generating-camera-rays/standard-coordinate-systems 10 | 11 | Inputs: 12 | H, W, focal: image height, width and focal length 13 | 14 | Outputs: 15 | directions: (H, W, 3), the direction of the rays in camera coordinate 16 | """ 17 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 18 | i, j = grid.unbind(-1) 19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 20 | # see https://github.com/bmild/nerf/issues/24 21 | directions = \ 22 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3) 23 | 24 | dir_bounds = directions.view(-1, 3) 25 | # print("Directions ", directions[0,0,:], directions[H-1,0,:], directions[0,W-1,:], directions[H-1, W-1, :]) 26 | # print("Directions ", dir_bounds[0], dir_bounds[W-1], dir_bounds[H*W-W], dir_bounds[H*W-1]) 27 | 28 | return directions 29 | 30 | 31 | def get_rays(directions, c2w): 32 | """ 33 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 34 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 35 | ray-tracing-generating-camera-rays/standard-coordinate-systems 36 | 37 | Inputs: 38 | directions: (H, W, 3) precomputed ray directions in camera coordinate 39 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 40 | 41 | Outputs: 42 | rays_o: (H*W, 3), the origin of the rays in world coordinate 43 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 44 | """ 45 | # Rotate ray directions from camera coordinate to the world coordinate 46 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 47 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 48 | # The origin of all rays is the camera origin in world coordinate 49 | rays_o = c2w[:3, -1].expand(rays_d.shape) # (H, W, 3) 50 | 51 | rays_d = rays_d.view(-1, 3) 52 | rays_o = rays_o.view(-1, 3) 53 | 54 | return rays_o, rays_d 55 | 56 | 57 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 58 | """ 59 | Transform rays from world coordinate to NDC. 60 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. 61 | For detailed derivation, please see: 62 | http://www.songho.ca/opengl/gl_projectionmatrix.html 63 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf 64 | 65 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth). 66 | See https://github.com/bmild/nerf/issues/18 67 | 68 | Inputs: 69 | H, W, focal: image height, width and focal length 70 | near: (N_rays) or float, the depths of the near plane 71 | rays_o: (N_rays, 3), the origin of the rays in world coordinate 72 | rays_d: (N_rays, 3), the direction of the rays in world coordinate 73 | 74 | Outputs: 75 | rays_o: (N_rays, 3), the origin of the rays in NDC 76 | rays_d: (N_rays, 3), the direction of the rays in NDC 77 | """ 78 | # Shift ray origins to near plane 79 | t = -(near + rays_o[...,2]) / rays_d[...,2] 80 | rays_o = rays_o + t[...,None] * rays_d 81 | 82 | # Store some intermediate homogeneous results 83 | ox_oz = rays_o[...,0] / rays_o[...,2] 84 | oy_oz = rays_o[...,1] / rays_o[...,2] 85 | 86 | # Projection 87 | o0 = -1./(W/(2.*focal)) * ox_oz 88 | o1 = -1./(H/(2.*focal)) * oy_oz 89 | o2 = 1. + 2. * near / rays_o[...,2] 90 | 91 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz) 92 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz) 93 | d2 = 1 - o2 94 | 95 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) 96 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) 97 | 98 | return rays_o, rays_d 99 | -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from datetime import datetime 3 | import numpy as np 4 | import imageio 5 | import json 6 | import pdb 7 | import random 8 | import time 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.distributions import Categorical 13 | from tqdm import tqdm, trange 14 | import pickle 15 | 16 | import matplotlib.pyplot as plt 17 | 18 | from run_nerf_helpers import * 19 | from optimizer import MultiOptimizer 20 | from radam import RAdam 21 | from loss import sigma_sparsity_loss, total_variation_loss 22 | 23 | from load_llff import load_llff_data 24 | from load_deepvoxels import load_dv_data 25 | from load_blender import load_blender_data 26 | from load_scannet import load_scannet_data 27 | from load_LINEMOD import load_LINEMOD_data 28 | 29 | 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | np.random.seed(0) 32 | DEBUG = False 33 | 34 | 35 | def batchify(fn, chunk): 36 | """Constructs a version of 'fn' that applies to smaller batches. 37 | """ 38 | if chunk is None: 39 | return fn 40 | def ret(inputs): 41 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 42 | return ret 43 | 44 | 45 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 46 | """Prepares inputs and applies network 'fn'. 47 | """ 48 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 49 | embedded, keep_mask = embed_fn(inputs_flat) 50 | 51 | if viewdirs is not None: 52 | input_dirs = viewdirs[:,None].expand(inputs.shape) 53 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 54 | embedded_dirs = embeddirs_fn(input_dirs_flat) 55 | embedded = torch.cat([embedded, embedded_dirs], -1) 56 | 57 | outputs_flat = batchify(fn, netchunk)(embedded) 58 | outputs_flat[~keep_mask, -1] = 0 # set sigma to 0 for invalid points 59 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 60 | return outputs 61 | 62 | 63 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 64 | """Render rays in smaller minibatches to avoid OOM. 65 | """ 66 | all_ret = {} 67 | for i in range(0, rays_flat.shape[0], chunk): 68 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 69 | for k in ret: 70 | if k not in all_ret: 71 | all_ret[k] = [] 72 | all_ret[k].append(ret[k]) 73 | 74 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 75 | return all_ret 76 | 77 | 78 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 79 | near=0., far=1., 80 | use_viewdirs=False, c2w_staticcam=None, 81 | **kwargs): 82 | """Render rays 83 | Args: 84 | H: int. Height of image in pixels. 85 | W: int. Width of image in pixels. 86 | focal: float. Focal length of pinhole camera. 87 | chunk: int. Maximum number of rays to process simultaneously. Used to 88 | control maximum memory usage. Does not affect final results. 89 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 90 | each example in batch. 91 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 92 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 93 | near: float or array of shape [batch_size]. Nearest distance for a ray. 94 | far: float or array of shape [batch_size]. Farthest distance for a ray. 95 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 96 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 97 | camera while using other c2w argument for viewing directions. 98 | Returns: 99 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 100 | disp_map: [batch_size]. Disparity map. Inverse of depth. 101 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 102 | extras: dict with everything returned by render_rays(). 103 | """ 104 | if c2w is not None: 105 | # special case to render full image 106 | rays_o, rays_d = get_rays(H, W, K, c2w) 107 | else: 108 | # use provided ray batch 109 | rays_o, rays_d = rays 110 | 111 | if use_viewdirs: 112 | # provide ray directions as input 113 | viewdirs = rays_d 114 | if c2w_staticcam is not None: 115 | # special case to visualize effect of viewdirs 116 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 117 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 118 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 119 | 120 | sh = rays_d.shape # [..., 3] 121 | if ndc: 122 | # for forward facing scenes 123 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 124 | 125 | # Create ray batch 126 | rays_o = torch.reshape(rays_o, [-1,3]).float() 127 | rays_d = torch.reshape(rays_d, [-1,3]).float() 128 | 129 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 130 | rays = torch.cat([rays_o, rays_d, near, far], -1) 131 | if use_viewdirs: 132 | rays = torch.cat([rays, viewdirs], -1) 133 | 134 | # Render and reshape 135 | all_ret = batchify_rays(rays, chunk, **kwargs) 136 | for k in all_ret: 137 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 138 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 139 | 140 | k_extract = ['rgb_map', 'depth_map', 'acc_map'] 141 | ret_list = [all_ret[k] for k in k_extract] 142 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 143 | return ret_list + [ret_dict] 144 | 145 | 146 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 147 | 148 | H, W, focal = hwf 149 | near, far = render_kwargs['near'], render_kwargs['far'] 150 | 151 | if render_factor!=0: 152 | # Render downsampled for speed 153 | H = H//render_factor 154 | W = W//render_factor 155 | focal = focal/render_factor 156 | 157 | rgbs = [] 158 | depths = [] 159 | psnrs = [] 160 | 161 | t = time.time() 162 | for i, c2w in enumerate(tqdm(render_poses)): 163 | print(i, time.time() - t) 164 | t = time.time() 165 | rgb, depth, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 166 | rgbs.append(rgb.cpu().numpy()) 167 | # normalize depth to [0,1] 168 | depth = (depth - near) / (far - near) 169 | depths.append(depth.cpu().numpy()) 170 | if i==0: 171 | print(rgb.shape, depth.shape) 172 | 173 | if gt_imgs is not None and render_factor==0: 174 | try: 175 | gt_img = gt_imgs[i].cpu().numpy() 176 | except: 177 | gt_img = gt_imgs[i] 178 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_img))) 179 | print(p) 180 | psnrs.append(p) 181 | 182 | if savedir is not None: 183 | # save rgb and depth as a figure 184 | fig = plt.figure(figsize=(25,15)) 185 | ax = fig.add_subplot(1, 2, 1) 186 | rgb8 = to8b(rgbs[-1]) 187 | ax.imshow(rgb8) 188 | ax.axis('off') 189 | ax = fig.add_subplot(1, 2, 2) 190 | ax.imshow(depths[-1], cmap='plasma', vmin=0, vmax=1) 191 | ax.axis('off') 192 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 193 | # save as png 194 | plt.savefig(filename, bbox_inches='tight', pad_inches=0) 195 | plt.close(fig) 196 | # imageio.imwrite(filename, rgb8) 197 | 198 | 199 | rgbs = np.stack(rgbs, 0) 200 | depths = np.stack(depths, 0) 201 | if gt_imgs is not None and render_factor==0: 202 | avg_psnr = sum(psnrs)/len(psnrs) 203 | print("Avg PSNR over Test set: ", avg_psnr) 204 | with open(os.path.join(savedir, "test_psnrs_avg{:0.2f}.pkl".format(avg_psnr)), "wb") as fp: 205 | pickle.dump(psnrs, fp) 206 | 207 | return rgbs, depths 208 | 209 | 210 | def create_nerf(args): 211 | """Instantiate NeRF's MLP model. 212 | """ 213 | embed_fn, input_ch = get_embedder(args.multires, args, i=args.i_embed) 214 | if args.i_embed==1: 215 | # hashed embedding table 216 | embedding_params = list(embed_fn.parameters()) 217 | 218 | input_ch_views = 0 219 | embeddirs_fn = None 220 | if args.use_viewdirs: 221 | # if using hashed for xyz, use SH for views 222 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args, i=args.i_embed_views) 223 | 224 | output_ch = 5 if args.N_importance > 0 else 4 225 | skips = [4] 226 | 227 | if args.i_embed==1: 228 | model = NeRFSmall(num_layers=2, 229 | hidden_dim=64, 230 | geo_feat_dim=15, 231 | num_layers_color=3, 232 | hidden_dim_color=64, 233 | input_ch=input_ch, input_ch_views=input_ch_views).to(device) 234 | else: 235 | model = NeRF(D=args.netdepth, W=args.netwidth, 236 | input_ch=input_ch, output_ch=output_ch, skips=skips, 237 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 238 | grad_vars = list(model.parameters()) 239 | 240 | model_fine = None 241 | 242 | if args.N_importance > 0: 243 | if args.i_embed==1: 244 | model_fine = NeRFSmall(num_layers=2, 245 | hidden_dim=64, 246 | geo_feat_dim=15, 247 | num_layers_color=3, 248 | hidden_dim_color=64, 249 | input_ch=input_ch, input_ch_views=input_ch_views).to(device) 250 | else: 251 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 252 | input_ch=input_ch, output_ch=output_ch, skips=skips, 253 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 254 | grad_vars += list(model_fine.parameters()) 255 | 256 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 257 | embed_fn=embed_fn, 258 | embeddirs_fn=embeddirs_fn, 259 | netchunk=args.netchunk) 260 | 261 | # Create optimizer 262 | if args.i_embed==1: 263 | optimizer = RAdam([ 264 | {'params': grad_vars, 'weight_decay': 1e-6}, 265 | {'params': embedding_params, 'eps': 1e-15} 266 | ], lr=args.lrate, betas=(0.9, 0.99)) 267 | else: 268 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 269 | 270 | start = 0 271 | basedir = args.basedir 272 | expname = args.expname 273 | 274 | ########################## 275 | 276 | # Load checkpoints 277 | if args.ft_path is not None and args.ft_path!='None': 278 | ckpts = [args.ft_path] 279 | else: 280 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 281 | 282 | print('Found ckpts', ckpts) 283 | if len(ckpts) > 0 and not args.no_reload: 284 | ckpt_path = ckpts[-1] 285 | print('Reloading from', ckpt_path) 286 | ckpt = torch.load(ckpt_path) 287 | 288 | start = ckpt['global_step'] 289 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 290 | 291 | # Load model 292 | model.load_state_dict(ckpt['network_fn_state_dict']) 293 | if model_fine is not None: 294 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 295 | if args.i_embed==1: 296 | embed_fn.load_state_dict(ckpt['embed_fn_state_dict']) 297 | 298 | ########################## 299 | # pdb.set_trace() 300 | 301 | render_kwargs_train = { 302 | 'network_query_fn' : network_query_fn, 303 | 'perturb' : args.perturb, 304 | 'N_importance' : args.N_importance, 305 | 'network_fine' : model_fine, 306 | 'N_samples' : args.N_samples, 307 | 'network_fn' : model, 308 | 'embed_fn': embed_fn, 309 | 'use_viewdirs' : args.use_viewdirs, 310 | 'white_bkgd' : args.white_bkgd, 311 | 'raw_noise_std' : args.raw_noise_std, 312 | } 313 | 314 | # NDC only good for LLFF-style forward facing data 315 | if args.dataset_type != 'llff' or args.no_ndc: 316 | print('Not ndc!') 317 | render_kwargs_train['ndc'] = False 318 | render_kwargs_train['lindisp'] = args.lindisp 319 | 320 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 321 | render_kwargs_test['perturb'] = False 322 | render_kwargs_test['raw_noise_std'] = 0. 323 | 324 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 325 | 326 | 327 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 328 | """Transforms model's predictions to semantically meaningful values. 329 | Args: 330 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 331 | z_vals: [num_rays, num_samples along ray]. Integration time. 332 | rays_d: [num_rays, 3]. Direction of each ray. 333 | Returns: 334 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 335 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 336 | acc_map: [num_rays]. Sum of weights along each ray. 337 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 338 | depth_map: [num_rays]. Estimated distance to object. 339 | """ 340 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 341 | 342 | dists = z_vals[...,1:] - z_vals[...,:-1] 343 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 344 | 345 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 346 | 347 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 348 | noise = 0. 349 | if raw_noise_std > 0.: 350 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 351 | 352 | # Overwrite randomly sampled data if pytest 353 | if pytest: 354 | np.random.seed(0) 355 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 356 | noise = torch.Tensor(noise) 357 | 358 | # sigma_loss = sigma_sparsity_loss(raw[...,3]) 359 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 360 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 361 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 362 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 363 | 364 | depth_map = torch.sum(weights * z_vals, -1) / torch.sum(weights, -1) 365 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map) 366 | acc_map = torch.sum(weights, -1) 367 | 368 | if white_bkgd: 369 | rgb_map = rgb_map + (1.-acc_map[...,None]) 370 | 371 | # Calculate weights sparsity loss 372 | try: 373 | entropy = Categorical(probs = torch.cat([weights, 1.0-weights.sum(-1, keepdim=True)+1e-6], dim=-1)).entropy() 374 | except: 375 | pdb.set_trace() 376 | sparsity_loss = entropy 377 | 378 | return rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss 379 | 380 | 381 | def render_rays(ray_batch, 382 | network_fn, 383 | network_query_fn, 384 | N_samples, 385 | embed_fn=None, 386 | retraw=False, 387 | lindisp=False, 388 | perturb=0., 389 | N_importance=0, 390 | network_fine=None, 391 | white_bkgd=False, 392 | raw_noise_std=0., 393 | verbose=False, 394 | pytest=False): 395 | """Volumetric rendering. 396 | Args: 397 | ray_batch: array of shape [batch_size, ...]. All information necessary 398 | for sampling along a ray, including: ray origin, ray direction, min 399 | dist, max dist, and unit-magnitude viewing direction. 400 | network_fn: function. Model for predicting RGB and density at each point 401 | in space. 402 | network_query_fn: function used for passing queries to network_fn. 403 | N_samples: int. Number of different times to sample along each ray. 404 | retraw: bool. If True, include model's raw, unprocessed predictions. 405 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 406 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 407 | random points in time. 408 | N_importance: int. Number of additional times to sample along each ray. 409 | These samples are only passed to network_fine. 410 | network_fine: "fine" network with same spec as network_fn. 411 | white_bkgd: bool. If True, assume a white background. 412 | raw_noise_std: ... 413 | verbose: bool. If True, print more debugging info. 414 | Returns: 415 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 416 | disp_map: [num_rays]. Disparity map. 1 / depth. 417 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 418 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 419 | rgb0: See rgb_map. Output for coarse model. 420 | disp0: See disp_map. Output for coarse model. 421 | acc0: See acc_map. Output for coarse model. 422 | z_std: [num_rays]. Standard deviation of distances along ray for each 423 | sample. 424 | """ 425 | N_rays = ray_batch.shape[0] 426 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 427 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 428 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 429 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 430 | 431 | t_vals = torch.linspace(0., 1., steps=N_samples) 432 | if not lindisp: 433 | z_vals = near * (1.-t_vals) + far * (t_vals) 434 | else: 435 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 436 | 437 | z_vals = z_vals.expand([N_rays, N_samples]) 438 | 439 | if perturb > 0.: 440 | # get intervals between samples 441 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 442 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 443 | lower = torch.cat([z_vals[...,:1], mids], -1) 444 | # stratified samples in those intervals 445 | t_rand = torch.rand(z_vals.shape) 446 | 447 | # Pytest, overwrite u with numpy's fixed random numbers 448 | if pytest: 449 | np.random.seed(0) 450 | t_rand = np.random.rand(*list(z_vals.shape)) 451 | t_rand = torch.Tensor(t_rand) 452 | 453 | z_vals = lower + (upper - lower) * t_rand 454 | 455 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 456 | 457 | raw = network_query_fn(pts, viewdirs, network_fn) 458 | rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 459 | 460 | if N_importance > 0: 461 | 462 | rgb_map_0, depth_map_0, acc_map_0, sparsity_loss_0 = rgb_map, depth_map, acc_map, sparsity_loss 463 | 464 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 465 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 466 | z_samples = z_samples.detach() 467 | 468 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 469 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 470 | 471 | run_fn = network_fn if network_fine is None else network_fine 472 | # raw = run_network(pts, fn=run_fn) 473 | raw = network_query_fn(pts, viewdirs, run_fn) 474 | 475 | rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 476 | 477 | ret = {'rgb_map' : rgb_map, 'depth_map' : depth_map, 'acc_map' : acc_map, 'sparsity_loss': sparsity_loss} 478 | if retraw: 479 | ret['raw'] = raw 480 | if N_importance > 0: 481 | ret['rgb0'] = rgb_map_0 482 | ret['depth0'] = depth_map_0 483 | ret['acc0'] = acc_map_0 484 | ret['sparsity_loss0'] = sparsity_loss_0 485 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 486 | 487 | for k in ret: 488 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 489 | print(f"! [Numerical Error] {k} contains nan or inf.") 490 | 491 | return ret 492 | 493 | 494 | def config_parser(): 495 | 496 | import configargparse 497 | parser = configargparse.ArgumentParser() 498 | parser.add_argument('--config', is_config_file=True, 499 | help='config file path') 500 | parser.add_argument("--expname", type=str, 501 | help='experiment name') 502 | parser.add_argument("--basedir", type=str, default='./logs/', 503 | help='where to store ckpts and logs') 504 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 505 | help='input data directory') 506 | 507 | # training options 508 | parser.add_argument("--netdepth", type=int, default=8, 509 | help='layers in network') 510 | parser.add_argument("--netwidth", type=int, default=256, 511 | help='channels per layer') 512 | parser.add_argument("--netdepth_fine", type=int, default=8, 513 | help='layers in fine network') 514 | parser.add_argument("--netwidth_fine", type=int, default=256, 515 | help='channels per layer in fine network') 516 | parser.add_argument("--N_rand", type=int, default=32*32*4, 517 | help='batch size (number of random rays per gradient step)') 518 | parser.add_argument("--lrate", type=float, default=5e-4, 519 | help='learning rate') 520 | parser.add_argument("--lrate_decay", type=int, default=250, 521 | help='exponential learning rate decay (in 1000 steps)') 522 | parser.add_argument("--chunk", type=int, default=1024*32, 523 | help='number of rays processed in parallel, decrease if running out of memory') 524 | parser.add_argument("--netchunk", type=int, default=1024*64, 525 | help='number of pts sent through network in parallel, decrease if running out of memory') 526 | parser.add_argument("--no_batching", action='store_true', 527 | help='only take random rays from 1 image at a time') 528 | parser.add_argument("--no_reload", action='store_true', 529 | help='do not reload weights from saved ckpt') 530 | parser.add_argument("--ft_path", type=str, default=None, 531 | help='specific weights npy file to reload for coarse network') 532 | 533 | # rendering options 534 | parser.add_argument("--N_samples", type=int, default=64, 535 | help='number of coarse samples per ray') 536 | parser.add_argument("--N_importance", type=int, default=0, 537 | help='number of additional fine samples per ray') 538 | parser.add_argument("--perturb", type=float, default=1., 539 | help='set to 0. for no jitter, 1. for jitter') 540 | parser.add_argument("--use_viewdirs", action='store_true', 541 | help='use full 5D input instead of 3D') 542 | parser.add_argument("--i_embed", type=int, default=1, 543 | help='set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical') 544 | parser.add_argument("--i_embed_views", type=int, default=2, 545 | help='set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical') 546 | parser.add_argument("--multires", type=int, default=10, 547 | help='log2 of max freq for positional encoding (3D location)') 548 | parser.add_argument("--multires_views", type=int, default=4, 549 | help='log2 of max freq for positional encoding (2D direction)') 550 | parser.add_argument("--raw_noise_std", type=float, default=0., 551 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 552 | 553 | parser.add_argument("--render_only", action='store_true', 554 | help='do not optimize, reload weights and render out render_poses path') 555 | parser.add_argument("--render_test", action='store_true', 556 | help='render the test set instead of render_poses path') 557 | parser.add_argument("--render_factor", type=int, default=0, 558 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 559 | 560 | # training options 561 | parser.add_argument("--precrop_iters", type=int, default=0, 562 | help='number of steps to train on central crops') 563 | parser.add_argument("--precrop_frac", type=float, 564 | default=.5, help='fraction of img taken for central crops') 565 | 566 | # dataset options 567 | parser.add_argument("--dataset_type", type=str, default='llff', 568 | help='options: llff / blender / deepvoxels') 569 | parser.add_argument("--testskip", type=int, default=8, 570 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 571 | 572 | ## deepvoxels flags 573 | parser.add_argument("--shape", type=str, default='greek', 574 | help='options : armchair / cube / greek / vase') 575 | 576 | ## blender flags 577 | parser.add_argument("--white_bkgd", action='store_true', 578 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 579 | parser.add_argument("--half_res", action='store_true', 580 | help='load blender synthetic data at 400x400 instead of 800x800') 581 | 582 | ## scannet flags 583 | parser.add_argument("--scannet_sceneID", type=str, default='scene0000_00', 584 | help='sceneID to load from scannet') 585 | 586 | ## llff flags 587 | parser.add_argument("--factor", type=int, default=8, 588 | help='downsample factor for LLFF images') 589 | parser.add_argument("--no_ndc", action='store_true', 590 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 591 | parser.add_argument("--lindisp", action='store_true', 592 | help='sampling linearly in disparity rather than depth') 593 | parser.add_argument("--spherify", action='store_true', 594 | help='set for spherical 360 scenes') 595 | parser.add_argument("--llffhold", type=int, default=8, 596 | help='will take every 1/N images as LLFF test set, paper uses 8') 597 | 598 | # logging/saving options 599 | parser.add_argument("--i_print", type=int, default=100, 600 | help='frequency of console printout and metric loggin') 601 | parser.add_argument("--i_img", type=int, default=500, 602 | help='frequency of tensorboard image logging') 603 | parser.add_argument("--i_weights", type=int, default=10000, 604 | help='frequency of weight ckpt saving') 605 | parser.add_argument("--i_testset", type=int, default=1000, 606 | help='frequency of testset saving') 607 | parser.add_argument("--i_video", type=int, default=5000, 608 | help='frequency of render_poses video saving') 609 | 610 | parser.add_argument("--finest_res", type=int, default=512, 611 | help='finest resolultion for hashed embedding') 612 | parser.add_argument("--log2_hashmap_size", type=int, default=19, 613 | help='log2 of hashmap size') 614 | parser.add_argument("--sparse-loss-weight", type=float, default=1e-10, 615 | help='learning rate') 616 | parser.add_argument("--tv-loss-weight", type=float, default=1e-6, 617 | help='learning rate') 618 | 619 | return parser 620 | 621 | 622 | def train(): 623 | 624 | parser = config_parser() 625 | args = parser.parse_args() 626 | 627 | # Load data 628 | K = None 629 | if args.dataset_type == 'llff': 630 | images, poses, bds, render_poses, i_test, bounding_box = load_llff_data(args.datadir, args.factor, 631 | recenter=True, bd_factor=.75, 632 | spherify=args.spherify) 633 | hwf = poses[0,:3,-1] 634 | poses = poses[:,:3,:4] 635 | args.bounding_box = bounding_box 636 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 637 | 638 | if not isinstance(i_test, list): 639 | i_test = [i_test] 640 | 641 | if args.llffhold > 0: 642 | print('Auto LLFF holdout,', args.llffhold) 643 | i_test = np.arange(images.shape[0])[::args.llffhold] 644 | 645 | i_val = i_test 646 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 647 | (i not in i_test and i not in i_val)]) 648 | 649 | print('DEFINING BOUNDS') 650 | if args.no_ndc: 651 | near = np.ndarray.min(bds) * .9 652 | far = np.ndarray.max(bds) * 1. 653 | 654 | else: 655 | near = 0. 656 | far = 1. 657 | print('NEAR FAR', near, far) 658 | 659 | elif args.dataset_type == 'blender': 660 | images, poses, render_poses, hwf, i_split, bounding_box = load_blender_data(args.datadir, args.half_res, args.testskip) 661 | args.bounding_box = bounding_box 662 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 663 | i_train, i_val, i_test = i_split 664 | 665 | near = 2. 666 | far = 6. 667 | 668 | if args.white_bkgd: 669 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 670 | else: 671 | images = images[...,:3] 672 | 673 | elif args.dataset_type == 'scannet': 674 | images, poses, render_poses, hwf, i_split, bounding_box = load_scannet_data(args.datadir, args.scannet_sceneID, args.half_res) 675 | args.bounding_box = bounding_box 676 | print('Loaded scannet', images.shape, render_poses.shape, hwf, args.datadir) 677 | i_train, i_val, i_test = i_split 678 | 679 | near = 0.1 680 | far = 10.0 681 | 682 | elif args.dataset_type == 'LINEMOD': 683 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) 684 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') 685 | print(f'[CHECK HERE] near: {near}, far: {far}.') 686 | i_train, i_val, i_test = i_split 687 | 688 | if args.white_bkgd: 689 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 690 | else: 691 | images = images[...,:3] 692 | 693 | elif args.dataset_type == 'deepvoxels': 694 | 695 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 696 | basedir=args.datadir, 697 | testskip=args.testskip) 698 | 699 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) 700 | i_train, i_val, i_test = i_split 701 | 702 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) 703 | near = hemi_R-1. 704 | far = hemi_R+1. 705 | 706 | else: 707 | print('Unknown dataset type', args.dataset_type, 'exiting') 708 | return 709 | 710 | # Cast intrinsics to right types 711 | H, W, focal = hwf 712 | H, W = int(H), int(W) 713 | hwf = [H, W, focal] 714 | 715 | if K is None: 716 | K = np.array([ 717 | [focal, 0, 0.5*W], 718 | [0, focal, 0.5*H], 719 | [0, 0, 1] 720 | ]) 721 | 722 | if args.render_test: 723 | render_poses = np.array(poses[i_test]) 724 | 725 | # Create log dir and copy the config file 726 | basedir = args.basedir 727 | if args.i_embed==1: 728 | args.expname += "_hashXYZ" 729 | elif args.i_embed==0: 730 | args.expname += "_posXYZ" 731 | if args.i_embed_views==2: 732 | args.expname += "_sphereVIEW" 733 | elif args.i_embed_views==0: 734 | args.expname += "_posVIEW" 735 | args.expname += "_fine"+str(args.finest_res) + "_log2T"+str(args.log2_hashmap_size) 736 | args.expname += "_lr"+str(args.lrate) + "_decay"+str(args.lrate_decay) 737 | args.expname += "_RAdam" 738 | if args.sparse_loss_weight > 0: 739 | args.expname += "_sparse" + str(args.sparse_loss_weight) 740 | args.expname += "_TV" + str(args.tv_loss_weight) 741 | #args.expname += datetime.now().strftime('_%H_%M_%d_%m_%Y') 742 | expname = args.expname 743 | 744 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 745 | f = os.path.join(basedir, expname, 'args.txt') 746 | with open(f, 'w') as file: 747 | for arg in sorted(vars(args)): 748 | attr = getattr(args, arg) 749 | file.write('{} = {}\n'.format(arg, attr)) 750 | if args.config is not None: 751 | f = os.path.join(basedir, expname, 'config.txt') 752 | with open(f, 'w') as file: 753 | file.write(open(args.config, 'r').read()) 754 | 755 | # Create nerf model 756 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 757 | global_step = start 758 | 759 | bds_dict = { 760 | 'near' : near, 761 | 'far' : far, 762 | } 763 | render_kwargs_train.update(bds_dict) 764 | render_kwargs_test.update(bds_dict) 765 | 766 | # Move testing data to GPU 767 | render_poses = torch.Tensor(render_poses).to(device) 768 | 769 | # Short circuit if only rendering out from trained model 770 | if args.render_only: 771 | print('RENDER ONLY') 772 | with torch.no_grad(): 773 | if args.render_test: 774 | # render_test switches to test poses 775 | images = images[i_test] 776 | else: 777 | # Default is smoother render_poses path 778 | images = None 779 | 780 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 781 | os.makedirs(testsavedir, exist_ok=True) 782 | print('test poses shape', render_poses.shape) 783 | 784 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) 785 | print('Done rendering', testsavedir) 786 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 787 | 788 | return 789 | 790 | # Prepare raybatch tensor if batching random rays 791 | N_rand = args.N_rand 792 | use_batching = not args.no_batching 793 | if use_batching: 794 | # For random ray batching 795 | print('get rays') 796 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 797 | print('done, concats') 798 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 799 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 800 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 801 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 802 | rays_rgb = rays_rgb.astype(np.float32) 803 | print('shuffle rays') 804 | np.random.shuffle(rays_rgb) 805 | 806 | print('done') 807 | i_batch = 0 808 | 809 | # Move training data to GPU 810 | if use_batching: 811 | images = torch.Tensor(images).to(device) 812 | poses = torch.Tensor(poses).to(device) 813 | if use_batching: 814 | rays_rgb = torch.Tensor(rays_rgb).to(device) 815 | 816 | 817 | N_iters = 50000 + 1 818 | print('Begin') 819 | print('TRAIN views are', i_train) 820 | print('TEST views are', i_test) 821 | print('VAL views are', i_val) 822 | 823 | # Summary writers 824 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) 825 | 826 | loss_list = [] 827 | psnr_list = [] 828 | time_list = [] 829 | start = start + 1 830 | time0 = time.time() 831 | for i in trange(start, N_iters): 832 | # Sample random ray batch 833 | if use_batching: 834 | # Random over all images 835 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] 836 | batch = torch.transpose(batch, 0, 1) 837 | batch_rays, target_s = batch[:2], batch[2] 838 | 839 | i_batch += N_rand 840 | if i_batch >= rays_rgb.shape[0]: 841 | print("Shuffle data after an epoch!") 842 | rand_idx = torch.randperm(rays_rgb.shape[0]) 843 | rays_rgb = rays_rgb[rand_idx] 844 | i_batch = 0 845 | 846 | else: 847 | # Random from one image 848 | img_i = np.random.choice(i_train) 849 | target = images[img_i] 850 | target = torch.Tensor(target).to(device) 851 | pose = poses[img_i, :3,:4] 852 | 853 | if N_rand is not None: 854 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) 855 | 856 | if i < args.precrop_iters: 857 | dH = int(H//2 * args.precrop_frac) 858 | dW = int(W//2 * args.precrop_frac) 859 | coords = torch.stack( 860 | torch.meshgrid( 861 | torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 862 | torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) 863 | ), -1) 864 | if i == start: 865 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") 866 | else: 867 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) 868 | 869 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 870 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 871 | select_coords = coords[select_inds].long() # (N_rand, 2) 872 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 873 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 874 | batch_rays = torch.stack([rays_o, rays_d], 0) 875 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 876 | 877 | ##### Core optimization loop ##### 878 | rgb, depth, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 879 | verbose=i < 10, retraw=True, 880 | **render_kwargs_train) 881 | 882 | optimizer.zero_grad() 883 | img_loss = img2mse(rgb, target_s) 884 | trans = extras['raw'][...,-1] 885 | loss = img_loss 886 | psnr = mse2psnr(img_loss) 887 | 888 | if 'rgb0' in extras: 889 | img_loss0 = img2mse(extras['rgb0'], target_s) 890 | loss = loss + img_loss0 891 | psnr0 = mse2psnr(img_loss0) 892 | 893 | sparsity_loss = args.sparse_loss_weight*(extras["sparsity_loss"].sum() + extras["sparsity_loss0"].sum()) 894 | loss = loss + sparsity_loss 895 | 896 | # add Total Variation loss 897 | if args.i_embed==1: 898 | n_levels = render_kwargs_train["embed_fn"].n_levels 899 | min_res = render_kwargs_train["embed_fn"].base_resolution 900 | max_res = render_kwargs_train["embed_fn"].finest_resolution 901 | log2_hashmap_size = render_kwargs_train["embed_fn"].log2_hashmap_size 902 | TV_loss = sum(total_variation_loss(render_kwargs_train["embed_fn"].embeddings[i], \ 903 | min_res, max_res, \ 904 | i, log2_hashmap_size, \ 905 | n_levels=n_levels) for i in range(n_levels)) 906 | loss = loss + args.tv_loss_weight * TV_loss 907 | if i>1000: 908 | args.tv_loss_weight = 0.0 909 | 910 | loss.backward() 911 | # pdb.set_trace() 912 | optimizer.step() 913 | 914 | # NOTE: IMPORTANT! 915 | ### update learning rate ### 916 | decay_rate = 0.1 917 | decay_steps = args.lrate_decay * 1000 918 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 919 | for param_group in optimizer.param_groups: 920 | param_group['lr'] = new_lrate 921 | ################################ 922 | 923 | t = time.time()-time0 924 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") 925 | ##### end ##### 926 | 927 | # Rest is logging 928 | if i%args.i_weights==0: 929 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 930 | if args.i_embed==1: 931 | torch.save({ 932 | 'global_step': global_step, 933 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 934 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 935 | 'embed_fn_state_dict': render_kwargs_train['embed_fn'].state_dict(), 936 | 'optimizer_state_dict': optimizer.state_dict(), 937 | }, path) 938 | else: 939 | torch.save({ 940 | 'global_step': global_step, 941 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 942 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 943 | 'optimizer_state_dict': optimizer.state_dict(), 944 | }, path) 945 | print('Saved checkpoints at', path) 946 | 947 | if i%args.i_video==0 and i > 0: 948 | # Turn on testing mode 949 | with torch.no_grad(): 950 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) 951 | print('Done, saving', rgbs.shape, disps.shape) 952 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 953 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) 954 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) 955 | 956 | # if args.use_viewdirs: 957 | # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] 958 | # with torch.no_grad(): 959 | # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) 960 | # render_kwargs_test['c2w_staticcam'] = None 961 | # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) 962 | 963 | if i%args.i_testset==0 and i > 0: 964 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 965 | os.makedirs(testsavedir, exist_ok=True) 966 | print('test poses shape', poses[i_test].shape) 967 | with torch.no_grad(): 968 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) 969 | print('Saved test set') 970 | 971 | 972 | 973 | if i%args.i_print==0: 974 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 975 | loss_list.append(loss.item()) 976 | psnr_list.append(psnr.item()) 977 | time_list.append(t) 978 | loss_psnr_time = { 979 | "losses": loss_list, 980 | "psnr": psnr_list, 981 | "time": time_list 982 | } 983 | with open(os.path.join(basedir, expname, "loss_vs_time.pkl"), "wb") as fp: 984 | pickle.dump(loss_psnr_time, fp) 985 | 986 | global_step += 1 987 | 988 | 989 | if __name__=='__main__': 990 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 991 | 992 | train() 993 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from hash_encoding import HashEmbedder, SHEncoder 8 | 9 | # Misc 10 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 11 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 12 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 13 | 14 | 15 | # Positional encoding (section 5.1) 16 | class Embedder: 17 | def __init__(self, **kwargs): 18 | self.kwargs = kwargs 19 | self.create_embedding_fn() 20 | 21 | def create_embedding_fn(self): 22 | embed_fns = [] 23 | d = self.kwargs['input_dims'] 24 | out_dim = 0 25 | if self.kwargs['include_input']: 26 | embed_fns.append(lambda x : x) 27 | out_dim += d 28 | 29 | max_freq = self.kwargs['max_freq_log2'] 30 | N_freqs = self.kwargs['num_freqs'] 31 | 32 | if self.kwargs['log_sampling']: 33 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 34 | else: 35 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 36 | 37 | for freq in freq_bands: 38 | for p_fn in self.kwargs['periodic_fns']: 39 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 40 | out_dim += d 41 | 42 | self.embed_fns = embed_fns 43 | self.out_dim = out_dim 44 | 45 | def embed(self, inputs): 46 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 47 | 48 | 49 | def get_embedder(multires, args, i=0): 50 | if i == -1: 51 | return nn.Identity(), 3 52 | elif i==0: 53 | embed_kwargs = { 54 | 'include_input' : True, 55 | 'input_dims' : 3, 56 | 'max_freq_log2' : multires-1, 57 | 'num_freqs' : multires, 58 | 'log_sampling' : True, 59 | 'periodic_fns' : [torch.sin, torch.cos], 60 | } 61 | 62 | embedder_obj = Embedder(**embed_kwargs) 63 | embed = lambda x, eo=embedder_obj : eo.embed(x) 64 | out_dim = embedder_obj.out_dim 65 | elif i==1: 66 | embed = HashEmbedder(bounding_box=args.bounding_box, \ 67 | log2_hashmap_size=args.log2_hashmap_size, \ 68 | finest_resolution=args.finest_res) 69 | out_dim = embed.out_dim 70 | elif i==2: 71 | embed = SHEncoder() 72 | out_dim = embed.out_dim 73 | return embed, out_dim 74 | 75 | 76 | # Model 77 | class NeRF(nn.Module): 78 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 79 | """ 80 | """ 81 | super(NeRF, self).__init__() 82 | self.D = D 83 | self.W = W 84 | self.input_ch = input_ch 85 | self.input_ch_views = input_ch_views 86 | self.skips = skips 87 | self.use_viewdirs = use_viewdirs 88 | 89 | self.pts_linears = nn.ModuleList( 90 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 91 | 92 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 93 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 94 | 95 | ### Implementation according to the paper 96 | # self.views_linears = nn.ModuleList( 97 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 98 | 99 | if use_viewdirs: 100 | self.feature_linear = nn.Linear(W, W) 101 | self.alpha_linear = nn.Linear(W, 1) 102 | self.rgb_linear = nn.Linear(W//2, 3) 103 | else: 104 | self.output_linear = nn.Linear(W, output_ch) 105 | 106 | def forward(self, x): 107 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 108 | h = input_pts 109 | for i, l in enumerate(self.pts_linears): 110 | h = self.pts_linears[i](h) 111 | h = F.relu(h) 112 | if i in self.skips: 113 | h = torch.cat([input_pts, h], -1) 114 | 115 | if self.use_viewdirs: 116 | alpha = self.alpha_linear(h) 117 | feature = self.feature_linear(h) 118 | h = torch.cat([feature, input_views], -1) 119 | 120 | for i, l in enumerate(self.views_linears): 121 | h = self.views_linears[i](h) 122 | h = F.relu(h) 123 | 124 | rgb = self.rgb_linear(h) 125 | outputs = torch.cat([rgb, alpha], -1) 126 | else: 127 | outputs = self.output_linear(h) 128 | 129 | return outputs 130 | 131 | def load_weights_from_keras(self, weights): 132 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 133 | 134 | # Load pts_linears 135 | for i in range(self.D): 136 | idx_pts_linears = 2 * i 137 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 138 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 139 | 140 | # Load feature_linear 141 | idx_feature_linear = 2 * self.D 142 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 143 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 144 | 145 | # Load views_linears 146 | idx_views_linears = 2 * self.D + 2 147 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 148 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 149 | 150 | # Load rgb_linear 151 | idx_rbg_linear = 2 * self.D + 4 152 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 153 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 154 | 155 | # Load alpha_linear 156 | idx_alpha_linear = 2 * self.D + 6 157 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 158 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 159 | 160 | 161 | # Small NeRF for Hash embeddings 162 | class NeRFSmall(nn.Module): 163 | def __init__(self, 164 | num_layers=3, 165 | hidden_dim=64, 166 | geo_feat_dim=15, 167 | num_layers_color=4, 168 | hidden_dim_color=64, 169 | input_ch=3, input_ch_views=3, 170 | ): 171 | super(NeRFSmall, self).__init__() 172 | 173 | self.input_ch = input_ch 174 | self.input_ch_views = input_ch_views 175 | 176 | # sigma network 177 | self.num_layers = num_layers 178 | self.hidden_dim = hidden_dim 179 | self.geo_feat_dim = geo_feat_dim 180 | 181 | sigma_net = [] 182 | for l in range(num_layers): 183 | if l == 0: 184 | in_dim = self.input_ch 185 | else: 186 | in_dim = hidden_dim 187 | 188 | if l == num_layers - 1: 189 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color 190 | else: 191 | out_dim = hidden_dim 192 | 193 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 194 | 195 | self.sigma_net = nn.ModuleList(sigma_net) 196 | 197 | # color network 198 | self.num_layers_color = num_layers_color 199 | self.hidden_dim_color = hidden_dim_color 200 | 201 | color_net = [] 202 | for l in range(num_layers_color): 203 | if l == 0: 204 | in_dim = self.input_ch_views + self.geo_feat_dim 205 | else: 206 | in_dim = hidden_dim 207 | 208 | if l == num_layers_color - 1: 209 | out_dim = 3 # 3 rgb 210 | else: 211 | out_dim = hidden_dim 212 | 213 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 214 | 215 | self.color_net = nn.ModuleList(color_net) 216 | 217 | def forward(self, x): 218 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 219 | 220 | # sigma 221 | h = input_pts 222 | for l in range(self.num_layers): 223 | h = self.sigma_net[l](h) 224 | if l != self.num_layers - 1: 225 | h = F.relu(h, inplace=True) 226 | 227 | sigma, geo_feat = h[..., 0], h[..., 1:] 228 | 229 | # color 230 | h = torch.cat([input_views, geo_feat], dim=-1) 231 | for l in range(self.num_layers_color): 232 | h = self.color_net[l](h) 233 | if l != self.num_layers_color - 1: 234 | h = F.relu(h, inplace=True) 235 | 236 | # color = torch.sigmoid(h) 237 | color = h 238 | outputs = torch.cat([color, sigma.unsqueeze(dim=-1)], -1) 239 | 240 | return outputs 241 | 242 | 243 | 244 | # Ray helpers 245 | def get_rays(H, W, K, c2w): 246 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 247 | i = i.t() 248 | j = j.t() 249 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 250 | # Rotate ray directions from camera frame to the world frame 251 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 252 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 253 | rays_o = c2w[:3,-1].expand(rays_d.shape) 254 | return rays_o, rays_d 255 | 256 | 257 | def get_rays_np(H, W, K, c2w): 258 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 259 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) 260 | # Rotate ray directions from camera frame to the world frame 261 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 262 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 263 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 264 | return rays_o, rays_d 265 | 266 | 267 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 268 | # Shift ray origins to near plane 269 | t = -(near + rays_o[...,2]) / rays_d[...,2] 270 | rays_o = rays_o + t[...,None] * rays_d 271 | 272 | # Projection 273 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 274 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 275 | o2 = 1. + 2. * near / rays_o[...,2] 276 | 277 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 278 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 279 | d2 = -2. * near / rays_o[...,2] 280 | 281 | rays_o = torch.stack([o0,o1,o2], -1) 282 | rays_d = torch.stack([d0,d1,d2], -1) 283 | 284 | return rays_o, rays_d 285 | 286 | 287 | # Hierarchical sampling (section 5.2) 288 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 289 | # Get pdf 290 | weights = weights + 1e-5 # prevent nans 291 | pdf = weights / torch.sum(weights, -1, keepdim=True) 292 | cdf = torch.cumsum(pdf, -1) 293 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 294 | 295 | # Take uniform samples 296 | if det: 297 | u = torch.linspace(0., 1., steps=N_samples) 298 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 299 | else: 300 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 301 | 302 | # Pytest, overwrite u with numpy's fixed random numbers 303 | if pytest: 304 | np.random.seed(0) 305 | new_shape = list(cdf.shape[:-1]) + [N_samples] 306 | if det: 307 | u = np.linspace(0., 1., N_samples) 308 | u = np.broadcast_to(u, new_shape) 309 | else: 310 | u = np.random.rand(*new_shape) 311 | u = torch.Tensor(u) 312 | 313 | # Invert CDF 314 | u = u.contiguous() 315 | inds = torch.searchsorted(cdf, u, right=True) 316 | below = torch.max(torch.zeros_like(inds-1), inds-1) 317 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 318 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 319 | 320 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 321 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 322 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 323 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 324 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 325 | 326 | denom = (cdf_g[...,1]-cdf_g[...,0]) 327 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 328 | t = (u-cdf_g[...,0])/denom 329 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 330 | 331 | return samples 332 | -------------------------------------------------------------------------------- /scripts/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yashbhalgat/HashNeRF-pytorch/82885e698295982504eb6a26d060a6b2473e3706/scripts/arial.ttf -------------------------------------------------------------------------------- /scripts/make_gif.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL import ImageFont 3 | from PIL import ImageDraw 4 | import os 5 | import imageio 6 | import pdb 7 | import numpy as np 8 | 9 | image_idx = "000" 10 | 11 | paths = { 12 | "Hashed": "../logs/blender_hotdog_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10"} 13 | 14 | for path_name, log_path in paths.items(): 15 | folders = [name for name in os.listdir(log_path) if name.startswith("renderonly_path_")] 16 | folders.sort() 17 | images = [] 18 | writer = imageio.get_writer(os.path.join(log_path, 'convergence.mp4'), fps=2) 19 | for i, folder in enumerate(folders): 20 | if i>50: 21 | break 22 | img = Image.open(os.path.join(log_path, folder, image_idx + ".png")) 23 | font = ImageFont.truetype("arial.ttf", 30) 24 | ImageDraw.Draw( 25 | img # Image 26 | ).text( 27 | (0, 0), # Coordinates 28 | 'Iter: '+str(int(folder[-6:])), # Text 29 | (0, 0, 0), # Color 30 | font=font 31 | ) 32 | images.append(img) 33 | writer.append_data(np.array(img)) 34 | pdb.set_trace() 35 | writer.close() 36 | #imageio.mimsave(os.path.join(log_path, 'convergence_dur025.gif'), images, duration=0.25) 37 | -------------------------------------------------------------------------------- /scripts/plot_losses.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import matplotlib.pyplot as plt 3 | import pdb 4 | 5 | paths = { 6 | #"Vanilla HighLR": "../logs/blender_chair_posXYZ_posVIEW_fine1024_log2T19_lr0.01_decay100/loss_vs_time.pkl", \ 7 | "Hashed Fast": "../logs/blender_chair_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay100/loss_vs_time.pkl", \ 8 | "Hashed Superfast": "../logs/blender_chair_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10/loss_vs_time.pkl", \ 9 | "Vanilla SlowLR": "../logs/blender_chair_posXYZ_posVIEW_fine1024_log2T19_lr0.0005_decay500/loss_vs_time.pkl"} 10 | 11 | # load data 12 | data_dict = {} 13 | for path_key in paths: 14 | filepath = paths[path_key] 15 | with open(filepath, "rb") as f: 16 | data_dict[path_key] = pickle.load(f) 17 | 18 | # plot data 19 | #for k in data_dict: 20 | for k in data_dict: 21 | plt.plot(data_dict[k]["psnr"][1:200][::2], label=k) 22 | 23 | plt.legend() 24 | plt.show() 25 | -------------------------------------------------------------------------------- /scripts/run_all_checkpoints.sh: -------------------------------------------------------------------------------- 1 | for i in logs/blender_hotdog_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10/*.tar; do 2 | CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/hotdog.txt --finest_res 1024 --lr 0.01 --lr_decay 10 --render_only --ft_path $i 3 | done 4 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python run_nerf.py --config configs/chair.txt --finest_res 1024 2 | # CUDA_VISIBLE_DEVICES=1 python run_nerf.py --config configs/chair.txt --finest_res 1024 --i_embed_views 0 3 | # CUDA_VISIBLE_DEVICES=2 python run_nerf.py --config configs/chair.txt --finest_res 1024 --lrate 0.01 --lrate_decay 100 4 | # CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/chair.txt --finest_res 1024 --log2_hashmap_size 14 5 | 6 | CUDA_VISIBLE_DEVICES=0 python run_nerf.py --config configs/chair.txt --finest_res 1024 7 | CUDA_VISIBLE_DEVICES=1 python run_nerf.py --config configs/chair.txt --finest_res 1024 --lrate 0.01 --lrate_decay 100 8 | CUDA_VISIBLE_DEVICES=2 python run_nerf.py --config configs/chair.txt --finest_res 1024 --i_embed 0 --i_embed_views 0 9 | CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/chair.txt --finest_res 1024 --i_embed 0 --i_embed_views 0 --lrate 0.01 --lrate_decay 100 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pdb 4 | import torch 5 | 6 | from ray_utils import get_rays, get_ray_directions, get_ndc_rays 7 | 8 | 9 | BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]], 10 | device='cuda') 11 | 12 | 13 | def hash(coords, log2_hashmap_size): 14 | ''' 15 | coords: this function can process upto 7 dim coordinates 16 | log2T: logarithm of T w.r.t 2 17 | ''' 18 | primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] 19 | 20 | xor_result = torch.zeros_like(coords)[..., 0] 21 | for i in range(coords.shape[-1]): 22 | xor_result ^= coords[..., i]*primes[i] 23 | 24 | return torch.tensor((1< pt[i]): 46 | min_bound[i] = pt[i] 47 | if(max_bound[i] < pt[i]): 48 | max_bound[i] = pt[i] 49 | return 50 | 51 | for i in [0, W-1, H*W-W, H*W-1]: 52 | min_point = rays_o[i] + near*rays_d[i] 53 | max_point = rays_o[i] + far*rays_d[i] 54 | points += [min_point, max_point] 55 | find_min_max(min_point) 56 | find_min_max(max_point) 57 | 58 | return (torch.tensor(min_bound)-torch.tensor([1.0,1.0,1.0]), torch.tensor(max_bound)+torch.tensor([1.0,1.0,1.0])) 59 | 60 | 61 | def get_bbox3d_for_llff(poses, hwf, near=0.0, far=1.0): 62 | H, W, focal = hwf 63 | H, W = int(H), int(W) 64 | 65 | # ray directions in camera coordinates 66 | directions = get_ray_directions(H, W, focal) 67 | 68 | min_bound = [100, 100, 100] 69 | max_bound = [-100, -100, -100] 70 | 71 | points = [] 72 | poses = torch.FloatTensor(poses) 73 | for pose in poses: 74 | rays_o, rays_d = get_rays(directions, pose) 75 | rays_o, rays_d = get_ndc_rays(H, W, focal, 1.0, rays_o, rays_d) 76 | 77 | def find_min_max(pt): 78 | for i in range(3): 79 | if(min_bound[i] > pt[i]): 80 | min_bound[i] = pt[i] 81 | if(max_bound[i] < pt[i]): 82 | max_bound[i] = pt[i] 83 | return 84 | 85 | for i in [0, W-1, H*W-W, H*W-1]: 86 | min_point = rays_o[i] + near*rays_d[i] 87 | max_point = rays_o[i] + far*rays_d[i] 88 | points += [min_point, max_point] 89 | find_min_max(min_point) 90 | find_min_max(max_point) 91 | 92 | return (torch.tensor(min_bound)-torch.tensor([0.1,0.1,0.0001]), torch.tensor(max_bound)+torch.tensor([0.1,0.1,0.0001])) 93 | 94 | 95 | def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size): 96 | ''' 97 | xyz: 3D coordinates of samples. B x 3 98 | bounding_box: min and max x,y,z coordinates of object bbox 99 | resolution: number of voxels per axis 100 | ''' 101 | box_min, box_max = bounding_box 102 | 103 | keep_mask = xyz==torch.max(torch.min(xyz, box_max), box_min) 104 | if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min): 105 | # print("ALERT: some points are outside bounding box. Clipping them!") 106 | xyz = torch.clamp(xyz, min=box_min, max=box_max) 107 | 108 | grid_size = (box_max-box_min)/resolution 109 | 110 | bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int() 111 | voxel_min_vertex = bottom_left_idx*grid_size + box_min 112 | voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size 113 | 114 | voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS 115 | hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size) 116 | 117 | return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask 118 | 119 | 120 | 121 | if __name__=="__main__": 122 | with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f: 123 | camera_transforms = json.load(f) 124 | 125 | bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800) 126 | --------------------------------------------------------------------------------