├── .gitignore ├── LICENSE ├── README.md ├── configs ├── chair.txt ├── drums.txt ├── fern.txt ├── ficus.txt ├── flower.txt ├── fortress.txt ├── horns.txt ├── hotdog.txt ├── leaves.txt ├── lego.txt ├── materials.txt ├── mic.txt ├── orchids.txt ├── room.txt ├── ship.txt └── trex.txt ├── download_example_data.sh ├── imgs └── pipeline.jpg ├── nerf ├── __init__.py ├── data │ ├── __init__.py │ ├── build.py │ ├── load_LINEMOD.py │ ├── load_blender.py │ ├── load_deepvoxels.py │ └── load_llff.py ├── engine │ ├── __init__.py │ └── trainer.py ├── models │ ├── __init__.py │ ├── build.py │ ├── embed.py │ └── nerf.py └── utils │ ├── __init__.py │ ├── helpers.py │ └── render.py ├── requirements.txt ├── run.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.png 4 | *.mp4 5 | *.npy 6 | *.npz 7 | *.dae 8 | data/* 9 | logs/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Megvii 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeRF-megengine 2 | 3 | This is an implementation of [NeRF](https://arxiv.org/abs/2003.08934) for [MegEngine](https://github.com/MegEngine/MegEngine). 4 | 5 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below): 6 | 7 | ![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif) 8 | ![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif) 9 | 10 | This project is a faithful MegEngine implementation of NeRF. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf) and Pytorch implemention [here](https://github.com/yenchenlin/nerf-pytorch) by [Yen-Chen Lin](https://github.com/yenchenlin). 11 | 12 | If you meet any problems when using this repo, please feel free to contact [FateScript](https://github.com/FateScript) @ [Megvii](https://megvii.com/). 13 | 14 | ## Installation 15 | 16 | It's recommended to use venv to train your NeRF model. 17 | 18 | A simple venv command example: 19 | ```bash 20 | python3 -m venv nerf_mge 21 | source nerf_mge/bin/activate 22 | ``` 23 | 24 | ``` 25 | git clone https://github.com/MegEngine/NeRF.git 26 | cd nerf 27 | python3 -m pip -v -e . 28 | ``` 29 | The LLFF data loader requires ImageMagick. 30 | 31 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data. 32 | 33 | ## How To Run? 34 | 35 | ### Quick Start 36 | 37 | Download data for two example datasets: `lego` and `fern` 38 | ``` 39 | bash download_example_data.sh 40 | ``` 41 | 42 | To train a low-res `lego` NeRF: 43 | ``` 44 | python3 run.py --config configs/lego.txt 45 | ``` 46 | After training for 100k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`. 47 | 48 | ![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif) 49 | 50 | --- 51 | 52 | To train a low-res `fern` NeRF: 53 | ``` 54 | python run.py --config configs/fern.txt 55 | ``` 56 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4` 57 | 58 | ![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif) 59 | 60 | --- 61 | 62 | ### More Datasets 63 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure: 64 | ``` 65 | ├── configs 66 | │   ├── ... 67 | │   68 | ├── data 69 | │   ├── nerf_llff_data 70 | │   │   └── fern 71 | │   │  └── flower # downloaded llff dataset 72 | │   │  └── horns # downloaded llff dataset 73 | | | └── ... 74 | | ├── nerf_synthetic 75 | | | └── lego 76 | | | └── ship # downloaded synthetic dataset 77 | | | └── ... 78 | ``` 79 | 80 | --- 81 | 82 | To train NeRF on different datasets: 83 | 84 | ``` 85 | python run.py --config configs/{DATASET}.txt 86 | ``` 87 | 88 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 89 | 90 | --- 91 | 92 | To test NeRF trained on different datasets: 93 | 94 | ``` 95 | python run.py --config configs/{DATASET}.txt --render_only 96 | ``` 97 | 98 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 99 | 100 | 101 | ### Pre-trained Models 102 | 103 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example: 104 | 105 | ``` 106 | ├── logs 107 | │   ├── fern_test 108 | │   ├── flower_test # downloaded logs 109 | │ ├── trex_test # downloaded logs 110 | ``` 111 | 112 | ## Method 113 | 114 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf) 115 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1, 116 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1, 117 | [Matthew Tancik](http://tancik.com/)\*1, 118 | [Jonathan T. Barron](http://jonbarron.info/)2, 119 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3, 120 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
121 | 1UC Berkeley, 2Google Research, 3UC San Diego 122 | \*denotes equal contribution 123 | 124 | 125 | 126 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views 127 | 128 | 129 | ## Citation 130 | Kudos to the authors for their amazing results: 131 | ``` 132 | @misc{mildenhall2020nerf, 133 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, 134 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, 135 | year={2020}, 136 | eprint={2003.08934}, 137 | archivePrefix={arXiv}, 138 | primaryClass={cs.CV} 139 | } 140 | ``` 141 | ## Acknowledgements 142 | Some code in this repo is based on the following repo: 143 | * https://github.com/bmild/nerf MIT License 144 | * https://github.com/yenchenlin/nerf-pytorch MIT License 145 | 146 | 147 | ## Known issue 148 | * Some operators is not supported well in MegEngine, I work around this issue by using operator writing in Python, this might cause NeRF-megengine runs slower. 149 | -------------------------------------------------------------------------------- /configs/chair.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_chair 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/chair 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/drums.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_drums 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/drums 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/fern.txt: -------------------------------------------------------------------------------- 1 | expname = fern_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fern 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/ficus.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ficus 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ficus 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/flower.txt: -------------------------------------------------------------------------------- 1 | expname = flower_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/flower 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/fortress.txt: -------------------------------------------------------------------------------- 1 | expname = fortress_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fortress 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/horns.txt: -------------------------------------------------------------------------------- 1 | expname = horns_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/horns 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/hotdog.txt: -------------------------------------------------------------------------------- 1 | expname = hotdog_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/hotdog 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/leaves.txt: -------------------------------------------------------------------------------- 1 | expname = leaves_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/leaves 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_lego 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/lego 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/materials.txt: -------------------------------------------------------------------------------- 1 | expname = materials_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/materials 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/mic.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_mic 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/mic 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/orchids.txt: -------------------------------------------------------------------------------- 1 | expname = orchids_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/orchids 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/room.txt: -------------------------------------------------------------------------------- 1 | expname = room_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/room 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/ship.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ship 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ship 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/trex.txt: -------------------------------------------------------------------------------- 1 | expname = trex_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/trex 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /download_example_data.sh: -------------------------------------------------------------------------------- 1 | wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz 2 | mkdir -p data 3 | cd data 4 | wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/nerf_example_data.zip 5 | unzip nerf_example_data.zip 6 | cd .. 7 | -------------------------------------------------------------------------------- /imgs/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MegEngine/NeRF/b60724ed7ee859a20ed545198e97cc7467fb5829/imgs/pipeline.jpg -------------------------------------------------------------------------------- /nerf/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /nerf/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | from .build import build_loader 5 | -------------------------------------------------------------------------------- /nerf/data/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | import numpy as np 4 | from loguru import logger 5 | 6 | from .load_llff import load_llff_data 7 | from .load_deepvoxels import load_dv_data 8 | from .load_blender import load_blender_data 9 | from .load_LINEMOD import load_LINEMOD_data 10 | 11 | 12 | def build_loader(dataset_type, args): 13 | K = None 14 | if dataset_type == "llff": 15 | images, poses, bds, render_poses, i_test = load_llff_data( 16 | args.datadir, 17 | args.factor, 18 | recenter=True, 19 | bd_factor=0.75, 20 | spherify=args.spherify, 21 | ) 22 | hwf = poses[0, :3, -1] 23 | poses = poses[:, :3, :4] 24 | logger.info(f"Loaded llff {images.shape} {render_poses.shape} {hwf} {args.datadir}") 25 | num_images = images.shape[0] 26 | if not isinstance(i_test, list): 27 | i_test = [i_test] 28 | 29 | if args.llffhold > 0: 30 | logger.info(f"Auto LLFF holdout, {args.llffhold}") 31 | i_test = np.arange(num_images)[:: args.llffhold] 32 | 33 | i_val = i_test 34 | i_train = np.array( 35 | [i for i in range(num_images) if (i not in i_test and i not in i_val)] 36 | ) 37 | i_split = (i_train, i_val, i_test) 38 | 39 | logger.info("DEFINING BOUNDS") 40 | if args.no_ndc: 41 | near, far = np.ndarray.min(bds) * 0.9, np.ndarray.max(bds) * 1.0 42 | else: 43 | near, far = 0.0, 1.0 44 | logger.info(f"NEAR: {near} FAR: {far}") 45 | 46 | elif dataset_type == "blender": 47 | images, poses, render_poses, hwf, i_split = load_blender_data( 48 | args.datadir, args.half_res, args.testskip 49 | ) 50 | logger.info(f"Loaded blender {images.shape} {render_poses.shape} {hwf} {args.datadir}") 51 | if args.white_bkgd: 52 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) 53 | else: 54 | images = images[..., :3] 55 | near, far = 2.0, 6.0 56 | 57 | elif dataset_type == "LINEMOD": 58 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data( 59 | args.datadir, args.half_res, args.testskip 60 | ) 61 | logger.info(f"Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}") 62 | logger.info(f"[CHECK HERE] near: {near}, far: {far}.") 63 | 64 | if args.white_bkgd: 65 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:]) 66 | else: 67 | images = images[..., :3] 68 | 69 | elif dataset_type == "deepvoxels": 70 | images, poses, render_poses, hwf, i_split = load_dv_data( 71 | scene=args.shape, basedir=args.datadir, testskip=args.testskip 72 | ) 73 | logger.info(f"Loaded deepvoxels {images.shape} {render_poses.shape} {hwf} {args.datadir}") 74 | hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) 75 | near, far = hemi_R - 1.0, hemi_R + 1.0 76 | 77 | else: 78 | raise ValueError(f"Unknown dataset type {dataset_type}") 79 | 80 | # cast height and wigth to right types 81 | height, width, focal = hwf 82 | height, width = int(height), int(width) 83 | hwf = [height, width, focal] 84 | 85 | if K is None: 86 | K = np.array([ 87 | [focal, 0, 0.5 * width], 88 | [0, focal, 0.5 * height], 89 | [0, 0, 1] 90 | ]) 91 | 92 | return images, poses, render_poses, hwf, i_split, near, far, K 93 | -------------------------------------------------------------------------------- /nerf/data/load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /nerf/data/load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import cv2 7 | 8 | 9 | trans_t = lambda t: torch.Tensor( 10 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]] 11 | ).float() 12 | 13 | 14 | rot_phi = lambda phi: torch.Tensor( 15 | [ 16 | [1, 0, 0, 0], 17 | [0, np.cos(phi), -np.sin(phi), 0], 18 | [0, np.sin(phi), np.cos(phi), 0], 19 | [0, 0, 0, 1], 20 | ] 21 | ).float() 22 | 23 | 24 | rot_theta = lambda th: torch.Tensor( 25 | [ 26 | [np.cos(th), 0, -np.sin(th), 0], 27 | [0, 1, 0, 0], 28 | [np.sin(th), 0, np.cos(th), 0], 29 | [0, 0, 0, 1], 30 | ] 31 | ).float() 32 | 33 | 34 | def pose_spherical(theta, phi, radius): 35 | c2w = trans_t(radius) 36 | c2w = rot_phi(phi / 180.0 * np.pi) @ c2w 37 | c2w = rot_theta(theta / 180.0 * np.pi) @ c2w 38 | c2w = ( 39 | torch.Tensor( 40 | np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 41 | ) 42 | @ c2w 43 | ) 44 | return c2w 45 | 46 | 47 | def load_blender_data(basedir, half_res=False, testskip=1): 48 | splits = ["train", "val", "test"] 49 | metas = {} 50 | for s in splits: 51 | with open(os.path.join(basedir, "transforms_{}.json".format(s)), "r") as fp: 52 | metas[s] = json.load(fp) 53 | 54 | all_imgs = [] 55 | all_poses = [] 56 | counts = [0] 57 | for s in splits: 58 | meta = metas[s] 59 | imgs = [] 60 | poses = [] 61 | if s == "train" or testskip == 0: 62 | skip = 1 63 | else: 64 | skip = testskip 65 | 66 | for frame in meta["frames"][::skip]: 67 | fname = os.path.join(basedir, frame["file_path"] + ".png") 68 | imgs.append(imageio.imread(fname)) 69 | poses.append(np.array(frame["transform_matrix"])) 70 | imgs = (np.array(imgs) / 255.0).astype(np.float32) # keep all 4 channels (RGBA) 71 | poses = np.array(poses).astype(np.float32) 72 | counts.append(counts[-1] + imgs.shape[0]) 73 | all_imgs.append(imgs) 74 | all_poses.append(poses) 75 | 76 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 77 | 78 | imgs = np.concatenate(all_imgs, 0) 79 | poses = np.concatenate(all_poses, 0) 80 | 81 | H, W = imgs[0].shape[:2] 82 | camera_angle_x = float(meta["camera_angle_x"]) 83 | focal = 0.5 * W / np.tan(0.5 * camera_angle_x) 84 | 85 | render_poses = torch.stack( 86 | [ 87 | pose_spherical(angle, -30.0, 4.0) 88 | for angle in np.linspace(-180, 180, 40 + 1)[:-1] 89 | ], 90 | 0, 91 | ) 92 | 93 | if half_res: 94 | H = H // 2 95 | W = W // 2 96 | focal = focal / 2.0 97 | 98 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 99 | for i, img in enumerate(imgs): 100 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 101 | imgs = imgs_half_res 102 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 103 | 104 | return imgs, poses, render_poses, [H, W, focal], i_split 105 | -------------------------------------------------------------------------------- /nerf/data/load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /nerf/data/load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import imageio 4 | from loguru import logger 5 | from subprocess import check_output 6 | 7 | 8 | # Slightly modified version of LLFF data loading code 9 | # see https://github.com/Fyusion/LLFF for original 10 | 11 | def _minify(basedir, factors=[], resolutions=[]): 12 | needtoload = False 13 | for r in factors: 14 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 15 | if not os.path.exists(imgdir): 16 | needtoload = True 17 | for r in resolutions: 18 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 19 | if not os.path.exists(imgdir): 20 | needtoload = True 21 | if not needtoload: 22 | return 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = f"images_{r}" 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = 'images_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | logger.info(f"Minifying {r} {basedir}") 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | logger.info(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | 60 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 61 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 62 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 63 | bds = poses_arr[:, -2:].transpose([1, 0]) 64 | 65 | img0 = [ 66 | os.path.join(basedir, 'images', f) 67 | for f in sorted(os.listdir(os.path.join(basedir, 'images'))) 68 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') 69 | ][0] 70 | sh = imageio.imread(img0).shape 71 | 72 | sfx = "" 73 | 74 | if factor is not None: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | _minify(basedir, resolutions=[[height, width]]) 82 | sfx = '_{}x{}'.format(width, height) 83 | elif width is not None: 84 | factor = sh[1] / float(width) 85 | height = int(sh[0] / factor) 86 | _minify(basedir, resolutions=[[height, width]]) 87 | sfx = '_{}x{}'.format(width, height) 88 | else: 89 | factor = 1 90 | 91 | imgdir = os.path.join(basedir, 'images' + sfx) 92 | if not os.path.exists(imgdir): 93 | logger.info(f"{imgdir} does not exist, returning") 94 | return 95 | 96 | imgfiles = [ 97 | os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) 98 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png') 99 | ] 100 | if poses.shape[-1] != len(imgfiles): 101 | logger.info(f'Mismatch between imgs {len(imgfiles)} and poses {poses.shape[-1]} !!!!') 102 | return 103 | 104 | # modify intrinsics of camera 105 | sh = imageio.imread(imgfiles[0]).shape 106 | # poses[:, 4, :] *= 1. / factor 107 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 108 | poses[2, 4, :] *= 1./factor # focal / factor 109 | 110 | if not load_imgs: 111 | return poses, bds 112 | 113 | def imread(f): 114 | if f.endswith('png'): 115 | return imageio.imread(f, ignoregamma=True) 116 | else: 117 | return imageio.imread(f) 118 | 119 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 120 | imgs = np.stack(imgs, -1) 121 | 122 | logger.info(f'Loaded image data {imgs.shape} {poses[:, -1, 0]}') 123 | return poses, bds, imgs 124 | 125 | 126 | def normalize(x): 127 | return x / np.linalg.norm(x) 128 | 129 | 130 | def viewmatrix(z, up, pos): 131 | vec2 = normalize(z) # z axis 132 | vec0 = normalize(np.cross(up, vec2)) # x axis 133 | vec1 = normalize(np.cross(vec2, vec0)) # y axis 134 | m = np.stack([vec0, vec1, vec2, pos], 1) 135 | # return c2w matrix 136 | return m 137 | 138 | 139 | def ptstocam(pts, c2w): 140 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0] 141 | return tt 142 | 143 | 144 | def poses_avg(poses): 145 | hwf = poses[0, :3, -1:] 146 | center = poses[:, :3, 3].mean(0) # camera coordinates center 147 | z = poses[:, :3, 2].sum(0) 148 | y = poses[:, :3, 1].sum(0) 149 | c2w = np.concatenate([viewmatrix(z, y, center), hwf], 1) 150 | return c2w 151 | 152 | 153 | def render_path_spiral(c2w, up, rads, focal, zrate, rots, N): 154 | render_poses = [] 155 | rads = np.array(list(rads) + [1.]) 156 | hwf, c2w_matrix = c2w[:, 4:5], c2w[:, :4] 157 | 158 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 159 | c = np.dot( 160 | c2w_matrix, 161 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads 162 | ) 163 | z = normalize(c - np.dot(c2w_matrix, np.array([0, 0, -focal, 1.]))) 164 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 165 | 166 | return render_poses 167 | 168 | 169 | def recenter_poses(poses): 170 | # process of w2c @ intrinsics 171 | poses_ = poses.copy() 172 | bottom = np.array([[0, 0, 0, 1.]]) # bottom part of camera to world matrix 173 | c2w = poses_avg(poses) 174 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) # c2w shape to (4, 4) 175 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 176 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 177 | 178 | poses = np.linalg.inv(c2w) @ poses 179 | poses_[:, :3, :4] = poses[:, :3, :4] 180 | return poses_ 181 | 182 | 183 | def spherify_poses(poses, bds): 184 | 185 | def p34_to_44(p): 186 | return np.concatenate( 187 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 188 | ) 189 | 190 | rays_d = poses[:, :3, 2:3] 191 | rays_o = poses[:, :3, 3:4] 192 | 193 | def min_line_dist(rays_o, rays_d): 194 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 195 | b_i = -A_i @ rays_o 196 | pt_mindist = np.squeeze( 197 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) 198 | ) 199 | return pt_mindist 200 | 201 | pt_mindist = min_line_dist(rays_o, rays_d) 202 | 203 | center = pt_mindist 204 | up = (poses[:, :3, 3] - center).mean(0) 205 | 206 | vec0 = normalize(up) 207 | vec1 = normalize(np.cross([.1, .2, .3], vec0)) 208 | vec2 = normalize(np.cross(vec0, vec1)) 209 | pos = center 210 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 211 | 212 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 213 | 214 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 215 | 216 | sc = 1./rad 217 | poses_reset[:, :3, 3] *= sc 218 | bds *= sc 219 | rad *= sc 220 | 221 | centroid = np.mean(poses_reset[:, :3, 3], 0) 222 | zh = centroid[2] 223 | radcircle = np.sqrt(rad**2-zh**2) 224 | new_poses = [] 225 | 226 | for th in np.linspace(0., 2.*np.pi, 120): 227 | 228 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 229 | up = np.array([0, 0, -1.]) 230 | 231 | vec2 = normalize(camorigin) 232 | vec0 = normalize(np.cross(vec2, up)) 233 | vec1 = normalize(np.cross(vec2, vec0)) 234 | pos = camorigin 235 | p = np.stack([vec0, vec1, vec2, pos], 1) 236 | 237 | new_poses.append(p) 238 | 239 | new_poses = np.stack(new_poses, 0) 240 | 241 | new_poses = np.concatenate( 242 | [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1 243 | ) 244 | poses_reset = np.concatenate([ 245 | poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape) 246 | ], -1) 247 | 248 | return poses_reset, new_poses, bds 249 | 250 | 251 | def load_llff_data( 252 | basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False 253 | ): 254 | # factor=8 downsamples original imgs by 8x 255 | poses, bds, images = _load_data(basedir, factor=factor) 256 | logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}") 257 | 258 | # Correct rotation matrix ordering and move variable dim to axis 0 259 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 260 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 261 | images = np.moveaxis(images, -1, 0).astype(np.float32) # (H, W, C, N)(N, H, W, C) 262 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 263 | 264 | # Rescale if bd_factor is provided 265 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 266 | poses[:, :3, 3] *= sc 267 | bds *= sc 268 | 269 | if recenter: 270 | poses = recenter_poses(poses) 271 | 272 | if spherify: 273 | poses, render_poses, bds = spherify_poses(poses, bds) 274 | 275 | else: 276 | c2w = poses_avg(poses) 277 | logger.info(f'recentered c2w shape: {c2w.shape}') 278 | logger.info(f"c2w: {c2w[:3, :4]}") 279 | 280 | # Get spiral 281 | # Get average pose 282 | up = normalize(poses[:, :3, 1].sum(0)) 283 | 284 | # Find a reasonable "focus depth" for this dataset 285 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 286 | dt = .75 287 | focal = 1. / (((1. - dt) / close_depth + dt / inf_depth)) 288 | 289 | # Get radii for spiral path 290 | tt = poses[:, :3, 3] 291 | rads = np.percentile(np.abs(tt), 90, 0) 292 | c2w_path = c2w 293 | N_views, N_rots = 120, 2 294 | if path_zflat: 295 | zloc = -close_depth * .1 296 | c2w_path[:3, 3] += zloc * c2w_path[:3, 2] 297 | rads[2] = 0. 298 | N_rots = 1 299 | N_views /= 2 300 | 301 | # generate poses for spiral path 302 | render_poses = render_path_spiral( 303 | c2w_path, up, rads, focal, zrate=.5, rots=N_rots, N=N_views 304 | ) 305 | 306 | render_poses = np.array(render_poses).astype(np.float32) 307 | 308 | c2w = poses_avg(poses) 309 | logger.info( 310 | f"Data:\nposes shape: {poses.shape}, images shape: {images.shape}, bds shape: {bds.shape}" 311 | ) 312 | 313 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1) 314 | i_test = np.argmin(dists) 315 | logger.info(f"HOLDOUT view is {i_test}") 316 | 317 | images = images.astype(np.float32) 318 | poses = poses.astype(np.float32) 319 | 320 | return images, poses, bds, render_poses, i_test 321 | -------------------------------------------------------------------------------- /nerf/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer -------------------------------------------------------------------------------- /nerf/engine/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import datetime 5 | import numpy as np 6 | import os 7 | import imageio 8 | import time 9 | from loguru import logger 10 | 11 | import megengine as mge 12 | import megengine.functional as F 13 | import megengine.distributed as dist 14 | from megengine.autodiff import GradManager 15 | 16 | from nerf.data import build_loader 17 | from nerf.models.build import create_nerf 18 | from nerf.utils import ( 19 | ensure_dir, get_rays, get_rays_np, to8b, img2mse, mse2psnr, render_path, render, meshgrid 20 | ) 21 | 22 | 23 | def rendering_only(args, images, i_test, render_poses, hwf, K, render_kwargs_test, start_iter): 24 | logger.info("RENDER ONLY") 25 | # render_test switches to test poses, Default is smoother render_poses path 26 | images = images[i_test] if args.render_test else None 27 | 28 | testsavedir = os.path.join( 29 | args.basedir, 30 | args.expname, 31 | "renderonly_{}_{:06d}".format("test" if args.render_test else "path", start_iter), 32 | ) 33 | ensure_dir(testsavedir) 34 | logger.info(f"test poses shape: {render_poses.shape}") 35 | 36 | rgbs, _ = render_path( 37 | render_poses, 38 | hwf, 39 | K, 40 | args.chunk, 41 | render_kwargs_test, 42 | savedir=testsavedir, 43 | render_factor=args.render_factor, 44 | ) 45 | logger.info(f"Done rendering {testsavedir}") 46 | imageio.mimwrite(os.path.join(testsavedir, "video.mp4"), to8b(rgbs), fps=30, quality=8) 47 | 48 | 49 | class Trainer: 50 | def __init__(self, args): 51 | # init function only defines some basic attr, other attrs like model, optimizer 52 | # are built in `before_train` methods. 53 | self.args = args 54 | self.start_iter, self.max_iter = (0, 200000) 55 | self.rank = 0 56 | self.amp_training = False 57 | 58 | def train(self): 59 | self.before_train() 60 | try: 61 | self.train_in_iter() 62 | except Exception: 63 | raise 64 | finally: 65 | self.after_train() 66 | 67 | def before_train(self): 68 | args = self.args 69 | 70 | logger.info(f"Full args:\n{args}") 71 | # model related init 72 | images, poses, render_poses, hwf, i_split, near, far, K = build_loader(args.dataset_type, args) # noqa 73 | 74 | i_train, i_val, i_test = i_split 75 | info_string = f"Train views are {i_train}\nTest views are {i_test}\nVal views are {i_val}" 76 | if args.render_test: 77 | render_poses = np.array(poses[i_test]) 78 | 79 | self.i_split = i_split 80 | self.hwf = hwf 81 | self.K = K 82 | self.render_poses = mge.Tensor(render_poses) 83 | 84 | # create dir to save all the results 85 | self.save_dir = os.path.join(args.basedir, args.expname) 86 | ensure_dir(self.save_dir) 87 | logger.add(os.path.join(self.save_dir, "log.txt")) 88 | 89 | # save args.txt config.txt 90 | with open(os.path.join(self.save_dir, "args.txt"), "w") as f: 91 | for arg in sorted(vars(args)): 92 | f.write(f"{arg} = {getattr(args, arg)}\n") 93 | 94 | if args.config is not None: 95 | with open(os.path.join(self.save_dir, "config.txt"), "w") as f: 96 | f.write(open(args.config, "r").read()) 97 | 98 | # Create nerf model 99 | render_kwargs_train, render_kwargs_test, optimizer, gm = self.build_nerf() 100 | bds_dict = {"near": near, "far": far} 101 | render_kwargs_train.update(bds_dict) 102 | render_kwargs_test.update(bds_dict) 103 | self.render_kwargs_train = render_kwargs_train 104 | self.render_kwargs_test = render_kwargs_test 105 | self.optimizer = optimizer 106 | self.grad_manager = gm 107 | 108 | # Short circuit if only rendering out from trained model 109 | if args.render_only: 110 | rendering_only( 111 | args, images, i_test, self.render_poses, self.hwf, self.K, 112 | self.render_kwargs_test, self.start_iter, 113 | ) 114 | return 115 | 116 | # Prepare raybatch tensor if batching random rays 117 | use_batching = not args.no_batching 118 | assert use_batching 119 | if use_batching: 120 | # For random ray batching 121 | logger.info("get rays") 122 | rays = np.stack( 123 | [get_rays_np(self.hwf[0], self.hwf[1], self.K, p) for p in poses[:, :3, :4]], 0 124 | ) # [N, ro+rd, H, W, 3] 125 | logger.info("get rays done, start concats") 126 | rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] 127 | rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] 128 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 129 | rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] 130 | rays_rgb = rays_rgb.astype(np.float32) 131 | logger.info("shuffle rays") 132 | np.random.shuffle(rays_rgb) 133 | logger.info("shuffle rasy done") 134 | 135 | images = mge.Tensor(images) 136 | rays_rgb = mge.Tensor(rays_rgb) 137 | i_batch = 0 138 | 139 | self.poses = mge.Tensor(poses) 140 | self.images = images 141 | self.rays_rgb = rays_rgb 142 | self.i_batch = i_batch 143 | 144 | logger.info("Begin training\n" + info_string) 145 | 146 | def build_nerf(self): 147 | args = self.args 148 | 149 | model, model_fine, network_query_fn = create_nerf(args) 150 | params = list(model.parameters()) 151 | logger.info(f"Model:\n{model}") 152 | if model_fine is not None: 153 | logger.info(f"Model Fine:\n{model_fine}") 154 | params += list(model_fine.parameters()) 155 | 156 | gm = GradManager() 157 | world_size = dist.get_world_size() 158 | callbacks = [dist.make_allreduce_cb("MEAN", dist.WORLD)] if world_size > 1 else None # noqa 159 | gm.attach(params, callbacks=callbacks) 160 | 161 | optimizer = mge.optimizer.Adam(params=params, lr=args.lr, betas=(0.9, 0.999)) 162 | self.resume_ckpt(model, model_fine, optimizer) 163 | 164 | render_kwargs_train = { 165 | "network_query_fn": network_query_fn, 166 | "perturb": args.perturb, 167 | "N_importance": args.N_importance, 168 | "network_fine": model_fine, 169 | "N_samples": args.N_samples, 170 | "network_fn": model, 171 | "use_viewdirs": args.use_viewdirs, 172 | "white_bkgd": args.white_bkgd, 173 | "raw_noise_std": args.raw_noise_std, 174 | } 175 | 176 | # NDC only good for LLFF-style forward facing data 177 | if args.dataset_type != "llff" or args.no_ndc: 178 | logger.info("Not ndc!") 179 | render_kwargs_train["ndc"] = False 180 | render_kwargs_train["lindisp"] = args.lindisp 181 | 182 | render_kwargs_test = {k: v for k, v in render_kwargs_train.items()} 183 | render_kwargs_test["perturb"] = False 184 | render_kwargs_test["raw_noise_std"] = 0.0 185 | 186 | return render_kwargs_train, render_kwargs_test, optimizer, gm 187 | 188 | def after_train(self): 189 | logger.info("Training of experiment is done.") 190 | 191 | def train_in_iter(self): 192 | H, W, _ = self.hwf 193 | 194 | for i in range(self.start_iter, self.max_iter): 195 | self.global_step = i + 1 196 | 197 | iter_start_time = time.time() 198 | batch_rays, target_s = self.sample_rays() 199 | 200 | # Core optimization loop 201 | with self.grad_manager: 202 | rgb, disp, acc, extras = render( 203 | H, W, self.K, chunk=self.args.chunk, rays=batch_rays, 204 | retraw=True, **self.render_kwargs_train, 205 | ) 206 | 207 | loss = img2mse(rgb, target_s) 208 | psnr = mse2psnr(loss.detach()) 209 | if "rgb0" in extras: 210 | loss += img2mse(extras["rgb0"], target_s) 211 | 212 | self.grad_manager.backward(loss) 213 | self.optimizer.step().clear_grad() 214 | 215 | lr = self.update_lr() 216 | iter_time = time.time() - iter_start_time 217 | 218 | self.save_ckpt() 219 | self.save_test() 220 | 221 | # log training info 222 | if self.global_step % self.args.log_interval == 0: 223 | eta_seconds = (self.max_iter - self.global_step) * iter_time 224 | logger.info( 225 | f"iter: {self.global_step}/{self.max_iter}, " 226 | f"loss: {loss.item():.4f}, " 227 | f"PSNR: {psnr.item():.3f}, " 228 | f"lr: {lr:.4e}, " 229 | f"iter time: {iter_time:.3f}s, " 230 | f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}" 231 | ) 232 | 233 | def update_lr(self, decay_rate=0.1): 234 | decay_steps = self.args.lrate_decay * 1000 235 | new_lr = self.args.lr * (decay_rate ** (self.global_step / decay_steps)) 236 | 237 | for param_group in self.optimizer.param_groups: 238 | param_group["lr"] = new_lr 239 | 240 | return new_lr 241 | 242 | def save_ckpt(self): 243 | if self.rank == 0 and self.global_step % self.args.i_weights == 0: 244 | path = os.path.join(self.save_dir, f"{self.global_step:06d}.tar") 245 | ckpt_state = { 246 | "global_step": self.global_step, 247 | "network_fn_state_dict": self.render_kwargs_train["network_fn"].state_dict(), 248 | "network_fine_state_dict": self.render_kwargs_train["network_fine"].state_dict(), 249 | "optimizer_state_dict": self.optimizer.state_dict(), 250 | } 251 | mge.save(ckpt_state, path) 252 | logger.info(f"Save checkpoint at {path}") 253 | 254 | def save_test(self): 255 | if self.global_step % self.args.i_video == 0: 256 | # Turn on testing mode 257 | rgbs, disps = render_path( 258 | self.render_poses, self.hwf, self.K, self.args.chunk, self.render_kwargs_test 259 | ) 260 | logger.info(f"Done, saving {rgbs.shape} {disps.shape}") 261 | moviebase = os.path.join( 262 | self.save_dir, "{}_spiral_{:06d}_".format(self.args.expname, self.global_step) 263 | ) 264 | imageio.mimwrite(moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8) 265 | imageio.mimwrite( 266 | moviebase + "disp.mp4", to8b(disps / np.max(disps)), fps=30, quality=8 267 | ) 268 | 269 | if self.global_step % self.args.i_testset == 0: 270 | i_test = self.i_split[-1] 271 | test_save_dir = os.path.join(self.save_dir, f"testset_{self.global_step:06d}") 272 | ensure_dir(test_save_dir) 273 | logger.info(f"test poses shape: {self.poses[i_test].shape}") 274 | render_path( 275 | mge.Tensor(self.poses[i_test]), 276 | self.hwf, 277 | self.K, 278 | self.args.chunk, 279 | self.render_kwargs_test, 280 | savedir=test_save_dir, 281 | ) 282 | logger.info("Saved test set") 283 | 284 | def resume_ckpt(self, model, model_fine, optimizer): 285 | if self.args.ft_path is not None and self.args.ft_path != "None": 286 | ckpts = [self.args.ft_path] 287 | else: 288 | ckpts = [ 289 | os.path.join(self.save_dir, f) 290 | for f in sorted(os.listdir(self.save_dir)) 291 | if f.endswith("tar") 292 | ] 293 | 294 | if ckpts and not self.args.no_reload: 295 | logger.info(f"Found ckpts: {ckpts}") 296 | ckpt_to_load = ckpts[-1] 297 | logger.info(f"Reloading from {ckpt_to_load}") 298 | ckpt = mge.load(ckpt_to_load) 299 | 300 | self.start_iter = ckpt["global_step"] 301 | model.load_state_dict(ckpt["network_fn_state_dict"]) 302 | if model_fine is not None: 303 | model_fine.load_state_dict(ckpt["network_fine_state_dict"]) 304 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 305 | 306 | return model, model_fine, optimizer 307 | 308 | def sample_rays(self): 309 | # Sample random ray batch 310 | N_rand = self.args.N_rand 311 | use_batching = not self.args.no_batching 312 | rays_rgb = self.rays_rgb 313 | i_train = self.i_split[0] 314 | images = self.images 315 | H, W, _ = self.hwf 316 | 317 | if use_batching: 318 | # Random over all images 319 | batch = rays_rgb[self.i_batch: self.i_batch + N_rand] # [B, 2+1, 3*?] 320 | batch = batch.transpose(1, 0, 2) 321 | batch_rays, target_s = batch[:2], batch[2] 322 | 323 | self.i_batch += N_rand 324 | if self.i_batch >= rays_rgb.shape[0]: 325 | logger.info("Shuffle data after an epoch!") 326 | rand_idx = mge.Tensor(np.random.permutation(rays_rgb.shape[0])) 327 | rays_rgb = rays_rgb[rand_idx] 328 | self.i_batch = 0 329 | 330 | else: 331 | # Random from one image 332 | img_i = np.random.choice(i_train) 333 | target = images[img_i] 334 | target = mge.Tensor(target) 335 | pose = self.poses[img_i, :3, :4] 336 | 337 | if N_rand is not None: 338 | rays_o, rays_d = get_rays(H, W, self.K, pose) # (H, W, 3), (H, W, 3) 339 | 340 | if self.global_step < self.args.precrop_iters: 341 | dH = int(H // 2 * self.args.precrop_frac) 342 | dW = int(W // 2 * self.args.precrop_frac) 343 | coords = F.stack( 344 | meshgrid( 345 | F.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), 346 | F.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW), 347 | indexing="ij" 348 | ), 349 | -1, 350 | ) 351 | if self.global_step == 1: 352 | logger.info( 353 | f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {self.args.precrop_iters}" # noqa 354 | ) 355 | else: 356 | coords = F.stack( 357 | meshgrid(F.linspace(0, H - 1, H), F.linspace(0, W - 1, W), indexing="ij"), 358 | -1, 359 | ) # (H, W, 2) 360 | 361 | coords = coords.reshape(-1, 2) # (H * W, 2) 362 | select_inds = np.random.choice( 363 | coords.shape[0], size=[N_rand], replace=False 364 | ) # (N_rand,) 365 | select_coords = coords[select_inds].long() # (N_rand, 2) 366 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 367 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 368 | batch_rays = F.stack([rays_o, rays_d], 0) 369 | target_s = target[ 370 | select_coords[:, 0], select_coords[:, 1] 371 | ] # (N_rand, 3) 372 | 373 | return batch_rays, target_s 374 | -------------------------------------------------------------------------------- /nerf/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | from .nerf import NeRF 5 | -------------------------------------------------------------------------------- /nerf/models/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import megengine.functional as F 5 | 6 | from .nerf import NeRF 7 | from .embed import get_embedder 8 | 9 | __all__ = ["create_nerf"] 10 | 11 | 12 | def batchify(fn, chunk): 13 | """Constructs a version of 'fn' that applies to smaller batches.""" 14 | if chunk is None: 15 | return fn 16 | 17 | def ret(inputs): 18 | return F.concat([fn(inputs[i: i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 19 | 20 | return ret 21 | 22 | 23 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64): 24 | """Prepares inputs and applies network 'fn'.""" 25 | inputs_flat = inputs.reshape([-1, inputs.shape[-1]]) 26 | embedded = embed_fn(inputs_flat) 27 | 28 | if viewdirs is not None: 29 | input_dirs = F.broadcast_to(F.expand_dims(viewdirs, axis=1), inputs.shape) 30 | input_dirs_flat = input_dirs.reshape([-1, input_dirs.shape[-1]]) 31 | embedded_dirs = embeddirs_fn(input_dirs_flat) 32 | embedded = F.concat([embedded, embedded_dirs], -1) 33 | 34 | outputs_flat = batchify(fn, netchunk)(embedded) 35 | outputs = outputs_flat.reshape(list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 36 | return outputs 37 | 38 | 39 | def create_nerf(args): 40 | """Instantiate NeRF's MLP model.""" 41 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 42 | 43 | input_ch_views = 0 44 | embeddirs_fn = None 45 | if args.use_viewdirs: 46 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 47 | output_ch = 5 if args.N_importance > 0 else 4 48 | skips = [4] 49 | model = NeRF( 50 | depth=args.netdepth, 51 | width=args.netwidth, 52 | input_ch=input_ch, 53 | output_ch=output_ch, 54 | skips=skips, 55 | input_ch_views=input_ch_views, 56 | use_viewdirs=args.use_viewdirs, 57 | ) 58 | 59 | model_fine = None 60 | if args.N_importance > 0: 61 | model_fine = NeRF( 62 | depth=args.netdepth_fine, 63 | width=args.netwidth_fine, 64 | input_ch=input_ch, 65 | output_ch=output_ch, 66 | skips=skips, 67 | input_ch_views=input_ch_views, 68 | use_viewdirs=args.use_viewdirs, 69 | ) 70 | 71 | def network_query_fn(inputs, viewdirs, network_fn): 72 | return run_network( 73 | inputs, viewdirs, network_fn, embed_fn=embed_fn, 74 | embeddirs_fn=embeddirs_fn, netchunk=args.netchunk, 75 | ) 76 | 77 | return model, model_fine, network_query_fn 78 | -------------------------------------------------------------------------------- /nerf/models/embed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | import megengine.functional as F 4 | import megengine.module as nn 5 | 6 | __all__ = ["Embedder", "get_embedder"] 7 | 8 | 9 | # Positional encoding (section 5.1) 10 | class Embedder: 11 | def __init__(self, **kwargs): 12 | self.kwargs = kwargs 13 | self.create_embedding_fn() 14 | 15 | def create_embedding_fn(self): 16 | embed_fns = [] 17 | d = self.kwargs["input_dims"] 18 | out_dim = 0 19 | if self.kwargs["include_input"]: 20 | embed_fns.append(lambda x: x) 21 | out_dim += d 22 | 23 | max_freq = self.kwargs["max_freq_log2"] 24 | N_freqs = self.kwargs["num_freqs"] 25 | 26 | if self.kwargs["log_sampling"]: 27 | freq_bands = 2.0 ** F.linspace(0.0, max_freq, N_freqs) 28 | else: 29 | freq_bands = F.linspace(2.0**0.0, 2.0**max_freq, N_freqs) 30 | 31 | for freq in freq_bands: 32 | for p_fn in self.kwargs["periodic_fns"]: 33 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 34 | out_dim += d 35 | 36 | self.embed_fns = embed_fns 37 | self.out_dim = out_dim 38 | 39 | def embed(self, inputs): 40 | return F.concat([fn(inputs) for fn in self.embed_fns], -1) 41 | 42 | 43 | def get_embedder(multires, i=0): 44 | if i == -1: 45 | return nn.Identity(), 3 46 | 47 | embed_kwargs = { 48 | "include_input": True, 49 | "input_dims": 3, 50 | "max_freq_log2": multires - 1, 51 | "num_freqs": multires, 52 | "log_sampling": True, 53 | "periodic_fns": [F.sin, F.cos], 54 | } 55 | 56 | embedder_obj = Embedder(**embed_kwargs) 57 | 58 | def embed(x, eo=embedder_obj): 59 | return eo.embed(x) 60 | 61 | return embed, embedder_obj.out_dim 62 | -------------------------------------------------------------------------------- /nerf/models/nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | import megengine.module as nn 4 | import megengine.functional as F 5 | 6 | 7 | class NeRF(nn.Module): 8 | """NeRF module""" 9 | def __init__( 10 | self, depth=8, width=256, input_ch=3, input_ch_views=3, 11 | output_ch=4, skips=[4], use_viewdirs=False, 12 | ): 13 | super().__init__() 14 | self.depth = depth 15 | self.width = width 16 | self.input_ch = input_ch 17 | self.input_ch_views = input_ch_views 18 | self.skips = skips 19 | self.use_viewdirs = use_viewdirs 20 | 21 | self.pts_linears = [nn.Linear(input_ch, width)] + [ 22 | nn.Linear(width, width) if i not in self.skips else nn.Linear(width + input_ch, width) 23 | for i in range(depth - 1) 24 | ] 25 | 26 | if use_viewdirs: 27 | self.alpha_linear = nn.Linear(width, 1) 28 | self.feature_linear = nn.Linear(width, width) 29 | self.views_linears = nn.Linear(input_ch_views + width, width // 2) 30 | self.rgb_linear = nn.Linear(width // 2, 3) 31 | else: 32 | self.output_linear = nn.Linear(width, output_ch) 33 | 34 | def forward(self, x): 35 | input_pts, input_views = F.split(x, [self.input_ch], axis=-1) 36 | h = input_pts 37 | 38 | for i, layer in enumerate(self.pts_linears): 39 | h = F.relu(layer(h)) 40 | if i in self.skips: 41 | h = F.concat([input_pts, h], -1) 42 | 43 | if self.use_viewdirs: 44 | alpha = self.alpha_linear(h) 45 | feature = self.feature_linear(h) 46 | h = F.concat([feature, input_views], -1) 47 | h = F.relu(self.views_linears(h)) 48 | rgb = self.rgb_linear(h) 49 | outputs = F.concat([rgb, alpha], -1) 50 | else: 51 | outputs = self.output_linear(h) 52 | 53 | return outputs 54 | -------------------------------------------------------------------------------- /nerf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | from .helpers import * 5 | from .render import * -------------------------------------------------------------------------------- /nerf/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import megengine as mge 3 | import megengine.functional as F 4 | import numpy as np 5 | import math 6 | 7 | 8 | __all__ = [ 9 | "ensure_dir", 10 | "img2mse", 11 | "mse2psnr", 12 | "to8b", 13 | "get_rays", 14 | "get_rays_np", 15 | "ndc_rays", 16 | "sample_pdf", 17 | "meshgrid", 18 | "cumprod", 19 | ] 20 | 21 | 22 | def cumprod(x: mge.Tensor, axis: int): 23 | dim = x.ndim 24 | axis = axis if axis > 0 else axis + dim 25 | num_loop = x.shape[axis] 26 | t_shape = [i + 1 if i < axis else i for i in range(dim)] 27 | t_shape[axis] = 0 28 | x = x.transpose(*t_shape) 29 | assert len(x) == num_loop 30 | cum_val = F.ones(x[0].shape) 31 | for i in range(num_loop): 32 | cum_val *= x[i] 33 | x[i] = cum_val 34 | return x.transpose(*t_shape) 35 | 36 | 37 | def ensure_dir(path: str): 38 | """create directories if *path* does not exist""" 39 | if not os.path.isdir(path): 40 | os.makedirs(path) 41 | 42 | 43 | def meshgrid(x, y, indexing="xy"): 44 | """meshgrid wrapper for megengine""" 45 | assert len(x.shape) == 1 46 | assert len(y.shape) == 1 47 | mesh_shape = (y.shape[0], x.shape[0]) 48 | mesh_x = F.broadcast_to(x, mesh_shape) 49 | mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape) 50 | if indexing == "ij": 51 | mesh_x, mesh_y = mesh_x.T, mesh_y.T 52 | return mesh_x, mesh_y 53 | 54 | 55 | # Misc 56 | def img2mse(x, y): 57 | return F.mean((x - y) ** 2) 58 | 59 | 60 | def mse2psnr(x): 61 | return -10.0 * (F.log(x) / math.log(10.0)) 62 | 63 | 64 | def to8b(x): 65 | return (255 * np.clip(x, 0, 1)).astype(np.uint8) 66 | 67 | 68 | # Ray helpers 69 | def get_rays(H, W, K, c2w): 70 | i, j = meshgrid(F.linspace(0, W - 1, W), F.linspace(0, H - 1, H), indexing="xy") 71 | dirs = F.stack( 72 | [ 73 | (i - float(K[0][2])) / float(K[0][0]), 74 | -(j - float(K[1][2])) / float(K[1][1]), 75 | -F.ones_like(i), 76 | ], -1 77 | ) 78 | # Rotate ray directions from camera frame to the world frame 79 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] 80 | rays_d = (F.expand_dims(dirs, axis=-2) * c2w[:3, :3]).sum(axis=-1) 81 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 82 | rays_o = F.broadcast_to(c2w[:3, -1], rays_d.shape) 83 | return rays_o, rays_d 84 | 85 | 86 | def get_rays_np(H: int, W: int, K: np.array, c2w: np.array): 87 | """ 88 | 89 | Args: 90 | H (int): height of image. 91 | W (int): width of image. 92 | K (np.array): intrinsic matrix. 93 | c2w (np.array): camera to world matrix. 94 | """ 95 | i, j = np.meshgrid( 96 | np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy" 97 | ) 98 | # K @ dirs is (x, -y0, -1), and K is intrinsics of camera 99 | dirs = np.stack([(i - K[0, 2]) / K[0, 0], -(j - K[1, 2]) / K[1, 1], -np.ones_like(i)], -1) 100 | # Rotate ray directions from camera frame to the world frame 101 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # noqa 102 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 103 | rays_o = np.broadcast_to(c2w[:3, -1], rays_d.shape) 104 | return rays_o, rays_d 105 | 106 | 107 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 108 | """ 109 | get ray o' and d' in normalized device coordinates space. 110 | check more details in appendix C of nerf paper. 111 | 112 | Args: 113 | rays_o: rays origin of shape (N, 3). 114 | rays_d: rays direction of shape (N, 3). 115 | """ 116 | # Shift ray origins to near plane 117 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 118 | rays_o = rays_o + F.expand_dims(t, axis=-1) * rays_d 119 | 120 | # Projection 121 | # according to paper, a_x = - f_cam / (W / 2), a_y = -f_cam / (H / 2) 122 | a_x, a_y = - float((2.0 * focal) / W), - float((2.0 * focal) / H) 123 | o_x = a_x * rays_o[..., 0] / rays_o[..., 2] 124 | o_y = a_y * rays_o[..., 1] / rays_o[..., 2] 125 | o_z = 1.0 + 2.0 * near / rays_o[..., 2] 126 | 127 | d_x = a_x * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 128 | d_y = a_y * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 129 | d_z = -2.0 * near / rays_o[..., 2] 130 | 131 | rays_o = F.stack([o_x, o_y, o_z], -1) 132 | rays_d = F.stack([d_x, d_y, d_z], -1) 133 | 134 | return rays_o, rays_d 135 | 136 | 137 | def search_sorted(cdf, value): 138 | # TODO: torch to pure mge 139 | import torch 140 | inds = torch.searchsorted(torch.tensor(cdf.numpy()), torch.tensor(value.numpy()), right=True) 141 | inds = mge.tensor(inds) 142 | return inds 143 | 144 | 145 | # Hierarchical sampling (section 5.2) 146 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 147 | # Get pdf 148 | weights = weights + 1e-5 # prevent nans 149 | pdf = weights / F.sum(weights, -1, keepdims=True) 150 | cdf = F.cumsum(pdf, -1 + pdf.ndim) 151 | cdf = F.concat([F.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 152 | 153 | # Take uniform samples 154 | if det: 155 | u = F.linspace(0.0, 1.0, N_samples) 156 | u = F.broadcast_to(u, list(cdf.shape[:-1]) + [N_samples]) 157 | else: 158 | u = mge.random.uniform(size=list(cdf.shape[:-1]) + [N_samples]) 159 | 160 | # Pytest, overwrite u with numpy's fixed random numbers 161 | if pytest: 162 | np.random.seed(0) 163 | new_shape = list(cdf.shape[:-1]) + [N_samples] 164 | if det: 165 | u = np.linspace(0.0, 1.0, N_samples) 166 | u = np.broadcast_to(u, new_shape) 167 | else: 168 | u = np.random.rand(*new_shape) 169 | u = mge.Tensor(u) 170 | 171 | # Invert CDF 172 | inds = search_sorted(cdf, u) 173 | below = F.maximum(F.zeros_like(inds - 1), inds - 1) 174 | above = F.minimum((cdf.shape[-1] - 1) * F.ones_like(inds), inds) 175 | inds_g = F.stack([below, above], -1) # (batch, N_samples, 2) 176 | 177 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 178 | cdf_g = F.gather(F.broadcast_to(F.expand_dims(cdf, axis=1), matched_shape), 2, inds_g) 179 | bins_g = F.gather(F.broadcast_to(F.expand_dims(bins, axis=1), matched_shape), 2, inds_g) 180 | denom = cdf_g[..., 1] - cdf_g[..., 0] 181 | denom = F.where(denom < 1e-5, F.ones_like(denom), denom) 182 | t = (u - cdf_g[..., 0]) / denom 183 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 184 | 185 | return samples 186 | -------------------------------------------------------------------------------- /nerf/utils/render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import os 5 | import imageio 6 | import numpy as np 7 | import megengine as mge 8 | import megengine.functional as F 9 | import time 10 | from collections import defaultdict 11 | from loguru import logger 12 | from tqdm import tqdm 13 | 14 | from nerf.utils import get_rays, ndc_rays, to8b, sample_pdf, cumprod 15 | 16 | __all__ = [ 17 | "batchify_rays", 18 | "render_rays", 19 | "render_path", 20 | "render", 21 | ] 22 | 23 | 24 | def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs): 25 | """Render rays in smaller minibatches to avoid OOM.""" 26 | all_ret = defaultdict(list) 27 | for i in range(0, rays_flat.shape[0], chunk): 28 | ret = render_rays(rays_flat[i: i + chunk], **kwargs) 29 | for k, v in ret.items(): 30 | all_ret[k].append(v) 31 | 32 | return {k: F.concat(all_ret[k], 0) for k in all_ret} 33 | 34 | 35 | def render( 36 | H: int, 37 | W: int, 38 | K: float, 39 | chunk: int = 1024 * 32, 40 | rays=None, 41 | c2w=None, 42 | ndc: bool = True, 43 | near: float = 0.0, 44 | far: float = 1.0, 45 | use_viewdirs: bool = False, 46 | c2w_staticcam=None, 47 | **kwargs, 48 | ): 49 | """Render rays 50 | 51 | Args: 52 | H (int): Height of image in pixels. 53 | W (int): Width of image in pixels. 54 | K: intrinsics of camera 55 | focal: float. Focal length of pinhole camera. 56 | chunk: int. Maximum number of rays to process simultaneously. Used to 57 | control maximum memory usage. Does not affect final results. 58 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 59 | each example in batch. 60 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 61 | ndc: If True, represent ray origin, direction in NDC coordinates. 62 | near: float or array of shape [batch_size]. Nearest distance for a ray. 63 | far: float or array of shape [batch_size]. Farthest distance for a ray. 64 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 65 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 66 | camera while using other c2w argument for viewing directions. 67 | 68 | Returns: 69 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 70 | disp_map: [batch_size]. Disparity map. Inverse of depth. 71 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 72 | extras: dict with everything returned by render_rays(). 73 | """ 74 | if c2w is not None: 75 | # special case to render full image 76 | rays_o, rays_d = get_rays(H, W, K, c2w) 77 | else: 78 | # use provided ray batch 79 | rays_o, rays_d = rays 80 | 81 | if use_viewdirs: 82 | # provide ray directions as input 83 | viewdirs = rays_d 84 | if c2w_staticcam is not None: 85 | # special case to visualize effect of viewdirs 86 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 87 | viewdirs = viewdirs / F.norm(viewdirs, axis=-1, keepdims=True) 88 | viewdirs = viewdirs.reshape(-1, 3) 89 | 90 | sh = rays_d.shape # [..., 3] 91 | if ndc: 92 | # for forward facing scenes 93 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1.0, rays_o, rays_d) 94 | 95 | # Create ray batch 96 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 97 | 98 | near, far = near * F.ones_like(rays_d[..., :1]), far * F.ones_like(rays_d[..., :1]) 99 | rays = F.concat([rays_o, rays_d, near, far], -1) 100 | if use_viewdirs: 101 | rays = F.concat([rays, viewdirs], -1) 102 | 103 | # Render and reshape 104 | all_ret = batchify_rays(rays, chunk, **kwargs) 105 | for k, v in all_ret.items(): 106 | all_ret[k] = v.reshape(list(sh[:-1]) + list(v.shape[1:])) 107 | 108 | k_extract = ["rgb_map", "disp_map", "acc_map"] 109 | ret_list = [all_ret[k] for k in k_extract] 110 | ret_dict = {k: v for k, v in all_ret.items() if k not in k_extract} 111 | return ret_list + [ret_dict] 112 | 113 | 114 | def render_path(render_poses, hwf, K, chunk, render_kwargs, savedir=None, render_factor=0): 115 | H, W, focal = hwf 116 | 117 | if render_factor != 0: 118 | # Render downsampled for speed 119 | H = H // render_factor 120 | W = W // render_factor 121 | focal = focal / render_factor 122 | 123 | rgbs, disps = [], [] 124 | 125 | t = time.time() 126 | for i, c2w in enumerate(tqdm(render_poses)): 127 | logger.info(f"{i} {time.time() - t}") 128 | t = time.time() 129 | rgb, disp, acc, _ = render( 130 | H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs 131 | ) 132 | rgbs.append(rgb.numpy()) 133 | disps.append(disp.numpy()) 134 | if i == 0: 135 | logger.info(f"rgb shape: {rgb.shape}, disp shape: {disp.shape}") 136 | 137 | if savedir is not None: 138 | rgb8 = to8b(rgbs[-1]) 139 | filename = os.path.join(savedir, "{:03d}.png".format(i)) 140 | imageio.imwrite(filename, rgb8) 141 | 142 | rgbs = np.stack(rgbs, 0) 143 | disps = np.stack(disps, 0) 144 | 145 | return rgbs, disps 146 | 147 | 148 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 149 | """Transforms model's predictions to semantically meaningful values. 150 | 151 | Args: 152 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 153 | z_vals: [num_rays, num_samples along ray]. Integration time. 154 | rays_d: [num_rays, 3]. Direction of each ray. 155 | 156 | Returns: 157 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 158 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 159 | acc_map: [num_rays]. Sum of weights along each ray. 160 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 161 | depth_map: [num_rays]. Estimated distance to object. 162 | 163 | """ 164 | def raw2alpha(raw, dists, act_fn=F.relu): 165 | return 1.0 - F.exp(-act_fn(raw) * dists) 166 | 167 | dists = z_vals[..., 1:] - z_vals[..., :-1] 168 | dists = F.concat([dists, F.full(dists[..., :1].shape, 1e10)], -1) # [N_rays, N_samples] 169 | 170 | dists = dists * F.norm(F.expand_dims(rays_d, axis=-2), axis=-1) 171 | 172 | rgb = F.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 173 | noise = 0.0 174 | if raw_noise_std > 0.0: 175 | noise = mge.random.normal(size=raw[..., 3].shape) * raw_noise_std 176 | 177 | # Overwrite randomly sampled data if pytest 178 | if pytest: 179 | np.random.seed(0) 180 | noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std 181 | noise = mge.Tensor(noise) 182 | 183 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] 184 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 185 | weights = ( 186 | alpha * cumprod( 187 | F.concat([F.ones((alpha.shape[0], 1)), 1.0 - alpha + 1e-10], -1), -1 188 | )[:, :-1] 189 | ) 190 | rgb_map = F.sum(F.expand_dims(weights, axis=-1) * rgb, -2) # [N_rays, 3] 191 | 192 | depth_map = F.sum(weights * z_vals, -1) 193 | disp_map = 1.0 / F.maximum( 194 | 1e-10 * F.ones_like(depth_map), depth_map / F.sum(weights, -1) 195 | ) 196 | acc_map = F.sum(weights, -1) 197 | 198 | if white_bkgd: 199 | rgb_map = rgb_map + (1.0 - acc_map[..., None]) 200 | 201 | return rgb_map, disp_map, acc_map, weights, depth_map 202 | 203 | 204 | def render_rays( 205 | ray_batch, 206 | network_fn, 207 | network_query_fn, 208 | N_samples, 209 | retraw=False, 210 | lindisp=False, 211 | perturb=0.0, 212 | N_importance=0, 213 | network_fine=None, 214 | white_bkgd=False, 215 | raw_noise_std=0.0, 216 | pytest=False, 217 | ): 218 | """Volumetric rendering. 219 | 220 | Args: 221 | ray_batch: array of shape [batch_size, ...]. All information necessary 222 | for sampling along a ray, including: ray origin, ray direction, min 223 | dist, max dist, and unit-magnitude viewing direction. 224 | network_fn: function. Model for predicting RGB and density at each point 225 | in space. 226 | network_query_fn: function used for passing queries to network_fn. 227 | N_samples: int. Number of different times to sample along each ray. 228 | retraw: bool. If True, include model's raw, unprocessed predictions. 229 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 230 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 231 | random points in time. 232 | N_importance: int. Number of additional times to sample along each ray. 233 | These samples are only passed to network_fine. 234 | network_fine: "fine" network with same spec as network_fn. 235 | white_bkgd: bool. If True, assume a white background. 236 | raw_noise_std: ... 237 | 238 | Returns: 239 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 240 | disp_map: [num_rays]. Disparity map. 1 / depth. 241 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 242 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 243 | rgb0: See rgb_map. Output for coarse model. 244 | disp0: See disp_map. Output for coarse model. 245 | acc0: See acc_map. Output for coarse model. 246 | z_std: [num_rays]. Standard deviation of distances along ray for each sample. 247 | """ 248 | N_rays = ray_batch.shape[0] 249 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each 250 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None 251 | bounds = ray_batch[..., 6:8].reshape(-1, 1, 2) 252 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1] 253 | 254 | # Generate sample bins 255 | bin_vals = F.linspace(0.0, 1.0, N_samples) 256 | if not lindisp: 257 | z_vals = near + (far - near) * bin_vals 258 | else: 259 | z_vals = 1.0 / (1.0 / near * (1.0 - bin_vals) + 1.0 / far * (bin_vals)) 260 | z_vals = F.broadcast_to(z_vals, [N_rays, N_samples]) 261 | 262 | if perturb > 0.0: 263 | # get intervals between samples 264 | mids = (z_vals[..., :-1] + z_vals[..., 1:]) / 2.0 265 | upper = F.concat([mids, z_vals[..., -1:]], -1) 266 | lower = F.concat([z_vals[..., :1], mids], -1) 267 | # stratified samples in those intervals 268 | t_rand = mge.random.uniform(size=z_vals.shape) 269 | 270 | # Pytest, overwrite u with numpy's fixed random numbers 271 | if pytest: 272 | np.random.seed(0) 273 | t_rand = mge.Tensor(np.random.rand(*list(z_vals.shape))) 274 | 275 | z_vals = lower + (upper - lower) * t_rand 276 | 277 | pts = F.expand_dims(rays_o, axis=-2) + F.expand_dims(rays_d, axis=-2) * F.expand_dims(z_vals, axis=-1) # noqa 278 | # shape of pts: [N_rays, N_samples, 3] 279 | 280 | raw = network_query_fn(pts, viewdirs, network_fn) 281 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( 282 | raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest 283 | ) 284 | 285 | if N_importance > 0: 286 | 287 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 288 | 289 | z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) 290 | z_samples = sample_pdf( 291 | z_vals_mid, 292 | weights[..., 1:-1], 293 | N_importance, 294 | det=(perturb == 0.0), 295 | pytest=pytest, 296 | ) 297 | z_samples = z_samples.detach() 298 | 299 | # note that sort in megengine is different from torch 300 | z_vals, _ = F.sort(F.concat([z_vals, z_samples], -1,), descending=False) 301 | pts = F.expand_dims(rays_o, -2) + F.expand_dims(rays_d, -2) * F.expand_dims(z_vals, -1) 302 | # [N_rays, N_samples + N_importance, 3] 303 | 304 | run_fn = network_fn if network_fine is None else network_fine 305 | raw = network_query_fn(pts, viewdirs, run_fn) 306 | 307 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( 308 | raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest 309 | ) 310 | 311 | ret = {"rgb_map": rgb_map, "disp_map": disp_map, "acc_map": acc_map} 312 | if retraw: 313 | ret["raw"] = raw 314 | if N_importance > 0: 315 | ret["rgb0"] = rgb_map_0 316 | ret["disp0"] = disp_map_0 317 | ret["acc0"] = acc_map_0 318 | ret["z_std"] = F.std(z_samples, axis=-1) # [N_rays] 319 | 320 | return ret 321 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | imageio-ffmpeg 3 | torch>=1.4 4 | loguru 5 | configargparse 6 | tqdm 7 | opencv-python 8 | megengine -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from loguru import logger 3 | 4 | from nerf.engine import Trainer 5 | 6 | 7 | def config_parser(): 8 | 9 | import configargparse 10 | 11 | parser = configargparse.ArgumentParser() 12 | parser.add_argument("--config", is_config_file=True, help="config file path") 13 | parser.add_argument("--expname", type=str, help="experiment name") 14 | parser.add_argument( 15 | "--basedir", type=str, default="./logs/", help="where to store ckpts and logs" 16 | ) 17 | parser.add_argument( 18 | "--datadir", type=str, default="./data/llff/fern", help="input data directory" 19 | ) 20 | 21 | # training options 22 | parser.add_argument("--netdepth", type=int, default=8, help="layers in network") 23 | parser.add_argument("--netwidth", type=int, default=256, help="channels per layer") 24 | parser.add_argument( 25 | "--netdepth_fine", type=int, default=8, help="layers in fine network" 26 | ) 27 | parser.add_argument( 28 | "--netwidth_fine", 29 | type=int, 30 | default=256, 31 | help="channels per layer in fine network", 32 | ) 33 | parser.add_argument( 34 | "--N_rand", 35 | type=int, 36 | default=32 * 32 * 4, 37 | help="batch size (number of random rays per gradient step)", 38 | ) 39 | parser.add_argument("--lr", type=float, default=5e-4, help="learning rate") 40 | parser.add_argument( 41 | "--lrate_decay", 42 | type=int, 43 | default=250, 44 | help="exponential learning rate decay (in 1000 steps)", 45 | ) 46 | parser.add_argument( 47 | "--chunk", 48 | type=int, 49 | default=1024 * 32, 50 | help="number of rays processed in parallel, decrease if running out of memory", 51 | ) 52 | parser.add_argument( 53 | "--netchunk", 54 | type=int, 55 | default=1024 * 64, 56 | help="number of pts sent through network in parallel, decrease if running out of memory", 57 | ) 58 | parser.add_argument( 59 | "--no_batching", 60 | action="store_true", 61 | help="only take random rays from 1 image at a time", 62 | ) 63 | parser.add_argument( 64 | "--no_reload", action="store_true", help="do not reload weights from saved ckpt" 65 | ) 66 | parser.add_argument( 67 | "--ft_path", 68 | type=str, 69 | default=None, 70 | help="specific weights npy file to reload for coarse network", 71 | ) 72 | 73 | # rendering options 74 | parser.add_argument( 75 | "--N_samples", type=int, default=64, help="number of coarse samples per ray" 76 | ) 77 | parser.add_argument( 78 | "--N_importance", 79 | type=int, 80 | default=0, 81 | help="number of additional fine samples per ray", 82 | ) 83 | parser.add_argument( 84 | "--perturb", 85 | type=float, 86 | default=1.0, 87 | help="set to 0. for no jitter, 1. for jitter", 88 | ) 89 | parser.add_argument( 90 | "--use_viewdirs", action="store_true", help="use full 5D input instead of 3D" 91 | ) 92 | parser.add_argument( 93 | "--i_embed", 94 | type=int, 95 | default=0, 96 | help="set 0 for default positional encoding, -1 for none", 97 | ) 98 | parser.add_argument( 99 | "--multires", 100 | type=int, 101 | default=10, 102 | help="log2 of max freq for positional encoding (3D location)", 103 | ) 104 | parser.add_argument( 105 | "--multires_views", 106 | type=int, 107 | default=4, 108 | help="log2 of max freq for positional encoding (2D direction)", 109 | ) 110 | parser.add_argument( 111 | "--raw_noise_std", 112 | type=float, 113 | default=0.0, 114 | help="std dev of noise added to regularize sigma_a output, 1e0 recommended", 115 | ) 116 | 117 | parser.add_argument( 118 | "--render_only", 119 | action="store_true", 120 | help="do not optimize, reload weights and render out render_poses path", 121 | ) 122 | parser.add_argument( 123 | "--render_test", 124 | action="store_true", 125 | help="render the test set instead of render_poses path", 126 | ) 127 | parser.add_argument( 128 | "--render_factor", 129 | type=int, 130 | default=0, 131 | help="downsampling factor to speed up rendering, set 4 or 8 for fast preview", 132 | ) 133 | 134 | # training options 135 | parser.add_argument( 136 | "--precrop_iters", 137 | type=int, 138 | default=0, 139 | help="number of steps to train on central crops", 140 | ) 141 | parser.add_argument( 142 | "--precrop_frac", 143 | type=float, 144 | default=0.5, 145 | help="fraction of img taken for central crops", 146 | ) 147 | 148 | # dataset options 149 | parser.add_argument( 150 | "--dataset_type", 151 | type=str, 152 | default="llff", 153 | help="options: llff / blender / deepvoxels", 154 | ) 155 | parser.add_argument( 156 | "--testskip", 157 | type=int, 158 | default=8, 159 | help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels", 160 | ) 161 | 162 | # deepvoxels flags 163 | parser.add_argument( 164 | "--shape", 165 | type=str, 166 | default="greek", 167 | help="options : armchair / cube / greek / vase", 168 | ) 169 | 170 | # blender flags 171 | parser.add_argument( 172 | "--white_bkgd", 173 | action="store_true", 174 | help="set to render synthetic data on a white bkgd (always use for dvoxels)", 175 | ) 176 | parser.add_argument( 177 | "--half_res", 178 | action="store_true", 179 | help="load blender synthetic data at 400x400 instead of 800x800", 180 | ) 181 | 182 | # llff flags 183 | parser.add_argument( 184 | "--factor", type=int, default=8, help="downsample factor for LLFF images" 185 | ) 186 | parser.add_argument( 187 | "--no_ndc", 188 | action="store_true", 189 | help="do not use normalized device coordinates (set for non-forward facing scenes)", 190 | ) 191 | parser.add_argument( 192 | "--lindisp", 193 | action="store_true", 194 | help="sampling linearly in disparity rather than depth", 195 | ) 196 | parser.add_argument( 197 | "--spherify", action="store_true", help="set for spherical 360 scenes" 198 | ) 199 | parser.add_argument( 200 | "--llffhold", 201 | type=int, 202 | default=8, 203 | help="will take every 1/N images as LLFF test set, paper uses 8", 204 | ) 205 | 206 | # logging/saving options 207 | parser.add_argument( 208 | "--log_interval", 209 | type=int, 210 | default=50, 211 | help="frequency of information logging", 212 | ) 213 | parser.add_argument( 214 | "--i_img", type=int, default=500, help="frequency of tensorboard image logging" 215 | ) 216 | parser.add_argument( 217 | "--i_weights", type=int, default=10000, help="frequency of weight ckpt saving" 218 | ) 219 | parser.add_argument( 220 | "--i_testset", type=int, default=50000, help="frequency of testset saving" 221 | ) 222 | parser.add_argument( 223 | "--i_video", 224 | type=int, 225 | default=50000, 226 | help="frequency of render_poses video saving", 227 | ) 228 | 229 | return parser 230 | 231 | 232 | @logger.catch 233 | def main(): 234 | np.random.seed(0) 235 | args = config_parser().parse_args() 236 | trainer = Trainer(args) 237 | trainer.train() 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 100 3 | multi_line_output = 3 4 | balanced_wrapping = True 5 | known_third_party = loguru,tabulate,yaml,cv2,numpy,PIL,pycocotools 6 | known_deeplearning = megengine,torch,torchvision 7 | known_myself = nerf 8 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,deeplearning,myself,LOCALFOLDER 9 | no_lines_before=STDLIB,THIRDPARTY 10 | default_section = FIRSTPARTY 11 | 12 | [flake8] 13 | max-line-length = 100 14 | per-file-ignores = 15 | **/__init__.py:F401,F403 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import setuptools 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | 8 | with open("nerf/__init__.py", "r") as f: 9 | version = re.search( 10 | r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 11 | f.read(), re.MULTILINE 12 | ).group(1) 13 | 14 | 15 | with open("requirements.txt", "r") as f: 16 | reqs = [x.strip() for x in f.readlines()] 17 | 18 | 19 | setuptools.setup( 20 | name="nerf", 21 | version=version, 22 | author="FateScript", 23 | author_email="wangfeng02@megvii.com", 24 | description="Nerf implemented by megengine", 25 | long_description=long_description, 26 | long_description_content_type="text/markdown", 27 | url=None, 28 | packages=setuptools.find_packages(), 29 | classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"], 30 | install_requires=reqs, 31 | ) 32 | --------------------------------------------------------------------------------