├── .gitignore
├── LICENSE
├── README.md
├── configs
├── chair.txt
├── drums.txt
├── fern.txt
├── ficus.txt
├── flower.txt
├── fortress.txt
├── horns.txt
├── hotdog.txt
├── leaves.txt
├── lego.txt
├── materials.txt
├── mic.txt
├── orchids.txt
├── room.txt
├── ship.txt
└── trex.txt
├── download_example_data.sh
├── imgs
└── pipeline.jpg
├── nerf
├── __init__.py
├── data
│ ├── __init__.py
│ ├── build.py
│ ├── load_LINEMOD.py
│ ├── load_blender.py
│ ├── load_deepvoxels.py
│ └── load_llff.py
├── engine
│ ├── __init__.py
│ └── trainer.py
├── models
│ ├── __init__.py
│ ├── build.py
│ ├── embed.py
│ └── nerf.py
└── utils
│ ├── __init__.py
│ ├── helpers.py
│ └── render.py
├── requirements.txt
├── run.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
2 | **/__pycache__
3 | *.png
4 | *.mp4
5 | *.npy
6 | *.npz
7 | *.dae
8 | data/*
9 | logs/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Megvii
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeRF-megengine
2 |
3 | This is an implementation of [NeRF](https://arxiv.org/abs/2003.08934) for [MegEngine](https://github.com/MegEngine/MegEngine).
4 |
5 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
6 |
7 | 
8 | 
9 |
10 | This project is a faithful MegEngine implementation of NeRF. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf) and Pytorch implemention [here](https://github.com/yenchenlin/nerf-pytorch) by [Yen-Chen Lin](https://github.com/yenchenlin).
11 |
12 | If you meet any problems when using this repo, please feel free to contact [FateScript](https://github.com/FateScript) @ [Megvii](https://megvii.com/).
13 |
14 | ## Installation
15 |
16 | It's recommended to use venv to train your NeRF model.
17 |
18 | A simple venv command example:
19 | ```bash
20 | python3 -m venv nerf_mge
21 | source nerf_mge/bin/activate
22 | ```
23 |
24 | ```
25 | git clone https://github.com/MegEngine/NeRF.git
26 | cd nerf
27 | python3 -m pip -v -e .
28 | ```
29 | The LLFF data loader requires ImageMagick.
30 |
31 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
32 |
33 | ## How To Run?
34 |
35 | ### Quick Start
36 |
37 | Download data for two example datasets: `lego` and `fern`
38 | ```
39 | bash download_example_data.sh
40 | ```
41 |
42 | To train a low-res `lego` NeRF:
43 | ```
44 | python3 run.py --config configs/lego.txt
45 | ```
46 | After training for 100k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
47 |
48 | 
49 |
50 | ---
51 |
52 | To train a low-res `fern` NeRF:
53 | ```
54 | python run.py --config configs/fern.txt
55 | ```
56 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
57 |
58 | 
59 |
60 | ---
61 |
62 | ### More Datasets
63 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
64 | ```
65 | ├── configs
66 | │ ├── ...
67 | │
68 | ├── data
69 | │ ├── nerf_llff_data
70 | │ │ └── fern
71 | │ │ └── flower # downloaded llff dataset
72 | │ │ └── horns # downloaded llff dataset
73 | | | └── ...
74 | | ├── nerf_synthetic
75 | | | └── lego
76 | | | └── ship # downloaded synthetic dataset
77 | | | └── ...
78 | ```
79 |
80 | ---
81 |
82 | To train NeRF on different datasets:
83 |
84 | ```
85 | python run.py --config configs/{DATASET}.txt
86 | ```
87 |
88 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
89 |
90 | ---
91 |
92 | To test NeRF trained on different datasets:
93 |
94 | ```
95 | python run.py --config configs/{DATASET}.txt --render_only
96 | ```
97 |
98 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
99 |
100 |
101 | ### Pre-trained Models
102 |
103 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
104 |
105 | ```
106 | ├── logs
107 | │ ├── fern_test
108 | │ ├── flower_test # downloaded logs
109 | │ ├── trex_test # downloaded logs
110 | ```
111 |
112 | ## Method
113 |
114 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
115 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1,
116 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1,
117 | [Matthew Tancik](http://tancik.com/)\*1,
118 | [Jonathan T. Barron](http://jonbarron.info/)2,
119 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3,
120 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
121 | 1UC Berkeley, 2Google Research, 3UC San Diego
122 | \*denotes equal contribution
123 |
124 |
125 |
126 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
127 |
128 |
129 | ## Citation
130 | Kudos to the authors for their amazing results:
131 | ```
132 | @misc{mildenhall2020nerf,
133 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
134 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
135 | year={2020},
136 | eprint={2003.08934},
137 | archivePrefix={arXiv},
138 | primaryClass={cs.CV}
139 | }
140 | ```
141 | ## Acknowledgements
142 | Some code in this repo is based on the following repo:
143 | * https://github.com/bmild/nerf MIT License
144 | * https://github.com/yenchenlin/nerf-pytorch MIT License
145 |
146 |
147 | ## Known issue
148 | * Some operators is not supported well in MegEngine, I work around this issue by using operator writing in Python, this might cause NeRF-megengine runs slower.
149 |
--------------------------------------------------------------------------------
/configs/chair.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_chair
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/chair
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/drums.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_drums
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/drums
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/fern.txt:
--------------------------------------------------------------------------------
1 | expname = fern_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fern
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/ficus.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ficus
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/ficus
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/flower.txt:
--------------------------------------------------------------------------------
1 | expname = flower_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/flower
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/fortress.txt:
--------------------------------------------------------------------------------
1 | expname = fortress_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fortress
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/horns.txt:
--------------------------------------------------------------------------------
1 | expname = horns_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/horns
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/hotdog.txt:
--------------------------------------------------------------------------------
1 | expname = hotdog_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/hotdog
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/leaves.txt:
--------------------------------------------------------------------------------
1 | expname = leaves_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/leaves
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/lego.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_lego
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/lego
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/materials.txt:
--------------------------------------------------------------------------------
1 | expname = materials_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/materials
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/mic.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_mic
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/mic
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/orchids.txt:
--------------------------------------------------------------------------------
1 | expname = orchids_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/orchids
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/room.txt:
--------------------------------------------------------------------------------
1 | expname = room_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/room
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/configs/ship.txt:
--------------------------------------------------------------------------------
1 | expname = blender_paper_ship
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/ship
4 | dataset_type = blender
5 |
6 | no_batching = True
7 |
8 | use_viewdirs = True
9 | white_bkgd = True
10 | lrate_decay = 500
11 |
12 | N_samples = 64
13 | N_importance = 128
14 | N_rand = 1024
15 |
16 | precrop_iters = 500
17 | precrop_frac = 0.5
18 |
19 | half_res = True
20 |
--------------------------------------------------------------------------------
/configs/trex.txt:
--------------------------------------------------------------------------------
1 | expname = trex_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/trex
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/download_example_data.sh:
--------------------------------------------------------------------------------
1 | wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
2 | mkdir -p data
3 | cd data
4 | wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/nerf_example_data.zip
5 | unzip nerf_example_data.zip
6 | cd ..
7 |
--------------------------------------------------------------------------------
/imgs/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MegEngine/NeRF/b60724ed7ee859a20ed545198e97cc7467fb5829/imgs/pipeline.jpg
--------------------------------------------------------------------------------
/nerf/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | __version__ = "0.1.0"
5 |
--------------------------------------------------------------------------------
/nerf/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | from .build import build_loader
5 |
--------------------------------------------------------------------------------
/nerf/data/build.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | import numpy as np
4 | from loguru import logger
5 |
6 | from .load_llff import load_llff_data
7 | from .load_deepvoxels import load_dv_data
8 | from .load_blender import load_blender_data
9 | from .load_LINEMOD import load_LINEMOD_data
10 |
11 |
12 | def build_loader(dataset_type, args):
13 | K = None
14 | if dataset_type == "llff":
15 | images, poses, bds, render_poses, i_test = load_llff_data(
16 | args.datadir,
17 | args.factor,
18 | recenter=True,
19 | bd_factor=0.75,
20 | spherify=args.spherify,
21 | )
22 | hwf = poses[0, :3, -1]
23 | poses = poses[:, :3, :4]
24 | logger.info(f"Loaded llff {images.shape} {render_poses.shape} {hwf} {args.datadir}")
25 | num_images = images.shape[0]
26 | if not isinstance(i_test, list):
27 | i_test = [i_test]
28 |
29 | if args.llffhold > 0:
30 | logger.info(f"Auto LLFF holdout, {args.llffhold}")
31 | i_test = np.arange(num_images)[:: args.llffhold]
32 |
33 | i_val = i_test
34 | i_train = np.array(
35 | [i for i in range(num_images) if (i not in i_test and i not in i_val)]
36 | )
37 | i_split = (i_train, i_val, i_test)
38 |
39 | logger.info("DEFINING BOUNDS")
40 | if args.no_ndc:
41 | near, far = np.ndarray.min(bds) * 0.9, np.ndarray.max(bds) * 1.0
42 | else:
43 | near, far = 0.0, 1.0
44 | logger.info(f"NEAR: {near} FAR: {far}")
45 |
46 | elif dataset_type == "blender":
47 | images, poses, render_poses, hwf, i_split = load_blender_data(
48 | args.datadir, args.half_res, args.testskip
49 | )
50 | logger.info(f"Loaded blender {images.shape} {render_poses.shape} {hwf} {args.datadir}")
51 | if args.white_bkgd:
52 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
53 | else:
54 | images = images[..., :3]
55 | near, far = 2.0, 6.0
56 |
57 | elif dataset_type == "LINEMOD":
58 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(
59 | args.datadir, args.half_res, args.testskip
60 | )
61 | logger.info(f"Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}")
62 | logger.info(f"[CHECK HERE] near: {near}, far: {far}.")
63 |
64 | if args.white_bkgd:
65 | images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
66 | else:
67 | images = images[..., :3]
68 |
69 | elif dataset_type == "deepvoxels":
70 | images, poses, render_poses, hwf, i_split = load_dv_data(
71 | scene=args.shape, basedir=args.datadir, testskip=args.testskip
72 | )
73 | logger.info(f"Loaded deepvoxels {images.shape} {render_poses.shape} {hwf} {args.datadir}")
74 | hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
75 | near, far = hemi_R - 1.0, hemi_R + 1.0
76 |
77 | else:
78 | raise ValueError(f"Unknown dataset type {dataset_type}")
79 |
80 | # cast height and wigth to right types
81 | height, width, focal = hwf
82 | height, width = int(height), int(width)
83 | hwf = [height, width, focal]
84 |
85 | if K is None:
86 | K = np.array([
87 | [focal, 0, 0.5 * width],
88 | [0, focal, 0.5 * height],
89 | [0, 0, 1]
90 | ])
91 |
92 | return images, poses, render_poses, hwf, i_split, near, far, K
93 |
--------------------------------------------------------------------------------
/nerf/data/load_LINEMOD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 |
37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1):
38 | splits = ['train', 'val', 'test']
39 | metas = {}
40 | for s in splits:
41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
42 | metas[s] = json.load(fp)
43 |
44 | all_imgs = []
45 | all_poses = []
46 | counts = [0]
47 | for s in splits:
48 | meta = metas[s]
49 | imgs = []
50 | poses = []
51 | if s=='train' or testskip==0:
52 | skip = 1
53 | else:
54 | skip = testskip
55 |
56 | for idx_test, frame in enumerate(meta['frames'][::skip]):
57 | fname = frame['file_path']
58 | if s == 'test':
59 | print(f"{idx_test}th test frame: {fname}")
60 | imgs.append(imageio.imread(fname))
61 | poses.append(np.array(frame['transform_matrix']))
62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
63 | poses = np.array(poses).astype(np.float32)
64 | counts.append(counts[-1] + imgs.shape[0])
65 | all_imgs.append(imgs)
66 | all_poses.append(poses)
67 |
68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
69 |
70 | imgs = np.concatenate(all_imgs, 0)
71 | poses = np.concatenate(all_poses, 0)
72 |
73 | H, W = imgs[0].shape[:2]
74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
75 | K = meta['frames'][0]['intrinsic_matrix']
76 | print(f"Focal: {focal}")
77 |
78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
79 |
80 | if half_res:
81 | H = H//2
82 | W = W//2
83 | focal = focal/2.
84 |
85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
86 | for i, img in enumerate(imgs):
87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
88 | imgs = imgs_half_res
89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
90 |
91 | near = np.floor(min(metas['train']['near'], metas['test']['near']))
92 | far = np.ceil(max(metas['train']['far'], metas['test']['far']))
93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far
94 |
95 |
96 |
--------------------------------------------------------------------------------
/nerf/data/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import cv2
7 |
8 |
9 | trans_t = lambda t: torch.Tensor(
10 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]]
11 | ).float()
12 |
13 |
14 | rot_phi = lambda phi: torch.Tensor(
15 | [
16 | [1, 0, 0, 0],
17 | [0, np.cos(phi), -np.sin(phi), 0],
18 | [0, np.sin(phi), np.cos(phi), 0],
19 | [0, 0, 0, 1],
20 | ]
21 | ).float()
22 |
23 |
24 | rot_theta = lambda th: torch.Tensor(
25 | [
26 | [np.cos(th), 0, -np.sin(th), 0],
27 | [0, 1, 0, 0],
28 | [np.sin(th), 0, np.cos(th), 0],
29 | [0, 0, 0, 1],
30 | ]
31 | ).float()
32 |
33 |
34 | def pose_spherical(theta, phi, radius):
35 | c2w = trans_t(radius)
36 | c2w = rot_phi(phi / 180.0 * np.pi) @ c2w
37 | c2w = rot_theta(theta / 180.0 * np.pi) @ c2w
38 | c2w = (
39 | torch.Tensor(
40 | np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
41 | )
42 | @ c2w
43 | )
44 | return c2w
45 |
46 |
47 | def load_blender_data(basedir, half_res=False, testskip=1):
48 | splits = ["train", "val", "test"]
49 | metas = {}
50 | for s in splits:
51 | with open(os.path.join(basedir, "transforms_{}.json".format(s)), "r") as fp:
52 | metas[s] = json.load(fp)
53 |
54 | all_imgs = []
55 | all_poses = []
56 | counts = [0]
57 | for s in splits:
58 | meta = metas[s]
59 | imgs = []
60 | poses = []
61 | if s == "train" or testskip == 0:
62 | skip = 1
63 | else:
64 | skip = testskip
65 |
66 | for frame in meta["frames"][::skip]:
67 | fname = os.path.join(basedir, frame["file_path"] + ".png")
68 | imgs.append(imageio.imread(fname))
69 | poses.append(np.array(frame["transform_matrix"]))
70 | imgs = (np.array(imgs) / 255.0).astype(np.float32) # keep all 4 channels (RGBA)
71 | poses = np.array(poses).astype(np.float32)
72 | counts.append(counts[-1] + imgs.shape[0])
73 | all_imgs.append(imgs)
74 | all_poses.append(poses)
75 |
76 | i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
77 |
78 | imgs = np.concatenate(all_imgs, 0)
79 | poses = np.concatenate(all_poses, 0)
80 |
81 | H, W = imgs[0].shape[:2]
82 | camera_angle_x = float(meta["camera_angle_x"])
83 | focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
84 |
85 | render_poses = torch.stack(
86 | [
87 | pose_spherical(angle, -30.0, 4.0)
88 | for angle in np.linspace(-180, 180, 40 + 1)[:-1]
89 | ],
90 | 0,
91 | )
92 |
93 | if half_res:
94 | H = H // 2
95 | W = W // 2
96 | focal = focal / 2.0
97 |
98 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
99 | for i, img in enumerate(imgs):
100 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
101 | imgs = imgs_half_res
102 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
103 |
104 | return imgs, poses, render_poses, [H, W, focal], i_split
105 |
--------------------------------------------------------------------------------
/nerf/data/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 |
8 |
9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
10 | # Get camera intrinsics
11 | with open(filepath, 'r') as file:
12 | f, cx, cy = list(map(float, file.readline().split()))[:3]
13 | grid_barycenter = np.array(list(map(float, file.readline().split())))
14 | near_plane = float(file.readline())
15 | scale = float(file.readline())
16 | height, width = map(float, file.readline().split())
17 |
18 | try:
19 | world2cam_poses = int(file.readline())
20 | except ValueError:
21 | world2cam_poses = None
22 |
23 | if world2cam_poses is None:
24 | world2cam_poses = False
25 |
26 | world2cam_poses = bool(world2cam_poses)
27 |
28 | print(cx,cy,f,height,width)
29 |
30 | cx = cx / width * trgt_sidelength
31 | cy = cy / height * trgt_sidelength
32 | f = trgt_sidelength / height * f
33 |
34 | fx = f
35 | if invert_y:
36 | fy = -f
37 | else:
38 | fy = f
39 |
40 | # Build the intrinsic matrices
41 | full_intrinsic = np.array([[fx, 0., cx, 0.],
42 | [0., fy, cy, 0],
43 | [0., 0, 1, 0],
44 | [0, 0, 0, 1]])
45 |
46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
47 |
48 |
49 | def load_pose(filename):
50 | assert os.path.isfile(filename)
51 | nums = open(filename).read().split()
52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
53 |
54 |
55 | H = 512
56 | W = 512
57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
58 |
59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
61 | focal = full_intrinsic[0,0]
62 | print(H, W, focal)
63 |
64 |
65 | def dir2poses(posedir):
66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
67 | transf = np.array([
68 | [1,0,0,0],
69 | [0,-1,0,0],
70 | [0,0,-1,0],
71 | [0,0,0,1.],
72 | ])
73 | poses = poses @ transf
74 | poses = poses[:,:3,:4].astype(np.float32)
75 | return poses
76 |
77 | posedir = os.path.join(deepvoxels_base, 'pose')
78 | poses = dir2poses(posedir)
79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
80 | testposes = testposes[::testskip]
81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
82 | valposes = valposes[::testskip]
83 |
84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
86 |
87 |
88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
91 |
92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
95 |
96 | all_imgs = [imgs, valimgs, testimgs]
97 | counts = [0] + [x.shape[0] for x in all_imgs]
98 | counts = np.cumsum(counts)
99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
100 |
101 | imgs = np.concatenate(all_imgs, 0)
102 | poses = np.concatenate([poses, valposes, testposes], 0)
103 |
104 | render_poses = testposes
105 |
106 | print(poses.shape, imgs.shape)
107 |
108 | return imgs, poses, render_poses, [H,W,focal], i_split
109 |
110 |
111 |
--------------------------------------------------------------------------------
/nerf/data/load_llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import imageio
4 | from loguru import logger
5 | from subprocess import check_output
6 |
7 |
8 | # Slightly modified version of LLFF data loading code
9 | # see https://github.com/Fyusion/LLFF for original
10 |
11 | def _minify(basedir, factors=[], resolutions=[]):
12 | needtoload = False
13 | for r in factors:
14 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
15 | if not os.path.exists(imgdir):
16 | needtoload = True
17 | for r in resolutions:
18 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
19 | if not os.path.exists(imgdir):
20 | needtoload = True
21 | if not needtoload:
22 | return
23 |
24 | imgdir = os.path.join(basedir, 'images')
25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
27 | imgdir_orig = imgdir
28 |
29 | wd = os.getcwd()
30 |
31 | for r in factors + resolutions:
32 | if isinstance(r, int):
33 | name = f"images_{r}"
34 | resizearg = '{}%'.format(100./r)
35 | else:
36 | name = 'images_{}x{}'.format(r[1], r[0])
37 | resizearg = '{}x{}'.format(r[1], r[0])
38 | imgdir = os.path.join(basedir, name)
39 | if os.path.exists(imgdir):
40 | continue
41 |
42 | logger.info(f"Minifying {r} {basedir}")
43 |
44 | os.makedirs(imgdir)
45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
46 |
47 | ext = imgs[0].split('.')[-1]
48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
49 | logger.info(args)
50 | os.chdir(imgdir)
51 | check_output(args, shell=True)
52 | os.chdir(wd)
53 |
54 | if ext != 'png':
55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
56 | print('Removed duplicates')
57 | print('Done')
58 |
59 |
60 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
61 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
62 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
63 | bds = poses_arr[:, -2:].transpose([1, 0])
64 |
65 | img0 = [
66 | os.path.join(basedir, 'images', f)
67 | for f in sorted(os.listdir(os.path.join(basedir, 'images')))
68 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')
69 | ][0]
70 | sh = imageio.imread(img0).shape
71 |
72 | sfx = ""
73 |
74 | if factor is not None:
75 | sfx = '_{}'.format(factor)
76 | _minify(basedir, factors=[factor])
77 | factor = factor
78 | elif height is not None:
79 | factor = sh[0] / float(height)
80 | width = int(sh[1] / factor)
81 | _minify(basedir, resolutions=[[height, width]])
82 | sfx = '_{}x{}'.format(width, height)
83 | elif width is not None:
84 | factor = sh[1] / float(width)
85 | height = int(sh[0] / factor)
86 | _minify(basedir, resolutions=[[height, width]])
87 | sfx = '_{}x{}'.format(width, height)
88 | else:
89 | factor = 1
90 |
91 | imgdir = os.path.join(basedir, 'images' + sfx)
92 | if not os.path.exists(imgdir):
93 | logger.info(f"{imgdir} does not exist, returning")
94 | return
95 |
96 | imgfiles = [
97 | os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))
98 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')
99 | ]
100 | if poses.shape[-1] != len(imgfiles):
101 | logger.info(f'Mismatch between imgs {len(imgfiles)} and poses {poses.shape[-1]} !!!!')
102 | return
103 |
104 | # modify intrinsics of camera
105 | sh = imageio.imread(imgfiles[0]).shape
106 | # poses[:, 4, :] *= 1. / factor
107 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
108 | poses[2, 4, :] *= 1./factor # focal / factor
109 |
110 | if not load_imgs:
111 | return poses, bds
112 |
113 | def imread(f):
114 | if f.endswith('png'):
115 | return imageio.imread(f, ignoregamma=True)
116 | else:
117 | return imageio.imread(f)
118 |
119 | imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
120 | imgs = np.stack(imgs, -1)
121 |
122 | logger.info(f'Loaded image data {imgs.shape} {poses[:, -1, 0]}')
123 | return poses, bds, imgs
124 |
125 |
126 | def normalize(x):
127 | return x / np.linalg.norm(x)
128 |
129 |
130 | def viewmatrix(z, up, pos):
131 | vec2 = normalize(z) # z axis
132 | vec0 = normalize(np.cross(up, vec2)) # x axis
133 | vec1 = normalize(np.cross(vec2, vec0)) # y axis
134 | m = np.stack([vec0, vec1, vec2, pos], 1)
135 | # return c2w matrix
136 | return m
137 |
138 |
139 | def ptstocam(pts, c2w):
140 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0]
141 | return tt
142 |
143 |
144 | def poses_avg(poses):
145 | hwf = poses[0, :3, -1:]
146 | center = poses[:, :3, 3].mean(0) # camera coordinates center
147 | z = poses[:, :3, 2].sum(0)
148 | y = poses[:, :3, 1].sum(0)
149 | c2w = np.concatenate([viewmatrix(z, y, center), hwf], 1)
150 | return c2w
151 |
152 |
153 | def render_path_spiral(c2w, up, rads, focal, zrate, rots, N):
154 | render_poses = []
155 | rads = np.array(list(rads) + [1.])
156 | hwf, c2w_matrix = c2w[:, 4:5], c2w[:, :4]
157 |
158 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
159 | c = np.dot(
160 | c2w_matrix,
161 | np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads
162 | )
163 | z = normalize(c - np.dot(c2w_matrix, np.array([0, 0, -focal, 1.])))
164 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
165 |
166 | return render_poses
167 |
168 |
169 | def recenter_poses(poses):
170 | # process of w2c @ intrinsics
171 | poses_ = poses.copy()
172 | bottom = np.array([[0, 0, 0, 1.]]) # bottom part of camera to world matrix
173 | c2w = poses_avg(poses)
174 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) # c2w shape to (4, 4)
175 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
176 | poses = np.concatenate([poses[:, :3, :4], bottom], -2)
177 |
178 | poses = np.linalg.inv(c2w) @ poses
179 | poses_[:, :3, :4] = poses[:, :3, :4]
180 | return poses_
181 |
182 |
183 | def spherify_poses(poses, bds):
184 |
185 | def p34_to_44(p):
186 | return np.concatenate(
187 | [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
188 | )
189 |
190 | rays_d = poses[:, :3, 2:3]
191 | rays_o = poses[:, :3, 3:4]
192 |
193 | def min_line_dist(rays_o, rays_d):
194 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
195 | b_i = -A_i @ rays_o
196 | pt_mindist = np.squeeze(
197 | -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
198 | )
199 | return pt_mindist
200 |
201 | pt_mindist = min_line_dist(rays_o, rays_d)
202 |
203 | center = pt_mindist
204 | up = (poses[:, :3, 3] - center).mean(0)
205 |
206 | vec0 = normalize(up)
207 | vec1 = normalize(np.cross([.1, .2, .3], vec0))
208 | vec2 = normalize(np.cross(vec0, vec1))
209 | pos = center
210 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
211 |
212 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
213 |
214 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
215 |
216 | sc = 1./rad
217 | poses_reset[:, :3, 3] *= sc
218 | bds *= sc
219 | rad *= sc
220 |
221 | centroid = np.mean(poses_reset[:, :3, 3], 0)
222 | zh = centroid[2]
223 | radcircle = np.sqrt(rad**2-zh**2)
224 | new_poses = []
225 |
226 | for th in np.linspace(0., 2.*np.pi, 120):
227 |
228 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
229 | up = np.array([0, 0, -1.])
230 |
231 | vec2 = normalize(camorigin)
232 | vec0 = normalize(np.cross(vec2, up))
233 | vec1 = normalize(np.cross(vec2, vec0))
234 | pos = camorigin
235 | p = np.stack([vec0, vec1, vec2, pos], 1)
236 |
237 | new_poses.append(p)
238 |
239 | new_poses = np.stack(new_poses, 0)
240 |
241 | new_poses = np.concatenate(
242 | [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
243 | )
244 | poses_reset = np.concatenate([
245 | poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
246 | ], -1)
247 |
248 | return poses_reset, new_poses, bds
249 |
250 |
251 | def load_llff_data(
252 | basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False
253 | ):
254 | # factor=8 downsamples original imgs by 8x
255 | poses, bds, images = _load_data(basedir, factor=factor)
256 | logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
257 |
258 | # Correct rotation matrix ordering and move variable dim to axis 0
259 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
260 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
261 | images = np.moveaxis(images, -1, 0).astype(np.float32) # (H, W, C, N)(N, H, W, C)
262 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
263 |
264 | # Rescale if bd_factor is provided
265 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
266 | poses[:, :3, 3] *= sc
267 | bds *= sc
268 |
269 | if recenter:
270 | poses = recenter_poses(poses)
271 |
272 | if spherify:
273 | poses, render_poses, bds = spherify_poses(poses, bds)
274 |
275 | else:
276 | c2w = poses_avg(poses)
277 | logger.info(f'recentered c2w shape: {c2w.shape}')
278 | logger.info(f"c2w: {c2w[:3, :4]}")
279 |
280 | # Get spiral
281 | # Get average pose
282 | up = normalize(poses[:, :3, 1].sum(0))
283 |
284 | # Find a reasonable "focus depth" for this dataset
285 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
286 | dt = .75
287 | focal = 1. / (((1. - dt) / close_depth + dt / inf_depth))
288 |
289 | # Get radii for spiral path
290 | tt = poses[:, :3, 3]
291 | rads = np.percentile(np.abs(tt), 90, 0)
292 | c2w_path = c2w
293 | N_views, N_rots = 120, 2
294 | if path_zflat:
295 | zloc = -close_depth * .1
296 | c2w_path[:3, 3] += zloc * c2w_path[:3, 2]
297 | rads[2] = 0.
298 | N_rots = 1
299 | N_views /= 2
300 |
301 | # generate poses for spiral path
302 | render_poses = render_path_spiral(
303 | c2w_path, up, rads, focal, zrate=.5, rots=N_rots, N=N_views
304 | )
305 |
306 | render_poses = np.array(render_poses).astype(np.float32)
307 |
308 | c2w = poses_avg(poses)
309 | logger.info(
310 | f"Data:\nposes shape: {poses.shape}, images shape: {images.shape}, bds shape: {bds.shape}"
311 | )
312 |
313 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1)
314 | i_test = np.argmin(dists)
315 | logger.info(f"HOLDOUT view is {i_test}")
316 |
317 | images = images.astype(np.float32)
318 | poses = poses.astype(np.float32)
319 |
320 | return images, poses, bds, render_poses, i_test
321 |
--------------------------------------------------------------------------------
/nerf/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
--------------------------------------------------------------------------------
/nerf/engine/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | import datetime
5 | import numpy as np
6 | import os
7 | import imageio
8 | import time
9 | from loguru import logger
10 |
11 | import megengine as mge
12 | import megengine.functional as F
13 | import megengine.distributed as dist
14 | from megengine.autodiff import GradManager
15 |
16 | from nerf.data import build_loader
17 | from nerf.models.build import create_nerf
18 | from nerf.utils import (
19 | ensure_dir, get_rays, get_rays_np, to8b, img2mse, mse2psnr, render_path, render, meshgrid
20 | )
21 |
22 |
23 | def rendering_only(args, images, i_test, render_poses, hwf, K, render_kwargs_test, start_iter):
24 | logger.info("RENDER ONLY")
25 | # render_test switches to test poses, Default is smoother render_poses path
26 | images = images[i_test] if args.render_test else None
27 |
28 | testsavedir = os.path.join(
29 | args.basedir,
30 | args.expname,
31 | "renderonly_{}_{:06d}".format("test" if args.render_test else "path", start_iter),
32 | )
33 | ensure_dir(testsavedir)
34 | logger.info(f"test poses shape: {render_poses.shape}")
35 |
36 | rgbs, _ = render_path(
37 | render_poses,
38 | hwf,
39 | K,
40 | args.chunk,
41 | render_kwargs_test,
42 | savedir=testsavedir,
43 | render_factor=args.render_factor,
44 | )
45 | logger.info(f"Done rendering {testsavedir}")
46 | imageio.mimwrite(os.path.join(testsavedir, "video.mp4"), to8b(rgbs), fps=30, quality=8)
47 |
48 |
49 | class Trainer:
50 | def __init__(self, args):
51 | # init function only defines some basic attr, other attrs like model, optimizer
52 | # are built in `before_train` methods.
53 | self.args = args
54 | self.start_iter, self.max_iter = (0, 200000)
55 | self.rank = 0
56 | self.amp_training = False
57 |
58 | def train(self):
59 | self.before_train()
60 | try:
61 | self.train_in_iter()
62 | except Exception:
63 | raise
64 | finally:
65 | self.after_train()
66 |
67 | def before_train(self):
68 | args = self.args
69 |
70 | logger.info(f"Full args:\n{args}")
71 | # model related init
72 | images, poses, render_poses, hwf, i_split, near, far, K = build_loader(args.dataset_type, args) # noqa
73 |
74 | i_train, i_val, i_test = i_split
75 | info_string = f"Train views are {i_train}\nTest views are {i_test}\nVal views are {i_val}"
76 | if args.render_test:
77 | render_poses = np.array(poses[i_test])
78 |
79 | self.i_split = i_split
80 | self.hwf = hwf
81 | self.K = K
82 | self.render_poses = mge.Tensor(render_poses)
83 |
84 | # create dir to save all the results
85 | self.save_dir = os.path.join(args.basedir, args.expname)
86 | ensure_dir(self.save_dir)
87 | logger.add(os.path.join(self.save_dir, "log.txt"))
88 |
89 | # save args.txt config.txt
90 | with open(os.path.join(self.save_dir, "args.txt"), "w") as f:
91 | for arg in sorted(vars(args)):
92 | f.write(f"{arg} = {getattr(args, arg)}\n")
93 |
94 | if args.config is not None:
95 | with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
96 | f.write(open(args.config, "r").read())
97 |
98 | # Create nerf model
99 | render_kwargs_train, render_kwargs_test, optimizer, gm = self.build_nerf()
100 | bds_dict = {"near": near, "far": far}
101 | render_kwargs_train.update(bds_dict)
102 | render_kwargs_test.update(bds_dict)
103 | self.render_kwargs_train = render_kwargs_train
104 | self.render_kwargs_test = render_kwargs_test
105 | self.optimizer = optimizer
106 | self.grad_manager = gm
107 |
108 | # Short circuit if only rendering out from trained model
109 | if args.render_only:
110 | rendering_only(
111 | args, images, i_test, self.render_poses, self.hwf, self.K,
112 | self.render_kwargs_test, self.start_iter,
113 | )
114 | return
115 |
116 | # Prepare raybatch tensor if batching random rays
117 | use_batching = not args.no_batching
118 | assert use_batching
119 | if use_batching:
120 | # For random ray batching
121 | logger.info("get rays")
122 | rays = np.stack(
123 | [get_rays_np(self.hwf[0], self.hwf[1], self.K, p) for p in poses[:, :3, :4]], 0
124 | ) # [N, ro+rd, H, W, 3]
125 | logger.info("get rays done, start concats")
126 | rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3]
127 | rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3]
128 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
129 | rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3]
130 | rays_rgb = rays_rgb.astype(np.float32)
131 | logger.info("shuffle rays")
132 | np.random.shuffle(rays_rgb)
133 | logger.info("shuffle rasy done")
134 |
135 | images = mge.Tensor(images)
136 | rays_rgb = mge.Tensor(rays_rgb)
137 | i_batch = 0
138 |
139 | self.poses = mge.Tensor(poses)
140 | self.images = images
141 | self.rays_rgb = rays_rgb
142 | self.i_batch = i_batch
143 |
144 | logger.info("Begin training\n" + info_string)
145 |
146 | def build_nerf(self):
147 | args = self.args
148 |
149 | model, model_fine, network_query_fn = create_nerf(args)
150 | params = list(model.parameters())
151 | logger.info(f"Model:\n{model}")
152 | if model_fine is not None:
153 | logger.info(f"Model Fine:\n{model_fine}")
154 | params += list(model_fine.parameters())
155 |
156 | gm = GradManager()
157 | world_size = dist.get_world_size()
158 | callbacks = [dist.make_allreduce_cb("MEAN", dist.WORLD)] if world_size > 1 else None # noqa
159 | gm.attach(params, callbacks=callbacks)
160 |
161 | optimizer = mge.optimizer.Adam(params=params, lr=args.lr, betas=(0.9, 0.999))
162 | self.resume_ckpt(model, model_fine, optimizer)
163 |
164 | render_kwargs_train = {
165 | "network_query_fn": network_query_fn,
166 | "perturb": args.perturb,
167 | "N_importance": args.N_importance,
168 | "network_fine": model_fine,
169 | "N_samples": args.N_samples,
170 | "network_fn": model,
171 | "use_viewdirs": args.use_viewdirs,
172 | "white_bkgd": args.white_bkgd,
173 | "raw_noise_std": args.raw_noise_std,
174 | }
175 |
176 | # NDC only good for LLFF-style forward facing data
177 | if args.dataset_type != "llff" or args.no_ndc:
178 | logger.info("Not ndc!")
179 | render_kwargs_train["ndc"] = False
180 | render_kwargs_train["lindisp"] = args.lindisp
181 |
182 | render_kwargs_test = {k: v for k, v in render_kwargs_train.items()}
183 | render_kwargs_test["perturb"] = False
184 | render_kwargs_test["raw_noise_std"] = 0.0
185 |
186 | return render_kwargs_train, render_kwargs_test, optimizer, gm
187 |
188 | def after_train(self):
189 | logger.info("Training of experiment is done.")
190 |
191 | def train_in_iter(self):
192 | H, W, _ = self.hwf
193 |
194 | for i in range(self.start_iter, self.max_iter):
195 | self.global_step = i + 1
196 |
197 | iter_start_time = time.time()
198 | batch_rays, target_s = self.sample_rays()
199 |
200 | # Core optimization loop
201 | with self.grad_manager:
202 | rgb, disp, acc, extras = render(
203 | H, W, self.K, chunk=self.args.chunk, rays=batch_rays,
204 | retraw=True, **self.render_kwargs_train,
205 | )
206 |
207 | loss = img2mse(rgb, target_s)
208 | psnr = mse2psnr(loss.detach())
209 | if "rgb0" in extras:
210 | loss += img2mse(extras["rgb0"], target_s)
211 |
212 | self.grad_manager.backward(loss)
213 | self.optimizer.step().clear_grad()
214 |
215 | lr = self.update_lr()
216 | iter_time = time.time() - iter_start_time
217 |
218 | self.save_ckpt()
219 | self.save_test()
220 |
221 | # log training info
222 | if self.global_step % self.args.log_interval == 0:
223 | eta_seconds = (self.max_iter - self.global_step) * iter_time
224 | logger.info(
225 | f"iter: {self.global_step}/{self.max_iter}, "
226 | f"loss: {loss.item():.4f}, "
227 | f"PSNR: {psnr.item():.3f}, "
228 | f"lr: {lr:.4e}, "
229 | f"iter time: {iter_time:.3f}s, "
230 | f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}"
231 | )
232 |
233 | def update_lr(self, decay_rate=0.1):
234 | decay_steps = self.args.lrate_decay * 1000
235 | new_lr = self.args.lr * (decay_rate ** (self.global_step / decay_steps))
236 |
237 | for param_group in self.optimizer.param_groups:
238 | param_group["lr"] = new_lr
239 |
240 | return new_lr
241 |
242 | def save_ckpt(self):
243 | if self.rank == 0 and self.global_step % self.args.i_weights == 0:
244 | path = os.path.join(self.save_dir, f"{self.global_step:06d}.tar")
245 | ckpt_state = {
246 | "global_step": self.global_step,
247 | "network_fn_state_dict": self.render_kwargs_train["network_fn"].state_dict(),
248 | "network_fine_state_dict": self.render_kwargs_train["network_fine"].state_dict(),
249 | "optimizer_state_dict": self.optimizer.state_dict(),
250 | }
251 | mge.save(ckpt_state, path)
252 | logger.info(f"Save checkpoint at {path}")
253 |
254 | def save_test(self):
255 | if self.global_step % self.args.i_video == 0:
256 | # Turn on testing mode
257 | rgbs, disps = render_path(
258 | self.render_poses, self.hwf, self.K, self.args.chunk, self.render_kwargs_test
259 | )
260 | logger.info(f"Done, saving {rgbs.shape} {disps.shape}")
261 | moviebase = os.path.join(
262 | self.save_dir, "{}_spiral_{:06d}_".format(self.args.expname, self.global_step)
263 | )
264 | imageio.mimwrite(moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8)
265 | imageio.mimwrite(
266 | moviebase + "disp.mp4", to8b(disps / np.max(disps)), fps=30, quality=8
267 | )
268 |
269 | if self.global_step % self.args.i_testset == 0:
270 | i_test = self.i_split[-1]
271 | test_save_dir = os.path.join(self.save_dir, f"testset_{self.global_step:06d}")
272 | ensure_dir(test_save_dir)
273 | logger.info(f"test poses shape: {self.poses[i_test].shape}")
274 | render_path(
275 | mge.Tensor(self.poses[i_test]),
276 | self.hwf,
277 | self.K,
278 | self.args.chunk,
279 | self.render_kwargs_test,
280 | savedir=test_save_dir,
281 | )
282 | logger.info("Saved test set")
283 |
284 | def resume_ckpt(self, model, model_fine, optimizer):
285 | if self.args.ft_path is not None and self.args.ft_path != "None":
286 | ckpts = [self.args.ft_path]
287 | else:
288 | ckpts = [
289 | os.path.join(self.save_dir, f)
290 | for f in sorted(os.listdir(self.save_dir))
291 | if f.endswith("tar")
292 | ]
293 |
294 | if ckpts and not self.args.no_reload:
295 | logger.info(f"Found ckpts: {ckpts}")
296 | ckpt_to_load = ckpts[-1]
297 | logger.info(f"Reloading from {ckpt_to_load}")
298 | ckpt = mge.load(ckpt_to_load)
299 |
300 | self.start_iter = ckpt["global_step"]
301 | model.load_state_dict(ckpt["network_fn_state_dict"])
302 | if model_fine is not None:
303 | model_fine.load_state_dict(ckpt["network_fine_state_dict"])
304 | optimizer.load_state_dict(ckpt["optimizer_state_dict"])
305 |
306 | return model, model_fine, optimizer
307 |
308 | def sample_rays(self):
309 | # Sample random ray batch
310 | N_rand = self.args.N_rand
311 | use_batching = not self.args.no_batching
312 | rays_rgb = self.rays_rgb
313 | i_train = self.i_split[0]
314 | images = self.images
315 | H, W, _ = self.hwf
316 |
317 | if use_batching:
318 | # Random over all images
319 | batch = rays_rgb[self.i_batch: self.i_batch + N_rand] # [B, 2+1, 3*?]
320 | batch = batch.transpose(1, 0, 2)
321 | batch_rays, target_s = batch[:2], batch[2]
322 |
323 | self.i_batch += N_rand
324 | if self.i_batch >= rays_rgb.shape[0]:
325 | logger.info("Shuffle data after an epoch!")
326 | rand_idx = mge.Tensor(np.random.permutation(rays_rgb.shape[0]))
327 | rays_rgb = rays_rgb[rand_idx]
328 | self.i_batch = 0
329 |
330 | else:
331 | # Random from one image
332 | img_i = np.random.choice(i_train)
333 | target = images[img_i]
334 | target = mge.Tensor(target)
335 | pose = self.poses[img_i, :3, :4]
336 |
337 | if N_rand is not None:
338 | rays_o, rays_d = get_rays(H, W, self.K, pose) # (H, W, 3), (H, W, 3)
339 |
340 | if self.global_step < self.args.precrop_iters:
341 | dH = int(H // 2 * self.args.precrop_frac)
342 | dW = int(W // 2 * self.args.precrop_frac)
343 | coords = F.stack(
344 | meshgrid(
345 | F.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
346 | F.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW),
347 | indexing="ij"
348 | ),
349 | -1,
350 | )
351 | if self.global_step == 1:
352 | logger.info(
353 | f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {self.args.precrop_iters}" # noqa
354 | )
355 | else:
356 | coords = F.stack(
357 | meshgrid(F.linspace(0, H - 1, H), F.linspace(0, W - 1, W), indexing="ij"),
358 | -1,
359 | ) # (H, W, 2)
360 |
361 | coords = coords.reshape(-1, 2) # (H * W, 2)
362 | select_inds = np.random.choice(
363 | coords.shape[0], size=[N_rand], replace=False
364 | ) # (N_rand,)
365 | select_coords = coords[select_inds].long() # (N_rand, 2)
366 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
367 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
368 | batch_rays = F.stack([rays_o, rays_d], 0)
369 | target_s = target[
370 | select_coords[:, 0], select_coords[:, 1]
371 | ] # (N_rand, 3)
372 |
373 | return batch_rays, target_s
374 |
--------------------------------------------------------------------------------
/nerf/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | from .nerf import NeRF
5 |
--------------------------------------------------------------------------------
/nerf/models/build.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | import megengine.functional as F
5 |
6 | from .nerf import NeRF
7 | from .embed import get_embedder
8 |
9 | __all__ = ["create_nerf"]
10 |
11 |
12 | def batchify(fn, chunk):
13 | """Constructs a version of 'fn' that applies to smaller batches."""
14 | if chunk is None:
15 | return fn
16 |
17 | def ret(inputs):
18 | return F.concat([fn(inputs[i: i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
19 |
20 | return ret
21 |
22 |
23 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64):
24 | """Prepares inputs and applies network 'fn'."""
25 | inputs_flat = inputs.reshape([-1, inputs.shape[-1]])
26 | embedded = embed_fn(inputs_flat)
27 |
28 | if viewdirs is not None:
29 | input_dirs = F.broadcast_to(F.expand_dims(viewdirs, axis=1), inputs.shape)
30 | input_dirs_flat = input_dirs.reshape([-1, input_dirs.shape[-1]])
31 | embedded_dirs = embeddirs_fn(input_dirs_flat)
32 | embedded = F.concat([embedded, embedded_dirs], -1)
33 |
34 | outputs_flat = batchify(fn, netchunk)(embedded)
35 | outputs = outputs_flat.reshape(list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
36 | return outputs
37 |
38 |
39 | def create_nerf(args):
40 | """Instantiate NeRF's MLP model."""
41 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
42 |
43 | input_ch_views = 0
44 | embeddirs_fn = None
45 | if args.use_viewdirs:
46 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
47 | output_ch = 5 if args.N_importance > 0 else 4
48 | skips = [4]
49 | model = NeRF(
50 | depth=args.netdepth,
51 | width=args.netwidth,
52 | input_ch=input_ch,
53 | output_ch=output_ch,
54 | skips=skips,
55 | input_ch_views=input_ch_views,
56 | use_viewdirs=args.use_viewdirs,
57 | )
58 |
59 | model_fine = None
60 | if args.N_importance > 0:
61 | model_fine = NeRF(
62 | depth=args.netdepth_fine,
63 | width=args.netwidth_fine,
64 | input_ch=input_ch,
65 | output_ch=output_ch,
66 | skips=skips,
67 | input_ch_views=input_ch_views,
68 | use_viewdirs=args.use_viewdirs,
69 | )
70 |
71 | def network_query_fn(inputs, viewdirs, network_fn):
72 | return run_network(
73 | inputs, viewdirs, network_fn, embed_fn=embed_fn,
74 | embeddirs_fn=embeddirs_fn, netchunk=args.netchunk,
75 | )
76 |
77 | return model, model_fine, network_query_fn
78 |
--------------------------------------------------------------------------------
/nerf/models/embed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | import megengine.functional as F
4 | import megengine.module as nn
5 |
6 | __all__ = ["Embedder", "get_embedder"]
7 |
8 |
9 | # Positional encoding (section 5.1)
10 | class Embedder:
11 | def __init__(self, **kwargs):
12 | self.kwargs = kwargs
13 | self.create_embedding_fn()
14 |
15 | def create_embedding_fn(self):
16 | embed_fns = []
17 | d = self.kwargs["input_dims"]
18 | out_dim = 0
19 | if self.kwargs["include_input"]:
20 | embed_fns.append(lambda x: x)
21 | out_dim += d
22 |
23 | max_freq = self.kwargs["max_freq_log2"]
24 | N_freqs = self.kwargs["num_freqs"]
25 |
26 | if self.kwargs["log_sampling"]:
27 | freq_bands = 2.0 ** F.linspace(0.0, max_freq, N_freqs)
28 | else:
29 | freq_bands = F.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
30 |
31 | for freq in freq_bands:
32 | for p_fn in self.kwargs["periodic_fns"]:
33 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
34 | out_dim += d
35 |
36 | self.embed_fns = embed_fns
37 | self.out_dim = out_dim
38 |
39 | def embed(self, inputs):
40 | return F.concat([fn(inputs) for fn in self.embed_fns], -1)
41 |
42 |
43 | def get_embedder(multires, i=0):
44 | if i == -1:
45 | return nn.Identity(), 3
46 |
47 | embed_kwargs = {
48 | "include_input": True,
49 | "input_dims": 3,
50 | "max_freq_log2": multires - 1,
51 | "num_freqs": multires,
52 | "log_sampling": True,
53 | "periodic_fns": [F.sin, F.cos],
54 | }
55 |
56 | embedder_obj = Embedder(**embed_kwargs)
57 |
58 | def embed(x, eo=embedder_obj):
59 | return eo.embed(x)
60 |
61 | return embed, embedder_obj.out_dim
62 |
--------------------------------------------------------------------------------
/nerf/models/nerf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | import megengine.module as nn
4 | import megengine.functional as F
5 |
6 |
7 | class NeRF(nn.Module):
8 | """NeRF module"""
9 | def __init__(
10 | self, depth=8, width=256, input_ch=3, input_ch_views=3,
11 | output_ch=4, skips=[4], use_viewdirs=False,
12 | ):
13 | super().__init__()
14 | self.depth = depth
15 | self.width = width
16 | self.input_ch = input_ch
17 | self.input_ch_views = input_ch_views
18 | self.skips = skips
19 | self.use_viewdirs = use_viewdirs
20 |
21 | self.pts_linears = [nn.Linear(input_ch, width)] + [
22 | nn.Linear(width, width) if i not in self.skips else nn.Linear(width + input_ch, width)
23 | for i in range(depth - 1)
24 | ]
25 |
26 | if use_viewdirs:
27 | self.alpha_linear = nn.Linear(width, 1)
28 | self.feature_linear = nn.Linear(width, width)
29 | self.views_linears = nn.Linear(input_ch_views + width, width // 2)
30 | self.rgb_linear = nn.Linear(width // 2, 3)
31 | else:
32 | self.output_linear = nn.Linear(width, output_ch)
33 |
34 | def forward(self, x):
35 | input_pts, input_views = F.split(x, [self.input_ch], axis=-1)
36 | h = input_pts
37 |
38 | for i, layer in enumerate(self.pts_linears):
39 | h = F.relu(layer(h))
40 | if i in self.skips:
41 | h = F.concat([input_pts, h], -1)
42 |
43 | if self.use_viewdirs:
44 | alpha = self.alpha_linear(h)
45 | feature = self.feature_linear(h)
46 | h = F.concat([feature, input_views], -1)
47 | h = F.relu(self.views_linears(h))
48 | rgb = self.rgb_linear(h)
49 | outputs = F.concat([rgb, alpha], -1)
50 | else:
51 | outputs = self.output_linear(h)
52 |
53 | return outputs
54 |
--------------------------------------------------------------------------------
/nerf/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | from .helpers import *
5 | from .render import *
--------------------------------------------------------------------------------
/nerf/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import megengine as mge
3 | import megengine.functional as F
4 | import numpy as np
5 | import math
6 |
7 |
8 | __all__ = [
9 | "ensure_dir",
10 | "img2mse",
11 | "mse2psnr",
12 | "to8b",
13 | "get_rays",
14 | "get_rays_np",
15 | "ndc_rays",
16 | "sample_pdf",
17 | "meshgrid",
18 | "cumprod",
19 | ]
20 |
21 |
22 | def cumprod(x: mge.Tensor, axis: int):
23 | dim = x.ndim
24 | axis = axis if axis > 0 else axis + dim
25 | num_loop = x.shape[axis]
26 | t_shape = [i + 1 if i < axis else i for i in range(dim)]
27 | t_shape[axis] = 0
28 | x = x.transpose(*t_shape)
29 | assert len(x) == num_loop
30 | cum_val = F.ones(x[0].shape)
31 | for i in range(num_loop):
32 | cum_val *= x[i]
33 | x[i] = cum_val
34 | return x.transpose(*t_shape)
35 |
36 |
37 | def ensure_dir(path: str):
38 | """create directories if *path* does not exist"""
39 | if not os.path.isdir(path):
40 | os.makedirs(path)
41 |
42 |
43 | def meshgrid(x, y, indexing="xy"):
44 | """meshgrid wrapper for megengine"""
45 | assert len(x.shape) == 1
46 | assert len(y.shape) == 1
47 | mesh_shape = (y.shape[0], x.shape[0])
48 | mesh_x = F.broadcast_to(x, mesh_shape)
49 | mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape)
50 | if indexing == "ij":
51 | mesh_x, mesh_y = mesh_x.T, mesh_y.T
52 | return mesh_x, mesh_y
53 |
54 |
55 | # Misc
56 | def img2mse(x, y):
57 | return F.mean((x - y) ** 2)
58 |
59 |
60 | def mse2psnr(x):
61 | return -10.0 * (F.log(x) / math.log(10.0))
62 |
63 |
64 | def to8b(x):
65 | return (255 * np.clip(x, 0, 1)).astype(np.uint8)
66 |
67 |
68 | # Ray helpers
69 | def get_rays(H, W, K, c2w):
70 | i, j = meshgrid(F.linspace(0, W - 1, W), F.linspace(0, H - 1, H), indexing="xy")
71 | dirs = F.stack(
72 | [
73 | (i - float(K[0][2])) / float(K[0][0]),
74 | -(j - float(K[1][2])) / float(K[1][1]),
75 | -F.ones_like(i),
76 | ], -1
77 | )
78 | # Rotate ray directions from camera frame to the world frame
79 | # dot product, equals to: [c2w.dot(dir) for dir in dirs]
80 | rays_d = (F.expand_dims(dirs, axis=-2) * c2w[:3, :3]).sum(axis=-1)
81 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
82 | rays_o = F.broadcast_to(c2w[:3, -1], rays_d.shape)
83 | return rays_o, rays_d
84 |
85 |
86 | def get_rays_np(H: int, W: int, K: np.array, c2w: np.array):
87 | """
88 |
89 | Args:
90 | H (int): height of image.
91 | W (int): width of image.
92 | K (np.array): intrinsic matrix.
93 | c2w (np.array): camera to world matrix.
94 | """
95 | i, j = np.meshgrid(
96 | np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy"
97 | )
98 | # K @ dirs is (x, -y0, -1), and K is intrinsics of camera
99 | dirs = np.stack([(i - K[0, 2]) / K[0, 0], -(j - K[1, 2]) / K[1, 1], -np.ones_like(i)], -1)
100 | # Rotate ray directions from camera frame to the world frame
101 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # noqa
102 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
103 | rays_o = np.broadcast_to(c2w[:3, -1], rays_d.shape)
104 | return rays_o, rays_d
105 |
106 |
107 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
108 | """
109 | get ray o' and d' in normalized device coordinates space.
110 | check more details in appendix C of nerf paper.
111 |
112 | Args:
113 | rays_o: rays origin of shape (N, 3).
114 | rays_d: rays direction of shape (N, 3).
115 | """
116 | # Shift ray origins to near plane
117 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
118 | rays_o = rays_o + F.expand_dims(t, axis=-1) * rays_d
119 |
120 | # Projection
121 | # according to paper, a_x = - f_cam / (W / 2), a_y = -f_cam / (H / 2)
122 | a_x, a_y = - float((2.0 * focal) / W), - float((2.0 * focal) / H)
123 | o_x = a_x * rays_o[..., 0] / rays_o[..., 2]
124 | o_y = a_y * rays_o[..., 1] / rays_o[..., 2]
125 | o_z = 1.0 + 2.0 * near / rays_o[..., 2]
126 |
127 | d_x = a_x * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
128 | d_y = a_y * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
129 | d_z = -2.0 * near / rays_o[..., 2]
130 |
131 | rays_o = F.stack([o_x, o_y, o_z], -1)
132 | rays_d = F.stack([d_x, d_y, d_z], -1)
133 |
134 | return rays_o, rays_d
135 |
136 |
137 | def search_sorted(cdf, value):
138 | # TODO: torch to pure mge
139 | import torch
140 | inds = torch.searchsorted(torch.tensor(cdf.numpy()), torch.tensor(value.numpy()), right=True)
141 | inds = mge.tensor(inds)
142 | return inds
143 |
144 |
145 | # Hierarchical sampling (section 5.2)
146 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
147 | # Get pdf
148 | weights = weights + 1e-5 # prevent nans
149 | pdf = weights / F.sum(weights, -1, keepdims=True)
150 | cdf = F.cumsum(pdf, -1 + pdf.ndim)
151 | cdf = F.concat([F.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
152 |
153 | # Take uniform samples
154 | if det:
155 | u = F.linspace(0.0, 1.0, N_samples)
156 | u = F.broadcast_to(u, list(cdf.shape[:-1]) + [N_samples])
157 | else:
158 | u = mge.random.uniform(size=list(cdf.shape[:-1]) + [N_samples])
159 |
160 | # Pytest, overwrite u with numpy's fixed random numbers
161 | if pytest:
162 | np.random.seed(0)
163 | new_shape = list(cdf.shape[:-1]) + [N_samples]
164 | if det:
165 | u = np.linspace(0.0, 1.0, N_samples)
166 | u = np.broadcast_to(u, new_shape)
167 | else:
168 | u = np.random.rand(*new_shape)
169 | u = mge.Tensor(u)
170 |
171 | # Invert CDF
172 | inds = search_sorted(cdf, u)
173 | below = F.maximum(F.zeros_like(inds - 1), inds - 1)
174 | above = F.minimum((cdf.shape[-1] - 1) * F.ones_like(inds), inds)
175 | inds_g = F.stack([below, above], -1) # (batch, N_samples, 2)
176 |
177 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
178 | cdf_g = F.gather(F.broadcast_to(F.expand_dims(cdf, axis=1), matched_shape), 2, inds_g)
179 | bins_g = F.gather(F.broadcast_to(F.expand_dims(bins, axis=1), matched_shape), 2, inds_g)
180 | denom = cdf_g[..., 1] - cdf_g[..., 0]
181 | denom = F.where(denom < 1e-5, F.ones_like(denom), denom)
182 | t = (u - cdf_g[..., 0]) / denom
183 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
184 |
185 | return samples
186 |
--------------------------------------------------------------------------------
/nerf/utils/render.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | import os
5 | import imageio
6 | import numpy as np
7 | import megengine as mge
8 | import megengine.functional as F
9 | import time
10 | from collections import defaultdict
11 | from loguru import logger
12 | from tqdm import tqdm
13 |
14 | from nerf.utils import get_rays, ndc_rays, to8b, sample_pdf, cumprod
15 |
16 | __all__ = [
17 | "batchify_rays",
18 | "render_rays",
19 | "render_path",
20 | "render",
21 | ]
22 |
23 |
24 | def batchify_rays(rays_flat, chunk=1024 * 32, **kwargs):
25 | """Render rays in smaller minibatches to avoid OOM."""
26 | all_ret = defaultdict(list)
27 | for i in range(0, rays_flat.shape[0], chunk):
28 | ret = render_rays(rays_flat[i: i + chunk], **kwargs)
29 | for k, v in ret.items():
30 | all_ret[k].append(v)
31 |
32 | return {k: F.concat(all_ret[k], 0) for k in all_ret}
33 |
34 |
35 | def render(
36 | H: int,
37 | W: int,
38 | K: float,
39 | chunk: int = 1024 * 32,
40 | rays=None,
41 | c2w=None,
42 | ndc: bool = True,
43 | near: float = 0.0,
44 | far: float = 1.0,
45 | use_viewdirs: bool = False,
46 | c2w_staticcam=None,
47 | **kwargs,
48 | ):
49 | """Render rays
50 |
51 | Args:
52 | H (int): Height of image in pixels.
53 | W (int): Width of image in pixels.
54 | K: intrinsics of camera
55 | focal: float. Focal length of pinhole camera.
56 | chunk: int. Maximum number of rays to process simultaneously. Used to
57 | control maximum memory usage. Does not affect final results.
58 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
59 | each example in batch.
60 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
61 | ndc: If True, represent ray origin, direction in NDC coordinates.
62 | near: float or array of shape [batch_size]. Nearest distance for a ray.
63 | far: float or array of shape [batch_size]. Farthest distance for a ray.
64 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
65 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
66 | camera while using other c2w argument for viewing directions.
67 |
68 | Returns:
69 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
70 | disp_map: [batch_size]. Disparity map. Inverse of depth.
71 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
72 | extras: dict with everything returned by render_rays().
73 | """
74 | if c2w is not None:
75 | # special case to render full image
76 | rays_o, rays_d = get_rays(H, W, K, c2w)
77 | else:
78 | # use provided ray batch
79 | rays_o, rays_d = rays
80 |
81 | if use_viewdirs:
82 | # provide ray directions as input
83 | viewdirs = rays_d
84 | if c2w_staticcam is not None:
85 | # special case to visualize effect of viewdirs
86 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
87 | viewdirs = viewdirs / F.norm(viewdirs, axis=-1, keepdims=True)
88 | viewdirs = viewdirs.reshape(-1, 3)
89 |
90 | sh = rays_d.shape # [..., 3]
91 | if ndc:
92 | # for forward facing scenes
93 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1.0, rays_o, rays_d)
94 |
95 | # Create ray batch
96 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
97 |
98 | near, far = near * F.ones_like(rays_d[..., :1]), far * F.ones_like(rays_d[..., :1])
99 | rays = F.concat([rays_o, rays_d, near, far], -1)
100 | if use_viewdirs:
101 | rays = F.concat([rays, viewdirs], -1)
102 |
103 | # Render and reshape
104 | all_ret = batchify_rays(rays, chunk, **kwargs)
105 | for k, v in all_ret.items():
106 | all_ret[k] = v.reshape(list(sh[:-1]) + list(v.shape[1:]))
107 |
108 | k_extract = ["rgb_map", "disp_map", "acc_map"]
109 | ret_list = [all_ret[k] for k in k_extract]
110 | ret_dict = {k: v for k, v in all_ret.items() if k not in k_extract}
111 | return ret_list + [ret_dict]
112 |
113 |
114 | def render_path(render_poses, hwf, K, chunk, render_kwargs, savedir=None, render_factor=0):
115 | H, W, focal = hwf
116 |
117 | if render_factor != 0:
118 | # Render downsampled for speed
119 | H = H // render_factor
120 | W = W // render_factor
121 | focal = focal / render_factor
122 |
123 | rgbs, disps = [], []
124 |
125 | t = time.time()
126 | for i, c2w in enumerate(tqdm(render_poses)):
127 | logger.info(f"{i} {time.time() - t}")
128 | t = time.time()
129 | rgb, disp, acc, _ = render(
130 | H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs
131 | )
132 | rgbs.append(rgb.numpy())
133 | disps.append(disp.numpy())
134 | if i == 0:
135 | logger.info(f"rgb shape: {rgb.shape}, disp shape: {disp.shape}")
136 |
137 | if savedir is not None:
138 | rgb8 = to8b(rgbs[-1])
139 | filename = os.path.join(savedir, "{:03d}.png".format(i))
140 | imageio.imwrite(filename, rgb8)
141 |
142 | rgbs = np.stack(rgbs, 0)
143 | disps = np.stack(disps, 0)
144 |
145 | return rgbs, disps
146 |
147 |
148 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
149 | """Transforms model's predictions to semantically meaningful values.
150 |
151 | Args:
152 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
153 | z_vals: [num_rays, num_samples along ray]. Integration time.
154 | rays_d: [num_rays, 3]. Direction of each ray.
155 |
156 | Returns:
157 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
158 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
159 | acc_map: [num_rays]. Sum of weights along each ray.
160 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
161 | depth_map: [num_rays]. Estimated distance to object.
162 |
163 | """
164 | def raw2alpha(raw, dists, act_fn=F.relu):
165 | return 1.0 - F.exp(-act_fn(raw) * dists)
166 |
167 | dists = z_vals[..., 1:] - z_vals[..., :-1]
168 | dists = F.concat([dists, F.full(dists[..., :1].shape, 1e10)], -1) # [N_rays, N_samples]
169 |
170 | dists = dists * F.norm(F.expand_dims(rays_d, axis=-2), axis=-1)
171 |
172 | rgb = F.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
173 | noise = 0.0
174 | if raw_noise_std > 0.0:
175 | noise = mge.random.normal(size=raw[..., 3].shape) * raw_noise_std
176 |
177 | # Overwrite randomly sampled data if pytest
178 | if pytest:
179 | np.random.seed(0)
180 | noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
181 | noise = mge.Tensor(noise)
182 |
183 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
184 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
185 | weights = (
186 | alpha * cumprod(
187 | F.concat([F.ones((alpha.shape[0], 1)), 1.0 - alpha + 1e-10], -1), -1
188 | )[:, :-1]
189 | )
190 | rgb_map = F.sum(F.expand_dims(weights, axis=-1) * rgb, -2) # [N_rays, 3]
191 |
192 | depth_map = F.sum(weights * z_vals, -1)
193 | disp_map = 1.0 / F.maximum(
194 | 1e-10 * F.ones_like(depth_map), depth_map / F.sum(weights, -1)
195 | )
196 | acc_map = F.sum(weights, -1)
197 |
198 | if white_bkgd:
199 | rgb_map = rgb_map + (1.0 - acc_map[..., None])
200 |
201 | return rgb_map, disp_map, acc_map, weights, depth_map
202 |
203 |
204 | def render_rays(
205 | ray_batch,
206 | network_fn,
207 | network_query_fn,
208 | N_samples,
209 | retraw=False,
210 | lindisp=False,
211 | perturb=0.0,
212 | N_importance=0,
213 | network_fine=None,
214 | white_bkgd=False,
215 | raw_noise_std=0.0,
216 | pytest=False,
217 | ):
218 | """Volumetric rendering.
219 |
220 | Args:
221 | ray_batch: array of shape [batch_size, ...]. All information necessary
222 | for sampling along a ray, including: ray origin, ray direction, min
223 | dist, max dist, and unit-magnitude viewing direction.
224 | network_fn: function. Model for predicting RGB and density at each point
225 | in space.
226 | network_query_fn: function used for passing queries to network_fn.
227 | N_samples: int. Number of different times to sample along each ray.
228 | retraw: bool. If True, include model's raw, unprocessed predictions.
229 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
230 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
231 | random points in time.
232 | N_importance: int. Number of additional times to sample along each ray.
233 | These samples are only passed to network_fine.
234 | network_fine: "fine" network with same spec as network_fn.
235 | white_bkgd: bool. If True, assume a white background.
236 | raw_noise_std: ...
237 |
238 | Returns:
239 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
240 | disp_map: [num_rays]. Disparity map. 1 / depth.
241 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
242 | raw: [num_rays, num_samples, 4]. Raw predictions from model.
243 | rgb0: See rgb_map. Output for coarse model.
244 | disp0: See disp_map. Output for coarse model.
245 | acc0: See acc_map. Output for coarse model.
246 | z_std: [num_rays]. Standard deviation of distances along ray for each sample.
247 | """
248 | N_rays = ray_batch.shape[0]
249 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
250 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
251 | bounds = ray_batch[..., 6:8].reshape(-1, 1, 2)
252 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
253 |
254 | # Generate sample bins
255 | bin_vals = F.linspace(0.0, 1.0, N_samples)
256 | if not lindisp:
257 | z_vals = near + (far - near) * bin_vals
258 | else:
259 | z_vals = 1.0 / (1.0 / near * (1.0 - bin_vals) + 1.0 / far * (bin_vals))
260 | z_vals = F.broadcast_to(z_vals, [N_rays, N_samples])
261 |
262 | if perturb > 0.0:
263 | # get intervals between samples
264 | mids = (z_vals[..., :-1] + z_vals[..., 1:]) / 2.0
265 | upper = F.concat([mids, z_vals[..., -1:]], -1)
266 | lower = F.concat([z_vals[..., :1], mids], -1)
267 | # stratified samples in those intervals
268 | t_rand = mge.random.uniform(size=z_vals.shape)
269 |
270 | # Pytest, overwrite u with numpy's fixed random numbers
271 | if pytest:
272 | np.random.seed(0)
273 | t_rand = mge.Tensor(np.random.rand(*list(z_vals.shape)))
274 |
275 | z_vals = lower + (upper - lower) * t_rand
276 |
277 | pts = F.expand_dims(rays_o, axis=-2) + F.expand_dims(rays_d, axis=-2) * F.expand_dims(z_vals, axis=-1) # noqa
278 | # shape of pts: [N_rays, N_samples, 3]
279 |
280 | raw = network_query_fn(pts, viewdirs, network_fn)
281 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
282 | raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest
283 | )
284 |
285 | if N_importance > 0:
286 |
287 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
288 |
289 | z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
290 | z_samples = sample_pdf(
291 | z_vals_mid,
292 | weights[..., 1:-1],
293 | N_importance,
294 | det=(perturb == 0.0),
295 | pytest=pytest,
296 | )
297 | z_samples = z_samples.detach()
298 |
299 | # note that sort in megengine is different from torch
300 | z_vals, _ = F.sort(F.concat([z_vals, z_samples], -1,), descending=False)
301 | pts = F.expand_dims(rays_o, -2) + F.expand_dims(rays_d, -2) * F.expand_dims(z_vals, -1)
302 | # [N_rays, N_samples + N_importance, 3]
303 |
304 | run_fn = network_fn if network_fine is None else network_fine
305 | raw = network_query_fn(pts, viewdirs, run_fn)
306 |
307 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
308 | raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest
309 | )
310 |
311 | ret = {"rgb_map": rgb_map, "disp_map": disp_map, "acc_map": acc_map}
312 | if retraw:
313 | ret["raw"] = raw
314 | if N_importance > 0:
315 | ret["rgb0"] = rgb_map_0
316 | ret["disp0"] = disp_map_0
317 | ret["acc0"] = acc_map_0
318 | ret["z_std"] = F.std(z_samples, axis=-1) # [N_rays]
319 |
320 | return ret
321 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio
2 | imageio-ffmpeg
3 | torch>=1.4
4 | loguru
5 | configargparse
6 | tqdm
7 | opencv-python
8 | megengine
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from loguru import logger
3 |
4 | from nerf.engine import Trainer
5 |
6 |
7 | def config_parser():
8 |
9 | import configargparse
10 |
11 | parser = configargparse.ArgumentParser()
12 | parser.add_argument("--config", is_config_file=True, help="config file path")
13 | parser.add_argument("--expname", type=str, help="experiment name")
14 | parser.add_argument(
15 | "--basedir", type=str, default="./logs/", help="where to store ckpts and logs"
16 | )
17 | parser.add_argument(
18 | "--datadir", type=str, default="./data/llff/fern", help="input data directory"
19 | )
20 |
21 | # training options
22 | parser.add_argument("--netdepth", type=int, default=8, help="layers in network")
23 | parser.add_argument("--netwidth", type=int, default=256, help="channels per layer")
24 | parser.add_argument(
25 | "--netdepth_fine", type=int, default=8, help="layers in fine network"
26 | )
27 | parser.add_argument(
28 | "--netwidth_fine",
29 | type=int,
30 | default=256,
31 | help="channels per layer in fine network",
32 | )
33 | parser.add_argument(
34 | "--N_rand",
35 | type=int,
36 | default=32 * 32 * 4,
37 | help="batch size (number of random rays per gradient step)",
38 | )
39 | parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
40 | parser.add_argument(
41 | "--lrate_decay",
42 | type=int,
43 | default=250,
44 | help="exponential learning rate decay (in 1000 steps)",
45 | )
46 | parser.add_argument(
47 | "--chunk",
48 | type=int,
49 | default=1024 * 32,
50 | help="number of rays processed in parallel, decrease if running out of memory",
51 | )
52 | parser.add_argument(
53 | "--netchunk",
54 | type=int,
55 | default=1024 * 64,
56 | help="number of pts sent through network in parallel, decrease if running out of memory",
57 | )
58 | parser.add_argument(
59 | "--no_batching",
60 | action="store_true",
61 | help="only take random rays from 1 image at a time",
62 | )
63 | parser.add_argument(
64 | "--no_reload", action="store_true", help="do not reload weights from saved ckpt"
65 | )
66 | parser.add_argument(
67 | "--ft_path",
68 | type=str,
69 | default=None,
70 | help="specific weights npy file to reload for coarse network",
71 | )
72 |
73 | # rendering options
74 | parser.add_argument(
75 | "--N_samples", type=int, default=64, help="number of coarse samples per ray"
76 | )
77 | parser.add_argument(
78 | "--N_importance",
79 | type=int,
80 | default=0,
81 | help="number of additional fine samples per ray",
82 | )
83 | parser.add_argument(
84 | "--perturb",
85 | type=float,
86 | default=1.0,
87 | help="set to 0. for no jitter, 1. for jitter",
88 | )
89 | parser.add_argument(
90 | "--use_viewdirs", action="store_true", help="use full 5D input instead of 3D"
91 | )
92 | parser.add_argument(
93 | "--i_embed",
94 | type=int,
95 | default=0,
96 | help="set 0 for default positional encoding, -1 for none",
97 | )
98 | parser.add_argument(
99 | "--multires",
100 | type=int,
101 | default=10,
102 | help="log2 of max freq for positional encoding (3D location)",
103 | )
104 | parser.add_argument(
105 | "--multires_views",
106 | type=int,
107 | default=4,
108 | help="log2 of max freq for positional encoding (2D direction)",
109 | )
110 | parser.add_argument(
111 | "--raw_noise_std",
112 | type=float,
113 | default=0.0,
114 | help="std dev of noise added to regularize sigma_a output, 1e0 recommended",
115 | )
116 |
117 | parser.add_argument(
118 | "--render_only",
119 | action="store_true",
120 | help="do not optimize, reload weights and render out render_poses path",
121 | )
122 | parser.add_argument(
123 | "--render_test",
124 | action="store_true",
125 | help="render the test set instead of render_poses path",
126 | )
127 | parser.add_argument(
128 | "--render_factor",
129 | type=int,
130 | default=0,
131 | help="downsampling factor to speed up rendering, set 4 or 8 for fast preview",
132 | )
133 |
134 | # training options
135 | parser.add_argument(
136 | "--precrop_iters",
137 | type=int,
138 | default=0,
139 | help="number of steps to train on central crops",
140 | )
141 | parser.add_argument(
142 | "--precrop_frac",
143 | type=float,
144 | default=0.5,
145 | help="fraction of img taken for central crops",
146 | )
147 |
148 | # dataset options
149 | parser.add_argument(
150 | "--dataset_type",
151 | type=str,
152 | default="llff",
153 | help="options: llff / blender / deepvoxels",
154 | )
155 | parser.add_argument(
156 | "--testskip",
157 | type=int,
158 | default=8,
159 | help="will load 1/N images from test/val sets, useful for large datasets like deepvoxels",
160 | )
161 |
162 | # deepvoxels flags
163 | parser.add_argument(
164 | "--shape",
165 | type=str,
166 | default="greek",
167 | help="options : armchair / cube / greek / vase",
168 | )
169 |
170 | # blender flags
171 | parser.add_argument(
172 | "--white_bkgd",
173 | action="store_true",
174 | help="set to render synthetic data on a white bkgd (always use for dvoxels)",
175 | )
176 | parser.add_argument(
177 | "--half_res",
178 | action="store_true",
179 | help="load blender synthetic data at 400x400 instead of 800x800",
180 | )
181 |
182 | # llff flags
183 | parser.add_argument(
184 | "--factor", type=int, default=8, help="downsample factor for LLFF images"
185 | )
186 | parser.add_argument(
187 | "--no_ndc",
188 | action="store_true",
189 | help="do not use normalized device coordinates (set for non-forward facing scenes)",
190 | )
191 | parser.add_argument(
192 | "--lindisp",
193 | action="store_true",
194 | help="sampling linearly in disparity rather than depth",
195 | )
196 | parser.add_argument(
197 | "--spherify", action="store_true", help="set for spherical 360 scenes"
198 | )
199 | parser.add_argument(
200 | "--llffhold",
201 | type=int,
202 | default=8,
203 | help="will take every 1/N images as LLFF test set, paper uses 8",
204 | )
205 |
206 | # logging/saving options
207 | parser.add_argument(
208 | "--log_interval",
209 | type=int,
210 | default=50,
211 | help="frequency of information logging",
212 | )
213 | parser.add_argument(
214 | "--i_img", type=int, default=500, help="frequency of tensorboard image logging"
215 | )
216 | parser.add_argument(
217 | "--i_weights", type=int, default=10000, help="frequency of weight ckpt saving"
218 | )
219 | parser.add_argument(
220 | "--i_testset", type=int, default=50000, help="frequency of testset saving"
221 | )
222 | parser.add_argument(
223 | "--i_video",
224 | type=int,
225 | default=50000,
226 | help="frequency of render_poses video saving",
227 | )
228 |
229 | return parser
230 |
231 |
232 | @logger.catch
233 | def main():
234 | np.random.seed(0)
235 | args = config_parser().parse_args()
236 | trainer = Trainer(args)
237 | trainer.train()
238 |
239 |
240 | if __name__ == "__main__":
241 | main()
242 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length = 100
3 | multi_line_output = 3
4 | balanced_wrapping = True
5 | known_third_party = loguru,tabulate,yaml,cv2,numpy,PIL,pycocotools
6 | known_deeplearning = megengine,torch,torchvision
7 | known_myself = nerf
8 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,deeplearning,myself,LOCALFOLDER
9 | no_lines_before=STDLIB,THIRDPARTY
10 | default_section = FIRSTPARTY
11 |
12 | [flake8]
13 | max-line-length = 100
14 | per-file-ignores =
15 | **/__init__.py:F401,F403
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | import setuptools
3 |
4 | with open("README.md", "r") as fh:
5 | long_description = fh.read()
6 |
7 |
8 | with open("nerf/__init__.py", "r") as f:
9 | version = re.search(
10 | r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
11 | f.read(), re.MULTILINE
12 | ).group(1)
13 |
14 |
15 | with open("requirements.txt", "r") as f:
16 | reqs = [x.strip() for x in f.readlines()]
17 |
18 |
19 | setuptools.setup(
20 | name="nerf",
21 | version=version,
22 | author="FateScript",
23 | author_email="wangfeng02@megvii.com",
24 | description="Nerf implemented by megengine",
25 | long_description=long_description,
26 | long_description_content_type="text/markdown",
27 | url=None,
28 | packages=setuptools.find_packages(),
29 | classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"],
30 | install_requires=reqs,
31 | )
32 |
--------------------------------------------------------------------------------