├── EDI.py
├── LICENSE
├── Papers.pdf
├── README.md
├── config_real-world.txt
├── config_synthetic.txt
├── event_loss_helpers.py
├── figure.png
├── load_LINEMOD.py
├── load_blender.py
├── load_deepvoxels.py
├── load_event.py
├── load_llff.py
├── orginal_data&preprocessing
├── real-world
│ └── Data-preprocessing
│ │ └── data-preprocessing.py
└── synthetic
│ ├── Blur-synthesizing
│ ├── README.md
│ ├── main.py
│ ├── process.py
│ ├── unprocess.py
│ └── util.py
│ └── Event-preprocessing
│ └── event-preprocessing.py
├── requirements.txt
├── run_nerf_exp.py
└── run_nerf_helpers.py
/EDI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import math
4 |
5 |
6 | # -------------------EDI MODEL----------------------
7 | def EDI_rgb(data_path, frame_num, bin_num, events):
8 | event_sum = torch.zeros(260, 346)
9 | EDI = torch.ones(260, 346)
10 |
11 | for i in range(bin_num):
12 | event_sum = event_sum + events[frame_num][i]
13 | EDI = EDI + torch.exp(0.3 * event_sum)
14 |
15 | EDI = torch.stack([EDI, EDI, EDI], axis=-1)
16 | img = (bin_num + 1) * blurry_image / EDI
17 | img = torch.clamp(img, max=255)
18 | cv2.imwrite(data_path + "images_for_colmap/{0:03d}.jpg".format(frame_num * (bin_num + 1)), img.numpy()) # save the first deblurred image
19 |
20 | offset = torch.zeros(260, 346)
21 | for i in range(bin_num):
22 | offset = offset + events[frame_num][i]
23 | imgs = img * torch.exp(0.3 * torch.stack([offset, offset, offset], axis=-1))
24 | cv2.imwrite(data_path + "images_for_colmap/{0:03d}.jpg".format(frame_num * (bin_num + 1) + 1 + i), imgs.numpy()) # save the rest of the deblurred images
25 |
26 | threshold = 0.3
27 | bin_num = 4
28 | view_num = 30
29 | data_path = "./data/"
30 | events = torch.loadtxt(data_path + "events.pt").view(view_num, bin_num, 260, 346) # load the preprocessed events in .pt file
31 |
32 | for i in range(view_num):
33 | blurry_image = torch.tensor(cv2.imread(data_path + "images/{0:03d}.jpg".format(i * (bin_num + 1))), dtype=torch.float) # load the target blurry image
34 | EDI_rgb(data_path, i, bin_num, events)
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 CVTEAM
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Papers.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iCVTEAM/E2NeRF/487fec175c8729fdb49c22a2de1a093ca9192fe9/Papers.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code for E2NeRF: Event Enhanced Neural Radiance Fields from Blurry Images (ICCV 2023)
2 | This is an official PyTorch implementation of the E2NeRF. Click [here](https://icvteam.github.io/E2NeRF.html) to see the video and supplementary materials in our project website.
3 |
4 |
5 |
6 | ## Method Overview
7 |
8 | 
9 |
10 |
11 |
12 | ## Installation
13 | The code is based on [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) and use the same environment.
14 | Please refer to its github [website](https://github.com/yenchenlin/nerf-pytorch) for the environment installation.
15 |
16 |
17 |
18 | ## Code
19 |
20 | ### Synthetic Data
21 |
22 | The configs of the synthetic data are in the config_synthetic.txt file. Please download the synthetic data below and put it into the corresponding file (./data/synthetic/). Then you can use the command below to train the model.
23 |
24 | ```
25 | python run_nerf_exp.py --config config_synthetic.txt
26 | ```
27 |
28 | ### Real-World Data
29 |
30 | The configs of the real-world data are in the config_real-world.txt file. Please download the real-world data below and put it into the corresponding file (./data/real-world/). Then you can use the command below to train the model.
31 |
32 | ```
33 | python run_nerf_exp.py --config config_real-world.txt
34 | ```
35 |
36 | Notice that for real-world data experiement in the paper, we use the poses for the video rendering to render the novel view images.
37 | Please just replace the "test_poses" in line 853 of run_nerf_exp.py with "render_poses" to generate the novel view images (120 images in total).
38 | We use the no-reference image quality assessment metrics to evaluate the novel view images and blur view images together (120+30=150 images).
39 |
40 |
41 | ## Dataset
42 | Download the dataset [here](https://drive.google.com/drive/folders/1XhOEp4UdLL7EnDNyWdxxX8aRvzF53fWo?usp=sharing).
43 | The dataset contains the "data" for training and the "original data".
44 |
45 | ### Synthetic Data:
46 | For the file of each scene, there are training images in the "train" file and the corresponding event data "events.pt". The ground truth images are in the "test" file.
47 |
48 | Like in original NeRF, the training and testing poses are in the "transform_train.json" file and "transform_test.json" file.
49 | Notice that at the test time, we use the first pose of each view in "transform_test.json" to render the test images and the Ground Truth images are also rendered at this pose.
50 |
51 | ### Real-World Data:
52 | The structure is like original NeRF's llff data and the event data is in "event.pt".
53 |
54 | ### Event Data:
55 | For easy reading, we transform the event stream in to event bins as event.pt file. You can use pytorch to load the file. The shape of the tensor is (view_number, bin_number, H, W) and each element means the number of the events (positive and negative indicate polarity).
56 |
57 |
58 |
59 | ## Original Data & Preproccesing
60 | ### Synthetic Data:
61 | There are original images for synthesizing the blurry image and the code. Besides, we supply the original event data generated from v2e. We also provide the code to transform the ".txt" event to "events.pt" for E2NeRF training.
62 |
63 | ### Real-World Data:
64 | We supply the original ".aedat4" data captured by davis346 and the processing code in the file. We also convert the event data into events.pt for training.
65 |
66 | ### EDI:
67 | We update the EDI code in the repository.
68 | You can use this code to deblur the images in the "train" file with corresponding events.pt data.
69 | And the deblurred images are saved at "images_for_colmap" file.
70 | Then, you can use colmap to generate the poses as in NeRF.
71 |
72 |
73 |
74 | ## Citation
75 |
76 | If you find this useful, please consider citing our paper:
77 |
78 | ```bibtex
79 | @inproceedings{qi2023e2nerf,
80 | title={E2nerf: Event enhanced neural radiance fields from blurry images},
81 | author={Qi, Yunshan and Zhu, Lin and Zhang, Yu and Li, Jia},
82 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
83 | pages={13254--13264},
84 | year={2023}
85 | }
86 | ```
87 |
88 |
89 |
90 | ## Acknowledgment
91 |
92 | The overall framework are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/). We appreciate the effort of the contributors to these repositories.
93 |
--------------------------------------------------------------------------------
/config_real-world.txt:
--------------------------------------------------------------------------------
1 | expname = lego
2 | basedir = ./logs/real-world
3 | datadir = ./data/real-world/lego
4 | dataset_type = ellff
5 |
6 | factor = 1
7 | llffhold = 0
8 | no_batching = True
9 |
10 | N_rand = 1024
11 | N_samples = 64
12 | N_importance = 128
13 |
14 | use_viewdirs = True
15 | raw_noise_std = 1e0
16 |
17 | spherify = True
18 | no_ndc = True
--------------------------------------------------------------------------------
/config_synthetic.txt:
--------------------------------------------------------------------------------
1 | expname = lego
2 | basedir = ./logs
3 | datadir = ./data/synthetic/lego
4 | dataset_type = blender
5 |
6 | no_batching = True
7 | lrate_decay = 500
8 |
9 | N_samples = 64
10 | N_importance = 128
11 |
12 | use_viewdirs = True
13 |
14 | white_bkgd = False
15 |
16 | N_rand = 1024
--------------------------------------------------------------------------------
/event_loss_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import random
4 |
5 | def lin_log(x, threshold=20):
6 | """
7 | linear mapping + logarithmic mapping.
8 | :param x: float or ndarray the input linear value in range 0-255
9 | :param threshold: float threshold 0-255 the threshold for transisition from linear to log mapping
10 | """
11 | # converting x into np.float32.
12 | if x.dtype is not torch.float64:
13 | x = x.double()
14 | f = (1./threshold) * math.log(threshold)
15 | y = torch.where(x <= threshold, x*f, torch.log(x))
16 |
17 | return y.float()
18 |
19 |
20 | def event_loss_call(all_rgb, event_data, select_coords, combination, rgb2gray, resolution_h, resolution_w):
21 | '''
22 | simulate the generation of event stream and calculate the event loss
23 | '''
24 | if rgb2gray == "rgb":
25 | rgb2grey = torch.tensor([0.299,0.587,0.114])
26 | elif rgb2gray == "ave":
27 | rgb2grey = torch.tensor([1/3, 1/3, 1/3])
28 | loss = []
29 |
30 | chose = random.sample(combination, 10)
31 | for its in range(10):
32 | start = chose[its][0]
33 | end = chose[its][1]
34 |
35 | thres_pos = (lin_log(torch.mv(all_rgb[end], rgb2grey) * 255) - lin_log(torch.mv(all_rgb[start], rgb2grey) * 255)) / 0.3
36 | thres_neg = (lin_log(torch.mv(all_rgb[end], rgb2grey) * 255) - lin_log(torch.mv(all_rgb[start], rgb2grey) * 255)) / 0.2
37 |
38 | event_cur = event_data[start].view(resolution_h, resolution_w)[select_coords[:, 0], select_coords[:, 1]]
39 | for j in range(start + 1, end):
40 | event_cur += event_data[j].view(resolution_h, resolution_w)[select_coords[:, 0], select_coords[:, 1]]
41 |
42 | pos = event_cur > 0
43 | neg = event_cur < 0
44 |
45 | loss_pos = torch.mean(((thres_pos * pos) - ((event_cur + 0.5) * pos)) ** 2)
46 | loss_neg = torch.mean(((thres_neg * neg) - ((event_cur - 0.5) * neg)) ** 2)
47 |
48 | loss.append(loss_pos + loss_neg)
49 |
50 | event_loss = torch.mean(torch.stack(loss, dim=0), dim=0)
51 | return event_loss
52 |
--------------------------------------------------------------------------------
/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iCVTEAM/E2NeRF/487fec175c8729fdb49c22a2de1a093ca9192fe9/figure.png
--------------------------------------------------------------------------------
/load_LINEMOD.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 |
37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1):
38 | splits = ['train', 'val', 'test']
39 | metas = {}
40 | for s in splits:
41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
42 | metas[s] = json.load(fp)
43 |
44 | all_imgs = []
45 | all_poses = []
46 | counts = [0]
47 | for s in splits:
48 | meta = metas[s]
49 | imgs = []
50 | poses = []
51 | if s=='train' or testskip==0:
52 | skip = 1
53 | else:
54 | skip = testskip
55 |
56 | for idx_test, frame in enumerate(meta['frames'][::skip]):
57 | fname = frame['file_path']
58 | if s == 'test':
59 | print(f"{idx_test}th test frame: {fname}")
60 | imgs.append(imageio.imread(fname))
61 | poses.append(np.array(frame['transform_matrix']))
62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
63 | poses = np.array(poses).astype(np.float32)
64 | counts.append(counts[-1] + imgs.shape[0])
65 | all_imgs.append(imgs)
66 | all_poses.append(poses)
67 |
68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
69 |
70 | imgs = np.concatenate(all_imgs, 0)
71 | poses = np.concatenate(all_poses, 0)
72 |
73 | H, W = imgs[0].shape[:2]
74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
75 | K = meta['frames'][0]['intrinsic_matrix']
76 | print(f"Focal: {focal}")
77 |
78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
79 |
80 | if half_res:
81 | H = H//2
82 | W = W//2
83 | focal = focal/2.
84 |
85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
86 | for i, img in enumerate(imgs):
87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
88 | imgs = imgs_half_res
89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
90 |
91 | near = np.floor(min(metas['train']['near'], metas['test']['near']))
92 | far = np.ceil(max(metas['train']['far'], metas['test']['far']))
93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far
94 |
95 |
96 |
--------------------------------------------------------------------------------
/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 |
7 |
8 | trans_t = lambda t : torch.Tensor([
9 | [1,0,0,0],
10 | [0,1,0,0],
11 | [0,0,1,t],
12 | [0,0,0,1]]).float()
13 |
14 | rot_phi = lambda phi : torch.Tensor([
15 | [1,0,0,0],
16 | [0,np.cos(phi),-np.sin(phi),0],
17 | [0,np.sin(phi), np.cos(phi),0],
18 | [0,0,0,1]]).float()
19 |
20 | rot_theta = lambda th : torch.Tensor([
21 | [np.cos(th),0,-np.sin(th),0],
22 | [0,1,0,0],
23 | [np.sin(th),0, np.cos(th),0],
24 | [0,0,0,1]]).float()
25 |
26 |
27 | def pose_spherical(theta, phi, radius):
28 | c2w = trans_t(radius)
29 | c2w = rot_phi(phi/180.*np.pi) @ c2w
30 | c2w = rot_theta(theta/180.*np.pi) @ c2w
31 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
32 | return c2w
33 |
34 |
35 | def load_blender_data(basedir):
36 | splits = ['train', 'test']
37 | metas = {}
38 | for s in splits:
39 | with open(os.path.join(basedir, 'transform_{}.json'.format(s)), 'r') as fp:
40 | metas[s] = json.load(fp)
41 |
42 | imgs = []
43 | poses = []
44 | test_imgs = []
45 | test_poses = []
46 | for s in splits:
47 | meta = metas[s]
48 |
49 | if s == 'train':
50 | num = 0
51 | for frame in meta['frames'][0:200:2]:
52 | fname = os.path.join(basedir, "./train/r_{}.png".format(num))
53 | num = num + 2
54 | imgs.append(imageio.imread(fname))
55 | poses.append(np.array([frame['transform_matrix'][1], frame['transform_matrix'][5], frame['transform_matrix'][9], frame['transform_matrix'][13], frame['transform_matrix'][17]]))
56 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
57 | poses = np.array(poses).astype(np.float32)
58 | else:
59 | num = 0
60 | for frame in meta['frames'][0:200:1]:
61 | fname = os.path.join(basedir, "./test/r_{}.png".format(num))
62 | num = num + 1
63 | test_imgs.append(imageio.imread(fname))
64 | test_poses.append(np.array(frame['transform_matrix'][1]))
65 |
66 |
67 | test_imgs = (np.array(test_imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
68 | test_poses = np.array(test_poses).astype(np.float32)
69 |
70 |
71 | H, W = imgs[0].shape[:2]
72 | meta = metas['train']
73 | camera_angle_x = float(meta['camera_angle_x'])
74 | focal = .5 * W / np.tan(.5 * camera_angle_x)
75 |
76 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 180 + 1)[:-1]], 0)
77 |
78 | return imgs, poses, test_imgs, test_poses, render_poses, [H, W, focal]
79 |
--------------------------------------------------------------------------------
/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 |
8 |
9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
10 | # Get camera intrinsics
11 | with open(filepath, 'r') as file:
12 | f, cx, cy = list(map(float, file.readline().split()))[:3]
13 | grid_barycenter = np.array(list(map(float, file.readline().split())))
14 | near_plane = float(file.readline())
15 | scale = float(file.readline())
16 | height, width = map(float, file.readline().split())
17 |
18 | try:
19 | world2cam_poses = int(file.readline())
20 | except ValueError:
21 | world2cam_poses = None
22 |
23 | if world2cam_poses is None:
24 | world2cam_poses = False
25 |
26 | world2cam_poses = bool(world2cam_poses)
27 |
28 | print(cx,cy,f,height,width)
29 |
30 | cx = cx / width * trgt_sidelength
31 | cy = cy / height * trgt_sidelength
32 | f = trgt_sidelength / height * f
33 |
34 | fx = f
35 | if invert_y:
36 | fy = -f
37 | else:
38 | fy = f
39 |
40 | # Build the intrinsic matrices
41 | full_intrinsic = np.array([[fx, 0., cx, 0.],
42 | [0., fy, cy, 0],
43 | [0., 0, 1, 0],
44 | [0, 0, 0, 1]])
45 |
46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
47 |
48 |
49 | def load_pose(filename):
50 | assert os.path.isfile(filename)
51 | nums = open(filename).read().split()
52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
53 |
54 |
55 | H = 512
56 | W = 512
57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
58 |
59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
61 | focal = full_intrinsic[0,0]
62 | print(H, W, focal)
63 |
64 |
65 | def dir2poses(posedir):
66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
67 | transf = np.array([
68 | [1,0,0,0],
69 | [0,-1,0,0],
70 | [0,0,-1,0],
71 | [0,0,0,1.],
72 | ])
73 | poses = poses @ transf
74 | poses = poses[:,:3,:4].astype(np.float32)
75 | return poses
76 |
77 | posedir = os.path.join(deepvoxels_base, 'pose')
78 | poses = dir2poses(posedir)
79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
80 | testposes = testposes[::testskip]
81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
82 | valposes = valposes[::testskip]
83 |
84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
86 |
87 |
88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
91 |
92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
95 |
96 | all_imgs = [imgs, valimgs, testimgs]
97 | counts = [0] + [x.shape[0] for x in all_imgs]
98 | counts = np.cumsum(counts)
99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
100 |
101 | imgs = np.concatenate(all_imgs, 0)
102 | poses = np.concatenate([poses, valposes, testposes], 0)
103 |
104 | render_poses = testposes
105 |
106 | print(poses.shape, imgs.shape)
107 |
108 | return imgs, poses, render_poses, [H,W,focal], i_split
109 |
110 |
111 |
--------------------------------------------------------------------------------
/load_event.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | # read event data to list
11 | def load_event_data_v1(basedir):
12 | data_path = os.path.join(basedir, "event")
13 | event_list = []
14 |
15 | for i in range(0, 10, 2):
16 |
17 | file = os.path.join(data_path, "r_" + str(i) + "/v2e-dvs-events.txt")
18 | print(file)
19 | fp = open(file, "r")
20 | events = []
21 | event_block = []
22 | counter = 1
23 |
24 | for j in range(6):
25 | fp.readline()
26 |
27 | while True:
28 | line = fp.readline()
29 | if not line:
30 | events.append(event_block)
31 | break
32 |
33 | info = line.split()
34 | t = float(info[0])
35 | x = int(info[1])
36 | y = int(info[2])
37 | p = int(info[3])
38 | if t > counter * 0.001:
39 | events.append(event_block)
40 | event_block = []
41 | counter += 1
42 | event = [x, y, t, p]
43 | event_block.append(event)
44 | while counter < 20:
45 | counter += 1
46 | events.append([])
47 | event_list.append(events)
48 | return
49 |
50 |
51 | # read event data to numpy
52 | def load_event_data_v2(basedir):
53 | data_path = os.path.join(basedir, "event")
54 | event_map = np.zeros((100, 20, 800, 800), dtype=np.int)
55 |
56 | for i in range(0, 200, 2):
57 |
58 | file = os.path.join(data_path, "r_" + str(i) + "/v2e-dvs-events.txt")
59 | fp = open(file, "r")
60 | counter = 1
61 |
62 | for j in range(6):
63 | fp.readline()
64 |
65 | while True:
66 | line = fp.readline()
67 | if not line:
68 | break
69 |
70 | info = line.split()
71 | t = float(info[0])
72 | x = int(info[1])
73 | y = int(info[2])
74 | p = int(info[3])
75 | if t > counter * 0.001:
76 | counter += 1
77 | if p == 0:
78 | event_map[int(i / 2)][counter - 1][y][x] -= 1
79 | else:
80 | event_map[int(i / 2)][counter - 1][y][x] += 1
81 |
82 | return event_map
83 |
--------------------------------------------------------------------------------
/load_llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, imageio
3 |
4 |
5 | ########## Slightly modified version of LLFF data loading code
6 | ########## see https://github.com/Fyusion/LLFF for original
7 |
8 | def _minify(basedir, factors=[], resolutions=[]):
9 | needtoload = False
10 | for r in factors:
11 | imgdir = os.path.join(basedir, 'images_{}'.format(r))
12 | if not os.path.exists(imgdir):
13 | needtoload = True
14 | for r in resolutions:
15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
16 | if not os.path.exists(imgdir):
17 | needtoload = True
18 | if not needtoload:
19 | return
20 |
21 | from shutil import copy
22 | from subprocess import check_output
23 |
24 | imgdir = os.path.join(basedir, 'images')
25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
27 | imgdir_orig = imgdir
28 |
29 | wd = os.getcwd()
30 |
31 | for r in factors + resolutions:
32 | if isinstance(r, int):
33 | name = 'images_{}'.format(r)
34 | resizearg = '{}%'.format(100./r)
35 | else:
36 | name = 'images_{}x{}'.format(r[1], r[0])
37 | resizearg = '{}x{}'.format(r[1], r[0])
38 | imgdir = os.path.join(basedir, name)
39 | if os.path.exists(imgdir):
40 | continue
41 |
42 | print('Minifying', r, basedir)
43 |
44 | os.makedirs(imgdir)
45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
46 |
47 | ext = imgs[0].split('.')[-1]
48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
49 | print(args)
50 | os.chdir(imgdir)
51 | check_output(args, shell=True)
52 | os.chdir(wd)
53 |
54 | if ext != 'png':
55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
56 | print('Removed duplicates')
57 | print('Done')
58 |
59 |
60 |
61 |
62 | def _load_data(basedir, factor=8, width=None, height=None, load_imgs=True):
63 |
64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
66 | bds = poses_arr[:, -2:].transpose([1,0])
67 |
68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
70 | sh = imageio.imread(img0).shape
71 |
72 | sfx = ''
73 |
74 | if factor > 1:
75 | sfx = '_{}'.format(factor)
76 | _minify(basedir, factors=[factor])
77 | factor = factor
78 | elif height is not None:
79 | factor = sh[0] / float(height)
80 | width = int(sh[1] / factor)
81 | _minify(basedir, resolutions=[[height, width]])
82 | sfx = '_{}x{}'.format(width, height)
83 | elif width is not None:
84 | factor = sh[1] / float(width)
85 | height = int(sh[0] / factor)
86 | _minify(basedir, resolutions=[[height, width]])
87 | sfx = '_{}x{}'.format(width, height)
88 | else:
89 | factor = 1
90 |
91 | imgdir = os.path.join(basedir, 'images' + sfx)
92 | if not os.path.exists(imgdir):
93 | print( imgdir, 'does not exist, returning' )
94 | return
95 |
96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
97 |
98 |
99 | if poses.shape[-1] != len(imgfiles) * 5:
100 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
101 | return
102 |
103 | sh = imageio.imread(imgfiles[0]).shape
104 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
105 | poses[2, 4, :] = poses[2, 4, :] * 1./factor
106 |
107 | if not load_imgs:
108 | return poses, bds
109 |
110 | def imread(f):
111 | if f.endswith('png'):
112 | return imageio.imread(f, ignoregamma=True)
113 | else:
114 | return imageio.imread(f)
115 |
116 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles]
117 | imgs = np.stack(imgs, -1)
118 |
119 | print('Loaded image data', imgs.shape, poses[:,-1,0])
120 |
121 | return poses, bds, imgs
122 |
123 |
124 |
125 |
126 |
127 |
128 | def normalize(x):
129 | return x / np.linalg.norm(x)
130 |
131 | def viewmatrix(z, up, pos):
132 | vec2 = normalize(z)
133 | vec1_avg = up
134 | vec0 = normalize(np.cross(vec1_avg, vec2))
135 | vec1 = normalize(np.cross(vec2, vec0))
136 | m = np.stack([vec0, vec1, vec2, pos], 1)
137 | return m
138 |
139 | def ptstocam(pts, c2w):
140 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
141 | return tt
142 |
143 | def poses_avg(poses):
144 |
145 | hwf = poses[0, :3, -1:]
146 |
147 | center = poses[:, :3, 3].mean(0)
148 | vec2 = normalize(poses[:, :3, 2].sum(0))
149 | up = poses[:, :3, 1].sum(0)
150 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
151 |
152 | return c2w
153 |
154 |
155 |
156 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
157 | render_poses = []
158 | rads = np.array(list(rads) + [1.])
159 | hwf = c2w[:,4:5]
160 |
161 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
162 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
163 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
164 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
165 | return render_poses
166 |
167 |
168 |
169 | def recenter_poses(poses):
170 |
171 | poses_ = poses+0
172 | bottom = np.reshape([0,0,0,1.], [1,4])
173 | c2w = poses_avg(poses)
174 | c2w = np.concatenate([c2w[:3,:4], bottom], -2)
175 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
176 | poses = np.concatenate([poses[:,:3,:4], bottom], -2)
177 |
178 | poses = np.linalg.inv(c2w) @ poses
179 | poses_[:,:3,:4] = poses[:,:3,:4]
180 | poses = poses_
181 | return poses
182 |
183 |
184 | #####################
185 |
186 |
187 | def spherify_poses(poses, bds):
188 |
189 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
190 |
191 | rays_d = poses[:,:3,2:3]
192 | rays_o = poses[:,:3,3:4]
193 |
194 | def min_line_dist(rays_o, rays_d):
195 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
196 | b_i = -A_i @ rays_o
197 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
198 | return pt_mindist
199 |
200 | pt_mindist = min_line_dist(rays_o, rays_d)
201 |
202 | center = pt_mindist
203 | up = (poses[:,:3,3] - center).mean(0)
204 |
205 | vec0 = normalize(up)
206 | vec1 = normalize(np.cross([.1,.2,.3], vec0))
207 | vec2 = normalize(np.cross(vec0, vec1))
208 | pos = center
209 | c2w = np.stack([vec1, vec2, vec0, pos], 1)
210 |
211 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
212 |
213 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
214 |
215 | sc = 1./rad
216 | poses_reset[:,:3,3] *= sc
217 | bds *= sc
218 | rad *= sc
219 |
220 | centroid = np.mean(poses_reset[:,:3,3], 0)
221 | zh = centroid[2]
222 | radcircle = np.sqrt(rad**2-zh**2)
223 | new_poses = []
224 |
225 | for th in np.linspace(0.,2.*np.pi, 120):
226 |
227 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
228 | up = np.array([0,0,-1.])
229 |
230 | vec2 = normalize(camorigin)
231 | vec0 = normalize(np.cross(vec2, up))
232 | vec1 = normalize(np.cross(vec2, vec0))
233 | pos = camorigin
234 | p = np.stack([vec0, vec1, vec2, pos], 1)
235 |
236 | new_poses.append(p)
237 |
238 | new_poses = np.stack(new_poses, 0)
239 |
240 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
241 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
242 |
243 | return poses_reset, new_poses, bds
244 |
245 |
246 | def load_llff_data(basedir, factor=None, recenter=True, bd_factor=.75, spherify=False, path_zflat=False):
247 |
248 |
249 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x
250 | print('Loaded', basedir, bds.min(), bds.max())
251 |
252 | # Correct rotation matrix ordering and move variable dim to axis 0
253 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
254 | poses = np.moveaxis(poses, -1, 0).astype(np.float32)
255 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
256 | images = imgs
257 | bds = np.moveaxis(bds, -1, 0).astype(np.float32)
258 |
259 | # Rescale if bd_factor is provided
260 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor)
261 | poses[:,:3,3] *= sc
262 | bds *= sc
263 |
264 | if recenter:
265 | poses = recenter_poses(poses)
266 |
267 | if spherify:
268 | poses, render_poses, bds = spherify_poses(poses, bds)
269 |
270 | else:
271 |
272 | c2w = poses_avg(poses)
273 | print('recentered', c2w.shape)
274 | print(c2w[:3,:4])
275 |
276 | ## Get spiral
277 | # Get average pose
278 | up = normalize(poses[:, :3, 1].sum(0))
279 |
280 | # Find a reasonable "focus depth" for this dataset
281 | close_depth, inf_depth = bds.min()*.9, bds.max()*5.
282 | dt = .75
283 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
284 | focal = mean_dz
285 |
286 | # Get radii for spiral path
287 | shrink_factor = .8
288 | zdelta = close_depth * .2
289 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T
290 | rads = np.percentile(np.abs(tt), 90, 0)
291 | c2w_path = c2w
292 | N_views = 120
293 | N_rots = 2
294 | if path_zflat:
295 | # zloc = np.percentile(tt, 10, 0)[2]
296 | zloc = -close_depth * .1
297 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2]
298 | rads[2] = 0.
299 | N_rots = 1
300 | N_views/=2
301 |
302 | # Generate poses for spiral path
303 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
304 |
305 |
306 | render_poses = np.array(render_poses).astype(np.float32)
307 |
308 | c2w = poses_avg(poses)
309 | print('Data:')
310 | print(poses.shape, images.shape, bds.shape)
311 |
312 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1)
313 | i_test = np.argmin(dists)
314 | print('HOLDOUT view is', i_test)
315 |
316 | images = images.astype(np.float32)
317 | poses = poses.astype(np.float32)
318 |
319 | return images, poses, bds, render_poses, i_test
320 |
321 |
322 |
323 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/real-world/Data-preprocessing/data-preprocessing.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import cv2
3 | import torch
4 | import numpy as np
5 | from dv import AedatFile
6 |
7 | '''
8 | preload the aedat file
9 | transform the video into separate blurry frames
10 | transform the corresponding event into .pt file for E2NeRF training
11 | '''
12 |
13 | def trans(x):
14 | if x:
15 | return 1
16 | else:
17 | return 0
18 |
19 | def load_events(file, time, views_num, inter_num):
20 | '''
21 | according to the start exposure timestamp and end exposure timestamp of the blurry image to select the corresponding event
22 | output the .pt preloaded event data for E2NeRF training
23 | '''
24 | f = AedatFile(file)
25 | frame_num = 0
26 | event_map = np.zeros((views_num, inter_num - 1, 260, 346))
27 |
28 | for event in f["events"]:
29 | if time[frame_num][1] < event.timestamp:
30 | frame_num = frame_num + 1
31 | if frame_num >= views_num:
32 | break
33 |
34 | if time[frame_num][0] <= event.timestamp <= time[frame_num][1]:
35 | if event.polarity:
36 | event_map[frame_num][int((event.timestamp - time[frame_num][0]) / 25001)][event.y][event.x] += 1
37 | else:
38 | event_map[frame_num][int((event.timestamp - time[frame_num][0]) / 25001)][event.y][event.x] -= 1
39 | return event_map
40 |
41 |
42 | def load_frames(file, basedir, views_num, inter_num):
43 | '''
44 | preload the blurry frames in aedat4 file
45 | output the start exposure timestamp and end exposure timestamp of the corresponding blurry image
46 | '''
47 | global s
48 | f = AedatFile(file)
49 | sum = 0
50 | times = []
51 | for frame in f["frames"]:
52 | if sum % 10 == 0:
53 | cv2.imwrite(basedir + data_name + "/images/{0:03d}.jpg".format(s * inter_num), frame.image)
54 | times.append([frame.timestamp_start_of_exposure, frame.timestamp_end_of_exposure])
55 | s += 1
56 | sum += 1
57 | if s == views_num:
58 | break
59 | return times
60 |
61 |
62 | basedir = "../davis-aedat4/"
63 | data_name = "lego"
64 | height = 260
65 | width = 346
66 | global s
67 | s = 0
68 |
69 | if __name__ == '__main__':
70 | events = []
71 |
72 | inter_num = 5 # The number of the event bin (b in paper) + 1
73 | views_num = 30 # The number of the views of the scene
74 |
75 | if not os.path.exists(basedir + data_name + "/images"):
76 | os.mkdir(basedir + data_name + "/images")
77 |
78 | file = basedir + data_name + ".aedat4"
79 |
80 | times = load_frames(file, basedir, views_num, inter_num)
81 | events.append(load_events(file, times, views_num, inter_num))
82 | events = np.concatenate(events)
83 | events = torch.tensor(events).view(-1, inter_num - 1, 89960)
84 | torch.save(events, basedir + data_name + "/events.pt")
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Blur-synthesizing/README.md:
--------------------------------------------------------------------------------
1 | # Generation of blur image
2 |
3 | ## Introduction
4 |
5 | According to the inverse ISP and ISP process, generate the original image sequence into a image with motion blur.
6 |
7 | ## Usage
8 |
9 | ### Parameters
10 |
11 | * input_dir: input frames dir
12 | * output_dir: output frame name
13 | * scale_factor: convert scale_factor frame to 1 frame with blur, note that scale_factor must be divide by total number of frames of input
14 | * input_exposure: base exposure time of input frames in microsecond
15 | * input_iso: assumed ISO for input data
16 | * output_iso: expected ISO for output
17 |
18 | ### Sample
19 |
20 | ```
21 | python main.py
22 | --input_dir ../blender-synthetic-images/chair/r_0/
23 | --output_name ../blender-synthetic-images/chair/r_0.png
24 | --scale_factor 18
25 | --input_exposure 10
26 | --input_iso 50
27 | --output_iso 50
28 | ```
29 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Blur-synthesizing/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from process import *
3 | from unprocess import *
4 | import cv2 as cv
5 | import os
6 | from tqdm import tqdm
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--input_dir', type=str, default='../blender-synthetic-images/chair/r_0/', help='input frames dir')
10 | parser.add_argument('--output_name', type=str, default='../blender-synthetic-images/chair/r_0.png', help='output frame name')
11 | parser.add_argument('--scale_factor', type=int, default=18, help='convert scale_factor frames to 1 frame with blur, note that scale_factor must be divide by total number of frames of input')
12 | parser.add_argument('--input_exposure', type=float, default=10, help='base exposure time of input frames in microsecond')
13 | parser.add_argument('--input_iso', type=int, default=50, help='Assumed ISO for input data')
14 | parser.add_argument('--output_iso', type=int, default=50, help='Expected ISO for output')
15 | parser.add_argument('--debug', dest='debug', action='store_true', default=False, help='Debug mode')
16 | parser.add_argument('--show_effect', dest='show_effect', action='store_true', default=False,
17 | help='Show the initial image and final image')
18 | args = parser.parse_args()
19 |
20 | if __name__ == '__main__':
21 | # if file_cnt accum to scale_factor, output it and file_cnt reset to 0
22 | file_cnt = 0
23 | file_tot_cnt = 0
24 | # acc_img: img accum to acc_img
25 | acc_img = 0.
26 | # get file list
27 | # input_list = sorted(os.listdir(args.input_dir))
28 | input_list = os.listdir(args.input_dir)
29 | input_list.sort(key=lambda x: int(x[:-4]))
30 | # get image list
31 | img_list = []
32 | for input_file in input_list:
33 | if input_file.split('.')[1] == 'jpg' or input_file.split('.')[1] == 'png':
34 | img_list.append(input_file)
35 |
36 | # scan the input dir
37 | with tqdm(total=len(img_list), desc='process of file datasimul') as pbar:
38 | for filename in img_list:
39 | assert(filename.split('.')[1] == 'jpg' or filename.split('.')[1] == 'png')
40 | # Read the image
41 | img = imread_rgb(args.input_dir + filename)
42 |
43 | # make the image H W to even
44 | flagH, flagW = img.shape[0] % 2, img.shape[1] % 2
45 | img = cv.copyMakeBorder(img, flagH, 0, flagW, 0, cv.BORDER_WRAP)
46 |
47 | H, W, C = img.shape
48 |
49 | # Transform image from numpy.ndarray to torch.Tensor type
50 | img = torch.from_numpy(img)
51 |
52 | # Display the initial image
53 | if args.show_effect:
54 | single_8bit_image_display(img, "Initial")
55 |
56 | # Transform image from [0, 255] to [0, 1]
57 | img = image_8bit_to_real(img)
58 |
59 | # Fundamental Arguments Settings
60 | args.rgb2cam, args.cam2rgb = random_ccm()
61 | args.red_gain, args.blue_gain = random_gains()
62 | args.output_exposure = args.input_iso * args.input_exposure / args.output_iso
63 |
64 | # Invert ISP
65 | img = unprocess(image=img, args=args, debug=args.debug)
66 |
67 | # file_cnt accum
68 | file_cnt += 1
69 | file_tot_cnt += 1
70 | acc_img += img
71 |
72 | if(file_cnt == args.scale_factor):
73 |
74 | file_cnt = 0
75 | acc_img.clamp_(0., 1.)
76 | # ISP
77 | acc_img = process(image=acc_img, args=args, debug=args.debug)
78 | # Transform image from [0, 1] to [0, 255]
79 | acc_img = image_real_to_8bit(acc_img)
80 |
81 | # Display the final image
82 | if args.show_effect:
83 | single_8bit_image_display(acc_img, "Final")
84 |
85 | # Transform image from torch.Tensor type to numpy.ndarray
86 | acc_img = acc_img.numpy()
87 |
88 | # Write the image
89 | rgba_img = cv2.cvtColor(acc_img, cv2.COLOR_BGR2RGBA)
90 | cv2.imwrite(args.output_name, rgba_img)
91 | acc_img = 0.
92 |
93 | pbar.update(args.scale_factor)
94 |
95 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Blur-synthesizing/process.py:
--------------------------------------------------------------------------------
1 | from util import *
2 | import math
3 | import torch.distributions as tdist
4 |
5 |
6 | def tone_mapping(image: torch.Tensor, debug=False):
7 | image = torch.clamp(image, min=0.0, max=1.0)
8 | out = 3.0 * torch.pow(image, 2) - 2.0 * torch.pow(image, 3)
9 | if debug:
10 | single_real_image_display(out, "Image_After_Tone_Mapping")
11 | return out
12 |
13 |
14 | def gamma_compression(image: torch.Tensor, debug=False):
15 | out = torch.pow(torch.clamp(image, min=1e-8), 1 / 2.2)
16 | if debug:
17 | single_real_image_display(out, "Image_After_Gamma_Compression")
18 | return out
19 |
20 |
21 | def color_correction(image: torch.Tensor, ccm, debug=False):
22 | shape = image.size()
23 | image = torch.reshape(image, [-1, 3])
24 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
25 | out = torch.reshape(image, shape)
26 | if debug:
27 | single_real_image_display(out, "Image_After_Color_Correction")
28 | return out
29 |
30 |
31 | def demosaic(image: torch.Tensor, debug=False):
32 | """ Reference Code: https://github.com/Chinmayi5577/Demosaicing """
33 | shape = image.size()
34 | image = image.numpy()
35 | out_np = np.zeros((shape[0] * 2, shape[1] * 2, 3), dtype=np.float32)
36 |
37 | out_np[0::2, 0::2, 0] = image[:, :, 0]
38 | out_np[0::2, 1::2, 1] = image[:, :, 1]
39 | out_np[1::2, 0::2, 1] = image[:, :, 2]
40 | out_np[1::2, 1::2, 2] = image[:, :, 3]
41 |
42 | for i in range(1, shape[0] * 2 - 1):
43 | for j in range(1, shape[1] * 2 - 1):
44 | if i % 2 == 1 and j % 2 == 0: # case when green pixel is present with red on top and bottom
45 | out_np[i][j][0] = (out_np[i - 1][j][0] + out_np[i + 1][j][0]) / 2 # red
46 | out_np[i][j][2] = (out_np[i][j - 1][2] + out_np[i][j + 1][2]) / 2 # blue
47 | elif i % 2 == 0 and j % 2 == 0: # case when red pixel is present
48 | out_np[i][j][1] = (out_np[i - 1][j][1] + out_np[i][j + 1][1] +
49 | out_np[i + 1][j][1] + out_np[i][j - 1][1]) / 4
50 | out_np[i][j][2] = (out_np[i - 1][j - 1][2] + out_np[i + 1][j - 1][2] +
51 | out_np[i - 1][j + 1][2] + out_np[i + 1][j + 1][2]) / 4
52 | elif i % 2 == 0 and j % 2 == 1: # case when green pixel is present with blue on top and bottom
53 | out_np[i][j][0] = (out_np[i][j + 1][0] + out_np[i][j - 1][0]) / 2
54 | out_np[i][j][2] = (out_np[i + 1][j][2] + out_np[i - 1][j][2]) / 2
55 | else: # case when blue pixel is present
56 | out_np[i][j][0] = (out_np[i - 1][j - 1][0] + out_np[i - 1][j + 1][0] +
57 | out_np[i + 1][j + 1][0] + out_np[i + 1][j - 1][0]) / 4
58 | out_np[i][j][1] = (out_np[i - 1][j][1] + out_np[i][j + 1][1] +
59 | out_np[i + 1][j][1] + out_np[i][j - 1][1]) / 4
60 |
61 | last_row = shape[0] * 2 - 1
62 | for j in range(1, shape[1] * 2 - 1):
63 | if j % 2 == 0: # case when red pixel is present on first row or green pixel on last row
64 | out_np[0][j][1] = (out_np[0][j - 1][1] + out_np[0][j + 1][1] + out_np[1][j][1]) / 3
65 | out_np[0][j][2] = (out_np[1][j - 1][2] + out_np[1][j + 1][2]) / 2
66 | out_np[last_row][j][0] = out_np[last_row - 1][j][0]
67 | out_np[last_row][j][2] = (out_np[last_row][j - 1][2] + out_np[last_row][j + 1][2]) / 2
68 | else: # case when green pixel is present on first row or blue pixel on last row
69 | out_np[0][j][0] = (out_np[0][j - 1][0] + out_np[0][j + 1][0]) / 2
70 | out_np[0][j][2] = out_np[1][j][2]
71 | out_np[last_row][j][0] = (out_np[last_row - 1][j - 1][0] + out_np[last_row - 1][j + 1][0]) / 2
72 | out_np[last_row][j][1] = (out_np[last_row][j - 1][1] + out_np[last_row][j + 1][1] +
73 | out_np[last_row - 1][j][1]) / 3
74 |
75 | last_column = shape[1] * 2 - 1
76 | for i in range(1, shape[0] * 2 - 1):
77 | if i % 2 == 0: # case when red pixel is present on first column or green pixel on last column
78 | out_np[i][0][1] = (out_np[i - 1][0][1] + out_np[i + 1][0][1] + out_np[i][1][1]) / 3
79 | out_np[i][0][2] = (out_np[i - 1][1][2] + out_np[i + 1][1][2]) / 2
80 | out_np[i][last_column][0] = out_np[i][last_column - 1][0]
81 | out_np[i][last_column][2] = (out_np[i - 1][last_column][2] + out_np[i + 1][last_column][2]) / 2
82 | else: # case when green pixel is present on first column or blue pixel on last column
83 | out_np[i][0][0] = (out_np[i - 1][1][0] + out_np[i + 1][1][0]) / 2
84 | out_np[i][0][2] = out_np[i][1][2]
85 | out_np[i][last_column][0] = (out_np[i - 1][last_column - 1][0] + out_np[i + 1][last_column - 1][0]) / 2
86 | out_np[i][last_column][1] = (out_np[i - 1][last_column][1] + out_np[i + 1][last_column][1] +
87 | out_np[i][last_column - 1][1]) / 3
88 |
89 | out_np[0][0][1] = (out_np[0][1][1] + out_np[1][0][1]) / 2
90 | out_np[0][0][2] = out_np[1][1][2]
91 | out_np[0][last_column][0] = out_np[0][last_column - 1][0]
92 | out_np[0][last_column][2] = out_np[1][last_column][2]
93 | out_np[last_row][0][0] = out_np[last_row - 1][0][0]
94 | out_np[last_row][0][2] = out_np[last_row][1][2]
95 | out_np[last_row][last_column][0] = out_np[last_row - 1][last_column - 1][0]
96 | out_np[last_row][last_column][1] = (out_np[last_row - 1][last_column][1] +
97 | out_np[last_row][last_column - 1][1]) / 2
98 |
99 | out = torch.from_numpy(out_np)
100 | out = torch.clamp(out, min=0.0, max=1.0)
101 | if debug:
102 | single_real_image_display(out, "Image_After_Demosaic")
103 | return out
104 |
105 |
106 | def white_balance(image: torch.Tensor, red_gain=1.9, blue_gain=1.5, debug=False):
107 | f_red = image[:, :, 0] * red_gain
108 | f_blue = image[:, :, 3] * blue_gain
109 | out = torch.stack((f_red, image[:, :, 1], image[:, :, 2], f_blue), dim=-1)
110 | out = torch.clamp(out, min=0.0, max=1.0)
111 | if debug:
112 | single_raw_image_display(out, "Image_After_White_Balance")
113 | return out
114 |
115 |
116 | def digital_gain(image: torch.Tensor, iso=800, debug=False):
117 | out = image * iso
118 | if debug:
119 | tmp = torch.clamp(out, min=0.0, max=1.0)
120 | single_raw_image_display(tmp, "Image_After_Digital_Gain")
121 | return out
122 |
123 |
124 | def exposure_time_accumulate(image: torch.Tensor, exposure_time=10, debug=False):
125 | out = image * exposure_time
126 | if debug:
127 | tmp = torch.clamp(out, min=0.0, max=1.0)
128 | single_raw_image_display(tmp, "Image_After_Exposure_Time_Accumulation")
129 | return out
130 |
131 |
132 | def add_read_noise(image: torch.Tensor, iso=800, debug=False):
133 | image *= (1023 - 64)
134 | r_channel = image[:, :, 0]
135 | gr_channel = image[:, :, 1]
136 | gb_channel = image[:, :, 2]
137 | b_channel = image[:, :, 3]
138 |
139 | R0 = {'R': 0.300575, 'G': 0.347856, 'B': 0.356116}
140 | R1 = {'R': 1.293143, 'G': 0.403101, 'B': 0.403101}
141 |
142 | r_noise_sigma_square = (iso / 100.) ** 2 * R0['R'] + R1['R']
143 | gb_noise_sigma_square = (iso / 100.) ** 2 * R0['G'] + R1['G']
144 | gr_noise_sigma_square = (iso / 100.) ** 2 * R0['G'] + R1['G']
145 | b_noise_sigma_square = (iso / 100.) ** 2 * R0['B'] + R1['B']
146 |
147 | r_samples = tdist.Normal(loc=torch.zeros_like(r_channel), scale=math.sqrt(r_noise_sigma_square)).sample()
148 | gr_samples = tdist.Normal(loc=torch.zeros_like(gr_channel), scale=math.sqrt(gb_noise_sigma_square)).sample()
149 | gb_samples = tdist.Normal(loc=torch.zeros_like(gb_channel), scale=math.sqrt(gr_noise_sigma_square)).sample()
150 | b_samples = tdist.Normal(loc=torch.zeros_like(b_channel), scale=math.sqrt(b_noise_sigma_square)).sample()
151 |
152 | r_channel += r_samples
153 | gr_channel += gr_samples
154 | gb_channel += gb_samples
155 | b_channel += b_samples
156 |
157 | out = torch.stack((r_channel, gr_channel, gb_channel, b_channel), dim=-1)
158 | out /= (1023 - 64)
159 | out = torch.clamp(out, min=0.0, max=1.0)
160 |
161 | if debug:
162 | single_raw_image_display(out, "Image_After_Add_Read_Noise")
163 | return out
164 |
165 |
166 | def add_shot_noise(image: torch.Tensor, debug=False):
167 | image *= (1023 - 64)
168 | r_channel = image[:, :, 0]
169 | gr_channel = image[:, :, 1]
170 | gb_channel = image[:, :, 2]
171 | b_channel = image[:, :, 3]
172 |
173 | S = {'R': 0.343334/2, 'G': 0.348052/2, 'B': 0.346563/2}
174 |
175 | r_noise_sigma_square = (S['R'] / 100) * r_channel
176 | gb_noise_sigma_square = (S['G'] / 100) * gr_channel
177 | gr_noise_sigma_square = (S['G'] / 100) * gb_channel
178 | b_noise_sigma_square = (S['B'] / 100) * b_channel
179 |
180 | # print(r_noise_sigma_square)
181 | # print(torch.sqrt(r_noise_sigma_square).shape, r_channel.shape)
182 |
183 | r_samples = tdist.Normal(loc=torch.zeros_like(r_channel), scale=torch.sqrt(r_noise_sigma_square)).sample()
184 | gr_samples = tdist.Normal(loc=torch.zeros_like(gr_channel), scale=torch.sqrt(gb_noise_sigma_square)).sample()
185 | gb_samples = tdist.Normal(loc=torch.zeros_like(gb_channel), scale=torch.sqrt(gr_noise_sigma_square)).sample()
186 | b_samples = tdist.Normal(loc=torch.zeros_like(b_channel), scale=torch.sqrt(b_noise_sigma_square)).sample()
187 |
188 | r_channel += r_samples
189 | gr_channel += gr_samples
190 | gb_channel += gb_samples
191 | b_channel += b_samples
192 |
193 | out = torch.stack((r_channel, gr_channel, gb_channel, b_channel), dim=-1)
194 | out /= (1023 - 64)
195 | out = torch.clamp(out, min=0.0, max=1.0)
196 |
197 | if debug:
198 | single_raw_image_display(out, "Image_After_Add_Shot_Noise")
199 | return out
200 |
201 |
202 | def process(image: torch.Tensor, args, debug=False):
203 | # Followings are ISP Steps
204 | # ISP-step -7: the exposure time accumulation
205 | image = exposure_time_accumulate(image, exposure_time=args.output_exposure, debug=debug)
206 | image = add_shot_noise(image, debug)
207 | # ISP-step -6: the digital gain with read noise
208 | image = digital_gain(image, iso=args.output_iso, debug=debug)
209 | image = add_read_noise(image, iso=args.output_iso, debug=debug)
210 | # ISP-step -5: the white balance
211 | image = white_balance(image, red_gain=args.red_gain, blue_gain=args.blue_gain, debug=debug)
212 | # ISP-step -4: the demosaicing
213 | image = demosaic(image, debug)
214 | # ISP-step -3: the color correction
215 | image = color_correction(image, args.cam2rgb, debug)
216 | # ISP-step -2: the gamma compression
217 | image = gamma_compression(image, debug)
218 | # ISP-step -1: the tone mapping
219 | image = tone_mapping(image, debug)
220 | return image
221 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Blur-synthesizing/unprocess.py:
--------------------------------------------------------------------------------
1 | from util import *
2 |
3 |
4 | def inv_tone_mapping(image: torch.Tensor, debug=False):
5 | image = torch.clamp(image, min=0.0, max=1.0)
6 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
7 | if debug:
8 | single_real_image_display(out, "Image_After_Inv_Tone_Mapping")
9 | return out
10 |
11 |
12 | def inv_gamma_compression(image: torch.Tensor, debug=False):
13 | out = torch.pow(torch.clamp(image, min=1e-8), 2.2)
14 | if debug:
15 | single_real_image_display(out, "Image_After_Inv_Gamma_Compression")
16 | return out
17 |
18 |
19 | def inv_color_correction(image: torch.Tensor, ccm, debug=False):
20 | shape = image.size()
21 | image = torch.reshape(image, [-1, 3])
22 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
23 | out = torch.reshape(image, shape)
24 | if debug:
25 | single_real_image_display(out, "Image_After_Inv_Color_Correction")
26 | return out
27 |
28 |
29 | def mosaic(image: torch.Tensor, debug=False):
30 | shape = image.size()
31 | red = image[0::2, 0::2, 0]
32 | green_red = image[0::2, 1::2, 1]
33 | green_blue = image[1::2, 0::2, 1]
34 | blue = image[1::2, 1::2, 2]
35 | # import pdb
36 | # pdb.set_trace()
37 | # maybe shape[0] or shape[1] is odd, that cannot be divide by 2!!!!!!!
38 |
39 | out = torch.stack((red, green_red, green_blue, blue), dim=-1)
40 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4))
41 | if debug:
42 | single_raw_image_display(out, "Image_After_Mosaic")
43 | return out
44 |
45 |
46 | def inv_white_balance(image: torch.Tensor, red_gain=1.9, blue_gain=1.5, threshold=0.9, debug=False):
47 | red = image[:, :, 0]
48 | alpha_red = (torch.max(red - threshold, torch.zeros_like(red)) / (1 - threshold)) ** 2
49 | f_red = torch.max(red / red_gain, (1 - alpha_red) * (red / red_gain) + alpha_red * red)
50 | blue = image[:, :, 3]
51 | alpha_blue = (torch.max(blue - threshold, torch.zeros_like(blue)) / (1 - threshold)) ** 2
52 | f_blue = torch.max(blue / blue_gain, (1 - alpha_blue) * (blue / blue_gain) + alpha_blue * blue)
53 | out = torch.stack((f_red, image[:, :, 1], image[:, :, 2], f_blue), dim=-1)
54 | if debug:
55 | single_raw_image_display(out, "Image_After_Inv_White_Balance")
56 | return out
57 |
58 |
59 | def inv_digital_gain(image: torch.Tensor, iso=800, scale_factor=5, debug=False):
60 | # out = image / iso
61 | out = image / (iso * scale_factor)
62 | # out /= scale_factor
63 | if debug:
64 | single_raw_image_display(out, "Image_After_Inv_Digital_Gain")
65 | return out
66 |
67 |
68 | def inv_exposure_time(image: torch.Tensor, exposure_time=10, debug=False):
69 | out = image / exposure_time
70 | if debug:
71 | single_raw_image_display(out, "Image_After_Inv_Exposure_Time")
72 | return out
73 |
74 |
75 | def unprocess(image: torch.Tensor, args, debug=False):
76 | # Followings are Inversion-ISP Steps
77 | # Inv-step 1: the inversion of tone mapping
78 | image = inv_tone_mapping(image, debug)
79 | # Inv-step 2: the inversion of gamma compression
80 | image = inv_gamma_compression(image, debug)
81 | # Inv-step 3: the inversion of color correction
82 | image = inv_color_correction(image, args.rgb2cam, debug)
83 | # Inv-step 4: the mosaicing
84 | image = mosaic(image, debug)
85 | # Inv-step 5: the inversion of white balance
86 | image = inv_white_balance(image, red_gain=args.red_gain, blue_gain=args.blue_gain, debug=debug)
87 | #single_raw_image_display(image, "balance")
88 | # Inv-step 6: the inversion of digital gain
89 | image = inv_digital_gain(image, iso=args.input_iso, scale_factor=args.scale_factor, debug=debug)
90 |
91 | # Inv-step 7: split exposure time to 1ms
92 | image = inv_exposure_time(image, exposure_time=args.input_exposure, debug=debug)
93 |
94 | return image
95 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Blur-synthesizing/util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import numpy
5 |
6 |
7 | def imread_rgb(path):
8 | image = cv2.imread(path)
9 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
10 | return image
11 |
12 |
13 | def imwrite_rgb(image, filename):
14 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
15 | cv2.imwrite(filename, image)
16 |
17 |
18 | def image_8bit_to_real(image_8bit: torch.Tensor):
19 | """ Covert image from [0, 255] to [0, 1] """
20 | return torch.div(image_8bit, 255)
21 |
22 |
23 | def image_real_to_8bit(image_real: torch.Tensor):
24 | """ Covert image from [0, 1] to [0, 255] """
25 | image_real = torch.clamp(image_real, min=0.0, max=1.0)
26 | return torch.mul(image_real, 255).type(torch.uint8)
27 |
28 |
29 | def single_8bit_image_display(image: torch.Tensor, description=None):
30 | """ Display a single [0, 255] image """
31 | image = image.numpy()
32 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
33 |
34 | # cv2.imshow(description, image)
35 | # cv2.waitKey(0)
36 | cv2.imwrite('./resfig/ans.jpg', image)
37 |
38 |
39 | # cv2.destroyAllWindows()
40 |
41 |
42 | def single_real_image_display(image: torch.Tensor, description=None):
43 | """ Display a single [0, 1] image """
44 | # image_real_to_8bit(image)
45 | image = image.numpy()
46 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
47 | cv2.imshow(description, image)
48 | cv2.waitKey(0)
49 | # cv2.destroyAllWindows()
50 |
51 |
52 | def single_8bit_image_display_numpy(image: numpy.ndarray, description=None):
53 | cv2.imshow(description, image)
54 | cv2.waitKey(0)
55 | # cv2.destroyAllWindows()
56 |
57 |
58 | def single_raw_image_display(image: torch.Tensor, description=None):
59 | shape = image.size()
60 | image = image.numpy()
61 | raw_in_rgb = numpy.zeros([shape[0] * 2, shape[1] * 2, 3], dtype=np.float32)
62 | raw_in_rgb[0::2, 0::2, 0] = image[:, :, 0]
63 | raw_in_rgb[0::2, 1::2, 1] = image[:, :, 1]
64 | raw_in_rgb[1::2, 0::2, 1] = image[:, :, 2]
65 | raw_in_rgb[1::2, 1::2, 2] = image[:, :, 3]
66 | raw_in_rgb = cv2.cvtColor(raw_in_rgb, cv2.COLOR_RGB2BGR)
67 | cv2.imshow(description, raw_in_rgb)
68 | cv2.imwrite(description + ".jpg", raw_in_rgb * 255)
69 | cv2.waitKey(0)
70 |
71 |
72 | def random_ccm():
73 | """Generates random RGB -> Camera color correction matrices."""
74 | # Takes a random convex combination of XYZ -> Camera CCMs.
75 | xyz2cams = [[[1.0234, -0.2969, -0.2266],
76 | [-0.5625, 1.6328, -0.0469],
77 | [-0.0703, 0.2188, 0.6406]],
78 | [[0.4913, -0.0541, -0.0202],
79 | [-0.613, 1.3513, 0.2906],
80 | [-0.1564, 0.2151, 0.7183]],
81 | [[0.838, -0.263, -0.0639],
82 | [-0.2887, 1.0725, 0.2496],
83 | [-0.0627, 0.1427, 0.5438]],
84 | [[0.6596, -0.2079, -0.0562],
85 | [-0.4782, 1.3016, 0.1933],
86 | [-0.097, 0.1581, 0.5181]]]
87 | num_ccms = len(xyz2cams)
88 | xyz2cams = torch.FloatTensor(xyz2cams)
89 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8)
90 | weights_sum = torch.sum(weights, dim=0)
91 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum
92 |
93 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
94 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375],
95 | [0.2126729, 0.7151522, 0.0721750],
96 | [0.0193339, 0.1191920, 0.9503041]])
97 | rgb2cam = torch.mm(xyz2cam, rgb2xyz)
98 |
99 | # Normalizes each row.
100 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True)
101 |
102 | cam2rgb = torch.inverse(rgb2cam)
103 |
104 | rgb2cam = torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
105 | cam2rgb = torch.inverse(rgb2cam)
106 |
107 |
108 | return rgb2cam, cam2rgb
109 |
110 |
111 | def random_gains():
112 | # Red and blue gains represent white balance.
113 | # red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4)
114 | red_gain = torch.FloatTensor([2.15])
115 | # blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9)
116 | blue_gain = torch.FloatTensor([1.7])
117 | return red_gain, blue_gain
118 |
--------------------------------------------------------------------------------
/orginal_data&preprocessing/synthetic/Event-preprocessing/event-preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 | '''
6 | Event data loader
7 | Transform the original event data to .pt file for E2NeRF training
8 | '''
9 | def load_event_data_noisy(data_path):
10 | event_map = np.zeros((100, 4, 800, 800)) # 100 views of input, 4 event bin (b in the paper) for each view, resolution of 800*800
11 |
12 | for i in range(0, 200, 2):
13 | file = os.path.join(data_path, "r_{}/v2e-dvs-events.txt".format(i))
14 | fp = open(file, "r")
15 |
16 | print("Processing data_path_{}".format(i))
17 | counter = 1
18 | for j in range(6):
19 | fp.readline()
20 |
21 | while True:
22 | line = fp.readline()
23 | if not line:
24 | break
25 |
26 | info = line.split()
27 | t = float(info[0])
28 | x = int(info[1])
29 | y = int(info[2])
30 | p = int(info[3])
31 |
32 | if t > counter * 0.04 + 0.01:
33 | counter += 1
34 | if counter >= 5:
35 | break
36 |
37 | if p == 0:
38 | event_map[int(i / 2)][counter - 1][y][x] -= 1
39 | else:
40 | event_map[int(i / 2)][counter - 1][y][x] += 1
41 | return event_map
42 |
43 | input_data_path = "../blender-v2e-synthetic-events/lego/"
44 | output_data_path = "../blender-v2e-synthetic-events/lego/events.pt"
45 |
46 | if __name__ == '__main__':
47 | events = load_event_data_noisy(input_data_path)
48 | events = torch.tensor(events).view(100, 4, 640000)
49 | torch.save(events, output_data_path)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.20.2
2 | imageio~=2.9.0
3 | torch~=1.9.0+cu111
4 | opencv-python~=4.5.1.48
5 | matplotlib~=3.4.2
6 | tqdm~=4.62.0
7 | dv~=1.0.10
--------------------------------------------------------------------------------
/run_nerf_exp.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import imageio
4 | import json
5 | import random
6 | import time
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from tqdm import tqdm, trange
11 |
12 | import matplotlib.pyplot as plt
13 |
14 | import load_blender
15 | from run_nerf_helpers import *
16 |
17 | from load_llff import load_llff_data
18 | from load_blender import load_blender_data
19 | from load_event import *
20 |
21 | import pynvml
22 | from torch.utils.tensorboard import SummaryWriter
23 | import math
24 | from event_loss_helpers import *
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | np.random.seed(0)
29 | DEBUG = False
30 |
31 |
32 | def batchify(fn, chunk):
33 | """Constructs a version of 'fn' that applies to smaller batches.
34 | """
35 | if chunk is None:
36 | return fn
37 | def ret(inputs):
38 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
39 | return ret
40 |
41 |
42 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
43 | """Prepares inputs and applies network 'fn'.
44 | """
45 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
46 | embedded = embed_fn(inputs_flat)
47 |
48 | if viewdirs is not None:
49 | input_dirs = viewdirs[:,None].expand(inputs.shape)
50 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
51 | embedded_dirs = embeddirs_fn(input_dirs_flat)
52 | embedded = torch.cat([embedded, embedded_dirs], -1)
53 |
54 | outputs_flat = batchify(fn, netchunk)(embedded)
55 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
56 | return outputs
57 |
58 |
59 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
60 | """Render rays in smaller minibatches to avoid OOM.
61 | """
62 | all_ret = {}
63 | for i in range(0, rays_flat.shape[0], chunk):
64 | ret = render_rays(rays_flat[i:i+chunk], **kwargs)
65 | for k in ret:
66 | if k not in all_ret:
67 | all_ret[k] = []
68 | all_ret[k].append(ret[k])
69 |
70 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
71 | return all_ret
72 |
73 |
74 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
75 | near=0., far=1.,
76 | use_viewdirs=False, c2w_staticcam=None,
77 | **kwargs):
78 | """Render rays
79 | Args:
80 | H: int. Height of image in pixels.
81 | W: int. Width of image in pixels.
82 | focal: float. Focal length of pinhole camera.
83 | chunk: int. Maximum number of rays to process simultaneously. Used to
84 | control maximum memory usage. Does not affect final results.
85 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for
86 | each example in batch.
87 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
88 | ndc: bool. If True, represent ray origin, direction in NDC coordinates.
89 | near: float or array of shape [batch_size]. Nearest distance for a ray.
90 | far: float or array of shape [batch_size]. Farthest distance for a ray.
91 | use_viewdirs: bool. If True, use viewing direction of a point in space in model.
92 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
93 | camera while using other c2w argument for viewing directions.
94 | Returns:
95 | rgb_map: [batch_size, 3]. Predicted RGB values for rays.
96 | disp_map: [batch_size]. Disparity map. Inverse of depth.
97 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
98 | extras: dict with everything returned by render_rays().
99 | """
100 |
101 | if c2w is not None:
102 | # special case to render full image
103 | rays_o, rays_d = get_rays(H, W, K, c2w)
104 | else:
105 | # use provided ray batch
106 | rays_o, rays_d = rays
107 |
108 | if use_viewdirs:
109 | # provide ray directions as input
110 | viewdirs = rays_d
111 | if c2w_staticcam is not None:
112 | # special case to visualize effect of viewdirs
113 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
114 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
115 | viewdirs = torch.reshape(viewdirs, [-1,3]).float()
116 |
117 | sh = rays_d.shape # [..., 3]
118 | if ndc:
119 | # for forward facing scenes
120 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
121 |
122 | # Create ray batch
123 | rays_o = torch.reshape(rays_o, [-1,3]).float()
124 | rays_d = torch.reshape(rays_d, [-1,3]).float()
125 |
126 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
127 | rays = torch.cat([rays_o, rays_d, near, far], -1)
128 | if use_viewdirs:
129 | rays = torch.cat([rays, viewdirs], -1)
130 |
131 | # Render and reshape
132 | all_ret = batchify_rays(rays, chunk, **kwargs)
133 | for k in all_ret:
134 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
135 | all_ret[k] = torch.reshape(all_ret[k], k_sh)
136 |
137 | k_extract = ['rgb_map', 'disp_map', 'acc_map']
138 | ret_list = [all_ret[k] for k in k_extract]
139 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
140 | return ret_list + [ret_dict]
141 |
142 |
143 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
144 |
145 | H, W, focal = hwf
146 |
147 | if render_factor!=0:
148 | # Render downsampled for speed
149 | H = H//render_factor
150 | W = W//render_factor
151 | focal = focal/render_factor
152 |
153 | rgbs = []
154 | disps = []
155 | rgbs_extras = []
156 |
157 | t = time.time()
158 | for i, c2w in enumerate(tqdm(render_poses)):
159 | print(i, time.time() - t)
160 | t = time.time()
161 | #rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
162 | rgb, disp, acc, extra = render(H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs)
163 | rgbs.append(rgb.cpu().numpy())
164 | disps.append(disp.cpu().numpy())
165 |
166 | if i==0:
167 | print(rgb.shape, disp.shape)
168 |
169 |
170 | if gt_imgs is not None and render_factor==0:
171 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
172 | print(p)
173 |
174 | if savedir is not None:
175 | rgb8 = to8b(rgbs[-1])
176 | filename = os.path.join(savedir, '{:03d}.png'.format(i))
177 | imageio.imwrite(filename, rgb8)
178 |
179 | rgbs = np.stack(rgbs, 0)
180 | disps = np.stack(disps, 0)
181 |
182 | return rgbs, disps
183 |
184 |
185 | def create_nerf(args):
186 | """Instantiate NeRF's MLP model.
187 | """
188 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
189 |
190 | input_ch_views = 0
191 | embeddirs_fn = None
192 | if args.use_viewdirs:
193 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
194 | output_ch = 5 if args.N_importance > 0 else 4
195 | skips = [4]
196 | model = NeRF(D=args.netdepth, W=args.netwidth,
197 | input_ch=input_ch, output_ch=output_ch, skips=skips,
198 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
199 | grad_vars = list(model.parameters())
200 |
201 | model_fine = None
202 | if args.N_importance > 0:
203 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
204 | input_ch=input_ch, output_ch=output_ch, skips=skips,
205 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
206 | grad_vars += list(model_fine.parameters())
207 |
208 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
209 | embed_fn=embed_fn,
210 | embeddirs_fn=embeddirs_fn,
211 | netchunk=args.netchunk)
212 |
213 | # Create optimizer
214 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
215 |
216 | start = 0
217 | basedir = args.basedir
218 | expname = args.expname
219 |
220 | ##########################
221 |
222 | # Load checkpoints
223 | if args.ft_path is not None and args.ft_path!='None':
224 | ckpts = [args.ft_path]
225 | else:
226 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
227 |
228 | print('Found ckpts', ckpts)
229 | if len(ckpts) > 0 and not args.no_reload:
230 | ckpt_path = ckpts[-1]
231 | print('Reloading from', ckpt_path)
232 | ckpt = torch.load(ckpt_path)
233 |
234 | start = ckpt['global_step']
235 | optimizer.load_state_dict(ckpt['optimizer_state_dict'])
236 |
237 | # Load model
238 | model.load_state_dict(ckpt['network_fn_state_dict'])
239 | if model_fine is not None:
240 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
241 |
242 | ##########################
243 |
244 | render_kwargs_train = {
245 | 'network_query_fn' : network_query_fn,
246 | 'perturb' : args.perturb,
247 | 'N_importance' : args.N_importance,
248 | 'network_fine' : model_fine,
249 | 'N_samples' : args.N_samples,
250 | 'network_fn' : model,
251 | 'use_viewdirs' : args.use_viewdirs,
252 | 'white_bkgd' : args.white_bkgd,
253 | 'raw_noise_std' : args.raw_noise_std,
254 | }
255 |
256 | # NDC only good for LLFF-style forward facing data
257 | if args.dataset_type != 'llff' or args.no_ndc:
258 | print('Not ndc!')
259 | render_kwargs_train['ndc'] = False
260 | render_kwargs_train['lindisp'] = args.lindisp
261 |
262 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
263 | render_kwargs_test['perturb'] = False
264 | render_kwargs_test['raw_noise_std'] = 0.
265 |
266 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
267 |
268 |
269 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
270 | """Transforms model's predictions to semantically meaningful values.
271 | Args:
272 | raw: [num_rays, num_samples along ray, 4]. Prediction from model.
273 | z_vals: [num_rays, num_samples along ray]. Integration time.
274 | rays_d: [num_rays, 3]. Direction of each ray.
275 | Returns:
276 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
277 | disp_map: [num_rays]. Disparity map. Inverse of depth map.
278 | acc_map: [num_rays]. Sum of weights along each ray.
279 | weights: [num_rays, num_samples]. Weights assigned to each sampled color.
280 | depth_map: [num_rays]. Estimated distance to object.
281 | """
282 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
283 |
284 | dists = z_vals[...,1:] - z_vals[...,:-1]
285 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
286 |
287 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
288 |
289 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
290 | noise = 0.
291 | if raw_noise_std > 0.:
292 | noise = torch.randn(raw[...,3].shape) * raw_noise_std
293 |
294 | # Overwrite randomly sampled data if pytest
295 | if pytest:
296 | np.random.seed(0)
297 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
298 | noise = torch.Tensor(noise)
299 |
300 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
301 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
302 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
303 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
304 |
305 | depth_map = torch.sum(weights * z_vals, -1)
306 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
307 | acc_map = torch.sum(weights, -1)
308 |
309 | if white_bkgd:
310 | rgb_map = rgb_map + (1.-acc_map[...,None])
311 |
312 | return rgb_map, disp_map, acc_map, weights, depth_map
313 |
314 |
315 | def render_rays(ray_batch,
316 | network_fn,
317 | network_query_fn,
318 | N_samples,
319 | retraw=False,
320 | lindisp=False,
321 | perturb=0.,
322 | N_importance=0,
323 | network_fine=None,
324 | white_bkgd=False,
325 | raw_noise_std=0.,
326 | verbose=False,
327 | pytest=False):
328 | """Volumetric rendering.
329 | Args:
330 | ray_batch: array of shape [batch_size, ...]. All information necessary
331 | for sampling along a ray, including: ray origin, ray direction, min
332 | dist, max dist, and unit-magnitude viewing direction.
333 | network_fn: function. Model for predicting RGB and density at each point
334 | in space.
335 | network_query_fn: function used for passing queries to network_fn.
336 | N_samples: int. Number of different times to sample along each ray.
337 | retraw: bool. If True, include model's raw, unprocessed predictions.
338 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
339 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
340 | random points in time.
341 | N_importance: int. Number of additional times to sample along each ray.
342 | These samples are only passed to network_fine.
343 | network_fine: "fine" network with same spec as network_fn.
344 | white_bkgd: bool. If True, assume a white background.
345 | raw_noise_std: ...
346 | verbose: bool. If True, print more debugging info.
347 | Returns:
348 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
349 | disp_map: [num_rays]. Disparity map. 1 / depth.
350 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
351 | raw: [num_rays, num_samples, 4]. Raw predictions from model.
352 | rgb0: See rgb_map. Output for coarse model.
353 | disp0: See disp_map. Output for coarse model.
354 | acc0: See acc_map. Output for coarse model.
355 | z_std: [num_rays]. Standard deviation of distances along ray for each
356 | sample.
357 | """
358 | N_rays = ray_batch.shape[0]
359 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
360 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
361 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
362 | near, far = bounds[...,0], bounds[...,1] # [-1,1]
363 |
364 | t_vals = torch.linspace(0., 1., steps=N_samples)
365 | if not lindisp:
366 | z_vals = near * (1.-t_vals) + far * (t_vals)
367 | else:
368 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
369 |
370 | z_vals = z_vals.expand([N_rays, N_samples])
371 |
372 | if perturb > 0.:
373 | # get intervals between samples
374 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
375 | upper = torch.cat([mids, z_vals[...,-1:]], -1)
376 | lower = torch.cat([z_vals[...,:1], mids], -1)
377 | # stratified samples in those intervals
378 | t_rand = torch.rand(z_vals.shape)
379 |
380 | # Pytest, overwrite u with numpy's fixed random numbers
381 | if pytest:
382 | np.random.seed(0)
383 | t_rand = np.random.rand(*list(z_vals.shape))
384 | t_rand = torch.Tensor(t_rand)
385 |
386 | z_vals = lower + (upper - lower) * t_rand
387 |
388 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
389 |
390 |
391 | # raw = run_network(pts)
392 | raw = network_query_fn(pts, viewdirs, network_fn)
393 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
394 |
395 | if N_importance > 0:
396 |
397 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
398 |
399 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
400 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
401 | z_samples = z_samples.detach()
402 |
403 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
404 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
405 |
406 | run_fn = network_fn if network_fine is None else network_fine
407 | # raw = run_network(pts, fn=run_fn)
408 | raw = network_query_fn(pts, viewdirs, run_fn)
409 |
410 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
411 |
412 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
413 | if retraw:
414 | ret['raw'] = raw
415 | if N_importance > 0:
416 | ret['rgb0'] = rgb_map_0
417 | ret['disp0'] = disp_map_0
418 | ret['acc0'] = acc_map_0
419 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
420 |
421 | for k in ret:
422 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
423 | print(f"! [Numerical Error] {k} contains nan or inf.")
424 |
425 | return ret
426 |
427 |
428 | def config_parser():
429 |
430 | import configargparse
431 | parser = configargparse.ArgumentParser()
432 | parser.add_argument('--config', is_config_file=True,
433 | help='config file path')
434 | parser.add_argument("--expname", type=str,
435 | help='experiment name')
436 | parser.add_argument("--basedir", type=str, default='./logs/',
437 | help='where to store ckpts and logs')
438 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
439 | help='input data directory')
440 |
441 | # training options
442 | parser.add_argument("--netdepth", type=int, default=8,
443 | help='layers in network')
444 | parser.add_argument("--netwidth", type=int, default=256,
445 | help='channels per layer')
446 | parser.add_argument("--netdepth_fine", type=int, default=8,
447 | help='layers in fine network')
448 | parser.add_argument("--netwidth_fine", type=int, default=256,
449 | help='channels per layer in fine network')
450 | parser.add_argument("--N_rand", type=int, default=32*32*4,
451 | help='batch size (number of random rays per gradient step)')
452 | parser.add_argument("--lrate", type=float, default=5e-4,
453 | help='learning rate')
454 | parser.add_argument("--lrate_decay", type=int, default=250,
455 | help='exponential learning rate decay (in 1000 steps)')
456 | parser.add_argument("--chunk", type=int, default=1024*32,
457 | help='number of rays processed in parallel, decrease if running out of memory')
458 | parser.add_argument("--netchunk", type=int, default=1024*64,
459 | help='number of pts sent through network in parallel, decrease if running out of memory')
460 | parser.add_argument("--no_batching", action='store_true',
461 | help='only take random rays from 1 image at a time')
462 | parser.add_argument("--no_reload", action='store_true',
463 | help='do not reload weights from saved ckpt')
464 | parser.add_argument("--ft_path", type=str, default=None,
465 | help='specific weights npy file to reload for coarse network')
466 |
467 | # rendering options
468 | parser.add_argument("--N_samples", type=int, default=64,
469 | help='number of coarse samples per ray')
470 | parser.add_argument("--N_importance", type=int, default=0,
471 | help='number of additional fine samples per ray')
472 | parser.add_argument("--perturb", type=float, default=1.,
473 | help='set to 0. for no jitter, 1. for jitter')
474 | parser.add_argument("--use_viewdirs", action='store_true',
475 | help='use full 5D input instead of 3D')
476 | parser.add_argument("--i_embed", type=int, default=0,
477 | help='set 0 for default positional encoding, -1 for none')
478 | parser.add_argument("--multires", type=int, default=10,
479 | help='log2 of max freq for positional encoding (3D location)')
480 | parser.add_argument("--multires_views", type=int, default=4,
481 | help='log2 of max freq for positional encoding (2D direction)')
482 | parser.add_argument("--raw_noise_std", type=float, default=0.,
483 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
484 |
485 | parser.add_argument("--render_only", action='store_true',
486 | help='do not optimize, reload weights and render out render_poses path')
487 | parser.add_argument("--render_factor", type=int, default=0,
488 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
489 |
490 | # training options
491 | parser.add_argument("--precrop_iters", type=int, default=0,
492 | help='number of steps to train on central crops')
493 | parser.add_argument("--precrop_frac", type=float,
494 | default=.5, help='fraction of img taken for central crops')
495 |
496 | # dataset options
497 | parser.add_argument("--dataset_type", type=str, default='blender',
498 | help='options: blender / ellff')
499 | parser.add_argument("--testskip", type=int, default=8,
500 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
501 |
502 | # blender flags
503 | parser.add_argument("--white_bkgd", action='store_true',
504 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
505 |
506 | # llff flags
507 | parser.add_argument("--factor", type=int, default=8,
508 | help='downsample factor for LLFF images')
509 | parser.add_argument("--no_ndc", action='store_true',
510 | help='do not use normalized device coordinates (set for non-forward facing scenes)')
511 | parser.add_argument("--lindisp", action='store_true',
512 | help='sampling linearly in disparity rather than depth')
513 | parser.add_argument("--spherify", action='store_true',
514 | help='set for spherical 360 scenes')
515 | parser.add_argument("--llffhold", type=int, default=8,
516 | help='will take every 1/N images as LLFF test set, paper uses 8')
517 |
518 | # event options
519 | parser.add_argument("--use_event", type=bool, default=True,
520 | help='use event to help the training')
521 |
522 | # logging/saving options
523 | parser.add_argument("--i_print", type=int, default=10,
524 | help='frequency of console printout and metric loggin')
525 | parser.add_argument("--i_img", type=int, default=500,
526 | help='frequency of tensorboard image logging')
527 | parser.add_argument("--i_weights", type=int, default=10000,
528 | help='frequency of weight ckpt saving')
529 | parser.add_argument("--i_testset", type=int, default=200000,
530 | help='frequency of testset saving')
531 | parser.add_argument("--i_video", type=int, default=200000,
532 | help='frequency of render_poses video saving')
533 |
534 | return parser
535 |
536 | def train():
537 | print("----------------starting----------------")
538 | pynvml.nvmlInit()
539 | parser = config_parser()
540 | args = parser.parse_args()
541 |
542 | print("---------------data_loading----------------")
543 | # Load data
544 | K = None
545 | if args.dataset_type == 'blender':
546 | resolution_h, resolution_w = 800, 800
547 | view_num = 100
548 |
549 | images, poses, test_images, test_poses, render_poses, hwf = load_blender_data(args.datadir)
550 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
551 |
552 | near = 2.
553 | far = 6.
554 |
555 | if args.white_bkgd:
556 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
557 | test_images = test_images[...,:3]*test_images[...,-1:] + (1.-test_images[...,-1:])
558 | else:
559 | images = images[...,:3]
560 | test_imges = test_images[...,:3]
561 |
562 |
563 | elif args.dataset_type == 'ellff':
564 | resolution_h, resolution_w = 260, 346
565 | view_num = 30
566 |
567 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
568 | recenter=True, bd_factor=.75,
569 | spherify=args.spherify)
570 | hwf = poses[0, :3, -1]
571 | poses = poses[:, :3, :4]
572 |
573 | new_poses = []
574 | test_poses = []
575 | for i in range(30):
576 | pose = []
577 | test_poses.append(poses[i * 5])
578 | for j in range(5):
579 | pose.append(poses[i * 5 + j])
580 | new_poses.append(np.stack(pose, axis=0))
581 | poses = np.stack(new_poses, axis=0)
582 | test_poses = np.stack(test_poses, axis=0)
583 | test_images = None
584 |
585 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
586 |
587 | print('DEFINING BOUNDS')
588 | if args.no_ndc:
589 | near = np.ndarray.min(bds) * .9
590 | far = np.ndarray.max(bds) * 1.
591 |
592 | else:
593 | near = 0.
594 | far = 1.
595 | print('NEAR FAR', near, far)
596 |
597 | else:
598 | print('Unknown dataset type', args.dataset_type, 'exiting')
599 | return
600 |
601 |
602 | # Cast intrinsics to right types
603 | H, W, focal = hwf
604 | H, W = int(H), int(W)
605 | hwf = [H, W, focal]
606 |
607 | if K is None:
608 | K = np.array([
609 | [focal, 0, 0.5*W],
610 | [0, focal, 0.5*H],
611 | [0, 0, 1]
612 | ])
613 |
614 | # load event
615 | use_event = args.use_event
616 | if use_event:
617 | print("---------------event_data_loading----------------")
618 | event_map = torch.load(os.path.join(args.datadir, "events.pt"))
619 | print("event size: " + str(event_map.size()))
620 | event_map = event_map.to(device)
621 | combination = []
622 | for m in range(4):
623 | for n in range(m + 1, 5):
624 | combination.append([m, n])
625 |
626 | print("---------------creat_nerf----------------")
627 | # Create log dir and copy the config file
628 | basedir = args.basedir
629 | expname = args.expname
630 | os.makedirs(os.path.join(basedir, expname), exist_ok=True)
631 | f = os.path.join(basedir, expname, 'args.txt')
632 | with open(f, 'w') as file:
633 | for arg in sorted(vars(args)):
634 | attr = getattr(args, arg)
635 | file.write('{} = {}\n'.format(arg, attr))
636 | if args.config is not None:
637 | f = os.path.join(basedir, expname, 'config.txt')
638 | with open(f, 'w') as file:
639 | file.write(open(args.config, 'r').read())
640 |
641 | # Create nerf model
642 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
643 | global_step = start
644 |
645 | bds_dict = {
646 | 'near' : near,
647 | 'far' : far,
648 | }
649 | render_kwargs_train.update(bds_dict)
650 | render_kwargs_test.update(bds_dict)
651 |
652 | # Move testing data to GPU
653 | test_poses = torch.Tensor(test_poses).to(device)
654 | render_poses = torch.Tensor(render_poses).to(device)
655 |
656 | # Short circuit if only rendering out from trained model
657 | '''
658 | if args.render_only:
659 | print('RENDER ONLY')
660 | with torch.no_grad():
661 | if args.render_test:
662 | # render_test switches to test poses
663 | images = test_images
664 | else:
665 | # Default is smoother render_poses path
666 | images = None
667 |
668 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))
669 | os.makedirs(testsavedir, exist_ok=True)
670 | print('test poses shape', render_poses.shape)
671 |
672 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
673 | print('Done rendering', testsavedir)
674 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
675 |
676 | return
677 | '''
678 |
679 | # Prepare raybatch tensor if batching random rays
680 | N_rand = args.N_rand
681 | use_batching = not args.no_batching
682 |
683 | '''
684 | if use_batching:
685 | print("To be ccomplete")
686 | # For random ray batching---------------------------------------------------------------------------------------
687 | print('get rays')
688 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
689 | print('done, concats')
690 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
691 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
692 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
693 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
694 | rays_rgb = rays_rgb.astype(np.float32)
695 | print('shuffle rays')
696 | np.random.shuffle(rays_rgb)
697 |
698 | print('done')
699 | i_batch = 0
700 |
701 | # Move training data to GPU
702 | if use_batching:
703 | images = torch.Tensor(images).to(device)
704 | if use_batching:
705 | rays_rgb = torch.Tensor(rays_rgb).to(device)
706 | '''
707 |
708 |
709 | print("---------------start_training----------------")
710 | N_iters = 200000 + 1
711 | print('Begin')
712 |
713 |
714 | # Summary writers
715 | writer = SummaryWriter(os.path.join(basedir, expname, 'logs'))
716 |
717 | start = start + 1
718 | for i in trange(start, N_iters):
719 | time0 = time.time()
720 |
721 | # Sample random ray batch---------------------------------------------------------------------------------------
722 |
723 | if use_batching:
724 | print("To be ccomplete")
725 | '''
726 | # Random over all images
727 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
728 | batch = torch.transpose(batch, 0, 1)
729 | batch_rays, target_s = batch[:2], batch[2]
730 |
731 | i_batch += N_rand
732 | if i_batch >= rays_rgb.shape[0]:
733 | print("Shuffle data after an epoch!")
734 | rand_idx = torch.randperm(rays_rgb.shape[0])
735 | rays_rgb = rays_rgb[rand_idx]
736 | i_batch = 0
737 | '''
738 | else:
739 | # Random from one image
740 | img_i = np.random.choice(view_num)
741 | target = images[img_i]
742 | target = torch.Tensor(target).to(device)
743 | pose = poses[img_i, :, :3,:4]
744 |
745 | if N_rand is not None:
746 | rays_os = []
747 | rays_ds = []
748 | for j in range(5):
749 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose[j]))
750 | rays_os.append(rays_o)
751 | rays_ds.append(rays_d)
752 | if i < args.precrop_iters:
753 | dH = int(H // 2 * args.precrop_frac)
754 | dW = int(W // 2 * args.precrop_frac)
755 | coords = torch.stack(
756 | torch.meshgrid(
757 | torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
758 | torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)
759 | ), -1)
760 | if i == start:
761 | print(f"[Config] Center cropping of size {2 * dH} x {2 * dW} is enabled until iter {args.precrop_iters}")
762 | else:
763 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)),-1) # (H, W, 2)
764 |
765 | coords = torch.reshape(coords, [-1, 2]) # (H * W, 2)
766 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
767 | select_coords = coords[select_inds].long() # (N_rand, 2)
768 |
769 | all_extra = []
770 | all_rgb = []
771 | for j in range(5):
772 | rays_o = rays_os[j][select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
773 | rays_d = rays_ds[j][select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
774 | batch_rays = torch.stack([rays_o, rays_d], 0)
775 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
776 | verbose=i < 10, retraw=True,
777 | **render_kwargs_train)
778 | all_rgb.append(rgb)
779 | all_extra.append(extras['rgb0'])
780 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
781 |
782 | ##### Core optimization loop #####
783 | #change No.1 average-nerf
784 | all_rgb = torch.stack(all_rgb, dim=0)
785 | rgb = torch.mean(all_rgb, dim = 0)
786 |
787 | #change No.2 event-nerf
788 | if use_event:
789 | event_data = event_map[img_i]
790 | event_loss = event_loss_call(all_rgb, event_data, select_coords, combination, "rgb", resolution_h, resolution_w) * 0.001
791 |
792 | img_loss = img2mse(rgb, target_s)
793 | trans = extras['raw'][...,-1]
794 | loss = img_loss
795 | psnr = mse2psnr(img_loss)
796 |
797 | if 'rgb0' in extras:
798 | all_extra = torch.stack(all_extra, dim=0)
799 | extras = torch.mean(all_extra, dim=0)
800 | img_loss0 = img2mse(extras, target_s)
801 | loss = loss + img_loss0
802 | psnr0 = mse2psnr(img_loss0)
803 | loss = loss + event_loss
804 |
805 | optimizer.zero_grad()
806 | loss.backward()
807 | optimizer.step()
808 |
809 | # NOTE: IMPORTANT!
810 | ### update learning rate ###
811 | decay_rate = 0.1
812 | decay_steps = args.lrate_decay * 1000
813 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
814 | for param_group in optimizer.param_groups:
815 | param_group['lr'] = new_lrate
816 | ################################
817 |
818 | dt = time.time()-time0
819 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
820 | ##### end #####
821 |
822 | # Rest is logging
823 | if i%args.i_weights==0:
824 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
825 | torch.save({
826 | 'global_step': global_step,
827 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
828 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
829 | 'optimizer_state_dict': optimizer.state_dict(),
830 | }, path)
831 | print('Saved checkpoints at', path)
832 |
833 | if i%args.i_video==0 and i > 0:
834 | # Turn on testing mode
835 | with torch.no_grad():
836 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
837 | print('Done, saving', rgbs.shape, disps.shape)
838 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
839 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
840 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
841 |
842 | # if args.use_viewdirs:
843 | # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
844 | # with torch.no_grad():
845 | # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
846 | # render_kwargs_test['c2w_staticcam'] = None
847 | # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)
848 |
849 | if i%args.i_testset==0 and i > 0:
850 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
851 | os.makedirs(testsavedir, exist_ok=True)
852 | with torch.no_grad():
853 | render_path(torch.Tensor(test_poses).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=test_images, savedir=testsavedir)
854 | print('Saved test set')
855 |
856 | if i%args.i_print==0:
857 | writer.add_scalar("loss", loss, i)
858 | writer.add_scalar("event_loss", event_loss, i)
859 | writer.add_scalar("img_loss", img_loss, i)
860 | writer.add_scalar("img_loss0", img_loss0, i)
861 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
862 | tqdm.write(f"[TRAIN] Iter: {i} Event_Loss: {event_loss.item()} Img_Loss: {img_loss.item()} Img_Loss0: {img_loss0.item()}")
863 |
864 | global_step += 1
865 |
866 | if __name__=='__main__':
867 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
868 | train()
869 |
--------------------------------------------------------------------------------
/run_nerf_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | # Misc
9 | img2mse = lambda x, y : torch.mean((x - y) ** 2)
10 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
12 |
13 | # Positional encoding (section 5.1)
14 | class Embedder:
15 | def __init__(self, **kwargs):
16 | self.kwargs = kwargs
17 | self.create_embedding_fn()
18 |
19 | def create_embedding_fn(self):
20 | embed_fns = []
21 | d = self.kwargs['input_dims']
22 | out_dim = 0
23 | if self.kwargs['include_input']:
24 | embed_fns.append(lambda x : x)
25 | out_dim += d
26 |
27 | max_freq = self.kwargs['max_freq_log2']
28 | N_freqs = self.kwargs['num_freqs']
29 |
30 | if self.kwargs['log_sampling']:
31 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
32 | else:
33 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
34 |
35 | for freq in freq_bands:
36 | for p_fn in self.kwargs['periodic_fns']:
37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
38 | out_dim += d
39 |
40 | self.embed_fns = embed_fns
41 | self.out_dim = out_dim
42 |
43 | def embed(self, inputs):
44 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
45 |
46 |
47 | def get_embedder(multires, i=0):
48 | if i == -1:
49 | return nn.Identity(), 3
50 |
51 | embed_kwargs = {
52 | 'include_input' : True,
53 | 'input_dims' : 3,
54 | 'max_freq_log2' : multires-1,
55 | 'num_freqs' : multires,
56 | 'log_sampling' : True,
57 | 'periodic_fns' : [torch.sin, torch.cos],
58 | }
59 |
60 | embedder_obj = Embedder(**embed_kwargs)
61 | embed = lambda x, eo=embedder_obj : eo.embed(x)
62 | return embed, embedder_obj.out_dim
63 |
64 |
65 | # Model
66 | class NeRF(nn.Module):
67 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
68 | """
69 | """
70 | super(NeRF, self).__init__()
71 | self.D = D
72 | self.W = W
73 | self.input_ch = input_ch
74 | self.input_ch_views = input_ch_views
75 | self.skips = skips
76 | self.use_viewdirs = use_viewdirs
77 |
78 | self.pts_linears = nn.ModuleList(
79 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
80 |
81 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
82 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
83 |
84 | ### Implementation according to the paper
85 | # self.views_linears = nn.ModuleList(
86 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
87 |
88 | if use_viewdirs:
89 | self.feature_linear = nn.Linear(W, W)
90 | self.alpha_linear = nn.Linear(W, 1)
91 | self.rgb_linear = nn.Linear(W//2, 3)
92 | else:
93 | self.output_linear = nn.Linear(W, output_ch)
94 |
95 | def forward(self, x):
96 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
97 | h = input_pts
98 | for i, l in enumerate(self.pts_linears):
99 | h = self.pts_linears[i](h)
100 | h = F.relu(h)
101 | if i in self.skips:
102 | h = torch.cat([input_pts, h], -1)
103 |
104 | if self.use_viewdirs:
105 | alpha = self.alpha_linear(h)
106 | feature = self.feature_linear(h)
107 | h = torch.cat([feature, input_views], -1)
108 |
109 | for i, l in enumerate(self.views_linears):
110 | h = self.views_linears[i](h)
111 | h = F.relu(h)
112 |
113 | rgb = self.rgb_linear(h)
114 | outputs = torch.cat([rgb, alpha], -1)
115 | else:
116 | outputs = self.output_linear(h)
117 |
118 | return outputs
119 |
120 | def load_weights_from_keras(self, weights):
121 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
122 |
123 | # Load pts_linears
124 | for i in range(self.D):
125 | idx_pts_linears = 2 * i
126 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
127 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
128 |
129 | # Load feature_linear
130 | idx_feature_linear = 2 * self.D
131 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
132 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
133 |
134 | # Load views_linears
135 | idx_views_linears = 2 * self.D + 2
136 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
137 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
138 |
139 | # Load rgb_linear
140 | idx_rbg_linear = 2 * self.D + 4
141 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
142 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
143 |
144 | # Load alpha_linear
145 | idx_alpha_linear = 2 * self.D + 6
146 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
147 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
148 |
149 |
150 |
151 | # Ray helpers
152 | def get_rays(H, W, K, c2w):
153 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
154 | i = i.t()
155 | j = j.t()
156 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
157 | # Rotate ray directions from camera frame to the world frame
158 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
159 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
160 | rays_o = c2w[:3,-1].expand(rays_d.shape)
161 | return rays_o, rays_d
162 |
163 |
164 | def get_rays_np(H, W, K, c2w):
165 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
166 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
167 | # Rotate ray directions from camera frame to the world frame
168 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
169 | # Translate camera frame's origin to the world frame. It is the origin of all rays.
170 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
171 | return rays_o, rays_d
172 |
173 |
174 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
175 | # Shift ray origins to near plane
176 | t = -(near + rays_o[...,2]) / rays_d[...,2]
177 | rays_o = rays_o + t[...,None] * rays_d
178 |
179 | # Projection
180 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
181 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
182 | o2 = 1. + 2. * near / rays_o[...,2]
183 |
184 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
185 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
186 | d2 = -2. * near / rays_o[...,2]
187 |
188 | rays_o = torch.stack([o0,o1,o2], -1)
189 | rays_d = torch.stack([d0,d1,d2], -1)
190 |
191 | return rays_o, rays_d
192 |
193 |
194 | # Hierarchical sampling (section 5.2)
195 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
196 | # Get pdf
197 | weights = weights + 1e-5 # prevent nans
198 | pdf = weights / torch.sum(weights, -1, keepdim=True)
199 | cdf = torch.cumsum(pdf, -1)
200 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
201 |
202 | # Take uniform samples
203 | if det:
204 | u = torch.linspace(0., 1., steps=N_samples)
205 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
206 | else:
207 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
208 |
209 | # Pytest, overwrite u with numpy's fixed random numbers
210 | if pytest:
211 | np.random.seed(0)
212 | new_shape = list(cdf.shape[:-1]) + [N_samples]
213 | if det:
214 | u = np.linspace(0., 1., N_samples)
215 | u = np.broadcast_to(u, new_shape)
216 | else:
217 | u = np.random.rand(*new_shape)
218 | u = torch.Tensor(u)
219 |
220 | # Invert CDF
221 | u = u.contiguous()
222 | inds = torch.searchsorted(cdf, u, right=True)
223 | below = torch.max(torch.zeros_like(inds-1), inds-1)
224 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
225 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
226 |
227 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
228 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
229 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
230 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
231 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
232 |
233 | denom = (cdf_g[...,1]-cdf_g[...,0])
234 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
235 | t = (u-cdf_g[...,0])/denom
236 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
237 |
238 | return samples
239 |
--------------------------------------------------------------------------------