├── EDI.py ├── LICENSE ├── Papers.pdf ├── README.md ├── config_real-world.txt ├── config_synthetic.txt ├── event_loss_helpers.py ├── figure.png ├── load_LINEMOD.py ├── load_blender.py ├── load_deepvoxels.py ├── load_event.py ├── load_llff.py ├── orginal_data&preprocessing ├── real-world │ └── Data-preprocessing │ │ └── data-preprocessing.py └── synthetic │ ├── Blur-synthesizing │ ├── README.md │ ├── main.py │ ├── process.py │ ├── unprocess.py │ └── util.py │ └── Event-preprocessing │ └── event-preprocessing.py ├── requirements.txt ├── run_nerf_exp.py └── run_nerf_helpers.py /EDI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import math 4 | 5 | 6 | # -------------------EDI MODEL---------------------- 7 | def EDI_rgb(data_path, frame_num, bin_num, events): 8 | event_sum = torch.zeros(260, 346) 9 | EDI = torch.ones(260, 346) 10 | 11 | for i in range(bin_num): 12 | event_sum = event_sum + events[frame_num][i] 13 | EDI = EDI + torch.exp(0.3 * event_sum) 14 | 15 | EDI = torch.stack([EDI, EDI, EDI], axis=-1) 16 | img = (bin_num + 1) * blurry_image / EDI 17 | img = torch.clamp(img, max=255) 18 | cv2.imwrite(data_path + "images_for_colmap/{0:03d}.jpg".format(frame_num * (bin_num + 1)), img.numpy()) # save the first deblurred image 19 | 20 | offset = torch.zeros(260, 346) 21 | for i in range(bin_num): 22 | offset = offset + events[frame_num][i] 23 | imgs = img * torch.exp(0.3 * torch.stack([offset, offset, offset], axis=-1)) 24 | cv2.imwrite(data_path + "images_for_colmap/{0:03d}.jpg".format(frame_num * (bin_num + 1) + 1 + i), imgs.numpy()) # save the rest of the deblurred images 25 | 26 | threshold = 0.3 27 | bin_num = 4 28 | view_num = 30 29 | data_path = "./data/" 30 | events = torch.loadtxt(data_path + "events.pt").view(view_num, bin_num, 260, 346) # load the preprocessed events in .pt file 31 | 32 | for i in range(view_num): 33 | blurry_image = torch.tensor(cv2.imread(data_path + "images/{0:03d}.jpg".format(i * (bin_num + 1))), dtype=torch.float) # load the target blurry image 34 | EDI_rgb(data_path, i, bin_num, events) 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CVTEAM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Papers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/E2NeRF/487fec175c8729fdb49c22a2de1a093ca9192fe9/Papers.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for E2NeRF: Event Enhanced Neural Radiance Fields from Blurry Images (ICCV 2023) 2 | This is an official PyTorch implementation of the E2NeRF. Click [here](https://icvteam.github.io/E2NeRF.html) to see the video and supplementary materials in our project website. 3 | 4 | 5 | 6 | ## Method Overview 7 | 8 | ![](./figure.png) 9 | 10 | 11 | 12 | ## Installation 13 | The code is based on [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) and use the same environment. 14 | Please refer to its github [website](https://github.com/yenchenlin/nerf-pytorch) for the environment installation. 15 | 16 | 17 | 18 | ## Code 19 | 20 | ### Synthetic Data 21 | 22 | The configs of the synthetic data are in the config_synthetic.txt file. Please download the synthetic data below and put it into the corresponding file (./data/synthetic/). Then you can use the command below to train the model. 23 | 24 | ``` 25 | python run_nerf_exp.py --config config_synthetic.txt 26 | ``` 27 | 28 | ### Real-World Data 29 | 30 | The configs of the real-world data are in the config_real-world.txt file. Please download the real-world data below and put it into the corresponding file (./data/real-world/). Then you can use the command below to train the model. 31 | 32 | ``` 33 | python run_nerf_exp.py --config config_real-world.txt 34 | ``` 35 | 36 | Notice that for real-world data experiement in the paper, we use the poses for the video rendering to render the novel view images. 37 | Please just replace the "test_poses" in line 853 of run_nerf_exp.py with "render_poses" to generate the novel view images (120 images in total). 38 | We use the no-reference image quality assessment metrics to evaluate the novel view images and blur view images together (120+30=150 images). 39 | 40 | 41 | ## Dataset 42 | Download the dataset [here](https://drive.google.com/drive/folders/1XhOEp4UdLL7EnDNyWdxxX8aRvzF53fWo?usp=sharing). 43 | The dataset contains the "data" for training and the "original data". 44 | 45 | ### Synthetic Data: 46 | For the file of each scene, there are training images in the "train" file and the corresponding event data "events.pt". The ground truth images are in the "test" file. 47 | 48 | Like in original NeRF, the training and testing poses are in the "transform_train.json" file and "transform_test.json" file. 49 | Notice that at the test time, we use the first pose of each view in "transform_test.json" to render the test images and the Ground Truth images are also rendered at this pose. 50 | 51 | ### Real-World Data: 52 | The structure is like original NeRF's llff data and the event data is in "event.pt". 53 | 54 | ### Event Data: 55 | For easy reading, we transform the event stream in to event bins as event.pt file. You can use pytorch to load the file. The shape of the tensor is (view_number, bin_number, H, W) and each element means the number of the events (positive and negative indicate polarity). 56 | 57 | 58 | 59 | ## Original Data & Preproccesing 60 | ### Synthetic Data: 61 | There are original images for synthesizing the blurry image and the code. Besides, we supply the original event data generated from v2e. We also provide the code to transform the ".txt" event to "events.pt" for E2NeRF training. 62 | 63 | ### Real-World Data: 64 | We supply the original ".aedat4" data captured by davis346 and the processing code in the file. We also convert the event data into events.pt for training. 65 | 66 | ### EDI: 67 | We update the EDI code in the repository. 68 | You can use this code to deblur the images in the "train" file with corresponding events.pt data. 69 | And the deblurred images are saved at "images_for_colmap" file. 70 | Then, you can use colmap to generate the poses as in NeRF. 71 | 72 | 73 | 74 | ## Citation 75 | 76 | If you find this useful, please consider citing our paper: 77 | 78 | ```bibtex 79 | @inproceedings{qi2023e2nerf, 80 | title={E2nerf: Event enhanced neural radiance fields from blurry images}, 81 | author={Qi, Yunshan and Zhu, Lin and Zhang, Yu and Li, Jia}, 82 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 83 | pages={13254--13264}, 84 | year={2023} 85 | } 86 | ``` 87 | 88 | 89 | 90 | ## Acknowledgment 91 | 92 | The overall framework are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/). We appreciate the effort of the contributors to these repositories. 93 | -------------------------------------------------------------------------------- /config_real-world.txt: -------------------------------------------------------------------------------- 1 | expname = lego 2 | basedir = ./logs/real-world 3 | datadir = ./data/real-world/lego 4 | dataset_type = ellff 5 | 6 | factor = 1 7 | llffhold = 0 8 | no_batching = True 9 | 10 | N_rand = 1024 11 | N_samples = 64 12 | N_importance = 128 13 | 14 | use_viewdirs = True 15 | raw_noise_std = 1e0 16 | 17 | spherify = True 18 | no_ndc = True -------------------------------------------------------------------------------- /config_synthetic.txt: -------------------------------------------------------------------------------- 1 | expname = lego 2 | basedir = ./logs 3 | datadir = ./data/synthetic/lego 4 | dataset_type = blender 5 | 6 | no_batching = True 7 | lrate_decay = 500 8 | 9 | N_samples = 64 10 | N_importance = 128 11 | 12 | use_viewdirs = True 13 | 14 | white_bkgd = False 15 | 16 | N_rand = 1024 -------------------------------------------------------------------------------- /event_loss_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | 5 | def lin_log(x, threshold=20): 6 | """ 7 | linear mapping + logarithmic mapping. 8 | :param x: float or ndarray the input linear value in range 0-255 9 | :param threshold: float threshold 0-255 the threshold for transisition from linear to log mapping 10 | """ 11 | # converting x into np.float32. 12 | if x.dtype is not torch.float64: 13 | x = x.double() 14 | f = (1./threshold) * math.log(threshold) 15 | y = torch.where(x <= threshold, x*f, torch.log(x)) 16 | 17 | return y.float() 18 | 19 | 20 | def event_loss_call(all_rgb, event_data, select_coords, combination, rgb2gray, resolution_h, resolution_w): 21 | ''' 22 | simulate the generation of event stream and calculate the event loss 23 | ''' 24 | if rgb2gray == "rgb": 25 | rgb2grey = torch.tensor([0.299,0.587,0.114]) 26 | elif rgb2gray == "ave": 27 | rgb2grey = torch.tensor([1/3, 1/3, 1/3]) 28 | loss = [] 29 | 30 | chose = random.sample(combination, 10) 31 | for its in range(10): 32 | start = chose[its][0] 33 | end = chose[its][1] 34 | 35 | thres_pos = (lin_log(torch.mv(all_rgb[end], rgb2grey) * 255) - lin_log(torch.mv(all_rgb[start], rgb2grey) * 255)) / 0.3 36 | thres_neg = (lin_log(torch.mv(all_rgb[end], rgb2grey) * 255) - lin_log(torch.mv(all_rgb[start], rgb2grey) * 255)) / 0.2 37 | 38 | event_cur = event_data[start].view(resolution_h, resolution_w)[select_coords[:, 0], select_coords[:, 1]] 39 | for j in range(start + 1, end): 40 | event_cur += event_data[j].view(resolution_h, resolution_w)[select_coords[:, 0], select_coords[:, 1]] 41 | 42 | pos = event_cur > 0 43 | neg = event_cur < 0 44 | 45 | loss_pos = torch.mean(((thres_pos * pos) - ((event_cur + 0.5) * pos)) ** 2) 46 | loss_neg = torch.mean(((thres_neg * neg) - ((event_cur - 0.5) * neg)) ** 2) 47 | 48 | loss.append(loss_pos + loss_neg) 49 | 50 | event_loss = torch.mean(torch.stack(loss, dim=0), dim=0) 51 | return event_loss 52 | -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/E2NeRF/487fec175c8729fdb49c22a2de1a093ca9192fe9/figure.png -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | 7 | 8 | trans_t = lambda t : torch.Tensor([ 9 | [1,0,0,0], 10 | [0,1,0,0], 11 | [0,0,1,t], 12 | [0,0,0,1]]).float() 13 | 14 | rot_phi = lambda phi : torch.Tensor([ 15 | [1,0,0,0], 16 | [0,np.cos(phi),-np.sin(phi),0], 17 | [0,np.sin(phi), np.cos(phi),0], 18 | [0,0,0,1]]).float() 19 | 20 | rot_theta = lambda th : torch.Tensor([ 21 | [np.cos(th),0,-np.sin(th),0], 22 | [0,1,0,0], 23 | [np.sin(th),0, np.cos(th),0], 24 | [0,0,0,1]]).float() 25 | 26 | 27 | def pose_spherical(theta, phi, radius): 28 | c2w = trans_t(radius) 29 | c2w = rot_phi(phi/180.*np.pi) @ c2w 30 | c2w = rot_theta(theta/180.*np.pi) @ c2w 31 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 32 | return c2w 33 | 34 | 35 | def load_blender_data(basedir): 36 | splits = ['train', 'test'] 37 | metas = {} 38 | for s in splits: 39 | with open(os.path.join(basedir, 'transform_{}.json'.format(s)), 'r') as fp: 40 | metas[s] = json.load(fp) 41 | 42 | imgs = [] 43 | poses = [] 44 | test_imgs = [] 45 | test_poses = [] 46 | for s in splits: 47 | meta = metas[s] 48 | 49 | if s == 'train': 50 | num = 0 51 | for frame in meta['frames'][0:200:2]: 52 | fname = os.path.join(basedir, "./train/r_{}.png".format(num)) 53 | num = num + 2 54 | imgs.append(imageio.imread(fname)) 55 | poses.append(np.array([frame['transform_matrix'][1], frame['transform_matrix'][5], frame['transform_matrix'][9], frame['transform_matrix'][13], frame['transform_matrix'][17]])) 56 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 57 | poses = np.array(poses).astype(np.float32) 58 | else: 59 | num = 0 60 | for frame in meta['frames'][0:200:1]: 61 | fname = os.path.join(basedir, "./test/r_{}.png".format(num)) 62 | num = num + 1 63 | test_imgs.append(imageio.imread(fname)) 64 | test_poses.append(np.array(frame['transform_matrix'][1])) 65 | 66 | 67 | test_imgs = (np.array(test_imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 68 | test_poses = np.array(test_poses).astype(np.float32) 69 | 70 | 71 | H, W = imgs[0].shape[:2] 72 | meta = metas['train'] 73 | camera_angle_x = float(meta['camera_angle_x']) 74 | focal = .5 * W / np.tan(.5 * camera_angle_x) 75 | 76 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180, 180, 180 + 1)[:-1]], 0) 77 | 78 | return imgs, poses, test_imgs, test_poses, render_poses, [H, W, focal] 79 | -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /load_event.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | # read event data to list 11 | def load_event_data_v1(basedir): 12 | data_path = os.path.join(basedir, "event") 13 | event_list = [] 14 | 15 | for i in range(0, 10, 2): 16 | 17 | file = os.path.join(data_path, "r_" + str(i) + "/v2e-dvs-events.txt") 18 | print(file) 19 | fp = open(file, "r") 20 | events = [] 21 | event_block = [] 22 | counter = 1 23 | 24 | for j in range(6): 25 | fp.readline() 26 | 27 | while True: 28 | line = fp.readline() 29 | if not line: 30 | events.append(event_block) 31 | break 32 | 33 | info = line.split() 34 | t = float(info[0]) 35 | x = int(info[1]) 36 | y = int(info[2]) 37 | p = int(info[3]) 38 | if t > counter * 0.001: 39 | events.append(event_block) 40 | event_block = [] 41 | counter += 1 42 | event = [x, y, t, p] 43 | event_block.append(event) 44 | while counter < 20: 45 | counter += 1 46 | events.append([]) 47 | event_list.append(events) 48 | return 49 | 50 | 51 | # read event data to numpy 52 | def load_event_data_v2(basedir): 53 | data_path = os.path.join(basedir, "event") 54 | event_map = np.zeros((100, 20, 800, 800), dtype=np.int) 55 | 56 | for i in range(0, 200, 2): 57 | 58 | file = os.path.join(data_path, "r_" + str(i) + "/v2e-dvs-events.txt") 59 | fp = open(file, "r") 60 | counter = 1 61 | 62 | for j in range(6): 63 | fp.readline() 64 | 65 | while True: 66 | line = fp.readline() 67 | if not line: 68 | break 69 | 70 | info = line.split() 71 | t = float(info[0]) 72 | x = int(info[1]) 73 | y = int(info[2]) 74 | p = int(info[3]) 75 | if t > counter * 0.001: 76 | counter += 1 77 | if p == 0: 78 | event_map[int(i / 2)][counter - 1][y][x] -= 1 79 | else: 80 | event_map[int(i / 2)][counter - 1][y][x] += 1 81 | 82 | return event_map 83 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = 'images_{}'.format(r) 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = 'images_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | print('Minifying', r, basedir) 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | print(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | 60 | 61 | 62 | def _load_data(basedir, factor=8, width=None, height=None, load_imgs=True): 63 | 64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 66 | bds = poses_arr[:, -2:].transpose([1,0]) 67 | 68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 70 | sh = imageio.imread(img0).shape 71 | 72 | sfx = '' 73 | 74 | if factor > 1: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | _minify(basedir, resolutions=[[height, width]]) 82 | sfx = '_{}x{}'.format(width, height) 83 | elif width is not None: 84 | factor = sh[1] / float(width) 85 | height = int(sh[0] / factor) 86 | _minify(basedir, resolutions=[[height, width]]) 87 | sfx = '_{}x{}'.format(width, height) 88 | else: 89 | factor = 1 90 | 91 | imgdir = os.path.join(basedir, 'images' + sfx) 92 | if not os.path.exists(imgdir): 93 | print( imgdir, 'does not exist, returning' ) 94 | return 95 | 96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 97 | 98 | 99 | if poses.shape[-1] != len(imgfiles) * 5: 100 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 101 | return 102 | 103 | sh = imageio.imread(imgfiles[0]).shape 104 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 105 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 106 | 107 | if not load_imgs: 108 | return poses, bds 109 | 110 | def imread(f): 111 | if f.endswith('png'): 112 | return imageio.imread(f, ignoregamma=True) 113 | else: 114 | return imageio.imread(f) 115 | 116 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 117 | imgs = np.stack(imgs, -1) 118 | 119 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 120 | 121 | return poses, bds, imgs 122 | 123 | 124 | 125 | 126 | 127 | 128 | def normalize(x): 129 | return x / np.linalg.norm(x) 130 | 131 | def viewmatrix(z, up, pos): 132 | vec2 = normalize(z) 133 | vec1_avg = up 134 | vec0 = normalize(np.cross(vec1_avg, vec2)) 135 | vec1 = normalize(np.cross(vec2, vec0)) 136 | m = np.stack([vec0, vec1, vec2, pos], 1) 137 | return m 138 | 139 | def ptstocam(pts, c2w): 140 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 141 | return tt 142 | 143 | def poses_avg(poses): 144 | 145 | hwf = poses[0, :3, -1:] 146 | 147 | center = poses[:, :3, 3].mean(0) 148 | vec2 = normalize(poses[:, :3, 2].sum(0)) 149 | up = poses[:, :3, 1].sum(0) 150 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 151 | 152 | return c2w 153 | 154 | 155 | 156 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 157 | render_poses = [] 158 | rads = np.array(list(rads) + [1.]) 159 | hwf = c2w[:,4:5] 160 | 161 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 162 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 163 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 164 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 165 | return render_poses 166 | 167 | 168 | 169 | def recenter_poses(poses): 170 | 171 | poses_ = poses+0 172 | bottom = np.reshape([0,0,0,1.], [1,4]) 173 | c2w = poses_avg(poses) 174 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 175 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 176 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 177 | 178 | poses = np.linalg.inv(c2w) @ poses 179 | poses_[:,:3,:4] = poses[:,:3,:4] 180 | poses = poses_ 181 | return poses 182 | 183 | 184 | ##################### 185 | 186 | 187 | def spherify_poses(poses, bds): 188 | 189 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 190 | 191 | rays_d = poses[:,:3,2:3] 192 | rays_o = poses[:,:3,3:4] 193 | 194 | def min_line_dist(rays_o, rays_d): 195 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 196 | b_i = -A_i @ rays_o 197 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 198 | return pt_mindist 199 | 200 | pt_mindist = min_line_dist(rays_o, rays_d) 201 | 202 | center = pt_mindist 203 | up = (poses[:,:3,3] - center).mean(0) 204 | 205 | vec0 = normalize(up) 206 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 207 | vec2 = normalize(np.cross(vec0, vec1)) 208 | pos = center 209 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 210 | 211 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 212 | 213 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 214 | 215 | sc = 1./rad 216 | poses_reset[:,:3,3] *= sc 217 | bds *= sc 218 | rad *= sc 219 | 220 | centroid = np.mean(poses_reset[:,:3,3], 0) 221 | zh = centroid[2] 222 | radcircle = np.sqrt(rad**2-zh**2) 223 | new_poses = [] 224 | 225 | for th in np.linspace(0.,2.*np.pi, 120): 226 | 227 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 228 | up = np.array([0,0,-1.]) 229 | 230 | vec2 = normalize(camorigin) 231 | vec0 = normalize(np.cross(vec2, up)) 232 | vec1 = normalize(np.cross(vec2, vec0)) 233 | pos = camorigin 234 | p = np.stack([vec0, vec1, vec2, pos], 1) 235 | 236 | new_poses.append(p) 237 | 238 | new_poses = np.stack(new_poses, 0) 239 | 240 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 241 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 242 | 243 | return poses_reset, new_poses, bds 244 | 245 | 246 | def load_llff_data(basedir, factor=None, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 247 | 248 | 249 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 250 | print('Loaded', basedir, bds.min(), bds.max()) 251 | 252 | # Correct rotation matrix ordering and move variable dim to axis 0 253 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 254 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 255 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 256 | images = imgs 257 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 258 | 259 | # Rescale if bd_factor is provided 260 | sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 261 | poses[:,:3,3] *= sc 262 | bds *= sc 263 | 264 | if recenter: 265 | poses = recenter_poses(poses) 266 | 267 | if spherify: 268 | poses, render_poses, bds = spherify_poses(poses, bds) 269 | 270 | else: 271 | 272 | c2w = poses_avg(poses) 273 | print('recentered', c2w.shape) 274 | print(c2w[:3,:4]) 275 | 276 | ## Get spiral 277 | # Get average pose 278 | up = normalize(poses[:, :3, 1].sum(0)) 279 | 280 | # Find a reasonable "focus depth" for this dataset 281 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 282 | dt = .75 283 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 284 | focal = mean_dz 285 | 286 | # Get radii for spiral path 287 | shrink_factor = .8 288 | zdelta = close_depth * .2 289 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 290 | rads = np.percentile(np.abs(tt), 90, 0) 291 | c2w_path = c2w 292 | N_views = 120 293 | N_rots = 2 294 | if path_zflat: 295 | # zloc = np.percentile(tt, 10, 0)[2] 296 | zloc = -close_depth * .1 297 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 298 | rads[2] = 0. 299 | N_rots = 1 300 | N_views/=2 301 | 302 | # Generate poses for spiral path 303 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 304 | 305 | 306 | render_poses = np.array(render_poses).astype(np.float32) 307 | 308 | c2w = poses_avg(poses) 309 | print('Data:') 310 | print(poses.shape, images.shape, bds.shape) 311 | 312 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 313 | i_test = np.argmin(dists) 314 | print('HOLDOUT view is', i_test) 315 | 316 | images = images.astype(np.float32) 317 | poses = poses.astype(np.float32) 318 | 319 | return images, poses, bds, render_poses, i_test 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/real-world/Data-preprocessing/data-preprocessing.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from dv import AedatFile 6 | 7 | ''' 8 | preload the aedat file 9 | transform the video into separate blurry frames 10 | transform the corresponding event into .pt file for E2NeRF training 11 | ''' 12 | 13 | def trans(x): 14 | if x: 15 | return 1 16 | else: 17 | return 0 18 | 19 | def load_events(file, time, views_num, inter_num): 20 | ''' 21 | according to the start exposure timestamp and end exposure timestamp of the blurry image to select the corresponding event 22 | output the .pt preloaded event data for E2NeRF training 23 | ''' 24 | f = AedatFile(file) 25 | frame_num = 0 26 | event_map = np.zeros((views_num, inter_num - 1, 260, 346)) 27 | 28 | for event in f["events"]: 29 | if time[frame_num][1] < event.timestamp: 30 | frame_num = frame_num + 1 31 | if frame_num >= views_num: 32 | break 33 | 34 | if time[frame_num][0] <= event.timestamp <= time[frame_num][1]: 35 | if event.polarity: 36 | event_map[frame_num][int((event.timestamp - time[frame_num][0]) / 25001)][event.y][event.x] += 1 37 | else: 38 | event_map[frame_num][int((event.timestamp - time[frame_num][0]) / 25001)][event.y][event.x] -= 1 39 | return event_map 40 | 41 | 42 | def load_frames(file, basedir, views_num, inter_num): 43 | ''' 44 | preload the blurry frames in aedat4 file 45 | output the start exposure timestamp and end exposure timestamp of the corresponding blurry image 46 | ''' 47 | global s 48 | f = AedatFile(file) 49 | sum = 0 50 | times = [] 51 | for frame in f["frames"]: 52 | if sum % 10 == 0: 53 | cv2.imwrite(basedir + data_name + "/images/{0:03d}.jpg".format(s * inter_num), frame.image) 54 | times.append([frame.timestamp_start_of_exposure, frame.timestamp_end_of_exposure]) 55 | s += 1 56 | sum += 1 57 | if s == views_num: 58 | break 59 | return times 60 | 61 | 62 | basedir = "../davis-aedat4/" 63 | data_name = "lego" 64 | height = 260 65 | width = 346 66 | global s 67 | s = 0 68 | 69 | if __name__ == '__main__': 70 | events = [] 71 | 72 | inter_num = 5 # The number of the event bin (b in paper) + 1 73 | views_num = 30 # The number of the views of the scene 74 | 75 | if not os.path.exists(basedir + data_name + "/images"): 76 | os.mkdir(basedir + data_name + "/images") 77 | 78 | file = basedir + data_name + ".aedat4" 79 | 80 | times = load_frames(file, basedir, views_num, inter_num) 81 | events.append(load_events(file, times, views_num, inter_num)) 82 | events = np.concatenate(events) 83 | events = torch.tensor(events).view(-1, inter_num - 1, 89960) 84 | torch.save(events, basedir + data_name + "/events.pt") 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Blur-synthesizing/README.md: -------------------------------------------------------------------------------- 1 | # Generation of blur image 2 | 3 | ## Introduction 4 | 5 | According to the inverse ISP and ISP process, generate the original image sequence into a image with motion blur. 6 | 7 | ## Usage 8 | 9 | ### Parameters 10 | 11 | * input_dir: input frames dir 12 | * output_dir: output frame name 13 | * scale_factor: convert scale_factor frame to 1 frame with blur, note that scale_factor must be divide by total number of frames of input 14 | * input_exposure: base exposure time of input frames in microsecond 15 | * input_iso: assumed ISO for input data 16 | * output_iso: expected ISO for output 17 | 18 | ### Sample 19 | 20 | ``` 21 | python main.py 22 | --input_dir ../blender-synthetic-images/chair/r_0/ 23 | --output_name ../blender-synthetic-images/chair/r_0.png 24 | --scale_factor 18 25 | --input_exposure 10 26 | --input_iso 50 27 | --output_iso 50 28 | ``` 29 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Blur-synthesizing/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from process import * 3 | from unprocess import * 4 | import cv2 as cv 5 | import os 6 | from tqdm import tqdm 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input_dir', type=str, default='../blender-synthetic-images/chair/r_0/', help='input frames dir') 10 | parser.add_argument('--output_name', type=str, default='../blender-synthetic-images/chair/r_0.png', help='output frame name') 11 | parser.add_argument('--scale_factor', type=int, default=18, help='convert scale_factor frames to 1 frame with blur, note that scale_factor must be divide by total number of frames of input') 12 | parser.add_argument('--input_exposure', type=float, default=10, help='base exposure time of input frames in microsecond') 13 | parser.add_argument('--input_iso', type=int, default=50, help='Assumed ISO for input data') 14 | parser.add_argument('--output_iso', type=int, default=50, help='Expected ISO for output') 15 | parser.add_argument('--debug', dest='debug', action='store_true', default=False, help='Debug mode') 16 | parser.add_argument('--show_effect', dest='show_effect', action='store_true', default=False, 17 | help='Show the initial image and final image') 18 | args = parser.parse_args() 19 | 20 | if __name__ == '__main__': 21 | # if file_cnt accum to scale_factor, output it and file_cnt reset to 0 22 | file_cnt = 0 23 | file_tot_cnt = 0 24 | # acc_img: img accum to acc_img 25 | acc_img = 0. 26 | # get file list 27 | # input_list = sorted(os.listdir(args.input_dir)) 28 | input_list = os.listdir(args.input_dir) 29 | input_list.sort(key=lambda x: int(x[:-4])) 30 | # get image list 31 | img_list = [] 32 | for input_file in input_list: 33 | if input_file.split('.')[1] == 'jpg' or input_file.split('.')[1] == 'png': 34 | img_list.append(input_file) 35 | 36 | # scan the input dir 37 | with tqdm(total=len(img_list), desc='process of file datasimul') as pbar: 38 | for filename in img_list: 39 | assert(filename.split('.')[1] == 'jpg' or filename.split('.')[1] == 'png') 40 | # Read the image 41 | img = imread_rgb(args.input_dir + filename) 42 | 43 | # make the image H W to even 44 | flagH, flagW = img.shape[0] % 2, img.shape[1] % 2 45 | img = cv.copyMakeBorder(img, flagH, 0, flagW, 0, cv.BORDER_WRAP) 46 | 47 | H, W, C = img.shape 48 | 49 | # Transform image from numpy.ndarray to torch.Tensor type 50 | img = torch.from_numpy(img) 51 | 52 | # Display the initial image 53 | if args.show_effect: 54 | single_8bit_image_display(img, "Initial") 55 | 56 | # Transform image from [0, 255] to [0, 1] 57 | img = image_8bit_to_real(img) 58 | 59 | # Fundamental Arguments Settings 60 | args.rgb2cam, args.cam2rgb = random_ccm() 61 | args.red_gain, args.blue_gain = random_gains() 62 | args.output_exposure = args.input_iso * args.input_exposure / args.output_iso 63 | 64 | # Invert ISP 65 | img = unprocess(image=img, args=args, debug=args.debug) 66 | 67 | # file_cnt accum 68 | file_cnt += 1 69 | file_tot_cnt += 1 70 | acc_img += img 71 | 72 | if(file_cnt == args.scale_factor): 73 | 74 | file_cnt = 0 75 | acc_img.clamp_(0., 1.) 76 | # ISP 77 | acc_img = process(image=acc_img, args=args, debug=args.debug) 78 | # Transform image from [0, 1] to [0, 255] 79 | acc_img = image_real_to_8bit(acc_img) 80 | 81 | # Display the final image 82 | if args.show_effect: 83 | single_8bit_image_display(acc_img, "Final") 84 | 85 | # Transform image from torch.Tensor type to numpy.ndarray 86 | acc_img = acc_img.numpy() 87 | 88 | # Write the image 89 | rgba_img = cv2.cvtColor(acc_img, cv2.COLOR_BGR2RGBA) 90 | cv2.imwrite(args.output_name, rgba_img) 91 | acc_img = 0. 92 | 93 | pbar.update(args.scale_factor) 94 | 95 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Blur-synthesizing/process.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import math 3 | import torch.distributions as tdist 4 | 5 | 6 | def tone_mapping(image: torch.Tensor, debug=False): 7 | image = torch.clamp(image, min=0.0, max=1.0) 8 | out = 3.0 * torch.pow(image, 2) - 2.0 * torch.pow(image, 3) 9 | if debug: 10 | single_real_image_display(out, "Image_After_Tone_Mapping") 11 | return out 12 | 13 | 14 | def gamma_compression(image: torch.Tensor, debug=False): 15 | out = torch.pow(torch.clamp(image, min=1e-8), 1 / 2.2) 16 | if debug: 17 | single_real_image_display(out, "Image_After_Gamma_Compression") 18 | return out 19 | 20 | 21 | def color_correction(image: torch.Tensor, ccm, debug=False): 22 | shape = image.size() 23 | image = torch.reshape(image, [-1, 3]) 24 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) 25 | out = torch.reshape(image, shape) 26 | if debug: 27 | single_real_image_display(out, "Image_After_Color_Correction") 28 | return out 29 | 30 | 31 | def demosaic(image: torch.Tensor, debug=False): 32 | """ Reference Code: https://github.com/Chinmayi5577/Demosaicing """ 33 | shape = image.size() 34 | image = image.numpy() 35 | out_np = np.zeros((shape[0] * 2, shape[1] * 2, 3), dtype=np.float32) 36 | 37 | out_np[0::2, 0::2, 0] = image[:, :, 0] 38 | out_np[0::2, 1::2, 1] = image[:, :, 1] 39 | out_np[1::2, 0::2, 1] = image[:, :, 2] 40 | out_np[1::2, 1::2, 2] = image[:, :, 3] 41 | 42 | for i in range(1, shape[0] * 2 - 1): 43 | for j in range(1, shape[1] * 2 - 1): 44 | if i % 2 == 1 and j % 2 == 0: # case when green pixel is present with red on top and bottom 45 | out_np[i][j][0] = (out_np[i - 1][j][0] + out_np[i + 1][j][0]) / 2 # red 46 | out_np[i][j][2] = (out_np[i][j - 1][2] + out_np[i][j + 1][2]) / 2 # blue 47 | elif i % 2 == 0 and j % 2 == 0: # case when red pixel is present 48 | out_np[i][j][1] = (out_np[i - 1][j][1] + out_np[i][j + 1][1] + 49 | out_np[i + 1][j][1] + out_np[i][j - 1][1]) / 4 50 | out_np[i][j][2] = (out_np[i - 1][j - 1][2] + out_np[i + 1][j - 1][2] + 51 | out_np[i - 1][j + 1][2] + out_np[i + 1][j + 1][2]) / 4 52 | elif i % 2 == 0 and j % 2 == 1: # case when green pixel is present with blue on top and bottom 53 | out_np[i][j][0] = (out_np[i][j + 1][0] + out_np[i][j - 1][0]) / 2 54 | out_np[i][j][2] = (out_np[i + 1][j][2] + out_np[i - 1][j][2]) / 2 55 | else: # case when blue pixel is present 56 | out_np[i][j][0] = (out_np[i - 1][j - 1][0] + out_np[i - 1][j + 1][0] + 57 | out_np[i + 1][j + 1][0] + out_np[i + 1][j - 1][0]) / 4 58 | out_np[i][j][1] = (out_np[i - 1][j][1] + out_np[i][j + 1][1] + 59 | out_np[i + 1][j][1] + out_np[i][j - 1][1]) / 4 60 | 61 | last_row = shape[0] * 2 - 1 62 | for j in range(1, shape[1] * 2 - 1): 63 | if j % 2 == 0: # case when red pixel is present on first row or green pixel on last row 64 | out_np[0][j][1] = (out_np[0][j - 1][1] + out_np[0][j + 1][1] + out_np[1][j][1]) / 3 65 | out_np[0][j][2] = (out_np[1][j - 1][2] + out_np[1][j + 1][2]) / 2 66 | out_np[last_row][j][0] = out_np[last_row - 1][j][0] 67 | out_np[last_row][j][2] = (out_np[last_row][j - 1][2] + out_np[last_row][j + 1][2]) / 2 68 | else: # case when green pixel is present on first row or blue pixel on last row 69 | out_np[0][j][0] = (out_np[0][j - 1][0] + out_np[0][j + 1][0]) / 2 70 | out_np[0][j][2] = out_np[1][j][2] 71 | out_np[last_row][j][0] = (out_np[last_row - 1][j - 1][0] + out_np[last_row - 1][j + 1][0]) / 2 72 | out_np[last_row][j][1] = (out_np[last_row][j - 1][1] + out_np[last_row][j + 1][1] + 73 | out_np[last_row - 1][j][1]) / 3 74 | 75 | last_column = shape[1] * 2 - 1 76 | for i in range(1, shape[0] * 2 - 1): 77 | if i % 2 == 0: # case when red pixel is present on first column or green pixel on last column 78 | out_np[i][0][1] = (out_np[i - 1][0][1] + out_np[i + 1][0][1] + out_np[i][1][1]) / 3 79 | out_np[i][0][2] = (out_np[i - 1][1][2] + out_np[i + 1][1][2]) / 2 80 | out_np[i][last_column][0] = out_np[i][last_column - 1][0] 81 | out_np[i][last_column][2] = (out_np[i - 1][last_column][2] + out_np[i + 1][last_column][2]) / 2 82 | else: # case when green pixel is present on first column or blue pixel on last column 83 | out_np[i][0][0] = (out_np[i - 1][1][0] + out_np[i + 1][1][0]) / 2 84 | out_np[i][0][2] = out_np[i][1][2] 85 | out_np[i][last_column][0] = (out_np[i - 1][last_column - 1][0] + out_np[i + 1][last_column - 1][0]) / 2 86 | out_np[i][last_column][1] = (out_np[i - 1][last_column][1] + out_np[i + 1][last_column][1] + 87 | out_np[i][last_column - 1][1]) / 3 88 | 89 | out_np[0][0][1] = (out_np[0][1][1] + out_np[1][0][1]) / 2 90 | out_np[0][0][2] = out_np[1][1][2] 91 | out_np[0][last_column][0] = out_np[0][last_column - 1][0] 92 | out_np[0][last_column][2] = out_np[1][last_column][2] 93 | out_np[last_row][0][0] = out_np[last_row - 1][0][0] 94 | out_np[last_row][0][2] = out_np[last_row][1][2] 95 | out_np[last_row][last_column][0] = out_np[last_row - 1][last_column - 1][0] 96 | out_np[last_row][last_column][1] = (out_np[last_row - 1][last_column][1] + 97 | out_np[last_row][last_column - 1][1]) / 2 98 | 99 | out = torch.from_numpy(out_np) 100 | out = torch.clamp(out, min=0.0, max=1.0) 101 | if debug: 102 | single_real_image_display(out, "Image_After_Demosaic") 103 | return out 104 | 105 | 106 | def white_balance(image: torch.Tensor, red_gain=1.9, blue_gain=1.5, debug=False): 107 | f_red = image[:, :, 0] * red_gain 108 | f_blue = image[:, :, 3] * blue_gain 109 | out = torch.stack((f_red, image[:, :, 1], image[:, :, 2], f_blue), dim=-1) 110 | out = torch.clamp(out, min=0.0, max=1.0) 111 | if debug: 112 | single_raw_image_display(out, "Image_After_White_Balance") 113 | return out 114 | 115 | 116 | def digital_gain(image: torch.Tensor, iso=800, debug=False): 117 | out = image * iso 118 | if debug: 119 | tmp = torch.clamp(out, min=0.0, max=1.0) 120 | single_raw_image_display(tmp, "Image_After_Digital_Gain") 121 | return out 122 | 123 | 124 | def exposure_time_accumulate(image: torch.Tensor, exposure_time=10, debug=False): 125 | out = image * exposure_time 126 | if debug: 127 | tmp = torch.clamp(out, min=0.0, max=1.0) 128 | single_raw_image_display(tmp, "Image_After_Exposure_Time_Accumulation") 129 | return out 130 | 131 | 132 | def add_read_noise(image: torch.Tensor, iso=800, debug=False): 133 | image *= (1023 - 64) 134 | r_channel = image[:, :, 0] 135 | gr_channel = image[:, :, 1] 136 | gb_channel = image[:, :, 2] 137 | b_channel = image[:, :, 3] 138 | 139 | R0 = {'R': 0.300575, 'G': 0.347856, 'B': 0.356116} 140 | R1 = {'R': 1.293143, 'G': 0.403101, 'B': 0.403101} 141 | 142 | r_noise_sigma_square = (iso / 100.) ** 2 * R0['R'] + R1['R'] 143 | gb_noise_sigma_square = (iso / 100.) ** 2 * R0['G'] + R1['G'] 144 | gr_noise_sigma_square = (iso / 100.) ** 2 * R0['G'] + R1['G'] 145 | b_noise_sigma_square = (iso / 100.) ** 2 * R0['B'] + R1['B'] 146 | 147 | r_samples = tdist.Normal(loc=torch.zeros_like(r_channel), scale=math.sqrt(r_noise_sigma_square)).sample() 148 | gr_samples = tdist.Normal(loc=torch.zeros_like(gr_channel), scale=math.sqrt(gb_noise_sigma_square)).sample() 149 | gb_samples = tdist.Normal(loc=torch.zeros_like(gb_channel), scale=math.sqrt(gr_noise_sigma_square)).sample() 150 | b_samples = tdist.Normal(loc=torch.zeros_like(b_channel), scale=math.sqrt(b_noise_sigma_square)).sample() 151 | 152 | r_channel += r_samples 153 | gr_channel += gr_samples 154 | gb_channel += gb_samples 155 | b_channel += b_samples 156 | 157 | out = torch.stack((r_channel, gr_channel, gb_channel, b_channel), dim=-1) 158 | out /= (1023 - 64) 159 | out = torch.clamp(out, min=0.0, max=1.0) 160 | 161 | if debug: 162 | single_raw_image_display(out, "Image_After_Add_Read_Noise") 163 | return out 164 | 165 | 166 | def add_shot_noise(image: torch.Tensor, debug=False): 167 | image *= (1023 - 64) 168 | r_channel = image[:, :, 0] 169 | gr_channel = image[:, :, 1] 170 | gb_channel = image[:, :, 2] 171 | b_channel = image[:, :, 3] 172 | 173 | S = {'R': 0.343334/2, 'G': 0.348052/2, 'B': 0.346563/2} 174 | 175 | r_noise_sigma_square = (S['R'] / 100) * r_channel 176 | gb_noise_sigma_square = (S['G'] / 100) * gr_channel 177 | gr_noise_sigma_square = (S['G'] / 100) * gb_channel 178 | b_noise_sigma_square = (S['B'] / 100) * b_channel 179 | 180 | # print(r_noise_sigma_square) 181 | # print(torch.sqrt(r_noise_sigma_square).shape, r_channel.shape) 182 | 183 | r_samples = tdist.Normal(loc=torch.zeros_like(r_channel), scale=torch.sqrt(r_noise_sigma_square)).sample() 184 | gr_samples = tdist.Normal(loc=torch.zeros_like(gr_channel), scale=torch.sqrt(gb_noise_sigma_square)).sample() 185 | gb_samples = tdist.Normal(loc=torch.zeros_like(gb_channel), scale=torch.sqrt(gr_noise_sigma_square)).sample() 186 | b_samples = tdist.Normal(loc=torch.zeros_like(b_channel), scale=torch.sqrt(b_noise_sigma_square)).sample() 187 | 188 | r_channel += r_samples 189 | gr_channel += gr_samples 190 | gb_channel += gb_samples 191 | b_channel += b_samples 192 | 193 | out = torch.stack((r_channel, gr_channel, gb_channel, b_channel), dim=-1) 194 | out /= (1023 - 64) 195 | out = torch.clamp(out, min=0.0, max=1.0) 196 | 197 | if debug: 198 | single_raw_image_display(out, "Image_After_Add_Shot_Noise") 199 | return out 200 | 201 | 202 | def process(image: torch.Tensor, args, debug=False): 203 | # Followings are ISP Steps 204 | # ISP-step -7: the exposure time accumulation 205 | image = exposure_time_accumulate(image, exposure_time=args.output_exposure, debug=debug) 206 | image = add_shot_noise(image, debug) 207 | # ISP-step -6: the digital gain with read noise 208 | image = digital_gain(image, iso=args.output_iso, debug=debug) 209 | image = add_read_noise(image, iso=args.output_iso, debug=debug) 210 | # ISP-step -5: the white balance 211 | image = white_balance(image, red_gain=args.red_gain, blue_gain=args.blue_gain, debug=debug) 212 | # ISP-step -4: the demosaicing 213 | image = demosaic(image, debug) 214 | # ISP-step -3: the color correction 215 | image = color_correction(image, args.cam2rgb, debug) 216 | # ISP-step -2: the gamma compression 217 | image = gamma_compression(image, debug) 218 | # ISP-step -1: the tone mapping 219 | image = tone_mapping(image, debug) 220 | return image 221 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Blur-synthesizing/unprocess.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | 3 | 4 | def inv_tone_mapping(image: torch.Tensor, debug=False): 5 | image = torch.clamp(image, min=0.0, max=1.0) 6 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) 7 | if debug: 8 | single_real_image_display(out, "Image_After_Inv_Tone_Mapping") 9 | return out 10 | 11 | 12 | def inv_gamma_compression(image: torch.Tensor, debug=False): 13 | out = torch.pow(torch.clamp(image, min=1e-8), 2.2) 14 | if debug: 15 | single_real_image_display(out, "Image_After_Inv_Gamma_Compression") 16 | return out 17 | 18 | 19 | def inv_color_correction(image: torch.Tensor, ccm, debug=False): 20 | shape = image.size() 21 | image = torch.reshape(image, [-1, 3]) 22 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) 23 | out = torch.reshape(image, shape) 24 | if debug: 25 | single_real_image_display(out, "Image_After_Inv_Color_Correction") 26 | return out 27 | 28 | 29 | def mosaic(image: torch.Tensor, debug=False): 30 | shape = image.size() 31 | red = image[0::2, 0::2, 0] 32 | green_red = image[0::2, 1::2, 1] 33 | green_blue = image[1::2, 0::2, 1] 34 | blue = image[1::2, 1::2, 2] 35 | # import pdb 36 | # pdb.set_trace() 37 | # maybe shape[0] or shape[1] is odd, that cannot be divide by 2!!!!!!! 38 | 39 | out = torch.stack((red, green_red, green_blue, blue), dim=-1) 40 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4)) 41 | if debug: 42 | single_raw_image_display(out, "Image_After_Mosaic") 43 | return out 44 | 45 | 46 | def inv_white_balance(image: torch.Tensor, red_gain=1.9, blue_gain=1.5, threshold=0.9, debug=False): 47 | red = image[:, :, 0] 48 | alpha_red = (torch.max(red - threshold, torch.zeros_like(red)) / (1 - threshold)) ** 2 49 | f_red = torch.max(red / red_gain, (1 - alpha_red) * (red / red_gain) + alpha_red * red) 50 | blue = image[:, :, 3] 51 | alpha_blue = (torch.max(blue - threshold, torch.zeros_like(blue)) / (1 - threshold)) ** 2 52 | f_blue = torch.max(blue / blue_gain, (1 - alpha_blue) * (blue / blue_gain) + alpha_blue * blue) 53 | out = torch.stack((f_red, image[:, :, 1], image[:, :, 2], f_blue), dim=-1) 54 | if debug: 55 | single_raw_image_display(out, "Image_After_Inv_White_Balance") 56 | return out 57 | 58 | 59 | def inv_digital_gain(image: torch.Tensor, iso=800, scale_factor=5, debug=False): 60 | # out = image / iso 61 | out = image / (iso * scale_factor) 62 | # out /= scale_factor 63 | if debug: 64 | single_raw_image_display(out, "Image_After_Inv_Digital_Gain") 65 | return out 66 | 67 | 68 | def inv_exposure_time(image: torch.Tensor, exposure_time=10, debug=False): 69 | out = image / exposure_time 70 | if debug: 71 | single_raw_image_display(out, "Image_After_Inv_Exposure_Time") 72 | return out 73 | 74 | 75 | def unprocess(image: torch.Tensor, args, debug=False): 76 | # Followings are Inversion-ISP Steps 77 | # Inv-step 1: the inversion of tone mapping 78 | image = inv_tone_mapping(image, debug) 79 | # Inv-step 2: the inversion of gamma compression 80 | image = inv_gamma_compression(image, debug) 81 | # Inv-step 3: the inversion of color correction 82 | image = inv_color_correction(image, args.rgb2cam, debug) 83 | # Inv-step 4: the mosaicing 84 | image = mosaic(image, debug) 85 | # Inv-step 5: the inversion of white balance 86 | image = inv_white_balance(image, red_gain=args.red_gain, blue_gain=args.blue_gain, debug=debug) 87 | #single_raw_image_display(image, "balance") 88 | # Inv-step 6: the inversion of digital gain 89 | image = inv_digital_gain(image, iso=args.input_iso, scale_factor=args.scale_factor, debug=debug) 90 | 91 | # Inv-step 7: split exposure time to 1ms 92 | image = inv_exposure_time(image, exposure_time=args.input_exposure, debug=debug) 93 | 94 | return image 95 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Blur-synthesizing/util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import numpy 5 | 6 | 7 | def imread_rgb(path): 8 | image = cv2.imread(path) 9 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 10 | return image 11 | 12 | 13 | def imwrite_rgb(image, filename): 14 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 15 | cv2.imwrite(filename, image) 16 | 17 | 18 | def image_8bit_to_real(image_8bit: torch.Tensor): 19 | """ Covert image from [0, 255] to [0, 1] """ 20 | return torch.div(image_8bit, 255) 21 | 22 | 23 | def image_real_to_8bit(image_real: torch.Tensor): 24 | """ Covert image from [0, 1] to [0, 255] """ 25 | image_real = torch.clamp(image_real, min=0.0, max=1.0) 26 | return torch.mul(image_real, 255).type(torch.uint8) 27 | 28 | 29 | def single_8bit_image_display(image: torch.Tensor, description=None): 30 | """ Display a single [0, 255] image """ 31 | image = image.numpy() 32 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 33 | 34 | # cv2.imshow(description, image) 35 | # cv2.waitKey(0) 36 | cv2.imwrite('./resfig/ans.jpg', image) 37 | 38 | 39 | # cv2.destroyAllWindows() 40 | 41 | 42 | def single_real_image_display(image: torch.Tensor, description=None): 43 | """ Display a single [0, 1] image """ 44 | # image_real_to_8bit(image) 45 | image = image.numpy() 46 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 47 | cv2.imshow(description, image) 48 | cv2.waitKey(0) 49 | # cv2.destroyAllWindows() 50 | 51 | 52 | def single_8bit_image_display_numpy(image: numpy.ndarray, description=None): 53 | cv2.imshow(description, image) 54 | cv2.waitKey(0) 55 | # cv2.destroyAllWindows() 56 | 57 | 58 | def single_raw_image_display(image: torch.Tensor, description=None): 59 | shape = image.size() 60 | image = image.numpy() 61 | raw_in_rgb = numpy.zeros([shape[0] * 2, shape[1] * 2, 3], dtype=np.float32) 62 | raw_in_rgb[0::2, 0::2, 0] = image[:, :, 0] 63 | raw_in_rgb[0::2, 1::2, 1] = image[:, :, 1] 64 | raw_in_rgb[1::2, 0::2, 1] = image[:, :, 2] 65 | raw_in_rgb[1::2, 1::2, 2] = image[:, :, 3] 66 | raw_in_rgb = cv2.cvtColor(raw_in_rgb, cv2.COLOR_RGB2BGR) 67 | cv2.imshow(description, raw_in_rgb) 68 | cv2.imwrite(description + ".jpg", raw_in_rgb * 255) 69 | cv2.waitKey(0) 70 | 71 | 72 | def random_ccm(): 73 | """Generates random RGB -> Camera color correction matrices.""" 74 | # Takes a random convex combination of XYZ -> Camera CCMs. 75 | xyz2cams = [[[1.0234, -0.2969, -0.2266], 76 | [-0.5625, 1.6328, -0.0469], 77 | [-0.0703, 0.2188, 0.6406]], 78 | [[0.4913, -0.0541, -0.0202], 79 | [-0.613, 1.3513, 0.2906], 80 | [-0.1564, 0.2151, 0.7183]], 81 | [[0.838, -0.263, -0.0639], 82 | [-0.2887, 1.0725, 0.2496], 83 | [-0.0627, 0.1427, 0.5438]], 84 | [[0.6596, -0.2079, -0.0562], 85 | [-0.4782, 1.3016, 0.1933], 86 | [-0.097, 0.1581, 0.5181]]] 87 | num_ccms = len(xyz2cams) 88 | xyz2cams = torch.FloatTensor(xyz2cams) 89 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8) 90 | weights_sum = torch.sum(weights, dim=0) 91 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum 92 | 93 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. 94 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375], 95 | [0.2126729, 0.7151522, 0.0721750], 96 | [0.0193339, 0.1191920, 0.9503041]]) 97 | rgb2cam = torch.mm(xyz2cam, rgb2xyz) 98 | 99 | # Normalizes each row. 100 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True) 101 | 102 | cam2rgb = torch.inverse(rgb2cam) 103 | 104 | rgb2cam = torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 105 | cam2rgb = torch.inverse(rgb2cam) 106 | 107 | 108 | return rgb2cam, cam2rgb 109 | 110 | 111 | def random_gains(): 112 | # Red and blue gains represent white balance. 113 | # red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4) 114 | red_gain = torch.FloatTensor([2.15]) 115 | # blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9) 116 | blue_gain = torch.FloatTensor([1.7]) 117 | return red_gain, blue_gain 118 | -------------------------------------------------------------------------------- /orginal_data&preprocessing/synthetic/Event-preprocessing/event-preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | ''' 6 | Event data loader 7 | Transform the original event data to .pt file for E2NeRF training 8 | ''' 9 | def load_event_data_noisy(data_path): 10 | event_map = np.zeros((100, 4, 800, 800)) # 100 views of input, 4 event bin (b in the paper) for each view, resolution of 800*800 11 | 12 | for i in range(0, 200, 2): 13 | file = os.path.join(data_path, "r_{}/v2e-dvs-events.txt".format(i)) 14 | fp = open(file, "r") 15 | 16 | print("Processing data_path_{}".format(i)) 17 | counter = 1 18 | for j in range(6): 19 | fp.readline() 20 | 21 | while True: 22 | line = fp.readline() 23 | if not line: 24 | break 25 | 26 | info = line.split() 27 | t = float(info[0]) 28 | x = int(info[1]) 29 | y = int(info[2]) 30 | p = int(info[3]) 31 | 32 | if t > counter * 0.04 + 0.01: 33 | counter += 1 34 | if counter >= 5: 35 | break 36 | 37 | if p == 0: 38 | event_map[int(i / 2)][counter - 1][y][x] -= 1 39 | else: 40 | event_map[int(i / 2)][counter - 1][y][x] += 1 41 | return event_map 42 | 43 | input_data_path = "../blender-v2e-synthetic-events/lego/" 44 | output_data_path = "../blender-v2e-synthetic-events/lego/events.pt" 45 | 46 | if __name__ == '__main__': 47 | events = load_event_data_noisy(input_data_path) 48 | events = torch.tensor(events).view(100, 4, 640000) 49 | torch.save(events, output_data_path) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.20.2 2 | imageio~=2.9.0 3 | torch~=1.9.0+cu111 4 | opencv-python~=4.5.1.48 5 | matplotlib~=3.4.2 6 | tqdm~=4.62.0 7 | dv~=1.0.10 -------------------------------------------------------------------------------- /run_nerf_exp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm, trange 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | import load_blender 15 | from run_nerf_helpers import * 16 | 17 | from load_llff import load_llff_data 18 | from load_blender import load_blender_data 19 | from load_event import * 20 | 21 | import pynvml 22 | from torch.utils.tensorboard import SummaryWriter 23 | import math 24 | from event_loss_helpers import * 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | np.random.seed(0) 29 | DEBUG = False 30 | 31 | 32 | def batchify(fn, chunk): 33 | """Constructs a version of 'fn' that applies to smaller batches. 34 | """ 35 | if chunk is None: 36 | return fn 37 | def ret(inputs): 38 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 39 | return ret 40 | 41 | 42 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 43 | """Prepares inputs and applies network 'fn'. 44 | """ 45 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 46 | embedded = embed_fn(inputs_flat) 47 | 48 | if viewdirs is not None: 49 | input_dirs = viewdirs[:,None].expand(inputs.shape) 50 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 51 | embedded_dirs = embeddirs_fn(input_dirs_flat) 52 | embedded = torch.cat([embedded, embedded_dirs], -1) 53 | 54 | outputs_flat = batchify(fn, netchunk)(embedded) 55 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 56 | return outputs 57 | 58 | 59 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 60 | """Render rays in smaller minibatches to avoid OOM. 61 | """ 62 | all_ret = {} 63 | for i in range(0, rays_flat.shape[0], chunk): 64 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 65 | for k in ret: 66 | if k not in all_ret: 67 | all_ret[k] = [] 68 | all_ret[k].append(ret[k]) 69 | 70 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 71 | return all_ret 72 | 73 | 74 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 75 | near=0., far=1., 76 | use_viewdirs=False, c2w_staticcam=None, 77 | **kwargs): 78 | """Render rays 79 | Args: 80 | H: int. Height of image in pixels. 81 | W: int. Width of image in pixels. 82 | focal: float. Focal length of pinhole camera. 83 | chunk: int. Maximum number of rays to process simultaneously. Used to 84 | control maximum memory usage. Does not affect final results. 85 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 86 | each example in batch. 87 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 88 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 89 | near: float or array of shape [batch_size]. Nearest distance for a ray. 90 | far: float or array of shape [batch_size]. Farthest distance for a ray. 91 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 92 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 93 | camera while using other c2w argument for viewing directions. 94 | Returns: 95 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 96 | disp_map: [batch_size]. Disparity map. Inverse of depth. 97 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 98 | extras: dict with everything returned by render_rays(). 99 | """ 100 | 101 | if c2w is not None: 102 | # special case to render full image 103 | rays_o, rays_d = get_rays(H, W, K, c2w) 104 | else: 105 | # use provided ray batch 106 | rays_o, rays_d = rays 107 | 108 | if use_viewdirs: 109 | # provide ray directions as input 110 | viewdirs = rays_d 111 | if c2w_staticcam is not None: 112 | # special case to visualize effect of viewdirs 113 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 114 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 115 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 116 | 117 | sh = rays_d.shape # [..., 3] 118 | if ndc: 119 | # for forward facing scenes 120 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 121 | 122 | # Create ray batch 123 | rays_o = torch.reshape(rays_o, [-1,3]).float() 124 | rays_d = torch.reshape(rays_d, [-1,3]).float() 125 | 126 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 127 | rays = torch.cat([rays_o, rays_d, near, far], -1) 128 | if use_viewdirs: 129 | rays = torch.cat([rays, viewdirs], -1) 130 | 131 | # Render and reshape 132 | all_ret = batchify_rays(rays, chunk, **kwargs) 133 | for k in all_ret: 134 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 135 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 136 | 137 | k_extract = ['rgb_map', 'disp_map', 'acc_map'] 138 | ret_list = [all_ret[k] for k in k_extract] 139 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 140 | return ret_list + [ret_dict] 141 | 142 | 143 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 144 | 145 | H, W, focal = hwf 146 | 147 | if render_factor!=0: 148 | # Render downsampled for speed 149 | H = H//render_factor 150 | W = W//render_factor 151 | focal = focal/render_factor 152 | 153 | rgbs = [] 154 | disps = [] 155 | rgbs_extras = [] 156 | 157 | t = time.time() 158 | for i, c2w in enumerate(tqdm(render_poses)): 159 | print(i, time.time() - t) 160 | t = time.time() 161 | #rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 162 | rgb, disp, acc, extra = render(H, W, K, chunk=chunk, c2w=c2w[:3, :4], **render_kwargs) 163 | rgbs.append(rgb.cpu().numpy()) 164 | disps.append(disp.cpu().numpy()) 165 | 166 | if i==0: 167 | print(rgb.shape, disp.shape) 168 | 169 | 170 | if gt_imgs is not None and render_factor==0: 171 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 172 | print(p) 173 | 174 | if savedir is not None: 175 | rgb8 = to8b(rgbs[-1]) 176 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 177 | imageio.imwrite(filename, rgb8) 178 | 179 | rgbs = np.stack(rgbs, 0) 180 | disps = np.stack(disps, 0) 181 | 182 | return rgbs, disps 183 | 184 | 185 | def create_nerf(args): 186 | """Instantiate NeRF's MLP model. 187 | """ 188 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 189 | 190 | input_ch_views = 0 191 | embeddirs_fn = None 192 | if args.use_viewdirs: 193 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 194 | output_ch = 5 if args.N_importance > 0 else 4 195 | skips = [4] 196 | model = NeRF(D=args.netdepth, W=args.netwidth, 197 | input_ch=input_ch, output_ch=output_ch, skips=skips, 198 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 199 | grad_vars = list(model.parameters()) 200 | 201 | model_fine = None 202 | if args.N_importance > 0: 203 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 204 | input_ch=input_ch, output_ch=output_ch, skips=skips, 205 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 206 | grad_vars += list(model_fine.parameters()) 207 | 208 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 209 | embed_fn=embed_fn, 210 | embeddirs_fn=embeddirs_fn, 211 | netchunk=args.netchunk) 212 | 213 | # Create optimizer 214 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 215 | 216 | start = 0 217 | basedir = args.basedir 218 | expname = args.expname 219 | 220 | ########################## 221 | 222 | # Load checkpoints 223 | if args.ft_path is not None and args.ft_path!='None': 224 | ckpts = [args.ft_path] 225 | else: 226 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 227 | 228 | print('Found ckpts', ckpts) 229 | if len(ckpts) > 0 and not args.no_reload: 230 | ckpt_path = ckpts[-1] 231 | print('Reloading from', ckpt_path) 232 | ckpt = torch.load(ckpt_path) 233 | 234 | start = ckpt['global_step'] 235 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 236 | 237 | # Load model 238 | model.load_state_dict(ckpt['network_fn_state_dict']) 239 | if model_fine is not None: 240 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 241 | 242 | ########################## 243 | 244 | render_kwargs_train = { 245 | 'network_query_fn' : network_query_fn, 246 | 'perturb' : args.perturb, 247 | 'N_importance' : args.N_importance, 248 | 'network_fine' : model_fine, 249 | 'N_samples' : args.N_samples, 250 | 'network_fn' : model, 251 | 'use_viewdirs' : args.use_viewdirs, 252 | 'white_bkgd' : args.white_bkgd, 253 | 'raw_noise_std' : args.raw_noise_std, 254 | } 255 | 256 | # NDC only good for LLFF-style forward facing data 257 | if args.dataset_type != 'llff' or args.no_ndc: 258 | print('Not ndc!') 259 | render_kwargs_train['ndc'] = False 260 | render_kwargs_train['lindisp'] = args.lindisp 261 | 262 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 263 | render_kwargs_test['perturb'] = False 264 | render_kwargs_test['raw_noise_std'] = 0. 265 | 266 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 267 | 268 | 269 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 270 | """Transforms model's predictions to semantically meaningful values. 271 | Args: 272 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 273 | z_vals: [num_rays, num_samples along ray]. Integration time. 274 | rays_d: [num_rays, 3]. Direction of each ray. 275 | Returns: 276 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 277 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 278 | acc_map: [num_rays]. Sum of weights along each ray. 279 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 280 | depth_map: [num_rays]. Estimated distance to object. 281 | """ 282 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 283 | 284 | dists = z_vals[...,1:] - z_vals[...,:-1] 285 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 286 | 287 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 288 | 289 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 290 | noise = 0. 291 | if raw_noise_std > 0.: 292 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 293 | 294 | # Overwrite randomly sampled data if pytest 295 | if pytest: 296 | np.random.seed(0) 297 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 298 | noise = torch.Tensor(noise) 299 | 300 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 301 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 302 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 303 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 304 | 305 | depth_map = torch.sum(weights * z_vals, -1) 306 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 307 | acc_map = torch.sum(weights, -1) 308 | 309 | if white_bkgd: 310 | rgb_map = rgb_map + (1.-acc_map[...,None]) 311 | 312 | return rgb_map, disp_map, acc_map, weights, depth_map 313 | 314 | 315 | def render_rays(ray_batch, 316 | network_fn, 317 | network_query_fn, 318 | N_samples, 319 | retraw=False, 320 | lindisp=False, 321 | perturb=0., 322 | N_importance=0, 323 | network_fine=None, 324 | white_bkgd=False, 325 | raw_noise_std=0., 326 | verbose=False, 327 | pytest=False): 328 | """Volumetric rendering. 329 | Args: 330 | ray_batch: array of shape [batch_size, ...]. All information necessary 331 | for sampling along a ray, including: ray origin, ray direction, min 332 | dist, max dist, and unit-magnitude viewing direction. 333 | network_fn: function. Model for predicting RGB and density at each point 334 | in space. 335 | network_query_fn: function used for passing queries to network_fn. 336 | N_samples: int. Number of different times to sample along each ray. 337 | retraw: bool. If True, include model's raw, unprocessed predictions. 338 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 339 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 340 | random points in time. 341 | N_importance: int. Number of additional times to sample along each ray. 342 | These samples are only passed to network_fine. 343 | network_fine: "fine" network with same spec as network_fn. 344 | white_bkgd: bool. If True, assume a white background. 345 | raw_noise_std: ... 346 | verbose: bool. If True, print more debugging info. 347 | Returns: 348 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 349 | disp_map: [num_rays]. Disparity map. 1 / depth. 350 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 351 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 352 | rgb0: See rgb_map. Output for coarse model. 353 | disp0: See disp_map. Output for coarse model. 354 | acc0: See acc_map. Output for coarse model. 355 | z_std: [num_rays]. Standard deviation of distances along ray for each 356 | sample. 357 | """ 358 | N_rays = ray_batch.shape[0] 359 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 360 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 361 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 362 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 363 | 364 | t_vals = torch.linspace(0., 1., steps=N_samples) 365 | if not lindisp: 366 | z_vals = near * (1.-t_vals) + far * (t_vals) 367 | else: 368 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 369 | 370 | z_vals = z_vals.expand([N_rays, N_samples]) 371 | 372 | if perturb > 0.: 373 | # get intervals between samples 374 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 375 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 376 | lower = torch.cat([z_vals[...,:1], mids], -1) 377 | # stratified samples in those intervals 378 | t_rand = torch.rand(z_vals.shape) 379 | 380 | # Pytest, overwrite u with numpy's fixed random numbers 381 | if pytest: 382 | np.random.seed(0) 383 | t_rand = np.random.rand(*list(z_vals.shape)) 384 | t_rand = torch.Tensor(t_rand) 385 | 386 | z_vals = lower + (upper - lower) * t_rand 387 | 388 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 389 | 390 | 391 | # raw = run_network(pts) 392 | raw = network_query_fn(pts, viewdirs, network_fn) 393 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 394 | 395 | if N_importance > 0: 396 | 397 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 398 | 399 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 400 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 401 | z_samples = z_samples.detach() 402 | 403 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 404 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 405 | 406 | run_fn = network_fn if network_fine is None else network_fine 407 | # raw = run_network(pts, fn=run_fn) 408 | raw = network_query_fn(pts, viewdirs, run_fn) 409 | 410 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 411 | 412 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} 413 | if retraw: 414 | ret['raw'] = raw 415 | if N_importance > 0: 416 | ret['rgb0'] = rgb_map_0 417 | ret['disp0'] = disp_map_0 418 | ret['acc0'] = acc_map_0 419 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 420 | 421 | for k in ret: 422 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 423 | print(f"! [Numerical Error] {k} contains nan or inf.") 424 | 425 | return ret 426 | 427 | 428 | def config_parser(): 429 | 430 | import configargparse 431 | parser = configargparse.ArgumentParser() 432 | parser.add_argument('--config', is_config_file=True, 433 | help='config file path') 434 | parser.add_argument("--expname", type=str, 435 | help='experiment name') 436 | parser.add_argument("--basedir", type=str, default='./logs/', 437 | help='where to store ckpts and logs') 438 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 439 | help='input data directory') 440 | 441 | # training options 442 | parser.add_argument("--netdepth", type=int, default=8, 443 | help='layers in network') 444 | parser.add_argument("--netwidth", type=int, default=256, 445 | help='channels per layer') 446 | parser.add_argument("--netdepth_fine", type=int, default=8, 447 | help='layers in fine network') 448 | parser.add_argument("--netwidth_fine", type=int, default=256, 449 | help='channels per layer in fine network') 450 | parser.add_argument("--N_rand", type=int, default=32*32*4, 451 | help='batch size (number of random rays per gradient step)') 452 | parser.add_argument("--lrate", type=float, default=5e-4, 453 | help='learning rate') 454 | parser.add_argument("--lrate_decay", type=int, default=250, 455 | help='exponential learning rate decay (in 1000 steps)') 456 | parser.add_argument("--chunk", type=int, default=1024*32, 457 | help='number of rays processed in parallel, decrease if running out of memory') 458 | parser.add_argument("--netchunk", type=int, default=1024*64, 459 | help='number of pts sent through network in parallel, decrease if running out of memory') 460 | parser.add_argument("--no_batching", action='store_true', 461 | help='only take random rays from 1 image at a time') 462 | parser.add_argument("--no_reload", action='store_true', 463 | help='do not reload weights from saved ckpt') 464 | parser.add_argument("--ft_path", type=str, default=None, 465 | help='specific weights npy file to reload for coarse network') 466 | 467 | # rendering options 468 | parser.add_argument("--N_samples", type=int, default=64, 469 | help='number of coarse samples per ray') 470 | parser.add_argument("--N_importance", type=int, default=0, 471 | help='number of additional fine samples per ray') 472 | parser.add_argument("--perturb", type=float, default=1., 473 | help='set to 0. for no jitter, 1. for jitter') 474 | parser.add_argument("--use_viewdirs", action='store_true', 475 | help='use full 5D input instead of 3D') 476 | parser.add_argument("--i_embed", type=int, default=0, 477 | help='set 0 for default positional encoding, -1 for none') 478 | parser.add_argument("--multires", type=int, default=10, 479 | help='log2 of max freq for positional encoding (3D location)') 480 | parser.add_argument("--multires_views", type=int, default=4, 481 | help='log2 of max freq for positional encoding (2D direction)') 482 | parser.add_argument("--raw_noise_std", type=float, default=0., 483 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 484 | 485 | parser.add_argument("--render_only", action='store_true', 486 | help='do not optimize, reload weights and render out render_poses path') 487 | parser.add_argument("--render_factor", type=int, default=0, 488 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 489 | 490 | # training options 491 | parser.add_argument("--precrop_iters", type=int, default=0, 492 | help='number of steps to train on central crops') 493 | parser.add_argument("--precrop_frac", type=float, 494 | default=.5, help='fraction of img taken for central crops') 495 | 496 | # dataset options 497 | parser.add_argument("--dataset_type", type=str, default='blender', 498 | help='options: blender / ellff') 499 | parser.add_argument("--testskip", type=int, default=8, 500 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 501 | 502 | # blender flags 503 | parser.add_argument("--white_bkgd", action='store_true', 504 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 505 | 506 | # llff flags 507 | parser.add_argument("--factor", type=int, default=8, 508 | help='downsample factor for LLFF images') 509 | parser.add_argument("--no_ndc", action='store_true', 510 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 511 | parser.add_argument("--lindisp", action='store_true', 512 | help='sampling linearly in disparity rather than depth') 513 | parser.add_argument("--spherify", action='store_true', 514 | help='set for spherical 360 scenes') 515 | parser.add_argument("--llffhold", type=int, default=8, 516 | help='will take every 1/N images as LLFF test set, paper uses 8') 517 | 518 | # event options 519 | parser.add_argument("--use_event", type=bool, default=True, 520 | help='use event to help the training') 521 | 522 | # logging/saving options 523 | parser.add_argument("--i_print", type=int, default=10, 524 | help='frequency of console printout and metric loggin') 525 | parser.add_argument("--i_img", type=int, default=500, 526 | help='frequency of tensorboard image logging') 527 | parser.add_argument("--i_weights", type=int, default=10000, 528 | help='frequency of weight ckpt saving') 529 | parser.add_argument("--i_testset", type=int, default=200000, 530 | help='frequency of testset saving') 531 | parser.add_argument("--i_video", type=int, default=200000, 532 | help='frequency of render_poses video saving') 533 | 534 | return parser 535 | 536 | def train(): 537 | print("----------------starting----------------") 538 | pynvml.nvmlInit() 539 | parser = config_parser() 540 | args = parser.parse_args() 541 | 542 | print("---------------data_loading----------------") 543 | # Load data 544 | K = None 545 | if args.dataset_type == 'blender': 546 | resolution_h, resolution_w = 800, 800 547 | view_num = 100 548 | 549 | images, poses, test_images, test_poses, render_poses, hwf = load_blender_data(args.datadir) 550 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 551 | 552 | near = 2. 553 | far = 6. 554 | 555 | if args.white_bkgd: 556 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 557 | test_images = test_images[...,:3]*test_images[...,-1:] + (1.-test_images[...,-1:]) 558 | else: 559 | images = images[...,:3] 560 | test_imges = test_images[...,:3] 561 | 562 | 563 | elif args.dataset_type == 'ellff': 564 | resolution_h, resolution_w = 260, 346 565 | view_num = 30 566 | 567 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 568 | recenter=True, bd_factor=.75, 569 | spherify=args.spherify) 570 | hwf = poses[0, :3, -1] 571 | poses = poses[:, :3, :4] 572 | 573 | new_poses = [] 574 | test_poses = [] 575 | for i in range(30): 576 | pose = [] 577 | test_poses.append(poses[i * 5]) 578 | for j in range(5): 579 | pose.append(poses[i * 5 + j]) 580 | new_poses.append(np.stack(pose, axis=0)) 581 | poses = np.stack(new_poses, axis=0) 582 | test_poses = np.stack(test_poses, axis=0) 583 | test_images = None 584 | 585 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 586 | 587 | print('DEFINING BOUNDS') 588 | if args.no_ndc: 589 | near = np.ndarray.min(bds) * .9 590 | far = np.ndarray.max(bds) * 1. 591 | 592 | else: 593 | near = 0. 594 | far = 1. 595 | print('NEAR FAR', near, far) 596 | 597 | else: 598 | print('Unknown dataset type', args.dataset_type, 'exiting') 599 | return 600 | 601 | 602 | # Cast intrinsics to right types 603 | H, W, focal = hwf 604 | H, W = int(H), int(W) 605 | hwf = [H, W, focal] 606 | 607 | if K is None: 608 | K = np.array([ 609 | [focal, 0, 0.5*W], 610 | [0, focal, 0.5*H], 611 | [0, 0, 1] 612 | ]) 613 | 614 | # load event 615 | use_event = args.use_event 616 | if use_event: 617 | print("---------------event_data_loading----------------") 618 | event_map = torch.load(os.path.join(args.datadir, "events.pt")) 619 | print("event size: " + str(event_map.size())) 620 | event_map = event_map.to(device) 621 | combination = [] 622 | for m in range(4): 623 | for n in range(m + 1, 5): 624 | combination.append([m, n]) 625 | 626 | print("---------------creat_nerf----------------") 627 | # Create log dir and copy the config file 628 | basedir = args.basedir 629 | expname = args.expname 630 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 631 | f = os.path.join(basedir, expname, 'args.txt') 632 | with open(f, 'w') as file: 633 | for arg in sorted(vars(args)): 634 | attr = getattr(args, arg) 635 | file.write('{} = {}\n'.format(arg, attr)) 636 | if args.config is not None: 637 | f = os.path.join(basedir, expname, 'config.txt') 638 | with open(f, 'w') as file: 639 | file.write(open(args.config, 'r').read()) 640 | 641 | # Create nerf model 642 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 643 | global_step = start 644 | 645 | bds_dict = { 646 | 'near' : near, 647 | 'far' : far, 648 | } 649 | render_kwargs_train.update(bds_dict) 650 | render_kwargs_test.update(bds_dict) 651 | 652 | # Move testing data to GPU 653 | test_poses = torch.Tensor(test_poses).to(device) 654 | render_poses = torch.Tensor(render_poses).to(device) 655 | 656 | # Short circuit if only rendering out from trained model 657 | ''' 658 | if args.render_only: 659 | print('RENDER ONLY') 660 | with torch.no_grad(): 661 | if args.render_test: 662 | # render_test switches to test poses 663 | images = test_images 664 | else: 665 | # Default is smoother render_poses path 666 | images = None 667 | 668 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 669 | os.makedirs(testsavedir, exist_ok=True) 670 | print('test poses shape', render_poses.shape) 671 | 672 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) 673 | print('Done rendering', testsavedir) 674 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 675 | 676 | return 677 | ''' 678 | 679 | # Prepare raybatch tensor if batching random rays 680 | N_rand = args.N_rand 681 | use_batching = not args.no_batching 682 | 683 | ''' 684 | if use_batching: 685 | print("To be ccomplete") 686 | # For random ray batching--------------------------------------------------------------------------------------- 687 | print('get rays') 688 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 689 | print('done, concats') 690 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 691 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 692 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 693 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 694 | rays_rgb = rays_rgb.astype(np.float32) 695 | print('shuffle rays') 696 | np.random.shuffle(rays_rgb) 697 | 698 | print('done') 699 | i_batch = 0 700 | 701 | # Move training data to GPU 702 | if use_batching: 703 | images = torch.Tensor(images).to(device) 704 | if use_batching: 705 | rays_rgb = torch.Tensor(rays_rgb).to(device) 706 | ''' 707 | 708 | 709 | print("---------------start_training----------------") 710 | N_iters = 200000 + 1 711 | print('Begin') 712 | 713 | 714 | # Summary writers 715 | writer = SummaryWriter(os.path.join(basedir, expname, 'logs')) 716 | 717 | start = start + 1 718 | for i in trange(start, N_iters): 719 | time0 = time.time() 720 | 721 | # Sample random ray batch--------------------------------------------------------------------------------------- 722 | 723 | if use_batching: 724 | print("To be ccomplete") 725 | ''' 726 | # Random over all images 727 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] 728 | batch = torch.transpose(batch, 0, 1) 729 | batch_rays, target_s = batch[:2], batch[2] 730 | 731 | i_batch += N_rand 732 | if i_batch >= rays_rgb.shape[0]: 733 | print("Shuffle data after an epoch!") 734 | rand_idx = torch.randperm(rays_rgb.shape[0]) 735 | rays_rgb = rays_rgb[rand_idx] 736 | i_batch = 0 737 | ''' 738 | else: 739 | # Random from one image 740 | img_i = np.random.choice(view_num) 741 | target = images[img_i] 742 | target = torch.Tensor(target).to(device) 743 | pose = poses[img_i, :, :3,:4] 744 | 745 | if N_rand is not None: 746 | rays_os = [] 747 | rays_ds = [] 748 | for j in range(5): 749 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose[j])) 750 | rays_os.append(rays_o) 751 | rays_ds.append(rays_d) 752 | if i < args.precrop_iters: 753 | dH = int(H // 2 * args.precrop_frac) 754 | dW = int(W // 2 * args.precrop_frac) 755 | coords = torch.stack( 756 | torch.meshgrid( 757 | torch.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), 758 | torch.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW) 759 | ), -1) 760 | if i == start: 761 | print(f"[Config] Center cropping of size {2 * dH} x {2 * dW} is enabled until iter {args.precrop_iters}") 762 | else: 763 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W)),-1) # (H, W, 2) 764 | 765 | coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) 766 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 767 | select_coords = coords[select_inds].long() # (N_rand, 2) 768 | 769 | all_extra = [] 770 | all_rgb = [] 771 | for j in range(5): 772 | rays_o = rays_os[j][select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 773 | rays_d = rays_ds[j][select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 774 | batch_rays = torch.stack([rays_o, rays_d], 0) 775 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 776 | verbose=i < 10, retraw=True, 777 | **render_kwargs_train) 778 | all_rgb.append(rgb) 779 | all_extra.append(extras['rgb0']) 780 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 781 | 782 | ##### Core optimization loop ##### 783 | #change No.1 average-nerf 784 | all_rgb = torch.stack(all_rgb, dim=0) 785 | rgb = torch.mean(all_rgb, dim = 0) 786 | 787 | #change No.2 event-nerf 788 | if use_event: 789 | event_data = event_map[img_i] 790 | event_loss = event_loss_call(all_rgb, event_data, select_coords, combination, "rgb", resolution_h, resolution_w) * 0.001 791 | 792 | img_loss = img2mse(rgb, target_s) 793 | trans = extras['raw'][...,-1] 794 | loss = img_loss 795 | psnr = mse2psnr(img_loss) 796 | 797 | if 'rgb0' in extras: 798 | all_extra = torch.stack(all_extra, dim=0) 799 | extras = torch.mean(all_extra, dim=0) 800 | img_loss0 = img2mse(extras, target_s) 801 | loss = loss + img_loss0 802 | psnr0 = mse2psnr(img_loss0) 803 | loss = loss + event_loss 804 | 805 | optimizer.zero_grad() 806 | loss.backward() 807 | optimizer.step() 808 | 809 | # NOTE: IMPORTANT! 810 | ### update learning rate ### 811 | decay_rate = 0.1 812 | decay_steps = args.lrate_decay * 1000 813 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 814 | for param_group in optimizer.param_groups: 815 | param_group['lr'] = new_lrate 816 | ################################ 817 | 818 | dt = time.time()-time0 819 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") 820 | ##### end ##### 821 | 822 | # Rest is logging 823 | if i%args.i_weights==0: 824 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 825 | torch.save({ 826 | 'global_step': global_step, 827 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 828 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 829 | 'optimizer_state_dict': optimizer.state_dict(), 830 | }, path) 831 | print('Saved checkpoints at', path) 832 | 833 | if i%args.i_video==0 and i > 0: 834 | # Turn on testing mode 835 | with torch.no_grad(): 836 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) 837 | print('Done, saving', rgbs.shape, disps.shape) 838 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 839 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) 840 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) 841 | 842 | # if args.use_viewdirs: 843 | # render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4] 844 | # with torch.no_grad(): 845 | # rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test) 846 | # render_kwargs_test['c2w_staticcam'] = None 847 | # imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8) 848 | 849 | if i%args.i_testset==0 and i > 0: 850 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 851 | os.makedirs(testsavedir, exist_ok=True) 852 | with torch.no_grad(): 853 | render_path(torch.Tensor(test_poses).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=test_images, savedir=testsavedir) 854 | print('Saved test set') 855 | 856 | if i%args.i_print==0: 857 | writer.add_scalar("loss", loss, i) 858 | writer.add_scalar("event_loss", event_loss, i) 859 | writer.add_scalar("img_loss", img_loss, i) 860 | writer.add_scalar("img_loss0", img_loss0, i) 861 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 862 | tqdm.write(f"[TRAIN] Iter: {i} Event_Loss: {event_loss.item()} Img_Loss: {img_loss.item()} Img_Loss0: {img_loss0.item()}") 863 | 864 | global_step += 1 865 | 866 | if __name__=='__main__': 867 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 868 | train() 869 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | # Misc 9 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 10 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 12 | 13 | # Positional encoding (section 5.1) 14 | class Embedder: 15 | def __init__(self, **kwargs): 16 | self.kwargs = kwargs 17 | self.create_embedding_fn() 18 | 19 | def create_embedding_fn(self): 20 | embed_fns = [] 21 | d = self.kwargs['input_dims'] 22 | out_dim = 0 23 | if self.kwargs['include_input']: 24 | embed_fns.append(lambda x : x) 25 | out_dim += d 26 | 27 | max_freq = self.kwargs['max_freq_log2'] 28 | N_freqs = self.kwargs['num_freqs'] 29 | 30 | if self.kwargs['log_sampling']: 31 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 32 | else: 33 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 34 | 35 | for freq in freq_bands: 36 | for p_fn in self.kwargs['periodic_fns']: 37 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 38 | out_dim += d 39 | 40 | self.embed_fns = embed_fns 41 | self.out_dim = out_dim 42 | 43 | def embed(self, inputs): 44 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 45 | 46 | 47 | def get_embedder(multires, i=0): 48 | if i == -1: 49 | return nn.Identity(), 3 50 | 51 | embed_kwargs = { 52 | 'include_input' : True, 53 | 'input_dims' : 3, 54 | 'max_freq_log2' : multires-1, 55 | 'num_freqs' : multires, 56 | 'log_sampling' : True, 57 | 'periodic_fns' : [torch.sin, torch.cos], 58 | } 59 | 60 | embedder_obj = Embedder(**embed_kwargs) 61 | embed = lambda x, eo=embedder_obj : eo.embed(x) 62 | return embed, embedder_obj.out_dim 63 | 64 | 65 | # Model 66 | class NeRF(nn.Module): 67 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 68 | """ 69 | """ 70 | super(NeRF, self).__init__() 71 | self.D = D 72 | self.W = W 73 | self.input_ch = input_ch 74 | self.input_ch_views = input_ch_views 75 | self.skips = skips 76 | self.use_viewdirs = use_viewdirs 77 | 78 | self.pts_linears = nn.ModuleList( 79 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 80 | 81 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 82 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 83 | 84 | ### Implementation according to the paper 85 | # self.views_linears = nn.ModuleList( 86 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 87 | 88 | if use_viewdirs: 89 | self.feature_linear = nn.Linear(W, W) 90 | self.alpha_linear = nn.Linear(W, 1) 91 | self.rgb_linear = nn.Linear(W//2, 3) 92 | else: 93 | self.output_linear = nn.Linear(W, output_ch) 94 | 95 | def forward(self, x): 96 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 97 | h = input_pts 98 | for i, l in enumerate(self.pts_linears): 99 | h = self.pts_linears[i](h) 100 | h = F.relu(h) 101 | if i in self.skips: 102 | h = torch.cat([input_pts, h], -1) 103 | 104 | if self.use_viewdirs: 105 | alpha = self.alpha_linear(h) 106 | feature = self.feature_linear(h) 107 | h = torch.cat([feature, input_views], -1) 108 | 109 | for i, l in enumerate(self.views_linears): 110 | h = self.views_linears[i](h) 111 | h = F.relu(h) 112 | 113 | rgb = self.rgb_linear(h) 114 | outputs = torch.cat([rgb, alpha], -1) 115 | else: 116 | outputs = self.output_linear(h) 117 | 118 | return outputs 119 | 120 | def load_weights_from_keras(self, weights): 121 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 122 | 123 | # Load pts_linears 124 | for i in range(self.D): 125 | idx_pts_linears = 2 * i 126 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 127 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 128 | 129 | # Load feature_linear 130 | idx_feature_linear = 2 * self.D 131 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 132 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 133 | 134 | # Load views_linears 135 | idx_views_linears = 2 * self.D + 2 136 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 137 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 138 | 139 | # Load rgb_linear 140 | idx_rbg_linear = 2 * self.D + 4 141 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 142 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 143 | 144 | # Load alpha_linear 145 | idx_alpha_linear = 2 * self.D + 6 146 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 147 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 148 | 149 | 150 | 151 | # Ray helpers 152 | def get_rays(H, W, K, c2w): 153 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 154 | i = i.t() 155 | j = j.t() 156 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 157 | # Rotate ray directions from camera frame to the world frame 158 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 159 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 160 | rays_o = c2w[:3,-1].expand(rays_d.shape) 161 | return rays_o, rays_d 162 | 163 | 164 | def get_rays_np(H, W, K, c2w): 165 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 166 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) 167 | # Rotate ray directions from camera frame to the world frame 168 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 169 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 170 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 171 | return rays_o, rays_d 172 | 173 | 174 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 175 | # Shift ray origins to near plane 176 | t = -(near + rays_o[...,2]) / rays_d[...,2] 177 | rays_o = rays_o + t[...,None] * rays_d 178 | 179 | # Projection 180 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 181 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 182 | o2 = 1. + 2. * near / rays_o[...,2] 183 | 184 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 185 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 186 | d2 = -2. * near / rays_o[...,2] 187 | 188 | rays_o = torch.stack([o0,o1,o2], -1) 189 | rays_d = torch.stack([d0,d1,d2], -1) 190 | 191 | return rays_o, rays_d 192 | 193 | 194 | # Hierarchical sampling (section 5.2) 195 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 196 | # Get pdf 197 | weights = weights + 1e-5 # prevent nans 198 | pdf = weights / torch.sum(weights, -1, keepdim=True) 199 | cdf = torch.cumsum(pdf, -1) 200 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 201 | 202 | # Take uniform samples 203 | if det: 204 | u = torch.linspace(0., 1., steps=N_samples) 205 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 206 | else: 207 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 208 | 209 | # Pytest, overwrite u with numpy's fixed random numbers 210 | if pytest: 211 | np.random.seed(0) 212 | new_shape = list(cdf.shape[:-1]) + [N_samples] 213 | if det: 214 | u = np.linspace(0., 1., N_samples) 215 | u = np.broadcast_to(u, new_shape) 216 | else: 217 | u = np.random.rand(*new_shape) 218 | u = torch.Tensor(u) 219 | 220 | # Invert CDF 221 | u = u.contiguous() 222 | inds = torch.searchsorted(cdf, u, right=True) 223 | below = torch.max(torch.zeros_like(inds-1), inds-1) 224 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 225 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 226 | 227 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 228 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 229 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 230 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 231 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 232 | 233 | denom = (cdf_g[...,1]-cdf_g[...,0]) 234 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 235 | t = (u-cdf_g[...,0])/denom 236 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 237 | 238 | return samples 239 | --------------------------------------------------------------------------------