├── requirements.txt
├── configs
├── Ball.txt
├── Pen.txt
├── Glass.txt
└── WineGlass.txt
├── LICENSE
├── load_blender.py
├── load_LINEMOD.py
├── load_deepvoxels.py
├── README.md
├── load_llff.py
├── run_nerf_helpers.py
├── find_bounding_box.py
├── render_model.py
├── run_nerf_inside.py
├── run_ior.py
└── run_nerf.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu113
2 | torch==1.11.0+cu113
3 | torchvision==0.12.0+cu113
4 | imageio
5 | imageio-ffmpeg
6 | matplotlib
7 | configargparse
8 | tqdm
9 | opencv-python
10 | torchdiffeq
--------------------------------------------------------------------------------
/configs/Ball.txt:
--------------------------------------------------------------------------------
1 | expname = Ball
2 | datadir = data\Ball
3 | basedir = logs
4 | dataset_type = llff
5 |
6 | factor = 1
7 |
8 | N_samples = 64
9 | N_importance = 64
10 | N_rand = 1024
11 |
12 | spherify = True
13 | use_viewdirs = True
14 | no_ndc = True
15 | lindisp = False
16 |
17 | llffhold = 10
18 | raw_noise_std = 1.0
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/configs/Pen.txt:
--------------------------------------------------------------------------------
1 | expname = Pen
2 | datadir = data\Pen
3 | basedir = logs
4 | dataset_type = llff
5 |
6 | factor = 1
7 |
8 | N_samples = 64
9 | N_importance = 64
10 | N_rand = 1024
11 |
12 | spherify = True
13 | use_viewdirs = True
14 | no_ndc = True
15 | lindisp = False
16 |
17 | llffhold = 10
18 | raw_noise_std = 1.0
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/configs/Glass.txt:
--------------------------------------------------------------------------------
1 | expname = Glass
2 | datadir = data\Glass
3 | basedir = logs
4 | dataset_type = llff
5 |
6 | factor = 1
7 |
8 | N_samples = 64
9 | N_importance = 64
10 | N_rand = 1024
11 |
12 | spherify = True
13 | use_viewdirs = True
14 | no_ndc = True
15 | lindisp = False
16 |
17 | llffhold = 10
18 | raw_noise_std = 1.0
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/configs/WineGlass.txt:
--------------------------------------------------------------------------------
1 | expname = WineGlass
2 | datadir = data\WineGlass
3 | basedir = logs
4 | dataset_type = llff
5 |
6 | factor = 1
7 |
8 | N_samples = 64
9 | N_importance = 64
10 | N_rand = 1024
11 |
12 | spherify = True
13 | use_viewdirs = True
14 | no_ndc = True
15 | lindisp = False
16 |
17 | llffhold = 10
18 | raw_noise_std = 1.0
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Mojtaba Bemana
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 |
37 | def load_blender_data(basedir, half_res=False, testskip=1):
38 | splits = ['train', 'val', 'test']
39 | metas = {}
40 | for s in splits:
41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
42 | metas[s] = json.load(fp)
43 |
44 | all_imgs = []
45 | all_poses = []
46 | counts = [0]
47 | for s in splits:
48 | meta = metas[s]
49 | imgs = []
50 | poses = []
51 | if s=='train' or testskip==0:
52 | skip = 1
53 | else:
54 | skip = testskip
55 |
56 | for frame in meta['frames'][::skip]:
57 | fname = os.path.join(basedir, frame['file_path'] + '.png')
58 | imgs.append(imageio.imread(fname))
59 | poses.append(np.array(frame['transform_matrix']))
60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
61 | poses = np.array(poses).astype(np.float32)
62 | counts.append(counts[-1] + imgs.shape[0])
63 | all_imgs.append(imgs)
64 | all_poses.append(poses)
65 |
66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
67 |
68 | imgs = np.concatenate(all_imgs, 0)
69 | poses = np.concatenate(all_poses, 0)
70 |
71 | H, W = imgs[0].shape[:2]
72 | camera_angle_x = float(meta['camera_angle_x'])
73 | focal = .5 * W / np.tan(.5 * camera_angle_x)
74 |
75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
76 |
77 | if half_res:
78 | H = H//2
79 | W = W//2
80 | focal = focal/2.
81 |
82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
83 | for i, img in enumerate(imgs):
84 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
85 | imgs = imgs_half_res
86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
87 |
88 |
89 | return imgs, poses, render_poses, [H, W, focal], i_split
90 |
91 |
92 |
--------------------------------------------------------------------------------
/load_LINEMOD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 |
37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1):
38 | splits = ['train', 'val', 'test']
39 | metas = {}
40 | for s in splits:
41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
42 | metas[s] = json.load(fp)
43 |
44 | all_imgs = []
45 | all_poses = []
46 | counts = [0]
47 | for s in splits:
48 | meta = metas[s]
49 | imgs = []
50 | poses = []
51 | if s=='train' or testskip==0:
52 | skip = 1
53 | else:
54 | skip = testskip
55 |
56 | for idx_test, frame in enumerate(meta['frames'][::skip]):
57 | fname = frame['file_path']
58 | if s == 'test':
59 | print(f"{idx_test}th test frame: {fname}")
60 | imgs.append(imageio.imread(fname))
61 | poses.append(np.array(frame['transform_matrix']))
62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
63 | poses = np.array(poses).astype(np.float32)
64 | counts.append(counts[-1] + imgs.shape[0])
65 | all_imgs.append(imgs)
66 | all_poses.append(poses)
67 |
68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
69 |
70 | imgs = np.concatenate(all_imgs, 0)
71 | poses = np.concatenate(all_poses, 0)
72 |
73 | H, W = imgs[0].shape[:2]
74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
75 | K = meta['frames'][0]['intrinsic_matrix']
76 | print(f"Focal: {focal}")
77 |
78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
79 |
80 | if half_res:
81 | H = H//2
82 | W = W//2
83 | focal = focal/2.
84 |
85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
86 | for i, img in enumerate(imgs):
87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
88 | imgs = imgs_half_res
89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
90 |
91 | near = np.floor(min(metas['train']['near'], metas['test']['near']))
92 | far = np.ceil(max(metas['train']['far'], metas['test']['far']))
93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far
94 |
95 |
96 |
--------------------------------------------------------------------------------
/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 |
8 |
9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
10 | # Get camera intrinsics
11 | with open(filepath, 'r') as file:
12 | f, cx, cy = list(map(float, file.readline().split()))[:3]
13 | grid_barycenter = np.array(list(map(float, file.readline().split())))
14 | near_plane = float(file.readline())
15 | scale = float(file.readline())
16 | height, width = map(float, file.readline().split())
17 |
18 | try:
19 | world2cam_poses = int(file.readline())
20 | except ValueError:
21 | world2cam_poses = None
22 |
23 | if world2cam_poses is None:
24 | world2cam_poses = False
25 |
26 | world2cam_poses = bool(world2cam_poses)
27 |
28 | print(cx,cy,f,height,width)
29 |
30 | cx = cx / width * trgt_sidelength
31 | cy = cy / height * trgt_sidelength
32 | f = trgt_sidelength / height * f
33 |
34 | fx = f
35 | if invert_y:
36 | fy = -f
37 | else:
38 | fy = f
39 |
40 | # Build the intrinsic matrices
41 | full_intrinsic = np.array([[fx, 0., cx, 0.],
42 | [0., fy, cy, 0],
43 | [0., 0, 1, 0],
44 | [0, 0, 0, 1]])
45 |
46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
47 |
48 |
49 | def load_pose(filename):
50 | assert os.path.isfile(filename)
51 | nums = open(filename).read().split()
52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
53 |
54 |
55 | H = 512
56 | W = 512
57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
58 |
59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
61 | focal = full_intrinsic[0,0]
62 | print(H, W, focal)
63 |
64 |
65 | def dir2poses(posedir):
66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
67 | transf = np.array([
68 | [1,0,0,0],
69 | [0,-1,0,0],
70 | [0,0,-1,0],
71 | [0,0,0,1.],
72 | ])
73 | poses = poses @ transf
74 | poses = poses[:,:3,:4].astype(np.float32)
75 | return poses
76 |
77 | posedir = os.path.join(deepvoxels_base, 'pose')
78 | poses = dir2poses(posedir)
79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
80 | testposes = testposes[::testskip]
81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
82 | valposes = valposes[::testskip]
83 |
84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
86 |
87 |
88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
91 |
92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
95 |
96 | all_imgs = [imgs, valimgs, testimgs]
97 | counts = [0] + [x.shape[0] for x in all_imgs]
98 | counts = np.cumsum(counts)
99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
100 |
101 | imgs = np.concatenate(all_imgs, 0)
102 | poses = np.concatenate([poses, valposes, testposes], 0)
103 |
104 | render_poses = testposes
105 |
106 | print(poses.shape, imgs.shape)
107 |
108 | return imgs, poses, render_poses, [H,W,focal], i_split
109 |
110 |
111 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Eikonal Fields for Refractive Novel-View Synthesis [[Project Page]](https://eikonalfield.mpi-inf.mpg.de)
2 |
3 |
4 |
5 |
6 |
7 |
8 | ## Installation
9 |
10 | ```
11 | conda create -n eikonalfield python=3.8
12 | conda activate eikonalfield
13 | pip install -r requirements.txt
14 | ```
15 |
16 |
17 | Dependencies (Click to expand)
18 |
19 | ### Dependencies
20 | * torch>=1.8
21 | * torchvision>=0.9.1
22 | * matplotlib
23 | * imageio
24 | * imageio-ffmpeg
25 | * configargparse
26 | * tqdm
27 | * opencv-python
28 | * [torchdiffeq](https://github.com/rtqichen/torchdiffeq)
29 |
30 |
31 |
32 | ## Dataset
33 |
34 |
35 |
36 |
37 | * [``Ball``](https://eikonalfield.mpi-inf.mpg.de//datasets/Ball.zip)
38 | * [``Glass``](https://eikonalfield.mpi-inf.mpg.de//datasets/Glass.zip)
39 | * [``Pen``](https://eikonalfield.mpi-inf.mpg.de//datasets/Pen.zip)
40 | * [``WineGlass``](https://eikonalfield.mpi-inf.mpg.de//datasets/WineGlass.zip)
41 |
42 | Each dataset contains the captured images and a short video of the scene.
43 | In the `captured images` folder, we provide the images with the original 4K resolution and a smaller resolution with the estimated camera poses using the COLMAP and [LLFF code](https://github.com/fyusion/llff). In the `captured video` folder, the video frames with their estimated camera poses are provided.
44 |
45 |
46 | ## Training
47 |
48 | * __Step 0__: Finding the camera poses ``poses_bounds.npy``with the instruction given [here](https://github.com/bmild/nerf#generating-poses-for-your-own-scenes)
49 |
50 | (For our dataset the camera parameters are already provided!)
51 | * __Step 1__: Estimating the radiance field for the entire scene by running ``run_nerf.py`` (The code is borrowed from [``nerf-pytorch``](https://github.com/yenchenlin/nerf-pytorch))
52 | ```
53 | python run_nerf.py --config configs/Glass.txt
54 | ```
55 | The config files for the scenes in our dataset are located in the `configs` folder.
56 |
57 | * __Step 2__: Finding the 3D bounding box containing the transparent object using ``find_bounding_box.py``.
58 | ```
59 | python find_bounding_box.py --config configs/Glass.txt
60 | ```
61 |
62 |
63 | (Click to expand) In this step, 1/10 of the training images are displayed in order to mark a few points at the extent of the transparent object.
64 |
65 |
66 |
67 |
68 |
69 |
70 | * __Step 3__: Learning the index of refraction (IoR) field with ``run_ior.py``
71 | ```
72 | python run_ior.py --config configs/Glass.txt --N_rand 32000 --N_samples 128
73 | ```
74 | * __Step 4__: Learning the radiance field for the object inside the transparent object using ``run_nerf_inside.py``
75 | ```
76 | python run_nerf_inside.py --config configs/Glass.txt --N_samples 512
77 | ```
78 | (Please note that for the Ball scene we skipped this step)
79 |
80 | ## Rendering
81 |
82 | Please run the ``render_model.py`` with different modes to render the learned models at each training step.
83 |
84 | ```
85 | python render_model.py --config configs/Glass.txt --N_samples 512 --mode 1 --render_video
86 | ```
87 |
88 | The rendering options are:
89 |
90 | ```
91 | --mode # use 0 to render the output of step 1 (Original NeRF)
92 | # use 1 to render the output of step 3 (Learned IoR)
93 | # use 2 to render the output of step 4 (Complete model with the inside NeRF)
94 | --render_test # rendering the test set images
95 | --render_video # rendering a video from a precomputed path
96 | --render_from_path # rendering a video from a specified path
97 | ```
98 | ## Models
99 |
100 | Please find below our results and pre-trained model for each scene:
101 | * [``Ball``](https://eikonalfield.mpi-inf.mpg.de//results/Ball.zip)
102 | * [``Glass``](https://eikonalfield.mpi-inf.mpg.de//results/Glass.zip)
103 | * [``Pen``](https://eikonalfield.mpi-inf.mpg.de//results/Pen.zip)
104 | * [``WineGlass``](https://eikonalfield.mpi-inf.mpg.de//results/WineGlass.zip)
105 |
106 | Each scene contains the following folders:
107 |
108 | * ``model_weights`` --> the pre-trained model
109 | * ``bounding_box`` ---> the parameters of the bounding box
110 | * ``masked_regions`` ---> the masked images identifying the regions crossing the bounding box in each view
111 | * ``rendered_from_a_path`` ---> the rendered video result along the camera trajectory of the real video capture
112 |
113 |
114 | ## Details
115 | ### Capturing
116 | Our method works with a general capturing setup and does not require any calibration pattern or a specific setup. We spherically capture the scene and get close enough to the transparent object to properly sample the transparent object.
117 |
118 |
119 | ### Bounding Box
120 | Our bounding box (BB) is a rectangular cuboid parameterized by its center $c = (c_x,c_y,c_z)$ and the distances from the center to a face in each dimension $d=(d_x,d_y,d_z)$.
121 | For a 3D point $(x,y,z)$ in the space, the bounding box is analytically expressed as follows:
122 |
123 | $$ 𝐵𝐵(𝒙,𝒚,𝒛)= 1 - 𝑆(𝛽*(d_𝑥−|𝒙−𝑐_𝑥 |)). 𝑆(𝛽*(d_𝑦−|𝒚−𝑐_𝑦 |)) . 𝑆(𝛽*(d_𝑧−|𝒛−𝑐_𝑧 |)) $$
124 |
125 |
126 | where
127 | $𝑆(𝑥)=\frac{1}{1+𝑒^{−𝑥}}$ is the sigmoid function and $\beta$ is the steepness coefficient. We use $\beta=200$ in our experiments. Using this function, a point inside the box gets a zero value and the points outside get a value close to one.
128 |
129 | ### Voxel grid
130 | In our IoR optimizations we first need to smooth the learned radiance field; however, explicitly smoothing an MLP-based radiance field is not straightforward,
131 | we instead fit a uniform 3D grid to the learned radiance field. We then band-limit the grid in the Fourier domain using a Gaussian blur kernel to obtain the coarse-to-fine radiance field.
132 | Note we fit the voxel grid to the NeRF coarse model rather than the fine one to avoid aliasing, and for the spherical captures, we limit the scene far bound to 2.5.
133 |
134 | ### IoR optimization
135 | Since we have a complete volume rendering model in the form of ODEs, we use a differentiable ODE solvers package provided by the [Neural ODE](https://github.com/rtqichen/torchdiffeq) to backpropagate through the ODEs. Moreover, using this package our training proceeds in a memory independent of the step count which allows the processing of more rays (as large as 32k rays) in each iteration.
136 |
137 | ### IoR model
138 | When using differentiable ODE solvers, we found it very important to use a smooth non-linear activation such as Softplus in our IoR MLP model otherwise the optimization becomes unstable.
139 |
140 |
141 | ### Rendering
142 | Since we could not utilize a hierarchical sampling in our volume rendering with ODE formulation, we consider
143 | 512 steps along the ray to properly sample both interior and exterior radiance fields.
144 |
145 |
146 |
147 | ## Citation
148 |
149 | @inproceedings{bemana2022eikonal,
150 | title={Eikonal Fields for Refractive Novel-View Synthesis},
151 | author={Bemana, Mojtaba and Myszkowski, Karol and Revall Frisvad, Jeppe and Seidel, Hans-Peter and Ritschel, Tobias},
152 | booktitle={Special Interest Group on Computer Graphics and Interactive Techniques Conference Proceedings},
153 | pages={1--9},
154 | year={2022}
155 |
156 | ## Contact
157 | mbemana@mpi-inf.mpg.de
158 |
--------------------------------------------------------------------------------
/load_llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, imageio
3 |
4 |
5 | ########## Slightly modified version of LLFF data loading code
6 | ########## see https://github.com/Fyusion/LLFF for original
7 |
8 | def _minify(basedir, factors=[], resolutions=[]):
9 | needtoload = False
10 | for r in factors:
11 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
12 | if not os.path.exists(imgdir):
13 | needtoload = True
14 | for r in resolutions:
15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
16 | if not os.path.exists(imgdir):
17 | needtoload = True
18 | if not needtoload:
19 | return
20 |
21 | from shutil import copy
22 | from subprocess import check_output
23 |
24 | imgdir = os.path.join(basedir, 'images')
25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
27 | imgdir_orig = imgdir
28 | wd = os.getcwd()
29 |
30 | for r in factors + resolutions:
31 | if isinstance(r, int):
32 | name = 'images_{}'.format(r)
33 | resizearg = '{}%'.format(100./r)
34 | else:
35 | name = 'images_{}x{}'.format(r[1], r[0])
36 | resizearg = '{}x{}'.format(r[1], r[0])
37 | imgdir = os.path.join(basedir, name)
38 | if os.path.exists(imgdir):
39 | continue
40 |
41 | print('Minifying', r, basedir)
42 |
43 | os.makedirs(imgdir)
44 | check_output('copy {}\* {}'.format(imgdir_orig, imgdir), shell=True)
45 |
46 | ext = imgs[0].split('.')[-1]
47 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
48 | print(args)
49 | os.chdir(imgdir)
50 | check_output(args, shell=True)
51 | os.chdir(wd)
52 |
53 | if ext != 'png':
54 | check_output('del {}\*.{}'.format(imgdir, ext), shell=True)
55 | print('Removed duplicates')
56 | print('Done')
57 |
58 |
59 |
60 |
61 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
62 |
63 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
64 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
65 | bds = poses_arr[:, -2:].transpose([1,0])
66 |
67 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
68 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
69 | sh = imageio.imread(img0).shape
70 |
71 | sfx = ''
72 |
73 | if factor is not None and factor != 1:
74 | sfx = '_{}'.format(factor)
75 | _minify(basedir, factors=[factor])
76 | factor = factor
77 | elif height is not None:
78 | factor = sh[0] / float(height)
79 | width = int(sh[1] / factor)
80 | _minify(basedir, resolutions=[[height, width]])
81 | sfx = '_{}x{}'.format(width, height)
82 | elif width is not None:
83 | factor = sh[1] / float(width)
84 | height = int(sh[0] / factor)
85 | _minify(basedir, resolutions=[[height, width]])
86 | sfx = '_{}x{}'.format(width, height)
87 | else:
88 | factor = 1
89 |
90 | imgdir = os.path.join(basedir, 'images' + sfx)
91 | if not os.path.exists(imgdir):
92 | print( imgdir, 'does not exist, returning' )
93 | return
94 |
95 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
96 | # if poses.shape[-1] != len(imgfiles):
97 | # print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
98 | # return
99 |
100 | sh = imageio.imread(imgfiles[0]).shape
101 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
102 | poses[2, 4, :] = poses[2, 4, :] * 1./factor
103 |
104 | if not load_imgs:
105 | return poses, bds
106 |
107 | def imread(f):
108 | if f.endswith('png'):
109 | return imageio.imread(f, ignoregamma=True)
110 | else:
111 | return imageio.imread(f)
112 |
113 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles]
114 | imgs = np.stack(imgs, -1)
115 |
116 | print('Loaded image data', imgs.shape, poses[:,-1,0])
117 | return poses, bds, imgs
118 |
119 |
120 |
121 |
122 |
123 |
124 | def normalize(x):
125 | return x / np.linalg.norm(x)
126 |
127 | def viewmatrix(z, up, pos):
128 | vec2 = normalize(z)
129 | vec1_avg = up
130 | vec0 = normalize(np.cross(vec1_avg, vec2))
131 | vec1 = normalize(np.cross(vec2, vec0))
132 | m = np.stack([vec0, vec1, vec2, pos], 1)
133 | return m
134 |
135 | def ptstocam(pts, c2w):
136 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
137 | return tt
138 |
139 | def poses_avg(poses):
140 |
141 | hwf = poses[0, :3, -1:]
142 |
143 | center = poses[:, :3, 3].mean(0)
144 | vec2 = normalize(poses[:, :3, 2].sum(0))
145 | up = poses[:, :3, 1].sum(0)
146 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
147 |
148 | return c2w
149 |
150 |
151 |
152 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
153 | render_poses = []
154 | rads = np.array(list(rads) + [1.])
155 | hwf = c2w[:,4:5]
156 |
157 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
158 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
159 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
160 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
161 | return render_poses
162 |
163 |
164 |
165 | def recenter_poses(poses):
166 |
167 | poses_ = poses+0
168 | bottom = np.reshape([0,0,0,1.], [1,4])
169 | c2w = poses_avg(poses)
170 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
171 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
172 | poses = np.concatenate([poses[:,:3,:4], bottom], -2)
173 |
174 | poses = np.linalg.inv(c2w) @ poses
175 | poses_[:,:3,:4] = poses[:,:3,:4]
176 | poses = poses_
177 | return poses
178 |
179 |
180 | #####################
181 |
182 |
183 | def spherify_poses(poses, bds):
184 |
185 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
186 |
187 | rays_d = poses[:,:3,2:3]
188 | rays_o = poses[:,:3,3:4]
189 |
190 | def min_line_dist(rays_o, rays_d):
191 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
192 | b_i = -A_i @ rays_o
193 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
194 | return pt_mindist
195 |
196 | pt_mindist = min_line_dist(rays_o, rays_d)
197 |
198 | center = pt_mindist
199 | up = (poses[:,:3,3] - center).mean(0)
200 |
201 | vec0 = normalize(up)
202 | vec1 = normalize(np.cross([.1,.2,.3], vec0))
203 | vec2 = normalize(np.cross(vec0, vec1))
204 | pos = center
205 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
206 |
207 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
208 |
209 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
210 |
211 | sc = 1./rad
212 | poses_reset[:,:3,3] *= sc
213 | bds *= sc
214 | rad *= sc
215 |
216 | centroid = np.mean(poses_reset[:,:3,3], 0)
217 | zh = centroid[2]
218 | radcircle = np.sqrt(rad**2-zh**2)
219 | new_poses = []
220 |
221 | for th in np.linspace(0.,2.*np.pi, 120):
222 |
223 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
224 | up = np.array([0,0,-1.])
225 |
226 | vec2 = normalize(camorigin)
227 | vec0 = normalize(np.cross(vec2, up))
228 | vec1 = normalize(np.cross(vec2, vec0))
229 | pos = camorigin
230 | p = np.stack([vec0, vec1, vec2, pos], 1)
231 |
232 | new_poses.append(p)
233 |
234 | new_poses = np.stack(new_poses, 0)
235 |
236 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
237 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
238 |
239 | return poses_reset, new_poses, bds
240 |
241 |
242 | def register_video_path(basedir,factor,poses,scale):
243 |
244 | poses_video, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
245 | poses_video = np.concatenate([poses_video[:, 1:2, :], -poses_video[:, 0:1, :], poses_video[:, 2:, :]], 1)
246 | poses_video = np.moveaxis(poses_video, -1, 0).astype(np.float32)
247 | # poses_video = poses_video[84:]
248 | poses_video[:,:3,3] *= scale
249 |
250 | poses_ = poses_video+0
251 | bottom = np.reshape([0,0,0,1.], [1,4])
252 | c2w = poses_avg(poses)
253 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
254 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses_video.shape[0],1,1])
255 | poses_video = np.concatenate([poses_video[:,:3,:4], bottom], -2)
256 |
257 | poses_video = np.linalg.inv(c2w) @ poses_video
258 | poses_[:,:3,:4] = poses_video[:,:3,:4]
259 | poses_video = poses_
260 |
261 | poses = recenter_poses(poses)
262 |
263 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
264 |
265 | rays_d = poses[:,:3,2:3]
266 | rays_o = poses[:,:3,3:4]
267 |
268 | def min_line_dist(rays_o, rays_d):
269 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
270 | b_i = -A_i @ rays_o
271 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
272 | return pt_mindist
273 |
274 | pt_mindist = min_line_dist(rays_o, rays_d)
275 |
276 | center = pt_mindist
277 | up = (poses[:,:3,3] - center).mean(0)
278 |
279 | vec0 = normalize(up)
280 | vec1 = normalize(np.cross([.1,.2,.3], vec0))
281 | vec2 = normalize(np.cross(vec0, vec1))
282 | pos = center
283 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
284 |
285 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses_video[:,:3,:4])
286 |
287 | poses_reset_ = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
288 |
289 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset_[:,:3,3]), -1)))
290 |
291 | sc = 1./rad
292 | poses_reset[:,:3,3] *= sc
293 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses_video[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
294 |
295 | poses_reset = poses_reset.astype(np.float32)
296 |
297 |
298 | return poses_reset
299 |
300 |
301 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_video=None, path_zflat=False):
302 |
303 |
304 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
305 | print('Loaded', basedir, bds.min(), bds.max())
306 |
307 | # Correct rotation matrix ordering and move variable dim to axis 0
308 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
309 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
310 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
311 | images = imgs
312 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
313 |
314 | # Rescale if bd_factor is provided
315 | # sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
316 | sc = 1./(bds.max() - bds.min())
317 | print(sc)
318 | poses[:,:3,3] *= sc
319 | bds *= sc
320 | print('Loaded', basedir, bds.min(), bds.max())
321 |
322 | if path_video is not None:
323 | render_path = register_video_path(path_video,factor,poses,sc)
324 |
325 |
326 | if recenter:
327 | poses = recenter_poses(poses)
328 |
329 | if spherify:
330 | poses, render_poses, bds = spherify_poses(poses, bds)
331 |
332 | else:
333 |
334 | c2w = poses_avg(poses)
335 | print('recentered', c2w.shape)
336 | print(c2w[:3,:4])
337 |
338 | ## Get spiral
339 | # Get average pose
340 | up = normalize(poses[:, :3, 1].sum(0))
341 |
342 | # Find a reasonable "focus depth" for this dataset
343 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
344 | dt = .75
345 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
346 | focal = mean_dz
347 |
348 | # Get radii for spiral path
349 | shrink_factor = .8
350 | zdelta = close_depth * .2
351 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
352 | rads = np.percentile(np.abs(tt), 90, 0)
353 | c2w_path = c2w
354 | N_views = 120
355 | N_rots = 2
356 | if path_zflat:
357 | # zloc = np.percentile(tt, 10, 0)[2]
358 | zloc = -close_depth * .1
359 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2]
360 | rads[2] = 0.
361 | N_rots = 1
362 | N_views/=2
363 |
364 | # Generate poses for spiral path
365 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
366 |
367 |
368 | render_poses = np.array(render_poses).astype(np.float32)
369 |
370 | c2w = poses_avg(poses)
371 | print('Data:')
372 | print(poses.shape, images.shape, bds.shape)
373 |
374 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1)
375 | i_test = np.argmin(dists)
376 | print('HOLDOUT view is', i_test)
377 |
378 | images = images.astype(np.float32)
379 | poses = poses.astype(np.float32)
380 |
381 | if path_video is None:
382 | render_path = render_poses
383 |
384 |
385 | return images, poses, bds, render_path, i_test
386 |
387 |
388 |
389 |
--------------------------------------------------------------------------------
/run_nerf_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from tqdm import tqdm, trange
8 |
9 | # Misc
10 | img2mse = lambda x, y : torch.mean((x - y) ** 2)
11 | img2abs = lambda x, y : torch.mean(torch.abs(x - y))
12 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
13 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
14 |
15 |
16 | # Positional encoding (section 5.1)
17 | class Embedder:
18 | def __init__(self, **kwargs):
19 | self.kwargs = kwargs
20 | self.create_embedding_fn()
21 |
22 | def create_embedding_fn(self):
23 | embed_fns = []
24 | d = self.kwargs['input_dims']
25 | out_dim = 0
26 | if self.kwargs['include_input']:
27 | embed_fns.append(lambda x : x)
28 | out_dim += d
29 |
30 | max_freq = self.kwargs['max_freq_log2']
31 | N_freqs = self.kwargs['num_freqs']
32 |
33 | if self.kwargs['log_sampling']:
34 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
35 | else:
36 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
37 |
38 | for freq in freq_bands:
39 | for p_fn in self.kwargs['periodic_fns']:
40 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
41 | out_dim += d
42 |
43 | self.embed_fns = embed_fns
44 | self.out_dim = out_dim
45 |
46 | def embed(self, inputs):
47 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
48 |
49 |
50 | def get_embedder(multires, i=0):
51 | if i == -1:
52 | return nn.Identity(), 3
53 |
54 | embed_kwargs = {
55 | 'include_input' : True,
56 | 'input_dims' : 3,
57 | 'max_freq_log2' : multires-1,
58 | 'num_freqs' : multires,
59 | 'log_sampling' : True,
60 | 'periodic_fns' : [torch.sin, torch.cos],
61 | }
62 |
63 | embedder_obj = Embedder(**embed_kwargs)
64 | embed = lambda x, eo=embedder_obj : eo.embed(x)
65 | return embed, embedder_obj.out_dim
66 |
67 |
68 | # Model
69 | class NeRF(nn.Module):
70 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
71 | """
72 | """
73 | super(NeRF, self).__init__()
74 | self.D = D
75 | self.W = W
76 | self.input_ch = input_ch
77 | self.input_ch_views = input_ch_views
78 | self.skips = skips
79 | self.use_viewdirs = use_viewdirs
80 |
81 | self.pts_linears = nn.ModuleList(
82 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
83 |
84 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
85 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
86 |
87 | ### Implementation according to the paper
88 | # self.views_linears = nn.ModuleList(
89 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
90 |
91 | if use_viewdirs:
92 | self.feature_linear = nn.Linear(W, W)
93 | self.alpha_linear = nn.Linear(W, 1)
94 | self.rgb_linear = nn.Linear(W//2, 3)
95 | else:
96 | self.output_linear = nn.Linear(W, output_ch)
97 |
98 | def forward(self, x):
99 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
100 | h = input_pts
101 | for i, l in enumerate(self.pts_linears):
102 | h = self.pts_linears[i](h)
103 | h = F.relu(h)
104 | if i in self.skips:
105 | h = torch.cat([input_pts, h], -1)
106 |
107 | if self.use_viewdirs:
108 | alpha = self.alpha_linear(h)
109 | feature = self.feature_linear(h)
110 | h = torch.cat([feature, input_views], -1)
111 |
112 | for i, l in enumerate(self.views_linears):
113 | h = self.views_linears[i](h)
114 | h = F.relu(h)
115 |
116 | rgb = self.rgb_linear(h)
117 | outputs = torch.cat([rgb, alpha], -1)
118 | else:
119 | outputs = self.output_linear(h)
120 |
121 | return outputs
122 |
123 | def load_weights_from_keras(self, weights):
124 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
125 |
126 | # Load pts_linears
127 | for i in range(self.D):
128 | idx_pts_linears = 2 * i
129 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
130 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
131 |
132 | # Load feature_linear
133 | idx_feature_linear = 2 * self.D
134 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
135 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
136 |
137 | # Load views_linears
138 | idx_views_linears = 2 * self.D + 2
139 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
140 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
141 |
142 | # Load rgb_linear
143 | idx_rbg_linear = 2 * self.D + 4
144 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
145 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
146 |
147 | # Load alpha_linear
148 | idx_alpha_linear = 2 * self.D + 6
149 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
150 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
151 |
152 |
153 |
154 | def init_weights(m):
155 | if type(m) == nn.Linear:
156 | torch.nn.init.xavier_uniform_(m.weight)
157 |
158 | def init_weights_zeros_(m):
159 | if type(m) == nn.Linear:
160 | torch.nn.init.zeros_(m.weight)
161 | torch.nn.init.zeros_(m.bias)
162 |
163 |
164 |
165 | class MLP_IOR(nn.Module):
166 | def __init__(self, D=8, W=64, input_ch=3 + 3*2*6, output_ch=1, skips=[4],is_index=False):
167 | """
168 | """
169 | super(MLP_IOR, self).__init__()
170 | self.D = D
171 | self.W = W
172 | self.input_ch = input_ch
173 | self.skips = skips
174 | self.is_index = is_index
175 | self.mlp_ior = nn.ModuleList(
176 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
177 |
178 |
179 | self.mlp_ior_end = nn.Linear(W, output_ch)
180 | self.softplus = nn.Softplus(beta=5)
181 | # self.IOR = nn.Parameter(0.5*torch.ones(1))
182 | def forward(self, x):
183 | input_pts = x
184 | h = input_pts
185 | for i, l in enumerate(self.mlp_ior):
186 | h = self.mlp_ior[i](h)
187 | h = self.softplus(h)
188 | if i in self.skips:
189 | h = torch.cat([input_pts, h], -1)
190 |
191 |
192 | outputs = self.mlp_ior_end(h)
193 | shape_out = self.softplus(outputs)
194 |
195 | return shape_out
196 | # Ray helpers
197 | def get_rays(H, W, K, c2w):
198 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
199 | i = i.t()
200 | j = j.t()
201 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
202 | # Rotate ray directions from camera frame to the world frame
203 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
204 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
205 | rays_o = c2w[:3,-1].expand(rays_d.shape)
206 | return rays_o, rays_d
207 |
208 |
209 | def get_rays_np(H, W, K, c2w):
210 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
211 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
212 | # Rotate ray directions from camera frame to the world frame
213 |
214 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
215 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
216 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
217 | return rays_o, rays_d
218 |
219 |
220 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
221 | # Shift ray origins to near plane
222 | t = -(near + rays_o[...,2]) / rays_d[...,2]
223 | rays_o = rays_o + t[...,None] * rays_d
224 |
225 | # Projection
226 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
227 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
228 | o2 = 1. + 2. * near / rays_o[...,2]
229 |
230 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
231 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
232 | d2 = -2. * near / rays_o[...,2]
233 |
234 | rays_o = torch.stack([o0,o1,o2], -1)
235 | rays_d = torch.stack([d0,d1,d2], -1)
236 |
237 | return rays_o, rays_d
238 |
239 |
240 | # Hierarchical sampling (section 5.2)
241 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
242 | # Get pdf
243 | weights = weights + 1e-5 # prevent nans
244 | pdf = weights / torch.sum(weights, -1, keepdim=True)
245 | cdf = torch.cumsum(pdf, -1)
246 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
247 |
248 | # Take uniform samples
249 | if det:
250 | u = torch.linspace(0., 1., steps=N_samples)
251 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
252 | else:
253 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
254 |
255 | # Pytest, overwrite u with numpy's fixed random numbers
256 | if pytest:
257 | np.random.seed(0)
258 | new_shape = list(cdf.shape[:-1]) + [N_samples]
259 | if det:
260 | u = np.linspace(0., 1., N_samples)
261 | u = np.broadcast_to(u, new_shape)
262 | else:
263 | u = np.random.rand(*new_shape)
264 | u = torch.Tensor(u)
265 |
266 | # Invert CDF
267 | u = u.contiguous()
268 | inds = torch.searchsorted(cdf, u, right=True)
269 | below = torch.max(torch.zeros_like(inds-1), inds-1)
270 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
271 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
272 |
273 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
274 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
275 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
276 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
277 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
278 |
279 | denom = (cdf_g[...,1]-cdf_g[...,0])
280 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
281 | t = (u-cdf_g[...,0])/denom
282 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
283 |
284 | return samples
285 |
286 | dtype_long = torch.cuda.LongTensor
287 |
288 | def trilinear_interpolation(inputs,xq,scene_bound):
289 |
290 | res = inputs.shape[0]
291 | _min = scene_bound[0,:]
292 | _max = scene_bound[1,:]
293 | xq = (res-1)*(xq-_min)/(_max-_min)
294 | xq = xq.clip(0,res-1)
295 | x = xq[:,0]
296 | y = xq[:,1]
297 | z = xq[:,2]
298 |
299 | x_0 = torch.clamp(torch.floor(x).type(dtype_long),0, res-1)
300 | x_1 = torch.clamp(x_0 + 1, 0, res-1)
301 |
302 | y_0 = torch.clamp(torch.floor(y).type(dtype_long),0, res-1)
303 | y_1 = torch.clamp(y_0 + 1, 0, res-1)
304 |
305 | z_0 = torch.clamp(torch.floor(z).type(dtype_long),0, res-1)
306 | z_1 = torch.clamp(z_0 + 1, 0, res-1)
307 |
308 |
309 | u, v, w = x-x_0, y-y_0, z-z_0
310 | u = u[:,None]
311 | v = v[:,None]
312 | w = w[:,None]
313 |
314 | c_000 = inputs[x_0,y_0,z_0]
315 | c_001 = inputs[x_0,y_0,z_1]
316 | c_010 = inputs[x_0,y_1,z_0]
317 | c_011 = inputs[x_0,y_1,z_1]
318 | c_100 = inputs[x_1,y_0,z_0]
319 | c_101 = inputs[x_1,y_0,z_1]
320 | c_110 = inputs[x_1,y_1,z_0]
321 | c_111 = inputs[x_1,y_1,z_1]
322 | # print(c_111.shape)
323 | # print(u.shape)
324 |
325 | c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \
326 | (1.0-u)*(1.0-v)*(w)*c_001 + \
327 | (1.0-u)*(v)*(1.0-w)*c_010 + \
328 | (1.0-u)*(v)*(w)*c_011 + \
329 | (u)*(1.0-v)*(1.0-w)*c_100 + \
330 | (u)*(1.0-v)*(w)*c_101 + \
331 | (u)*(v)*(1.0-w)*c_110 + \
332 | (u)*(v)*(w)*c_111
333 |
334 |
335 |
336 | return c_xyz
337 |
338 |
339 |
340 | def lowpass_3d(res,sigma):
341 |
342 | res = res-1
343 | xx, yy,zz = torch.meshgrid(torch.linspace(0, res, res+1), torch.linspace(0, res, res+1), torch.linspace(0, res, res+1))
344 | xx = xx/res
345 | yy = yy/res
346 | zz = zz/res
347 | dist = (torch.square(xx-0.5)+torch.square(yy-0.5)+torch.square(zz-0.5))
348 |
349 | return torch.exp(-dist/(2*sigma**2))
350 |
351 |
352 | def voxel_lowpass_filtering(input_voxel,filter_3d):
353 |
354 | fftn_grid = torch.fft.fftshift(torch.fft.fftn(input_voxel,input_voxel.shape,norm='forward'))
355 | filtered_grid = torch.fft.ifftn(torch.fft.ifftshift(fftn_grid*filter_3d[...,None].to(input_voxel)),norm='forward')
356 |
357 |
358 | return filtered_grid.real
359 |
360 |
361 | eps = 1e-6
362 | def normalizing(vec):
363 |
364 | return vec/(torch.linalg.norm(vec, dim=-1, keepdim=True)+eps)
365 |
366 |
367 |
368 | def get_scene_bound(near,far,H,W,K,poses,min_=.25,max_=.25):
369 |
370 |
371 | N_samples = 2
372 | t_vals = torch.linspace(0., 1., steps=N_samples)
373 | z_vals = near * (1.-t_vals) + far * (t_vals)
374 |
375 | z_vals = z_vals.expand([H,W, N_samples])
376 | scene_bound_min = []
377 | scene_bound_max = []
378 |
379 | for cams in tqdm(range(len(poses))):
380 | rays_o, rays_d = get_rays(H, W, K, poses[cams])
381 | pts = rays_o[...,None,:] + (rays_d[...,None,:]) * z_vals[...,:,None]
382 | pts = pts.reshape(-1,3)
383 | max_pts = torch.quantile((pts),1.-max_,axis=0)
384 | min_pts = torch.quantile((pts),min_,axis=0)
385 | scene_bound_min.append(min_pts)
386 | scene_bound_max.append(max_pts)
387 |
388 |
389 | scene_bound_min = torch.stack(scene_bound_min,0)
390 | scene_bound_max = torch.stack(scene_bound_max,0)
391 |
392 | max_pts = torch.quantile(scene_bound_max,1.-max_,axis=0)
393 | min_pts = torch.quantile(scene_bound_min,min_,axis=0)
394 |
395 | return torch.stack([min_pts,max_pts],0)
396 |
397 |
398 | def get_bb_weights(pts,bounding_box_val,beta=200):
399 |
400 | center = bounding_box_val[0]
401 | rad = bounding_box_val[1]
402 |
403 | x_dist = torch.abs(pts[...,0:1] - torch.tensor(center[0]).to(pts))
404 | y_dist = torch.abs(pts[...,1:2] - torch.tensor(center[1]).to(pts))
405 | z_dist = torch.abs(pts[...,2:3] - torch.tensor(center[2]).to(pts))
406 |
407 | weights = torch.sigmoid(beta*(rad[0]-x_dist))*torch.sigmoid(beta*(rad[1]-y_dist))*torch.sigmoid(beta*(rad[2]-z_dist))
408 |
409 | return 1.0 - weights
410 |
411 |
412 |
413 |
414 | def get_voxel_grid(voxel_res,scene_bound,poses,query_fn,nerf_model,masking=False,bb_vals=None,flag=False):
415 |
416 |
417 | min_bound,max_bound = scene_bound.cpu().numpy()
418 |
419 |
420 | x = np.arange(voxel_res)
421 | grid_x,grid_y,grid_z = np.meshgrid(x,x,x)
422 |
423 | pts_ = np.stack((grid_y,grid_x,grid_z),-1)
424 | pts_ = pts_.reshape(-1,3)
425 | pts_ = pts_/(voxel_res-1)
426 | pts_ = (pts_)*(max_bound-min_bound) + min_bound
427 |
428 | raw_avg = 0
429 |
430 | for cam in tqdm(range(len(poses))):
431 |
432 | raw_stack = []
433 | pose = poses[cam,:,-1]
434 |
435 | batch = 128*128*128
436 | for ii in range(0,pts_.shape[0],batch):
437 | with torch.no_grad():
438 |
439 | new_pts_ = torch.tensor(pts_[ii:ii+batch,:]).to(poses).to(torch.float32)
440 | viewdirs = new_pts_ - pose
441 |
442 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
443 |
444 | raw = query_fn(new_pts_, viewdirs, nerf_model)
445 | raw[...,0:3] = torch.sigmoid(raw[...,0:3])
446 | raw[...,3:4] = F.relu(raw[...,3:4])
447 | raw_stack.append(raw)
448 |
449 | raw_stack = torch.cat(raw_stack,0)
450 |
451 |
452 | raw_avg += raw_stack
453 |
454 | raw_avg = raw_avg/len(poses)
455 | return raw_avg.reshape(voxel_res,voxel_res,voxel_res,4)
456 |
457 |
458 |
459 |
460 |
461 | def raw2outputs(raw, dists):
462 | """Transforms model's predictions to semantically meaningful values.
463 | Args:
464 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
465 | z_vals: [num_rays, num_samples along ray]. Integration time.
466 | rays_d: [num_rays, 3]. Direction of each ray.
467 | Returns:
468 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
469 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
470 | acc_map: [num_rays]. Sum of weights along each ray.
471 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
472 | depth_map: [num_rays]. Estimated distance to object.
473 | """
474 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
475 |
476 |
477 |
478 | alpha = raw2alpha(raw[...,3] , dists) # [N_rays, N_samples]
479 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
480 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
481 |
482 | return weights
483 |
--------------------------------------------------------------------------------
/find_bounding_box.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys,cv2
3 | import numpy as np
4 | import imageio
5 | import json
6 | import random
7 | import time
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from tqdm import tqdm, trange
12 |
13 | from run_nerf_helpers import *
14 |
15 | from load_llff import load_llff_data
16 | from load_deepvoxels import load_dv_data
17 | from load_blender import load_blender_data
18 | from load_LINEMOD import load_LINEMOD_data
19 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
20 |
21 |
22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 | np.random.seed(0)
24 | DEBUG = False
25 |
26 |
27 | def batchify(fn, chunk):
28 | """Constructs a version of 'fn' that applies to smaller batches.
29 | """
30 | if chunk is None:
31 | return fn
32 | def ret(inputs):
33 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
34 | return ret
35 |
36 |
37 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
38 | """Prepares inputs and applies network 'fn'.
39 | """
40 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
41 | embedded = embed_fn(inputs_flat)
42 |
43 | if viewdirs is not None:
44 | input_dirs = viewdirs[:,None].expand(inputs.shape)
45 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
46 | embedded_dirs = embeddirs_fn(input_dirs_flat)
47 | embedded = torch.cat([embedded, embedded_dirs], -1)
48 |
49 | outputs_flat = batchify(fn, netchunk)(embedded)
50 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
51 | return outputs
52 |
53 |
54 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
55 | """Render rays in smaller minibatches to avoid OOM.
56 | """
57 | all_ret = {}
58 | for i in range(0, rays_flat.shape[0], chunk):
59 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
60 | for k in ret:
61 | if k not in all_ret:
62 | all_ret[k] = []
63 | all_ret[k].append(ret[k])
64 |
65 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
66 | return all_ret
67 |
68 |
69 |
70 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
71 | near=0., far=1.,
72 | use_viewdirs=False, c2w_staticcam=None,
73 | **kwargs):
74 | """Render rays
75 | Args:
76 | H: int. Height of image in pixels.
77 | W: int. Width of image in pixels.
78 | focal: float. Focal length of pinhole camera.
79 | chunk: int. Maximum number of rays to process simultaneously. Used to
80 | control maximum memory usage. Does not affect final results.
81 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
82 | each example in batch.
83 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
84 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
85 | near: float or array of shape [batch_size]. Nearest distance for a ray.
86 | far: float or array of shape [batch_size]. Farthest distance for a ray.
87 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
88 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
89 | camera while using other c2w argument for viewing directions.
90 | Returns:
91 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
92 | disp_map: [batch_size]. Disparity map. Inverse of depth.
93 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
94 | extras: dict with everything returned by render_rays().
95 | """
96 | if c2w is not None:
97 | # special case to render full image
98 | rays_o, rays_d = get_rays(H, W, K, c2w)
99 | else:
100 | # use provided ray batch
101 | rays_o, rays_d = rays
102 |
103 | if use_viewdirs:
104 | # provide ray directions as input
105 | viewdirs = rays_d
106 | if c2w_staticcam is not None:
107 | # special case to visualize effect of viewdirs
108 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
109 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
110 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
111 |
112 | sh = rays_d.shape # [..., 3]
113 | if ndc:
114 | # for forward facing scenes
115 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
116 |
117 | # Create ray batch
118 | rays_o = torch.reshape(rays_o, [-1,3]).float()
119 | rays_d = torch.reshape(rays_d, [-1,3]).float()
120 |
121 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
122 | rays = torch.cat([rays_o, rays_d, near, far], -1)
123 | if use_viewdirs:
124 | rays = torch.cat([rays, viewdirs], -1)
125 |
126 | # Render and reshape
127 | all_ret = batchify_rays(rays, chunk, **kwargs)
128 | for k in all_ret:
129 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
130 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
131 |
132 | k_extract = ['rgb_map', 'disp_map', 'acc_map','weights']
133 | ret_list = [all_ret[k] for k in k_extract]
134 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
135 | return ret_list + [ret_dict]
136 |
137 |
138 | def create_nerf(args):
139 | """Instantiate NeRF's MLP model.
140 | """
141 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
142 |
143 | input_ch_views = 0
144 | embeddirs_fn = None
145 | if args.use_viewdirs:
146 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
147 | output_ch = 5 if args.N_importance > 0 else 4
148 | skips = [4]
149 | model = NeRF(D=args.netdepth, W=args.netwidth,
150 | input_ch=input_ch, output_ch=output_ch, skips=skips,
151 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
152 | grad_vars = list(model.parameters())
153 |
154 | model_fine = None
155 | if args.N_importance > 0:
156 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
157 | input_ch=input_ch, output_ch=output_ch, skips=skips,
158 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
159 | grad_vars += list(model_fine.parameters())
160 |
161 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
162 | embed_fn=embed_fn,
163 | embeddirs_fn=embeddirs_fn,
164 | netchunk=args.netchunk)
165 |
166 | # Create optimizer
167 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
168 |
169 | start = 0
170 | basedir = args.basedir
171 | expname = args.expname
172 |
173 | ##########################
174 |
175 | # Load checkpoints
176 |
177 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f]
178 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path))
179 | print('Found ckpts', ckpts)
180 | if len(ckpts) > 0 and not args.no_reload:
181 | ckpt_path = ckpts[-1]
182 | print('Reloading from', ckpt_path)
183 | ckpt = torch.load(ckpt_path)
184 |
185 | start = ckpt['global_step']
186 |
187 | # Load model
188 | model.load_state_dict(ckpt['network_fn_state_dict'])
189 | if model_fine is not None:
190 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
191 |
192 | ##########################
193 |
194 | render_kwargs_train = {
195 | 'network_query_fn' : network_query_fn,
196 | 'perturb' : args.perturb,
197 | 'N_importance' : args.N_importance,
198 | 'network_fine' : model_fine,
199 | 'N_samples' : args.N_samples,
200 | 'network_fn' : model,
201 | 'use_viewdirs' : args.use_viewdirs,
202 | 'white_bkgd' : args.white_bkgd,
203 | 'raw_noise_std' : args.raw_noise_std,
204 | }
205 |
206 | # NDC only good for LLFF-style forward facing data
207 | if args.dataset_type != 'llff' or args.no_ndc:
208 | print('Not ndc!')
209 | render_kwargs_train['ndc'] = False
210 | render_kwargs_train['lindisp'] = args.lindisp
211 |
212 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
213 | render_kwargs_test['perturb'] = False
214 | render_kwargs_test['raw_noise_std'] = 0.
215 |
216 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
217 |
218 |
219 | def raw2outputs(raw, z_vals, rays_d,raw_noise_std=0, white_bkgd=False, pytest=False):
220 | """Transforms model's predictions to semantically meaningful values.
221 | Args:
222 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
223 | z_vals: [num_rays, num_samples along ray]. Integration time.
224 | rays_d: [num_rays, 3]. Direction of each ray.
225 | Returns:
226 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
227 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
228 | acc_map: [num_rays]. Sum of weights along each ray.
229 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
230 | depth_map: [num_rays]. Estimated distance to object.
231 | """
232 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
233 |
234 | dists = z_vals[...,1:] - z_vals[...,:-1]
235 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
236 |
237 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
238 |
239 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
240 | noise = 0.
241 | if raw_noise_std > 0.:
242 | noise = torch.randn(raw[...,3].shape) * raw_noise_std
243 |
244 | # Overwrite randomly sampled data if pytest
245 | if pytest:
246 | np.random.seed(0)
247 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
248 | noise = torch.Tensor(noise)
249 |
250 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
251 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
252 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
253 |
254 | #
255 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
256 |
257 | depth_map = torch.sum(weights * z_vals, -1)
258 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
259 | acc_map = torch.sum(weights, -1)
260 |
261 | if white_bkgd:
262 | rgb_map = rgb_map + (1.-acc_map[...,None])
263 |
264 | return rgb_map, disp_map, acc_map, weights, depth_map
265 |
266 |
267 | def render_rays(ray_batch,
268 | network_fn,
269 | network_query_fn,
270 | N_samples,
271 | masking,
272 | bb_vals,
273 | retraw=True,
274 | lindisp=False,
275 | perturb=0.,
276 | N_importance=0,
277 | network_fine=None,
278 | white_bkgd=False,
279 | raw_noise_std=0.,
280 | verbose=False,
281 | pytest=False):
282 |
283 | N_rays = ray_batch.shape[0]
284 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
285 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
286 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
287 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
288 |
289 | t_vals = torch.linspace(0., 1., steps=N_samples)
290 | if not lindisp:
291 | z_vals = near * (1.-t_vals) + far * (t_vals)
292 | else:
293 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
294 |
295 | z_vals = z_vals.expand([N_rays, N_samples])
296 |
297 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
298 |
299 |
300 | raw = network_query_fn(pts, viewdirs, network_fn)
301 |
302 |
303 | if masking:
304 | weights = get_bb_weights(pts,bb_vals)
305 | raw = raw*weights
306 |
307 |
308 |
309 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
310 |
311 | if N_importance > 0:
312 |
313 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
314 |
315 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
316 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
317 | z_samples = z_samples.detach()
318 |
319 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
320 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
321 |
322 | run_fn = network_fn if network_fine is None else network_fine
323 | raw = network_query_fn(pts, viewdirs, run_fn)
324 |
325 | if masking:
326 | weights = get_bb_weights(pts,bb_vals)
327 | raw = raw*weights
328 |
329 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
330 |
331 |
332 | ret = {'rgb_map' : rgb_map, 'disp_map' : depth_map, 'acc_map' : z_vals,'weights':rays_o + rays_d*depth_map[:,None]}
333 | if retraw:
334 | ret['raw'] = pts
335 | if N_importance > 0:
336 | ret['rgb0'] = rgb_map_0
337 | ret['disp0'] = disp_map_0
338 | ret['acc0'] = acc_map_0
339 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
340 |
341 |
342 | for k in ret:
343 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
344 | print(f"! [Numerical Error] {k} contains nan or inf.")
345 |
346 | return ret
347 |
348 |
349 |
350 | def config_parser():
351 |
352 | import configargparse
353 | parser = configargparse.ArgumentParser()
354 | parser.add_argument('--config', is_config_file=True,
355 | help='config file path')
356 | parser.add_argument("--expname", default='glass_ball_new',type=str,
357 | help='experiment name')
358 | parser.add_argument("--basedir", type=str, default='logs',
359 | help='where to store ckpts and logs')
360 | parser.add_argument("--datadir", type=str, default='data\\glass_ball',
361 | help='input data directory')
362 |
363 | # training options
364 | parser.add_argument("--netdepth", type=int, default=8,
365 | help='layers in network')
366 | parser.add_argument("--netwidth", type=int, default=256,
367 | help='channels per layer')
368 | parser.add_argument("--netdepth_fine", type=int, default=8,
369 | help='layers in fine network')
370 | parser.add_argument("--netwidth_fine", type=int, default=256,
371 | help='channels per layer in fine network')
372 | parser.add_argument("--N_rand", type=int, default=32*32,
373 | help='batch size (number of random rays per gradient step)')
374 | parser.add_argument("--lrate", type=float, default=5e-4,
375 | help='learning rate')
376 | parser.add_argument("--lrate_decay", type=int, default=250,
377 | help='exponential learning rate decay (in 1000 steps)')
378 | parser.add_argument("--chunk", type=int, default=1024*32,
379 | help='number of rays processed in parallel, decrease if running out of memory')
380 | parser.add_argument("--netchunk", type=int, default=1024*64,
381 | help='number of pts sent through network in parallel, decrease if running out of memory')
382 | parser.add_argument("--no_batching", default=False,action='store_true',
383 | help='only take random rays from 1 image at a time')
384 | parser.add_argument("--no_reload", action='store_true',
385 | help='do not reload weights from saved ckpt')
386 | parser.add_argument("--model_path", type=str, default='model_weights',
387 | help='path to trained model weights')
388 | # rendering options
389 | parser.add_argument("--N_samples", type=int, default=64,
390 | help='number of coarse samples per ray')
391 | parser.add_argument("--N_importance", type=int, default=64,
392 | help='number of additional fine samples per ray')
393 | parser.add_argument("--perturb", type=float, default=1.,
394 | help='set to 0. for no jitter, 1. for jitter')
395 | parser.add_argument("--use_viewdirs", default=True,action='store_true',
396 | help='use full 5D input instead of 3D')
397 | parser.add_argument("--use_mask", default=False,action='store_true',
398 | help='use full 5D input instead of 3D')
399 | parser.add_argument("--i_embed", type=int, default=0,
400 | help='set 0 for default positional encoding, -1 for none')
401 | parser.add_argument("--multires", type=int, default=10,
402 | help='log2 of max freq for positional encoding (3D location)')
403 | parser.add_argument("--multires_views", type=int, default=4,
404 | help='log2 of max freq for positional encoding (2D direction)')
405 | parser.add_argument("--raw_noise_std", type=float, default=1.,
406 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
407 |
408 | parser.add_argument("--render_only", action='store_true',
409 | help='do not optimize, reload weights and render out render_poses path')
410 | parser.add_argument("--render_test", action='store_true',
411 | help='render the test set instead of render_poses path')
412 | parser.add_argument("--render_factor", type=int, default=0,
413 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
414 |
415 | # training options
416 | parser.add_argument("--precrop_iters", type=int, default=0,
417 | help='number of steps to train on central crops')
418 | parser.add_argument("--precrop_frac", type=float,
419 | default=.5, help='fraction of img taken for central crops')
420 |
421 | # dataset options
422 | parser.add_argument("--dataset_type", type=str, default='llff',
423 | help='options: llff / blender / deepvoxels')
424 | parser.add_argument("--testskip", type=int, default=8,
425 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
426 |
427 | ## deepvoxels flags
428 | parser.add_argument("--shape", type=str, default='greek',
429 | help='options : armchair / cube / greek / vase')
430 |
431 | ## blender flags
432 | parser.add_argument("--white_bkgd", action='store_true',
433 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
434 | parser.add_argument("--half_res", action='store_true',
435 | help='load blender synthetic data at 400x400 instead of 800x800')
436 |
437 | ## llff flags
438 | parser.add_argument("--factor", type=int, default=1,
439 | help='downsample factor for LLFF images')
440 | parser.add_argument("--no_ndc", default = True,action='store_true',
441 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
442 | parser.add_argument("--lindisp", action='store_true',
443 | help='sampling linearly in disparity rather than depth')
444 | parser.add_argument("--spherify", default = True,action='store_true',
445 | help='set for spherical 360 scenes')
446 | parser.add_argument("--llffhold", type=int, default=10,
447 | help='will take every 1/N images as LLFF test set, paper uses 8')
448 |
449 | parser.add_argument("--quantile", type=float, default=0.03,
450 | help='q-th quantiles between [0 1]')
451 | parser.add_argument("--ratio", type=float, default=1.1,
452 | help='enlarging the bounding box with ratio')
453 | parser.add_argument("--num_im", type=int, default=10,
454 | help='will take every 1/N images to visualize')
455 | parser.add_argument("--reload_bb", default=False, action='store_true',
456 | help='reload an existing bounding box value')
457 |
458 |
459 | return parser
460 |
461 | def find_box():
462 |
463 | parser = config_parser()
464 | args = parser.parse_args()
465 |
466 | # Load data
467 | K = None
468 | if args.dataset_type == 'llff':
469 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
470 | recenter=True, bd_factor=.75,
471 | spherify=args.spherify)
472 | hwf = poses[0,:3,-1]
473 | poses = poses[:,:3,:4]
474 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
475 | if not isinstance(i_test, list):
476 | i_test = [i_test]
477 |
478 | if args.num_im > 0:
479 | print('number of images to annotate: ', args.num_im)
480 | i_test = np.arange(images.shape[0])[::args.num_im]
481 |
482 | i_val = i_test
483 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
484 | (i not in i_test and i not in i_val)])
485 |
486 | print('DEFINING BOUNDS')
487 | if args.no_ndc:
488 | near = np.ndarray.min(bds) * .9
489 | far = np.ndarray.max(bds) * 1.
490 |
491 | else:
492 | near = 0.
493 | far = 1.
494 | print('NEAR FAR', near, far)
495 |
496 |
497 | # Cast intrinsics to right types
498 | H, W, focal = hwf
499 | H, W = int(H), int(W)
500 | hwf = [H, W, focal]
501 |
502 | if K is None:
503 | K = np.array([
504 | [focal, 0, 0.5*W],
505 | [0, focal, 0.5*H],
506 | [0, 0, 1]
507 | ])
508 |
509 | if args.render_test:
510 | render_poses = np.array(poses[i_test])
511 |
512 | # Create log dir and copy the config file
513 | basedir = args.basedir
514 | expname = args.expname
515 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
516 | f = os.path.join(basedir, expname, 'args.txt')
517 | with open(f, 'w') as file:
518 | for arg in sorted(vars(args)):
519 | attr = getattr(args, arg)
520 | file.write('{} = {}\n'.format(arg, attr))
521 | if args.config is not None:
522 | f = os.path.join(basedir, expname, 'config.txt')
523 | with open(f, 'w') as file:
524 | file.write(open(args.config, 'r').read())
525 |
526 | # Create nerf model
527 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
528 | global_step = start
529 |
530 | bds_dict = {
531 | 'near' : near,
532 | 'far' : far,
533 | }
534 | render_kwargs_test.update(bds_dict)
535 |
536 | render_poses = torch.Tensor(render_poses).to(device)
537 | N_rand = args.N_rand
538 | use_batching = not args.no_batching
539 |
540 | poses = torch.Tensor(poses).to(device)
541 |
542 | if not args.reload_bb:
543 | point_matrix = []
544 |
545 | counter = 0
546 | def mousePoints(event,x,y,flags,params):
547 | if event == cv2.EVENT_LBUTTONDOWN:
548 | point_matrix.append([x,y])
549 | cv2.circle(img,(x,y),5,(0,0,255),cv2.FILLED)
550 |
551 |
552 | num_ = len(i_test) #len(poses)//10
553 | cam_ind = i_test# np.arange(0,len(poses),num_)
554 | counter = 0
555 |
556 | img = to8b(images[cam_ind[counter],:,:,3::-1])
557 | rays_o, rays_d = get_rays(H, W, K, poses[cam_ind[counter]])
558 | rays = torch.stack([rays_o, rays_d],0)
559 | all_rays = []
560 |
561 | text = "Please press A to continue to the next image "
562 | coordinates = (10,20)
563 | font = cv2.FONT_HERSHEY_SIMPLEX
564 | fontScale = 0.75
565 | color = (0,0,255)
566 | thickness = 0
567 |
568 | while True:
569 |
570 |
571 | img = cv2.putText(img, text, coordinates, font, fontScale, color, thickness, cv2.LINE_AA)
572 | cv2.imshow(" Image ", img)
573 |
574 | cv2.setMouseCallback(" Image ", mousePoints)
575 | key = cv2.waitKey(1) & 0xFF
576 | if key == 27:
577 | break
578 |
579 | if key == ord('a'):
580 | counter = counter+1
581 | pts_inside = np.array(point_matrix)
582 | all_rays.append(rays[:,pts_inside[:,1],pts_inside[:,0],:])
583 | if counter == num_:
584 | break
585 | rays_o, rays_d = get_rays(H, W, K, poses[cam_ind[counter]])
586 | rays = torch.stack([rays_o, rays_d],0)
587 | point_matrix = []
588 | img = to8b(images[cam_ind[counter],:,:,3::-1])
589 |
590 |
591 |
592 |
593 | cv2.destroyAllWindows()
594 |
595 | all_rays = torch.cat(all_rays,1)
596 |
597 | print('Started finding the transparent object bounding box')
598 | render_kwargs_test['bb_vals'] = []
599 | render_kwargs_test['masking'] = False
600 | render_kwargs_test['N_importance'] = 256
601 |
602 | with torch.no_grad():
603 | _, _, _,weights, _ = render(H, W,K, chunk=args.chunk, rays= all_rays,
604 | **render_kwargs_test)
605 |
606 | point_cloud = weights.cpu().numpy()
607 |
608 |
609 | max_pts = np.percentile((point_cloud),100 - args.quantile,axis=0)
610 | min_pts = np.percentile((point_cloud),args.quantile,axis=0)
611 | center = 0.5*(min_pts+max_pts)
612 | radi = max_pts-center
613 | radi = args.ratio*radi
614 |
615 | bounding_box_vals = np.stack((center,radi),0)
616 |
617 | path_bounding_box = os.path.join(basedir, expname,"bounding_box")
618 | os.makedirs(path_bounding_box, exist_ok=True)
619 | path = os.path.join(path_bounding_box, 'bounding_box_vals.npy')
620 | np.save(path,bounding_box_vals)
621 | print('Finised finding the transparent object bounding box')
622 |
623 |
624 | else:
625 |
626 | path_bounding_box = os.path.join(basedir, expname,"bounding_box")
627 | path = os.path.join(path_bounding_box, 'bounding_box_vals.npy')
628 | bounding_box_vals = np.load(path)
629 |
630 |
631 |
632 |
633 | render_kwargs_test['bb_vals'] = bounding_box_vals
634 | render_kwargs_test['masking'] = True
635 | render_kwargs_test['N_importance'] = 64
636 |
637 |
638 |
639 | print('Rendering a test image with and without transparent object')
640 |
641 | path_bounding_box = os.path.join(basedir, expname,"bounding_box")
642 |
643 | img_i = i_test[5]
644 |
645 | with torch.no_grad():
646 | rgb_gt_without, _,_,_,_ = render(H, W,K, chunk=args.chunk, c2w= poses[img_i],
647 | **render_kwargs_test)
648 |
649 |
650 | imageio.imwrite(os.path.join(path_bounding_box, 'without_{:03d}.png'.format(img_i)), to8b(rgb_gt_without.cpu().numpy()))
651 |
652 | render_kwargs_test['masking'] = False
653 |
654 | with torch.no_grad():
655 | rgb_gt_with, _,_,_,_ = render(H, W,K, chunk=args.chunk, c2w= poses[img_i],
656 | **render_kwargs_test)
657 |
658 | imageio.imwrite(os.path.join(path_bounding_box, 'with_{:03d}.png'.format(img_i)), to8b(rgb_gt_with.cpu().numpy()))
659 |
660 | print('Done!')
661 | print('Saved in the folder \"bounding_box\"')
662 |
663 |
664 | print('Finding the region in each image crossing the 3D bounding box')
665 |
666 | path_to_mask_out = os.path.join(basedir, expname, 'masked_regions')
667 | os.makedirs(path_to_mask_out, exist_ok=True)
668 |
669 |
670 | for img_i in tqdm(range(len(poses))):
671 |
672 | pose = poses[img_i, :3,:4]
673 | rays_o, rays_d = get_rays(H, W, K, pose)
674 | t_vals = torch.linspace(0., 1., steps=128)
675 | z_vals = near * (1.-t_vals) + far * (t_vals)
676 | pts_ = rays_o[...,None,:] + rays_d[...,None,:]*z_vals[:,None]
677 | mask = 1. - get_bb_weights(pts_,bounding_box_vals)
678 | diff = torch.sum(mask,2)>0.0
679 | imageio.imwrite(os.path.join(path_to_mask_out, 'img_%0.3d.png'%(img_i)), to8b(diff.cpu().numpy()))
680 |
681 | print('Done!')
682 | print('Saved in the folder \"masked_regions\"')
683 |
684 |
685 | if __name__=='__main__':
686 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
687 |
688 | find_box()
--------------------------------------------------------------------------------
/render_model.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys,cv2
3 | import numpy as np
4 | import imageio
5 | import json
6 | import random
7 | import time
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from tqdm import tqdm, trange
12 | from torchdiffeq import odeint_adjoint as odeint
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | from run_nerf_helpers import *
17 |
18 | from load_llff import load_llff_data
19 | from load_deepvoxels import load_dv_data
20 | from load_blender import load_blender_data
21 | from load_LINEMOD import load_LINEMOD_data
22 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
23 |
24 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
25 |
26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | np.random.seed(0)
28 | DEBUG = False
29 |
30 |
31 | def batchify(fn, chunk):
32 | """Constructs a version of 'fn' that applies to smaller batches.
33 | """
34 | if chunk is None:
35 | return fn
36 | def ret(inputs):
37 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
38 | return ret
39 |
40 |
41 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
42 | """Prepares inputs and applies network 'fn'.
43 | """
44 | embedded = embed_fn(inputs)
45 |
46 | if viewdirs is not None:
47 |
48 | embedded_dirs = embeddirs_fn(viewdirs)
49 | embedded = torch.cat([embedded, embedded_dirs], -1)
50 |
51 | outputs_flat = batchify(fn, netchunk)(embedded)
52 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
53 | return outputs
54 |
55 |
56 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64):
57 |
58 | embedded = embed_fn(inputs)
59 |
60 |
61 | outputs = batchify(fn, netchunk)(embedded)
62 |
63 | return outputs
64 |
65 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
66 | """Render rays in smaller minibatches to avoid OOM.
67 | """
68 | all_ret = {}
69 | for i in range(0, rays_flat.shape[0], chunk):
70 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
71 | for k in ret:
72 | if k not in all_ret:
73 | all_ret[k] = []
74 | all_ret[k].append(ret[k])
75 |
76 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
77 | return all_ret
78 |
79 |
80 |
81 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
82 | near=0., far=1.,
83 | use_viewdirs=False, c2w_staticcam=None,
84 | **kwargs):
85 | """Render rays
86 | Args:
87 | H: int. Height of image in pixels.
88 | W: int. Width of image in pixels.
89 | focal: float. Focal length of pinhole camera.
90 | chunk: int. Maximum number of rays to process simultaneously. Used to
91 | control maximum memory usage. Does not affect final results.
92 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
93 | each example in batch.
94 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
95 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
96 | near: float or array of shape [batch_size]. Nearest distance for a ray.
97 | far: float or array of shape [batch_size]. Farthest distance for a ray.
98 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
99 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
100 | camera while using other c2w argument for viewing directions.
101 | Returns:
102 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
103 | disp_map: [batch_size]. Disparity map. Inverse of depth.
104 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
105 | extras: dict with everything returned by render_rays().
106 | """
107 | if c2w is not None:
108 | # special case to render full image
109 | rays_o, rays_d = get_rays(H, W, K, c2w)
110 | else:
111 | # use provided ray batch
112 | rays_o, rays_d = rays
113 |
114 | if use_viewdirs:
115 | # provide ray directions as input
116 | viewdirs = rays_d
117 | if c2w_staticcam is not None:
118 | # special case to visualize effect of viewdirs
119 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
120 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
121 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
122 |
123 | sh = rays_d.shape # [..., 3]
124 | if ndc:
125 | # for forward facing scenes
126 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
127 |
128 | # Create ray batch
129 | rays_o = torch.reshape(rays_o, [-1,3]).float()
130 | rays_d = torch.reshape(rays_d, [-1,3]).float()
131 |
132 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
133 | rays = torch.cat([rays_o, rays_d, near, far], -1)
134 | if use_viewdirs:
135 | rays = torch.cat([rays, viewdirs], -1)
136 |
137 | # Render and reshape
138 | all_ret = batchify_rays(rays, chunk, **kwargs)
139 | for k in all_ret:
140 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
141 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
142 |
143 | k_extract = ['rgb_map']
144 | ret_list = [all_ret[k] for k in k_extract]
145 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
146 | return ret_list + [ret_dict]
147 |
148 |
149 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
150 |
151 | H, W, focal = hwf
152 |
153 | if render_factor!=0:
154 | # Render downsampled for speed
155 | H = H//render_factor
156 | W = W//render_factor
157 | focal = focal/render_factor
158 |
159 | rgbs = []
160 | # disps = []
161 |
162 | t = time.time()
163 | for i, c2w in enumerate(tqdm(render_poses)):
164 | print(i, time.time() - t)
165 | t = time.time()
166 | rgb, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
167 | rgbs.append(rgb.cpu().numpy())
168 | # disps.append(disp.cpu().numpy())
169 | # if i==0:
170 | # print(rgb.shape, disp.shape)
171 |
172 | """
173 | if gt_imgs is not None and render_factor==0:
174 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
175 | print(p)
176 | """
177 |
178 | if savedir is not None:
179 | rgb8 = to8b(rgbs[-1])
180 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
181 | imageio.imwrite(filename, rgb8)
182 |
183 |
184 | rgbs = np.stack(rgbs, 0)
185 |
186 | return rgbs
187 |
188 | def create_models(args):
189 | """Instantiate NeRF's MLP model.
190 | """
191 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
192 |
193 | input_ch_views = 0
194 | embeddirs_fn = None
195 | if args.use_viewdirs:
196 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
197 | output_ch = 4
198 | skips = [4]
199 | model = NeRF(D=args.netdepth, W=args.netwidth,
200 | input_ch=input_ch, output_ch=output_ch, skips=skips,
201 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
202 |
203 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
204 | input_ch=input_ch, output_ch=output_ch, skips=skips,
205 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
206 |
207 |
208 |
209 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed)
210 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device)
211 | model_ior.apply(init_weights)
212 |
213 | model_inside = NeRF(D=args.netdepth, W=args.netwidth_inside,
214 | input_ch=input_ch, output_ch=output_ch, skips=skips,
215 | input_ch_views=input_ch_views, use_viewdirs=True).to(device)
216 |
217 | grad_vars = list(model_inside.parameters())
218 |
219 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
220 | embed_fn=embed_fn,
221 | embeddirs_fn=embeddirs_fn,
222 | netchunk=args.netchunk)
223 |
224 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn,
225 | embed_fn = embed_fn_ior,
226 | netchunk=args.netchunk)
227 | # Create optimizer
228 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
229 |
230 | start = 0
231 | basedir = args.basedir
232 | expname = args.expname
233 |
234 | ##########################
235 |
236 | # Load checkpoints
237 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f]
238 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path))
239 |
240 |
241 | print('Found ckpts', ckpts)
242 |
243 | if len(ckpts) > 0 and not args.no_reload:
244 | ckpt_path = ckpts[-1]
245 | print('Reloading from', ckpt_path)
246 | ckpt = torch.load(ckpt_path)
247 |
248 | start = ckpt['global_step']
249 |
250 | # Load model
251 | model.load_state_dict(ckpt['network_fn_state_dict'])
252 |
253 | if model_fine is not None:
254 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
255 |
256 | if "network_ior_state_dict" in ckpt:
257 | model_ior.load_state_dict(ckpt['network_ior_state_dict'])
258 |
259 | if "network_inside_state_dict" in ckpt:
260 | model_inside.load_state_dict(ckpt['network_inside_state_dict'])
261 |
262 |
263 |
264 | ##########################
265 |
266 | render_kwargs_train = {
267 |
268 | 'network_query_fn' : network_query_fn,
269 | 'network_query_fn_ior': network_query_fn_ior,
270 | 'perturb' : args.perturb,
271 | 'N_samples' : args.N_samples,
272 | 'network_fn' : model,
273 | 'network_fine' : model_fine,
274 | 'network_ior': model_ior,
275 | 'network_inside':model_inside,
276 | 'use_viewdirs' : args.use_viewdirs,
277 | 'white_bkgd' : args.white_bkgd,
278 | 'raw_noise_std' : args.raw_noise_std,
279 | }
280 |
281 | # NDC only good for LLFF-style forward facing data
282 | if args.dataset_type != 'llff' or args.no_ndc:
283 | print('Not ndc!')
284 | render_kwargs_train['ndc'] = False
285 | render_kwargs_train['lindisp'] = args.lindisp
286 |
287 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
288 | render_kwargs_test['perturb'] = False
289 | render_kwargs_test['raw_noise_std'] = 0.
290 |
291 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
292 |
293 |
294 | class IoR_emissionAbsorptionODE(nn.Module):
295 |
296 | def __init__(self,query_nerf,model_nerf,nerf_inside,query_ior,model_ior,n_samples,step_size,vals,viewdir,mode):
297 | super(IoR_emissionAbsorptionODE, self).__init__()
298 |
299 | self.query_ior = query_ior
300 | self.model_ior = model_ior
301 |
302 | self.query_nerf = query_nerf
303 | self.model_nerf = model_nerf
304 |
305 | self.nerf_inside = nerf_inside
306 |
307 | self.n_samples = n_samples
308 | self.step_size = step_size
309 | self.vals = vals
310 | self.viewdir = viewdir
311 | self.mode = mode
312 |
313 | def forward(self,t,y):
314 |
315 |
316 | pts = y[:,0:3]
317 | ray_dir = y[:,3:6]
318 | transmition_ = y[:,9:10]
319 |
320 |
321 | # querying the original radiance field for the entire scene
322 | with torch.no_grad():
323 | raw = self.query_nerf(pts,self.viewdir,self.model_nerf)
324 |
325 | rgb = torch.sigmoid(raw[...,:3])
326 | density = F.relu(raw[...,3:])
327 |
328 | # determining whether a point is inside the bounding box or not
329 | weights= get_bb_weights(pts,self.vals)
330 |
331 | # finding the points inside the bounding box
332 |
333 | insides = torch.where(weights<1.0)[0]
334 | ior_grad = torch.zeros_like(pts)
335 | steps = torch.ones_like(density)*self.step_size
336 |
337 |
338 |
339 |
340 | # calculating NeRF model only for the points inside the bounding box
341 | if self.mode != 0:
342 |
343 | if len(insides) > 0:
344 |
345 |
346 | # querying the radiance field for the content inside the bounding box
347 | pts_inside = pts[insides,:]
348 | viewdir_inside = self.viewdir[insides,:]
349 |
350 |
351 | if self.mode == 1:
352 | density = density*weights
353 |
354 | if self.mode == 2:
355 | raw = self.query_nerf(pts_inside,viewdir_inside,self.nerf_inside)
356 | rgb_inside = torch.sigmoid(raw[...,:3])
357 | density_inside = F.relu(raw[...,3:])
358 |
359 |
360 | # linearly blnding the radiance field for the content inside and outisde the bounding box
361 | density[insides,:] = density[insides,:]*(weights[insides,:])+ density_inside*(1.-weights[insides,:])
362 | rgb[insides,:] = rgb[insides,:]*(weights[insides,:])+ rgb_inside*(1.-weights[insides,:])
363 |
364 |
365 | # computing the gradient of IoR
366 | with torch.enable_grad():
367 |
368 | dn = torch.autograd.functional.vjp(lambda x : self.query_ior(x,self.model_ior), pts_inside ,v = torch.ones_like(pts_inside[:,0:1]))
369 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:])
370 |
371 |
372 | dv_ds = steps*ior_grad
373 | dx_ds = steps*normalizing(ray_dir)
374 | alpha = 1. - torch.exp(-density*steps/self.n_samples)
375 | alpha = alpha.clip(0.,1.)
376 | dc_dt = transmition_*rgb*alpha*self.n_samples
377 | dT_dt = -transmition_*alpha*self.n_samples
378 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1)
379 |
380 |
381 | return dy_dt
382 |
383 |
384 | def render_rays(ray_batch,
385 | N_samples,
386 | network_query_fn_ior,
387 | bb_vals,
388 | network_fine,
389 | network_ior,
390 | network_inside,
391 | mode,
392 | network_fn=None,
393 | network_query_fn=None,
394 | retraw=False,
395 | lindisp=False,
396 | perturb=0.,
397 | N_importance=0,
398 | white_bkgd=False,
399 | raw_noise_std=0.,
400 | verbose=False,
401 | pytest=False):
402 |
403 | N_rays = ray_batch.shape[0]
404 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
405 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
406 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
407 |
408 | t_vals = torch.linspace(0., 1., steps=N_samples)
409 |
410 |
411 | pts = rays_o + normalizing(rays_d)*near
412 | viewdir = normalizing(rays_d)
413 | color_ = torch.zeros_like(pts)
414 | transmition_ = torch.ones((N_rays,1))
415 |
416 |
417 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device)
418 | step_size = (far[0]-near[0])
419 | output = odeint(IoR_emissionAbsorptionODE(network_query_fn,network_fine,network_inside,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,viewdir,mode), y0, t_vals,method='euler')
420 |
421 | pts = output[:,:,0:3].permute(1,0,2)
422 |
423 | rgb_map = output[-1,:,6:9]
424 |
425 |
426 | ret = {'rgb_map' : rgb_map}
427 | if retraw:
428 | ret['raw'] = pts
429 |
430 | for k in ret:
431 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
432 | print(f"! [Numerical Error] {k} contains nan or inf.")
433 |
434 | return ret
435 |
436 |
437 |
438 | def config_parser():
439 |
440 | import configargparse
441 | parser = configargparse.ArgumentParser()
442 | parser.add_argument('--config', is_config_file=True,
443 | help='config file path')
444 | parser.add_argument("--expname", default='glass_ball_new',type=str,
445 | help='experiment name')
446 | parser.add_argument("--basedir", type=str, default='logs',
447 | help='where to store ckpts and logs')
448 | parser.add_argument("--datadir", type=str, default='data\\glass_ball',
449 | help='input data directory')
450 | parser.add_argument("--render_from_path", type=str, default=None,
451 | help='input data directory')
452 |
453 |
454 | # training options
455 | parser.add_argument("--netdepth", type=int, default=8,
456 | help='layers in network')
457 | parser.add_argument("--netwidth", type=int, default=256,
458 | help='channels per layer')
459 | parser.add_argument("--netdepth_fine", type=int, default=8,
460 | help='layers in fine network')
461 | parser.add_argument("--netwidth_fine", type=int, default=256,
462 | help='channels per layer in fine network')
463 | parser.add_argument("--N_rand", type=int, default=32*32,
464 | help='batch size (number of random rays per gradient step)')
465 | parser.add_argument("--lrate", type=float, default=5e-4,
466 | help='learning rate')
467 | parser.add_argument("--lrate_decay", type=int, default=250,
468 | help='exponential learning rate decay (in 1000 steps)')
469 | parser.add_argument("--chunk", type=int, default=1024*32*32,
470 | help='number of rays processed in parallel, decrease if running out of memory')
471 | parser.add_argument("--netchunk", type=int, default=1024*64*32,
472 | help='number of pts sent through network in parallel, decrease if running out of memory')
473 | parser.add_argument("--no_batching", default=False,action='store_true',
474 | help='only take random rays from 1 image at a time')
475 | parser.add_argument("--no_reload", action='store_true',
476 | help='do not reload weights from saved ckpt')
477 | parser.add_argument("--model_path", type=str, default='model_weights',
478 | help='path to trained model weights')
479 |
480 |
481 | # IoR model parameters
482 |
483 | parser.add_argument("--netdepth_ior", type=int, default=6,
484 | help='layers in the IoR network')
485 | parser.add_argument("--netwidth_ior", type=int, default=64,
486 | help='channels per layer')
487 | parser.add_argument("--multires_views_ior", type=int, default=5,
488 | help='log2 of max freq for positional encoding (2D direction)')
489 |
490 |
491 | # NeRF model inside parameters
492 |
493 | parser.add_argument("--netwidth_inside", type=int, default=128,
494 | help='layers in the IoR network')
495 |
496 | # NeRF model inside parameters
497 |
498 | parser.add_argument("--mode", type=int, default=1,
499 | help='rendering mode: 0:NeRF 1:NeRF+IoR 2:NeRF+IOR+NeRF_inside')
500 |
501 |
502 |
503 | # rendering options
504 | parser.add_argument("--N_samples", type=int, default=512,
505 | help='number of coarse samples per ray')
506 | parser.add_argument("--N_importance", type=int,
507 | help='Not used in this code')
508 | parser.add_argument("--perturb", type=float, default=1.,
509 | help='set to 0. for no jitter, 1. for jitter')
510 | parser.add_argument("--use_viewdirs", default=True,action='store_true',
511 | help='use full 5D input instead of 3D')
512 | parser.add_argument("--use_mask", default=False,action='store_true',
513 | help='use full 5D input instead of 3D')
514 | parser.add_argument("--i_embed", type=int, default=0,
515 | help='set 0 for default positional encoding, -1 for none')
516 | parser.add_argument("--multires", type=int, default=10,
517 | help='log2 of max freq for positional encoding (3D location)')
518 | parser.add_argument("--multires_views", type=int, default=4,
519 | help='log2 of max freq for positional encoding (2D direction)')
520 | parser.add_argument("--raw_noise_std", type=float, default=1.,
521 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
522 |
523 | parser.add_argument("--render_video", action='store_true',
524 | help='do not optimize, reload weights and render out render_poses path')
525 | parser.add_argument("--render_test", action='store_true',
526 | help='render the test set instead of render_poses path')
527 | parser.add_argument("--render_factor", type=int, default=0,
528 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
529 |
530 | # training options
531 | parser.add_argument("--precrop_iters", type=int, default=0,
532 | help='number of steps to train on central crops')
533 | parser.add_argument("--precrop_frac", type=float,
534 | default=.5, help='fraction of img taken for central crops')
535 |
536 | # dataset options
537 | parser.add_argument("--dataset_type", type=str, default='llff',
538 | help='options: llff / blender / deepvoxels')
539 | parser.add_argument("--testskip", type=int, default=8,
540 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
541 |
542 | ## deepvoxels flags
543 | parser.add_argument("--shape", type=str, default='greek',
544 | help='options : armchair / cube / greek / vase')
545 |
546 | ## blender flags
547 | parser.add_argument("--white_bkgd", action='store_true',
548 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
549 | parser.add_argument("--half_res", action='store_true',
550 | help='load blender synthetic data at 400x400 instead of 800x800')
551 |
552 | ## llff flags
553 | parser.add_argument("--factor", type=int, default=1,
554 | help='downsample factor for LLFF images')
555 | parser.add_argument("--no_ndc", default = True,action='store_true',
556 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
557 | parser.add_argument("--lindisp", action='store_true',
558 | help='sampling linearly in disparity rather than depth')
559 | parser.add_argument("--spherify", default = True, action='store_true',
560 | help='set for spherical 360 scenes')
561 | parser.add_argument("--llffhold", type=int, default=10,
562 | help='will take every 1/N images as LLFF test set, paper uses 8')
563 |
564 | # logging/saving options
565 | parser.add_argument("--i_print", type=int, default=100,
566 | help='frequency of console printout and metric loggin')
567 | parser.add_argument("--i_img", type=int, default=100,
568 | help='frequency of tensorboard image logging')
569 | parser.add_argument("--i_weights", type=int, default=500,
570 | help='frequency of weight ckpt saving')
571 |
572 |
573 | return parser
574 |
575 |
576 | def main():
577 |
578 | parser = config_parser()
579 | args = parser.parse_args()
580 |
581 | # Load data
582 | K = None
583 | if args.dataset_type == 'llff':
584 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
585 | recenter=True, bd_factor=.75,
586 | spherify=args.spherify,path_video=args.render_from_path) #_
587 | hwf = render_poses[0,:3,-1]
588 | poses = poses[:,:3,:4]
589 |
590 | # plt.plot(poses[:,0,3],poses[:,1,3],'.')
591 | # plt.plot(render_poses[:,0,3],render_poses[:,1,3],'.')
592 | # plt.show()
593 |
594 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
595 | if not isinstance(i_test, list):
596 | i_test = [i_test]
597 |
598 | if args.llffhold > 0:
599 | print('Auto LLFF holdout,', args.llffhold)
600 | i_test = np.arange(images.shape[0])[::args.llffhold]
601 |
602 | i_val = i_test
603 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
604 | (i not in i_test and i not in i_val)])
605 |
606 | print('DEFINING BOUNDS')
607 | if args.no_ndc:
608 | near = np.ndarray.min(bds) * .9
609 | far = np.ndarray.max(bds) * 1.
610 |
611 | else:
612 | near = 0.
613 | far = 1.
614 | if args.spherify:
615 | far = np.minimum(far,2.5)
616 | print('NEAR FAR', near, far)
617 |
618 | # Cast intrinsics to right types
619 | H, W, focal = hwf
620 | H, W = int(H), int(W)
621 | hwf = [H, W, focal]
622 |
623 | if K is None:
624 | K = np.array([
625 | [focal, 0, 0.5*W],
626 | [0, focal, 0.5*H],
627 | [0, 0, 1]
628 | ])
629 |
630 |
631 | # Create log dir and copy the config file
632 | basedir = args.basedir
633 | expname = args.expname
634 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
635 | f = os.path.join(basedir, expname, 'args.txt')
636 | with open(f, 'w') as file:
637 | for arg in sorted(vars(args)):
638 | attr = getattr(args, arg)
639 | file.write('{} = {}\n'.format(arg, attr))
640 | if args.config is not None:
641 | f = os.path.join(basedir, expname, 'config.txt')
642 | with open(f, 'w') as file:
643 | file.write(open(args.config, 'r').read())
644 |
645 |
646 |
647 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args)
648 |
649 |
650 |
651 | bds_dict = {
652 | 'near' : near,
653 | 'far' : far,
654 | }
655 |
656 | render_kwargs_test.update(bds_dict)
657 | render_kwargs_test['mode'] = args.mode
658 |
659 | print('loading bounding box values')
660 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy'))
661 |
662 | render_kwargs_test['bb_vals'] = bounding_box_vals
663 |
664 |
665 | with torch.no_grad():
666 |
667 |
668 |
669 | if args.render_from_path is not None:
670 |
671 | end_ = len(render_poses) - len(poses)
672 | render_poses = torch.Tensor(render_poses[:end_]).to(device)
673 |
674 |
675 | testsavedir = os.path.join(basedir, expname, 'rendered_from_a_path')
676 | os.makedirs(testsavedir, exist_ok=True)
677 | print('test poses shape', render_poses.shape)
678 |
679 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor)
680 | print('Done rendering', testsavedir)
681 | imageio.mimwrite(os.path.join(testsavedir, 'rendred.mp4'), to8b(rgbs), fps=10, quality=8)
682 |
683 |
684 | if args.render_test:
685 |
686 | render_poses = np.array(poses[i_test])
687 | render_poses = torch.Tensor(render_poses).to(device)
688 |
689 |
690 | testsavedir = os.path.join(basedir, expname, 'test_imgs')
691 | os.makedirs(testsavedir, exist_ok=True)
692 | print('test poses shape', render_poses.shape)
693 |
694 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor)
695 |
696 | if args.render_video:
697 |
698 | render_poses = torch.Tensor(render_poses).to(device)
699 |
700 | testsavedir = os.path.join(basedir, expname, 'rendered_video')
701 | os.makedirs(testsavedir, exist_ok=True)
702 | print('test poses shape', render_poses.shape)
703 |
704 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor)
705 | print('Done rendering', testsavedir)
706 | imageio.mimwrite(os.path.join(testsavedir, 'rendred.mp4'), to8b(rgbs), fps=10, quality=8)
707 |
708 |
709 | if __name__=='__main__':
710 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
711 |
712 | main()
713 |
714 |
715 |
--------------------------------------------------------------------------------
/run_nerf_inside.py:
--------------------------------------------------------------------------------
1 | import os, sys,cv2
2 | import numpy as np
3 | import imageio
4 | import json
5 | import random
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from tqdm import tqdm, trange
11 | from torchdiffeq import odeint_adjoint as odeint
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from run_nerf_helpers import *
16 |
17 | from load_llff import load_llff_data
18 | from load_deepvoxels import load_dv_data
19 | from load_blender import load_blender_data
20 | from load_LINEMOD import load_LINEMOD_data
21 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
22 |
23 |
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 | np.random.seed(0)
26 | DEBUG = False
27 |
28 |
29 | def get_rnd(all_rays,all_idx_inside,num_ray):
30 |
31 |
32 | all_ = []
33 | chunk_per_im = num_ray//len(all_idx_inside)
34 | for i in range(len(all_idx_inside)):
35 |
36 | idx_inside = all_idx_inside[i][np.random.choice(len(all_idx_inside[i]), size=[chunk_per_im], replace=False)]
37 | all_.append(all_rays[i][idx_inside])
38 |
39 | return torch.cat(all_,0)
40 |
41 | def get_mask_idx(path_to_mask,num_im,i_train):
42 |
43 | mask_ = []
44 | for img_i in range(num_im):
45 |
46 | im = imageio.imread(os.path.join(path_to_mask, 'img_%0.3d.png'%(img_i)))
47 | im = np.float32(im)/255.0
48 | im = im.reshape([-1])
49 | inside = np.where(im== 1.0)[0]
50 | mask_.append(inside)
51 |
52 |
53 | mask = [mask_[i] for i in i_train]
54 |
55 | return mask
56 |
57 |
58 |
59 | def batchify(fn, chunk):
60 | """Constructs a version of 'fn' that applies to smaller batches.
61 | """
62 | if chunk is None:
63 | return fn
64 | def ret(inputs):
65 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
66 | return ret
67 |
68 |
69 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
70 | """Prepares inputs and applies network 'fn'.
71 | """
72 | embedded = embed_fn(inputs)
73 |
74 | if viewdirs is not None:
75 |
76 | embedded_dirs = embeddirs_fn(viewdirs)
77 | embedded = torch.cat([embedded, embedded_dirs], -1)
78 |
79 | outputs_flat = batchify(fn, netchunk)(embedded)
80 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
81 | return outputs
82 |
83 |
84 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64):
85 |
86 | embedded = embed_fn(inputs)
87 |
88 |
89 | outputs = batchify(fn, netchunk)(embedded)
90 |
91 | return outputs
92 |
93 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
94 | """Render rays in smaller minibatches to avoid OOM.
95 | """
96 | all_ret = {}
97 | for i in range(0, rays_flat.shape[0], chunk):
98 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
99 | for k in ret:
100 | if k not in all_ret:
101 | all_ret[k] = []
102 | all_ret[k].append(ret[k])
103 |
104 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
105 | return all_ret
106 |
107 |
108 |
109 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
110 | near=0., far=1.,
111 | use_viewdirs=False, c2w_staticcam=None,
112 | **kwargs):
113 | """Render rays
114 | Args:
115 | H: int. Height of image in pixels.
116 | W: int. Width of image in pixels.
117 | focal: float. Focal length of pinhole camera.
118 | chunk: int. Maximum number of rays to process simultaneously. Used to
119 | control maximum memory usage. Does not affect final results.
120 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
121 | each example in batch.
122 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
123 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
124 | near: float or array of shape [batch_size]. Nearest distance for a ray.
125 | far: float or array of shape [batch_size]. Farthest distance for a ray.
126 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
127 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
128 | camera while using other c2w argument for viewing directions.
129 | Returns:
130 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
131 | disp_map: [batch_size]. Disparity map. Inverse of depth.
132 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
133 | extras: dict with everything returned by render_rays().
134 | """
135 | if c2w is not None:
136 | # special case to render full image
137 | rays_o, rays_d = get_rays(H, W, K, c2w)
138 | else:
139 | # use provided ray batch
140 | rays_o, rays_d = rays
141 |
142 | if use_viewdirs:
143 | # provide ray directions as input
144 | viewdirs = rays_d
145 | if c2w_staticcam is not None:
146 | # special case to visualize effect of viewdirs
147 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
148 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
149 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
150 |
151 | sh = rays_d.shape # [..., 3]
152 | if ndc:
153 | # for forward facing scenes
154 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
155 |
156 | # Create ray batch
157 | rays_o = torch.reshape(rays_o, [-1,3]).float()
158 | rays_d = torch.reshape(rays_d, [-1,3]).float()
159 |
160 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
161 | rays = torch.cat([rays_o, rays_d, near, far], -1)
162 | if use_viewdirs:
163 | rays = torch.cat([rays, viewdirs], -1)
164 |
165 | # Render and reshape
166 | all_ret = batchify_rays(rays, chunk, **kwargs)
167 | for k in all_ret:
168 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
169 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
170 |
171 | k_extract = ['rgb_map']
172 | ret_list = [all_ret[k] for k in k_extract]
173 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
174 | return ret_list + [ret_dict]
175 |
176 |
177 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
178 |
179 | H, W, focal = hwf
180 |
181 | if render_factor!=0:
182 | # Render downsampled for speed
183 | H = H//render_factor
184 | W = W//render_factor
185 | focal = focal/render_factor
186 |
187 | rgbs = []
188 | # disps = []
189 |
190 | t = time.time()
191 | for i, c2w in enumerate(tqdm(render_poses)):
192 | print(i, time.time() - t)
193 | t = time.time()
194 | rgb, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
195 | rgbs.append(rgb.cpu().numpy())
196 | # disps.append(disp.cpu().numpy())
197 | # if i==0:
198 | # print(rgb.shape, disp.shape)
199 |
200 | """
201 | if gt_imgs is not None and render_factor==0:
202 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
203 | print(p)
204 | """
205 |
206 | if savedir is not None:
207 | rgb8 = to8b(rgbs[-1])
208 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
209 | imageio.imwrite(filename, rgb8)
210 |
211 |
212 | rgbs = np.stack(rgbs, 0)
213 | # disps = np.stack(disps, 0)
214 |
215 | return rgbs
216 |
217 |
218 | def create_models(args):
219 | """Instantiate NeRF's MLP model.
220 | """
221 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
222 |
223 | input_ch_views = 0
224 | embeddirs_fn = None
225 | if args.use_viewdirs:
226 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
227 | output_ch = 4
228 | skips = [4]
229 | model = NeRF(D=args.netdepth, W=args.netwidth,
230 | input_ch=input_ch, output_ch=output_ch, skips=skips,
231 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
232 |
233 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
234 | input_ch=input_ch, output_ch=output_ch, skips=skips,
235 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
236 |
237 |
238 |
239 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed)
240 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device)
241 | model_ior.apply(init_weights)
242 |
243 | model_inside = NeRF(D=args.netdepth, W=args.netwidth_inside,
244 | input_ch=input_ch, output_ch=output_ch, skips=skips,
245 | input_ch_views=input_ch_views, use_viewdirs=True).to(device)
246 |
247 | grad_vars = list(model_inside.parameters())
248 |
249 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
250 | embed_fn=embed_fn,
251 | embeddirs_fn=embeddirs_fn,
252 | netchunk=args.netchunk)
253 |
254 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn,
255 | embed_fn = embed_fn_ior,
256 | netchunk=args.netchunk)
257 | # Create optimizer
258 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
259 |
260 | start = 0
261 | basedir = args.basedir
262 | expname = args.expname
263 |
264 | ##########################
265 |
266 | # Load checkpoints
267 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f]
268 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path))
269 |
270 |
271 | print('Found ckpts', ckpts)
272 |
273 | if len(ckpts) > 0:
274 | ckpt_path = ckpts[-1]
275 | print('Reloading from', ckpt_path)
276 | ckpt = torch.load(ckpt_path)
277 |
278 | start = ckpt['global_step']
279 |
280 | # Load model
281 | model.load_state_dict(ckpt['network_fn_state_dict'])
282 |
283 | if model_fine is not None:
284 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
285 |
286 | model_ior.load_state_dict(ckpt['network_ior_state_dict'])
287 | print('Reloading the weights for the NeRF and IoR MLPs')
288 |
289 | if "network_inside_state_dict" in ckpt and not args.no_reload:
290 | model_inside.load_state_dict(ckpt['network_inside_state_dict'])
291 | print('Reloading the weights for the inside NeRF MLP')
292 |
293 |
294 |
295 | ##########################
296 |
297 | render_kwargs_train = {
298 |
299 | 'network_query_fn' : network_query_fn,
300 | 'network_query_fn_ior': network_query_fn_ior,
301 | 'perturb' : args.perturb,
302 | 'N_samples' : args.N_samples,
303 | 'network_fn' : model,
304 | 'network_fine' : model_fine,
305 | 'network_ior': model_ior,
306 | 'network_inside':model_inside,
307 | 'use_viewdirs' : args.use_viewdirs,
308 | 'white_bkgd' : args.white_bkgd,
309 | 'raw_noise_std' : args.raw_noise_std,
310 | }
311 |
312 | # NDC only good for LLFF-style forward facing data
313 | if args.dataset_type != 'llff' or args.no_ndc:
314 | print('Not ndc!')
315 | render_kwargs_train['ndc'] = False
316 | render_kwargs_train['lindisp'] = args.lindisp
317 |
318 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
319 | render_kwargs_test['perturb'] = False
320 | render_kwargs_test['raw_noise_std'] = 0.
321 |
322 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
323 |
324 |
325 |
326 | class IoR_emissionAbsorptionODE(nn.Module):
327 |
328 | def __init__(self,query_nerf,model_nerf,nerf_inside,query_ior,model_ior,n_samples,step_size,vals,viewdir):
329 | super(IoR_emissionAbsorptionODE, self).__init__()
330 |
331 | self.query_ior = query_ior
332 | self.model_ior = model_ior
333 |
334 | self.query_nerf = query_nerf
335 | self.model_nerf = model_nerf
336 |
337 | self.nerf_inside = nerf_inside
338 |
339 | self.n_samples = n_samples
340 | self.step_size = step_size
341 | self.vals = vals
342 | self.viewdir = viewdir
343 |
344 | def forward(self,t,y):
345 |
346 |
347 | pts = y[:,0:3]
348 | ray_dir = y[:,3:6]
349 | transmition_ = y[:,9:10]
350 |
351 |
352 | # querying the original radiance field for the entire scene
353 | with torch.no_grad():
354 | raw = self.query_nerf(pts,self.viewdir,self.model_nerf)
355 |
356 | rgb = torch.sigmoid(raw[...,:3])
357 | density = F.relu(raw[...,3:])
358 |
359 | # determining whether a point is inside the bounding box or not
360 | weights= get_bb_weights(pts,self.vals)
361 |
362 | # finding the points inside the bounding box
363 |
364 | insides = torch.where(weights<1.0)[0]
365 | ior_grad = torch.zeros_like(pts)
366 | steps = torch.ones_like(density)*self.step_size
367 |
368 |
369 | # calculating NeRF model only for the points inside the bounding box
370 |
371 | if len(insides) > 0:
372 |
373 |
374 | # querying the radiance field for the content inside the bounding box
375 |
376 | pts_inside = pts[insides,:]
377 | viewdir_inside = self.viewdir[insides,:]
378 | raw = self.query_nerf(pts_inside,viewdir_inside,self.nerf_inside)
379 | rgb_inside = torch.sigmoid(raw[...,:3])
380 | density_inside = F.relu(raw[...,3:])
381 |
382 |
383 | # linearly blnding the radiance field for the content inside and outisde the bounding box
384 | density[insides,:] = density[insides,:]*(weights[insides,:])+ density_inside*(1.-weights[insides,:])
385 | rgb[insides,:] = rgb[insides,:]*(weights[insides,:])+ rgb_inside*(1.-weights[insides,:])
386 |
387 | # computing the gradient of IoR
388 | with torch.enable_grad():
389 |
390 | dn = torch.autograd.functional.vjp(lambda x : self.query_ior(x,self.model_ior), pts_inside ,v = torch.ones_like(pts_inside[:,0:1]))
391 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:])
392 |
393 |
394 | dv_ds = steps*ior_grad
395 | dx_ds = steps*normalizing(ray_dir)
396 | alpha = 1. - torch.exp(-density*steps/self.n_samples)
397 | alpha = alpha.clip(0.,1.)
398 | dc_dt = transmition_*rgb*alpha*self.n_samples
399 | dT_dt = -transmition_*alpha*self.n_samples
400 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1)
401 |
402 |
403 | return dy_dt
404 |
405 |
406 | def render_rays(ray_batch,
407 | N_samples,
408 | network_query_fn_ior,
409 | bb_vals,
410 | network_ior,
411 | network_inside,
412 | network_fn=None,
413 | network_query_fn=None,
414 | retraw=False,
415 | lindisp=False,
416 | perturb=0.,
417 | network_fine=None,
418 | white_bkgd=False,
419 | raw_noise_std=0.,
420 | verbose=False,
421 | pytest=False):
422 |
423 | N_rays = ray_batch.shape[0]
424 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
425 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
426 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
427 |
428 | t_vals = torch.linspace(0., 1., steps=N_samples)
429 |
430 |
431 | pts = rays_o + normalizing(rays_d)*near
432 | viewdir = normalizing(rays_d)
433 | color_ = torch.zeros_like(pts)
434 | transmition_ = torch.ones((N_rays,1))
435 |
436 |
437 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device)
438 | step_size = far[0]-near[0]
439 | output = odeint(IoR_emissionAbsorptionODE(network_query_fn,network_fine,network_inside,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,viewdir), y0, t_vals,method='euler')
440 |
441 | pts = output[:,:,0:3].permute(1,0,2)
442 |
443 | rgb_map = output[-1,:,6:9]
444 |
445 |
446 | ret = {'rgb_map' : rgb_map}
447 | if retraw:
448 | ret['raw'] = pts
449 |
450 | for k in ret:
451 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
452 | print(f"! [Numerical Error] {k} contains nan or inf.")
453 |
454 | return ret
455 |
456 | def config_parser():
457 |
458 | import configargparse
459 | parser = configargparse.ArgumentParser()
460 | parser.add_argument('--config', is_config_file=True,
461 | help='config file path')
462 | parser.add_argument("--expname", default='glass_ball_new',type=str,
463 | help='experiment name')
464 | parser.add_argument("--basedir", type=str, default='logs',
465 | help='where to store ckpts and logs')
466 | parser.add_argument("--datadir", type=str, default='data\\glass_ball',
467 | help='input data directory')
468 |
469 | # training options
470 | parser.add_argument("--netdepth", type=int, default=8,
471 | help='layers in network')
472 | parser.add_argument("--netwidth", type=int, default=256,
473 | help='channels per layer')
474 | parser.add_argument("--netdepth_fine", type=int, default=8,
475 | help='layers in fine network')
476 | parser.add_argument("--netwidth_fine", type=int, default=256,
477 | help='channels per layer in fine network')
478 | parser.add_argument("--N_rand", type=int, default=32*32,
479 | help='batch size (number of random rays per gradient step)')
480 | parser.add_argument("--lrate", type=float, default=5e-4,
481 | help='learning rate')
482 | parser.add_argument("--lrate_decay", type=int, default=250,
483 | help='exponential learning rate decay (in 1000 steps)')
484 | parser.add_argument("--chunk", type=int, default=1024*32*32,
485 | help='number of rays processed in parallel, decrease if running out of memory')
486 | parser.add_argument("--netchunk", type=int, default=1024*64*32,
487 | help='number of pts sent through network in parallel, decrease if running out of memory')
488 | parser.add_argument("--no_batching", default=False,action='store_true',
489 | help='only take random rays from 1 image at a time')
490 | parser.add_argument("--no_reload", default=True, action='store_true',
491 | help='do not reload weights of MLP for NeRF inside')
492 | parser.add_argument("--model_path", type=str, default='model_weights',
493 | help='path to trained model weights')
494 |
495 |
496 | # IoR model parameters
497 |
498 | parser.add_argument("--netdepth_ior", type=int, default=6,
499 | help='layers in the IoR network')
500 | parser.add_argument("--netwidth_ior", type=int, default=64,
501 | help='channels per layer')
502 | parser.add_argument("--multires_views_ior", type=int, default=5,
503 | help='log2 of max freq for positional encoding (2D direction)')
504 |
505 |
506 | # NeRF model inside parameters
507 |
508 | parser.add_argument("--netwidth_inside", type=int, default=128,
509 | help='layers in the IoR network')
510 |
511 |
512 | # rendering options
513 | parser.add_argument("--N_samples", type=int, default=512,
514 | help='number of coarse samples per ray')
515 | parser.add_argument("--N_importance", type=int,
516 | help='Not used in this code')
517 | parser.add_argument("--perturb", type=float, default=1.,
518 | help='set to 0. for no jitter, 1. for jitter')
519 | parser.add_argument("--use_viewdirs", default=True,action='store_true',
520 | help='use full 5D input instead of 3D')
521 | parser.add_argument("--use_mask", default=False,action='store_true',
522 | help='use full 5D input instead of 3D')
523 | parser.add_argument("--i_embed", type=int, default=0,
524 | help='set 0 for default positional encoding, -1 for none')
525 | parser.add_argument("--multires", type=int, default=10,
526 | help='log2 of max freq for positional encoding (3D location)')
527 | parser.add_argument("--multires_views", type=int, default=4,
528 | help='log2 of max freq for positional encoding (2D direction)')
529 | parser.add_argument("--raw_noise_std", type=float, default=1.,
530 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
531 |
532 | parser.add_argument("--render_only", action='store_true',
533 | help='do not optimize, reload weights and render out render_poses path')
534 | parser.add_argument("--render_test", action='store_true',
535 | help='render the test set instead of render_poses path')
536 | parser.add_argument("--render_factor", type=int, default=0,
537 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
538 |
539 | # training options
540 | parser.add_argument("--precrop_iters", type=int, default=0,
541 | help='number of steps to train on central crops')
542 | parser.add_argument("--precrop_frac", type=float,
543 | default=.5, help='fraction of img taken for central crops')
544 |
545 | # dataset options
546 | parser.add_argument("--dataset_type", type=str, default='llff',
547 | help='options: llff / blender / deepvoxels')
548 | parser.add_argument("--testskip", type=int, default=8,
549 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
550 |
551 | ## deepvoxels flags
552 | parser.add_argument("--shape", type=str, default='greek',
553 | help='options : armchair / cube / greek / vase')
554 |
555 | ## blender flags
556 | parser.add_argument("--white_bkgd", action='store_true',
557 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
558 | parser.add_argument("--half_res", action='store_true',
559 | help='load blender synthetic data at 400x400 instead of 800x800')
560 |
561 | ## llff flags
562 | parser.add_argument("--factor", type=int, default=1,
563 | help='downsample factor for LLFF images')
564 | parser.add_argument("--no_ndc", default = True,action='store_true',
565 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
566 | parser.add_argument("--lindisp", action='store_true',
567 | help='sampling linearly in disparity rather than depth')
568 | parser.add_argument("--spherify", default = True, action='store_true',
569 | help='set for spherical 360 scenes')
570 | parser.add_argument("--llffhold", type=int, default=10,
571 | help='will take every 1/N images as LLFF test set, paper uses 8')
572 |
573 | # logging/saving options
574 | parser.add_argument("--i_print", type=int, default=100,
575 | help='frequency of console printout and metric loggin')
576 | parser.add_argument("--i_img", type=int, default=100,
577 | help='frequency of tensorboard image logging')
578 | parser.add_argument("--i_weights", type=int, default=500,
579 | help='frequency of weight ckpt saving')
580 |
581 |
582 | return parser
583 |
584 |
585 | def train():
586 |
587 | parser = config_parser()
588 | args = parser.parse_args()
589 |
590 | # Load data
591 | K = None
592 | if args.dataset_type == 'llff':
593 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
594 | recenter=True, bd_factor=.75,
595 | spherify=args.spherify)
596 | hwf = poses[0,:3,-1]
597 | poses = poses[:,:3,:4]
598 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
599 | if not isinstance(i_test, list):
600 | i_test = [i_test]
601 |
602 | if args.llffhold > 0:
603 | print('Auto LLFF holdout,', args.llffhold)
604 | i_test = np.arange(images.shape[0])[::args.llffhold]
605 |
606 | i_val = i_test
607 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
608 | (i not in i_test and i not in i_val)])
609 |
610 | print('DEFINING BOUNDS')
611 | if args.no_ndc:
612 | near = np.ndarray.min(bds) * .9
613 | far = np.ndarray.max(bds) * 1.
614 |
615 | else:
616 | near = 0.
617 | far = 1.
618 |
619 |
620 | if args.spherify:
621 | far = np.minimum(far,2.5)
622 | print('NEAR FAR', near, far)
623 |
624 | # Cast intrinsics to right types
625 | H, W, focal = hwf
626 | H, W = int(H), int(W)
627 | hwf = [H, W, focal]
628 |
629 | if K is None:
630 | K = np.array([
631 | [focal, 0, 0.5*W],
632 | [0, focal, 0.5*H],
633 | [0, 0, 1]
634 | ])
635 |
636 | if args.render_test:
637 | render_poses = np.array(poses[i_test])
638 |
639 | # Create log dir and copy the config file
640 | basedir = args.basedir
641 | expname = args.expname
642 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
643 | f = os.path.join(basedir, expname, 'args.txt')
644 | with open(f, 'w') as file:
645 | for arg in sorted(vars(args)):
646 | attr = getattr(args, arg)
647 | file.write('{} = {}\n'.format(arg, attr))
648 | if args.config is not None:
649 | f = os.path.join(basedir, expname, 'config.txt')
650 | with open(f, 'w') as file:
651 | file.write(open(args.config, 'r').read())
652 |
653 |
654 |
655 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args)
656 | start = 0
657 | global_step = start
658 |
659 |
660 | bds_dict = {
661 | 'near' : near,
662 | 'far' : far,
663 | }
664 | render_kwargs_train.update(bds_dict)
665 | render_kwargs_test.update(bds_dict)
666 |
667 |
668 | print('ray batching')
669 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
670 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
671 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
672 | rays_rgb = np.reshape(rays_rgb, [len(images),-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
673 | rays_rgb = rays_rgb.astype(np.float32)
674 | rays_rgb = torch.Tensor(rays_rgb).to(device)
675 | rays_rgb = [rays_rgb[i] for i in i_train] # train images only
676 |
677 |
678 | poses = torch.Tensor(poses).to(device)
679 | render_poses = torch.Tensor(render_poses).to(device)
680 |
681 |
682 | print('loading mask images')
683 | path_mask = os.path.join(basedir, expname, 'masked_regions') # concatenating the indices of masked region in each view
684 | inside_idx = get_mask_idx(path_mask,len(poses),i_train)
685 |
686 |
687 | print('loading bounding box values')
688 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy'))
689 |
690 |
691 | render_kwargs_test['bb_vals'] = bounding_box_vals
692 | render_kwargs_train['bb_vals'] = bounding_box_vals
693 |
694 |
695 |
696 | N_iters = 10000 + 1 # number of total iteration
697 | N_rand = args.N_rand
698 | print('Begin')
699 | print('TEST views are', i_test)
700 |
701 |
702 | for i in trange(start,N_iters):
703 | time0 = time.time()
704 |
705 |
706 | # taking random rays overl all views
707 | batch = get_rnd(rays_rgb,inside_idx,N_rand)
708 | batch = torch.transpose(batch, 0, 1)
709 |
710 | batch_rays, target_s = batch[:2], batch[2]
711 |
712 | render_kwargs_train['N_samples'] = 256 + 256*i//N_iters # starting with 256 smaples and linearly increasing it to 512 samples
713 |
714 |
715 | rgb, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
716 | verbose=i < 10, retraw=False,
717 | **render_kwargs_train)
718 |
719 | optimizer.zero_grad()
720 | loss = img2abs(rgb, target_s)
721 | psnr = mse2psnr(img2mse(rgb, target_s))
722 |
723 |
724 |
725 | loss.backward()
726 | optimizer.step()
727 |
728 | decay_rate = 0.1
729 | decay_steps = N_iters
730 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
731 | for param_group in optimizer.param_groups:
732 | param_group['lr'] = new_lrate
733 | ################################
734 |
735 | dt = time.time()-time0
736 |
737 | if i%args.i_weights==0:
738 | model_path_dir = os.path.join(basedir, expname,args.model_path)
739 | os.makedirs(model_path_dir, exist_ok=True)
740 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i))
741 | torch.save({
742 | 'global_step': global_step,
743 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
744 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
745 | 'network_ior_state_dict': render_kwargs_train['network_ior'].state_dict(),
746 | 'network_inside_state_dict': render_kwargs_train['network_inside'].state_dict(),
747 | 'optimizer_state_dict': optimizer.state_dict(),
748 | }, path)
749 | print('Saved checkpoints at', path)
750 |
751 |
752 | if i%args.i_img==0:
753 |
754 | img_i= i_test[5]
755 | pose = poses[img_i, :3,:4]
756 | with torch.no_grad():
757 | rgb ,extras = render(H, W,K, chunk=args.chunk, c2w=pose,retraw=False,
758 | **render_kwargs_test)
759 |
760 |
761 | testimgdir = os.path.join(basedir, expname, 'training_nerf_inside')
762 | os.makedirs(testimgdir, exist_ok=True)
763 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:03d}.png'.format(i)), to8b(rgb.cpu().numpy()))
764 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:03d}.png'.format(img_i)), to8b(images[img_i]))
765 |
766 |
767 | if i%args.i_print==0:
768 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
769 |
770 |
771 | global_step += 1
772 |
773 |
774 | if __name__=='__main__':
775 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
776 |
777 | train()
--------------------------------------------------------------------------------
/run_ior.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys, cv2
3 | import numpy as np
4 | import imageio
5 | import json
6 | import random
7 | import time
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from tqdm import tqdm, trange
12 | from torchdiffeq import odeint_adjoint as odeint
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | from run_nerf_helpers import *
17 |
18 | from load_llff import load_llff_data
19 | from load_deepvoxels import load_dv_data
20 | from load_blender import load_blender_data
21 | from load_LINEMOD import load_LINEMOD_data
22 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
23 |
24 |
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | np.random.seed(0)
27 | DEBUG = False
28 |
29 |
30 | def get_rnd(all_rays,all_idx_inside,num_ray):
31 |
32 |
33 | all_ = []
34 | chunk_per_im = num_ray//len(all_idx_inside)
35 | for i in range(len(all_idx_inside)):
36 |
37 | idx_inside = all_idx_inside[i][np.random.choice(len(all_idx_inside[i]), size=[chunk_per_im], replace=False)]
38 | all_.append(all_rays[i][idx_inside])
39 |
40 | return torch.cat(all_,0)
41 |
42 |
43 | def get_mask_idx(path_to_mask,num_im,i_train):
44 |
45 | mask_ = []
46 | for img_i in range(num_im):
47 |
48 | im = imageio.imread(os.path.join(path_to_mask, 'img_%0.3d.png'%(img_i)))
49 | im = np.float32(im)/255.0
50 | im = im.reshape([-1])
51 | inside = np.where(im== 1.0)[0]
52 | mask_.append(inside)
53 |
54 |
55 | mask = [mask_[i] for i in i_train]
56 |
57 | return mask
58 |
59 |
60 | def batchify(fn, chunk):
61 | """Constructs a version of 'fn' that applies to smaller batches.
62 | """
63 | if chunk is None:
64 | return fn
65 | def ret(inputs):
66 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
67 | return ret
68 |
69 |
70 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
71 | """Prepares inputs and applies network 'fn'.
72 | """
73 | embedded = embed_fn(inputs)
74 |
75 | if viewdirs is not None:
76 | embedded_dirs = embeddirs_fn(viewdirs)
77 |
78 | embedded = torch.cat([embedded, embedded_dirs], -1)
79 |
80 | outputs_flat = batchify(fn, netchunk)(embedded)
81 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
82 | return outputs
83 |
84 |
85 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64):
86 |
87 | embedded = embed_fn(inputs)
88 |
89 |
90 | outputs = batchify(fn, netchunk)(embedded)
91 |
92 | return outputs
93 |
94 |
95 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
96 | """Render rays in smaller minibatches to avoid OOM.
97 | """
98 | all_ret = {}
99 | for i in range(0, rays_flat.shape[0], chunk):
100 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
101 | for k in ret:
102 | if k not in all_ret:
103 | all_ret[k] = []
104 | all_ret[k].append(ret[k])
105 |
106 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
107 | return all_ret
108 |
109 |
110 |
111 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
112 | near=0., far=1.,
113 | use_viewdirs=False, c2w_staticcam=None,
114 | **kwargs):
115 | """Render rays
116 | Args:
117 | H: int. Height of image in pixels.
118 | W: int. Width of image in pixels.
119 | focal: float. Focal length of pinhole camera.
120 | chunk: int. Maximum number of rays to process simultaneously. Used to
121 | control maximum memory usage. Does not affect final results.
122 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
123 | each example in batch.
124 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
125 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
126 | near: float or array of shape [batch_size]. Nearest distance for a ray.
127 | far: float or array of shape [batch_size]. Farthest distance for a ray.
128 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
129 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
130 | camera while using other c2w argument for viewing directions.
131 | Returns:
132 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
133 | disp_map: [batch_size]. Disparity map. Inverse of depth.
134 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
135 | extras: dict with everything returned by render_rays().
136 | """
137 | if c2w is not None:
138 | # special case to render full image
139 | rays_o, rays_d = get_rays(H, W, K, c2w)
140 | else:
141 | # use provided ray batch
142 | rays_o, rays_d = rays
143 |
144 | if use_viewdirs:
145 | # provide ray directions as input
146 | viewdirs = rays_d
147 | if c2w_staticcam is not None:
148 | # special case to visualize effect of viewdirs
149 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
150 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
151 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
152 |
153 | sh = rays_d.shape # [..., 3]
154 | if ndc:
155 | # for forward facing scenes
156 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
157 |
158 | # Create ray batch
159 | rays_o = torch.reshape(rays_o, [-1,3]).float()
160 | rays_d = torch.reshape(rays_d, [-1,3]).float()
161 |
162 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
163 | rays = torch.cat([rays_o, rays_d, near, far], -1)
164 | if use_viewdirs:
165 | rays = torch.cat([rays, viewdirs], -1)
166 |
167 | # Render and reshape
168 | all_ret = batchify_rays(rays, chunk, **kwargs)
169 | for k in all_ret:
170 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
171 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
172 |
173 | k_extract = ['rgb_map']
174 | ret_list = [all_ret[k] for k in k_extract]
175 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
176 | return ret_list + [ret_dict]
177 |
178 |
179 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
180 |
181 | H, W, focal = hwf
182 |
183 | if render_factor!=0:
184 | # Render downsampled for speed
185 | H = H//render_factor
186 | W = W//render_factor
187 | focal = focal/render_factor
188 |
189 | rgbs = []
190 | disps = []
191 |
192 | t = time.time()
193 | for i, c2w in enumerate(tqdm(render_poses)):
194 | print(i, time.time() - t)
195 | t = time.time()
196 | rgb, disp, acc, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
197 | rgbs.append(rgb.cpu().numpy())
198 | disps.append(disp.cpu().numpy())
199 | if i==0:
200 | print(rgb.shape, disp.shape)
201 |
202 | """
203 | if gt_imgs is not None and render_factor==0:
204 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
205 | print(p)
206 | """
207 |
208 | if savedir is not None:
209 | rgb8 = to8b(rgbs[-1])
210 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
211 | imageio.imwrite(filename, rgb8)
212 |
213 |
214 | rgbs = np.stack(rgbs, 0)
215 | disps = np.stack(disps, 0)
216 |
217 | return rgbs, disps
218 |
219 |
220 | def create_models(args):
221 | """Instantiate NeRF and IOR MLP models.
222 | """
223 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
224 |
225 | input_ch_views = 0
226 | embeddirs_fn = None
227 | if args.use_viewdirs:
228 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
229 | output_ch = 0
230 | skips = [4]
231 | model = NeRF(D=args.netdepth, W=args.netwidth,
232 | input_ch=input_ch, output_ch=output_ch, skips=skips,
233 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
234 |
235 |
236 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
237 | input_ch=input_ch, output_ch=output_ch, skips=skips,
238 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
239 |
240 |
241 |
242 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed)
243 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device)
244 | model_ior.apply(init_weights)
245 |
246 |
247 | # only optimizing for the IoR field
248 | grad_vars = list(model_ior.parameters())
249 |
250 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
251 | embed_fn=embed_fn,
252 | embeddirs_fn=embeddirs_fn,
253 | netchunk=args.netchunk)
254 |
255 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn,
256 | embed_fn = embed_fn_ior,
257 | netchunk=args.netchunk)
258 | # Create optimizer
259 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
260 |
261 |
262 | start = 0
263 | basedir = args.basedir
264 | expname = args.expname
265 |
266 |
267 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f]
268 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path))
269 |
270 |
271 | print('Found ckpts', ckpts)
272 |
273 | if len(ckpts) > 0 :
274 | ckpt_path = ckpts[-1]
275 | print('Reloading from', ckpt_path)
276 | ckpt = torch.load(ckpt_path)
277 |
278 | start = ckpt['global_step']
279 |
280 | # Load model
281 | model.load_state_dict(ckpt['network_fn_state_dict'])
282 | print('Reloading the weights for the NeRF MLPs')
283 |
284 | if model_fine is not None:
285 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
286 |
287 | if "network_ior_state_dict" in ckpt and not args.no_reload:
288 | model_ior.load_state_dict(ckpt['network_ior_state_dict'])
289 | print('Reloading the weights for the IoR MLP')
290 |
291 |
292 |
293 | ##########################
294 |
295 | render_kwargs_train = {
296 |
297 | 'network_query_fn' : network_query_fn,
298 | 'network_query_fn_ior': network_query_fn_ior,
299 | 'perturb' : args.perturb,
300 | 'network_fine' : model_fine,
301 | 'N_samples' : args.N_samples,
302 | 'network_fn' : model,
303 | 'network_ior': model_ior,
304 | 'use_viewdirs' : args.use_viewdirs,
305 | 'white_bkgd' : args.white_bkgd,
306 | 'raw_noise_std' : args.raw_noise_std,
307 | }
308 |
309 | # NDC only good for LLFF-style forward facing data
310 | if args.dataset_type != 'llff' or args.no_ndc:
311 | print('Not ndc!')
312 | render_kwargs_train['ndc'] = False
313 | render_kwargs_train['lindisp'] = args.lindisp
314 |
315 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
316 | render_kwargs_test['perturb'] = False
317 | render_kwargs_test['raw_noise_std'] = 0.
318 |
319 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
320 |
321 |
322 |
323 | class IoR_emissionAbsorptionODE(nn.Module):
324 |
325 | def __init__(self,voxel_grid,query_fn,model,n_samples,step_size,bb_vals,scene_bound):
326 | super(IoR_emissionAbsorptionODE, self).__init__()
327 |
328 | self.voxel_grid = voxel_grid
329 | self.model = model
330 | self.query_fn = query_fn
331 | self.n_samples = n_samples
332 | self.step_size = step_size
333 | self.bounding_box_vals = bb_vals
334 | self.scene_bound = scene_bound
335 |
336 | def forward(self,t,y):
337 |
338 |
339 | pts = y[:,0:3]
340 | ray_dir = y[:,3:6]
341 | transmition_ = y[:,9:10]
342 |
343 | # query the 3D grid with trilinear interpolation
344 | raw = trilinear_interpolation(self.voxel_grid,pts,self.scene_bound)
345 | rgb = raw[...,:3]
346 | density = raw[...,3:]
347 |
348 |
349 | # determining whether a point is inside the bounding box or not
350 | weights = get_bb_weights(pts,self.bounding_box_vals)
351 |
352 |
353 | # finding the points inside the bounding box
354 | insides = torch.where((weights<1.))[0]
355 | ior_grad = torch.zeros_like(pts)
356 |
357 |
358 | # setting the density to zero for the points inside the bounding box
359 | density= density*weights
360 |
361 | # calculating the gradient of IoR only for the points inside the bounding box
362 | if len(insides) > 0:
363 | pts_inside = pts[insides,:]
364 |
365 | with torch.enable_grad():
366 |
367 | dn = torch.autograd.functional.vjp(lambda x : self.query_fn(x,self.model), pts_inside ,v = torch.ones_like(pts_inside[:,0:1]),create_graph =True )
368 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:])
369 |
370 |
371 |
372 | dv_ds = self.step_size*ior_grad
373 | dx_ds = self.step_size*normalizing(ray_dir)
374 | alpha = 1. - torch.exp(-density*self.step_size/self.n_samples)
375 | alpha = alpha.clip(0.,1.)
376 | dc_dt = transmition_*rgb*alpha*self.n_samples
377 | dT_dt = -transmition_*alpha*self.n_samples
378 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1)
379 |
380 |
381 | return dy_dt
382 |
383 |
384 | def render_rays(ray_batch,
385 | voxel_grid,
386 | N_samples,
387 | network_query_fn_ior,
388 | bb_vals,
389 | network_ior,
390 | scene_bound,
391 | network_fn=None,
392 | network_query_fn=None,
393 | retraw=False,
394 | lindisp=False,
395 | perturb=0.,
396 | network_fine=None,
397 | white_bkgd=False,
398 | raw_noise_std=0.,
399 | verbose=False,
400 | pytest=False):
401 |
402 | N_rays = ray_batch.shape[0]
403 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
404 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
405 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
406 |
407 | t_vals = torch.linspace(0., 1., steps=N_samples)
408 |
409 |
410 | pts = rays_o + normalizing(rays_d)*near
411 | viewdir = normalizing(rays_d)
412 | color_ = torch.zeros_like(pts)
413 | transmition_ = torch.ones((N_rays,1))
414 |
415 |
416 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device)
417 | step_size = far[0]-near[0]
418 | output = odeint(IoR_emissionAbsorptionODE(voxel_grid,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,scene_bound), y0, t_vals,method='euler')
419 |
420 | pts = output[:,:,0:3].permute(1,0,2)
421 |
422 | rgb_map = output[-1,:,6:9]
423 |
424 |
425 | ret = {'rgb_map' : rgb_map}
426 | if retraw:
427 | ret['raw'] = pts
428 |
429 | for k in ret:
430 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
431 | print(f"! [Numerical Error] {k} contains nan or inf.")
432 |
433 | return ret
434 |
435 |
436 |
437 | def config_parser():
438 |
439 | import configargparse
440 | parser = configargparse.ArgumentParser()
441 | parser.add_argument('--config', is_config_file=True,
442 | help='config file path')
443 | parser.add_argument("--expname", default='glass_ball_new',type=str,
444 | help='experiment name')
445 | parser.add_argument("--basedir", type=str, default='logs',
446 | help='where to store ckpts and logs')
447 | parser.add_argument("--datadir", type=str, default='data\\glass_ball',
448 | help='input data directory')
449 |
450 | # training options
451 | parser.add_argument("--netdepth", type=int, default=8,
452 | help='layers in network')
453 | parser.add_argument("--netwidth", type=int, default=256,
454 | help='channels per layer')
455 | parser.add_argument("--netdepth_fine", type=int, default=8,
456 | help='layers in fine network')
457 | parser.add_argument("--netwidth_fine", type=int, default=256,
458 | help='channels per layer in fine network')
459 | parser.add_argument("--N_rand", type=int, default=32*32*32,
460 | help='batch size (number of random rays per gradient step)')
461 | parser.add_argument("--lrate", type=float, default=5e-4,
462 | help='learning rate')
463 | parser.add_argument("--lrate_decay", type=int, default=250,
464 | help='exponential learning rate decay (in 1000 steps)')
465 | parser.add_argument("--chunk", type=int, default=1024*32*32,
466 | help='number of rays processed in parallel, decrease if running out of memory')
467 | parser.add_argument("--netchunk", type=int, default=1024*64,
468 | help='number of pts sent through network in parallel, decrease if running out of memory')
469 | parser.add_argument("--no_batching", default=False,action='store_true',
470 | help='only take random rays from 1 image at a time')
471 | parser.add_argument("--no_reload", default=True, action='store_true',
472 | help='do not reload the weights for the IoR MLP')
473 | parser.add_argument("--model_path", type=str, default='model_weights',
474 | help='path to trained model weights')
475 |
476 |
477 | # IoR model parameters
478 |
479 | parser.add_argument("--netdepth_ior", type=int, default=6,
480 | help='layers in the IoR network')
481 | parser.add_argument("--netwidth_ior", type=int, default=64,
482 | help='channels per layer')
483 | parser.add_argument("--multires_views_ior", type=int, default=5,
484 | help='log2 of max freq for positional encoding (2D direction)')
485 |
486 |
487 | # rendering options
488 | parser.add_argument("--N_samples", type=int, default=128,
489 | help='number of coarse samples per ray')
490 | parser.add_argument("--N_importance", type=int,
491 | help='Not used in this code')
492 | parser.add_argument("--perturb", type=float, default=1.,
493 | help='set to 0. for no jitter, 1. for jitter')
494 | parser.add_argument("--use_viewdirs", default=True,action='store_true',
495 | help='use full 5D input instead of 3D')
496 | parser.add_argument("--use_mask", default=False,action='store_true',
497 | help='use full 5D input instead of 3D')
498 | parser.add_argument("--i_embed", type=int, default=0,
499 | help='set 0 for default positional encoding, -1 for none')
500 | parser.add_argument("--multires", type=int, default=10,
501 | help='log2 of max freq for positional encoding (3D location)')
502 | parser.add_argument("--multires_views", type=int, default=4,
503 | help='log2 of max freq for positional encoding (2D direction)')
504 | parser.add_argument("--raw_noise_std", type=float, default=1.,
505 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
506 |
507 | parser.add_argument("--render_only", action='store_true',
508 | help='do not optimize, reload weights and render out render_poses path')
509 | parser.add_argument("--render_test", action='store_true',
510 | help='render the test set instead of render_poses path')
511 | parser.add_argument("--render_factor", type=int, default=0,
512 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
513 |
514 | # training options
515 | parser.add_argument("--precrop_iters", type=int, default=0,
516 | help='number of steps to train on central crops')
517 | parser.add_argument("--precrop_frac", type=float,
518 | default=.5, help='fraction of img taken for central crops')
519 |
520 | # dataset options
521 | parser.add_argument("--dataset_type", type=str, default='llff',
522 | help='options: llff / blender / deepvoxels')
523 | parser.add_argument("--testskip", type=int, default=8,
524 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
525 |
526 | ## deepvoxels flags
527 | parser.add_argument("--shape", type=str, default='greek',
528 | help='options : armchair / cube / greek / vase')
529 |
530 | ## blender flags
531 | parser.add_argument("--white_bkgd", action='store_true',
532 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
533 | parser.add_argument("--half_res", action='store_true',
534 | help='load blender synthetic data at 400x400 instead of 800x800')
535 |
536 | ## llff flags
537 | parser.add_argument("--factor", type=int, default=1,
538 | help='downsample factor for LLFF images')
539 | parser.add_argument("--no_ndc", default = True,action='store_true',
540 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
541 | parser.add_argument("--lindisp", action='store_true',
542 | help='sampling linearly in disparity rather than depth')
543 | parser.add_argument("--spherify", default = True, action='store_true',
544 | help='set for spherical 360 scenes')
545 | parser.add_argument("--llffhold", type=int, default=10,
546 | help='will take every 1/N images as LLFF test set, paper uses 8')
547 |
548 | # logging/saving options
549 | parser.add_argument("--i_print", type=int, default=100,
550 | help='frequency of console printout and metric loggin')
551 | parser.add_argument("--i_img", type=int, default=50,
552 | help='frequency of tensorboard image logging')
553 | parser.add_argument("--i_weights", type=int, default=500,
554 | help='frequency of weight ckpt saving')
555 |
556 |
557 | return parser
558 |
559 |
560 |
561 | def train():
562 |
563 | parser = config_parser()
564 | args = parser.parse_args()
565 |
566 | # Load data
567 | K = None
568 | if args.dataset_type == 'llff':
569 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
570 | recenter=True, bd_factor=.75,
571 | spherify=args.spherify)
572 | hwf = poses[0,:3,-1]
573 | poses = poses[:,:3,:4]
574 | i_test = []
575 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
576 | if not isinstance(i_test, list):
577 | i_test = [i_test]
578 |
579 | if args.llffhold > 0:
580 | print('Auto LLFF holdout,', args.llffhold)
581 | i_test = np.arange(images.shape[0])[::args.llffhold]
582 |
583 | i_val = i_test
584 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
585 | (i not in i_test and i not in i_val)])
586 |
587 | print('DEFINING BOUNDS')
588 | if args.no_ndc:
589 | near = np.ndarray.min(bds) * .9
590 | far = np.ndarray.max(bds) * 1.
591 |
592 | else:
593 | near = 0.
594 | far = 1.
595 |
596 |
597 | if args.spherify:
598 | far = np.minimum(far,2.0)
599 | print('NEAR FAR', near, far)
600 |
601 |
602 | # Cast intrinsics to right types
603 | H, W, focal = hwf
604 | H, W = int(H), int(W)
605 | hwf = [H, W, focal]
606 |
607 | if K is None:
608 | K = np.array([
609 | [focal, 0, 0.5*W],
610 | [0, focal, 0.5*H],
611 | [0, 0, 1]
612 | ])
613 |
614 | if args.render_test:
615 | render_poses = np.array(poses[i_test])
616 |
617 | # Create log dir and copy the config file
618 | basedir = args.basedir
619 | expname = args.expname
620 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
621 | f = os.path.join(basedir, expname, 'args.txt')
622 | with open(f, 'w') as file:
623 | for arg in sorted(vars(args)):
624 | attr = getattr(args, arg)
625 | file.write('{} = {}\n'.format(arg, attr))
626 | if args.config is not None:
627 | f = os.path.join(basedir, expname, 'config.txt')
628 | with open(f, 'w') as file:
629 | file.write(open(args.config, 'r').read())
630 |
631 |
632 |
633 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args)
634 | start = 0
635 | global_step = start
636 |
637 |
638 | bds_dict = {
639 | 'near' : near,
640 | 'far' : far,
641 | }
642 | render_kwargs_train.update(bds_dict)
643 | render_kwargs_test.update(bds_dict)
644 |
645 |
646 | print('ray batching')
647 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
648 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
649 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
650 | rays_rgb = np.reshape(rays_rgb, [len(images),-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
651 | rays_rgb = rays_rgb.astype(np.float32)
652 | rays_rgb = torch.Tensor(rays_rgb).to(device)
653 | rays_rgb = [rays_rgb[i] for i in i_train] # train images only
654 |
655 | poses = torch.Tensor(poses).to(device)
656 | render_poses = torch.Tensor(render_poses).to(device)
657 |
658 |
659 | print('loading mask images')
660 | path_mask = os.path.join(basedir, expname, 'masked_regions')
661 | inside_idx = get_mask_idx(path_mask,len(poses),i_train) # concatenating the indices of masked region in each view
662 |
663 |
664 | print('calculating the scene bounds for 3D grid')
665 | scene_bound = get_scene_bound(near,far,H,W,K,poses,min_=.01,max_=.01)
666 |
667 | render_kwargs_test['scene_bound'] = scene_bound
668 | render_kwargs_train['scene_bound'] = scene_bound
669 |
670 |
671 | print('loading bounding box values')
672 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy'))
673 |
674 |
675 | render_kwargs_test['bb_vals'] = bounding_box_vals
676 | render_kwargs_train['bb_vals'] = bounding_box_vals
677 |
678 |
679 | print('fitting a 3D grid to learnt radiance field (Nerf)')
680 | query_function = render_kwargs_test['network_query_fn']
681 | nerf_coarse_model = render_kwargs_test['network_fn']
682 | Grid_res = 128
683 |
684 | reuse = False
685 | testimgdir = os.path.join(basedir, expname, 'voxel_grid.npy')
686 |
687 | skip_im = 4
688 | if reuse:
689 | voxel_grid = torch.tensor(np.load(testimgdir)).to(device)
690 | else:
691 | voxel_grid = get_voxel_grid(Grid_res,scene_bound,poses[0:-1:skip_im],query_function,nerf_coarse_model,False,bounding_box_vals)
692 | np.save(testimgdir,voxel_grid.cpu().numpy())
693 |
694 |
695 |
696 | N_iters = 5000 + 1 # number of total iteration
697 | iter_level = 1000 # number of iteration per each level
698 | sigma = 0.08 # initial sigma for the 3D Gaussian blur kernel
699 | N_rand = args.N_rand
700 | print('Begin')
701 | print('TEST views are', i_test)
702 |
703 | start = start + 0
704 | loss_vals = []
705 |
706 |
707 | for i in trange(start,N_iters):
708 | time0 = time.time()
709 |
710 |
711 | # taking random rays overl all views
712 | batch = get_rnd(rays_rgb,inside_idx,N_rand)
713 | batch = torch.transpose(batch, 0, 1)
714 |
715 | batch_rays, target_s = batch[:2], batch[2]
716 |
717 |
718 | # creating a 3D Gaussian blur kernel
719 | filter_3d = lowpass_3d(Grid_res,sigma*2**(i//iter_level))
720 |
721 | # smoothing the grid with the 3D Gaussian blur kernel in Fourier domain
722 | voxel_grid_blur = voxel_lowpass_filtering(voxel_grid,filter_3d)
723 |
724 | render_kwargs_train['voxel_grid'] = voxel_grid_blur
725 | render_kwargs_train['N_samples'] = 128
726 |
727 |
728 | ##### Core optimization loop #####
729 | rgb, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
730 | verbose=i < 10, retraw=False,
731 | **render_kwargs_train)
732 |
733 | loss = img2mse(rgb, target_s)
734 | psnr = mse2psnr(loss)
735 |
736 |
737 |
738 | loss.backward()
739 | optimizer.step()
740 | optimizer.zero_grad()
741 |
742 |
743 | decay_rate = 0.1
744 | new_lrate = args.lrate* (decay_rate ** (global_step / N_iters))
745 | for param_group in optimizer.param_groups:
746 | param_group['lr'] = new_lrate
747 | ################################
748 |
749 | dt = time.time()-time0
750 |
751 | loss_vals.append(loss.detach().cpu().numpy())
752 |
753 |
754 | if i%args.i_weights==0:
755 | model_path_dir = os.path.join(basedir, expname,args.model_path)
756 | os.makedirs(model_path_dir, exist_ok=True)
757 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i))
758 | torch.save({
759 | 'global_step': global_step,
760 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
761 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
762 | 'network_ior_state_dict': render_kwargs_train['network_ior'].state_dict(),
763 | 'optimizer_state_dict': optimizer.state_dict(),
764 | }, path)
765 | print('Saved checkpoints at', path)
766 |
767 |
768 |
769 | if i%args.i_img==0:
770 |
771 | render_kwargs_test['voxel_grid'] = voxel_grid # rendering images with finest version of the grid
772 | img_i= i_test[5]
773 | pose = poses[img_i, :3,:4]
774 | with torch.no_grad():
775 | rgb ,extras = render(H, W,K, chunk=args.chunk, c2w=pose,
776 | **render_kwargs_test)
777 |
778 |
779 | testimgdir = os.path.join(basedir, expname, 'training_ior')
780 | os.makedirs(testimgdir, exist_ok=True)
781 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:03d}.png'.format(i)), to8b(rgb.cpu().numpy()))
782 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:03d}.png'.format(img_i)), to8b(images[img_i]))
783 |
784 | # plt.plot(loss_vals)
785 | # plt.show()
786 |
787 | if i%args.i_print==0:
788 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
789 |
790 |
791 | global_step += 1
792 |
793 |
794 | if __name__=='__main__':
795 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
796 |
797 | train()
--------------------------------------------------------------------------------
/run_nerf.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import imageio
4 | import json
5 | import random
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from tqdm import tqdm, trange
11 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from run_nerf_helpers import *
16 |
17 | from load_llff import load_llff_data
18 | from load_deepvoxels import load_dv_data
19 | from load_blender import load_blender_data
20 | from load_LINEMOD import load_LINEMOD_data
21 |
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 | np.random.seed(0)
25 | DEBUG = False
26 |
27 |
28 | def batchify(fn, chunk):
29 | """Constructs a version of 'fn' that applies to smaller batches.
30 | """
31 | if chunk is None:
32 | return fn
33 | def ret(inputs):
34 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
35 | return ret
36 |
37 |
38 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
39 | """Prepares inputs and applies network 'fn'.
40 | """
41 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
42 | embedded = embed_fn(inputs_flat)
43 |
44 | if viewdirs is not None:
45 | input_dirs = viewdirs[:,None].expand(inputs.shape)
46 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
47 | embedded_dirs = embeddirs_fn(input_dirs_flat)
48 | embedded = torch.cat([embedded, embedded_dirs], -1)
49 |
50 | outputs_flat = batchify(fn, netchunk)(embedded)
51 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
52 | return outputs
53 |
54 |
55 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
56 | """Render rays in smaller minibatches to avoid OOM.
57 | """
58 | all_ret = {}
59 | for i in range(0, rays_flat.shape[0], chunk):
60 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
61 | for k in ret:
62 | if k not in all_ret:
63 | all_ret[k] = []
64 | all_ret[k].append(ret[k])
65 |
66 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
67 | return all_ret
68 |
69 |
70 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
71 | near=0., far=1.,
72 | use_viewdirs=False, c2w_staticcam=None,
73 | **kwargs):
74 | """Render rays
75 | Args:
76 | H: int. Height of image in pixels.
77 | W: int. Width of image in pixels.
78 | focal: float. Focal length of pinhole camera.
79 | chunk: int. Maximum number of rays to process simultaneously. Used to
80 | control maximum memory usage. Does not affect final results.
81 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
82 | each example in batch.
83 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
84 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
85 | near: float or array of shape [batch_size]. Nearest distance for a ray.
86 | far: float or array of shape [batch_size]. Farthest distance for a ray.
87 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
88 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
89 | camera while using other c2w argument for viewing directions.
90 | Returns:
91 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
92 | disp_map: [batch_size]. Disparity map. Inverse of depth.
93 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
94 | extras: dict with everything returned by render_rays().
95 | """
96 | if c2w is not None:
97 | # special case to render full image
98 | rays_o, rays_d = get_rays(H, W, K, c2w)
99 | else:
100 | # use provided ray batch
101 | rays_o, rays_d = rays
102 |
103 | if use_viewdirs:
104 | # provide ray directions as input
105 | viewdirs = rays_d
106 | if c2w_staticcam is not None:
107 | # special case to visualize effect of viewdirs
108 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
109 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
110 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
111 |
112 | sh = rays_d.shape # [..., 3]
113 | if ndc:
114 | # for forward facing scenes
115 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
116 |
117 | # Create ray batch
118 | rays_o = torch.reshape(rays_o, [-1,3]).float()
119 | rays_d = torch.reshape(rays_d, [-1,3]).float()
120 |
121 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
122 | rays = torch.cat([rays_o, rays_d, near, far], -1)
123 | if use_viewdirs:
124 | rays = torch.cat([rays, viewdirs], -1)
125 |
126 | # Render and reshape
127 | all_ret = batchify_rays(rays, chunk, **kwargs)
128 | for k in all_ret:
129 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
130 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
131 |
132 | k_extract = ['rgb_map', 'disp_map', 'acc_map']
133 | ret_list = [all_ret[k] for k in k_extract]
134 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
135 | return ret_list + [ret_dict]
136 |
137 |
138 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
139 |
140 | H, W, focal = hwf
141 |
142 | if render_factor!=0:
143 | # Render downsampled for speed
144 | H = H//render_factor
145 | W = W//render_factor
146 | focal = focal/render_factor
147 |
148 | rgbs = []
149 | disps = []
150 |
151 | t = time.time()
152 | for i, c2w in enumerate(tqdm(render_poses)):
153 | print(i, time.time() - t)
154 | t = time.time()
155 | rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
156 | rgbs.append(rgb.cpu().numpy())
157 | disps.append(disp.cpu().numpy())
158 | if i==0:
159 | print(rgb.shape, disp.shape)
160 |
161 | """
162 | if gt_imgs is not None and render_factor==0:
163 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
164 | print(p)
165 | """
166 |
167 | if savedir is not None:
168 | rgb8 = to8b(rgbs[-1])
169 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
170 | imageio.imwrite(filename, rgb8)
171 |
172 |
173 | rgbs = np.stack(rgbs, 0)
174 | disps = np.stack(disps, 0)
175 |
176 | return rgbs, disps
177 |
178 |
179 | def create_nerf(args):
180 | """Instantiate NeRF's MLP model.
181 | """
182 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
183 |
184 | input_ch_views = 0
185 | embeddirs_fn = None
186 | if args.use_viewdirs:
187 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
188 | output_ch = 5 if args.N_importance > 0 else 4
189 | skips = [4]
190 | model = NeRF(D=args.netdepth, W=args.netwidth,
191 | input_ch=input_ch, output_ch=output_ch, skips=skips,
192 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
193 | grad_vars = list(model.parameters())
194 |
195 | model_fine = None
196 | if args.N_importance > 0:
197 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
198 | input_ch=input_ch, output_ch=output_ch, skips=skips,
199 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
200 | grad_vars += list(model_fine.parameters())
201 |
202 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
203 | embed_fn=embed_fn,
204 | embeddirs_fn=embeddirs_fn,
205 | netchunk=args.netchunk)
206 |
207 | # Create optimizer
208 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
209 |
210 | start = 0
211 | basedir = args.basedir
212 | expname = args.expname
213 |
214 | ##########################
215 |
216 | # Load checkpoints
217 | os.makedirs(os.path.join(basedir, expname,args.model_path), exist_ok=True)
218 |
219 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f]
220 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path))
221 | print('Found ckpts', ckpts)
222 | if len(ckpts) > 0 and not args.no_reload:
223 | ckpt_path = ckpts[-1]
224 | print('Reloading from', ckpt_path)
225 | ckpt = torch.load(ckpt_path)
226 |
227 | start = ckpt['global_step']
228 | optimizer.load_state_dict(ckpt['optimizer_state_dict'])
229 |
230 | # Load model
231 | model.load_state_dict(ckpt['network_fn_state_dict'])
232 | if model_fine is not None:
233 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
234 |
235 | ##########################
236 |
237 | render_kwargs_train = {
238 | 'network_query_fn' : network_query_fn,
239 | 'perturb' : args.perturb,
240 | 'N_importance' : args.N_importance,
241 | 'network_fine' : model_fine,
242 | 'N_samples' : args.N_samples,
243 | 'network_fn' : model,
244 | 'use_viewdirs' : args.use_viewdirs,
245 | 'white_bkgd' : args.white_bkgd,
246 | 'raw_noise_std' : args.raw_noise_std,
247 | }
248 |
249 | # NDC only good for LLFF-style forward facing data
250 | if args.dataset_type != 'llff' or args.no_ndc:
251 | print('Not ndc!')
252 | render_kwargs_train['ndc'] = False
253 | render_kwargs_train['lindisp'] = args.lindisp
254 |
255 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
256 | render_kwargs_test['perturb'] = False
257 | render_kwargs_test['raw_noise_std'] = 0.
258 |
259 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
260 |
261 |
262 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
263 | """Transforms model's predictions to semantically meaningful values.
264 | Args:
265 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
266 | z_vals: [num_rays, num_samples along ray]. Integration time.
267 | rays_d: [num_rays, 3]. Direction of each ray.
268 | Returns:
269 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
270 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
271 | acc_map: [num_rays]. Sum of weights along each ray.
272 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
273 | depth_map: [num_rays]. Estimated distance to object.
274 | """
275 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
276 |
277 | dists = z_vals[...,1:] - z_vals[...,:-1]
278 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
279 |
280 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
281 |
282 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
283 | noise = 0.
284 | if raw_noise_std > 0.:
285 | noise = torch.randn(raw[...,3].shape) * raw_noise_std
286 |
287 | # Overwrite randomly sampled data if pytest
288 | if pytest:
289 | np.random.seed(0)
290 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
291 | noise = torch.Tensor(noise)
292 |
293 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
294 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
295 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
296 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
297 |
298 | depth_map = torch.sum(weights * z_vals, -1)
299 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
300 | acc_map = torch.sum(weights, -1)
301 |
302 | if white_bkgd:
303 | rgb_map = rgb_map + (1.-acc_map[...,None])
304 |
305 | return rgb_map, disp_map, acc_map, weights, depth_map
306 |
307 |
308 | def render_rays(ray_batch,
309 | network_fn,
310 | network_query_fn,
311 | N_samples,
312 | retraw=False,
313 | lindisp=False,
314 | perturb=0.,
315 | N_importance=0,
316 | network_fine=None,
317 | white_bkgd=False,
318 | raw_noise_std=0.,
319 | verbose=False,
320 | pytest=False):
321 | """Volumetric rendering.
322 | Args:
323 | ray_batch: array of shape [batch_size, ...]. All information necessary
324 | for sampling along a ray, including: ray origin, ray direction, min
325 | dist, max dist, and unit-magnitude viewing direction.
326 | network_fn: function. Model for predicting RGB and density at each point
327 | in space.
328 | network_query_fn: function used for passing queries to network_fn.
329 | N_samples: int. Number of different times to sample along each ray.
330 | retraw: bool. If True, include model's raw, unprocessed predictions.
331 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
332 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
333 | random points in time.
334 | N_importance: int. Number of additional times to sample along each ray.
335 | These samples are only passed to network_fine.
336 | network_fine: "fine" network with same spec as network_fn.
337 | white_bkgd: bool. If True, assume a white background.
338 | raw_noise_std: ...
339 | verbose: bool. If True, print more debugging info.
340 | Returns:
341 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
342 | disp_map: [num_rays]. Disparity map. 1 / depth.
343 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
344 | raw: [num_rays, num_samples, 4]. Raw predictions from model.
345 | rgb0: See rgb_map. Output for coarse model.
346 | disp0: See disp_map. Output for coarse model.
347 | acc0: See acc_map. Output for coarse model.
348 | z_std: [num_rays]. Standard deviation of distances along ray for each
349 | sample.
350 | """
351 | N_rays = ray_batch.shape[0]
352 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
353 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
354 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
355 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
356 |
357 | t_vals = torch.linspace(0., 1., steps=N_samples)
358 | if not lindisp:
359 | z_vals = near * (1.-t_vals) + far * (t_vals)
360 | else:
361 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
362 |
363 | z_vals = z_vals.expand([N_rays, N_samples])
364 |
365 | if perturb > 0.:
366 | # get intervals between samples
367 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
368 | upper = torch.cat([mids, z_vals[...,-1:]], -1)
369 | lower = torch.cat([z_vals[...,:1], mids], -1)
370 | # stratified samples in those intervals
371 | t_rand = torch.rand(z_vals.shape)
372 |
373 | # Pytest, overwrite u with numpy's fixed random numbers
374 | if pytest:
375 | np.random.seed(0)
376 | t_rand = np.random.rand(*list(z_vals.shape))
377 | t_rand = torch.Tensor(t_rand)
378 |
379 | z_vals = lower + (upper - lower) * t_rand
380 |
381 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
382 |
383 | raw = network_query_fn(pts, viewdirs, network_fn)
384 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
385 |
386 | if N_importance > 0:
387 |
388 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
389 |
390 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
391 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
392 | z_samples = z_samples.detach()
393 |
394 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
395 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
396 |
397 | run_fn = network_fn if network_fine is None else network_fine
398 | raw = network_query_fn(pts, viewdirs, run_fn)
399 |
400 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
401 |
402 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
403 | if retraw:
404 | ret['raw'] = raw
405 | if N_importance > 0:
406 | ret['rgb0'] = rgb_map_0
407 | ret['disp0'] = disp_map_0
408 | ret['acc0'] = acc_map_0
409 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
410 |
411 | for k in ret:
412 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
413 | print(f"! [Numerical Error] {k} contains nan or inf.")
414 |
415 | return ret
416 |
417 |
418 | def config_parser():
419 |
420 | import configargparse
421 | parser = configargparse.ArgumentParser()
422 | parser.add_argument('--config', is_config_file=True,
423 | help='config file path')
424 | parser.add_argument("--expname", default='Ball',type=str,
425 | help='experiment name')
426 | parser.add_argument("--basedir", type=str, default='logs',
427 | help='where to store ckpts and logs')
428 | parser.add_argument("--datadir", type=str, default='data\\Ball',
429 | help='input data directory')
430 |
431 | # training options
432 | parser.add_argument("--netdepth", type=int, default=8,
433 | help='layers in network')
434 | parser.add_argument("--netwidth", type=int, default=256,
435 | help='channels per layer')
436 | parser.add_argument("--netdepth_fine", type=int, default=8,
437 | help='layers in fine network')
438 | parser.add_argument("--netwidth_fine", type=int, default=256,
439 | help='channels per layer in fine network')
440 | parser.add_argument("--N_rand", type=int, default=32*32,
441 | help='batch size (number of random rays per gradient step)')
442 | parser.add_argument("--lrate", type=float, default=5e-4,
443 | help='learning rate')
444 | parser.add_argument("--lrate_decay", type=int, default=150,
445 | help='exponential learning rate decay (in 1000 steps)')
446 | parser.add_argument("--chunk", type=int, default=1024*32,
447 | help='number of rays processed in parallel, decrease if running out of memory')
448 | parser.add_argument("--netchunk", type=int, default=1024*64,
449 | help='number of pts sent through network in parallel, decrease if running out of memory')
450 | parser.add_argument("--no_batching", default=False,action='store_true',
451 | help='only take random rays from 1 image at a time')
452 | parser.add_argument("--no_reload", action='store_true',
453 | help='do not reload weights from saved ckpt')
454 | parser.add_argument("--model_path", type=str, default='model_weights',
455 | help='path to trained model weights')
456 | # rendering options
457 | parser.add_argument("--N_samples", type=int, default=64,
458 | help='number of coarse samples per ray')
459 | parser.add_argument("--N_importance", type=int, default=64,
460 | help='number of additional fine samples per ray')
461 | parser.add_argument("--perturb", type=float, default=1.,
462 | help='set to 0. for no jitter, 1. for jitter')
463 | parser.add_argument("--use_viewdirs", default=True,action='store_true',
464 | help='use full 5D input instead of 3D')
465 | parser.add_argument("--i_embed", type=int, default=0,
466 | help='set 0 for default positional encoding, -1 for none')
467 | parser.add_argument("--multires", type=int, default=10,
468 | help='log2 of max freq for positional encoding (3D location)')
469 | parser.add_argument("--multires_views", type=int, default=4,
470 | help='log2 of max freq for positional encoding (2D direction)')
471 | parser.add_argument("--raw_noise_std", type=float, default=1.,
472 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
473 |
474 | parser.add_argument("--render_only", action='store_true',
475 | help='do not optimize, reload weights and render out render_poses path')
476 | parser.add_argument("--render_test", action='store_true',
477 | help='render the test set instead of render_poses path')
478 | parser.add_argument("--render_factor", type=int, default=0,
479 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
480 |
481 | # training options
482 | parser.add_argument("--precrop_iters", type=int, default=0,
483 | help='number of steps to train on central crops')
484 | parser.add_argument("--precrop_frac", type=float,
485 | default=.5, help='fraction of img taken for central crops')
486 |
487 | # dataset options
488 | parser.add_argument("--dataset_type", type=str, default='llff',
489 | help='options: llff / blender / deepvoxels')
490 | parser.add_argument("--testskip", type=int, default=None,
491 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
492 |
493 | ## deepvoxels flags
494 | parser.add_argument("--shape", type=str, default='greek',
495 | help='options : armchair / cube / greek / vase')
496 |
497 | ## blender flags
498 | parser.add_argument("--white_bkgd", action='store_true',
499 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
500 | parser.add_argument("--half_res", action='store_true',
501 | help='load blender synthetic data at 400x400 instead of 800x800')
502 |
503 | ## llff flags
504 | parser.add_argument("--factor", type=int, default=1,
505 | help='downsample factor for LLFF images')
506 | parser.add_argument("--no_ndc", default=True, action='store_true',
507 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
508 | parser.add_argument("--lindisp", default=False, action='store_true',
509 | help='sampling linearly in disparity rather than depth')
510 | parser.add_argument("--spherify", default=True,action='store_true',
511 | help='set for spherical 360 scenes')
512 | parser.add_argument("--llffhold", type=int, default=10,
513 | help='will take every 1/N images as LLFF test set, paper uses 8')
514 |
515 | # logging/saving options
516 | parser.add_argument("--i_print", type=int, default=100,
517 | help='frequency of console printout and metric loggin')
518 | parser.add_argument("--i_img", type=int, default=1000,
519 | help='frequency of tensorboard image logging')
520 | parser.add_argument("--i_weights", type=int, default=1000,
521 | help='frequency of weight ckpt saving')
522 | parser.add_argument("--i_testset", type=int, default=400000,
523 | help='frequency of testset saving')
524 | parser.add_argument("--i_video", type=int, default=150000,
525 | help='frequency of render_poses video saving')
526 |
527 | return parser
528 |
529 |
530 | def train():
531 |
532 | parser = config_parser()
533 | args = parser.parse_args()
534 |
535 | # Load data
536 | K = None
537 | if args.dataset_type == 'llff':
538 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
539 | recenter=True, bd_factor=.75,
540 | spherify=args.spherify)
541 | hwf = poses[0,:3,-1]
542 | poses = poses[:,:3,:4]
543 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
544 | if not isinstance(i_test, list):
545 | i_test = [i_test]
546 |
547 | if args.llffhold > 0:
548 | print('Auto LLFF holdout,', args.llffhold)
549 | i_test = np.arange(images.shape[0])[::args.llffhold]
550 |
551 | i_val = i_test
552 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if
553 | (i not in i_test and i not in i_val)])
554 |
555 | print('DEFINING BOUNDS')
556 | if args.no_ndc:
557 | near = np.ndarray.min(bds) * .9
558 | far = np.ndarray.max(bds) * 1.
559 |
560 | else:
561 | near = 0.
562 | far = 1.
563 | print('NEAR FAR', near, far)
564 |
565 | elif args.dataset_type == 'blender':
566 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
567 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
568 | i_train, i_val, i_test = i_split
569 |
570 | near = 2.
571 | far = 6.
572 |
573 | if args.white_bkgd:
574 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
575 | else:
576 | images = images[...,:3]
577 |
578 | elif args.dataset_type == 'LINEMOD':
579 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
580 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
581 | print(f'[CHECK HERE] near: {near}, far: {far}.')
582 | i_train, i_val, i_test = i_split
583 |
584 | if args.white_bkgd:
585 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
586 | else:
587 | images = images[...,:3]
588 |
589 | elif args.dataset_type == 'deepvoxels':
590 |
591 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
592 | basedir=args.datadir,
593 | testskip=args.testskip)
594 |
595 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
596 | i_train, i_val, i_test = i_split
597 |
598 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
599 | near = hemi_R-1.
600 | far = hemi_R+1.
601 |
602 | else:
603 | print('Unknown dataset type', args.dataset_type, 'exiting')
604 | return
605 |
606 | # Cast intrinsics to right types
607 | H, W, focal = hwf
608 | H, W = int(H), int(W)
609 | hwf = [H, W, focal]
610 |
611 | if K is None:
612 | K = np.array([
613 | [focal, 0, 0.5*W],
614 | [0, focal, 0.5*H],
615 | [0, 0, 1]
616 | ])
617 |
618 | if args.render_test:
619 | render_poses = np.array(poses[i_test])
620 |
621 | # Create log dir and copy the config file
622 | basedir = args.basedir
623 | expname = args.expname
624 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
625 | f = os.path.join(basedir, expname, 'args.txt')
626 | with open(f, 'w') as file:
627 | for arg in sorted(vars(args)):
628 | attr = getattr(args, arg)
629 | file.write('{} = {}\n'.format(arg, attr))
630 | if args.config is not None:
631 | f = os.path.join(basedir, expname, 'config.txt')
632 | with open(f, 'w') as file:
633 | file.write(open(args.config, 'r').read())
634 |
635 | # Create nerf model
636 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
637 | global_step = start
638 |
639 | bds_dict = {
640 | 'near' : near,
641 | 'far' : far,
642 | }
643 | render_kwargs_train.update(bds_dict)
644 | render_kwargs_test.update(bds_dict)
645 |
646 | # Move testing data to GPU
647 | render_poses = torch.Tensor(render_poses).to(device)
648 |
649 | # Short circuit if only rendering out from trained model
650 | if args.render_only:
651 | print('RENDER ONLY')
652 | with torch.no_grad():
653 | if args.render_test:
654 | # render_test switches to test poses
655 | images = images[i_test]
656 | else:
657 | # Default is smoother render_poses path
658 | images = None
659 |
660 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
661 | os.makedirs(testsavedir, exist_ok=True)
662 | print('test poses shape', render_poses.shape)
663 |
664 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
665 | print('Done rendering', testsavedir)
666 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
667 |
668 | return
669 |
670 | # Prepare raybatch tensor if batching random rays
671 | N_rand = args.N_rand
672 | use_batching = not args.no_batching
673 | if use_batching:
674 | # For random ray batching
675 | print('get rays')
676 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
677 | print('done, concats')
678 |
679 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
680 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
681 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
682 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
683 | rays_rgb = rays_rgb.astype(np.float32)
684 | print('shuffle rays')
685 | np.random.shuffle(rays_rgb)
686 | print('done')
687 | i_batch = 0
688 |
689 | # Move training data to GPU
690 | if use_batching:
691 | images = torch.Tensor(images).to(device)
692 | poses = torch.Tensor(poses).to(device)
693 | if use_batching:
694 | rays_rgb = torch.Tensor(rays_rgb).to(device)
695 |
696 |
697 |
698 |
699 |
700 |
701 | N_iters = 150000 + 1
702 | print('Begin')
703 | print('TRAIN views are', i_train)
704 | print('TEST views are', i_test)
705 | print('VAL views are', i_val)
706 |
707 | # Summary writers
708 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
709 |
710 | start = start + 1
711 | for i in trange(start, N_iters):
712 | time0 = time.time()
713 |
714 | # Sample random ray batch
715 | if use_batching:
716 | # Random over all images
717 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
718 | batch = torch.transpose(batch, 0, 1)
719 |
720 | batch_rays, target_s = batch[:2], batch[2]
721 |
722 | i_batch += N_rand
723 | if i_batch >= rays_rgb.shape[0]:
724 | print("Shuffle data after an epoch!")
725 | rand_idx = torch.randperm(rays_rgb.shape[0])
726 | rays_rgb = rays_rgb[rand_idx]
727 | i_batch = 0
728 |
729 | else:
730 | # Random from one image
731 | img_i = np.random.choice(i_train)
732 | target = images[img_i]
733 | target = torch.Tensor(target).to(device)
734 | pose = poses[img_i, :3,:4]
735 |
736 | if N_rand is not None:
737 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
738 |
739 | if i < args.precrop_iters:
740 | dH = int(H//2 * args.precrop_frac)
741 | dW = int(W//2 * args.precrop_frac)
742 | coords = torch.stack(
743 | torch.meshgrid(
744 | torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
745 | torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
746 | ), -1)
747 | if i == start:
748 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
749 | else:
750 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
751 |
752 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
753 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
754 | select_coords = coords[select_inds].long() # (N_rand, 2)
755 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
756 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
757 | batch_rays = torch.stack([rays_o, rays_d], 0)
758 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
759 |
760 | ##### Core optimization loop #####
761 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
762 | verbose=i < 10, retraw=True,
763 | **render_kwargs_train)
764 |
765 | optimizer.zero_grad()
766 | img_loss = img2mse(rgb, target_s)
767 |
768 | loss = img_loss
769 | psnr = mse2psnr(img_loss)
770 |
771 | if 'rgb0' in extras:
772 | img_loss0 = img2mse(extras['rgb0'], target_s)
773 | loss = loss + img_loss0
774 | psnr0 = mse2psnr(img_loss0)
775 |
776 | loss.backward()
777 | optimizer.step()
778 |
779 | # NOTE: IMPORTANT!
780 | ### update learning rate ###
781 | decay_rate = 0.1
782 | decay_steps = args.lrate_decay * 1000
783 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
784 | for param_group in optimizer.param_groups:
785 | param_group['lr'] = new_lrate
786 | ################################
787 |
788 | dt = time.time()-time0
789 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
790 | ##### end #####
791 |
792 | # Rest is logging
793 | if i%args.i_weights==0:
794 | model_path_dir = os.path.join(basedir, expname,args.model_path)
795 | os.makedirs(model_path_dir, exist_ok=True)
796 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i))
797 |
798 | torch.save({
799 | 'global_step': global_step,
800 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
801 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
802 | 'optimizer_state_dict': optimizer.state_dict(),
803 | }, path)
804 | print('Saved checkpoints at', path)
805 |
806 | if i%args.i_video==0 and i > 0:
807 | # Turn on testing mode
808 | with torch.no_grad():
809 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
810 | print('Done, saving', rgbs.shape, disps.shape)
811 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
812 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=15, quality=8)
813 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
814 |
815 | if i%args.i_testset==0 and i > 0:
816 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
817 | os.makedirs(testsavedir, exist_ok=True)
818 | print('test poses shape', poses[i_test].shape)
819 | with torch.no_grad():
820 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
821 | print('Saved test set')
822 |
823 | if i%args.i_img==0:
824 |
825 | # Log a rendered validation view to Tensorboard
826 | img_i= i_test[5] #np.random.choice(i_val)
827 | target = images[img_i]
828 | pose = poses[img_i, :3,:4]
829 | with torch.no_grad():
830 | rgb, disp, acc, extras = render(H, W,K, chunk=args.chunk, c2w=pose,
831 | **render_kwargs_test)
832 | testimgdir = os.path.join(basedir, expname, 'training_nerf')
833 | os.makedirs(testimgdir, exist_ok=True)
834 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:06d}.png'.format(i)), to8b(rgb.cpu().numpy()))
835 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:06d}.png'.format(img_i)), to8b(images[img_i].cpu().numpy()))
836 |
837 |
838 | if i%args.i_print==0:
839 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
840 |
841 |
842 | global_step += 1
843 |
844 |
845 | if __name__=='__main__':
846 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
847 |
848 | train()
849 |
--------------------------------------------------------------------------------