├── .gitignore
├── LICENSE
├── NeRF - 源码.pptx
├── README.md
├── README_origin.md
├── configs
├── chair.txt
├── drums.txt
├── fern.txt
├── ficus.txt
├── flower.txt
├── fortress.txt
├── horns.txt
├── hotdog.txt
├── leaves.txt
├── lego-use_batching.txt
├── lego.txt
├── materials.txt
├── mic.txt
├── orchids.txt
├── room.txt
├── ship.txt
└── trex.txt
├── inference.py
├── load_LINEMOD.py
├── load_blender.py
├── load_deepvoxels.py
├── load_llff.py
├── nerf_helpers.py
├── nerf_model.py
├── opts.py
├── render.py
├── requirements.txt
├── run_nerf.py
├── test
└── frame.py
├── 模型.txt
├── 源码结构.md
└── 说明.md
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
2 | **/__pycache__
3 | *.png
4 | *.mp4
5 | *.npy
6 | *.npz
7 | *.dae
8 | data/*
9 | logs/*
10 |
11 | .idea
12 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 bmild
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NeRF - 源码.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xunull/read-nerf-pytorch/5940fc27e0ea82674859c61996e178ff71bddd1e/NeRF - 源码.pptx
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # read-nerf-pytorch
2 |
3 |
4 | 原始仓库: https://github.com/yenchenlin/nerf-pytorch
5 | nerf团队的tf实现: https://github.com/bmild/nerf
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README_origin.md:
--------------------------------------------------------------------------------
1 | # NeRF-pytorch
2 |
3 |
4 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
5 |
6 | 
7 | 
8 |
9 | This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf), and has been tested to match it numerically.
10 |
11 | ## Installation
12 |
13 | ```
14 | git clone https://github.com/yenchenlin/nerf-pytorch.git
15 | cd nerf-pytorch
16 | pip install -r requirements.txt
17 | ```
18 |
19 |
20 | Dependencies (click to expand)
21 |
22 | ## Dependencies
23 | - PyTorch 1.4
24 | - matplotlib
25 | - numpy
26 | - imageio
27 | - imageio-ffmpeg
28 | - configargparse
29 |
30 | The LLFF data loader requires ImageMagick.
31 |
32 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
33 |
34 |
35 |
36 | ## How To Run?
37 |
38 | ### Quick Start
39 |
40 | Download data for two example datasets: `lego` and `fern`
41 | ```
42 | bash download_example_data.sh
43 | ```
44 |
45 | To train a low-res `lego` NeRF:
46 | ```
47 | python run_nerf.py --config configs/lego.txt
48 | ```
49 | After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
50 |
51 | 
52 |
53 | ---
54 |
55 | To train a low-res `fern` NeRF:
56 | ```
57 | python run_nerf.py --config configs/fern.txt
58 | ```
59 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
60 |
61 | 
62 |
63 | ---
64 |
65 | ### More Datasets
66 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
67 | ```
68 | ├── configs
69 | │ ├── ...
70 | │
71 | ├── data
72 | │ ├── nerf_llff_data
73 | │ │ └── fern
74 | │ │ └── flower # downloaded llff dataset
75 | │ │ └── horns # downloaded llff dataset
76 | | | └── ...
77 | | ├── nerf_synthetic
78 | | | └── lego
79 | | | └── ship # downloaded synthetic dataset
80 | | | └── ...
81 | ```
82 |
83 | ---
84 |
85 | To train NeRF on different datasets:
86 |
87 | ```
88 | python run_nerf.py --config configs/{DATASET}.txt
89 | ```
90 |
91 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
92 |
93 | ---
94 |
95 | To test NeRF trained on different datasets:
96 |
97 | ```
98 | python run_nerf.py --config configs/{DATASET}.txt --render_only
99 | ```
100 |
101 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
102 |
103 |
104 | ### Pre-trained Models
105 |
106 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
107 |
108 | ```
109 | ├── logs
110 | │ ├── fern_test
111 | │ ├── flower_test # downloaded logs
112 | │ ├── trex_test # downloaded logs
113 | ```
114 |
115 | ### Reproducibility
116 |
117 | Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
118 | ```
119 | git checkout reproduce
120 | py.test
121 | ```
122 |
123 | ## Method
124 |
125 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
126 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1,
127 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1,
128 | [Matthew Tancik](http://tancik.com/)\*1,
129 | [Jonathan T. Barron](http://jonbarron.info/)2,
130 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3,
131 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
132 | 1UC Berkeley, 2Google Research, 3UC San Diego
133 | \*denotes equal contribution
134 |
135 |
136 |
137 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
138 |
139 |
140 | ## Citation
141 | Kudos to the authors for their amazing results:
142 | ```
143 | @misc{mildenhall2020nerf,
144 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
145 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
146 | year={2020},
147 | eprint={2003.08934},
148 | archivePrefix={arXiv},
149 | primaryClass={cs.CV}
150 | }
151 | ```
152 |
153 | However, if you find this implementation or pre-trained models helpful, please consider to cite:
154 | ```
155 | @misc{lin2020nerfpytorch,
156 | title={NeRF-pytorch},
157 | author={Yen-Chen, Lin},
158 | publisher = {GitHub},
159 | journal = {GitHub repository},
160 | howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
161 | year={2020}
162 | }
163 | ```
164 |
--------------------------------------------------------------------------------
/configs/chair.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_chair
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/chair
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/drums.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_drums
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/drums
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/fern.txt:
--------------------------------------------------------------------------------
1 | expname = fern_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fern
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/ficus.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ficus
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/ficus
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/flower.txt:
--------------------------------------------------------------------------------
1 | expname = flower_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/flower
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/fortress.txt:
--------------------------------------------------------------------------------
1 | expname = fortress_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fortress
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/horns.txt:
--------------------------------------------------------------------------------
1 | expname = horns_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/horns
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/hotdog.txt:
--------------------------------------------------------------------------------
1 | expname = hotdog_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/hotdog
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/leaves.txt:
--------------------------------------------------------------------------------
1 | expname = leaves_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/leaves
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/lego-use_batching.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_lego
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/lego
4 | dataset_type = blender
5 |
6 |
7 | use_viewdirs = True
8 | white_bkgd = True
9 | lrate_decay = 500
10 |
11 | N_samples = 64
12 | N_importance = 128
13 | N_rand = 1024
14 |
15 | precrop_iters = 500
16 | precrop_frac = 0.5
17 |
18 | half_res = True
19 |
--------------------------------------------------------------------------------
/configs/lego.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_lego
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/lego
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/materials.txt:
--------------------------------------------------------------------------------
1 | expname = materials_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/materials
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/mic.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_mic
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/mic
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/orchids.txt:
--------------------------------------------------------------------------------
1 | expname = orchids_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/orchids
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/room.txt:
--------------------------------------------------------------------------------
1 | expname = room_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/room
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/ship.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ship
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/ship
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/trex.txt:
--------------------------------------------------------------------------------
1 | expname = trex_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/trex
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import time
4 | from tqdm import tqdm
5 | from render import *
6 | from nerf_helpers import to8b
7 | import numpy as np
8 |
9 |
10 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
11 | H, W, focal = hwf
12 |
13 | if render_factor != 0:
14 | # Render downsampled for speed
15 | H = H // render_factor
16 | W = W // render_factor
17 | focal = focal / render_factor
18 |
19 | rgbs = []
20 | disps = []
21 |
22 | t = time.time()
23 | for i, c2w in enumerate(tqdm(render_poses)):
24 | print(i, time.time() - t)
25 | t = time.time()
26 | rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs)
27 | rgbs.append(rgb.cpu().numpy())
28 | disps.append(disp.cpu().numpy())
29 | if i == 0:
30 | print(rgb.shape, disp.shape)
31 |
32 | """
33 | if gt_imgs is not None and render_factor==0:
34 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
35 | print(p)
36 | """
37 |
38 | if savedir is not None:
39 | rgb8 = to8b(rgbs[-1])
40 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
41 | imageio.imwrite(filename, rgb8)
42 |
43 | rgbs = np.stack(rgbs, 0)
44 | disps = np.stack(disps, 0)
45 |
46 | return rgbs, disps
47 |
--------------------------------------------------------------------------------
/load_LINEMOD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 | trans_t = lambda t: torch.Tensor([
10 | [1, 0, 0, 0],
11 | [0, 1, 0, 0],
12 | [0, 0, 1, t],
13 | [0, 0, 0, 1]]).float()
14 |
15 | rot_phi = lambda phi: torch.Tensor([
16 | [1, 0, 0, 0],
17 | [0, np.cos(phi), -np.sin(phi), 0],
18 | [0, np.sin(phi), np.cos(phi), 0],
19 | [0, 0, 0, 1]]).float()
20 |
21 | rot_theta = lambda th: torch.Tensor([
22 | [np.cos(th), 0, -np.sin(th), 0],
23 | [0, 1, 0, 0],
24 | [np.sin(th), 0, np.cos(th), 0],
25 | [0, 0, 0, 1]]).float()
26 |
27 |
28 | def pose_spherical(theta, phi, radius):
29 | c2w = trans_t(radius)
30 | c2w = rot_phi(phi / 180. * np.pi) @ c2w
31 | c2w = rot_theta(theta / 180. * np.pi) @ c2w
32 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
33 | return c2w
34 |
35 |
36 | def load_LINEMOD_data(basedir, half_res=False, testskip=1):
37 | splits = ['train', 'val', 'test']
38 | metas = {}
39 | for s in splits:
40 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
41 | metas[s] = json.load(fp)
42 |
43 | all_imgs = []
44 | all_poses = []
45 | counts = [0]
46 | for s in splits:
47 | meta = metas[s]
48 | imgs = []
49 | poses = []
50 | if s == 'train' or testskip == 0:
51 | skip = 1
52 | else:
53 | skip = testskip
54 |
55 | for idx_test, frame in enumerate(meta['frames'][::skip]):
56 | fname = frame['file_path']
57 | if s == 'test':
58 | print(f"{idx_test}th test frame: {fname}")
59 | imgs.append(imageio.imread(fname))
60 | poses.append(np.array(frame['transform_matrix']))
61 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
62 | poses = np.array(poses).astype(np.float32)
63 | counts.append(counts[-1] + imgs.shape[0])
64 | all_imgs.append(imgs)
65 | all_poses.append(poses)
66 |
67 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
68 |
69 | imgs = np.concatenate(all_imgs, 0)
70 | poses = np.concatenate(all_poses, 0)
71 |
72 | H, W = imgs[0].shape[:2]
73 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
74 | K = meta['frames'][0]['intrinsic_matrix']
75 | print(f"Focal: {focal}")
76 |
77 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)
78 |
79 | if half_res:
80 | H = H // 2
81 | W = W // 2
82 | focal = focal / 2.
83 |
84 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
85 | for i, img in enumerate(imgs):
86 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
87 | imgs = imgs_half_res
88 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
89 |
90 | near = np.floor(min(metas['train']['near'], metas['test']['near']))
91 | far = np.ceil(max(metas['train']['far'], metas['test']['far']))
92 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far
93 |
--------------------------------------------------------------------------------
/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import cv2
7 |
8 | # 平移
9 | trans_t = lambda t: torch.Tensor([
10 | [1, 0, 0, 0],
11 | [0, 1, 0, 0],
12 | [0, 0, 1, t],
13 | [0, 0, 0, 1]]).float()
14 |
15 | # 绕x轴的旋转
16 | rot_phi = lambda phi: torch.Tensor([
17 | [1, 0, 0, 0],
18 | [0, np.cos(phi), -np.sin(phi), 0],
19 | [0, np.sin(phi), np.cos(phi), 0],
20 | [0, 0, 0, 1]]).float()
21 |
22 | # 绕y轴的旋转
23 | rot_theta = lambda th: torch.Tensor([
24 | [np.cos(th), 0, -np.sin(th), 0],
25 | [0, 1, 0, 0],
26 | [np.sin(th), 0, np.cos(th), 0],
27 | [0, 0, 0, 1]]).float()
28 |
29 |
30 | def pose_spherical(theta, phi, radius):
31 | """
32 | theta: -180 -- +180,间隔为9
33 | phi: 固定值 -30
34 | radius: 固定值 4
35 | """
36 | c2w = trans_t(radius)
37 | c2w = rot_phi(phi / 180. * np.pi) @ c2w
38 | c2w = rot_theta(theta / 180. * np.pi) @ c2w
39 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w
40 | return c2w
41 |
42 |
43 | def load_blender_data(basedir, half_res=False, testskip=1):
44 | """
45 | testskip: test和val数据集,只会读取其中的一部分数据,跳着读取
46 | """
47 | splits = ['train', 'val', 'test']
48 | # 存储了三个json文件的数据
49 | metas = {}
50 | for s in splits:
51 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
52 | metas[s] = json.load(fp)
53 |
54 | all_imgs = []
55 | all_poses = []
56 | counts = [0]
57 | for s in splits:
58 | meta = metas[s]
59 | imgs = []
60 | poses = []
61 | if s == 'train' or testskip == 0:
62 | skip = 1
63 | else:
64 | # 测试集如果数量很多,可能会设置testskip
65 | skip = testskip
66 | # 读取所有的图片,以及所有对应的transform_matrix
67 | for frame in meta['frames'][::skip]:
68 | fname = os.path.join(basedir, frame['file_path'] + '.png')
69 | imgs.append(imageio.imread(fname))
70 | poses.append(np.array(frame['transform_matrix']))
71 | # 归一化
72 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA),4通道 rgba
73 | poses = np.array(poses).astype(np.float32)
74 | # 用于计算train val test的递增值
75 | counts.append(counts[-1] + imgs.shape[0])
76 | all_imgs.append(imgs)
77 | all_poses.append(poses)
78 | # train val test 三个list
79 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
80 | # train test val 拼一起
81 | imgs = np.concatenate(all_imgs, 0)
82 | poses = np.concatenate(all_poses, 0)
83 |
84 | H, W = imgs[0].shape[:2]
85 | # meta使用了上面的局部变量,train test val 这个变量值是相同的,文件中这三个值确实是相同的
86 | camera_angle_x = float(meta['camera_angle_x'])
87 | # 焦距
88 | focal = .5 * W / np.tan(.5 * camera_angle_x)
89 |
90 | # np.linspace(-180, 180, 40 + 1) 9度一个间隔
91 | # (40,4,4), 渲染的结果就是40帧
92 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 40 + 1)[:-1]], 0)
93 |
94 | if half_res:
95 | H = H // 2
96 | W = W // 2
97 | # 焦距一半
98 | focal = focal / 2.
99 |
100 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
101 | for i, img in enumerate(imgs):
102 | # 调整成一半的大小
103 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
104 | imgs = imgs_half_res
105 |
106 | return imgs, poses, render_poses, [H, W, focal], i_split
107 |
--------------------------------------------------------------------------------
/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
8 | # Get camera intrinsics
9 | with open(filepath, 'r') as file:
10 | f, cx, cy = list(map(float, file.readline().split()))[:3]
11 | grid_barycenter = np.array(list(map(float, file.readline().split())))
12 | near_plane = float(file.readline())
13 | scale = float(file.readline())
14 | height, width = map(float, file.readline().split())
15 |
16 | try:
17 | world2cam_poses = int(file.readline())
18 | except ValueError:
19 | world2cam_poses = None
20 |
21 | if world2cam_poses is None:
22 | world2cam_poses = False
23 |
24 | world2cam_poses = bool(world2cam_poses)
25 |
26 | print(cx, cy, f, height, width)
27 |
28 | cx = cx / width * trgt_sidelength
29 | cy = cy / height * trgt_sidelength
30 | f = trgt_sidelength / height * f
31 |
32 | fx = f
33 | if invert_y:
34 | fy = -f
35 | else:
36 | fy = f
37 |
38 | # Build the intrinsic matrices
39 | full_intrinsic = np.array([[fx, 0., cx, 0.],
40 | [0., fy, cy, 0],
41 | [0., 0, 1, 0],
42 | [0, 0, 0, 1]])
43 |
44 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
45 |
46 | def load_pose(filename):
47 | assert os.path.isfile(filename)
48 | nums = open(filename).read().split()
49 | return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
50 |
51 | H = 512
52 | W = 512
53 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
54 |
55 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(
56 | os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
57 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
58 | focal = full_intrinsic[0, 0]
59 | print(H, W, focal)
60 |
61 | def dir2poses(posedir):
62 | poses = np.stack(
63 | [load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
64 | transf = np.array([
65 | [1, 0, 0, 0],
66 | [0, -1, 0, 0],
67 | [0, 0, -1, 0],
68 | [0, 0, 0, 1.],
69 | ])
70 | poses = poses @ transf
71 | poses = poses[:, :3, :4].astype(np.float32)
72 | return poses
73 |
74 | posedir = os.path.join(deepvoxels_base, 'pose')
75 | poses = dir2poses(posedir)
76 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
77 | testposes = testposes[::testskip]
78 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
79 | valposes = valposes[::testskip]
80 |
81 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
82 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f)) / 255. for f in imgfiles], 0).astype(
83 | np.float32)
84 |
85 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
86 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
87 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f)) / 255. for f in imgfiles[::testskip]], 0).astype(
88 | np.float32)
89 |
90 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
91 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
92 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f)) / 255. for f in imgfiles[::testskip]], 0).astype(
93 | np.float32)
94 |
95 | all_imgs = [imgs, valimgs, testimgs]
96 | counts = [0] + [x.shape[0] for x in all_imgs]
97 | counts = np.cumsum(counts)
98 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
99 |
100 | imgs = np.concatenate(all_imgs, 0)
101 | poses = np.concatenate([poses, valposes, testposes], 0)
102 |
103 | render_poses = testposes
104 |
105 | print(poses.shape, imgs.shape)
106 |
107 | return imgs, poses, render_poses, [H, W, focal], i_split
108 |
--------------------------------------------------------------------------------
/load_llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, imageio
3 |
4 |
5 | ########## Slightly modified version of LLFF data loading code
6 | ########## see https://github.com/Fyusion/LLFF for original
7 |
8 | def _minify(basedir, factors=[], resolutions=[]):
9 | needtoload = False
10 | # images_4
11 | # images_8 目录
12 | for r in factors:
13 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
14 | if not os.path.exists(imgdir):
15 | # 不存在对应的目录
16 | needtoload = True
17 | for r in resolutions:
18 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
19 | if not os.path.exists(imgdir):
20 | # 不存在对应的目录
21 | needtoload = True
22 |
23 | # 如果存在那些目录, 那么这里就返回
24 | if not needtoload:
25 | return
26 |
27 | from shutil import copy
28 | from subprocess import check_output
29 |
30 | imgdir = os.path.join(basedir, 'images')
31 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
32 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
33 | imgdir_orig = imgdir
34 |
35 | wd = os.getcwd()
36 |
37 | for r in factors + resolutions:
38 | if isinstance(r, int):
39 | name = 'images_{}'.format(r)
40 | resizearg = '{}%'.format(100. / r)
41 | else:
42 | name = 'images_{}x{}'.format(r[1], r[0])
43 | resizearg = '{}x{}'.format(r[1], r[0])
44 | imgdir = os.path.join(basedir, name)
45 | if os.path.exists(imgdir):
46 | continue
47 |
48 | print('Minifying', r, basedir)
49 |
50 | os.makedirs(imgdir)
51 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
52 |
53 | ext = imgs[0].split('.')[-1]
54 | # 执行一个shell命令
55 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
56 | print(args)
57 | os.chdir(imgdir)
58 | check_output(args, shell=True)
59 | os.chdir(wd)
60 |
61 | if ext != 'png':
62 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
63 | print('Removed duplicates')
64 | print('Done')
65 |
66 |
67 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
68 | # 加载 poses_bounds.npy
69 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
70 | # like [3,5,24]
71 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
72 |
73 | bds = poses_arr[:, -2:].transpose([1, 0])
74 |
75 | # images目录
76 | # 最原始的image目录
77 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
78 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
79 | sh = imageio.imread(img0).shape
80 |
81 | sfx = ''
82 |
83 | if factor is not None:
84 | sfx = '_{}'.format(factor)
85 | _minify(basedir, factors=[factor])
86 | factor = factor
87 | elif height is not None:
88 | factor = sh[0] / float(height)
89 | width = int(sh[1] / factor)
90 | _minify(basedir, resolutions=[[height, width]])
91 | sfx = '_{}x{}'.format(width, height)
92 | elif width is not None:
93 | factor = sh[1] / float(width)
94 | height = int(sh[0] / factor)
95 | _minify(basedir, resolutions=[[height, width]])
96 | sfx = '_{}x{}'.format(width, height)
97 | else:
98 | factor = 1
99 |
100 | # 这个就是 _4, _8 那些目录
101 | imgdir = os.path.join(basedir, 'images' + sfx)
102 | if not os.path.exists(imgdir):
103 | print(imgdir, 'does not exist, returning')
104 | return
105 |
106 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if
107 | f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
108 |
109 | # 判断数量是否相同
110 | if poses.shape[-1] != len(imgfiles):
111 | print('Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]))
112 | return
113 |
114 | sh = imageio.imread(imgfiles[0]).shape
115 |
116 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
117 | poses[2, 4, :] = poses[2, 4, :] * 1. / factor
118 |
119 | if not load_imgs:
120 | return poses, bds
121 |
122 | def imread(f):
123 | if f.endswith('png'):
124 | return imageio.imread(f, ignoregamma=True)
125 | else:
126 | return imageio.imread(f)
127 |
128 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
129 | imgs = np.stack(imgs, -1)
130 |
131 | print('Loaded image data', imgs.shape, poses[:, -1, 0])
132 | return poses, bds, imgs
133 |
134 |
135 | def normalize(x):
136 | return x / np.linalg.norm(x)
137 |
138 |
139 | def viewmatrix(z, up, pos):
140 | vec2 = normalize(z)
141 | vec1_avg = up
142 | vec0 = normalize(np.cross(vec1_avg, vec2))
143 | vec1 = normalize(np.cross(vec2, vec0))
144 | m = np.stack([vec0, vec1, vec2, pos], 1)
145 | return m
146 |
147 |
148 | def ptstocam(pts, c2w):
149 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
150 | return tt
151 |
152 |
153 | def poses_avg(poses):
154 | hwf = poses[0, :3, -1:]
155 |
156 | center = poses[:, :3, 3].mean(0)
157 | vec2 = normalize(poses[:, :3, 2].sum(0))
158 | up = poses[:, :3, 1].sum(0)
159 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
160 |
161 | return c2w
162 |
163 |
164 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
165 | render_poses = []
166 | rads = np.array(list(rads) + [1.])
167 | hwf = c2w[:, 4:5]
168 |
169 | for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]:
170 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
171 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
172 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
173 | return render_poses
174 |
175 |
176 | def recenter_poses(poses):
177 | poses_ = poses + 0
178 | bottom = np.reshape([0, 0, 0, 1.], [1, 4])
179 | c2w = poses_avg(poses)
180 | c2w = np.concatenate([c2w[:3, :4], bottom], -2)
181 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
182 | poses = np.concatenate([poses[:, :3, :4], bottom], -2)
183 |
184 | poses = np.linalg.inv(c2w) @ poses
185 | poses_[:, :3, :4] = poses[:, :3, :4]
186 | poses = poses_
187 | return poses
188 |
189 |
190 | #####################
191 |
192 |
193 | def spherify_poses(poses, bds):
194 | p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)
195 |
196 | rays_d = poses[:, :3, 2:3]
197 | rays_o = poses[:, :3, 3:4]
198 |
199 | def min_line_dist(rays_o, rays_d):
200 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
201 | b_i = -A_i @ rays_o
202 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))
203 | return pt_mindist
204 |
205 | pt_mindist = min_line_dist(rays_o, rays_d)
206 |
207 | center = pt_mindist
208 | up = (poses[:, :3, 3] - center).mean(0)
209 |
210 | vec0 = normalize(up)
211 | vec1 = normalize(np.cross([.1, .2, .3], vec0))
212 | vec2 = normalize(np.cross(vec0, vec1))
213 | pos = center
214 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
215 |
216 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
217 |
218 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
219 |
220 | sc = 1. / rad
221 | poses_reset[:, :3, 3] *= sc
222 | bds *= sc
223 | rad *= sc
224 |
225 | centroid = np.mean(poses_reset[:, :3, 3], 0)
226 | zh = centroid[2]
227 | radcircle = np.sqrt(rad ** 2 - zh ** 2)
228 | new_poses = []
229 |
230 | for th in np.linspace(0., 2. * np.pi, 120):
231 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
232 | up = np.array([0, 0, -1.])
233 |
234 | vec2 = normalize(camorigin)
235 | vec0 = normalize(np.cross(vec2, up))
236 | vec1 = normalize(np.cross(vec2, vec0))
237 | pos = camorigin
238 | p = np.stack([vec0, vec1, vec2, pos], 1)
239 |
240 | new_poses.append(p)
241 |
242 | new_poses = np.stack(new_poses, 0)
243 |
244 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1)
245 | poses_reset = np.concatenate(
246 | [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)
247 |
248 | return poses_reset, new_poses, bds
249 |
250 |
251 | # 1. 调用_load_data
252 | # 1. _minify
253 | # 这里可能会调用check_output对图片进行缩放
254 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False):
255 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
256 | print('Loaded', basedir, bds.min(), bds.max())
257 |
258 | # Correct rotation matrix ordering and move variable dim to axis 0
259 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
260 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
261 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
262 | images = imgs
263 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
264 |
265 | # Rescale if bd_factor is provided
266 | sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)
267 | poses[:, :3, 3] *= sc
268 | bds *= sc
269 |
270 | if recenter:
271 | poses = recenter_poses(poses)
272 |
273 | if spherify:
274 | poses, render_poses, bds = spherify_poses(poses, bds)
275 |
276 | else:
277 |
278 | c2w = poses_avg(poses)
279 | print('recentered', c2w.shape)
280 | print(c2w[:3, :4])
281 |
282 | ## Get spiral
283 | # Get average pose
284 | up = normalize(poses[:, :3, 1].sum(0))
285 |
286 | # Find a reasonable "focus depth" for this dataset
287 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
288 | dt = .75
289 | mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
290 | focal = mean_dz
291 |
292 | # Get radii for spiral path
293 | shrink_factor = .8
294 | zdelta = close_depth * .2
295 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
296 | rads = np.percentile(np.abs(tt), 90, 0)
297 | c2w_path = c2w
298 | N_views = 120
299 | N_rots = 2
300 | if path_zflat:
301 | # zloc = np.percentile(tt, 10, 0)[2]
302 | zloc = -close_depth * .1
303 | c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2]
304 | rads[2] = 0.
305 | N_rots = 1
306 | N_views /= 2
307 |
308 | # Generate poses for spiral path
309 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
310 |
311 | render_poses = np.array(render_poses).astype(np.float32)
312 |
313 | c2w = poses_avg(poses)
314 | print('Data:')
315 | print(poses.shape, images.shape, bds.shape)
316 |
317 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1)
318 | i_test = np.argmin(dists)
319 | print('HOLDOUT view is', i_test)
320 |
321 | images = images.astype(np.float32)
322 | poses = poses.astype(np.float32)
323 |
324 | return images, poses, bds, render_poses, i_test
325 |
--------------------------------------------------------------------------------
/nerf_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | __all__ = ['img2mse', 'mse2psnr', 'to8b', 'get_embedder', 'get_rays', 'get_rays_np', 'ndc_rays', 'sample_pdf']
6 |
7 | # Misc
8 | img2mse = lambda x, y: torch.mean((x - y) ** 2)
9 |
10 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
11 |
12 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
13 |
14 |
15 | # Positional encoding (section 5.1)
16 | class Embedder:
17 | def __init__(self, **kwargs):
18 | self.kwargs = kwargs
19 | self.create_embedding_fn()
20 |
21 | def create_embedding_fn(self):
22 | embed_fns = []
23 | d = self.kwargs['input_dims'] # 3
24 | out_dim = 0
25 | if self.kwargs['include_input']:
26 | embed_fns.append(lambda x: x)
27 | out_dim += d
28 |
29 | max_freq = self.kwargs['max_freq_log2']
30 | N_freqs = self.kwargs['num_freqs']
31 |
32 | if self.kwargs['log_sampling']:
33 | # tensor([ 1., 2., 4., 8., 16., 32., 64., 128., 256., 512.])
34 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
35 | else:
36 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
37 |
38 | for freq in freq_bands:
39 | for p_fn in self.kwargs['periodic_fns']:
40 | # sin(x),sin(2x),sin(4x),sin(8x),sin(16x),sin(32x),sin(64x),sin(128x),sin(256x),sin(512x)
41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
42 | out_dim += d
43 |
44 | self.embed_fns = embed_fns
45 |
46 | # 3D坐标是63,2D方向是27
47 | self.out_dim = out_dim
48 |
49 | def embed(self, inputs):
50 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
51 |
52 |
53 | # 位置编码相关
54 | def get_embedder(multires, i=0):
55 | """
56 | multires: 3D 坐标是10,2D方向是4
57 | """
58 | if i == -1:
59 | return nn.Identity(), 3
60 |
61 | embed_kwargs = {
62 | 'include_input': True,
63 | 'input_dims': 3,
64 | 'max_freq_log2': multires - 1,
65 | 'num_freqs': multires,
66 | 'log_sampling': True,
67 | 'periodic_fns': [torch.sin, torch.cos],
68 | }
69 |
70 | embedder_obj = Embedder(**embed_kwargs)
71 | embed = lambda x, eo=embedder_obj: eo.embed(x)
72 | # 第一个返回值是lamda,给定x,返回其位置编码
73 | return embed, embedder_obj.out_dim
74 |
75 |
76 | # ----------------------------------------------------------------------------------------------------------------------
77 |
78 | # Ray helpers
79 | def get_rays(H, W, K, c2w):
80 | """
81 | K:相机内参矩阵
82 | c2w: 相机到世界坐标系的转换
83 | """
84 | # j
85 | # [0,......]
86 | # [1,......]
87 | # [W-1,....]
88 | # i
89 | # [0,..,H-1]
90 | # [0,..,H-1]
91 | # [0,..,H-1]
92 |
93 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H), indexing='ij')
94 | i = i.t()
95 | j = j.t()
96 | # [400,400,3]
97 | dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
98 | # Rotate ray directions from camera frame to the world frame
99 | # dirs [400,400,3] -> [400,400,1,3]
100 | # dot product, equals to: [c2w.dot(dir) for dir in dirs]
101 | # rays_d [400,400,3]
102 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
103 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
104 | # 前三行,最后一列,定义了相机的平移,因此可以得到射线的原点o
105 | rays_o = c2w[:3, -1].expand(rays_d.shape)
106 | return rays_o, rays_d
107 |
108 |
109 | def get_rays_np(H, W, K, c2w):
110 | # 与上面的方法相似,这个是使用的numpy,上面是使用的torch
111 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
112 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1)
113 | # Rotate ray directions from camera frame to the world frame
114 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],
115 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
116 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
117 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
118 | return rays_o, rays_d
119 |
120 |
121 | # Hierarchical sampling (section 5.2)
122 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
123 | """
124 | bins: z_vals_mid
125 | """
126 |
127 | # Get pdf
128 | weights = weights + 1e-5 # prevent nans
129 | # 归一化 [bs, 62]
130 | # 概率密度函数
131 | pdf = weights / torch.sum(weights, -1, keepdim=True)
132 | # 累积分布函数
133 | cdf = torch.cumsum(pdf, -1)
134 | # 在第一个位置补0
135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
136 |
137 | # Take uniform samples
138 | if det:
139 | u = torch.linspace(0., 1., steps=N_samples)
140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
141 | else:
142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # [bs,128]
143 |
144 | # Pytest, overwrite u with numpy's fixed random numbers
145 | if pytest:
146 | np.random.seed(0)
147 | new_shape = list(cdf.shape[:-1]) + [N_samples]
148 | if det:
149 | u = np.linspace(0., 1., N_samples)
150 | u = np.broadcast_to(u, new_shape)
151 | else:
152 | u = np.random.rand(*new_shape)
153 | u = torch.Tensor(u)
154 |
155 | # Invert CDF
156 |
157 | u = u.contiguous()
158 | # u 是随机生成的
159 | # 找到对应的插入的位置
160 | inds = torch.searchsorted(cdf, u, right=True)
161 | # 前一个位置,为了防止inds中的0的前一个是-1,这里就还是0
162 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
163 | # 最大的位置就是cdf的上限位置,防止过头,跟上面的意义相同
164 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
165 | # (batch, N_samples, 2)
166 | inds_g = torch.stack([below, above], -1)
167 |
168 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
169 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
170 | # (batch, N_samples, 63)
171 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
172 | # 如[1024,128,63] 提取 根据 inds_g[i][j][0] inds_g[i][j][1]
173 | # cdf_g [1024,128,2]
174 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
175 | # 如上, bins 是从2到6的采样点,是64个点的中间值
176 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
177 | # 差值
178 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
179 | # 防止过小
180 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
181 |
182 | t = (u - cdf_g[..., 0]) / denom
183 |
184 | # lower+线性插值
185 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
186 |
187 | return samples
188 |
189 |
190 | # ----------------------------------------------------------------------------------------------------------------------
191 |
192 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
193 | # Shift ray origins to near plane
194 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
195 | rays_o = rays_o + t[..., None] * rays_d
196 |
197 | # Projection
198 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
199 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
200 | o2 = 1. + 2. * near / rays_o[..., 2]
201 |
202 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
203 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
204 | d2 = -2. * near / rays_o[..., 2]
205 |
206 | rays_o = torch.stack([o0, o1, o2], -1)
207 | rays_d = torch.stack([d0, d1, d2], -1)
208 |
209 | return rays_o, rays_d
210 |
--------------------------------------------------------------------------------
/nerf_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | class NeRF(nn.Module):
9 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
10 | """
11 | D: 深度,多少层网络
12 | W: 网络内的channel 宽度
13 | input_ch: xyz的宽度
14 | input_ch_views: direction的宽度
15 | output_ch: 这个参数尽在 use_viewdirs=False的时候会被使用
16 | skips: 类似resnet的残差连接,表明在第几层进行连接
17 | use_viewdirs:
18 | """
19 | super(NeRF, self).__init__()
20 | self.D = D
21 | self.W = W
22 | self.input_ch = input_ch
23 | self.input_ch_views = input_ch_views
24 | self.skips = skips
25 | self.use_viewdirs = use_viewdirs
26 |
27 | # 神经网络,MLP
28 | # 3D的空间坐标进入的网络
29 | # 这个跳跃连接层是直接拼接,不是resnet的那种相加
30 | self.pts_linears = nn.ModuleList(
31 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
32 | range(D - 1)])
33 |
34 | # 这里channel削减一半 128
35 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
36 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])
37 |
38 | if use_viewdirs:
39 | # 特征
40 | self.feature_linear = nn.Linear(W, W)
41 | # 透明度,一个值
42 | self.alpha_linear = nn.Linear(W, 1)
43 | # rgb颜色,三个值
44 | self.rgb_linear = nn.Linear(W // 2, 3)
45 | else:
46 | self.output_linear = nn.Linear(W, output_ch)
47 |
48 | def forward(self, x):
49 | # x [bs*64, 90]
50 | # input_pts [bs*64, 63]
51 | # input_views [bs*64,27]
52 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
53 |
54 | h = input_pts
55 |
56 | for i, l in enumerate(self.pts_linears):
57 |
58 | h = self.pts_linears[i](h)
59 | h = F.relu(h)
60 | # 第四层后相加
61 | if i in self.skips:
62 | h = torch.cat([input_pts, h], -1)
63 |
64 | if self.use_viewdirs:
65 | # alpha只与xyz有关
66 | alpha = self.alpha_linear(h)
67 | feature = self.feature_linear(h)
68 | # rgb与xyz和d都有关
69 | h = torch.cat([feature, input_views], -1)
70 |
71 | for i, l in enumerate(self.views_linears):
72 | h = self.views_linears[i](h)
73 | h = F.relu(h)
74 |
75 | rgb = self.rgb_linear(h)
76 | outputs = torch.cat([rgb, alpha], -1)
77 | else:
78 | outputs = self.output_linear(h)
79 |
80 | return outputs
81 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | def config_parser():
2 | import configargparse
3 | parser = configargparse.ArgumentParser()
4 |
5 | parser.add_argument('--config', is_config_file=True,
6 | help='config file path')
7 | # 本次实验的名称,作为log中文件夹的名字
8 | parser.add_argument("--expname", type=str,
9 | help='experiment name')
10 | # 输出目录
11 | parser.add_argument("--basedir", type=str, default='./logs/',
12 | help='where to store ckpts and logs')
13 | # 指定数据集的目录
14 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
15 | help='input data directory')
16 |
17 | # training options
18 | # 全连接的层数
19 | parser.add_argument("--netdepth", type=int, default=8,
20 | help='layers in network')
21 | # 网络宽度
22 | parser.add_argument("--netwidth", type=int, default=256,
23 | help='channels per layer')
24 |
25 | # 精细网络的全连接层数
26 | # 默认精细网络的深度和宽度与粗糙网络是相同的
27 | parser.add_argument("--netdepth_fine", type=int, default=8,
28 | help='layers in fine network')
29 | parser.add_argument("--netwidth_fine", type=int, default=256,
30 | help='channels per layer in fine network')
31 |
32 | # 这里的batch size,指的是光线的数量,像素点的数量
33 | # N_rand 配置文件中是1024
34 | # 32*32*4=4096
35 | # 800*800/4096=156 400*400/1024=156
36 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 4,
37 | help='batch size (number of random rays per gradient step)')
38 | # 学习率
39 | parser.add_argument("--lrate", type=float, default=5e-4,
40 | help='learning rate')
41 | # 学习率衰减
42 | parser.add_argument("--lrate_decay", type=int, default=250,
43 | help='exponential learning rate decay (in 1000 steps)')
44 |
45 | parser.add_argument("--chunk", type=int, default=1024 * 32,
46 | help='number of rays processed in parallel, decrease if running out of memory')
47 |
48 | # 网络中处理的点的数量
49 | parser.add_argument("--netchunk", type=int, default=1024 * 64,
50 | help='number of pts sent through network in parallel, decrease if running out of memory')
51 |
52 | # 合成的数据集一般都是True, 每次只从一张图片中选取随机光线
53 | # 真实的数据集一般都是False, 图形先混在一起
54 | parser.add_argument("--no_batching", action='store_true',
55 | help='only take random rays from 1 image at a time')
56 |
57 | # 不加载权重
58 | parser.add_argument("--no_reload", action='store_true',
59 | help='do not reload weights from saved ckpt')
60 | # 粗网络的权重文件的位置
61 | parser.add_argument("--ft_path", type=str, default=None,
62 | help='specific weights npy file to reload for coarse network')
63 |
64 | # rendering options
65 | parser.add_argument("--N_samples", type=int, default=64,
66 | help='number of coarse samples per ray')
67 | parser.add_argument("--N_importance", type=int, default=0,
68 | help='number of additional fine samples per ray')
69 |
70 | parser.add_argument("--perturb", type=float, default=1.,
71 | help='set to 0. for no jitter, 1. for jitter')
72 | # 不适用视角数据
73 | parser.add_argument("--use_viewdirs", action='store_true',
74 | help='use full 5D input instead of 3D')
75 | # 0 使用位置编码,-1 不使用位置编码
76 | parser.add_argument("--i_embed", type=int, default=0,
77 | help='set 0 for default positional encoding, -1 for none')
78 |
79 | # L=10
80 | parser.add_argument("--multires", type=int, default=10,
81 | help='log2 of max freq for positional encoding (3D location)')
82 | # L=4
83 | parser.add_argument("--multires_views", type=int, default=4,
84 | help='log2 of max freq for positional encoding (2D direction)')
85 |
86 | parser.add_argument("--raw_noise_std", type=float, default=0.,
87 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
88 |
89 | # 仅进行渲染
90 | parser.add_argument("--render_only", action='store_true',
91 | help='do not optimize, reload weights and render out render_poses path')
92 | # 渲染test数据集
93 | parser.add_argument("--render_test", action='store_true',
94 | help='render the test set instead of render_poses path')
95 | # 下采样的倍数
96 | parser.add_argument("--render_factor", type=int, default=0,
97 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
98 |
99 | # training options
100 | # 中心裁剪的训练轮数
101 | parser.add_argument("--precrop_iters", type=int, default=0,
102 | help='number of steps to train on central crops')
103 | parser.add_argument("--precrop_frac", type=float,
104 | default=.5, help='fraction of img taken for central crops')
105 |
106 | # dataset options
107 | # 数据格式
108 | parser.add_argument("--dataset_type", type=str, default='llff',
109 | help='options: llff / blender / deepvoxels')
110 |
111 | # 对于大的数据集,test和val数据集,只使用其中的一部分数据
112 | parser.add_argument("--testskip", type=int, default=8,
113 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
114 |
115 | ## deepvoxels flags
116 | parser.add_argument("--shape", type=str, default='greek',
117 | help='options : armchair / cube / greek / vase')
118 |
119 | ## blender flags
120 | # 白色背景
121 | parser.add_argument("--white_bkgd", action='store_true',
122 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
123 |
124 | # 使用一半分辨率
125 | parser.add_argument("--half_res", action='store_true',
126 | help='load blender synthetic data at 400x400 instead of 800x800')
127 |
128 | ## llff flags
129 | parser.add_argument("--factor", type=int, default=8,
130 | help='downsample factor for LLFF images')
131 | parser.add_argument("--no_ndc", action='store_true',
132 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
133 | parser.add_argument("--lindisp", action='store_true',
134 | help='sampling linearly in disparity rather than depth')
135 | parser.add_argument("--spherify", action='store_true',
136 | help='set for spherical 360 scenes')
137 | parser.add_argument("--llffhold", type=int, default=8,
138 | help='will take every 1/N images as LLFF test set, paper uses 8')
139 |
140 | # logging/saving options
141 | # log输出的频率
142 | parser.add_argument("--i_print", type=int, default=100,
143 | help='frequency of console printout and metric loggin')
144 | parser.add_argument("--i_img", type=int, default=500,
145 | help='frequency of tensorboard image logging')
146 | # 保存模型的频率
147 | # 每隔1w保存一个
148 | parser.add_argument("--i_weights", type=int, default=10000,
149 | help='frequency of weight ckpt saving')
150 | # 执行测试集渲染的频率
151 | parser.add_argument("--i_testset", type=int, default=50000,
152 | help='frequency of testset saving')
153 | # 执行渲染视频的频率
154 | parser.add_argument("--i_video", type=int, default=50000,
155 | help='frequency of render_poses video saving')
156 |
157 | return parser
158 |
--------------------------------------------------------------------------------
/render.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 | from nerf_helpers import sample_pdf, get_rays, ndc_rays
6 |
7 | DEBUG = False
8 |
9 | __all__ = ['render', 'batchify_rays', 'render_rays', 'raw2outputs']
10 |
11 |
12 | def render(H, W, K,
13 | chunk=1024 * 32, rays=None, c2w=None, ndc=True,
14 | near=0., far=1.,
15 | use_viewdirs=False, c2w_staticcam=None,
16 | **kwargs):
17 | """Render rays
18 | Args:
19 | H: int. Height of image in pixels.
20 | W: int. Width of image in pixels.
21 | K: 相机内参 focal
22 | chunk: int. Maximum number of rays to process simultaneously. Used to
23 | control maximum memory usage. Does not affect final results.
24 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
25 | each example in batch.
26 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
27 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
28 | near: float or array of shape [batch_size]. Nearest distance for a ray.
29 | far: float or array of shape [batch_size]. Farthest distance for a ray.
30 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
31 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
32 | camera while using other c2w argument for viewing directions.
33 | Returns:
34 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
35 | disp_map: [batch_size]. Disparity map. Inverse of depth.
36 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
37 | extras: dict with everything returned by render_rays().
38 | """
39 |
40 | if c2w is not None:
41 | # special case to render full image
42 | rays_o, rays_d = get_rays(H, W, K, c2w)
43 | else:
44 | # use provided ray batch
45 | # 光线的起始位置, 方向
46 | rays_o, rays_d = rays
47 |
48 | if use_viewdirs:
49 | # provide ray directions as input
50 | viewdirs = rays_d
51 | # 静态相机 相机坐标到世界坐标的转换
52 | if c2w_staticcam is not None:
53 | # special case to visualize effect of viewdirs
54 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
55 | # 单位向量 [bs,3]
56 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
57 | viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
58 |
59 | sh = rays_d.shape # [..., 3]
60 |
61 | if ndc:
62 | # for forward facing scenes
63 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
64 |
65 | # Create ray batch
66 | rays_o = torch.reshape(rays_o, [-1, 3]).float()
67 | rays_d = torch.reshape(rays_d, [-1, 3]).float()
68 | # [bs,1],[bs,1]
69 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
70 | # 8=3+3+1+1
71 | rays = torch.cat([rays_o, rays_d, near, far], -1)
72 | if use_viewdirs:
73 | # 加了direction的三个坐标
74 | # 3 3 1 1 3
75 | rays = torch.cat([rays, viewdirs], -1) # [bs,11]
76 |
77 | # Render and reshape
78 |
79 | # rgb_map,disp_map,acc_map,raw,rbg0,disp0,acc0,z_std
80 | all_ret = batchify_rays(rays, chunk, **kwargs)
81 | for k in all_ret:
82 | # 对所有的返回值进行reshape
83 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
84 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
85 |
86 | # 讲精细网络的输出单独拿了出来
87 | k_extract = ['rgb_map', 'disp_map', 'acc_map']
88 | ret_list = [all_ret[k] for k in k_extract]
89 | ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
90 | # 前三是list,后5还是在map中
91 | return ret_list + [ret_dict]
92 |
93 |
94 | def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs):
95 | """
96 | Render rays in smaller minibatches to avoid OOM.
97 | rays_flat: [N_rand,11]
98 | """
99 |
100 | all_ret = {}
101 | for i in range(0, rays_flat.shape[0], chunk):
102 | ret = render_rays(rays_flat[i:i + chunk], **kwargs)
103 | for k in ret:
104 | if k not in all_ret:
105 | all_ret[k] = []
106 | all_ret[k].append(ret[k])
107 | # 将分批处理的结果拼接在一起
108 | all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
109 | return all_ret
110 |
111 |
112 | # 这里面会经过神经网络
113 | def render_rays(ray_batch,
114 | network_fn,
115 | network_query_fn,
116 | N_samples,
117 | retraw=False,
118 | lindisp=False,
119 | perturb=0.,
120 | N_importance=0,
121 | network_fine=None,
122 | white_bkgd=False,
123 | raw_noise_std=0.,
124 | verbose=False,
125 | pytest=False):
126 | """Volumetric rendering.
127 | Args:
128 | ray_batch: array of shape [batch_size, ...]. All information necessary
129 | for sampling along a ray, including: ray origin, ray direction, min
130 | dist, max dist, and unit-magnitude viewing direction. 单位大小查看方向
131 | 粗网络
132 | network_fn: function. Model for predicting RGB and density at each point
133 | in space.
134 | network_query_fn: function used for passing queries to network_fn.
135 | N_samples: int. Number of different times to sample along each ray.
136 |
137 | raw 是指神经网络的输出
138 | retraw: bool. If True, include model's raw, unprocessed predictions.
139 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
140 |
141 |
142 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified random points in time.
143 |
144 | 精细网络中的光线上的采样频率
145 | N_importance: int. Number of additional times to sample along each ray.
146 | These samples are only passed to network_fine.
147 | 精细网络
148 | network_fine: "fine" network with same spec as network_fn.
149 | white_bkgd: bool. If True, assume a white background. 白色背景
150 | raw_noise_std: ...
151 |
152 |
153 | verbose: bool. If True, print more debugging info.
154 | Returns:
155 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
156 | disp_map: [num_rays]. Disparity map. 1 / depth.
157 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
158 | raw: [num_rays, num_samples, 4]. Raw predictions from model.
159 | rgb0: See rgb_map. Output for coarse model.
160 | disp0: See disp_map. Output for coarse model.
161 | acc0: See acc_map. Output for coarse model.
162 | z_std: [num_rays]. Standard deviation of distances along ray for each
163 | sample.
164 | """
165 | N_rays = ray_batch.shape[0] # N_rand
166 | # 光线起始位置,光线的方向
167 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
168 | # 视角的单位向量
169 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
170 | bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) # [bs,1,2] near和far
171 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
172 | # 采样点
173 | t_vals = torch.linspace(0., 1., steps=N_samples)
174 | if not lindisp:
175 | z_vals = near * (1. - t_vals) + far * (t_vals) # 插值采样
176 | else:
177 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))
178 | # [N_rand,64] -> [N_rand,64]
179 | z_vals = z_vals.expand([N_rays, N_samples])
180 |
181 | if perturb > 0.:
182 | # get intervals between samples,64个采样点的中点
183 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
184 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
185 | lower = torch.cat([z_vals[..., :1], mids], -1)
186 | # stratified samples in those intervals
187 | t_rand = torch.rand(z_vals.shape)
188 |
189 | # Pytest, overwrite u with numpy's fixed random numbers
190 | if pytest:
191 | np.random.seed(0)
192 | t_rand = np.random.rand(*list(z_vals.shape))
193 | t_rand = torch.Tensor(t_rand)
194 | # [bs,64] 加上随机的噪声
195 | z_vals = lower + (upper - lower) * t_rand
196 |
197 | # 空间中的采样点
198 | # [N_rand, 64, 3]
199 | # 出发点+距离*方向
200 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3]
201 |
202 | # 使用神经网络 viewdirs [N_rand,3], network_fn 指的是粗糙NeRF或者精细NeRF
203 | # raw [bs,64,3]
204 | raw = network_query_fn(pts, viewdirs, network_fn)
205 |
206 | # rgb值,xx,权重的和,weights就是论文中的那个Ti和alpha的乘积
207 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd,
208 | pytest=pytest)
209 |
210 | # 精细网络部分
211 | if N_importance > 0:
212 | # _0 是第一个阶段 粗糙网络的结果
213 | # 这三个留着放在dict中输出用
214 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
215 | # 第二次计算mid,取中点位置
216 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
217 | z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest)
218 | z_samples = z_samples.detach()
219 |
220 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
221 | # 给精细网络使用的点
222 | # [N_rays, N_samples + N_importance, 3]
223 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
224 |
225 | run_fn = network_fn if network_fine is None else network_fine
226 |
227 | # 使用神经网络
228 | # create_nerf 中的 network_query_fn 那个lambda 函数
229 | # viewdirs 与粗糙网络是相同的
230 | raw = network_query_fn(pts, viewdirs, run_fn)
231 |
232 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd,
233 | pytest=pytest)
234 |
235 | ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
236 |
237 | if retraw:
238 | # 如果是两个网络,那么这个raw就是最后精细网络的输出
239 | ret['raw'] = raw
240 |
241 | if N_importance > 0:
242 | # 下面的0是粗糙网络的输出
243 | ret['rgb0'] = rgb_map_0
244 | ret['disp0'] = disp_map_0
245 | ret['acc0'] = acc_map_0
246 |
247 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
248 |
249 | # 检查是否有异常值
250 | for k in ret:
251 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
252 | print(f"! [Numerical Error] {k} contains nan or inf.")
253 |
254 | return ret
255 |
256 |
257 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
258 | """Transforms model's predictions to semantically meaningful values.
259 | Args:
260 | Model的输出
261 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
262 | 采样,并未变成空间点的那个采样点
263 | z_vals: [num_rays, num_samples along ray]. Integration time.
264 | 光线的方向
265 | rays_d: [num_rays, 3]. Direction of each ray.
266 | Returns:
267 | RGB颜色值
268 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
269 | 逆深度
270 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
271 | 权重和?
272 | acc_map: [num_rays]. Sum of weights along each ray.
273 | 权重
274 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
275 | 估计的深度
276 | depth_map: [num_rays]. Estimated distance to object.
277 | """
278 | # Alpha的计算
279 | # relu, 负数拉平为0
280 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)
281 | # [bs,63]
282 | # 采样点之间的距离
283 | dists = z_vals[..., 1:] - z_vals[..., :-1]
284 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
285 | # rays_d[...,None,:] [bs,3] -> [bs,1,3]
286 | # 1维 -> 3维
287 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
288 | # RGB经过sigmoid处理
289 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
290 | noise = 0.
291 | if raw_noise_std > 0.:
292 | noise = torch.randn(raw[..., 3].shape) * raw_noise_std
293 |
294 | # Overwrite randomly sampled data if pytest
295 | if pytest:
296 | np.random.seed(0)
297 | noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
298 | noise = torch.Tensor(noise)
299 | # 计算公式3 [bs, 64],
300 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
301 |
302 | # 后面这部分就是Ti,前面是alpha,这个就是论文上的那个权重w [bs,64]
303 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)),
304 | 1. - alpha + 1e-10], -1),
305 | -1)[:, :-1]
306 | # [bs, 64,1] * [bs,64,3]
307 | # 在第二个维度,64将所有的点的值相加 -> [32,3]
308 | # 公式3的结果值
309 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3]
310 | # (32,)
311 | # 深度图
312 | # Estimated depth map is expected distance.
313 | depth_map = torch.sum(weights * z_vals, -1)
314 | # 视差图
315 | # Disparity map is inverse depth.
316 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
317 |
318 | # 权重和
319 | # 这个值仅做了输出用,后续并无使用
320 | acc_map = torch.sum(weights, -1)
321 |
322 | if white_bkgd:
323 | rgb_map = rgb_map + (1. - acc_map[..., None])
324 |
325 | return rgb_map, disp_map, acc_map, weights, depth_map
326 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio
2 | imageio-ffmpeg
3 | matplotlib
4 | configargparse
5 | tensorboard>=2.0
6 | tqdm
7 | opencv-python
8 |
--------------------------------------------------------------------------------
/run_nerf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import time
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm, trange
7 | from nerf_helpers import *
8 | from nerf_model import NeRF
9 | from load_llff import load_llff_data
10 | from load_deepvoxels import load_dv_data
11 | from load_blender import load_blender_data
12 | from load_LINEMOD import load_LINEMOD_data
13 | from render import *
14 | from inference import render_path
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 | np.random.seed(0)
18 | DEBUG = False
19 |
20 |
21 | def batchify(fn, chunk):
22 | """
23 | Constructs a version of 'fn' that applies to smaller batches.
24 | """
25 | if chunk is None:
26 | return fn
27 |
28 | def ret(inputs):
29 | # 以chunk分批进入网络,防止显存爆掉,然后在拼接
30 | return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
31 |
32 | return ret
33 |
34 |
35 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64):
36 | """
37 | 被下面的create_nerf 封装到了lambda方法里面
38 | Prepares inputs and applies network 'fn'.
39 | inputs: pts,光线上的点 如 [1024,64,3],1024条光线,一条光线上64个点
40 | viewdirs: 光线起点的方向
41 | fn: 神经网络模型 粗糙网络或者精细网络
42 | embed_fn:
43 | embeddirs_fn:
44 | netchunk:
45 | """
46 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # [N_rand*64,3]
47 | # 坐标点进行编码嵌入 [N_rand*64,63]
48 | embedded = embed_fn(inputs_flat)
49 |
50 | if viewdirs is not None:
51 | input_dirs = viewdirs[:, None].expand(inputs.shape)
52 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
53 | # 方向进行位置编码
54 | embedded_dirs = embeddirs_fn(input_dirs_flat) # [N_rand*64,27]
55 | embedded = torch.cat([embedded, embedded_dirs], -1)
56 |
57 | # 里面经过网络 [bs*64,3]
58 | outputs_flat = batchify(fn, netchunk)(embedded)
59 | # [bs*4,4] -> [bs,64,4]
60 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
61 | return outputs
62 |
63 |
64 | def create_nerf(args):
65 | """Instantiate NeRF's MLP model.
66 | """
67 |
68 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
69 |
70 | input_ch_views = 0
71 | embeddirs_fn = None
72 | if args.use_viewdirs:
73 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
74 |
75 | # 想要=5生效,首先需要use_viewdirs=False and N_importance>0
76 | output_ch = 5 if args.N_importance > 0 else 4
77 | skips = [4]
78 | # 粗网络
79 | model = NeRF(D=args.netdepth, W=args.netwidth,
80 | input_ch=input_ch, output_ch=output_ch, skips=skips,
81 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
82 | grad_vars = list(model.parameters())
83 |
84 | model_fine = None
85 | if args.N_importance > 0:
86 | # 精细网络
87 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
88 | input_ch=input_ch, output_ch=output_ch, skips=skips,
89 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
90 | # 模型参数
91 | grad_vars += list(model_fine.parameters())
92 |
93 | # netchunk 是网络中处理的点的batch_size
94 | network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn,
95 | embed_fn=embed_fn,
96 | embeddirs_fn=embeddirs_fn,
97 | netchunk=args.netchunk)
98 |
99 | # Create optimizer
100 | # 优化器
101 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
102 |
103 | start = 0
104 | basedir = args.basedir
105 | expname = args.expname
106 |
107 | ##########################
108 |
109 | # Load checkpoints
110 | if args.ft_path is not None and args.ft_path != 'None':
111 | ckpts = [args.ft_path]
112 | else:
113 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if
114 | 'tar' in f]
115 |
116 | print('Found ckpts', ckpts)
117 |
118 | # load参数
119 | if len(ckpts) > 0 and not args.no_reload:
120 | ckpt_path = ckpts[-1]
121 | print('Reloading from', ckpt_path)
122 | ckpt = torch.load(ckpt_path)
123 |
124 | start = ckpt['global_step']
125 | optimizer.load_state_dict(ckpt['optimizer_state_dict'])
126 |
127 | # Load model
128 | model.load_state_dict(ckpt['network_fn_state_dict'])
129 | if model_fine is not None:
130 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
131 |
132 | ##########################
133 |
134 | render_kwargs_train = {
135 | 'network_query_fn': network_query_fn,
136 | 'perturb': args.perturb,
137 | 'N_importance': args.N_importance,
138 | # 精细网络
139 | 'network_fine': model_fine,
140 | 'N_samples': args.N_samples,
141 | # 粗网络
142 | 'network_fn': model,
143 | 'use_viewdirs': args.use_viewdirs,
144 | 'white_bkgd': args.white_bkgd,
145 | 'raw_noise_std': args.raw_noise_std,
146 | }
147 |
148 | print(model_fine)
149 |
150 | # NDC only good for LLFF-style forward facing data
151 | if args.dataset_type != 'llff' or args.no_ndc:
152 | print('Not ndc!')
153 | render_kwargs_train['ndc'] = False
154 | render_kwargs_train['lindisp'] = args.lindisp
155 |
156 | render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train}
157 | render_kwargs_test['perturb'] = False
158 | render_kwargs_test['raw_noise_std'] = 0.
159 |
160 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
161 |
162 |
163 | # ----------------------------------------------------------------------------------------------------------------------
164 |
165 | def create_log_files(basedir, expname, args):
166 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
167 |
168 | # 保存一份参数文件
169 | f = os.path.join(basedir, expname, 'args.txt')
170 | with open(f, 'w') as file:
171 | for arg in sorted(vars(args)):
172 | attr = getattr(args, arg)
173 | file.write('{} = {}\n'.format(arg, attr))
174 |
175 | # 保存一份配置文件
176 | if args.config is not None:
177 | f = os.path.join(basedir, expname, 'config.txt')
178 | with open(f, 'w') as file:
179 | file.write(open(args.config, 'r').read())
180 |
181 | return basedir, expname
182 |
183 |
184 | def run_render_only(args, images, i_test, basedir, expname, render_poses, hwf, K, render_kwargs_test, start):
185 | with torch.no_grad():
186 | if args.render_test:
187 | # render_test switches to test poses
188 | images = images[i_test]
189 | else:
190 | # Default is smoother render_poses path
191 | images = None
192 |
193 | testsavedir = os.path.join(basedir, expname,
194 | 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
195 | os.makedirs(testsavedir, exist_ok=True)
196 | print('test poses shape', render_poses.shape)
197 |
198 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images,
199 | savedir=testsavedir, render_factor=args.render_factor)
200 | print('Done rendering', testsavedir)
201 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
202 |
203 |
204 | def train():
205 | # 解析参数
206 | from opts import config_parser
207 | parser = config_parser()
208 | args = parser.parse_args()
209 |
210 | # --------------------------------------------------------------------------------------------------------
211 |
212 | # Load data
213 |
214 | # 在这个数据集会特殊些 LINEMOD
215 | K = None
216 |
217 | # 一共有四种类型的数据集
218 | # 是configs目录中 只有llff和blender两种类型
219 | # 原始的nerf仓库中有deepvoxels类型的数据
220 | # LINEMOD 没见过
221 |
222 | # llff Local Light Field Fusion
223 | if args.dataset_type == 'llff':
224 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
225 | recenter=True, bd_factor=.75,
226 | spherify=args.spherify)
227 | hwf = poses[0, :3, -1]
228 | poses = poses[:, :3, :4]
229 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
230 | if not isinstance(i_test, list):
231 | i_test = [i_test]
232 |
233 | if args.llffhold > 0:
234 | print('Auto LLFF holdout,', args.llffhold)
235 | i_test = np.arange(images.shape[0])[::args.llffhold]
236 |
237 | i_val = i_test
238 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
239 | (i not in i_test and i not in i_val)])
240 |
241 | print('DEFINING BOUNDS')
242 | if args.no_ndc:
243 | near = np.ndarray.min(bds) * .9
244 | far = np.ndarray.max(bds) * 1.
245 |
246 | else:
247 | near = 0.
248 | far = 1.
249 | print('NEAR FAR', near, far)
250 |
251 | elif args.dataset_type == 'blender':
252 | # images,所有的图片,train val test在一起,poses也一样
253 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
254 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
255 | i_train, i_val, i_test = i_split
256 |
257 | near = 2.
258 | far = 6.
259 |
260 | if args.white_bkgd:
261 | # todo 这个是什么操作,为什么白色背景要这样操作
262 | images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
263 | else:
264 | images = images[..., :3]
265 |
266 | elif args.dataset_type == 'LINEMOD':
267 | # 这个数据类型 原始的nerf中没有
268 |
269 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res,
270 | args.testskip)
271 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
272 | print(f'[CHECK HERE] near: {near}, far: {far}.')
273 | i_train, i_val, i_test = i_split
274 |
275 | if args.white_bkgd:
276 | images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:])
277 | else:
278 | images = images[..., :3]
279 |
280 | elif args.dataset_type == 'deepvoxels':
281 |
282 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
283 | basedir=args.datadir,
284 | testskip=args.testskip)
285 |
286 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
287 | i_train, i_val, i_test = i_split
288 |
289 | hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
290 | near = hemi_R - 1.
291 | far = hemi_R + 1.
292 |
293 | else:
294 | print('Unknown dataset type', args.dataset_type, 'exiting')
295 | return
296 |
297 | # Cast intrinsics to right types
298 | H, W, focal = hwf
299 | H, W = int(H), int(W)
300 | hwf = [H, W, focal]
301 |
302 | # K 相机内参 focal 是焦距,0.5w 0.5h 是中心点坐标
303 | # 这个矩阵是相机坐标到图像坐标转换使用
304 | if K is None:
305 | K = np.array([
306 | [focal, 0, 0.5 * W],
307 | [0, focal, 0.5 * H],
308 | [0, 0, 1]
309 | ])
310 |
311 | # --------------------------------------------------------------------------------------------------------
312 |
313 | # render the test set instead of render_poses path
314 | # 使用测试集的pose,而不是用那个固定生成的render_poses
315 | if args.render_test:
316 | render_poses = np.array(poses[i_test])
317 |
318 | # Move testing data to GPU
319 | render_poses = torch.Tensor(render_poses).to(device)
320 |
321 | # --------------------------------------------------------------------------------------------------------
322 |
323 | # Create log dir and copy the config file
324 |
325 | basedir = args.basedir
326 | expname = args.expname
327 |
328 | create_log_files(basedir, expname, args)
329 |
330 | # --------------------------------------------------------------------------------------------------------
331 |
332 | # Create nerf model
333 | # 创建模型
334 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
335 | # 有可能从中间迭代恢复运行的
336 | global_step = start
337 |
338 | bds_dict = {
339 | 'near': near,
340 | 'far': far,
341 | }
342 | render_kwargs_train.update(bds_dict)
343 | render_kwargs_test.update(bds_dict)
344 |
345 | # --------------------------------------------------------------------------------------------------------
346 |
347 | # Short circuit if only rendering out from trained model
348 | # 这里会使用render_poses
349 | if args.render_only:
350 | # 仅进行渲染,不进行训练
351 | print('RENDER ONLY')
352 | run_render_only(args, images, i_test, basedir, expname, render_poses, hwf, K, render_kwargs_test, start)
353 | return
354 |
355 | # --------------------------------------------------------------------------------------------------------
356 |
357 | # Prepare ray batch tensor if batching random rays
358 | N_rand = args.N_rand
359 |
360 | use_batching = not args.no_batching
361 |
362 | if use_batching:
363 | # For random ray batching
364 | print('get rays') # (img_count,2,400,400,3) 2是 rays_o和rays_d
365 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3]
366 | print('done, concats') # rays和图像混在一起
367 | rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3]
368 | rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3]
369 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only, 仅使用训练文件夹下的数据
370 | rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3]
371 | rays_rgb = rays_rgb.astype(np.float32)
372 | print('shuffle rays')
373 | np.random.shuffle(rays_rgb) # 打乱光线
374 |
375 | print('done')
376 | i_batch = 0
377 |
378 | # 统一一个时刻放入cuda
379 | # Move training data to GPU
380 | if use_batching:
381 | images = torch.Tensor(images).to(device)
382 | rays_rgb = torch.Tensor(rays_rgb).to(device)
383 |
384 | poses = torch.Tensor(poses).to(device)
385 |
386 | # --------------------------------------------------------------------------------------------------------
387 |
388 | print('Begin')
389 | print('TRAIN views are', i_train)
390 | print('TEST views are', i_test)
391 | print('VAL views are', i_val)
392 |
393 | # 训练部分的代码
394 | # 两万次迭代
395 | # 可能是强迫症,不想在保存文件的时候,出现19999这种数字
396 | N_iters = 200000 + 1
397 | start = start + 1
398 | for i in trange(start, N_iters):
399 | time0 = time.time()
400 |
401 | # Sample random ray batch
402 | if use_batching:
403 | # Random over all images
404 | # 一批光线
405 | batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?]
406 |
407 | batch = torch.transpose(batch, 0, 1)
408 | batch_rays, target_s = batch[:2], batch[2] # 前两个是rays_o和rays_d, 第三个是target就是image的rgb
409 |
410 | i_batch += N_rand
411 | if i_batch >= rays_rgb.shape[0]:
412 | # 所用光线用过之后,重新打乱
413 | print("Shuffle data after an epoch!")
414 | rand_idx = torch.randperm(rays_rgb.shape[0])
415 | rays_rgb = rays_rgb[rand_idx]
416 | i_batch = 0
417 |
418 | else:
419 | # Random from one image
420 | img_i = np.random.choice(i_train)
421 | target = images[img_i] # [400,400,3] 图像内容
422 | target = torch.Tensor(target).to(device)
423 | pose = poses[img_i, :3, :4]
424 |
425 | if N_rand is not None:
426 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
427 |
428 | # precrop_iters: number of steps to train on central crops
429 | if i < args.precrop_iters:
430 | dH = int(H // 2 * args.precrop_frac)
431 | dW = int(W // 2 * args.precrop_frac)
432 | coords = torch.stack(
433 | torch.meshgrid(
434 | torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
435 | torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW), indexing='ij',
436 | ), -1)
437 | if i == start:
438 | print(
439 | f"[Config] Center cropping of size {2 * dH} x {2 * dW} is enabled until iter {args.precrop_iters}")
440 | else:
441 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H),
442 | torch.linspace(0, W - 1, W), indexing='ij'),
443 | -1) # (H, W, 2)
444 |
445 | coords = torch.reshape(coords, [-1, 2]) # (H * W, 2)
446 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
447 | # 选出的像素坐标
448 | select_coords = coords[select_inds].long() # (N_rand, 2)
449 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
450 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
451 | batch_rays = torch.stack([rays_o, rays_d], 0) # 堆叠 o和d
452 | # target 也同样选出对应位置的点
453 | # target 用来最后的mse loss 计算
454 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
455 |
456 | ##### Core optimization loop #####
457 | # rgb 网络计算出的图像
458 | # 前三是精细网络的输出内容,其他的还保存在一个dict中,有5项
459 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
460 | verbose=i < 10, retraw=True,
461 | **render_kwargs_train)
462 |
463 | optimizer.zero_grad()
464 | # 计算loss
465 | img_loss = img2mse(rgb, target_s)
466 | loss = img_loss
467 | # 计算指标
468 | psnr = mse2psnr(img_loss)
469 |
470 | # rgb0 粗网络的输出
471 | if 'rgb0' in extras:
472 | img_loss0 = img2mse(extras['rgb0'], target_s)
473 | loss = loss + img_loss0
474 | psnr0 = mse2psnr(img_loss0)
475 |
476 | loss.backward()
477 | optimizer.step()
478 |
479 | # NOTE: IMPORTANT!
480 | ### update learning rate ###
481 | # 学习率衰减
482 | decay_rate = 0.1
483 | decay_steps = args.lrate_decay * 1000
484 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
485 | for param_group in optimizer.param_groups:
486 | param_group['lr'] = new_lrate
487 | ################################
488 |
489 | # 保存模型
490 | if i % args.i_weights == 0:
491 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
492 | torch.save({
493 | # 运行的轮次数目
494 | 'global_step': global_step,
495 | # 粗网络的权重
496 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
497 | # 精细网络的权重
498 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
499 | # 优化器的状态
500 | 'optimizer_state_dict': optimizer.state_dict(),
501 | }, path)
502 | print('Saved checkpoints at', path)
503 |
504 | # 生成测试视频,使用的是render_poses (这个不等同于test数据)
505 | if i % args.i_video == 0 and i > 0:
506 | # Turn on testing mode
507 | with torch.no_grad():
508 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
509 | print('Done, saving', rgbs.shape, disps.shape)
510 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
511 | # 360度转一圈的视频
512 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
513 |
514 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
515 |
516 | # 执行测试,使用测试数据
517 | if i % args.i_testset == 0 and i > 0:
518 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
519 | os.makedirs(testsavedir, exist_ok=True)
520 | print('test poses shape', poses[i_test].shape)
521 | with torch.no_grad():
522 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test,
523 | gt_imgs=images[i_test], savedir=testsavedir)
524 | print('Saved test set')
525 |
526 | # 用时
527 | dt = time.time() - time0
528 | # 打印log信息的频率
529 | if i % args.i_print == 0:
530 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()} Time: {dt}")
531 |
532 | global_step += 1
533 |
534 |
535 | if __name__ == '__main__':
536 | if torch.cuda.is_available():
537 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
538 | train()
539 |
--------------------------------------------------------------------------------
/test/frame.py:
--------------------------------------------------------------------------------
1 | import imageio
2 |
3 | t = imageio.mimread('blender_paper_lego_spiral_200000_rgb.mp4')
4 |
5 | print(len(t))
6 |
--------------------------------------------------------------------------------
/模型.txt:
--------------------------------------------------------------------------------
1 |
2 | pts_linears
3 |
4 | 256+63=319
5 |
6 | ModuleList(
7 | (0): Linear(in_features=63, out_features=256, bias=True)
8 | (1): Linear(in_features=256, out_features=256, bias=True)
9 | (2): Linear(in_features=256, out_features=256, bias=True)
10 | (3): Linear(in_features=256, out_features=256, bias=True)
11 | (4): Linear(in_features=256, out_features=256, bias=True)
12 | (5): Linear(in_features=319, out_features=256, bias=True)
13 | (6): Linear(in_features=256, out_features=256, bias=True)
14 | (7): Linear(in_features=256, out_features=256, bias=True)
15 | )
16 |
17 | 一共八层
18 |
19 | views_linears
20 |
21 | 256+27=283
22 |
23 | ModuleList(
24 | (0): Linear(in_features=283, out_features=128, bias=True)
25 | )
26 |
27 | alpha_linear
28 | Linear(in_features=256, out_features=1, bias=True)
29 |
30 |
31 | feature_linear
32 | Linear(in_features=256, out_features=256, bias=True)
33 |
34 |
35 | rgb_linear
36 | Linear(in_features=128, out_features=3, bias=True)
--------------------------------------------------------------------------------
/源码结构.md:
--------------------------------------------------------------------------------
1 | # NeRF源码
2 |
3 | ## 核心文件
4 |
5 | 1. [opts](opts.py)
6 | 2. [load_blender_data](load_blender.py)
7 | 3. [run_nerf](run_nerf.py)
8 | 4. [nerf_helpers](nerf_helpers.py)
9 | 5. [nerf_model](nerf_model.py)
10 | 6. [render](render.py)
11 | 7. [inference](inference.py)
12 |
13 | ## 简易流程
14 |
15 | 1. [load_blender_data](load_blender.py)
16 | 1. [pose_spherical](load_blender.py)
17 | 2. [create_nerf](run_nerf.py)
18 | 1. get_embedder
19 | 1. create Embedder
20 | 2. 里面有个在train时候会被调用的lambda network_query_fn
21 | 3. [in train iteration](run_nerf.py):
22 | 1. use_batching or not
23 | 1. get_rays
24 | 2. render(训练相关的代码从这里开始)
25 | 1. rays_o, rays_d = get_rays(H, W, K, c2w)
26 | 2. [batchify_rays](render.py) 分批处理
27 | 1. render_rays
28 | 1. 准备工作
29 | 1. 分解出rays_o,rays_d, viewdirs, near, fear
30 | 2. 构造采样点,给采样点加上随机的噪声
31 | 2. network_query_fn (pts, viewdirs, network_fn) 这个函数是create_nerf中的那个lambda函数
32 | 1. run_network
33 | 1. xyz pe
34 | 2. viewdirs pe
35 | 3. batchify 在这里调用的fn就是NeRF model
36 | 1. 将pts,viewdirs 分开,63,27
37 | 2. pts 经过8层Linear
38 | 3. 8层后的输出经过一层Linear 输出 Alpha
39 | 4. 8层后的输出在来一层Linear (feature Linear)
40 | 5. feature 和 input_views拼接 在经过一层Linear
41 | 6. 最后在经过一层Linear 得到RGB
42 | 3. [raw2outputs](render.py) 体渲染在这里
43 | 4. sample_pdf(z_vals_mid, weights, N_importance) 精细网络使用的采样方案
44 | 5. network_query_fn (pts, viewdirs, network_fn) 第二次是精细网络的
45 | 6. raw2outputs 体渲染在这里
46 | 3. [img2mse](nerf_helpers.py)
47 | 4. [mse2psnr](nerf_helpers.py)
48 | 5. 调整学习率
49 | 6. 定期保存模型,定期生成测试视频,定期渲染测试数据
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/说明.md:
--------------------------------------------------------------------------------
1 | # 说明
2 |
3 | transforms文件中的每一项的内容如下:
4 |
5 | 
6 |
7 |
--------------------------------------------------------------------------------