├── .gitignore ├── LICENSE ├── NeRF - 源码.pptx ├── README.md ├── README_origin.md ├── configs ├── chair.txt ├── drums.txt ├── fern.txt ├── ficus.txt ├── flower.txt ├── fortress.txt ├── horns.txt ├── hotdog.txt ├── leaves.txt ├── lego-use_batching.txt ├── lego.txt ├── materials.txt ├── mic.txt ├── orchids.txt ├── room.txt ├── ship.txt └── trex.txt ├── inference.py ├── load_LINEMOD.py ├── load_blender.py ├── load_deepvoxels.py ├── load_llff.py ├── nerf_helpers.py ├── nerf_model.py ├── opts.py ├── render.py ├── requirements.txt ├── run_nerf.py ├── test └── frame.py ├── 模型.txt ├── 源码结构.md └── 说明.md /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.png 4 | *.mp4 5 | *.npy 6 | *.npz 7 | *.dae 8 | data/* 9 | logs/* 10 | 11 | .idea 12 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bmild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NeRF - 源码.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xunull/read-nerf-pytorch/5940fc27e0ea82674859c61996e178ff71bddd1e/NeRF - 源码.pptx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # read-nerf-pytorch 2 | 3 | 4 | 原始仓库: https://github.com/yenchenlin/nerf-pytorch 5 | nerf团队的tf实现: https://github.com/bmild/nerf 6 | 7 | 8 | -------------------------------------------------------------------------------- /README_origin.md: -------------------------------------------------------------------------------- 1 | # NeRF-pytorch 2 | 3 | 4 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below): 5 | 6 | ![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif) 7 | ![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif) 8 | 9 | This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf), and has been tested to match it numerically. 10 | 11 | ## Installation 12 | 13 | ``` 14 | git clone https://github.com/yenchenlin/nerf-pytorch.git 15 | cd nerf-pytorch 16 | pip install -r requirements.txt 17 | ``` 18 | 19 |
20 | Dependencies (click to expand) 21 | 22 | ## Dependencies 23 | - PyTorch 1.4 24 | - matplotlib 25 | - numpy 26 | - imageio 27 | - imageio-ffmpeg 28 | - configargparse 29 | 30 | The LLFF data loader requires ImageMagick. 31 | 32 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data. 33 | 34 |
35 | 36 | ## How To Run? 37 | 38 | ### Quick Start 39 | 40 | Download data for two example datasets: `lego` and `fern` 41 | ``` 42 | bash download_example_data.sh 43 | ``` 44 | 45 | To train a low-res `lego` NeRF: 46 | ``` 47 | python run_nerf.py --config configs/lego.txt 48 | ``` 49 | After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`. 50 | 51 | ![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif) 52 | 53 | --- 54 | 55 | To train a low-res `fern` NeRF: 56 | ``` 57 | python run_nerf.py --config configs/fern.txt 58 | ``` 59 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4` 60 | 61 | ![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif) 62 | 63 | --- 64 | 65 | ### More Datasets 66 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure: 67 | ``` 68 | ├── configs 69 | │   ├── ... 70 | │   71 | ├── data 72 | │   ├── nerf_llff_data 73 | │   │   └── fern 74 | │   │  └── flower # downloaded llff dataset 75 | │   │  └── horns # downloaded llff dataset 76 | | | └── ... 77 | | ├── nerf_synthetic 78 | | | └── lego 79 | | | └── ship # downloaded synthetic dataset 80 | | | └── ... 81 | ``` 82 | 83 | --- 84 | 85 | To train NeRF on different datasets: 86 | 87 | ``` 88 | python run_nerf.py --config configs/{DATASET}.txt 89 | ``` 90 | 91 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 92 | 93 | --- 94 | 95 | To test NeRF trained on different datasets: 96 | 97 | ``` 98 | python run_nerf.py --config configs/{DATASET}.txt --render_only 99 | ``` 100 | 101 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 102 | 103 | 104 | ### Pre-trained Models 105 | 106 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example: 107 | 108 | ``` 109 | ├── logs 110 | │   ├── fern_test 111 | │   ├── flower_test # downloaded logs 112 | │ ├── trex_test # downloaded logs 113 | ``` 114 | 115 | ### Reproducibility 116 | 117 | Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests: 118 | ``` 119 | git checkout reproduce 120 | py.test 121 | ``` 122 | 123 | ## Method 124 | 125 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf) 126 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1, 127 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1, 128 | [Matthew Tancik](http://tancik.com/)\*1, 129 | [Jonathan T. Barron](http://jonbarron.info/)2, 130 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3, 131 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
132 | 1UC Berkeley, 2Google Research, 3UC San Diego 133 | \*denotes equal contribution 134 | 135 | 136 | 137 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views 138 | 139 | 140 | ## Citation 141 | Kudos to the authors for their amazing results: 142 | ``` 143 | @misc{mildenhall2020nerf, 144 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, 145 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, 146 | year={2020}, 147 | eprint={2003.08934}, 148 | archivePrefix={arXiv}, 149 | primaryClass={cs.CV} 150 | } 151 | ``` 152 | 153 | However, if you find this implementation or pre-trained models helpful, please consider to cite: 154 | ``` 155 | @misc{lin2020nerfpytorch, 156 | title={NeRF-pytorch}, 157 | author={Yen-Chen, Lin}, 158 | publisher = {GitHub}, 159 | journal = {GitHub repository}, 160 | howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}}, 161 | year={2020} 162 | } 163 | ``` 164 | -------------------------------------------------------------------------------- /configs/chair.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_chair 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/chair 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/drums.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_drums 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/drums 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/fern.txt: -------------------------------------------------------------------------------- 1 | expname = fern_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fern 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/ficus.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ficus 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ficus 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/flower.txt: -------------------------------------------------------------------------------- 1 | expname = flower_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/flower 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/fortress.txt: -------------------------------------------------------------------------------- 1 | expname = fortress_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fortress 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/horns.txt: -------------------------------------------------------------------------------- 1 | expname = horns_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/horns 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/hotdog.txt: -------------------------------------------------------------------------------- 1 | expname = hotdog_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/hotdog 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/leaves.txt: -------------------------------------------------------------------------------- 1 | expname = leaves_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/leaves 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/lego-use_batching.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_lego 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/lego 4 | dataset_type = blender 5 | 6 | 7 | use_viewdirs = True 8 | white_bkgd = True 9 | lrate_decay = 500 10 | 11 | N_samples = 64 12 | N_importance = 128 13 | N_rand = 1024 14 | 15 | precrop_iters = 500 16 | precrop_frac = 0.5 17 | 18 | half_res = True 19 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_lego 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/lego 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/materials.txt: -------------------------------------------------------------------------------- 1 | expname = materials_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/materials 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/mic.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_mic 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/mic 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/orchids.txt: -------------------------------------------------------------------------------- 1 | expname = orchids_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/orchids 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/room.txt: -------------------------------------------------------------------------------- 1 | expname = room_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/room 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /configs/ship.txt: -------------------------------------------------------------------------------- 1 | expname = blender_paper_ship 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/ship 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | 8 | use_viewdirs = True 9 | white_bkgd = True 10 | lrate_decay = 500 11 | 12 | N_samples = 64 13 | N_importance = 128 14 | N_rand = 1024 15 | 16 | precrop_iters = 500 17 | precrop_frac = 0.5 18 | 19 | half_res = True 20 | -------------------------------------------------------------------------------- /configs/trex.txt: -------------------------------------------------------------------------------- 1 | expname = trex_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/trex 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import time 4 | from tqdm import tqdm 5 | from render import * 6 | from nerf_helpers import to8b 7 | import numpy as np 8 | 9 | 10 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 11 | H, W, focal = hwf 12 | 13 | if render_factor != 0: 14 | # Render downsampled for speed 15 | H = H // render_factor 16 | W = W // render_factor 17 | focal = focal / render_factor 18 | 19 | rgbs = [] 20 | disps = [] 21 | 22 | t = time.time() 23 | for i, c2w in enumerate(tqdm(render_poses)): 24 | print(i, time.time() - t) 25 | t = time.time() 26 | rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs) 27 | rgbs.append(rgb.cpu().numpy()) 28 | disps.append(disp.cpu().numpy()) 29 | if i == 0: 30 | print(rgb.shape, disp.shape) 31 | 32 | """ 33 | if gt_imgs is not None and render_factor==0: 34 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 35 | print(p) 36 | """ 37 | 38 | if savedir is not None: 39 | rgb8 = to8b(rgbs[-1]) 40 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 41 | imageio.imwrite(filename, rgb8) 42 | 43 | rgbs = np.stack(rgbs, 0) 44 | disps = np.stack(disps, 0) 45 | 46 | return rgbs, disps 47 | -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | trans_t = lambda t: torch.Tensor([ 10 | [1, 0, 0, 0], 11 | [0, 1, 0, 0], 12 | [0, 0, 1, t], 13 | [0, 0, 0, 1]]).float() 14 | 15 | rot_phi = lambda phi: torch.Tensor([ 16 | [1, 0, 0, 0], 17 | [0, np.cos(phi), -np.sin(phi), 0], 18 | [0, np.sin(phi), np.cos(phi), 0], 19 | [0, 0, 0, 1]]).float() 20 | 21 | rot_theta = lambda th: torch.Tensor([ 22 | [np.cos(th), 0, -np.sin(th), 0], 23 | [0, 1, 0, 0], 24 | [np.sin(th), 0, np.cos(th), 0], 25 | [0, 0, 0, 1]]).float() 26 | 27 | 28 | def pose_spherical(theta, phi, radius): 29 | c2w = trans_t(radius) 30 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 31 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 32 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 33 | return c2w 34 | 35 | 36 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 37 | splits = ['train', 'val', 'test'] 38 | metas = {} 39 | for s in splits: 40 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 41 | metas[s] = json.load(fp) 42 | 43 | all_imgs = [] 44 | all_poses = [] 45 | counts = [0] 46 | for s in splits: 47 | meta = metas[s] 48 | imgs = [] 49 | poses = [] 50 | if s == 'train' or testskip == 0: 51 | skip = 1 52 | else: 53 | skip = testskip 54 | 55 | for idx_test, frame in enumerate(meta['frames'][::skip]): 56 | fname = frame['file_path'] 57 | if s == 'test': 58 | print(f"{idx_test}th test frame: {fname}") 59 | imgs.append(imageio.imread(fname)) 60 | poses.append(np.array(frame['transform_matrix'])) 61 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 62 | poses = np.array(poses).astype(np.float32) 63 | counts.append(counts[-1] + imgs.shape[0]) 64 | all_imgs.append(imgs) 65 | all_poses.append(poses) 66 | 67 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 68 | 69 | imgs = np.concatenate(all_imgs, 0) 70 | poses = np.concatenate(all_poses, 0) 71 | 72 | H, W = imgs[0].shape[:2] 73 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 74 | K = meta['frames'][0]['intrinsic_matrix'] 75 | print(f"Focal: {focal}") 76 | 77 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0) 78 | 79 | if half_res: 80 | H = H // 2 81 | W = W // 2 82 | focal = focal / 2. 83 | 84 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 85 | for i, img in enumerate(imgs): 86 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 87 | imgs = imgs_half_res 88 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 89 | 90 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 91 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 92 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 93 | -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import cv2 7 | 8 | # 平移 9 | trans_t = lambda t: torch.Tensor([ 10 | [1, 0, 0, 0], 11 | [0, 1, 0, 0], 12 | [0, 0, 1, t], 13 | [0, 0, 0, 1]]).float() 14 | 15 | # 绕x轴的旋转 16 | rot_phi = lambda phi: torch.Tensor([ 17 | [1, 0, 0, 0], 18 | [0, np.cos(phi), -np.sin(phi), 0], 19 | [0, np.sin(phi), np.cos(phi), 0], 20 | [0, 0, 0, 1]]).float() 21 | 22 | # 绕y轴的旋转 23 | rot_theta = lambda th: torch.Tensor([ 24 | [np.cos(th), 0, -np.sin(th), 0], 25 | [0, 1, 0, 0], 26 | [np.sin(th), 0, np.cos(th), 0], 27 | [0, 0, 0, 1]]).float() 28 | 29 | 30 | def pose_spherical(theta, phi, radius): 31 | """ 32 | theta: -180 -- +180,间隔为9 33 | phi: 固定值 -30 34 | radius: 固定值 4 35 | """ 36 | c2w = trans_t(radius) 37 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 38 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 39 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 40 | return c2w 41 | 42 | 43 | def load_blender_data(basedir, half_res=False, testskip=1): 44 | """ 45 | testskip: test和val数据集,只会读取其中的一部分数据,跳着读取 46 | """ 47 | splits = ['train', 'val', 'test'] 48 | # 存储了三个json文件的数据 49 | metas = {} 50 | for s in splits: 51 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 52 | metas[s] = json.load(fp) 53 | 54 | all_imgs = [] 55 | all_poses = [] 56 | counts = [0] 57 | for s in splits: 58 | meta = metas[s] 59 | imgs = [] 60 | poses = [] 61 | if s == 'train' or testskip == 0: 62 | skip = 1 63 | else: 64 | # 测试集如果数量很多,可能会设置testskip 65 | skip = testskip 66 | # 读取所有的图片,以及所有对应的transform_matrix 67 | for frame in meta['frames'][::skip]: 68 | fname = os.path.join(basedir, frame['file_path'] + '.png') 69 | imgs.append(imageio.imread(fname)) 70 | poses.append(np.array(frame['transform_matrix'])) 71 | # 归一化 72 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA),4通道 rgba 73 | poses = np.array(poses).astype(np.float32) 74 | # 用于计算train val test的递增值 75 | counts.append(counts[-1] + imgs.shape[0]) 76 | all_imgs.append(imgs) 77 | all_poses.append(poses) 78 | # train val test 三个list 79 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 80 | # train test val 拼一起 81 | imgs = np.concatenate(all_imgs, 0) 82 | poses = np.concatenate(all_poses, 0) 83 | 84 | H, W = imgs[0].shape[:2] 85 | # meta使用了上面的局部变量,train test val 这个变量值是相同的,文件中这三个值确实是相同的 86 | camera_angle_x = float(meta['camera_angle_x']) 87 | # 焦距 88 | focal = .5 * W / np.tan(.5 * camera_angle_x) 89 | 90 | # np.linspace(-180, 180, 40 + 1) 9度一个间隔 91 | # (40,4,4), 渲染的结果就是40帧 92 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0) 93 | 94 | if half_res: 95 | H = H // 2 96 | W = W // 2 97 | # 焦距一半 98 | focal = focal / 2. 99 | 100 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 101 | for i, img in enumerate(imgs): 102 | # 调整成一半的大小 103 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 104 | imgs = imgs_half_res 105 | 106 | return imgs, poses, render_poses, [H, W, focal], i_split 107 | -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 8 | # Get camera intrinsics 9 | with open(filepath, 'r') as file: 10 | f, cx, cy = list(map(float, file.readline().split()))[:3] 11 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 12 | near_plane = float(file.readline()) 13 | scale = float(file.readline()) 14 | height, width = map(float, file.readline().split()) 15 | 16 | try: 17 | world2cam_poses = int(file.readline()) 18 | except ValueError: 19 | world2cam_poses = None 20 | 21 | if world2cam_poses is None: 22 | world2cam_poses = False 23 | 24 | world2cam_poses = bool(world2cam_poses) 25 | 26 | print(cx, cy, f, height, width) 27 | 28 | cx = cx / width * trgt_sidelength 29 | cy = cy / height * trgt_sidelength 30 | f = trgt_sidelength / height * f 31 | 32 | fx = f 33 | if invert_y: 34 | fy = -f 35 | else: 36 | fy = f 37 | 38 | # Build the intrinsic matrices 39 | full_intrinsic = np.array([[fx, 0., cx, 0.], 40 | [0., fy, cy, 0], 41 | [0., 0, 1, 0], 42 | [0, 0, 0, 1]]) 43 | 44 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 45 | 46 | def load_pose(filename): 47 | assert os.path.isfile(filename) 48 | nums = open(filename).read().split() 49 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32) 50 | 51 | H = 512 52 | W = 512 53 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 54 | 55 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics( 56 | os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 57 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 58 | focal = full_intrinsic[0, 0] 59 | print(H, W, focal) 60 | 61 | def dir2poses(posedir): 62 | poses = np.stack( 63 | [load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 64 | transf = np.array([ 65 | [1, 0, 0, 0], 66 | [0, -1, 0, 0], 67 | [0, 0, -1, 0], 68 | [0, 0, 0, 1.], 69 | ]) 70 | poses = poses @ transf 71 | poses = poses[:, :3, :4].astype(np.float32) 72 | return poses 73 | 74 | posedir = os.path.join(deepvoxels_base, 'pose') 75 | poses = dir2poses(posedir) 76 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 77 | testposes = testposes[::testskip] 78 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 79 | valposes = valposes[::testskip] 80 | 81 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 82 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f)) / 255. for f in imgfiles], 0).astype( 83 | np.float32) 84 | 85 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 86 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 87 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f)) / 255. for f in imgfiles[::testskip]], 0).astype( 88 | np.float32) 89 | 90 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 91 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 92 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f)) / 255. for f in imgfiles[::testskip]], 0).astype( 93 | np.float32) 94 | 95 | all_imgs = [imgs, valimgs, testimgs] 96 | counts = [0] + [x.shape[0] for x in all_imgs] 97 | counts = np.cumsum(counts) 98 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] 99 | 100 | imgs = np.concatenate(all_imgs, 0) 101 | poses = np.concatenate([poses, valposes, testposes], 0) 102 | 103 | render_poses = testposes 104 | 105 | print(poses.shape, imgs.shape) 106 | 107 | return imgs, poses, render_poses, [H, W, focal], i_split 108 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | # images_4 11 | # images_8 目录 12 | for r in factors: 13 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 14 | if not os.path.exists(imgdir): 15 | # 不存在对应的目录 16 | needtoload = True 17 | for r in resolutions: 18 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 19 | if not os.path.exists(imgdir): 20 | # 不存在对应的目录 21 | needtoload = True 22 | 23 | # 如果存在那些目录, 那么这里就返回 24 | if not needtoload: 25 | return 26 | 27 | from shutil import copy 28 | from subprocess import check_output 29 | 30 | imgdir = os.path.join(basedir, 'images') 31 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 32 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 33 | imgdir_orig = imgdir 34 | 35 | wd = os.getcwd() 36 | 37 | for r in factors + resolutions: 38 | if isinstance(r, int): 39 | name = 'images_{}'.format(r) 40 | resizearg = '{}%'.format(100. / r) 41 | else: 42 | name = 'images_{}x{}'.format(r[1], r[0]) 43 | resizearg = '{}x{}'.format(r[1], r[0]) 44 | imgdir = os.path.join(basedir, name) 45 | if os.path.exists(imgdir): 46 | continue 47 | 48 | print('Minifying', r, basedir) 49 | 50 | os.makedirs(imgdir) 51 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 52 | 53 | ext = imgs[0].split('.')[-1] 54 | # 执行一个shell命令 55 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 56 | print(args) 57 | os.chdir(imgdir) 58 | check_output(args, shell=True) 59 | os.chdir(wd) 60 | 61 | if ext != 'png': 62 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 63 | print('Removed duplicates') 64 | print('Done') 65 | 66 | 67 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 68 | # 加载 poses_bounds.npy 69 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 70 | # like [3,5,24] 71 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 72 | 73 | bds = poses_arr[:, -2:].transpose([1, 0]) 74 | 75 | # images目录 76 | # 最原始的image目录 77 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 78 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 79 | sh = imageio.imread(img0).shape 80 | 81 | sfx = '' 82 | 83 | if factor is not None: 84 | sfx = '_{}'.format(factor) 85 | _minify(basedir, factors=[factor]) 86 | factor = factor 87 | elif height is not None: 88 | factor = sh[0] / float(height) 89 | width = int(sh[1] / factor) 90 | _minify(basedir, resolutions=[[height, width]]) 91 | sfx = '_{}x{}'.format(width, height) 92 | elif width is not None: 93 | factor = sh[1] / float(width) 94 | height = int(sh[0] / factor) 95 | _minify(basedir, resolutions=[[height, width]]) 96 | sfx = '_{}x{}'.format(width, height) 97 | else: 98 | factor = 1 99 | 100 | # 这个就是 _4, _8 那些目录 101 | imgdir = os.path.join(basedir, 'images' + sfx) 102 | if not os.path.exists(imgdir): 103 | print(imgdir, 'does not exist, returning') 104 | return 105 | 106 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if 107 | f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 108 | 109 | # 判断数量是否相同 110 | if poses.shape[-1] != len(imgfiles): 111 | print('Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1])) 112 | return 113 | 114 | sh = imageio.imread(imgfiles[0]).shape 115 | 116 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 117 | poses[2, 4, :] = poses[2, 4, :] * 1. / factor 118 | 119 | if not load_imgs: 120 | return poses, bds 121 | 122 | def imread(f): 123 | if f.endswith('png'): 124 | return imageio.imread(f, ignoregamma=True) 125 | else: 126 | return imageio.imread(f) 127 | 128 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 129 | imgs = np.stack(imgs, -1) 130 | 131 | print('Loaded image data', imgs.shape, poses[:, -1, 0]) 132 | return poses, bds, imgs 133 | 134 | 135 | def normalize(x): 136 | return x / np.linalg.norm(x) 137 | 138 | 139 | def viewmatrix(z, up, pos): 140 | vec2 = normalize(z) 141 | vec1_avg = up 142 | vec0 = normalize(np.cross(vec1_avg, vec2)) 143 | vec1 = normalize(np.cross(vec2, vec0)) 144 | m = np.stack([vec0, vec1, vec2, pos], 1) 145 | return m 146 | 147 | 148 | def ptstocam(pts, c2w): 149 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 150 | return tt 151 | 152 | 153 | def poses_avg(poses): 154 | hwf = poses[0, :3, -1:] 155 | 156 | center = poses[:, :3, 3].mean(0) 157 | vec2 = normalize(poses[:, :3, 2].sum(0)) 158 | up = poses[:, :3, 1].sum(0) 159 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 160 | 161 | return c2w 162 | 163 | 164 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 165 | render_poses = [] 166 | rads = np.array(list(rads) + [1.]) 167 | hwf = c2w[:, 4:5] 168 | 169 | for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]: 170 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 171 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 172 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 173 | return render_poses 174 | 175 | 176 | def recenter_poses(poses): 177 | poses_ = poses + 0 178 | bottom = np.reshape([0, 0, 0, 1.], [1, 4]) 179 | c2w = poses_avg(poses) 180 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 181 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 182 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 183 | 184 | poses = np.linalg.inv(c2w) @ poses 185 | poses_[:, :3, :4] = poses[:, :3, :4] 186 | poses = poses_ 187 | return poses 188 | 189 | 190 | ##################### 191 | 192 | 193 | def spherify_poses(poses, bds): 194 | p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1) 195 | 196 | rays_d = poses[:, :3, 2:3] 197 | rays_o = poses[:, :3, 3:4] 198 | 199 | def min_line_dist(rays_o, rays_d): 200 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 201 | b_i = -A_i @ rays_o 202 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)) 203 | return pt_mindist 204 | 205 | pt_mindist = min_line_dist(rays_o, rays_d) 206 | 207 | center = pt_mindist 208 | up = (poses[:, :3, 3] - center).mean(0) 209 | 210 | vec0 = normalize(up) 211 | vec1 = normalize(np.cross([.1, .2, .3], vec0)) 212 | vec2 = normalize(np.cross(vec0, vec1)) 213 | pos = center 214 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 215 | 216 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 217 | 218 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 219 | 220 | sc = 1. / rad 221 | poses_reset[:, :3, 3] *= sc 222 | bds *= sc 223 | rad *= sc 224 | 225 | centroid = np.mean(poses_reset[:, :3, 3], 0) 226 | zh = centroid[2] 227 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 228 | new_poses = [] 229 | 230 | for th in np.linspace(0., 2. * np.pi, 120): 231 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 232 | up = np.array([0, 0, -1.]) 233 | 234 | vec2 = normalize(camorigin) 235 | vec0 = normalize(np.cross(vec2, up)) 236 | vec1 = normalize(np.cross(vec2, vec0)) 237 | pos = camorigin 238 | p = np.stack([vec0, vec1, vec2, pos], 1) 239 | 240 | new_poses.append(p) 241 | 242 | new_poses = np.stack(new_poses, 0) 243 | 244 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1) 245 | poses_reset = np.concatenate( 246 | [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1) 247 | 248 | return poses_reset, new_poses, bds 249 | 250 | 251 | # 1. 调用_load_data 252 | # 1. _minify 253 | # 这里可能会调用check_output对图片进行缩放 254 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 255 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 256 | print('Loaded', basedir, bds.min(), bds.max()) 257 | 258 | # Correct rotation matrix ordering and move variable dim to axis 0 259 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 260 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 261 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 262 | images = imgs 263 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 264 | 265 | # Rescale if bd_factor is provided 266 | sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor) 267 | poses[:, :3, 3] *= sc 268 | bds *= sc 269 | 270 | if recenter: 271 | poses = recenter_poses(poses) 272 | 273 | if spherify: 274 | poses, render_poses, bds = spherify_poses(poses, bds) 275 | 276 | else: 277 | 278 | c2w = poses_avg(poses) 279 | print('recentered', c2w.shape) 280 | print(c2w[:3, :4]) 281 | 282 | ## Get spiral 283 | # Get average pose 284 | up = normalize(poses[:, :3, 1].sum(0)) 285 | 286 | # Find a reasonable "focus depth" for this dataset 287 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5. 288 | dt = .75 289 | mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) 290 | focal = mean_dz 291 | 292 | # Get radii for spiral path 293 | shrink_factor = .8 294 | zdelta = close_depth * .2 295 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 296 | rads = np.percentile(np.abs(tt), 90, 0) 297 | c2w_path = c2w 298 | N_views = 120 299 | N_rots = 2 300 | if path_zflat: 301 | # zloc = np.percentile(tt, 10, 0)[2] 302 | zloc = -close_depth * .1 303 | c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2] 304 | rads[2] = 0. 305 | N_rots = 1 306 | N_views /= 2 307 | 308 | # Generate poses for spiral path 309 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 310 | 311 | render_poses = np.array(render_poses).astype(np.float32) 312 | 313 | c2w = poses_avg(poses) 314 | print('Data:') 315 | print(poses.shape, images.shape, bds.shape) 316 | 317 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1) 318 | i_test = np.argmin(dists) 319 | print('HOLDOUT view is', i_test) 320 | 321 | images = images.astype(np.float32) 322 | poses = poses.astype(np.float32) 323 | 324 | return images, poses, bds, render_poses, i_test 325 | -------------------------------------------------------------------------------- /nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | __all__ = ['img2mse', 'mse2psnr', 'to8b', 'get_embedder', 'get_rays', 'get_rays_np', 'ndc_rays', 'sample_pdf'] 6 | 7 | # Misc 8 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 9 | 10 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 11 | 12 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 13 | 14 | 15 | # Positional encoding (section 5.1) 16 | class Embedder: 17 | def __init__(self, **kwargs): 18 | self.kwargs = kwargs 19 | self.create_embedding_fn() 20 | 21 | def create_embedding_fn(self): 22 | embed_fns = [] 23 | d = self.kwargs['input_dims'] # 3 24 | out_dim = 0 25 | if self.kwargs['include_input']: 26 | embed_fns.append(lambda x: x) 27 | out_dim += d 28 | 29 | max_freq = self.kwargs['max_freq_log2'] 30 | N_freqs = self.kwargs['num_freqs'] 31 | 32 | if self.kwargs['log_sampling']: 33 | # tensor([ 1., 2., 4., 8., 16., 32., 64., 128., 256., 512.]) 34 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 35 | else: 36 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 37 | 38 | for freq in freq_bands: 39 | for p_fn in self.kwargs['periodic_fns']: 40 | # sin(x),sin(2x),sin(4x),sin(8x),sin(16x),sin(32x),sin(64x),sin(128x),sin(256x),sin(512x) 41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 42 | out_dim += d 43 | 44 | self.embed_fns = embed_fns 45 | 46 | # 3D坐标是63,2D方向是27 47 | self.out_dim = out_dim 48 | 49 | def embed(self, inputs): 50 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 51 | 52 | 53 | # 位置编码相关 54 | def get_embedder(multires, i=0): 55 | """ 56 | multires: 3D 坐标是10,2D方向是4 57 | """ 58 | if i == -1: 59 | return nn.Identity(), 3 60 | 61 | embed_kwargs = { 62 | 'include_input': True, 63 | 'input_dims': 3, 64 | 'max_freq_log2': multires - 1, 65 | 'num_freqs': multires, 66 | 'log_sampling': True, 67 | 'periodic_fns': [torch.sin, torch.cos], 68 | } 69 | 70 | embedder_obj = Embedder(**embed_kwargs) 71 | embed = lambda x, eo=embedder_obj: eo.embed(x) 72 | # 第一个返回值是lamda,给定x,返回其位置编码 73 | return embed, embedder_obj.out_dim 74 | 75 | 76 | # ---------------------------------------------------------------------------------------------------------------------- 77 | 78 | # Ray helpers 79 | def get_rays(H, W, K, c2w): 80 | """ 81 | K:相机内参矩阵 82 | c2w: 相机到世界坐标系的转换 83 | """ 84 | # j 85 | # [0,......] 86 | # [1,......] 87 | # [W-1,....] 88 | # i 89 | # [0,..,H-1] 90 | # [0,..,H-1] 91 | # [0,..,H-1] 92 | 93 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H), indexing='ij') 94 | i = i.t() 95 | j = j.t() 96 | # [400,400,3] 97 | dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1) 98 | # Rotate ray directions from camera frame to the world frame 99 | # dirs [400,400,3] -> [400,400,1,3] 100 | # dot product, equals to: [c2w.dot(dir) for dir in dirs] 101 | # rays_d [400,400,3] 102 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) 103 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 104 | # 前三行,最后一列,定义了相机的平移,因此可以得到射线的原点o 105 | rays_o = c2w[:3, -1].expand(rays_d.shape) 106 | return rays_o, rays_d 107 | 108 | 109 | def get_rays_np(H, W, K, c2w): 110 | # 与上面的方法相似,这个是使用的numpy,上面是使用的torch 111 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 112 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1) 113 | # Rotate ray directions from camera frame to the world frame 114 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 115 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 116 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 117 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) 118 | return rays_o, rays_d 119 | 120 | 121 | # Hierarchical sampling (section 5.2) 122 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 123 | """ 124 | bins: z_vals_mid 125 | """ 126 | 127 | # Get pdf 128 | weights = weights + 1e-5 # prevent nans 129 | # 归一化 [bs, 62] 130 | # 概率密度函数 131 | pdf = weights / torch.sum(weights, -1, keepdim=True) 132 | # 累积分布函数 133 | cdf = torch.cumsum(pdf, -1) 134 | # 在第一个位置补0 135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 136 | 137 | # Take uniform samples 138 | if det: 139 | u = torch.linspace(0., 1., steps=N_samples) 140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 141 | else: 142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # [bs,128] 143 | 144 | # Pytest, overwrite u with numpy's fixed random numbers 145 | if pytest: 146 | np.random.seed(0) 147 | new_shape = list(cdf.shape[:-1]) + [N_samples] 148 | if det: 149 | u = np.linspace(0., 1., N_samples) 150 | u = np.broadcast_to(u, new_shape) 151 | else: 152 | u = np.random.rand(*new_shape) 153 | u = torch.Tensor(u) 154 | 155 | # Invert CDF 156 | 157 | u = u.contiguous() 158 | # u 是随机生成的 159 | # 找到对应的插入的位置 160 | inds = torch.searchsorted(cdf, u, right=True) 161 | # 前一个位置,为了防止inds中的0的前一个是-1,这里就还是0 162 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 163 | # 最大的位置就是cdf的上限位置,防止过头,跟上面的意义相同 164 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 165 | # (batch, N_samples, 2) 166 | inds_g = torch.stack([below, above], -1) 167 | 168 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 169 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 170 | # (batch, N_samples, 63) 171 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 172 | # 如[1024,128,63] 提取 根据 inds_g[i][j][0] inds_g[i][j][1] 173 | # cdf_g [1024,128,2] 174 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 175 | # 如上, bins 是从2到6的采样点,是64个点的中间值 176 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 177 | # 差值 178 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 179 | # 防止过小 180 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 181 | 182 | t = (u - cdf_g[..., 0]) / denom 183 | 184 | # lower+线性插值 185 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 186 | 187 | return samples 188 | 189 | 190 | # ---------------------------------------------------------------------------------------------------------------------- 191 | 192 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 193 | # Shift ray origins to near plane 194 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 195 | rays_o = rays_o + t[..., None] * rays_d 196 | 197 | # Projection 198 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 199 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 200 | o2 = 1. + 2. * near / rays_o[..., 2] 201 | 202 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 203 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 204 | d2 = -2. * near / rays_o[..., 2] 205 | 206 | rays_o = torch.stack([o0, o1, o2], -1) 207 | rays_d = torch.stack([d0, d1, d2], -1) 208 | 209 | return rays_o, rays_d 210 | -------------------------------------------------------------------------------- /nerf_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class NeRF(nn.Module): 9 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 10 | """ 11 | D: 深度,多少层网络 12 | W: 网络内的channel 宽度 13 | input_ch: xyz的宽度 14 | input_ch_views: direction的宽度 15 | output_ch: 这个参数尽在 use_viewdirs=False的时候会被使用 16 | skips: 类似resnet的残差连接,表明在第几层进行连接 17 | use_viewdirs: 18 | """ 19 | super(NeRF, self).__init__() 20 | self.D = D 21 | self.W = W 22 | self.input_ch = input_ch 23 | self.input_ch_views = input_ch_views 24 | self.skips = skips 25 | self.use_viewdirs = use_viewdirs 26 | 27 | # 神经网络,MLP 28 | # 3D的空间坐标进入的网络 29 | # 这个跳跃连接层是直接拼接,不是resnet的那种相加 30 | self.pts_linears = nn.ModuleList( 31 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in 32 | range(D - 1)]) 33 | 34 | # 这里channel削减一半 128 35 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 36 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)]) 37 | 38 | if use_viewdirs: 39 | # 特征 40 | self.feature_linear = nn.Linear(W, W) 41 | # 透明度,一个值 42 | self.alpha_linear = nn.Linear(W, 1) 43 | # rgb颜色,三个值 44 | self.rgb_linear = nn.Linear(W // 2, 3) 45 | else: 46 | self.output_linear = nn.Linear(W, output_ch) 47 | 48 | def forward(self, x): 49 | # x [bs*64, 90] 50 | # input_pts [bs*64, 63] 51 | # input_views [bs*64,27] 52 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 53 | 54 | h = input_pts 55 | 56 | for i, l in enumerate(self.pts_linears): 57 | 58 | h = self.pts_linears[i](h) 59 | h = F.relu(h) 60 | # 第四层后相加 61 | if i in self.skips: 62 | h = torch.cat([input_pts, h], -1) 63 | 64 | if self.use_viewdirs: 65 | # alpha只与xyz有关 66 | alpha = self.alpha_linear(h) 67 | feature = self.feature_linear(h) 68 | # rgb与xyz和d都有关 69 | h = torch.cat([feature, input_views], -1) 70 | 71 | for i, l in enumerate(self.views_linears): 72 | h = self.views_linears[i](h) 73 | h = F.relu(h) 74 | 75 | rgb = self.rgb_linear(h) 76 | outputs = torch.cat([rgb, alpha], -1) 77 | else: 78 | outputs = self.output_linear(h) 79 | 80 | return outputs 81 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | def config_parser(): 2 | import configargparse 3 | parser = configargparse.ArgumentParser() 4 | 5 | parser.add_argument('--config', is_config_file=True, 6 | help='config file path') 7 | # 本次实验的名称,作为log中文件夹的名字 8 | parser.add_argument("--expname", type=str, 9 | help='experiment name') 10 | # 输出目录 11 | parser.add_argument("--basedir", type=str, default='./logs/', 12 | help='where to store ckpts and logs') 13 | # 指定数据集的目录 14 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 15 | help='input data directory') 16 | 17 | # training options 18 | # 全连接的层数 19 | parser.add_argument("--netdepth", type=int, default=8, 20 | help='layers in network') 21 | # 网络宽度 22 | parser.add_argument("--netwidth", type=int, default=256, 23 | help='channels per layer') 24 | 25 | # 精细网络的全连接层数 26 | # 默认精细网络的深度和宽度与粗糙网络是相同的 27 | parser.add_argument("--netdepth_fine", type=int, default=8, 28 | help='layers in fine network') 29 | parser.add_argument("--netwidth_fine", type=int, default=256, 30 | help='channels per layer in fine network') 31 | 32 | # 这里的batch size,指的是光线的数量,像素点的数量 33 | # N_rand 配置文件中是1024 34 | # 32*32*4=4096 35 | # 800*800/4096=156 400*400/1024=156 36 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 4, 37 | help='batch size (number of random rays per gradient step)') 38 | # 学习率 39 | parser.add_argument("--lrate", type=float, default=5e-4, 40 | help='learning rate') 41 | # 学习率衰减 42 | parser.add_argument("--lrate_decay", type=int, default=250, 43 | help='exponential learning rate decay (in 1000 steps)') 44 | 45 | parser.add_argument("--chunk", type=int, default=1024 * 32, 46 | help='number of rays processed in parallel, decrease if running out of memory') 47 | 48 | # 网络中处理的点的数量 49 | parser.add_argument("--netchunk", type=int, default=1024 * 64, 50 | help='number of pts sent through network in parallel, decrease if running out of memory') 51 | 52 | # 合成的数据集一般都是True, 每次只从一张图片中选取随机光线 53 | # 真实的数据集一般都是False, 图形先混在一起 54 | parser.add_argument("--no_batching", action='store_true', 55 | help='only take random rays from 1 image at a time') 56 | 57 | # 不加载权重 58 | parser.add_argument("--no_reload", action='store_true', 59 | help='do not reload weights from saved ckpt') 60 | # 粗网络的权重文件的位置 61 | parser.add_argument("--ft_path", type=str, default=None, 62 | help='specific weights npy file to reload for coarse network') 63 | 64 | # rendering options 65 | parser.add_argument("--N_samples", type=int, default=64, 66 | help='number of coarse samples per ray') 67 | parser.add_argument("--N_importance", type=int, default=0, 68 | help='number of additional fine samples per ray') 69 | 70 | parser.add_argument("--perturb", type=float, default=1., 71 | help='set to 0. for no jitter, 1. for jitter') 72 | # 不适用视角数据 73 | parser.add_argument("--use_viewdirs", action='store_true', 74 | help='use full 5D input instead of 3D') 75 | # 0 使用位置编码,-1 不使用位置编码 76 | parser.add_argument("--i_embed", type=int, default=0, 77 | help='set 0 for default positional encoding, -1 for none') 78 | 79 | # L=10 80 | parser.add_argument("--multires", type=int, default=10, 81 | help='log2 of max freq for positional encoding (3D location)') 82 | # L=4 83 | parser.add_argument("--multires_views", type=int, default=4, 84 | help='log2 of max freq for positional encoding (2D direction)') 85 | 86 | parser.add_argument("--raw_noise_std", type=float, default=0., 87 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 88 | 89 | # 仅进行渲染 90 | parser.add_argument("--render_only", action='store_true', 91 | help='do not optimize, reload weights and render out render_poses path') 92 | # 渲染test数据集 93 | parser.add_argument("--render_test", action='store_true', 94 | help='render the test set instead of render_poses path') 95 | # 下采样的倍数 96 | parser.add_argument("--render_factor", type=int, default=0, 97 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 98 | 99 | # training options 100 | # 中心裁剪的训练轮数 101 | parser.add_argument("--precrop_iters", type=int, default=0, 102 | help='number of steps to train on central crops') 103 | parser.add_argument("--precrop_frac", type=float, 104 | default=.5, help='fraction of img taken for central crops') 105 | 106 | # dataset options 107 | # 数据格式 108 | parser.add_argument("--dataset_type", type=str, default='llff', 109 | help='options: llff / blender / deepvoxels') 110 | 111 | # 对于大的数据集,test和val数据集,只使用其中的一部分数据 112 | parser.add_argument("--testskip", type=int, default=8, 113 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 114 | 115 | ## deepvoxels flags 116 | parser.add_argument("--shape", type=str, default='greek', 117 | help='options : armchair / cube / greek / vase') 118 | 119 | ## blender flags 120 | # 白色背景 121 | parser.add_argument("--white_bkgd", action='store_true', 122 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 123 | 124 | # 使用一半分辨率 125 | parser.add_argument("--half_res", action='store_true', 126 | help='load blender synthetic data at 400x400 instead of 800x800') 127 | 128 | ## llff flags 129 | parser.add_argument("--factor", type=int, default=8, 130 | help='downsample factor for LLFF images') 131 | parser.add_argument("--no_ndc", action='store_true', 132 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 133 | parser.add_argument("--lindisp", action='store_true', 134 | help='sampling linearly in disparity rather than depth') 135 | parser.add_argument("--spherify", action='store_true', 136 | help='set for spherical 360 scenes') 137 | parser.add_argument("--llffhold", type=int, default=8, 138 | help='will take every 1/N images as LLFF test set, paper uses 8') 139 | 140 | # logging/saving options 141 | # log输出的频率 142 | parser.add_argument("--i_print", type=int, default=100, 143 | help='frequency of console printout and metric loggin') 144 | parser.add_argument("--i_img", type=int, default=500, 145 | help='frequency of tensorboard image logging') 146 | # 保存模型的频率 147 | # 每隔1w保存一个 148 | parser.add_argument("--i_weights", type=int, default=10000, 149 | help='frequency of weight ckpt saving') 150 | # 执行测试集渲染的频率 151 | parser.add_argument("--i_testset", type=int, default=50000, 152 | help='frequency of testset saving') 153 | # 执行渲染视频的频率 154 | parser.add_argument("--i_video", type=int, default=50000, 155 | help='frequency of render_poses video saving') 156 | 157 | return parser 158 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from nerf_helpers import sample_pdf, get_rays, ndc_rays 6 | 7 | DEBUG = False 8 | 9 | __all__ = ['render', 'batchify_rays', 'render_rays', 'raw2outputs'] 10 | 11 | 12 | def render(H, W, K, 13 | chunk=1024 * 32, rays=None, c2w=None, ndc=True, 14 | near=0., far=1., 15 | use_viewdirs=False, c2w_staticcam=None, 16 | **kwargs): 17 | """Render rays 18 | Args: 19 | H: int. Height of image in pixels. 20 | W: int. Width of image in pixels. 21 | K: 相机内参 focal 22 | chunk: int. Maximum number of rays to process simultaneously. Used to 23 | control maximum memory usage. Does not affect final results. 24 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 25 | each example in batch. 26 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 27 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 28 | near: float or array of shape [batch_size]. Nearest distance for a ray. 29 | far: float or array of shape [batch_size]. Farthest distance for a ray. 30 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 31 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 32 | camera while using other c2w argument for viewing directions. 33 | Returns: 34 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 35 | disp_map: [batch_size]. Disparity map. Inverse of depth. 36 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 37 | extras: dict with everything returned by render_rays(). 38 | """ 39 | 40 | if c2w is not None: 41 | # special case to render full image 42 | rays_o, rays_d = get_rays(H, W, K, c2w) 43 | else: 44 | # use provided ray batch 45 | # 光线的起始位置, 方向 46 | rays_o, rays_d = rays 47 | 48 | if use_viewdirs: 49 | # provide ray directions as input 50 | viewdirs = rays_d 51 | # 静态相机 相机坐标到世界坐标的转换 52 | if c2w_staticcam is not None: 53 | # special case to visualize effect of viewdirs 54 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 55 | # 单位向量 [bs,3] 56 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 57 | viewdirs = torch.reshape(viewdirs, [-1, 3]).float() 58 | 59 | sh = rays_d.shape # [..., 3] 60 | 61 | if ndc: 62 | # for forward facing scenes 63 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 64 | 65 | # Create ray batch 66 | rays_o = torch.reshape(rays_o, [-1, 3]).float() 67 | rays_d = torch.reshape(rays_d, [-1, 3]).float() 68 | # [bs,1],[bs,1] 69 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) 70 | # 8=3+3+1+1 71 | rays = torch.cat([rays_o, rays_d, near, far], -1) 72 | if use_viewdirs: 73 | # 加了direction的三个坐标 74 | # 3 3 1 1 3 75 | rays = torch.cat([rays, viewdirs], -1) # [bs,11] 76 | 77 | # Render and reshape 78 | 79 | # rgb_map,disp_map,acc_map,raw,rbg0,disp0,acc0,z_std 80 | all_ret = batchify_rays(rays, chunk, **kwargs) 81 | for k in all_ret: 82 | # 对所有的返回值进行reshape 83 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 84 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 85 | 86 | # 讲精细网络的输出单独拿了出来 87 | k_extract = ['rgb_map', 'disp_map', 'acc_map'] 88 | ret_list = [all_ret[k] for k in k_extract] 89 | ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} 90 | # 前三是list,后5还是在map中 91 | return ret_list + [ret_dict] 92 | 93 | 94 | def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs): 95 | """ 96 | Render rays in smaller minibatches to avoid OOM. 97 | rays_flat: [N_rand,11] 98 | """ 99 | 100 | all_ret = {} 101 | for i in range(0, rays_flat.shape[0], chunk): 102 | ret = render_rays(rays_flat[i:i + chunk], **kwargs) 103 | for k in ret: 104 | if k not in all_ret: 105 | all_ret[k] = [] 106 | all_ret[k].append(ret[k]) 107 | # 将分批处理的结果拼接在一起 108 | all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} 109 | return all_ret 110 | 111 | 112 | # 这里面会经过神经网络 113 | def render_rays(ray_batch, 114 | network_fn, 115 | network_query_fn, 116 | N_samples, 117 | retraw=False, 118 | lindisp=False, 119 | perturb=0., 120 | N_importance=0, 121 | network_fine=None, 122 | white_bkgd=False, 123 | raw_noise_std=0., 124 | verbose=False, 125 | pytest=False): 126 | """Volumetric rendering. 127 | Args: 128 | ray_batch: array of shape [batch_size, ...]. All information necessary 129 | for sampling along a ray, including: ray origin, ray direction, min 130 | dist, max dist, and unit-magnitude viewing direction. 单位大小查看方向 131 | 粗网络 132 | network_fn: function. Model for predicting RGB and density at each point 133 | in space. 134 | network_query_fn: function used for passing queries to network_fn. 135 | N_samples: int. Number of different times to sample along each ray. 136 | 137 | raw 是指神经网络的输出 138 | retraw: bool. If True, include model's raw, unprocessed predictions. 139 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 140 | 141 | 142 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified random points in time. 143 | 144 | 精细网络中的光线上的采样频率 145 | N_importance: int. Number of additional times to sample along each ray. 146 | These samples are only passed to network_fine. 147 | 精细网络 148 | network_fine: "fine" network with same spec as network_fn. 149 | white_bkgd: bool. If True, assume a white background. 白色背景 150 | raw_noise_std: ... 151 | 152 | 153 | verbose: bool. If True, print more debugging info. 154 | Returns: 155 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 156 | disp_map: [num_rays]. Disparity map. 1 / depth. 157 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 158 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 159 | rgb0: See rgb_map. Output for coarse model. 160 | disp0: See disp_map. Output for coarse model. 161 | acc0: See acc_map. Output for coarse model. 162 | z_std: [num_rays]. Standard deviation of distances along ray for each 163 | sample. 164 | """ 165 | N_rays = ray_batch.shape[0] # N_rand 166 | # 光线起始位置,光线的方向 167 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each 168 | # 视角的单位向量 169 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None 170 | bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) # [bs,1,2] near和far 171 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1] 172 | # 采样点 173 | t_vals = torch.linspace(0., 1., steps=N_samples) 174 | if not lindisp: 175 | z_vals = near * (1. - t_vals) + far * (t_vals) # 插值采样 176 | else: 177 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)) 178 | # [N_rand,64] -> [N_rand,64] 179 | z_vals = z_vals.expand([N_rays, N_samples]) 180 | 181 | if perturb > 0.: 182 | # get intervals between samples,64个采样点的中点 183 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 184 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 185 | lower = torch.cat([z_vals[..., :1], mids], -1) 186 | # stratified samples in those intervals 187 | t_rand = torch.rand(z_vals.shape) 188 | 189 | # Pytest, overwrite u with numpy's fixed random numbers 190 | if pytest: 191 | np.random.seed(0) 192 | t_rand = np.random.rand(*list(z_vals.shape)) 193 | t_rand = torch.Tensor(t_rand) 194 | # [bs,64] 加上随机的噪声 195 | z_vals = lower + (upper - lower) * t_rand 196 | 197 | # 空间中的采样点 198 | # [N_rand, 64, 3] 199 | # 出发点+距离*方向 200 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3] 201 | 202 | # 使用神经网络 viewdirs [N_rand,3], network_fn 指的是粗糙NeRF或者精细NeRF 203 | # raw [bs,64,3] 204 | raw = network_query_fn(pts, viewdirs, network_fn) 205 | 206 | # rgb值,xx,权重的和,weights就是论文中的那个Ti和alpha的乘积 207 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, 208 | pytest=pytest) 209 | 210 | # 精细网络部分 211 | if N_importance > 0: 212 | # _0 是第一个阶段 粗糙网络的结果 213 | # 这三个留着放在dict中输出用 214 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 215 | # 第二次计算mid,取中点位置 216 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 217 | z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest) 218 | z_samples = z_samples.detach() 219 | 220 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 221 | # 给精细网络使用的点 222 | # [N_rays, N_samples + N_importance, 3] 223 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] 224 | 225 | run_fn = network_fn if network_fine is None else network_fine 226 | 227 | # 使用神经网络 228 | # create_nerf 中的 network_query_fn 那个lambda 函数 229 | # viewdirs 与粗糙网络是相同的 230 | raw = network_query_fn(pts, viewdirs, run_fn) 231 | 232 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, 233 | pytest=pytest) 234 | 235 | ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map} 236 | 237 | if retraw: 238 | # 如果是两个网络,那么这个raw就是最后精细网络的输出 239 | ret['raw'] = raw 240 | 241 | if N_importance > 0: 242 | # 下面的0是粗糙网络的输出 243 | ret['rgb0'] = rgb_map_0 244 | ret['disp0'] = disp_map_0 245 | ret['acc0'] = acc_map_0 246 | 247 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 248 | 249 | # 检查是否有异常值 250 | for k in ret: 251 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 252 | print(f"! [Numerical Error] {k} contains nan or inf.") 253 | 254 | return ret 255 | 256 | 257 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 258 | """Transforms model's predictions to semantically meaningful values. 259 | Args: 260 | Model的输出 261 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 262 | 采样,并未变成空间点的那个采样点 263 | z_vals: [num_rays, num_samples along ray]. Integration time. 264 | 光线的方向 265 | rays_d: [num_rays, 3]. Direction of each ray. 266 | Returns: 267 | RGB颜色值 268 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 269 | 逆深度 270 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 271 | 权重和? 272 | acc_map: [num_rays]. Sum of weights along each ray. 273 | 权重 274 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 275 | 估计的深度 276 | depth_map: [num_rays]. Estimated distance to object. 277 | """ 278 | # Alpha的计算 279 | # relu, 负数拉平为0 280 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists) 281 | # [bs,63] 282 | # 采样点之间的距离 283 | dists = z_vals[..., 1:] - z_vals[..., :-1] 284 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 285 | # rays_d[...,None,:] [bs,3] -> [bs,1,3] 286 | # 1维 -> 3维 287 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 288 | # RGB经过sigmoid处理 289 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 290 | noise = 0. 291 | if raw_noise_std > 0.: 292 | noise = torch.randn(raw[..., 3].shape) * raw_noise_std 293 | 294 | # Overwrite randomly sampled data if pytest 295 | if pytest: 296 | np.random.seed(0) 297 | noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std 298 | noise = torch.Tensor(noise) 299 | # 计算公式3 [bs, 64], 300 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] 301 | 302 | # 后面这部分就是Ti,前面是alpha,这个就是论文上的那个权重w [bs,64] 303 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 304 | 1. - alpha + 1e-10], -1), 305 | -1)[:, :-1] 306 | # [bs, 64,1] * [bs,64,3] 307 | # 在第二个维度,64将所有的点的值相加 -> [32,3] 308 | # 公式3的结果值 309 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 310 | # (32,) 311 | # 深度图 312 | # Estimated depth map is expected distance. 313 | depth_map = torch.sum(weights * z_vals, -1) 314 | # 视差图 315 | # Disparity map is inverse depth. 316 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 317 | 318 | # 权重和 319 | # 这个值仅做了输出用,后续并无使用 320 | acc_map = torch.sum(weights, -1) 321 | 322 | if white_bkgd: 323 | rgb_map = rgb_map + (1. - acc_map[..., None]) 324 | 325 | return rgb_map, disp_map, acc_map, weights, depth_map 326 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | imageio-ffmpeg 3 | matplotlib 4 | configargparse 5 | tensorboard>=2.0 6 | tqdm 7 | opencv-python 8 | -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import time 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm, trange 7 | from nerf_helpers import * 8 | from nerf_model import NeRF 9 | from load_llff import load_llff_data 10 | from load_deepvoxels import load_dv_data 11 | from load_blender import load_blender_data 12 | from load_LINEMOD import load_LINEMOD_data 13 | from render import * 14 | from inference import render_path 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | np.random.seed(0) 18 | DEBUG = False 19 | 20 | 21 | def batchify(fn, chunk): 22 | """ 23 | Constructs a version of 'fn' that applies to smaller batches. 24 | """ 25 | if chunk is None: 26 | return fn 27 | 28 | def ret(inputs): 29 | # 以chunk分批进入网络,防止显存爆掉,然后在拼接 30 | return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 31 | 32 | return ret 33 | 34 | 35 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64): 36 | """ 37 | 被下面的create_nerf 封装到了lambda方法里面 38 | Prepares inputs and applies network 'fn'. 39 | inputs: pts,光线上的点 如 [1024,64,3],1024条光线,一条光线上64个点 40 | viewdirs: 光线起点的方向 41 | fn: 神经网络模型 粗糙网络或者精细网络 42 | embed_fn: 43 | embeddirs_fn: 44 | netchunk: 45 | """ 46 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # [N_rand*64,3] 47 | # 坐标点进行编码嵌入 [N_rand*64,63] 48 | embedded = embed_fn(inputs_flat) 49 | 50 | if viewdirs is not None: 51 | input_dirs = viewdirs[:, None].expand(inputs.shape) 52 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 53 | # 方向进行位置编码 54 | embedded_dirs = embeddirs_fn(input_dirs_flat) # [N_rand*64,27] 55 | embedded = torch.cat([embedded, embedded_dirs], -1) 56 | 57 | # 里面经过网络 [bs*64,3] 58 | outputs_flat = batchify(fn, netchunk)(embedded) 59 | # [bs*4,4] -> [bs,64,4] 60 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 61 | return outputs 62 | 63 | 64 | def create_nerf(args): 65 | """Instantiate NeRF's MLP model. 66 | """ 67 | 68 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 69 | 70 | input_ch_views = 0 71 | embeddirs_fn = None 72 | if args.use_viewdirs: 73 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 74 | 75 | # 想要=5生效,首先需要use_viewdirs=False and N_importance>0 76 | output_ch = 5 if args.N_importance > 0 else 4 77 | skips = [4] 78 | # 粗网络 79 | model = NeRF(D=args.netdepth, W=args.netwidth, 80 | input_ch=input_ch, output_ch=output_ch, skips=skips, 81 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 82 | grad_vars = list(model.parameters()) 83 | 84 | model_fine = None 85 | if args.N_importance > 0: 86 | # 精细网络 87 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 88 | input_ch=input_ch, output_ch=output_ch, skips=skips, 89 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 90 | # 模型参数 91 | grad_vars += list(model_fine.parameters()) 92 | 93 | # netchunk 是网络中处理的点的batch_size 94 | network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn, 95 | embed_fn=embed_fn, 96 | embeddirs_fn=embeddirs_fn, 97 | netchunk=args.netchunk) 98 | 99 | # Create optimizer 100 | # 优化器 101 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 102 | 103 | start = 0 104 | basedir = args.basedir 105 | expname = args.expname 106 | 107 | ########################## 108 | 109 | # Load checkpoints 110 | if args.ft_path is not None and args.ft_path != 'None': 111 | ckpts = [args.ft_path] 112 | else: 113 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 114 | 'tar' in f] 115 | 116 | print('Found ckpts', ckpts) 117 | 118 | # load参数 119 | if len(ckpts) > 0 and not args.no_reload: 120 | ckpt_path = ckpts[-1] 121 | print('Reloading from', ckpt_path) 122 | ckpt = torch.load(ckpt_path) 123 | 124 | start = ckpt['global_step'] 125 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 126 | 127 | # Load model 128 | model.load_state_dict(ckpt['network_fn_state_dict']) 129 | if model_fine is not None: 130 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 131 | 132 | ########################## 133 | 134 | render_kwargs_train = { 135 | 'network_query_fn': network_query_fn, 136 | 'perturb': args.perturb, 137 | 'N_importance': args.N_importance, 138 | # 精细网络 139 | 'network_fine': model_fine, 140 | 'N_samples': args.N_samples, 141 | # 粗网络 142 | 'network_fn': model, 143 | 'use_viewdirs': args.use_viewdirs, 144 | 'white_bkgd': args.white_bkgd, 145 | 'raw_noise_std': args.raw_noise_std, 146 | } 147 | 148 | print(model_fine) 149 | 150 | # NDC only good for LLFF-style forward facing data 151 | if args.dataset_type != 'llff' or args.no_ndc: 152 | print('Not ndc!') 153 | render_kwargs_train['ndc'] = False 154 | render_kwargs_train['lindisp'] = args.lindisp 155 | 156 | render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train} 157 | render_kwargs_test['perturb'] = False 158 | render_kwargs_test['raw_noise_std'] = 0. 159 | 160 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 161 | 162 | 163 | # ---------------------------------------------------------------------------------------------------------------------- 164 | 165 | def create_log_files(basedir, expname, args): 166 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 167 | 168 | # 保存一份参数文件 169 | f = os.path.join(basedir, expname, 'args.txt') 170 | with open(f, 'w') as file: 171 | for arg in sorted(vars(args)): 172 | attr = getattr(args, arg) 173 | file.write('{} = {}\n'.format(arg, attr)) 174 | 175 | # 保存一份配置文件 176 | if args.config is not None: 177 | f = os.path.join(basedir, expname, 'config.txt') 178 | with open(f, 'w') as file: 179 | file.write(open(args.config, 'r').read()) 180 | 181 | return basedir, expname 182 | 183 | 184 | def run_render_only(args, images, i_test, basedir, expname, render_poses, hwf, K, render_kwargs_test, start): 185 | with torch.no_grad(): 186 | if args.render_test: 187 | # render_test switches to test poses 188 | images = images[i_test] 189 | else: 190 | # Default is smoother render_poses path 191 | images = None 192 | 193 | testsavedir = os.path.join(basedir, expname, 194 | 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 195 | os.makedirs(testsavedir, exist_ok=True) 196 | print('test poses shape', render_poses.shape) 197 | 198 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, 199 | savedir=testsavedir, render_factor=args.render_factor) 200 | print('Done rendering', testsavedir) 201 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 202 | 203 | 204 | def train(): 205 | # 解析参数 206 | from opts import config_parser 207 | parser = config_parser() 208 | args = parser.parse_args() 209 | 210 | # -------------------------------------------------------------------------------------------------------- 211 | 212 | # Load data 213 | 214 | # 在这个数据集会特殊些 LINEMOD 215 | K = None 216 | 217 | # 一共有四种类型的数据集 218 | # 是configs目录中 只有llff和blender两种类型 219 | # 原始的nerf仓库中有deepvoxels类型的数据 220 | # LINEMOD 没见过 221 | 222 | # llff Local Light Field Fusion 223 | if args.dataset_type == 'llff': 224 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 225 | recenter=True, bd_factor=.75, 226 | spherify=args.spherify) 227 | hwf = poses[0, :3, -1] 228 | poses = poses[:, :3, :4] 229 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 230 | if not isinstance(i_test, list): 231 | i_test = [i_test] 232 | 233 | if args.llffhold > 0: 234 | print('Auto LLFF holdout,', args.llffhold) 235 | i_test = np.arange(images.shape[0])[::args.llffhold] 236 | 237 | i_val = i_test 238 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 239 | (i not in i_test and i not in i_val)]) 240 | 241 | print('DEFINING BOUNDS') 242 | if args.no_ndc: 243 | near = np.ndarray.min(bds) * .9 244 | far = np.ndarray.max(bds) * 1. 245 | 246 | else: 247 | near = 0. 248 | far = 1. 249 | print('NEAR FAR', near, far) 250 | 251 | elif args.dataset_type == 'blender': 252 | # images,所有的图片,train val test在一起,poses也一样 253 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) 254 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 255 | i_train, i_val, i_test = i_split 256 | 257 | near = 2. 258 | far = 6. 259 | 260 | if args.white_bkgd: 261 | # todo 这个是什么操作,为什么白色背景要这样操作 262 | images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) 263 | else: 264 | images = images[..., :3] 265 | 266 | elif args.dataset_type == 'LINEMOD': 267 | # 这个数据类型 原始的nerf中没有 268 | 269 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, 270 | args.testskip) 271 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') 272 | print(f'[CHECK HERE] near: {near}, far: {far}.') 273 | i_train, i_val, i_test = i_split 274 | 275 | if args.white_bkgd: 276 | images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) 277 | else: 278 | images = images[..., :3] 279 | 280 | elif args.dataset_type == 'deepvoxels': 281 | 282 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 283 | basedir=args.datadir, 284 | testskip=args.testskip) 285 | 286 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) 287 | i_train, i_val, i_test = i_split 288 | 289 | hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) 290 | near = hemi_R - 1. 291 | far = hemi_R + 1. 292 | 293 | else: 294 | print('Unknown dataset type', args.dataset_type, 'exiting') 295 | return 296 | 297 | # Cast intrinsics to right types 298 | H, W, focal = hwf 299 | H, W = int(H), int(W) 300 | hwf = [H, W, focal] 301 | 302 | # K 相机内参 focal 是焦距,0.5w 0.5h 是中心点坐标 303 | # 这个矩阵是相机坐标到图像坐标转换使用 304 | if K is None: 305 | K = np.array([ 306 | [focal, 0, 0.5 * W], 307 | [0, focal, 0.5 * H], 308 | [0, 0, 1] 309 | ]) 310 | 311 | # -------------------------------------------------------------------------------------------------------- 312 | 313 | # render the test set instead of render_poses path 314 | # 使用测试集的pose,而不是用那个固定生成的render_poses 315 | if args.render_test: 316 | render_poses = np.array(poses[i_test]) 317 | 318 | # Move testing data to GPU 319 | render_poses = torch.Tensor(render_poses).to(device) 320 | 321 | # -------------------------------------------------------------------------------------------------------- 322 | 323 | # Create log dir and copy the config file 324 | 325 | basedir = args.basedir 326 | expname = args.expname 327 | 328 | create_log_files(basedir, expname, args) 329 | 330 | # -------------------------------------------------------------------------------------------------------- 331 | 332 | # Create nerf model 333 | # 创建模型 334 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 335 | # 有可能从中间迭代恢复运行的 336 | global_step = start 337 | 338 | bds_dict = { 339 | 'near': near, 340 | 'far': far, 341 | } 342 | render_kwargs_train.update(bds_dict) 343 | render_kwargs_test.update(bds_dict) 344 | 345 | # -------------------------------------------------------------------------------------------------------- 346 | 347 | # Short circuit if only rendering out from trained model 348 | # 这里会使用render_poses 349 | if args.render_only: 350 | # 仅进行渲染,不进行训练 351 | print('RENDER ONLY') 352 | run_render_only(args, images, i_test, basedir, expname, render_poses, hwf, K, render_kwargs_test, start) 353 | return 354 | 355 | # -------------------------------------------------------------------------------------------------------- 356 | 357 | # Prepare ray batch tensor if batching random rays 358 | N_rand = args.N_rand 359 | 360 | use_batching = not args.no_batching 361 | 362 | if use_batching: 363 | # For random ray batching 364 | print('get rays') # (img_count,2,400,400,3) 2是 rays_o和rays_d 365 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] 366 | print('done, concats') # rays和图像混在一起 367 | rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] 368 | rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] 369 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only, 仅使用训练文件夹下的数据 370 | rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] 371 | rays_rgb = rays_rgb.astype(np.float32) 372 | print('shuffle rays') 373 | np.random.shuffle(rays_rgb) # 打乱光线 374 | 375 | print('done') 376 | i_batch = 0 377 | 378 | # 统一一个时刻放入cuda 379 | # Move training data to GPU 380 | if use_batching: 381 | images = torch.Tensor(images).to(device) 382 | rays_rgb = torch.Tensor(rays_rgb).to(device) 383 | 384 | poses = torch.Tensor(poses).to(device) 385 | 386 | # -------------------------------------------------------------------------------------------------------- 387 | 388 | print('Begin') 389 | print('TRAIN views are', i_train) 390 | print('TEST views are', i_test) 391 | print('VAL views are', i_val) 392 | 393 | # 训练部分的代码 394 | # 两万次迭代 395 | # 可能是强迫症,不想在保存文件的时候,出现19999这种数字 396 | N_iters = 200000 + 1 397 | start = start + 1 398 | for i in trange(start, N_iters): 399 | time0 = time.time() 400 | 401 | # Sample random ray batch 402 | if use_batching: 403 | # Random over all images 404 | # 一批光线 405 | batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] 406 | 407 | batch = torch.transpose(batch, 0, 1) 408 | batch_rays, target_s = batch[:2], batch[2] # 前两个是rays_o和rays_d, 第三个是target就是image的rgb 409 | 410 | i_batch += N_rand 411 | if i_batch >= rays_rgb.shape[0]: 412 | # 所用光线用过之后,重新打乱 413 | print("Shuffle data after an epoch!") 414 | rand_idx = torch.randperm(rays_rgb.shape[0]) 415 | rays_rgb = rays_rgb[rand_idx] 416 | i_batch = 0 417 | 418 | else: 419 | # Random from one image 420 | img_i = np.random.choice(i_train) 421 | target = images[img_i] # [400,400,3] 图像内容 422 | target = torch.Tensor(target).to(device) 423 | pose = poses[img_i, :3, :4] 424 | 425 | if N_rand is not None: 426 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) 427 | 428 | # precrop_iters: number of steps to train on central crops 429 | if i < args.precrop_iters: 430 | dH = int(H // 2 * args.precrop_frac) 431 | dW = int(W // 2 * args.precrop_frac) 432 | coords = torch.stack( 433 | torch.meshgrid( 434 | torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), 435 | torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW), indexing='ij', 436 | ), -1) 437 | if i == start: 438 | print( 439 | f"[Config] Center cropping of size {2 * dH} x {2 * dW} is enabled until iter {args.precrop_iters}") 440 | else: 441 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H), 442 | torch.linspace(0, W - 1, W), indexing='ij'), 443 | -1) # (H, W, 2) 444 | 445 | coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) 446 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 447 | # 选出的像素坐标 448 | select_coords = coords[select_inds].long() # (N_rand, 2) 449 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 450 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 451 | batch_rays = torch.stack([rays_o, rays_d], 0) # 堆叠 o和d 452 | # target 也同样选出对应位置的点 453 | # target 用来最后的mse loss 计算 454 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 455 | 456 | ##### Core optimization loop ##### 457 | # rgb 网络计算出的图像 458 | # 前三是精细网络的输出内容,其他的还保存在一个dict中,有5项 459 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 460 | verbose=i < 10, retraw=True, 461 | **render_kwargs_train) 462 | 463 | optimizer.zero_grad() 464 | # 计算loss 465 | img_loss = img2mse(rgb, target_s) 466 | loss = img_loss 467 | # 计算指标 468 | psnr = mse2psnr(img_loss) 469 | 470 | # rgb0 粗网络的输出 471 | if 'rgb0' in extras: 472 | img_loss0 = img2mse(extras['rgb0'], target_s) 473 | loss = loss + img_loss0 474 | psnr0 = mse2psnr(img_loss0) 475 | 476 | loss.backward() 477 | optimizer.step() 478 | 479 | # NOTE: IMPORTANT! 480 | ### update learning rate ### 481 | # 学习率衰减 482 | decay_rate = 0.1 483 | decay_steps = args.lrate_decay * 1000 484 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 485 | for param_group in optimizer.param_groups: 486 | param_group['lr'] = new_lrate 487 | ################################ 488 | 489 | # 保存模型 490 | if i % args.i_weights == 0: 491 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 492 | torch.save({ 493 | # 运行的轮次数目 494 | 'global_step': global_step, 495 | # 粗网络的权重 496 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 497 | # 精细网络的权重 498 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 499 | # 优化器的状态 500 | 'optimizer_state_dict': optimizer.state_dict(), 501 | }, path) 502 | print('Saved checkpoints at', path) 503 | 504 | # 生成测试视频,使用的是render_poses (这个不等同于test数据) 505 | if i % args.i_video == 0 and i > 0: 506 | # Turn on testing mode 507 | with torch.no_grad(): 508 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) 509 | print('Done, saving', rgbs.shape, disps.shape) 510 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 511 | # 360度转一圈的视频 512 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) 513 | 514 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) 515 | 516 | # 执行测试,使用测试数据 517 | if i % args.i_testset == 0 and i > 0: 518 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 519 | os.makedirs(testsavedir, exist_ok=True) 520 | print('test poses shape', poses[i_test].shape) 521 | with torch.no_grad(): 522 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, 523 | gt_imgs=images[i_test], savedir=testsavedir) 524 | print('Saved test set') 525 | 526 | # 用时 527 | dt = time.time() - time0 528 | # 打印log信息的频率 529 | if i % args.i_print == 0: 530 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} Time: {dt}") 531 | 532 | global_step += 1 533 | 534 | 535 | if __name__ == '__main__': 536 | if torch.cuda.is_available(): 537 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 538 | train() 539 | -------------------------------------------------------------------------------- /test/frame.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | t = imageio.mimread('blender_paper_lego_spiral_200000_rgb.mp4') 4 | 5 | print(len(t)) 6 | -------------------------------------------------------------------------------- /模型.txt: -------------------------------------------------------------------------------- 1 | 2 | pts_linears 3 | 4 | 256+63=319 5 | 6 | ModuleList( 7 | (0): Linear(in_features=63, out_features=256, bias=True) 8 | (1): Linear(in_features=256, out_features=256, bias=True) 9 | (2): Linear(in_features=256, out_features=256, bias=True) 10 | (3): Linear(in_features=256, out_features=256, bias=True) 11 | (4): Linear(in_features=256, out_features=256, bias=True) 12 | (5): Linear(in_features=319, out_features=256, bias=True) 13 | (6): Linear(in_features=256, out_features=256, bias=True) 14 | (7): Linear(in_features=256, out_features=256, bias=True) 15 | ) 16 | 17 | 一共八层 18 | 19 | views_linears 20 | 21 | 256+27=283 22 | 23 | ModuleList( 24 | (0): Linear(in_features=283, out_features=128, bias=True) 25 | ) 26 | 27 | alpha_linear 28 | Linear(in_features=256, out_features=1, bias=True) 29 | 30 | 31 | feature_linear 32 | Linear(in_features=256, out_features=256, bias=True) 33 | 34 | 35 | rgb_linear 36 | Linear(in_features=128, out_features=3, bias=True) -------------------------------------------------------------------------------- /源码结构.md: -------------------------------------------------------------------------------- 1 | # NeRF源码 2 | 3 | ## 核心文件 4 | 5 | 1. [opts](opts.py) 6 | 2. [load_blender_data](load_blender.py) 7 | 3. [run_nerf](run_nerf.py) 8 | 4. [nerf_helpers](nerf_helpers.py) 9 | 5. [nerf_model](nerf_model.py) 10 | 6. [render](render.py) 11 | 7. [inference](inference.py) 12 | 13 | ## 简易流程 14 | 15 | 1. [load_blender_data](load_blender.py) 16 | 1. [pose_spherical](load_blender.py) 17 | 2. [create_nerf](run_nerf.py) 18 | 1. get_embedder 19 | 1. create Embedder 20 | 2. 里面有个在train时候会被调用的lambda network_query_fn 21 | 3. [in train iteration](run_nerf.py): 22 | 1. use_batching or not 23 | 1. get_rays 24 | 2. render(训练相关的代码从这里开始) 25 | 1. rays_o, rays_d = get_rays(H, W, K, c2w) 26 | 2. [batchify_rays](render.py) 分批处理 27 | 1. render_rays 28 | 1. 准备工作 29 | 1. 分解出rays_o,rays_d, viewdirs, near, fear 30 | 2. 构造采样点,给采样点加上随机的噪声 31 | 2. network_query_fn (pts, viewdirs, network_fn) 这个函数是create_nerf中的那个lambda函数 32 | 1. run_network 33 | 1. xyz pe 34 | 2. viewdirs pe 35 | 3. batchify 在这里调用的fn就是NeRF model 36 | 1. 将pts,viewdirs 分开,63,27 37 | 2. pts 经过8层Linear 38 | 3. 8层后的输出经过一层Linear 输出 Alpha 39 | 4. 8层后的输出在来一层Linear (feature Linear) 40 | 5. feature 和 input_views拼接 在经过一层Linear 41 | 6. 最后在经过一层Linear 得到RGB 42 | 3. [raw2outputs](render.py) 体渲染在这里 43 | 4. sample_pdf(z_vals_mid, weights, N_importance) 精细网络使用的采样方案 44 | 5. network_query_fn (pts, viewdirs, network_fn) 第二次是精细网络的 45 | 6. raw2outputs 体渲染在这里 46 | 3. [img2mse](nerf_helpers.py) 47 | 4. [mse2psnr](nerf_helpers.py) 48 | 5. 调整学习率 49 | 6. 定期保存模型,定期生成测试视频,定期渲染测试数据 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /说明.md: -------------------------------------------------------------------------------- 1 | # 说明 2 | 3 | transforms文件中的每一项的内容如下: 4 | 5 | ![img.png](img.png) 6 | 7 | --------------------------------------------------------------------------------