├── requirements.txt ├── configs ├── Ball.txt ├── Pen.txt ├── Glass.txt └── WineGlass.txt ├── LICENSE ├── load_blender.py ├── load_LINEMOD.py ├── load_deepvoxels.py ├── README.md ├── load_llff.py ├── run_nerf_helpers.py ├── find_bounding_box.py ├── render_model.py ├── run_nerf_inside.py ├── run_ior.py └── run_nerf.py /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.11.0+cu113 3 | torchvision==0.12.0+cu113 4 | imageio 5 | imageio-ffmpeg 6 | matplotlib 7 | configargparse 8 | tqdm 9 | opencv-python 10 | torchdiffeq -------------------------------------------------------------------------------- /configs/Ball.txt: -------------------------------------------------------------------------------- 1 | expname = Ball 2 | datadir = data\Ball 3 | basedir = logs 4 | dataset_type = llff 5 | 6 | factor = 1 7 | 8 | N_samples = 64 9 | N_importance = 64 10 | N_rand = 1024 11 | 12 | spherify = True 13 | use_viewdirs = True 14 | no_ndc = True 15 | lindisp = False 16 | 17 | llffhold = 10 18 | raw_noise_std = 1.0 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /configs/Pen.txt: -------------------------------------------------------------------------------- 1 | expname = Pen 2 | datadir = data\Pen 3 | basedir = logs 4 | dataset_type = llff 5 | 6 | factor = 1 7 | 8 | N_samples = 64 9 | N_importance = 64 10 | N_rand = 1024 11 | 12 | spherify = True 13 | use_viewdirs = True 14 | no_ndc = True 15 | lindisp = False 16 | 17 | llffhold = 10 18 | raw_noise_std = 1.0 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /configs/Glass.txt: -------------------------------------------------------------------------------- 1 | expname = Glass 2 | datadir = data\Glass 3 | basedir = logs 4 | dataset_type = llff 5 | 6 | factor = 1 7 | 8 | N_samples = 64 9 | N_importance = 64 10 | N_rand = 1024 11 | 12 | spherify = True 13 | use_viewdirs = True 14 | no_ndc = True 15 | lindisp = False 16 | 17 | llffhold = 10 18 | raw_noise_std = 1.0 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /configs/WineGlass.txt: -------------------------------------------------------------------------------- 1 | expname = WineGlass 2 | datadir = data\WineGlass 3 | basedir = logs 4 | dataset_type = llff 5 | 6 | factor = 1 7 | 8 | N_samples = 64 9 | N_importance = 64 10 | N_rand = 1024 11 | 12 | spherify = True 13 | use_viewdirs = True 14 | no_ndc = True 15 | lindisp = False 16 | 17 | llffhold = 10 18 | raw_noise_std = 1.0 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mojtaba Bemana 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_blender_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for frame in meta['frames'][::skip]: 57 | fname = os.path.join(basedir, frame['file_path'] + '.png') 58 | imgs.append(imageio.imread(fname)) 59 | poses.append(np.array(frame['transform_matrix'])) 60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 61 | poses = np.array(poses).astype(np.float32) 62 | counts.append(counts[-1] + imgs.shape[0]) 63 | all_imgs.append(imgs) 64 | all_poses.append(poses) 65 | 66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 67 | 68 | imgs = np.concatenate(all_imgs, 0) 69 | poses = np.concatenate(all_poses, 0) 70 | 71 | H, W = imgs[0].shape[:2] 72 | camera_angle_x = float(meta['camera_angle_x']) 73 | focal = .5 * W / np.tan(.5 * camera_angle_x) 74 | 75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 76 | 77 | if half_res: 78 | H = H//2 79 | W = W//2 80 | focal = focal/2. 81 | 82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 83 | for i, img in enumerate(imgs): 84 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 85 | imgs = imgs_half_res 86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 87 | 88 | 89 | return imgs, poses, render_poses, [H, W, focal], i_split 90 | 91 | 92 | -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Eikonal Fields for Refractive Novel-View Synthesis [[Project Page]](https://eikonalfield.mpi-inf.mpg.de) 2 | 3 | 4 |

5 | 6 |

7 | 8 | ## Installation 9 | 10 | ``` 11 | conda create -n eikonalfield python=3.8 12 | conda activate eikonalfield 13 | pip install -r requirements.txt 14 | ``` 15 | 16 |
17 | Dependencies (Click to expand) 18 | 19 | ### Dependencies 20 | * torch>=1.8 21 | * torchvision>=0.9.1 22 | * matplotlib 23 | * imageio 24 | * imageio-ffmpeg 25 | * configargparse 26 | * tqdm 27 | * opencv-python 28 | * [torchdiffeq](https://github.com/rtqichen/torchdiffeq) 29 | 30 |
31 | 32 | ## Dataset 33 | 34 | 35 | 36 | 37 | * [``Ball``](https://eikonalfield.mpi-inf.mpg.de//datasets/Ball.zip) 38 | * [``Glass``](https://eikonalfield.mpi-inf.mpg.de//datasets/Glass.zip) 39 | * [``Pen``](https://eikonalfield.mpi-inf.mpg.de//datasets/Pen.zip) 40 | * [``WineGlass``](https://eikonalfield.mpi-inf.mpg.de//datasets/WineGlass.zip) 41 | 42 | Each dataset contains the captured images and a short video of the scene. 43 | In the `captured images` folder, we provide the images with the original 4K resolution and a smaller resolution with the estimated camera poses using the COLMAP and [LLFF code](https://github.com/fyusion/llff). In the `captured video` folder, the video frames with their estimated camera poses are provided. 44 | 45 | 46 | ## Training 47 | 48 | * __Step 0__: Finding the camera poses ``poses_bounds.npy``with the instruction given [here](https://github.com/bmild/nerf#generating-poses-for-your-own-scenes) 49 | 50 | (For our dataset the camera parameters are already provided!) 51 | * __Step 1__: Estimating the radiance field for the entire scene by running ``run_nerf.py`` (The code is borrowed from [``nerf-pytorch``](https://github.com/yenchenlin/nerf-pytorch)) 52 | ``` 53 | python run_nerf.py --config configs/Glass.txt 54 | ``` 55 | The config files for the scenes in our dataset are located in the `configs` folder. 56 | 57 | * __Step 2__: Finding the 3D bounding box containing the transparent object using ``find_bounding_box.py``. 58 | ``` 59 | python find_bounding_box.py --config configs/Glass.txt 60 | ``` 61 | 62 |
63 | (Click to expand) In this step, 1/10 of the training images are displayed in order to mark a few points at the extent of the transparent object. 64 | 65 | 66 | 67 | 68 |
69 | 70 | * __Step 3__: Learning the index of refraction (IoR) field with ``run_ior.py`` 71 | ``` 72 | python run_ior.py --config configs/Glass.txt --N_rand 32000 --N_samples 128 73 | ``` 74 | * __Step 4__: Learning the radiance field for the object inside the transparent object using ``run_nerf_inside.py`` 75 | ``` 76 | python run_nerf_inside.py --config configs/Glass.txt --N_samples 512 77 | ``` 78 | (Please note that for the Ball scene we skipped this step) 79 | 80 | ## Rendering 81 | 82 | Please run the ``render_model.py`` with different modes to render the learned models at each training step. 83 | 84 | ``` 85 | python render_model.py --config configs/Glass.txt --N_samples 512 --mode 1 --render_video 86 | ``` 87 | 88 | The rendering options are: 89 | 90 | ``` 91 | --mode # use 0 to render the output of step 1 (Original NeRF) 92 | # use 1 to render the output of step 3 (Learned IoR) 93 | # use 2 to render the output of step 4 (Complete model with the inside NeRF) 94 | --render_test # rendering the test set images 95 | --render_video # rendering a video from a precomputed path 96 | --render_from_path # rendering a video from a specified path 97 | ``` 98 | ## Models 99 | 100 | Please find below our results and pre-trained model for each scene: 101 | * [``Ball``](https://eikonalfield.mpi-inf.mpg.de//results/Ball.zip) 102 | * [``Glass``](https://eikonalfield.mpi-inf.mpg.de//results/Glass.zip) 103 | * [``Pen``](https://eikonalfield.mpi-inf.mpg.de//results/Pen.zip) 104 | * [``WineGlass``](https://eikonalfield.mpi-inf.mpg.de//results/WineGlass.zip) 105 | 106 | Each scene contains the following folders: 107 | 108 | * ``model_weights`` --> the pre-trained model 109 | * ``bounding_box`` ---> the parameters of the bounding box 110 | * ``masked_regions`` ---> the masked images identifying the regions crossing the bounding box in each view 111 | * ``rendered_from_a_path`` ---> the rendered video result along the camera trajectory of the real video capture 112 | 113 | 114 | ## Details 115 | ### Capturing 116 | Our method works with a general capturing setup and does not require any calibration pattern or a specific setup. We spherically capture the scene and get close enough to the transparent object to properly sample the transparent object. 117 | 118 | 119 | ### Bounding Box 120 | Our bounding box (BB) is a rectangular cuboid parameterized by its center $c = (c_x,c_y,c_z)$ and the distances from the center to a face in each dimension $d=(d_x,d_y,d_z)$. 121 | For a 3D point $(x,y,z)$ in the space, the bounding box is analytically expressed as follows: 122 | 123 | $$ 𝐵𝐵(𝒙,𝒚,𝒛)= 1 - 𝑆(𝛽*(d_𝑥−|𝒙−𝑐_𝑥 |)). 𝑆(𝛽*(d_𝑦−|𝒚−𝑐_𝑦 |)) . 𝑆(𝛽*(d_𝑧−|𝒛−𝑐_𝑧 |)) $$ 124 | 125 | 126 | where 127 | $𝑆(𝑥)=\frac{1}{1+𝑒^{−𝑥}}$ is the sigmoid function and $\beta$ is the steepness coefficient. We use $\beta=200$ in our experiments. Using this function, a point inside the box gets a zero value and the points outside get a value close to one. 128 | 129 | ### Voxel grid 130 | In our IoR optimizations we first need to smooth the learned radiance field; however, explicitly smoothing an MLP-based radiance field is not straightforward, 131 | we instead fit a uniform 3D grid to the learned radiance field. We then band-limit the grid in the Fourier domain using a Gaussian blur kernel to obtain the coarse-to-fine radiance field. 132 | Note we fit the voxel grid to the NeRF coarse model rather than the fine one to avoid aliasing, and for the spherical captures, we limit the scene far bound to 2.5. 133 | 134 | ### IoR optimization 135 | Since we have a complete volume rendering model in the form of ODEs, we use a differentiable ODE solvers package provided by the [Neural ODE](https://github.com/rtqichen/torchdiffeq) to backpropagate through the ODEs. Moreover, using this package our training proceeds in a memory independent of the step count which allows the processing of more rays (as large as 32k rays) in each iteration. 136 | 137 | ### IoR model 138 | When using differentiable ODE solvers, we found it very important to use a smooth non-linear activation such as Softplus in our IoR MLP model otherwise the optimization becomes unstable. 139 | 140 | 141 | ### Rendering 142 | Since we could not utilize a hierarchical sampling in our volume rendering with ODE formulation, we consider 143 | 512 steps along the ray to properly sample both interior and exterior radiance fields. 144 | 145 | 146 | 147 | ## Citation 148 | 149 | @inproceedings{bemana2022eikonal, 150 | title={Eikonal Fields for Refractive Novel-View Synthesis}, 151 | author={Bemana, Mojtaba and Myszkowski, Karol and Revall Frisvad, Jeppe and Seidel, Hans-Peter and Ritschel, Tobias}, 152 | booktitle={Special Interest Group on Computer Graphics and Interactive Techniques Conference Proceedings}, 153 | pages={1--9}, 154 | year={2022} 155 | 156 | ## Contact 157 | mbemana@mpi-inf.mpg.de 158 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | wd = os.getcwd() 29 | 30 | for r in factors + resolutions: 31 | if isinstance(r, int): 32 | name = 'images_{}'.format(r) 33 | resizearg = '{}%'.format(100./r) 34 | else: 35 | name = 'images_{}x{}'.format(r[1], r[0]) 36 | resizearg = '{}x{}'.format(r[1], r[0]) 37 | imgdir = os.path.join(basedir, name) 38 | if os.path.exists(imgdir): 39 | continue 40 | 41 | print('Minifying', r, basedir) 42 | 43 | os.makedirs(imgdir) 44 | check_output('copy {}\* {}'.format(imgdir_orig, imgdir), shell=True) 45 | 46 | ext = imgs[0].split('.')[-1] 47 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 48 | print(args) 49 | os.chdir(imgdir) 50 | check_output(args, shell=True) 51 | os.chdir(wd) 52 | 53 | if ext != 'png': 54 | check_output('del {}\*.{}'.format(imgdir, ext), shell=True) 55 | print('Removed duplicates') 56 | print('Done') 57 | 58 | 59 | 60 | 61 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 62 | 63 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 64 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 65 | bds = poses_arr[:, -2:].transpose([1,0]) 66 | 67 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 68 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 69 | sh = imageio.imread(img0).shape 70 | 71 | sfx = '' 72 | 73 | if factor is not None and factor != 1: 74 | sfx = '_{}'.format(factor) 75 | _minify(basedir, factors=[factor]) 76 | factor = factor 77 | elif height is not None: 78 | factor = sh[0] / float(height) 79 | width = int(sh[1] / factor) 80 | _minify(basedir, resolutions=[[height, width]]) 81 | sfx = '_{}x{}'.format(width, height) 82 | elif width is not None: 83 | factor = sh[1] / float(width) 84 | height = int(sh[0] / factor) 85 | _minify(basedir, resolutions=[[height, width]]) 86 | sfx = '_{}x{}'.format(width, height) 87 | else: 88 | factor = 1 89 | 90 | imgdir = os.path.join(basedir, 'images' + sfx) 91 | if not os.path.exists(imgdir): 92 | print( imgdir, 'does not exist, returning' ) 93 | return 94 | 95 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 96 | # if poses.shape[-1] != len(imgfiles): 97 | # print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 98 | # return 99 | 100 | sh = imageio.imread(imgfiles[0]).shape 101 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 102 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 103 | 104 | if not load_imgs: 105 | return poses, bds 106 | 107 | def imread(f): 108 | if f.endswith('png'): 109 | return imageio.imread(f, ignoregamma=True) 110 | else: 111 | return imageio.imread(f) 112 | 113 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 114 | imgs = np.stack(imgs, -1) 115 | 116 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 117 | return poses, bds, imgs 118 | 119 | 120 | 121 | 122 | 123 | 124 | def normalize(x): 125 | return x / np.linalg.norm(x) 126 | 127 | def viewmatrix(z, up, pos): 128 | vec2 = normalize(z) 129 | vec1_avg = up 130 | vec0 = normalize(np.cross(vec1_avg, vec2)) 131 | vec1 = normalize(np.cross(vec2, vec0)) 132 | m = np.stack([vec0, vec1, vec2, pos], 1) 133 | return m 134 | 135 | def ptstocam(pts, c2w): 136 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 137 | return tt 138 | 139 | def poses_avg(poses): 140 | 141 | hwf = poses[0, :3, -1:] 142 | 143 | center = poses[:, :3, 3].mean(0) 144 | vec2 = normalize(poses[:, :3, 2].sum(0)) 145 | up = poses[:, :3, 1].sum(0) 146 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 147 | 148 | return c2w 149 | 150 | 151 | 152 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 153 | render_poses = [] 154 | rads = np.array(list(rads) + [1.]) 155 | hwf = c2w[:,4:5] 156 | 157 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 158 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 159 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 160 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 161 | return render_poses 162 | 163 | 164 | 165 | def recenter_poses(poses): 166 | 167 | poses_ = poses+0 168 | bottom = np.reshape([0,0,0,1.], [1,4]) 169 | c2w = poses_avg(poses) 170 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 171 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 172 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 173 | 174 | poses = np.linalg.inv(c2w) @ poses 175 | poses_[:,:3,:4] = poses[:,:3,:4] 176 | poses = poses_ 177 | return poses 178 | 179 | 180 | ##################### 181 | 182 | 183 | def spherify_poses(poses, bds): 184 | 185 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 186 | 187 | rays_d = poses[:,:3,2:3] 188 | rays_o = poses[:,:3,3:4] 189 | 190 | def min_line_dist(rays_o, rays_d): 191 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 192 | b_i = -A_i @ rays_o 193 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 194 | return pt_mindist 195 | 196 | pt_mindist = min_line_dist(rays_o, rays_d) 197 | 198 | center = pt_mindist 199 | up = (poses[:,:3,3] - center).mean(0) 200 | 201 | vec0 = normalize(up) 202 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 203 | vec2 = normalize(np.cross(vec0, vec1)) 204 | pos = center 205 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 206 | 207 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 208 | 209 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 210 | 211 | sc = 1./rad 212 | poses_reset[:,:3,3] *= sc 213 | bds *= sc 214 | rad *= sc 215 | 216 | centroid = np.mean(poses_reset[:,:3,3], 0) 217 | zh = centroid[2] 218 | radcircle = np.sqrt(rad**2-zh**2) 219 | new_poses = [] 220 | 221 | for th in np.linspace(0.,2.*np.pi, 120): 222 | 223 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 224 | up = np.array([0,0,-1.]) 225 | 226 | vec2 = normalize(camorigin) 227 | vec0 = normalize(np.cross(vec2, up)) 228 | vec1 = normalize(np.cross(vec2, vec0)) 229 | pos = camorigin 230 | p = np.stack([vec0, vec1, vec2, pos], 1) 231 | 232 | new_poses.append(p) 233 | 234 | new_poses = np.stack(new_poses, 0) 235 | 236 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 237 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 238 | 239 | return poses_reset, new_poses, bds 240 | 241 | 242 | def register_video_path(basedir,factor,poses,scale): 243 | 244 | poses_video, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 245 | poses_video = np.concatenate([poses_video[:, 1:2, :], -poses_video[:, 0:1, :], poses_video[:, 2:, :]], 1) 246 | poses_video = np.moveaxis(poses_video, -1, 0).astype(np.float32) 247 | # poses_video = poses_video[84:] 248 | poses_video[:,:3,3] *= scale 249 | 250 | poses_ = poses_video+0 251 | bottom = np.reshape([0,0,0,1.], [1,4]) 252 | c2w = poses_avg(poses) 253 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 254 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses_video.shape[0],1,1]) 255 | poses_video = np.concatenate([poses_video[:,:3,:4], bottom], -2) 256 | 257 | poses_video = np.linalg.inv(c2w) @ poses_video 258 | poses_[:,:3,:4] = poses_video[:,:3,:4] 259 | poses_video = poses_ 260 | 261 | poses = recenter_poses(poses) 262 | 263 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 264 | 265 | rays_d = poses[:,:3,2:3] 266 | rays_o = poses[:,:3,3:4] 267 | 268 | def min_line_dist(rays_o, rays_d): 269 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 270 | b_i = -A_i @ rays_o 271 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 272 | return pt_mindist 273 | 274 | pt_mindist = min_line_dist(rays_o, rays_d) 275 | 276 | center = pt_mindist 277 | up = (poses[:,:3,3] - center).mean(0) 278 | 279 | vec0 = normalize(up) 280 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 281 | vec2 = normalize(np.cross(vec0, vec1)) 282 | pos = center 283 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 284 | 285 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses_video[:,:3,:4]) 286 | 287 | poses_reset_ = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 288 | 289 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset_[:,:3,3]), -1))) 290 | 291 | sc = 1./rad 292 | poses_reset[:,:3,3] *= sc 293 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses_video[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 294 | 295 | poses_reset = poses_reset.astype(np.float32) 296 | 297 | 298 | return poses_reset 299 | 300 | 301 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_video=None, path_zflat=False): 302 | 303 | 304 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 305 | print('Loaded', basedir, bds.min(), bds.max()) 306 | 307 | # Correct rotation matrix ordering and move variable dim to axis 0 308 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 309 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 310 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 311 | images = imgs 312 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 313 | 314 | # Rescale if bd_factor is provided 315 | # sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 316 | sc = 1./(bds.max() - bds.min()) 317 | print(sc) 318 | poses[:,:3,3] *= sc 319 | bds *= sc 320 | print('Loaded', basedir, bds.min(), bds.max()) 321 | 322 | if path_video is not None: 323 | render_path = register_video_path(path_video,factor,poses,sc) 324 | 325 | 326 | if recenter: 327 | poses = recenter_poses(poses) 328 | 329 | if spherify: 330 | poses, render_poses, bds = spherify_poses(poses, bds) 331 | 332 | else: 333 | 334 | c2w = poses_avg(poses) 335 | print('recentered', c2w.shape) 336 | print(c2w[:3,:4]) 337 | 338 | ## Get spiral 339 | # Get average pose 340 | up = normalize(poses[:, :3, 1].sum(0)) 341 | 342 | # Find a reasonable "focus depth" for this dataset 343 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 344 | dt = .75 345 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 346 | focal = mean_dz 347 | 348 | # Get radii for spiral path 349 | shrink_factor = .8 350 | zdelta = close_depth * .2 351 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 352 | rads = np.percentile(np.abs(tt), 90, 0) 353 | c2w_path = c2w 354 | N_views = 120 355 | N_rots = 2 356 | if path_zflat: 357 | # zloc = np.percentile(tt, 10, 0)[2] 358 | zloc = -close_depth * .1 359 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 360 | rads[2] = 0. 361 | N_rots = 1 362 | N_views/=2 363 | 364 | # Generate poses for spiral path 365 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 366 | 367 | 368 | render_poses = np.array(render_poses).astype(np.float32) 369 | 370 | c2w = poses_avg(poses) 371 | print('Data:') 372 | print(poses.shape, images.shape, bds.shape) 373 | 374 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 375 | i_test = np.argmin(dists) 376 | print('HOLDOUT view is', i_test) 377 | 378 | images = images.astype(np.float32) 379 | poses = poses.astype(np.float32) 380 | 381 | if path_video is None: 382 | render_path = render_poses 383 | 384 | 385 | return images, poses, bds, render_path, i_test 386 | 387 | 388 | 389 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from tqdm import tqdm, trange 8 | 9 | # Misc 10 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 11 | img2abs = lambda x, y : torch.mean(torch.abs(x - y)) 12 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 13 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 14 | 15 | 16 | # Positional encoding (section 5.1) 17 | class Embedder: 18 | def __init__(self, **kwargs): 19 | self.kwargs = kwargs 20 | self.create_embedding_fn() 21 | 22 | def create_embedding_fn(self): 23 | embed_fns = [] 24 | d = self.kwargs['input_dims'] 25 | out_dim = 0 26 | if self.kwargs['include_input']: 27 | embed_fns.append(lambda x : x) 28 | out_dim += d 29 | 30 | max_freq = self.kwargs['max_freq_log2'] 31 | N_freqs = self.kwargs['num_freqs'] 32 | 33 | if self.kwargs['log_sampling']: 34 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 35 | else: 36 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 37 | 38 | for freq in freq_bands: 39 | for p_fn in self.kwargs['periodic_fns']: 40 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 41 | out_dim += d 42 | 43 | self.embed_fns = embed_fns 44 | self.out_dim = out_dim 45 | 46 | def embed(self, inputs): 47 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 48 | 49 | 50 | def get_embedder(multires, i=0): 51 | if i == -1: 52 | return nn.Identity(), 3 53 | 54 | embed_kwargs = { 55 | 'include_input' : True, 56 | 'input_dims' : 3, 57 | 'max_freq_log2' : multires-1, 58 | 'num_freqs' : multires, 59 | 'log_sampling' : True, 60 | 'periodic_fns' : [torch.sin, torch.cos], 61 | } 62 | 63 | embedder_obj = Embedder(**embed_kwargs) 64 | embed = lambda x, eo=embedder_obj : eo.embed(x) 65 | return embed, embedder_obj.out_dim 66 | 67 | 68 | # Model 69 | class NeRF(nn.Module): 70 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 71 | """ 72 | """ 73 | super(NeRF, self).__init__() 74 | self.D = D 75 | self.W = W 76 | self.input_ch = input_ch 77 | self.input_ch_views = input_ch_views 78 | self.skips = skips 79 | self.use_viewdirs = use_viewdirs 80 | 81 | self.pts_linears = nn.ModuleList( 82 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 83 | 84 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 85 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 86 | 87 | ### Implementation according to the paper 88 | # self.views_linears = nn.ModuleList( 89 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 90 | 91 | if use_viewdirs: 92 | self.feature_linear = nn.Linear(W, W) 93 | self.alpha_linear = nn.Linear(W, 1) 94 | self.rgb_linear = nn.Linear(W//2, 3) 95 | else: 96 | self.output_linear = nn.Linear(W, output_ch) 97 | 98 | def forward(self, x): 99 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 100 | h = input_pts 101 | for i, l in enumerate(self.pts_linears): 102 | h = self.pts_linears[i](h) 103 | h = F.relu(h) 104 | if i in self.skips: 105 | h = torch.cat([input_pts, h], -1) 106 | 107 | if self.use_viewdirs: 108 | alpha = self.alpha_linear(h) 109 | feature = self.feature_linear(h) 110 | h = torch.cat([feature, input_views], -1) 111 | 112 | for i, l in enumerate(self.views_linears): 113 | h = self.views_linears[i](h) 114 | h = F.relu(h) 115 | 116 | rgb = self.rgb_linear(h) 117 | outputs = torch.cat([rgb, alpha], -1) 118 | else: 119 | outputs = self.output_linear(h) 120 | 121 | return outputs 122 | 123 | def load_weights_from_keras(self, weights): 124 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 125 | 126 | # Load pts_linears 127 | for i in range(self.D): 128 | idx_pts_linears = 2 * i 129 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 130 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 131 | 132 | # Load feature_linear 133 | idx_feature_linear = 2 * self.D 134 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 135 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 136 | 137 | # Load views_linears 138 | idx_views_linears = 2 * self.D + 2 139 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 140 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 141 | 142 | # Load rgb_linear 143 | idx_rbg_linear = 2 * self.D + 4 144 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 145 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 146 | 147 | # Load alpha_linear 148 | idx_alpha_linear = 2 * self.D + 6 149 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 150 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 151 | 152 | 153 | 154 | def init_weights(m): 155 | if type(m) == nn.Linear: 156 | torch.nn.init.xavier_uniform_(m.weight) 157 | 158 | def init_weights_zeros_(m): 159 | if type(m) == nn.Linear: 160 | torch.nn.init.zeros_(m.weight) 161 | torch.nn.init.zeros_(m.bias) 162 | 163 | 164 | 165 | class MLP_IOR(nn.Module): 166 | def __init__(self, D=8, W=64, input_ch=3 + 3*2*6, output_ch=1, skips=[4],is_index=False): 167 | """ 168 | """ 169 | super(MLP_IOR, self).__init__() 170 | self.D = D 171 | self.W = W 172 | self.input_ch = input_ch 173 | self.skips = skips 174 | self.is_index = is_index 175 | self.mlp_ior = nn.ModuleList( 176 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 177 | 178 | 179 | self.mlp_ior_end = nn.Linear(W, output_ch) 180 | self.softplus = nn.Softplus(beta=5) 181 | # self.IOR = nn.Parameter(0.5*torch.ones(1)) 182 | def forward(self, x): 183 | input_pts = x 184 | h = input_pts 185 | for i, l in enumerate(self.mlp_ior): 186 | h = self.mlp_ior[i](h) 187 | h = self.softplus(h) 188 | if i in self.skips: 189 | h = torch.cat([input_pts, h], -1) 190 | 191 | 192 | outputs = self.mlp_ior_end(h) 193 | shape_out = self.softplus(outputs) 194 | 195 | return shape_out 196 | # Ray helpers 197 | def get_rays(H, W, K, c2w): 198 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 199 | i = i.t() 200 | j = j.t() 201 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 202 | # Rotate ray directions from camera frame to the world frame 203 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 204 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 205 | rays_o = c2w[:3,-1].expand(rays_d.shape) 206 | return rays_o, rays_d 207 | 208 | 209 | def get_rays_np(H, W, K, c2w): 210 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 211 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) 212 | # Rotate ray directions from camera frame to the world frame 213 | 214 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 215 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 216 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 217 | return rays_o, rays_d 218 | 219 | 220 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 221 | # Shift ray origins to near plane 222 | t = -(near + rays_o[...,2]) / rays_d[...,2] 223 | rays_o = rays_o + t[...,None] * rays_d 224 | 225 | # Projection 226 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 227 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 228 | o2 = 1. + 2. * near / rays_o[...,2] 229 | 230 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 231 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 232 | d2 = -2. * near / rays_o[...,2] 233 | 234 | rays_o = torch.stack([o0,o1,o2], -1) 235 | rays_d = torch.stack([d0,d1,d2], -1) 236 | 237 | return rays_o, rays_d 238 | 239 | 240 | # Hierarchical sampling (section 5.2) 241 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 242 | # Get pdf 243 | weights = weights + 1e-5 # prevent nans 244 | pdf = weights / torch.sum(weights, -1, keepdim=True) 245 | cdf = torch.cumsum(pdf, -1) 246 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 247 | 248 | # Take uniform samples 249 | if det: 250 | u = torch.linspace(0., 1., steps=N_samples) 251 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 252 | else: 253 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 254 | 255 | # Pytest, overwrite u with numpy's fixed random numbers 256 | if pytest: 257 | np.random.seed(0) 258 | new_shape = list(cdf.shape[:-1]) + [N_samples] 259 | if det: 260 | u = np.linspace(0., 1., N_samples) 261 | u = np.broadcast_to(u, new_shape) 262 | else: 263 | u = np.random.rand(*new_shape) 264 | u = torch.Tensor(u) 265 | 266 | # Invert CDF 267 | u = u.contiguous() 268 | inds = torch.searchsorted(cdf, u, right=True) 269 | below = torch.max(torch.zeros_like(inds-1), inds-1) 270 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 271 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 272 | 273 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 274 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 275 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 276 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 277 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 278 | 279 | denom = (cdf_g[...,1]-cdf_g[...,0]) 280 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 281 | t = (u-cdf_g[...,0])/denom 282 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 283 | 284 | return samples 285 | 286 | dtype_long = torch.cuda.LongTensor 287 | 288 | def trilinear_interpolation(inputs,xq,scene_bound): 289 | 290 | res = inputs.shape[0] 291 | _min = scene_bound[0,:] 292 | _max = scene_bound[1,:] 293 | xq = (res-1)*(xq-_min)/(_max-_min) 294 | xq = xq.clip(0,res-1) 295 | x = xq[:,0] 296 | y = xq[:,1] 297 | z = xq[:,2] 298 | 299 | x_0 = torch.clamp(torch.floor(x).type(dtype_long),0, res-1) 300 | x_1 = torch.clamp(x_0 + 1, 0, res-1) 301 | 302 | y_0 = torch.clamp(torch.floor(y).type(dtype_long),0, res-1) 303 | y_1 = torch.clamp(y_0 + 1, 0, res-1) 304 | 305 | z_0 = torch.clamp(torch.floor(z).type(dtype_long),0, res-1) 306 | z_1 = torch.clamp(z_0 + 1, 0, res-1) 307 | 308 | 309 | u, v, w = x-x_0, y-y_0, z-z_0 310 | u = u[:,None] 311 | v = v[:,None] 312 | w = w[:,None] 313 | 314 | c_000 = inputs[x_0,y_0,z_0] 315 | c_001 = inputs[x_0,y_0,z_1] 316 | c_010 = inputs[x_0,y_1,z_0] 317 | c_011 = inputs[x_0,y_1,z_1] 318 | c_100 = inputs[x_1,y_0,z_0] 319 | c_101 = inputs[x_1,y_0,z_1] 320 | c_110 = inputs[x_1,y_1,z_0] 321 | c_111 = inputs[x_1,y_1,z_1] 322 | # print(c_111.shape) 323 | # print(u.shape) 324 | 325 | c_xyz = (1.0-u)*(1.0-v)*(1.0-w)*c_000 + \ 326 | (1.0-u)*(1.0-v)*(w)*c_001 + \ 327 | (1.0-u)*(v)*(1.0-w)*c_010 + \ 328 | (1.0-u)*(v)*(w)*c_011 + \ 329 | (u)*(1.0-v)*(1.0-w)*c_100 + \ 330 | (u)*(1.0-v)*(w)*c_101 + \ 331 | (u)*(v)*(1.0-w)*c_110 + \ 332 | (u)*(v)*(w)*c_111 333 | 334 | 335 | 336 | return c_xyz 337 | 338 | 339 | 340 | def lowpass_3d(res,sigma): 341 | 342 | res = res-1 343 | xx, yy,zz = torch.meshgrid(torch.linspace(0, res, res+1), torch.linspace(0, res, res+1), torch.linspace(0, res, res+1)) 344 | xx = xx/res 345 | yy = yy/res 346 | zz = zz/res 347 | dist = (torch.square(xx-0.5)+torch.square(yy-0.5)+torch.square(zz-0.5)) 348 | 349 | return torch.exp(-dist/(2*sigma**2)) 350 | 351 | 352 | def voxel_lowpass_filtering(input_voxel,filter_3d): 353 | 354 | fftn_grid = torch.fft.fftshift(torch.fft.fftn(input_voxel,input_voxel.shape,norm='forward')) 355 | filtered_grid = torch.fft.ifftn(torch.fft.ifftshift(fftn_grid*filter_3d[...,None].to(input_voxel)),norm='forward') 356 | 357 | 358 | return filtered_grid.real 359 | 360 | 361 | eps = 1e-6 362 | def normalizing(vec): 363 | 364 | return vec/(torch.linalg.norm(vec, dim=-1, keepdim=True)+eps) 365 | 366 | 367 | 368 | def get_scene_bound(near,far,H,W,K,poses,min_=.25,max_=.25): 369 | 370 | 371 | N_samples = 2 372 | t_vals = torch.linspace(0., 1., steps=N_samples) 373 | z_vals = near * (1.-t_vals) + far * (t_vals) 374 | 375 | z_vals = z_vals.expand([H,W, N_samples]) 376 | scene_bound_min = [] 377 | scene_bound_max = [] 378 | 379 | for cams in tqdm(range(len(poses))): 380 | rays_o, rays_d = get_rays(H, W, K, poses[cams]) 381 | pts = rays_o[...,None,:] + (rays_d[...,None,:]) * z_vals[...,:,None] 382 | pts = pts.reshape(-1,3) 383 | max_pts = torch.quantile((pts),1.-max_,axis=0) 384 | min_pts = torch.quantile((pts),min_,axis=0) 385 | scene_bound_min.append(min_pts) 386 | scene_bound_max.append(max_pts) 387 | 388 | 389 | scene_bound_min = torch.stack(scene_bound_min,0) 390 | scene_bound_max = torch.stack(scene_bound_max,0) 391 | 392 | max_pts = torch.quantile(scene_bound_max,1.-max_,axis=0) 393 | min_pts = torch.quantile(scene_bound_min,min_,axis=0) 394 | 395 | return torch.stack([min_pts,max_pts],0) 396 | 397 | 398 | def get_bb_weights(pts,bounding_box_val,beta=200): 399 | 400 | center = bounding_box_val[0] 401 | rad = bounding_box_val[1] 402 | 403 | x_dist = torch.abs(pts[...,0:1] - torch.tensor(center[0]).to(pts)) 404 | y_dist = torch.abs(pts[...,1:2] - torch.tensor(center[1]).to(pts)) 405 | z_dist = torch.abs(pts[...,2:3] - torch.tensor(center[2]).to(pts)) 406 | 407 | weights = torch.sigmoid(beta*(rad[0]-x_dist))*torch.sigmoid(beta*(rad[1]-y_dist))*torch.sigmoid(beta*(rad[2]-z_dist)) 408 | 409 | return 1.0 - weights 410 | 411 | 412 | 413 | 414 | def get_voxel_grid(voxel_res,scene_bound,poses,query_fn,nerf_model,masking=False,bb_vals=None,flag=False): 415 | 416 | 417 | min_bound,max_bound = scene_bound.cpu().numpy() 418 | 419 | 420 | x = np.arange(voxel_res) 421 | grid_x,grid_y,grid_z = np.meshgrid(x,x,x) 422 | 423 | pts_ = np.stack((grid_y,grid_x,grid_z),-1) 424 | pts_ = pts_.reshape(-1,3) 425 | pts_ = pts_/(voxel_res-1) 426 | pts_ = (pts_)*(max_bound-min_bound) + min_bound 427 | 428 | raw_avg = 0 429 | 430 | for cam in tqdm(range(len(poses))): 431 | 432 | raw_stack = [] 433 | pose = poses[cam,:,-1] 434 | 435 | batch = 128*128*128 436 | for ii in range(0,pts_.shape[0],batch): 437 | with torch.no_grad(): 438 | 439 | new_pts_ = torch.tensor(pts_[ii:ii+batch,:]).to(poses).to(torch.float32) 440 | viewdirs = new_pts_ - pose 441 | 442 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 443 | 444 | raw = query_fn(new_pts_, viewdirs, nerf_model) 445 | raw[...,0:3] = torch.sigmoid(raw[...,0:3]) 446 | raw[...,3:4] = F.relu(raw[...,3:4]) 447 | raw_stack.append(raw) 448 | 449 | raw_stack = torch.cat(raw_stack,0) 450 | 451 | 452 | raw_avg += raw_stack 453 | 454 | raw_avg = raw_avg/len(poses) 455 | return raw_avg.reshape(voxel_res,voxel_res,voxel_res,4) 456 | 457 | 458 | 459 | 460 | 461 | def raw2outputs(raw, dists): 462 | """Transforms model's predictions to semantically meaningful values. 463 | Args: 464 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 465 | z_vals: [num_rays, num_samples along ray]. Integration time. 466 | rays_d: [num_rays, 3]. Direction of each ray. 467 | Returns: 468 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 469 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 470 | acc_map: [num_rays]. Sum of weights along each ray. 471 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 472 | depth_map: [num_rays]. Estimated distance to object. 473 | """ 474 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 475 | 476 | 477 | 478 | alpha = raw2alpha(raw[...,3] , dists) # [N_rays, N_samples] 479 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 480 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 481 | 482 | return weights 483 | -------------------------------------------------------------------------------- /find_bounding_box.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys,cv2 3 | import numpy as np 4 | import imageio 5 | import json 6 | import random 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from tqdm import tqdm, trange 12 | 13 | from run_nerf_helpers import * 14 | 15 | from load_llff import load_llff_data 16 | from load_deepvoxels import load_dv_data 17 | from load_blender import load_blender_data 18 | from load_LINEMOD import load_LINEMOD_data 19 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 20 | 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | np.random.seed(0) 24 | DEBUG = False 25 | 26 | 27 | def batchify(fn, chunk): 28 | """Constructs a version of 'fn' that applies to smaller batches. 29 | """ 30 | if chunk is None: 31 | return fn 32 | def ret(inputs): 33 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 34 | return ret 35 | 36 | 37 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 38 | """Prepares inputs and applies network 'fn'. 39 | """ 40 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 41 | embedded = embed_fn(inputs_flat) 42 | 43 | if viewdirs is not None: 44 | input_dirs = viewdirs[:,None].expand(inputs.shape) 45 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 46 | embedded_dirs = embeddirs_fn(input_dirs_flat) 47 | embedded = torch.cat([embedded, embedded_dirs], -1) 48 | 49 | outputs_flat = batchify(fn, netchunk)(embedded) 50 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 51 | return outputs 52 | 53 | 54 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 55 | """Render rays in smaller minibatches to avoid OOM. 56 | """ 57 | all_ret = {} 58 | for i in range(0, rays_flat.shape[0], chunk): 59 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 60 | for k in ret: 61 | if k not in all_ret: 62 | all_ret[k] = [] 63 | all_ret[k].append(ret[k]) 64 | 65 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 66 | return all_ret 67 | 68 | 69 | 70 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 71 | near=0., far=1., 72 | use_viewdirs=False, c2w_staticcam=None, 73 | **kwargs): 74 | """Render rays 75 | Args: 76 | H: int. Height of image in pixels. 77 | W: int. Width of image in pixels. 78 | focal: float. Focal length of pinhole camera. 79 | chunk: int. Maximum number of rays to process simultaneously. Used to 80 | control maximum memory usage. Does not affect final results. 81 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 82 | each example in batch. 83 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 84 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 85 | near: float or array of shape [batch_size]. Nearest distance for a ray. 86 | far: float or array of shape [batch_size]. Farthest distance for a ray. 87 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 88 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 89 | camera while using other c2w argument for viewing directions. 90 | Returns: 91 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 92 | disp_map: [batch_size]. Disparity map. Inverse of depth. 93 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 94 | extras: dict with everything returned by render_rays(). 95 | """ 96 | if c2w is not None: 97 | # special case to render full image 98 | rays_o, rays_d = get_rays(H, W, K, c2w) 99 | else: 100 | # use provided ray batch 101 | rays_o, rays_d = rays 102 | 103 | if use_viewdirs: 104 | # provide ray directions as input 105 | viewdirs = rays_d 106 | if c2w_staticcam is not None: 107 | # special case to visualize effect of viewdirs 108 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 109 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 110 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 111 | 112 | sh = rays_d.shape # [..., 3] 113 | if ndc: 114 | # for forward facing scenes 115 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 116 | 117 | # Create ray batch 118 | rays_o = torch.reshape(rays_o, [-1,3]).float() 119 | rays_d = torch.reshape(rays_d, [-1,3]).float() 120 | 121 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 122 | rays = torch.cat([rays_o, rays_d, near, far], -1) 123 | if use_viewdirs: 124 | rays = torch.cat([rays, viewdirs], -1) 125 | 126 | # Render and reshape 127 | all_ret = batchify_rays(rays, chunk, **kwargs) 128 | for k in all_ret: 129 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 130 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 131 | 132 | k_extract = ['rgb_map', 'disp_map', 'acc_map','weights'] 133 | ret_list = [all_ret[k] for k in k_extract] 134 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 135 | return ret_list + [ret_dict] 136 | 137 | 138 | def create_nerf(args): 139 | """Instantiate NeRF's MLP model. 140 | """ 141 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 142 | 143 | input_ch_views = 0 144 | embeddirs_fn = None 145 | if args.use_viewdirs: 146 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 147 | output_ch = 5 if args.N_importance > 0 else 4 148 | skips = [4] 149 | model = NeRF(D=args.netdepth, W=args.netwidth, 150 | input_ch=input_ch, output_ch=output_ch, skips=skips, 151 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 152 | grad_vars = list(model.parameters()) 153 | 154 | model_fine = None 155 | if args.N_importance > 0: 156 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 157 | input_ch=input_ch, output_ch=output_ch, skips=skips, 158 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 159 | grad_vars += list(model_fine.parameters()) 160 | 161 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 162 | embed_fn=embed_fn, 163 | embeddirs_fn=embeddirs_fn, 164 | netchunk=args.netchunk) 165 | 166 | # Create optimizer 167 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 168 | 169 | start = 0 170 | basedir = args.basedir 171 | expname = args.expname 172 | 173 | ########################## 174 | 175 | # Load checkpoints 176 | 177 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f] 178 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path)) 179 | print('Found ckpts', ckpts) 180 | if len(ckpts) > 0 and not args.no_reload: 181 | ckpt_path = ckpts[-1] 182 | print('Reloading from', ckpt_path) 183 | ckpt = torch.load(ckpt_path) 184 | 185 | start = ckpt['global_step'] 186 | 187 | # Load model 188 | model.load_state_dict(ckpt['network_fn_state_dict']) 189 | if model_fine is not None: 190 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 191 | 192 | ########################## 193 | 194 | render_kwargs_train = { 195 | 'network_query_fn' : network_query_fn, 196 | 'perturb' : args.perturb, 197 | 'N_importance' : args.N_importance, 198 | 'network_fine' : model_fine, 199 | 'N_samples' : args.N_samples, 200 | 'network_fn' : model, 201 | 'use_viewdirs' : args.use_viewdirs, 202 | 'white_bkgd' : args.white_bkgd, 203 | 'raw_noise_std' : args.raw_noise_std, 204 | } 205 | 206 | # NDC only good for LLFF-style forward facing data 207 | if args.dataset_type != 'llff' or args.no_ndc: 208 | print('Not ndc!') 209 | render_kwargs_train['ndc'] = False 210 | render_kwargs_train['lindisp'] = args.lindisp 211 | 212 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 213 | render_kwargs_test['perturb'] = False 214 | render_kwargs_test['raw_noise_std'] = 0. 215 | 216 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 217 | 218 | 219 | def raw2outputs(raw, z_vals, rays_d,raw_noise_std=0, white_bkgd=False, pytest=False): 220 | """Transforms model's predictions to semantically meaningful values. 221 | Args: 222 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 223 | z_vals: [num_rays, num_samples along ray]. Integration time. 224 | rays_d: [num_rays, 3]. Direction of each ray. 225 | Returns: 226 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 227 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 228 | acc_map: [num_rays]. Sum of weights along each ray. 229 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 230 | depth_map: [num_rays]. Estimated distance to object. 231 | """ 232 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 233 | 234 | dists = z_vals[...,1:] - z_vals[...,:-1] 235 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 236 | 237 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 238 | 239 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 240 | noise = 0. 241 | if raw_noise_std > 0.: 242 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 243 | 244 | # Overwrite randomly sampled data if pytest 245 | if pytest: 246 | np.random.seed(0) 247 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 248 | noise = torch.Tensor(noise) 249 | 250 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 251 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 252 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 253 | 254 | # 255 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 256 | 257 | depth_map = torch.sum(weights * z_vals, -1) 258 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 259 | acc_map = torch.sum(weights, -1) 260 | 261 | if white_bkgd: 262 | rgb_map = rgb_map + (1.-acc_map[...,None]) 263 | 264 | return rgb_map, disp_map, acc_map, weights, depth_map 265 | 266 | 267 | def render_rays(ray_batch, 268 | network_fn, 269 | network_query_fn, 270 | N_samples, 271 | masking, 272 | bb_vals, 273 | retraw=True, 274 | lindisp=False, 275 | perturb=0., 276 | N_importance=0, 277 | network_fine=None, 278 | white_bkgd=False, 279 | raw_noise_std=0., 280 | verbose=False, 281 | pytest=False): 282 | 283 | N_rays = ray_batch.shape[0] 284 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 285 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 286 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 287 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 288 | 289 | t_vals = torch.linspace(0., 1., steps=N_samples) 290 | if not lindisp: 291 | z_vals = near * (1.-t_vals) + far * (t_vals) 292 | else: 293 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 294 | 295 | z_vals = z_vals.expand([N_rays, N_samples]) 296 | 297 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 298 | 299 | 300 | raw = network_query_fn(pts, viewdirs, network_fn) 301 | 302 | 303 | if masking: 304 | weights = get_bb_weights(pts,bb_vals) 305 | raw = raw*weights 306 | 307 | 308 | 309 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 310 | 311 | if N_importance > 0: 312 | 313 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 314 | 315 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 316 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 317 | z_samples = z_samples.detach() 318 | 319 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 320 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 321 | 322 | run_fn = network_fn if network_fine is None else network_fine 323 | raw = network_query_fn(pts, viewdirs, run_fn) 324 | 325 | if masking: 326 | weights = get_bb_weights(pts,bb_vals) 327 | raw = raw*weights 328 | 329 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 330 | 331 | 332 | ret = {'rgb_map' : rgb_map, 'disp_map' : depth_map, 'acc_map' : z_vals,'weights':rays_o + rays_d*depth_map[:,None]} 333 | if retraw: 334 | ret['raw'] = pts 335 | if N_importance > 0: 336 | ret['rgb0'] = rgb_map_0 337 | ret['disp0'] = disp_map_0 338 | ret['acc0'] = acc_map_0 339 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 340 | 341 | 342 | for k in ret: 343 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 344 | print(f"! [Numerical Error] {k} contains nan or inf.") 345 | 346 | return ret 347 | 348 | 349 | 350 | def config_parser(): 351 | 352 | import configargparse 353 | parser = configargparse.ArgumentParser() 354 | parser.add_argument('--config', is_config_file=True, 355 | help='config file path') 356 | parser.add_argument("--expname", default='glass_ball_new',type=str, 357 | help='experiment name') 358 | parser.add_argument("--basedir", type=str, default='logs', 359 | help='where to store ckpts and logs') 360 | parser.add_argument("--datadir", type=str, default='data\\glass_ball', 361 | help='input data directory') 362 | 363 | # training options 364 | parser.add_argument("--netdepth", type=int, default=8, 365 | help='layers in network') 366 | parser.add_argument("--netwidth", type=int, default=256, 367 | help='channels per layer') 368 | parser.add_argument("--netdepth_fine", type=int, default=8, 369 | help='layers in fine network') 370 | parser.add_argument("--netwidth_fine", type=int, default=256, 371 | help='channels per layer in fine network') 372 | parser.add_argument("--N_rand", type=int, default=32*32, 373 | help='batch size (number of random rays per gradient step)') 374 | parser.add_argument("--lrate", type=float, default=5e-4, 375 | help='learning rate') 376 | parser.add_argument("--lrate_decay", type=int, default=250, 377 | help='exponential learning rate decay (in 1000 steps)') 378 | parser.add_argument("--chunk", type=int, default=1024*32, 379 | help='number of rays processed in parallel, decrease if running out of memory') 380 | parser.add_argument("--netchunk", type=int, default=1024*64, 381 | help='number of pts sent through network in parallel, decrease if running out of memory') 382 | parser.add_argument("--no_batching", default=False,action='store_true', 383 | help='only take random rays from 1 image at a time') 384 | parser.add_argument("--no_reload", action='store_true', 385 | help='do not reload weights from saved ckpt') 386 | parser.add_argument("--model_path", type=str, default='model_weights', 387 | help='path to trained model weights') 388 | # rendering options 389 | parser.add_argument("--N_samples", type=int, default=64, 390 | help='number of coarse samples per ray') 391 | parser.add_argument("--N_importance", type=int, default=64, 392 | help='number of additional fine samples per ray') 393 | parser.add_argument("--perturb", type=float, default=1., 394 | help='set to 0. for no jitter, 1. for jitter') 395 | parser.add_argument("--use_viewdirs", default=True,action='store_true', 396 | help='use full 5D input instead of 3D') 397 | parser.add_argument("--use_mask", default=False,action='store_true', 398 | help='use full 5D input instead of 3D') 399 | parser.add_argument("--i_embed", type=int, default=0, 400 | help='set 0 for default positional encoding, -1 for none') 401 | parser.add_argument("--multires", type=int, default=10, 402 | help='log2 of max freq for positional encoding (3D location)') 403 | parser.add_argument("--multires_views", type=int, default=4, 404 | help='log2 of max freq for positional encoding (2D direction)') 405 | parser.add_argument("--raw_noise_std", type=float, default=1., 406 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 407 | 408 | parser.add_argument("--render_only", action='store_true', 409 | help='do not optimize, reload weights and render out render_poses path') 410 | parser.add_argument("--render_test", action='store_true', 411 | help='render the test set instead of render_poses path') 412 | parser.add_argument("--render_factor", type=int, default=0, 413 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 414 | 415 | # training options 416 | parser.add_argument("--precrop_iters", type=int, default=0, 417 | help='number of steps to train on central crops') 418 | parser.add_argument("--precrop_frac", type=float, 419 | default=.5, help='fraction of img taken for central crops') 420 | 421 | # dataset options 422 | parser.add_argument("--dataset_type", type=str, default='llff', 423 | help='options: llff / blender / deepvoxels') 424 | parser.add_argument("--testskip", type=int, default=8, 425 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 426 | 427 | ## deepvoxels flags 428 | parser.add_argument("--shape", type=str, default='greek', 429 | help='options : armchair / cube / greek / vase') 430 | 431 | ## blender flags 432 | parser.add_argument("--white_bkgd", action='store_true', 433 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 434 | parser.add_argument("--half_res", action='store_true', 435 | help='load blender synthetic data at 400x400 instead of 800x800') 436 | 437 | ## llff flags 438 | parser.add_argument("--factor", type=int, default=1, 439 | help='downsample factor for LLFF images') 440 | parser.add_argument("--no_ndc", default = True,action='store_true', 441 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 442 | parser.add_argument("--lindisp", action='store_true', 443 | help='sampling linearly in disparity rather than depth') 444 | parser.add_argument("--spherify", default = True,action='store_true', 445 | help='set for spherical 360 scenes') 446 | parser.add_argument("--llffhold", type=int, default=10, 447 | help='will take every 1/N images as LLFF test set, paper uses 8') 448 | 449 | parser.add_argument("--quantile", type=float, default=0.03, 450 | help='q-th quantiles between [0 1]') 451 | parser.add_argument("--ratio", type=float, default=1.1, 452 | help='enlarging the bounding box with ratio') 453 | parser.add_argument("--num_im", type=int, default=10, 454 | help='will take every 1/N images to visualize') 455 | parser.add_argument("--reload_bb", default=False, action='store_true', 456 | help='reload an existing bounding box value') 457 | 458 | 459 | return parser 460 | 461 | def find_box(): 462 | 463 | parser = config_parser() 464 | args = parser.parse_args() 465 | 466 | # Load data 467 | K = None 468 | if args.dataset_type == 'llff': 469 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 470 | recenter=True, bd_factor=.75, 471 | spherify=args.spherify) 472 | hwf = poses[0,:3,-1] 473 | poses = poses[:,:3,:4] 474 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 475 | if not isinstance(i_test, list): 476 | i_test = [i_test] 477 | 478 | if args.num_im > 0: 479 | print('number of images to annotate: ', args.num_im) 480 | i_test = np.arange(images.shape[0])[::args.num_im] 481 | 482 | i_val = i_test 483 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 484 | (i not in i_test and i not in i_val)]) 485 | 486 | print('DEFINING BOUNDS') 487 | if args.no_ndc: 488 | near = np.ndarray.min(bds) * .9 489 | far = np.ndarray.max(bds) * 1. 490 | 491 | else: 492 | near = 0. 493 | far = 1. 494 | print('NEAR FAR', near, far) 495 | 496 | 497 | # Cast intrinsics to right types 498 | H, W, focal = hwf 499 | H, W = int(H), int(W) 500 | hwf = [H, W, focal] 501 | 502 | if K is None: 503 | K = np.array([ 504 | [focal, 0, 0.5*W], 505 | [0, focal, 0.5*H], 506 | [0, 0, 1] 507 | ]) 508 | 509 | if args.render_test: 510 | render_poses = np.array(poses[i_test]) 511 | 512 | # Create log dir and copy the config file 513 | basedir = args.basedir 514 | expname = args.expname 515 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 516 | f = os.path.join(basedir, expname, 'args.txt') 517 | with open(f, 'w') as file: 518 | for arg in sorted(vars(args)): 519 | attr = getattr(args, arg) 520 | file.write('{} = {}\n'.format(arg, attr)) 521 | if args.config is not None: 522 | f = os.path.join(basedir, expname, 'config.txt') 523 | with open(f, 'w') as file: 524 | file.write(open(args.config, 'r').read()) 525 | 526 | # Create nerf model 527 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 528 | global_step = start 529 | 530 | bds_dict = { 531 | 'near' : near, 532 | 'far' : far, 533 | } 534 | render_kwargs_test.update(bds_dict) 535 | 536 | render_poses = torch.Tensor(render_poses).to(device) 537 | N_rand = args.N_rand 538 | use_batching = not args.no_batching 539 | 540 | poses = torch.Tensor(poses).to(device) 541 | 542 | if not args.reload_bb: 543 | point_matrix = [] 544 | 545 | counter = 0 546 | def mousePoints(event,x,y,flags,params): 547 | if event == cv2.EVENT_LBUTTONDOWN: 548 | point_matrix.append([x,y]) 549 | cv2.circle(img,(x,y),5,(0,0,255),cv2.FILLED) 550 | 551 | 552 | num_ = len(i_test) #len(poses)//10 553 | cam_ind = i_test# np.arange(0,len(poses),num_) 554 | counter = 0 555 | 556 | img = to8b(images[cam_ind[counter],:,:,3::-1]) 557 | rays_o, rays_d = get_rays(H, W, K, poses[cam_ind[counter]]) 558 | rays = torch.stack([rays_o, rays_d],0) 559 | all_rays = [] 560 | 561 | text = "Please press A to continue to the next image " 562 | coordinates = (10,20) 563 | font = cv2.FONT_HERSHEY_SIMPLEX 564 | fontScale = 0.75 565 | color = (0,0,255) 566 | thickness = 0 567 | 568 | while True: 569 | 570 | 571 | img = cv2.putText(img, text, coordinates, font, fontScale, color, thickness, cv2.LINE_AA) 572 | cv2.imshow(" Image ", img) 573 | 574 | cv2.setMouseCallback(" Image ", mousePoints) 575 | key = cv2.waitKey(1) & 0xFF 576 | if key == 27: 577 | break 578 | 579 | if key == ord('a'): 580 | counter = counter+1 581 | pts_inside = np.array(point_matrix) 582 | all_rays.append(rays[:,pts_inside[:,1],pts_inside[:,0],:]) 583 | if counter == num_: 584 | break 585 | rays_o, rays_d = get_rays(H, W, K, poses[cam_ind[counter]]) 586 | rays = torch.stack([rays_o, rays_d],0) 587 | point_matrix = [] 588 | img = to8b(images[cam_ind[counter],:,:,3::-1]) 589 | 590 | 591 | 592 | 593 | cv2.destroyAllWindows() 594 | 595 | all_rays = torch.cat(all_rays,1) 596 | 597 | print('Started finding the transparent object bounding box') 598 | render_kwargs_test['bb_vals'] = [] 599 | render_kwargs_test['masking'] = False 600 | render_kwargs_test['N_importance'] = 256 601 | 602 | with torch.no_grad(): 603 | _, _, _,weights, _ = render(H, W,K, chunk=args.chunk, rays= all_rays, 604 | **render_kwargs_test) 605 | 606 | point_cloud = weights.cpu().numpy() 607 | 608 | 609 | max_pts = np.percentile((point_cloud),100 - args.quantile,axis=0) 610 | min_pts = np.percentile((point_cloud),args.quantile,axis=0) 611 | center = 0.5*(min_pts+max_pts) 612 | radi = max_pts-center 613 | radi = args.ratio*radi 614 | 615 | bounding_box_vals = np.stack((center,radi),0) 616 | 617 | path_bounding_box = os.path.join(basedir, expname,"bounding_box") 618 | os.makedirs(path_bounding_box, exist_ok=True) 619 | path = os.path.join(path_bounding_box, 'bounding_box_vals.npy') 620 | np.save(path,bounding_box_vals) 621 | print('Finised finding the transparent object bounding box') 622 | 623 | 624 | else: 625 | 626 | path_bounding_box = os.path.join(basedir, expname,"bounding_box") 627 | path = os.path.join(path_bounding_box, 'bounding_box_vals.npy') 628 | bounding_box_vals = np.load(path) 629 | 630 | 631 | 632 | 633 | render_kwargs_test['bb_vals'] = bounding_box_vals 634 | render_kwargs_test['masking'] = True 635 | render_kwargs_test['N_importance'] = 64 636 | 637 | 638 | 639 | print('Rendering a test image with and without transparent object') 640 | 641 | path_bounding_box = os.path.join(basedir, expname,"bounding_box") 642 | 643 | img_i = i_test[5] 644 | 645 | with torch.no_grad(): 646 | rgb_gt_without, _,_,_,_ = render(H, W,K, chunk=args.chunk, c2w= poses[img_i], 647 | **render_kwargs_test) 648 | 649 | 650 | imageio.imwrite(os.path.join(path_bounding_box, 'without_{:03d}.png'.format(img_i)), to8b(rgb_gt_without.cpu().numpy())) 651 | 652 | render_kwargs_test['masking'] = False 653 | 654 | with torch.no_grad(): 655 | rgb_gt_with, _,_,_,_ = render(H, W,K, chunk=args.chunk, c2w= poses[img_i], 656 | **render_kwargs_test) 657 | 658 | imageio.imwrite(os.path.join(path_bounding_box, 'with_{:03d}.png'.format(img_i)), to8b(rgb_gt_with.cpu().numpy())) 659 | 660 | print('Done!') 661 | print('Saved in the folder \"bounding_box\"') 662 | 663 | 664 | print('Finding the region in each image crossing the 3D bounding box') 665 | 666 | path_to_mask_out = os.path.join(basedir, expname, 'masked_regions') 667 | os.makedirs(path_to_mask_out, exist_ok=True) 668 | 669 | 670 | for img_i in tqdm(range(len(poses))): 671 | 672 | pose = poses[img_i, :3,:4] 673 | rays_o, rays_d = get_rays(H, W, K, pose) 674 | t_vals = torch.linspace(0., 1., steps=128) 675 | z_vals = near * (1.-t_vals) + far * (t_vals) 676 | pts_ = rays_o[...,None,:] + rays_d[...,None,:]*z_vals[:,None] 677 | mask = 1. - get_bb_weights(pts_,bounding_box_vals) 678 | diff = torch.sum(mask,2)>0.0 679 | imageio.imwrite(os.path.join(path_to_mask_out, 'img_%0.3d.png'%(img_i)), to8b(diff.cpu().numpy())) 680 | 681 | print('Done!') 682 | print('Saved in the folder \"masked_regions\"') 683 | 684 | 685 | if __name__=='__main__': 686 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 687 | 688 | find_box() -------------------------------------------------------------------------------- /render_model.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys,cv2 3 | import numpy as np 4 | import imageio 5 | import json 6 | import random 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from tqdm import tqdm, trange 12 | from torchdiffeq import odeint_adjoint as odeint 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | from run_nerf_helpers import * 17 | 18 | from load_llff import load_llff_data 19 | from load_deepvoxels import load_dv_data 20 | from load_blender import load_blender_data 21 | from load_LINEMOD import load_LINEMOD_data 22 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 23 | 24 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 25 | 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | np.random.seed(0) 28 | DEBUG = False 29 | 30 | 31 | def batchify(fn, chunk): 32 | """Constructs a version of 'fn' that applies to smaller batches. 33 | """ 34 | if chunk is None: 35 | return fn 36 | def ret(inputs): 37 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 38 | return ret 39 | 40 | 41 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 42 | """Prepares inputs and applies network 'fn'. 43 | """ 44 | embedded = embed_fn(inputs) 45 | 46 | if viewdirs is not None: 47 | 48 | embedded_dirs = embeddirs_fn(viewdirs) 49 | embedded = torch.cat([embedded, embedded_dirs], -1) 50 | 51 | outputs_flat = batchify(fn, netchunk)(embedded) 52 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 53 | return outputs 54 | 55 | 56 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64): 57 | 58 | embedded = embed_fn(inputs) 59 | 60 | 61 | outputs = batchify(fn, netchunk)(embedded) 62 | 63 | return outputs 64 | 65 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 66 | """Render rays in smaller minibatches to avoid OOM. 67 | """ 68 | all_ret = {} 69 | for i in range(0, rays_flat.shape[0], chunk): 70 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 71 | for k in ret: 72 | if k not in all_ret: 73 | all_ret[k] = [] 74 | all_ret[k].append(ret[k]) 75 | 76 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 77 | return all_ret 78 | 79 | 80 | 81 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 82 | near=0., far=1., 83 | use_viewdirs=False, c2w_staticcam=None, 84 | **kwargs): 85 | """Render rays 86 | Args: 87 | H: int. Height of image in pixels. 88 | W: int. Width of image in pixels. 89 | focal: float. Focal length of pinhole camera. 90 | chunk: int. Maximum number of rays to process simultaneously. Used to 91 | control maximum memory usage. Does not affect final results. 92 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 93 | each example in batch. 94 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 95 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 96 | near: float or array of shape [batch_size]. Nearest distance for a ray. 97 | far: float or array of shape [batch_size]. Farthest distance for a ray. 98 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 99 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 100 | camera while using other c2w argument for viewing directions. 101 | Returns: 102 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 103 | disp_map: [batch_size]. Disparity map. Inverse of depth. 104 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 105 | extras: dict with everything returned by render_rays(). 106 | """ 107 | if c2w is not None: 108 | # special case to render full image 109 | rays_o, rays_d = get_rays(H, W, K, c2w) 110 | else: 111 | # use provided ray batch 112 | rays_o, rays_d = rays 113 | 114 | if use_viewdirs: 115 | # provide ray directions as input 116 | viewdirs = rays_d 117 | if c2w_staticcam is not None: 118 | # special case to visualize effect of viewdirs 119 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 120 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 121 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 122 | 123 | sh = rays_d.shape # [..., 3] 124 | if ndc: 125 | # for forward facing scenes 126 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 127 | 128 | # Create ray batch 129 | rays_o = torch.reshape(rays_o, [-1,3]).float() 130 | rays_d = torch.reshape(rays_d, [-1,3]).float() 131 | 132 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 133 | rays = torch.cat([rays_o, rays_d, near, far], -1) 134 | if use_viewdirs: 135 | rays = torch.cat([rays, viewdirs], -1) 136 | 137 | # Render and reshape 138 | all_ret = batchify_rays(rays, chunk, **kwargs) 139 | for k in all_ret: 140 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 141 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 142 | 143 | k_extract = ['rgb_map'] 144 | ret_list = [all_ret[k] for k in k_extract] 145 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 146 | return ret_list + [ret_dict] 147 | 148 | 149 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 150 | 151 | H, W, focal = hwf 152 | 153 | if render_factor!=0: 154 | # Render downsampled for speed 155 | H = H//render_factor 156 | W = W//render_factor 157 | focal = focal/render_factor 158 | 159 | rgbs = [] 160 | # disps = [] 161 | 162 | t = time.time() 163 | for i, c2w in enumerate(tqdm(render_poses)): 164 | print(i, time.time() - t) 165 | t = time.time() 166 | rgb, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 167 | rgbs.append(rgb.cpu().numpy()) 168 | # disps.append(disp.cpu().numpy()) 169 | # if i==0: 170 | # print(rgb.shape, disp.shape) 171 | 172 | """ 173 | if gt_imgs is not None and render_factor==0: 174 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 175 | print(p) 176 | """ 177 | 178 | if savedir is not None: 179 | rgb8 = to8b(rgbs[-1]) 180 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 181 | imageio.imwrite(filename, rgb8) 182 | 183 | 184 | rgbs = np.stack(rgbs, 0) 185 | 186 | return rgbs 187 | 188 | def create_models(args): 189 | """Instantiate NeRF's MLP model. 190 | """ 191 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 192 | 193 | input_ch_views = 0 194 | embeddirs_fn = None 195 | if args.use_viewdirs: 196 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 197 | output_ch = 4 198 | skips = [4] 199 | model = NeRF(D=args.netdepth, W=args.netwidth, 200 | input_ch=input_ch, output_ch=output_ch, skips=skips, 201 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 202 | 203 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 204 | input_ch=input_ch, output_ch=output_ch, skips=skips, 205 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 206 | 207 | 208 | 209 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed) 210 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device) 211 | model_ior.apply(init_weights) 212 | 213 | model_inside = NeRF(D=args.netdepth, W=args.netwidth_inside, 214 | input_ch=input_ch, output_ch=output_ch, skips=skips, 215 | input_ch_views=input_ch_views, use_viewdirs=True).to(device) 216 | 217 | grad_vars = list(model_inside.parameters()) 218 | 219 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 220 | embed_fn=embed_fn, 221 | embeddirs_fn=embeddirs_fn, 222 | netchunk=args.netchunk) 223 | 224 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn, 225 | embed_fn = embed_fn_ior, 226 | netchunk=args.netchunk) 227 | # Create optimizer 228 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 229 | 230 | start = 0 231 | basedir = args.basedir 232 | expname = args.expname 233 | 234 | ########################## 235 | 236 | # Load checkpoints 237 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f] 238 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path)) 239 | 240 | 241 | print('Found ckpts', ckpts) 242 | 243 | if len(ckpts) > 0 and not args.no_reload: 244 | ckpt_path = ckpts[-1] 245 | print('Reloading from', ckpt_path) 246 | ckpt = torch.load(ckpt_path) 247 | 248 | start = ckpt['global_step'] 249 | 250 | # Load model 251 | model.load_state_dict(ckpt['network_fn_state_dict']) 252 | 253 | if model_fine is not None: 254 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 255 | 256 | if "network_ior_state_dict" in ckpt: 257 | model_ior.load_state_dict(ckpt['network_ior_state_dict']) 258 | 259 | if "network_inside_state_dict" in ckpt: 260 | model_inside.load_state_dict(ckpt['network_inside_state_dict']) 261 | 262 | 263 | 264 | ########################## 265 | 266 | render_kwargs_train = { 267 | 268 | 'network_query_fn' : network_query_fn, 269 | 'network_query_fn_ior': network_query_fn_ior, 270 | 'perturb' : args.perturb, 271 | 'N_samples' : args.N_samples, 272 | 'network_fn' : model, 273 | 'network_fine' : model_fine, 274 | 'network_ior': model_ior, 275 | 'network_inside':model_inside, 276 | 'use_viewdirs' : args.use_viewdirs, 277 | 'white_bkgd' : args.white_bkgd, 278 | 'raw_noise_std' : args.raw_noise_std, 279 | } 280 | 281 | # NDC only good for LLFF-style forward facing data 282 | if args.dataset_type != 'llff' or args.no_ndc: 283 | print('Not ndc!') 284 | render_kwargs_train['ndc'] = False 285 | render_kwargs_train['lindisp'] = args.lindisp 286 | 287 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 288 | render_kwargs_test['perturb'] = False 289 | render_kwargs_test['raw_noise_std'] = 0. 290 | 291 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 292 | 293 | 294 | class IoR_emissionAbsorptionODE(nn.Module): 295 | 296 | def __init__(self,query_nerf,model_nerf,nerf_inside,query_ior,model_ior,n_samples,step_size,vals,viewdir,mode): 297 | super(IoR_emissionAbsorptionODE, self).__init__() 298 | 299 | self.query_ior = query_ior 300 | self.model_ior = model_ior 301 | 302 | self.query_nerf = query_nerf 303 | self.model_nerf = model_nerf 304 | 305 | self.nerf_inside = nerf_inside 306 | 307 | self.n_samples = n_samples 308 | self.step_size = step_size 309 | self.vals = vals 310 | self.viewdir = viewdir 311 | self.mode = mode 312 | 313 | def forward(self,t,y): 314 | 315 | 316 | pts = y[:,0:3] 317 | ray_dir = y[:,3:6] 318 | transmition_ = y[:,9:10] 319 | 320 | 321 | # querying the original radiance field for the entire scene 322 | with torch.no_grad(): 323 | raw = self.query_nerf(pts,self.viewdir,self.model_nerf) 324 | 325 | rgb = torch.sigmoid(raw[...,:3]) 326 | density = F.relu(raw[...,3:]) 327 | 328 | # determining whether a point is inside the bounding box or not 329 | weights= get_bb_weights(pts,self.vals) 330 | 331 | # finding the points inside the bounding box 332 | 333 | insides = torch.where(weights<1.0)[0] 334 | ior_grad = torch.zeros_like(pts) 335 | steps = torch.ones_like(density)*self.step_size 336 | 337 | 338 | 339 | 340 | # calculating NeRF model only for the points inside the bounding box 341 | if self.mode != 0: 342 | 343 | if len(insides) > 0: 344 | 345 | 346 | # querying the radiance field for the content inside the bounding box 347 | pts_inside = pts[insides,:] 348 | viewdir_inside = self.viewdir[insides,:] 349 | 350 | 351 | if self.mode == 1: 352 | density = density*weights 353 | 354 | if self.mode == 2: 355 | raw = self.query_nerf(pts_inside,viewdir_inside,self.nerf_inside) 356 | rgb_inside = torch.sigmoid(raw[...,:3]) 357 | density_inside = F.relu(raw[...,3:]) 358 | 359 | 360 | # linearly blnding the radiance field for the content inside and outisde the bounding box 361 | density[insides,:] = density[insides,:]*(weights[insides,:])+ density_inside*(1.-weights[insides,:]) 362 | rgb[insides,:] = rgb[insides,:]*(weights[insides,:])+ rgb_inside*(1.-weights[insides,:]) 363 | 364 | 365 | # computing the gradient of IoR 366 | with torch.enable_grad(): 367 | 368 | dn = torch.autograd.functional.vjp(lambda x : self.query_ior(x,self.model_ior), pts_inside ,v = torch.ones_like(pts_inside[:,0:1])) 369 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:]) 370 | 371 | 372 | dv_ds = steps*ior_grad 373 | dx_ds = steps*normalizing(ray_dir) 374 | alpha = 1. - torch.exp(-density*steps/self.n_samples) 375 | alpha = alpha.clip(0.,1.) 376 | dc_dt = transmition_*rgb*alpha*self.n_samples 377 | dT_dt = -transmition_*alpha*self.n_samples 378 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1) 379 | 380 | 381 | return dy_dt 382 | 383 | 384 | def render_rays(ray_batch, 385 | N_samples, 386 | network_query_fn_ior, 387 | bb_vals, 388 | network_fine, 389 | network_ior, 390 | network_inside, 391 | mode, 392 | network_fn=None, 393 | network_query_fn=None, 394 | retraw=False, 395 | lindisp=False, 396 | perturb=0., 397 | N_importance=0, 398 | white_bkgd=False, 399 | raw_noise_std=0., 400 | verbose=False, 401 | pytest=False): 402 | 403 | N_rays = ray_batch.shape[0] 404 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 405 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 406 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 407 | 408 | t_vals = torch.linspace(0., 1., steps=N_samples) 409 | 410 | 411 | pts = rays_o + normalizing(rays_d)*near 412 | viewdir = normalizing(rays_d) 413 | color_ = torch.zeros_like(pts) 414 | transmition_ = torch.ones((N_rays,1)) 415 | 416 | 417 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device) 418 | step_size = (far[0]-near[0]) 419 | output = odeint(IoR_emissionAbsorptionODE(network_query_fn,network_fine,network_inside,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,viewdir,mode), y0, t_vals,method='euler') 420 | 421 | pts = output[:,:,0:3].permute(1,0,2) 422 | 423 | rgb_map = output[-1,:,6:9] 424 | 425 | 426 | ret = {'rgb_map' : rgb_map} 427 | if retraw: 428 | ret['raw'] = pts 429 | 430 | for k in ret: 431 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 432 | print(f"! [Numerical Error] {k} contains nan or inf.") 433 | 434 | return ret 435 | 436 | 437 | 438 | def config_parser(): 439 | 440 | import configargparse 441 | parser = configargparse.ArgumentParser() 442 | parser.add_argument('--config', is_config_file=True, 443 | help='config file path') 444 | parser.add_argument("--expname", default='glass_ball_new',type=str, 445 | help='experiment name') 446 | parser.add_argument("--basedir", type=str, default='logs', 447 | help='where to store ckpts and logs') 448 | parser.add_argument("--datadir", type=str, default='data\\glass_ball', 449 | help='input data directory') 450 | parser.add_argument("--render_from_path", type=str, default=None, 451 | help='input data directory') 452 | 453 | 454 | # training options 455 | parser.add_argument("--netdepth", type=int, default=8, 456 | help='layers in network') 457 | parser.add_argument("--netwidth", type=int, default=256, 458 | help='channels per layer') 459 | parser.add_argument("--netdepth_fine", type=int, default=8, 460 | help='layers in fine network') 461 | parser.add_argument("--netwidth_fine", type=int, default=256, 462 | help='channels per layer in fine network') 463 | parser.add_argument("--N_rand", type=int, default=32*32, 464 | help='batch size (number of random rays per gradient step)') 465 | parser.add_argument("--lrate", type=float, default=5e-4, 466 | help='learning rate') 467 | parser.add_argument("--lrate_decay", type=int, default=250, 468 | help='exponential learning rate decay (in 1000 steps)') 469 | parser.add_argument("--chunk", type=int, default=1024*32*32, 470 | help='number of rays processed in parallel, decrease if running out of memory') 471 | parser.add_argument("--netchunk", type=int, default=1024*64*32, 472 | help='number of pts sent through network in parallel, decrease if running out of memory') 473 | parser.add_argument("--no_batching", default=False,action='store_true', 474 | help='only take random rays from 1 image at a time') 475 | parser.add_argument("--no_reload", action='store_true', 476 | help='do not reload weights from saved ckpt') 477 | parser.add_argument("--model_path", type=str, default='model_weights', 478 | help='path to trained model weights') 479 | 480 | 481 | # IoR model parameters 482 | 483 | parser.add_argument("--netdepth_ior", type=int, default=6, 484 | help='layers in the IoR network') 485 | parser.add_argument("--netwidth_ior", type=int, default=64, 486 | help='channels per layer') 487 | parser.add_argument("--multires_views_ior", type=int, default=5, 488 | help='log2 of max freq for positional encoding (2D direction)') 489 | 490 | 491 | # NeRF model inside parameters 492 | 493 | parser.add_argument("--netwidth_inside", type=int, default=128, 494 | help='layers in the IoR network') 495 | 496 | # NeRF model inside parameters 497 | 498 | parser.add_argument("--mode", type=int, default=1, 499 | help='rendering mode: 0:NeRF 1:NeRF+IoR 2:NeRF+IOR+NeRF_inside') 500 | 501 | 502 | 503 | # rendering options 504 | parser.add_argument("--N_samples", type=int, default=512, 505 | help='number of coarse samples per ray') 506 | parser.add_argument("--N_importance", type=int, 507 | help='Not used in this code') 508 | parser.add_argument("--perturb", type=float, default=1., 509 | help='set to 0. for no jitter, 1. for jitter') 510 | parser.add_argument("--use_viewdirs", default=True,action='store_true', 511 | help='use full 5D input instead of 3D') 512 | parser.add_argument("--use_mask", default=False,action='store_true', 513 | help='use full 5D input instead of 3D') 514 | parser.add_argument("--i_embed", type=int, default=0, 515 | help='set 0 for default positional encoding, -1 for none') 516 | parser.add_argument("--multires", type=int, default=10, 517 | help='log2 of max freq for positional encoding (3D location)') 518 | parser.add_argument("--multires_views", type=int, default=4, 519 | help='log2 of max freq for positional encoding (2D direction)') 520 | parser.add_argument("--raw_noise_std", type=float, default=1., 521 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 522 | 523 | parser.add_argument("--render_video", action='store_true', 524 | help='do not optimize, reload weights and render out render_poses path') 525 | parser.add_argument("--render_test", action='store_true', 526 | help='render the test set instead of render_poses path') 527 | parser.add_argument("--render_factor", type=int, default=0, 528 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 529 | 530 | # training options 531 | parser.add_argument("--precrop_iters", type=int, default=0, 532 | help='number of steps to train on central crops') 533 | parser.add_argument("--precrop_frac", type=float, 534 | default=.5, help='fraction of img taken for central crops') 535 | 536 | # dataset options 537 | parser.add_argument("--dataset_type", type=str, default='llff', 538 | help='options: llff / blender / deepvoxels') 539 | parser.add_argument("--testskip", type=int, default=8, 540 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 541 | 542 | ## deepvoxels flags 543 | parser.add_argument("--shape", type=str, default='greek', 544 | help='options : armchair / cube / greek / vase') 545 | 546 | ## blender flags 547 | parser.add_argument("--white_bkgd", action='store_true', 548 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 549 | parser.add_argument("--half_res", action='store_true', 550 | help='load blender synthetic data at 400x400 instead of 800x800') 551 | 552 | ## llff flags 553 | parser.add_argument("--factor", type=int, default=1, 554 | help='downsample factor for LLFF images') 555 | parser.add_argument("--no_ndc", default = True,action='store_true', 556 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 557 | parser.add_argument("--lindisp", action='store_true', 558 | help='sampling linearly in disparity rather than depth') 559 | parser.add_argument("--spherify", default = True, action='store_true', 560 | help='set for spherical 360 scenes') 561 | parser.add_argument("--llffhold", type=int, default=10, 562 | help='will take every 1/N images as LLFF test set, paper uses 8') 563 | 564 | # logging/saving options 565 | parser.add_argument("--i_print", type=int, default=100, 566 | help='frequency of console printout and metric loggin') 567 | parser.add_argument("--i_img", type=int, default=100, 568 | help='frequency of tensorboard image logging') 569 | parser.add_argument("--i_weights", type=int, default=500, 570 | help='frequency of weight ckpt saving') 571 | 572 | 573 | return parser 574 | 575 | 576 | def main(): 577 | 578 | parser = config_parser() 579 | args = parser.parse_args() 580 | 581 | # Load data 582 | K = None 583 | if args.dataset_type == 'llff': 584 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 585 | recenter=True, bd_factor=.75, 586 | spherify=args.spherify,path_video=args.render_from_path) #_ 587 | hwf = render_poses[0,:3,-1] 588 | poses = poses[:,:3,:4] 589 | 590 | # plt.plot(poses[:,0,3],poses[:,1,3],'.') 591 | # plt.plot(render_poses[:,0,3],render_poses[:,1,3],'.') 592 | # plt.show() 593 | 594 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 595 | if not isinstance(i_test, list): 596 | i_test = [i_test] 597 | 598 | if args.llffhold > 0: 599 | print('Auto LLFF holdout,', args.llffhold) 600 | i_test = np.arange(images.shape[0])[::args.llffhold] 601 | 602 | i_val = i_test 603 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 604 | (i not in i_test and i not in i_val)]) 605 | 606 | print('DEFINING BOUNDS') 607 | if args.no_ndc: 608 | near = np.ndarray.min(bds) * .9 609 | far = np.ndarray.max(bds) * 1. 610 | 611 | else: 612 | near = 0. 613 | far = 1. 614 | if args.spherify: 615 | far = np.minimum(far,2.5) 616 | print('NEAR FAR', near, far) 617 | 618 | # Cast intrinsics to right types 619 | H, W, focal = hwf 620 | H, W = int(H), int(W) 621 | hwf = [H, W, focal] 622 | 623 | if K is None: 624 | K = np.array([ 625 | [focal, 0, 0.5*W], 626 | [0, focal, 0.5*H], 627 | [0, 0, 1] 628 | ]) 629 | 630 | 631 | # Create log dir and copy the config file 632 | basedir = args.basedir 633 | expname = args.expname 634 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 635 | f = os.path.join(basedir, expname, 'args.txt') 636 | with open(f, 'w') as file: 637 | for arg in sorted(vars(args)): 638 | attr = getattr(args, arg) 639 | file.write('{} = {}\n'.format(arg, attr)) 640 | if args.config is not None: 641 | f = os.path.join(basedir, expname, 'config.txt') 642 | with open(f, 'w') as file: 643 | file.write(open(args.config, 'r').read()) 644 | 645 | 646 | 647 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args) 648 | 649 | 650 | 651 | bds_dict = { 652 | 'near' : near, 653 | 'far' : far, 654 | } 655 | 656 | render_kwargs_test.update(bds_dict) 657 | render_kwargs_test['mode'] = args.mode 658 | 659 | print('loading bounding box values') 660 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy')) 661 | 662 | render_kwargs_test['bb_vals'] = bounding_box_vals 663 | 664 | 665 | with torch.no_grad(): 666 | 667 | 668 | 669 | if args.render_from_path is not None: 670 | 671 | end_ = len(render_poses) - len(poses) 672 | render_poses = torch.Tensor(render_poses[:end_]).to(device) 673 | 674 | 675 | testsavedir = os.path.join(basedir, expname, 'rendered_from_a_path') 676 | os.makedirs(testsavedir, exist_ok=True) 677 | print('test poses shape', render_poses.shape) 678 | 679 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor) 680 | print('Done rendering', testsavedir) 681 | imageio.mimwrite(os.path.join(testsavedir, 'rendred.mp4'), to8b(rgbs), fps=10, quality=8) 682 | 683 | 684 | if args.render_test: 685 | 686 | render_poses = np.array(poses[i_test]) 687 | render_poses = torch.Tensor(render_poses).to(device) 688 | 689 | 690 | testsavedir = os.path.join(basedir, expname, 'test_imgs') 691 | os.makedirs(testsavedir, exist_ok=True) 692 | print('test poses shape', render_poses.shape) 693 | 694 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor) 695 | 696 | if args.render_video: 697 | 698 | render_poses = torch.Tensor(render_poses).to(device) 699 | 700 | testsavedir = os.path.join(basedir, expname, 'rendered_video') 701 | os.makedirs(testsavedir, exist_ok=True) 702 | print('test poses shape', render_poses.shape) 703 | 704 | rgbs = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir, render_factor=args.render_factor) 705 | print('Done rendering', testsavedir) 706 | imageio.mimwrite(os.path.join(testsavedir, 'rendred.mp4'), to8b(rgbs), fps=10, quality=8) 707 | 708 | 709 | if __name__=='__main__': 710 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 711 | 712 | main() 713 | 714 | 715 | -------------------------------------------------------------------------------- /run_nerf_inside.py: -------------------------------------------------------------------------------- 1 | import os, sys,cv2 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm, trange 11 | from torchdiffeq import odeint_adjoint as odeint 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from run_nerf_helpers import * 16 | 17 | from load_llff import load_llff_data 18 | from load_deepvoxels import load_dv_data 19 | from load_blender import load_blender_data 20 | from load_LINEMOD import load_LINEMOD_data 21 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 22 | 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | np.random.seed(0) 26 | DEBUG = False 27 | 28 | 29 | def get_rnd(all_rays,all_idx_inside,num_ray): 30 | 31 | 32 | all_ = [] 33 | chunk_per_im = num_ray//len(all_idx_inside) 34 | for i in range(len(all_idx_inside)): 35 | 36 | idx_inside = all_idx_inside[i][np.random.choice(len(all_idx_inside[i]), size=[chunk_per_im], replace=False)] 37 | all_.append(all_rays[i][idx_inside]) 38 | 39 | return torch.cat(all_,0) 40 | 41 | def get_mask_idx(path_to_mask,num_im,i_train): 42 | 43 | mask_ = [] 44 | for img_i in range(num_im): 45 | 46 | im = imageio.imread(os.path.join(path_to_mask, 'img_%0.3d.png'%(img_i))) 47 | im = np.float32(im)/255.0 48 | im = im.reshape([-1]) 49 | inside = np.where(im== 1.0)[0] 50 | mask_.append(inside) 51 | 52 | 53 | mask = [mask_[i] for i in i_train] 54 | 55 | return mask 56 | 57 | 58 | 59 | def batchify(fn, chunk): 60 | """Constructs a version of 'fn' that applies to smaller batches. 61 | """ 62 | if chunk is None: 63 | return fn 64 | def ret(inputs): 65 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 66 | return ret 67 | 68 | 69 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 70 | """Prepares inputs and applies network 'fn'. 71 | """ 72 | embedded = embed_fn(inputs) 73 | 74 | if viewdirs is not None: 75 | 76 | embedded_dirs = embeddirs_fn(viewdirs) 77 | embedded = torch.cat([embedded, embedded_dirs], -1) 78 | 79 | outputs_flat = batchify(fn, netchunk)(embedded) 80 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 81 | return outputs 82 | 83 | 84 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64): 85 | 86 | embedded = embed_fn(inputs) 87 | 88 | 89 | outputs = batchify(fn, netchunk)(embedded) 90 | 91 | return outputs 92 | 93 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 94 | """Render rays in smaller minibatches to avoid OOM. 95 | """ 96 | all_ret = {} 97 | for i in range(0, rays_flat.shape[0], chunk): 98 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 99 | for k in ret: 100 | if k not in all_ret: 101 | all_ret[k] = [] 102 | all_ret[k].append(ret[k]) 103 | 104 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 105 | return all_ret 106 | 107 | 108 | 109 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 110 | near=0., far=1., 111 | use_viewdirs=False, c2w_staticcam=None, 112 | **kwargs): 113 | """Render rays 114 | Args: 115 | H: int. Height of image in pixels. 116 | W: int. Width of image in pixels. 117 | focal: float. Focal length of pinhole camera. 118 | chunk: int. Maximum number of rays to process simultaneously. Used to 119 | control maximum memory usage. Does not affect final results. 120 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 121 | each example in batch. 122 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 123 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 124 | near: float or array of shape [batch_size]. Nearest distance for a ray. 125 | far: float or array of shape [batch_size]. Farthest distance for a ray. 126 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 127 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 128 | camera while using other c2w argument for viewing directions. 129 | Returns: 130 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 131 | disp_map: [batch_size]. Disparity map. Inverse of depth. 132 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 133 | extras: dict with everything returned by render_rays(). 134 | """ 135 | if c2w is not None: 136 | # special case to render full image 137 | rays_o, rays_d = get_rays(H, W, K, c2w) 138 | else: 139 | # use provided ray batch 140 | rays_o, rays_d = rays 141 | 142 | if use_viewdirs: 143 | # provide ray directions as input 144 | viewdirs = rays_d 145 | if c2w_staticcam is not None: 146 | # special case to visualize effect of viewdirs 147 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 148 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 149 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 150 | 151 | sh = rays_d.shape # [..., 3] 152 | if ndc: 153 | # for forward facing scenes 154 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 155 | 156 | # Create ray batch 157 | rays_o = torch.reshape(rays_o, [-1,3]).float() 158 | rays_d = torch.reshape(rays_d, [-1,3]).float() 159 | 160 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 161 | rays = torch.cat([rays_o, rays_d, near, far], -1) 162 | if use_viewdirs: 163 | rays = torch.cat([rays, viewdirs], -1) 164 | 165 | # Render and reshape 166 | all_ret = batchify_rays(rays, chunk, **kwargs) 167 | for k in all_ret: 168 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 169 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 170 | 171 | k_extract = ['rgb_map'] 172 | ret_list = [all_ret[k] for k in k_extract] 173 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 174 | return ret_list + [ret_dict] 175 | 176 | 177 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 178 | 179 | H, W, focal = hwf 180 | 181 | if render_factor!=0: 182 | # Render downsampled for speed 183 | H = H//render_factor 184 | W = W//render_factor 185 | focal = focal/render_factor 186 | 187 | rgbs = [] 188 | # disps = [] 189 | 190 | t = time.time() 191 | for i, c2w in enumerate(tqdm(render_poses)): 192 | print(i, time.time() - t) 193 | t = time.time() 194 | rgb, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 195 | rgbs.append(rgb.cpu().numpy()) 196 | # disps.append(disp.cpu().numpy()) 197 | # if i==0: 198 | # print(rgb.shape, disp.shape) 199 | 200 | """ 201 | if gt_imgs is not None and render_factor==0: 202 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 203 | print(p) 204 | """ 205 | 206 | if savedir is not None: 207 | rgb8 = to8b(rgbs[-1]) 208 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 209 | imageio.imwrite(filename, rgb8) 210 | 211 | 212 | rgbs = np.stack(rgbs, 0) 213 | # disps = np.stack(disps, 0) 214 | 215 | return rgbs 216 | 217 | 218 | def create_models(args): 219 | """Instantiate NeRF's MLP model. 220 | """ 221 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 222 | 223 | input_ch_views = 0 224 | embeddirs_fn = None 225 | if args.use_viewdirs: 226 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 227 | output_ch = 4 228 | skips = [4] 229 | model = NeRF(D=args.netdepth, W=args.netwidth, 230 | input_ch=input_ch, output_ch=output_ch, skips=skips, 231 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 232 | 233 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 234 | input_ch=input_ch, output_ch=output_ch, skips=skips, 235 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 236 | 237 | 238 | 239 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed) 240 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device) 241 | model_ior.apply(init_weights) 242 | 243 | model_inside = NeRF(D=args.netdepth, W=args.netwidth_inside, 244 | input_ch=input_ch, output_ch=output_ch, skips=skips, 245 | input_ch_views=input_ch_views, use_viewdirs=True).to(device) 246 | 247 | grad_vars = list(model_inside.parameters()) 248 | 249 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 250 | embed_fn=embed_fn, 251 | embeddirs_fn=embeddirs_fn, 252 | netchunk=args.netchunk) 253 | 254 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn, 255 | embed_fn = embed_fn_ior, 256 | netchunk=args.netchunk) 257 | # Create optimizer 258 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 259 | 260 | start = 0 261 | basedir = args.basedir 262 | expname = args.expname 263 | 264 | ########################## 265 | 266 | # Load checkpoints 267 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f] 268 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path)) 269 | 270 | 271 | print('Found ckpts', ckpts) 272 | 273 | if len(ckpts) > 0: 274 | ckpt_path = ckpts[-1] 275 | print('Reloading from', ckpt_path) 276 | ckpt = torch.load(ckpt_path) 277 | 278 | start = ckpt['global_step'] 279 | 280 | # Load model 281 | model.load_state_dict(ckpt['network_fn_state_dict']) 282 | 283 | if model_fine is not None: 284 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 285 | 286 | model_ior.load_state_dict(ckpt['network_ior_state_dict']) 287 | print('Reloading the weights for the NeRF and IoR MLPs') 288 | 289 | if "network_inside_state_dict" in ckpt and not args.no_reload: 290 | model_inside.load_state_dict(ckpt['network_inside_state_dict']) 291 | print('Reloading the weights for the inside NeRF MLP') 292 | 293 | 294 | 295 | ########################## 296 | 297 | render_kwargs_train = { 298 | 299 | 'network_query_fn' : network_query_fn, 300 | 'network_query_fn_ior': network_query_fn_ior, 301 | 'perturb' : args.perturb, 302 | 'N_samples' : args.N_samples, 303 | 'network_fn' : model, 304 | 'network_fine' : model_fine, 305 | 'network_ior': model_ior, 306 | 'network_inside':model_inside, 307 | 'use_viewdirs' : args.use_viewdirs, 308 | 'white_bkgd' : args.white_bkgd, 309 | 'raw_noise_std' : args.raw_noise_std, 310 | } 311 | 312 | # NDC only good for LLFF-style forward facing data 313 | if args.dataset_type != 'llff' or args.no_ndc: 314 | print('Not ndc!') 315 | render_kwargs_train['ndc'] = False 316 | render_kwargs_train['lindisp'] = args.lindisp 317 | 318 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 319 | render_kwargs_test['perturb'] = False 320 | render_kwargs_test['raw_noise_std'] = 0. 321 | 322 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 323 | 324 | 325 | 326 | class IoR_emissionAbsorptionODE(nn.Module): 327 | 328 | def __init__(self,query_nerf,model_nerf,nerf_inside,query_ior,model_ior,n_samples,step_size,vals,viewdir): 329 | super(IoR_emissionAbsorptionODE, self).__init__() 330 | 331 | self.query_ior = query_ior 332 | self.model_ior = model_ior 333 | 334 | self.query_nerf = query_nerf 335 | self.model_nerf = model_nerf 336 | 337 | self.nerf_inside = nerf_inside 338 | 339 | self.n_samples = n_samples 340 | self.step_size = step_size 341 | self.vals = vals 342 | self.viewdir = viewdir 343 | 344 | def forward(self,t,y): 345 | 346 | 347 | pts = y[:,0:3] 348 | ray_dir = y[:,3:6] 349 | transmition_ = y[:,9:10] 350 | 351 | 352 | # querying the original radiance field for the entire scene 353 | with torch.no_grad(): 354 | raw = self.query_nerf(pts,self.viewdir,self.model_nerf) 355 | 356 | rgb = torch.sigmoid(raw[...,:3]) 357 | density = F.relu(raw[...,3:]) 358 | 359 | # determining whether a point is inside the bounding box or not 360 | weights= get_bb_weights(pts,self.vals) 361 | 362 | # finding the points inside the bounding box 363 | 364 | insides = torch.where(weights<1.0)[0] 365 | ior_grad = torch.zeros_like(pts) 366 | steps = torch.ones_like(density)*self.step_size 367 | 368 | 369 | # calculating NeRF model only for the points inside the bounding box 370 | 371 | if len(insides) > 0: 372 | 373 | 374 | # querying the radiance field for the content inside the bounding box 375 | 376 | pts_inside = pts[insides,:] 377 | viewdir_inside = self.viewdir[insides,:] 378 | raw = self.query_nerf(pts_inside,viewdir_inside,self.nerf_inside) 379 | rgb_inside = torch.sigmoid(raw[...,:3]) 380 | density_inside = F.relu(raw[...,3:]) 381 | 382 | 383 | # linearly blnding the radiance field for the content inside and outisde the bounding box 384 | density[insides,:] = density[insides,:]*(weights[insides,:])+ density_inside*(1.-weights[insides,:]) 385 | rgb[insides,:] = rgb[insides,:]*(weights[insides,:])+ rgb_inside*(1.-weights[insides,:]) 386 | 387 | # computing the gradient of IoR 388 | with torch.enable_grad(): 389 | 390 | dn = torch.autograd.functional.vjp(lambda x : self.query_ior(x,self.model_ior), pts_inside ,v = torch.ones_like(pts_inside[:,0:1])) 391 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:]) 392 | 393 | 394 | dv_ds = steps*ior_grad 395 | dx_ds = steps*normalizing(ray_dir) 396 | alpha = 1. - torch.exp(-density*steps/self.n_samples) 397 | alpha = alpha.clip(0.,1.) 398 | dc_dt = transmition_*rgb*alpha*self.n_samples 399 | dT_dt = -transmition_*alpha*self.n_samples 400 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1) 401 | 402 | 403 | return dy_dt 404 | 405 | 406 | def render_rays(ray_batch, 407 | N_samples, 408 | network_query_fn_ior, 409 | bb_vals, 410 | network_ior, 411 | network_inside, 412 | network_fn=None, 413 | network_query_fn=None, 414 | retraw=False, 415 | lindisp=False, 416 | perturb=0., 417 | network_fine=None, 418 | white_bkgd=False, 419 | raw_noise_std=0., 420 | verbose=False, 421 | pytest=False): 422 | 423 | N_rays = ray_batch.shape[0] 424 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 425 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 426 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 427 | 428 | t_vals = torch.linspace(0., 1., steps=N_samples) 429 | 430 | 431 | pts = rays_o + normalizing(rays_d)*near 432 | viewdir = normalizing(rays_d) 433 | color_ = torch.zeros_like(pts) 434 | transmition_ = torch.ones((N_rays,1)) 435 | 436 | 437 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device) 438 | step_size = far[0]-near[0] 439 | output = odeint(IoR_emissionAbsorptionODE(network_query_fn,network_fine,network_inside,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,viewdir), y0, t_vals,method='euler') 440 | 441 | pts = output[:,:,0:3].permute(1,0,2) 442 | 443 | rgb_map = output[-1,:,6:9] 444 | 445 | 446 | ret = {'rgb_map' : rgb_map} 447 | if retraw: 448 | ret['raw'] = pts 449 | 450 | for k in ret: 451 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 452 | print(f"! [Numerical Error] {k} contains nan or inf.") 453 | 454 | return ret 455 | 456 | def config_parser(): 457 | 458 | import configargparse 459 | parser = configargparse.ArgumentParser() 460 | parser.add_argument('--config', is_config_file=True, 461 | help='config file path') 462 | parser.add_argument("--expname", default='glass_ball_new',type=str, 463 | help='experiment name') 464 | parser.add_argument("--basedir", type=str, default='logs', 465 | help='where to store ckpts and logs') 466 | parser.add_argument("--datadir", type=str, default='data\\glass_ball', 467 | help='input data directory') 468 | 469 | # training options 470 | parser.add_argument("--netdepth", type=int, default=8, 471 | help='layers in network') 472 | parser.add_argument("--netwidth", type=int, default=256, 473 | help='channels per layer') 474 | parser.add_argument("--netdepth_fine", type=int, default=8, 475 | help='layers in fine network') 476 | parser.add_argument("--netwidth_fine", type=int, default=256, 477 | help='channels per layer in fine network') 478 | parser.add_argument("--N_rand", type=int, default=32*32, 479 | help='batch size (number of random rays per gradient step)') 480 | parser.add_argument("--lrate", type=float, default=5e-4, 481 | help='learning rate') 482 | parser.add_argument("--lrate_decay", type=int, default=250, 483 | help='exponential learning rate decay (in 1000 steps)') 484 | parser.add_argument("--chunk", type=int, default=1024*32*32, 485 | help='number of rays processed in parallel, decrease if running out of memory') 486 | parser.add_argument("--netchunk", type=int, default=1024*64*32, 487 | help='number of pts sent through network in parallel, decrease if running out of memory') 488 | parser.add_argument("--no_batching", default=False,action='store_true', 489 | help='only take random rays from 1 image at a time') 490 | parser.add_argument("--no_reload", default=True, action='store_true', 491 | help='do not reload weights of MLP for NeRF inside') 492 | parser.add_argument("--model_path", type=str, default='model_weights', 493 | help='path to trained model weights') 494 | 495 | 496 | # IoR model parameters 497 | 498 | parser.add_argument("--netdepth_ior", type=int, default=6, 499 | help='layers in the IoR network') 500 | parser.add_argument("--netwidth_ior", type=int, default=64, 501 | help='channels per layer') 502 | parser.add_argument("--multires_views_ior", type=int, default=5, 503 | help='log2 of max freq for positional encoding (2D direction)') 504 | 505 | 506 | # NeRF model inside parameters 507 | 508 | parser.add_argument("--netwidth_inside", type=int, default=128, 509 | help='layers in the IoR network') 510 | 511 | 512 | # rendering options 513 | parser.add_argument("--N_samples", type=int, default=512, 514 | help='number of coarse samples per ray') 515 | parser.add_argument("--N_importance", type=int, 516 | help='Not used in this code') 517 | parser.add_argument("--perturb", type=float, default=1., 518 | help='set to 0. for no jitter, 1. for jitter') 519 | parser.add_argument("--use_viewdirs", default=True,action='store_true', 520 | help='use full 5D input instead of 3D') 521 | parser.add_argument("--use_mask", default=False,action='store_true', 522 | help='use full 5D input instead of 3D') 523 | parser.add_argument("--i_embed", type=int, default=0, 524 | help='set 0 for default positional encoding, -1 for none') 525 | parser.add_argument("--multires", type=int, default=10, 526 | help='log2 of max freq for positional encoding (3D location)') 527 | parser.add_argument("--multires_views", type=int, default=4, 528 | help='log2 of max freq for positional encoding (2D direction)') 529 | parser.add_argument("--raw_noise_std", type=float, default=1., 530 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 531 | 532 | parser.add_argument("--render_only", action='store_true', 533 | help='do not optimize, reload weights and render out render_poses path') 534 | parser.add_argument("--render_test", action='store_true', 535 | help='render the test set instead of render_poses path') 536 | parser.add_argument("--render_factor", type=int, default=0, 537 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 538 | 539 | # training options 540 | parser.add_argument("--precrop_iters", type=int, default=0, 541 | help='number of steps to train on central crops') 542 | parser.add_argument("--precrop_frac", type=float, 543 | default=.5, help='fraction of img taken for central crops') 544 | 545 | # dataset options 546 | parser.add_argument("--dataset_type", type=str, default='llff', 547 | help='options: llff / blender / deepvoxels') 548 | parser.add_argument("--testskip", type=int, default=8, 549 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 550 | 551 | ## deepvoxels flags 552 | parser.add_argument("--shape", type=str, default='greek', 553 | help='options : armchair / cube / greek / vase') 554 | 555 | ## blender flags 556 | parser.add_argument("--white_bkgd", action='store_true', 557 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 558 | parser.add_argument("--half_res", action='store_true', 559 | help='load blender synthetic data at 400x400 instead of 800x800') 560 | 561 | ## llff flags 562 | parser.add_argument("--factor", type=int, default=1, 563 | help='downsample factor for LLFF images') 564 | parser.add_argument("--no_ndc", default = True,action='store_true', 565 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 566 | parser.add_argument("--lindisp", action='store_true', 567 | help='sampling linearly in disparity rather than depth') 568 | parser.add_argument("--spherify", default = True, action='store_true', 569 | help='set for spherical 360 scenes') 570 | parser.add_argument("--llffhold", type=int, default=10, 571 | help='will take every 1/N images as LLFF test set, paper uses 8') 572 | 573 | # logging/saving options 574 | parser.add_argument("--i_print", type=int, default=100, 575 | help='frequency of console printout and metric loggin') 576 | parser.add_argument("--i_img", type=int, default=100, 577 | help='frequency of tensorboard image logging') 578 | parser.add_argument("--i_weights", type=int, default=500, 579 | help='frequency of weight ckpt saving') 580 | 581 | 582 | return parser 583 | 584 | 585 | def train(): 586 | 587 | parser = config_parser() 588 | args = parser.parse_args() 589 | 590 | # Load data 591 | K = None 592 | if args.dataset_type == 'llff': 593 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 594 | recenter=True, bd_factor=.75, 595 | spherify=args.spherify) 596 | hwf = poses[0,:3,-1] 597 | poses = poses[:,:3,:4] 598 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 599 | if not isinstance(i_test, list): 600 | i_test = [i_test] 601 | 602 | if args.llffhold > 0: 603 | print('Auto LLFF holdout,', args.llffhold) 604 | i_test = np.arange(images.shape[0])[::args.llffhold] 605 | 606 | i_val = i_test 607 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 608 | (i not in i_test and i not in i_val)]) 609 | 610 | print('DEFINING BOUNDS') 611 | if args.no_ndc: 612 | near = np.ndarray.min(bds) * .9 613 | far = np.ndarray.max(bds) * 1. 614 | 615 | else: 616 | near = 0. 617 | far = 1. 618 | 619 | 620 | if args.spherify: 621 | far = np.minimum(far,2.5) 622 | print('NEAR FAR', near, far) 623 | 624 | # Cast intrinsics to right types 625 | H, W, focal = hwf 626 | H, W = int(H), int(W) 627 | hwf = [H, W, focal] 628 | 629 | if K is None: 630 | K = np.array([ 631 | [focal, 0, 0.5*W], 632 | [0, focal, 0.5*H], 633 | [0, 0, 1] 634 | ]) 635 | 636 | if args.render_test: 637 | render_poses = np.array(poses[i_test]) 638 | 639 | # Create log dir and copy the config file 640 | basedir = args.basedir 641 | expname = args.expname 642 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 643 | f = os.path.join(basedir, expname, 'args.txt') 644 | with open(f, 'w') as file: 645 | for arg in sorted(vars(args)): 646 | attr = getattr(args, arg) 647 | file.write('{} = {}\n'.format(arg, attr)) 648 | if args.config is not None: 649 | f = os.path.join(basedir, expname, 'config.txt') 650 | with open(f, 'w') as file: 651 | file.write(open(args.config, 'r').read()) 652 | 653 | 654 | 655 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args) 656 | start = 0 657 | global_step = start 658 | 659 | 660 | bds_dict = { 661 | 'near' : near, 662 | 'far' : far, 663 | } 664 | render_kwargs_train.update(bds_dict) 665 | render_kwargs_test.update(bds_dict) 666 | 667 | 668 | print('ray batching') 669 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 670 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 671 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 672 | rays_rgb = np.reshape(rays_rgb, [len(images),-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 673 | rays_rgb = rays_rgb.astype(np.float32) 674 | rays_rgb = torch.Tensor(rays_rgb).to(device) 675 | rays_rgb = [rays_rgb[i] for i in i_train] # train images only 676 | 677 | 678 | poses = torch.Tensor(poses).to(device) 679 | render_poses = torch.Tensor(render_poses).to(device) 680 | 681 | 682 | print('loading mask images') 683 | path_mask = os.path.join(basedir, expname, 'masked_regions') # concatenating the indices of masked region in each view 684 | inside_idx = get_mask_idx(path_mask,len(poses),i_train) 685 | 686 | 687 | print('loading bounding box values') 688 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy')) 689 | 690 | 691 | render_kwargs_test['bb_vals'] = bounding_box_vals 692 | render_kwargs_train['bb_vals'] = bounding_box_vals 693 | 694 | 695 | 696 | N_iters = 10000 + 1 # number of total iteration 697 | N_rand = args.N_rand 698 | print('Begin') 699 | print('TEST views are', i_test) 700 | 701 | 702 | for i in trange(start,N_iters): 703 | time0 = time.time() 704 | 705 | 706 | # taking random rays overl all views 707 | batch = get_rnd(rays_rgb,inside_idx,N_rand) 708 | batch = torch.transpose(batch, 0, 1) 709 | 710 | batch_rays, target_s = batch[:2], batch[2] 711 | 712 | render_kwargs_train['N_samples'] = 256 + 256*i//N_iters # starting with 256 smaples and linearly increasing it to 512 samples 713 | 714 | 715 | rgb, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 716 | verbose=i < 10, retraw=False, 717 | **render_kwargs_train) 718 | 719 | optimizer.zero_grad() 720 | loss = img2abs(rgb, target_s) 721 | psnr = mse2psnr(img2mse(rgb, target_s)) 722 | 723 | 724 | 725 | loss.backward() 726 | optimizer.step() 727 | 728 | decay_rate = 0.1 729 | decay_steps = N_iters 730 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 731 | for param_group in optimizer.param_groups: 732 | param_group['lr'] = new_lrate 733 | ################################ 734 | 735 | dt = time.time()-time0 736 | 737 | if i%args.i_weights==0: 738 | model_path_dir = os.path.join(basedir, expname,args.model_path) 739 | os.makedirs(model_path_dir, exist_ok=True) 740 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i)) 741 | torch.save({ 742 | 'global_step': global_step, 743 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 744 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 745 | 'network_ior_state_dict': render_kwargs_train['network_ior'].state_dict(), 746 | 'network_inside_state_dict': render_kwargs_train['network_inside'].state_dict(), 747 | 'optimizer_state_dict': optimizer.state_dict(), 748 | }, path) 749 | print('Saved checkpoints at', path) 750 | 751 | 752 | if i%args.i_img==0: 753 | 754 | img_i= i_test[5] 755 | pose = poses[img_i, :3,:4] 756 | with torch.no_grad(): 757 | rgb ,extras = render(H, W,K, chunk=args.chunk, c2w=pose,retraw=False, 758 | **render_kwargs_test) 759 | 760 | 761 | testimgdir = os.path.join(basedir, expname, 'training_nerf_inside') 762 | os.makedirs(testimgdir, exist_ok=True) 763 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:03d}.png'.format(i)), to8b(rgb.cpu().numpy())) 764 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:03d}.png'.format(img_i)), to8b(images[img_i])) 765 | 766 | 767 | if i%args.i_print==0: 768 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 769 | 770 | 771 | global_step += 1 772 | 773 | 774 | if __name__=='__main__': 775 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 776 | 777 | train() -------------------------------------------------------------------------------- /run_ior.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys, cv2 3 | import numpy as np 4 | import imageio 5 | import json 6 | import random 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from tqdm import tqdm, trange 12 | from torchdiffeq import odeint_adjoint as odeint 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | from run_nerf_helpers import * 17 | 18 | from load_llff import load_llff_data 19 | from load_deepvoxels import load_dv_data 20 | from load_blender import load_blender_data 21 | from load_LINEMOD import load_LINEMOD_data 22 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 23 | 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | np.random.seed(0) 27 | DEBUG = False 28 | 29 | 30 | def get_rnd(all_rays,all_idx_inside,num_ray): 31 | 32 | 33 | all_ = [] 34 | chunk_per_im = num_ray//len(all_idx_inside) 35 | for i in range(len(all_idx_inside)): 36 | 37 | idx_inside = all_idx_inside[i][np.random.choice(len(all_idx_inside[i]), size=[chunk_per_im], replace=False)] 38 | all_.append(all_rays[i][idx_inside]) 39 | 40 | return torch.cat(all_,0) 41 | 42 | 43 | def get_mask_idx(path_to_mask,num_im,i_train): 44 | 45 | mask_ = [] 46 | for img_i in range(num_im): 47 | 48 | im = imageio.imread(os.path.join(path_to_mask, 'img_%0.3d.png'%(img_i))) 49 | im = np.float32(im)/255.0 50 | im = im.reshape([-1]) 51 | inside = np.where(im== 1.0)[0] 52 | mask_.append(inside) 53 | 54 | 55 | mask = [mask_[i] for i in i_train] 56 | 57 | return mask 58 | 59 | 60 | def batchify(fn, chunk): 61 | """Constructs a version of 'fn' that applies to smaller batches. 62 | """ 63 | if chunk is None: 64 | return fn 65 | def ret(inputs): 66 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 67 | return ret 68 | 69 | 70 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 71 | """Prepares inputs and applies network 'fn'. 72 | """ 73 | embedded = embed_fn(inputs) 74 | 75 | if viewdirs is not None: 76 | embedded_dirs = embeddirs_fn(viewdirs) 77 | 78 | embedded = torch.cat([embedded, embedded_dirs], -1) 79 | 80 | outputs_flat = batchify(fn, netchunk)(embedded) 81 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 82 | return outputs 83 | 84 | 85 | def run_IOR_network(inputs, fn, embed_fn, netchunk=1024*64): 86 | 87 | embedded = embed_fn(inputs) 88 | 89 | 90 | outputs = batchify(fn, netchunk)(embedded) 91 | 92 | return outputs 93 | 94 | 95 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 96 | """Render rays in smaller minibatches to avoid OOM. 97 | """ 98 | all_ret = {} 99 | for i in range(0, rays_flat.shape[0], chunk): 100 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 101 | for k in ret: 102 | if k not in all_ret: 103 | all_ret[k] = [] 104 | all_ret[k].append(ret[k]) 105 | 106 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 107 | return all_ret 108 | 109 | 110 | 111 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 112 | near=0., far=1., 113 | use_viewdirs=False, c2w_staticcam=None, 114 | **kwargs): 115 | """Render rays 116 | Args: 117 | H: int. Height of image in pixels. 118 | W: int. Width of image in pixels. 119 | focal: float. Focal length of pinhole camera. 120 | chunk: int. Maximum number of rays to process simultaneously. Used to 121 | control maximum memory usage. Does not affect final results. 122 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 123 | each example in batch. 124 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 125 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 126 | near: float or array of shape [batch_size]. Nearest distance for a ray. 127 | far: float or array of shape [batch_size]. Farthest distance for a ray. 128 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 129 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 130 | camera while using other c2w argument for viewing directions. 131 | Returns: 132 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 133 | disp_map: [batch_size]. Disparity map. Inverse of depth. 134 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 135 | extras: dict with everything returned by render_rays(). 136 | """ 137 | if c2w is not None: 138 | # special case to render full image 139 | rays_o, rays_d = get_rays(H, W, K, c2w) 140 | else: 141 | # use provided ray batch 142 | rays_o, rays_d = rays 143 | 144 | if use_viewdirs: 145 | # provide ray directions as input 146 | viewdirs = rays_d 147 | if c2w_staticcam is not None: 148 | # special case to visualize effect of viewdirs 149 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 150 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 151 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 152 | 153 | sh = rays_d.shape # [..., 3] 154 | if ndc: 155 | # for forward facing scenes 156 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 157 | 158 | # Create ray batch 159 | rays_o = torch.reshape(rays_o, [-1,3]).float() 160 | rays_d = torch.reshape(rays_d, [-1,3]).float() 161 | 162 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 163 | rays = torch.cat([rays_o, rays_d, near, far], -1) 164 | if use_viewdirs: 165 | rays = torch.cat([rays, viewdirs], -1) 166 | 167 | # Render and reshape 168 | all_ret = batchify_rays(rays, chunk, **kwargs) 169 | for k in all_ret: 170 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 171 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 172 | 173 | k_extract = ['rgb_map'] 174 | ret_list = [all_ret[k] for k in k_extract] 175 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 176 | return ret_list + [ret_dict] 177 | 178 | 179 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 180 | 181 | H, W, focal = hwf 182 | 183 | if render_factor!=0: 184 | # Render downsampled for speed 185 | H = H//render_factor 186 | W = W//render_factor 187 | focal = focal/render_factor 188 | 189 | rgbs = [] 190 | disps = [] 191 | 192 | t = time.time() 193 | for i, c2w in enumerate(tqdm(render_poses)): 194 | print(i, time.time() - t) 195 | t = time.time() 196 | rgb, disp, acc, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 197 | rgbs.append(rgb.cpu().numpy()) 198 | disps.append(disp.cpu().numpy()) 199 | if i==0: 200 | print(rgb.shape, disp.shape) 201 | 202 | """ 203 | if gt_imgs is not None and render_factor==0: 204 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 205 | print(p) 206 | """ 207 | 208 | if savedir is not None: 209 | rgb8 = to8b(rgbs[-1]) 210 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 211 | imageio.imwrite(filename, rgb8) 212 | 213 | 214 | rgbs = np.stack(rgbs, 0) 215 | disps = np.stack(disps, 0) 216 | 217 | return rgbs, disps 218 | 219 | 220 | def create_models(args): 221 | """Instantiate NeRF and IOR MLP models. 222 | """ 223 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 224 | 225 | input_ch_views = 0 226 | embeddirs_fn = None 227 | if args.use_viewdirs: 228 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 229 | output_ch = 0 230 | skips = [4] 231 | model = NeRF(D=args.netdepth, W=args.netwidth, 232 | input_ch=input_ch, output_ch=output_ch, skips=skips, 233 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 234 | 235 | 236 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 237 | input_ch=input_ch, output_ch=output_ch, skips=skips, 238 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 239 | 240 | 241 | 242 | embed_fn_ior, input_ch_ior = get_embedder(args.multires_views_ior, args.i_embed) 243 | model_ior = MLP_IOR(input_ch = input_ch_ior,D=args.netdepth_ior, W= args.netwidth_ior ,skips=[3]).to(device) 244 | model_ior.apply(init_weights) 245 | 246 | 247 | # only optimizing for the IoR field 248 | grad_vars = list(model_ior.parameters()) 249 | 250 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 251 | embed_fn=embed_fn, 252 | embeddirs_fn=embeddirs_fn, 253 | netchunk=args.netchunk) 254 | 255 | network_query_fn_ior = lambda inputs, network_fn : run_IOR_network(inputs, network_fn, 256 | embed_fn = embed_fn_ior, 257 | netchunk=args.netchunk) 258 | # Create optimizer 259 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 260 | 261 | 262 | start = 0 263 | basedir = args.basedir 264 | expname = args.expname 265 | 266 | 267 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f] 268 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path)) 269 | 270 | 271 | print('Found ckpts', ckpts) 272 | 273 | if len(ckpts) > 0 : 274 | ckpt_path = ckpts[-1] 275 | print('Reloading from', ckpt_path) 276 | ckpt = torch.load(ckpt_path) 277 | 278 | start = ckpt['global_step'] 279 | 280 | # Load model 281 | model.load_state_dict(ckpt['network_fn_state_dict']) 282 | print('Reloading the weights for the NeRF MLPs') 283 | 284 | if model_fine is not None: 285 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 286 | 287 | if "network_ior_state_dict" in ckpt and not args.no_reload: 288 | model_ior.load_state_dict(ckpt['network_ior_state_dict']) 289 | print('Reloading the weights for the IoR MLP') 290 | 291 | 292 | 293 | ########################## 294 | 295 | render_kwargs_train = { 296 | 297 | 'network_query_fn' : network_query_fn, 298 | 'network_query_fn_ior': network_query_fn_ior, 299 | 'perturb' : args.perturb, 300 | 'network_fine' : model_fine, 301 | 'N_samples' : args.N_samples, 302 | 'network_fn' : model, 303 | 'network_ior': model_ior, 304 | 'use_viewdirs' : args.use_viewdirs, 305 | 'white_bkgd' : args.white_bkgd, 306 | 'raw_noise_std' : args.raw_noise_std, 307 | } 308 | 309 | # NDC only good for LLFF-style forward facing data 310 | if args.dataset_type != 'llff' or args.no_ndc: 311 | print('Not ndc!') 312 | render_kwargs_train['ndc'] = False 313 | render_kwargs_train['lindisp'] = args.lindisp 314 | 315 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 316 | render_kwargs_test['perturb'] = False 317 | render_kwargs_test['raw_noise_std'] = 0. 318 | 319 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 320 | 321 | 322 | 323 | class IoR_emissionAbsorptionODE(nn.Module): 324 | 325 | def __init__(self,voxel_grid,query_fn,model,n_samples,step_size,bb_vals,scene_bound): 326 | super(IoR_emissionAbsorptionODE, self).__init__() 327 | 328 | self.voxel_grid = voxel_grid 329 | self.model = model 330 | self.query_fn = query_fn 331 | self.n_samples = n_samples 332 | self.step_size = step_size 333 | self.bounding_box_vals = bb_vals 334 | self.scene_bound = scene_bound 335 | 336 | def forward(self,t,y): 337 | 338 | 339 | pts = y[:,0:3] 340 | ray_dir = y[:,3:6] 341 | transmition_ = y[:,9:10] 342 | 343 | # query the 3D grid with trilinear interpolation 344 | raw = trilinear_interpolation(self.voxel_grid,pts,self.scene_bound) 345 | rgb = raw[...,:3] 346 | density = raw[...,3:] 347 | 348 | 349 | # determining whether a point is inside the bounding box or not 350 | weights = get_bb_weights(pts,self.bounding_box_vals) 351 | 352 | 353 | # finding the points inside the bounding box 354 | insides = torch.where((weights<1.))[0] 355 | ior_grad = torch.zeros_like(pts) 356 | 357 | 358 | # setting the density to zero for the points inside the bounding box 359 | density= density*weights 360 | 361 | # calculating the gradient of IoR only for the points inside the bounding box 362 | if len(insides) > 0: 363 | pts_inside = pts[insides,:] 364 | 365 | with torch.enable_grad(): 366 | 367 | dn = torch.autograd.functional.vjp(lambda x : self.query_fn(x,self.model), pts_inside ,v = torch.ones_like(pts_inside[:,0:1]),create_graph =True ) 368 | ior_grad[insides,:] = dn[1]*(1.-weights[insides,:]) 369 | 370 | 371 | 372 | dv_ds = self.step_size*ior_grad 373 | dx_ds = self.step_size*normalizing(ray_dir) 374 | alpha = 1. - torch.exp(-density*self.step_size/self.n_samples) 375 | alpha = alpha.clip(0.,1.) 376 | dc_dt = transmition_*rgb*alpha*self.n_samples 377 | dT_dt = -transmition_*alpha*self.n_samples 378 | dy_dt = torch.cat([dx_ds,dv_ds,dc_dt,dT_dt],-1) 379 | 380 | 381 | return dy_dt 382 | 383 | 384 | def render_rays(ray_batch, 385 | voxel_grid, 386 | N_samples, 387 | network_query_fn_ior, 388 | bb_vals, 389 | network_ior, 390 | scene_bound, 391 | network_fn=None, 392 | network_query_fn=None, 393 | retraw=False, 394 | lindisp=False, 395 | perturb=0., 396 | network_fine=None, 397 | white_bkgd=False, 398 | raw_noise_std=0., 399 | verbose=False, 400 | pytest=False): 401 | 402 | N_rays = ray_batch.shape[0] 403 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 404 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 405 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 406 | 407 | t_vals = torch.linspace(0., 1., steps=N_samples) 408 | 409 | 410 | pts = rays_o + normalizing(rays_d)*near 411 | viewdir = normalizing(rays_d) 412 | color_ = torch.zeros_like(pts) 413 | transmition_ = torch.ones((N_rays,1)) 414 | 415 | 416 | y0 = torch.cat((pts,viewdir,color_,transmition_),-1).to(device) 417 | step_size = far[0]-near[0] 418 | output = odeint(IoR_emissionAbsorptionODE(voxel_grid,network_query_fn_ior,network_ior,N_samples,step_size,bb_vals,scene_bound), y0, t_vals,method='euler') 419 | 420 | pts = output[:,:,0:3].permute(1,0,2) 421 | 422 | rgb_map = output[-1,:,6:9] 423 | 424 | 425 | ret = {'rgb_map' : rgb_map} 426 | if retraw: 427 | ret['raw'] = pts 428 | 429 | for k in ret: 430 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 431 | print(f"! [Numerical Error] {k} contains nan or inf.") 432 | 433 | return ret 434 | 435 | 436 | 437 | def config_parser(): 438 | 439 | import configargparse 440 | parser = configargparse.ArgumentParser() 441 | parser.add_argument('--config', is_config_file=True, 442 | help='config file path') 443 | parser.add_argument("--expname", default='glass_ball_new',type=str, 444 | help='experiment name') 445 | parser.add_argument("--basedir", type=str, default='logs', 446 | help='where to store ckpts and logs') 447 | parser.add_argument("--datadir", type=str, default='data\\glass_ball', 448 | help='input data directory') 449 | 450 | # training options 451 | parser.add_argument("--netdepth", type=int, default=8, 452 | help='layers in network') 453 | parser.add_argument("--netwidth", type=int, default=256, 454 | help='channels per layer') 455 | parser.add_argument("--netdepth_fine", type=int, default=8, 456 | help='layers in fine network') 457 | parser.add_argument("--netwidth_fine", type=int, default=256, 458 | help='channels per layer in fine network') 459 | parser.add_argument("--N_rand", type=int, default=32*32*32, 460 | help='batch size (number of random rays per gradient step)') 461 | parser.add_argument("--lrate", type=float, default=5e-4, 462 | help='learning rate') 463 | parser.add_argument("--lrate_decay", type=int, default=250, 464 | help='exponential learning rate decay (in 1000 steps)') 465 | parser.add_argument("--chunk", type=int, default=1024*32*32, 466 | help='number of rays processed in parallel, decrease if running out of memory') 467 | parser.add_argument("--netchunk", type=int, default=1024*64, 468 | help='number of pts sent through network in parallel, decrease if running out of memory') 469 | parser.add_argument("--no_batching", default=False,action='store_true', 470 | help='only take random rays from 1 image at a time') 471 | parser.add_argument("--no_reload", default=True, action='store_true', 472 | help='do not reload the weights for the IoR MLP') 473 | parser.add_argument("--model_path", type=str, default='model_weights', 474 | help='path to trained model weights') 475 | 476 | 477 | # IoR model parameters 478 | 479 | parser.add_argument("--netdepth_ior", type=int, default=6, 480 | help='layers in the IoR network') 481 | parser.add_argument("--netwidth_ior", type=int, default=64, 482 | help='channels per layer') 483 | parser.add_argument("--multires_views_ior", type=int, default=5, 484 | help='log2 of max freq for positional encoding (2D direction)') 485 | 486 | 487 | # rendering options 488 | parser.add_argument("--N_samples", type=int, default=128, 489 | help='number of coarse samples per ray') 490 | parser.add_argument("--N_importance", type=int, 491 | help='Not used in this code') 492 | parser.add_argument("--perturb", type=float, default=1., 493 | help='set to 0. for no jitter, 1. for jitter') 494 | parser.add_argument("--use_viewdirs", default=True,action='store_true', 495 | help='use full 5D input instead of 3D') 496 | parser.add_argument("--use_mask", default=False,action='store_true', 497 | help='use full 5D input instead of 3D') 498 | parser.add_argument("--i_embed", type=int, default=0, 499 | help='set 0 for default positional encoding, -1 for none') 500 | parser.add_argument("--multires", type=int, default=10, 501 | help='log2 of max freq for positional encoding (3D location)') 502 | parser.add_argument("--multires_views", type=int, default=4, 503 | help='log2 of max freq for positional encoding (2D direction)') 504 | parser.add_argument("--raw_noise_std", type=float, default=1., 505 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 506 | 507 | parser.add_argument("--render_only", action='store_true', 508 | help='do not optimize, reload weights and render out render_poses path') 509 | parser.add_argument("--render_test", action='store_true', 510 | help='render the test set instead of render_poses path') 511 | parser.add_argument("--render_factor", type=int, default=0, 512 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 513 | 514 | # training options 515 | parser.add_argument("--precrop_iters", type=int, default=0, 516 | help='number of steps to train on central crops') 517 | parser.add_argument("--precrop_frac", type=float, 518 | default=.5, help='fraction of img taken for central crops') 519 | 520 | # dataset options 521 | parser.add_argument("--dataset_type", type=str, default='llff', 522 | help='options: llff / blender / deepvoxels') 523 | parser.add_argument("--testskip", type=int, default=8, 524 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 525 | 526 | ## deepvoxels flags 527 | parser.add_argument("--shape", type=str, default='greek', 528 | help='options : armchair / cube / greek / vase') 529 | 530 | ## blender flags 531 | parser.add_argument("--white_bkgd", action='store_true', 532 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 533 | parser.add_argument("--half_res", action='store_true', 534 | help='load blender synthetic data at 400x400 instead of 800x800') 535 | 536 | ## llff flags 537 | parser.add_argument("--factor", type=int, default=1, 538 | help='downsample factor for LLFF images') 539 | parser.add_argument("--no_ndc", default = True,action='store_true', 540 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 541 | parser.add_argument("--lindisp", action='store_true', 542 | help='sampling linearly in disparity rather than depth') 543 | parser.add_argument("--spherify", default = True, action='store_true', 544 | help='set for spherical 360 scenes') 545 | parser.add_argument("--llffhold", type=int, default=10, 546 | help='will take every 1/N images as LLFF test set, paper uses 8') 547 | 548 | # logging/saving options 549 | parser.add_argument("--i_print", type=int, default=100, 550 | help='frequency of console printout and metric loggin') 551 | parser.add_argument("--i_img", type=int, default=50, 552 | help='frequency of tensorboard image logging') 553 | parser.add_argument("--i_weights", type=int, default=500, 554 | help='frequency of weight ckpt saving') 555 | 556 | 557 | return parser 558 | 559 | 560 | 561 | def train(): 562 | 563 | parser = config_parser() 564 | args = parser.parse_args() 565 | 566 | # Load data 567 | K = None 568 | if args.dataset_type == 'llff': 569 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 570 | recenter=True, bd_factor=.75, 571 | spherify=args.spherify) 572 | hwf = poses[0,:3,-1] 573 | poses = poses[:,:3,:4] 574 | i_test = [] 575 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 576 | if not isinstance(i_test, list): 577 | i_test = [i_test] 578 | 579 | if args.llffhold > 0: 580 | print('Auto LLFF holdout,', args.llffhold) 581 | i_test = np.arange(images.shape[0])[::args.llffhold] 582 | 583 | i_val = i_test 584 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 585 | (i not in i_test and i not in i_val)]) 586 | 587 | print('DEFINING BOUNDS') 588 | if args.no_ndc: 589 | near = np.ndarray.min(bds) * .9 590 | far = np.ndarray.max(bds) * 1. 591 | 592 | else: 593 | near = 0. 594 | far = 1. 595 | 596 | 597 | if args.spherify: 598 | far = np.minimum(far,2.0) 599 | print('NEAR FAR', near, far) 600 | 601 | 602 | # Cast intrinsics to right types 603 | H, W, focal = hwf 604 | H, W = int(H), int(W) 605 | hwf = [H, W, focal] 606 | 607 | if K is None: 608 | K = np.array([ 609 | [focal, 0, 0.5*W], 610 | [0, focal, 0.5*H], 611 | [0, 0, 1] 612 | ]) 613 | 614 | if args.render_test: 615 | render_poses = np.array(poses[i_test]) 616 | 617 | # Create log dir and copy the config file 618 | basedir = args.basedir 619 | expname = args.expname 620 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 621 | f = os.path.join(basedir, expname, 'args.txt') 622 | with open(f, 'w') as file: 623 | for arg in sorted(vars(args)): 624 | attr = getattr(args, arg) 625 | file.write('{} = {}\n'.format(arg, attr)) 626 | if args.config is not None: 627 | f = os.path.join(basedir, expname, 'config.txt') 628 | with open(f, 'w') as file: 629 | file.write(open(args.config, 'r').read()) 630 | 631 | 632 | 633 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_models(args) 634 | start = 0 635 | global_step = start 636 | 637 | 638 | bds_dict = { 639 | 'near' : near, 640 | 'far' : far, 641 | } 642 | render_kwargs_train.update(bds_dict) 643 | render_kwargs_test.update(bds_dict) 644 | 645 | 646 | print('ray batching') 647 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 648 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 649 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 650 | rays_rgb = np.reshape(rays_rgb, [len(images),-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 651 | rays_rgb = rays_rgb.astype(np.float32) 652 | rays_rgb = torch.Tensor(rays_rgb).to(device) 653 | rays_rgb = [rays_rgb[i] for i in i_train] # train images only 654 | 655 | poses = torch.Tensor(poses).to(device) 656 | render_poses = torch.Tensor(render_poses).to(device) 657 | 658 | 659 | print('loading mask images') 660 | path_mask = os.path.join(basedir, expname, 'masked_regions') 661 | inside_idx = get_mask_idx(path_mask,len(poses),i_train) # concatenating the indices of masked region in each view 662 | 663 | 664 | print('calculating the scene bounds for 3D grid') 665 | scene_bound = get_scene_bound(near,far,H,W,K,poses,min_=.01,max_=.01) 666 | 667 | render_kwargs_test['scene_bound'] = scene_bound 668 | render_kwargs_train['scene_bound'] = scene_bound 669 | 670 | 671 | print('loading bounding box values') 672 | bounding_box_vals = np.load(os.path.join(basedir, expname, 'bounding_box\\bounding_box_vals.npy')) 673 | 674 | 675 | render_kwargs_test['bb_vals'] = bounding_box_vals 676 | render_kwargs_train['bb_vals'] = bounding_box_vals 677 | 678 | 679 | print('fitting a 3D grid to learnt radiance field (Nerf)') 680 | query_function = render_kwargs_test['network_query_fn'] 681 | nerf_coarse_model = render_kwargs_test['network_fn'] 682 | Grid_res = 128 683 | 684 | reuse = False 685 | testimgdir = os.path.join(basedir, expname, 'voxel_grid.npy') 686 | 687 | skip_im = 4 688 | if reuse: 689 | voxel_grid = torch.tensor(np.load(testimgdir)).to(device) 690 | else: 691 | voxel_grid = get_voxel_grid(Grid_res,scene_bound,poses[0:-1:skip_im],query_function,nerf_coarse_model,False,bounding_box_vals) 692 | np.save(testimgdir,voxel_grid.cpu().numpy()) 693 | 694 | 695 | 696 | N_iters = 5000 + 1 # number of total iteration 697 | iter_level = 1000 # number of iteration per each level 698 | sigma = 0.08 # initial sigma for the 3D Gaussian blur kernel 699 | N_rand = args.N_rand 700 | print('Begin') 701 | print('TEST views are', i_test) 702 | 703 | start = start + 0 704 | loss_vals = [] 705 | 706 | 707 | for i in trange(start,N_iters): 708 | time0 = time.time() 709 | 710 | 711 | # taking random rays overl all views 712 | batch = get_rnd(rays_rgb,inside_idx,N_rand) 713 | batch = torch.transpose(batch, 0, 1) 714 | 715 | batch_rays, target_s = batch[:2], batch[2] 716 | 717 | 718 | # creating a 3D Gaussian blur kernel 719 | filter_3d = lowpass_3d(Grid_res,sigma*2**(i//iter_level)) 720 | 721 | # smoothing the grid with the 3D Gaussian blur kernel in Fourier domain 722 | voxel_grid_blur = voxel_lowpass_filtering(voxel_grid,filter_3d) 723 | 724 | render_kwargs_train['voxel_grid'] = voxel_grid_blur 725 | render_kwargs_train['N_samples'] = 128 726 | 727 | 728 | ##### Core optimization loop ##### 729 | rgb, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 730 | verbose=i < 10, retraw=False, 731 | **render_kwargs_train) 732 | 733 | loss = img2mse(rgb, target_s) 734 | psnr = mse2psnr(loss) 735 | 736 | 737 | 738 | loss.backward() 739 | optimizer.step() 740 | optimizer.zero_grad() 741 | 742 | 743 | decay_rate = 0.1 744 | new_lrate = args.lrate* (decay_rate ** (global_step / N_iters)) 745 | for param_group in optimizer.param_groups: 746 | param_group['lr'] = new_lrate 747 | ################################ 748 | 749 | dt = time.time()-time0 750 | 751 | loss_vals.append(loss.detach().cpu().numpy()) 752 | 753 | 754 | if i%args.i_weights==0: 755 | model_path_dir = os.path.join(basedir, expname,args.model_path) 756 | os.makedirs(model_path_dir, exist_ok=True) 757 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i)) 758 | torch.save({ 759 | 'global_step': global_step, 760 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 761 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 762 | 'network_ior_state_dict': render_kwargs_train['network_ior'].state_dict(), 763 | 'optimizer_state_dict': optimizer.state_dict(), 764 | }, path) 765 | print('Saved checkpoints at', path) 766 | 767 | 768 | 769 | if i%args.i_img==0: 770 | 771 | render_kwargs_test['voxel_grid'] = voxel_grid # rendering images with finest version of the grid 772 | img_i= i_test[5] 773 | pose = poses[img_i, :3,:4] 774 | with torch.no_grad(): 775 | rgb ,extras = render(H, W,K, chunk=args.chunk, c2w=pose, 776 | **render_kwargs_test) 777 | 778 | 779 | testimgdir = os.path.join(basedir, expname, 'training_ior') 780 | os.makedirs(testimgdir, exist_ok=True) 781 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:03d}.png'.format(i)), to8b(rgb.cpu().numpy())) 782 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:03d}.png'.format(img_i)), to8b(images[img_i])) 783 | 784 | # plt.plot(loss_vals) 785 | # plt.show() 786 | 787 | if i%args.i_print==0: 788 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 789 | 790 | 791 | global_step += 1 792 | 793 | 794 | if __name__=='__main__': 795 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 796 | 797 | train() -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm, trange 11 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from run_nerf_helpers import * 16 | 17 | from load_llff import load_llff_data 18 | from load_deepvoxels import load_dv_data 19 | from load_blender import load_blender_data 20 | from load_LINEMOD import load_LINEMOD_data 21 | 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | np.random.seed(0) 25 | DEBUG = False 26 | 27 | 28 | def batchify(fn, chunk): 29 | """Constructs a version of 'fn' that applies to smaller batches. 30 | """ 31 | if chunk is None: 32 | return fn 33 | def ret(inputs): 34 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 35 | return ret 36 | 37 | 38 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 39 | """Prepares inputs and applies network 'fn'. 40 | """ 41 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 42 | embedded = embed_fn(inputs_flat) 43 | 44 | if viewdirs is not None: 45 | input_dirs = viewdirs[:,None].expand(inputs.shape) 46 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 47 | embedded_dirs = embeddirs_fn(input_dirs_flat) 48 | embedded = torch.cat([embedded, embedded_dirs], -1) 49 | 50 | outputs_flat = batchify(fn, netchunk)(embedded) 51 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 52 | return outputs 53 | 54 | 55 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 56 | """Render rays in smaller minibatches to avoid OOM. 57 | """ 58 | all_ret = {} 59 | for i in range(0, rays_flat.shape[0], chunk): 60 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 61 | for k in ret: 62 | if k not in all_ret: 63 | all_ret[k] = [] 64 | all_ret[k].append(ret[k]) 65 | 66 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 67 | return all_ret 68 | 69 | 70 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 71 | near=0., far=1., 72 | use_viewdirs=False, c2w_staticcam=None, 73 | **kwargs): 74 | """Render rays 75 | Args: 76 | H: int. Height of image in pixels. 77 | W: int. Width of image in pixels. 78 | focal: float. Focal length of pinhole camera. 79 | chunk: int. Maximum number of rays to process simultaneously. Used to 80 | control maximum memory usage. Does not affect final results. 81 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 82 | each example in batch. 83 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 84 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 85 | near: float or array of shape [batch_size]. Nearest distance for a ray. 86 | far: float or array of shape [batch_size]. Farthest distance for a ray. 87 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 88 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 89 | camera while using other c2w argument for viewing directions. 90 | Returns: 91 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 92 | disp_map: [batch_size]. Disparity map. Inverse of depth. 93 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 94 | extras: dict with everything returned by render_rays(). 95 | """ 96 | if c2w is not None: 97 | # special case to render full image 98 | rays_o, rays_d = get_rays(H, W, K, c2w) 99 | else: 100 | # use provided ray batch 101 | rays_o, rays_d = rays 102 | 103 | if use_viewdirs: 104 | # provide ray directions as input 105 | viewdirs = rays_d 106 | if c2w_staticcam is not None: 107 | # special case to visualize effect of viewdirs 108 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 109 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 110 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 111 | 112 | sh = rays_d.shape # [..., 3] 113 | if ndc: 114 | # for forward facing scenes 115 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 116 | 117 | # Create ray batch 118 | rays_o = torch.reshape(rays_o, [-1,3]).float() 119 | rays_d = torch.reshape(rays_d, [-1,3]).float() 120 | 121 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 122 | rays = torch.cat([rays_o, rays_d, near, far], -1) 123 | if use_viewdirs: 124 | rays = torch.cat([rays, viewdirs], -1) 125 | 126 | # Render and reshape 127 | all_ret = batchify_rays(rays, chunk, **kwargs) 128 | for k in all_ret: 129 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 130 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 131 | 132 | k_extract = ['rgb_map', 'disp_map', 'acc_map'] 133 | ret_list = [all_ret[k] for k in k_extract] 134 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 135 | return ret_list + [ret_dict] 136 | 137 | 138 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 139 | 140 | H, W, focal = hwf 141 | 142 | if render_factor!=0: 143 | # Render downsampled for speed 144 | H = H//render_factor 145 | W = W//render_factor 146 | focal = focal/render_factor 147 | 148 | rgbs = [] 149 | disps = [] 150 | 151 | t = time.time() 152 | for i, c2w in enumerate(tqdm(render_poses)): 153 | print(i, time.time() - t) 154 | t = time.time() 155 | rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 156 | rgbs.append(rgb.cpu().numpy()) 157 | disps.append(disp.cpu().numpy()) 158 | if i==0: 159 | print(rgb.shape, disp.shape) 160 | 161 | """ 162 | if gt_imgs is not None and render_factor==0: 163 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 164 | print(p) 165 | """ 166 | 167 | if savedir is not None: 168 | rgb8 = to8b(rgbs[-1]) 169 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 170 | imageio.imwrite(filename, rgb8) 171 | 172 | 173 | rgbs = np.stack(rgbs, 0) 174 | disps = np.stack(disps, 0) 175 | 176 | return rgbs, disps 177 | 178 | 179 | def create_nerf(args): 180 | """Instantiate NeRF's MLP model. 181 | """ 182 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 183 | 184 | input_ch_views = 0 185 | embeddirs_fn = None 186 | if args.use_viewdirs: 187 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 188 | output_ch = 5 if args.N_importance > 0 else 4 189 | skips = [4] 190 | model = NeRF(D=args.netdepth, W=args.netwidth, 191 | input_ch=input_ch, output_ch=output_ch, skips=skips, 192 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 193 | grad_vars = list(model.parameters()) 194 | 195 | model_fine = None 196 | if args.N_importance > 0: 197 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 198 | input_ch=input_ch, output_ch=output_ch, skips=skips, 199 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) 200 | grad_vars += list(model_fine.parameters()) 201 | 202 | network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 203 | embed_fn=embed_fn, 204 | embeddirs_fn=embeddirs_fn, 205 | netchunk=args.netchunk) 206 | 207 | # Create optimizer 208 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 209 | 210 | start = 0 211 | basedir = args.basedir 212 | expname = args.expname 213 | 214 | ########################## 215 | 216 | # Load checkpoints 217 | os.makedirs(os.path.join(basedir, expname,args.model_path), exist_ok=True) 218 | 219 | ckpts = [os.path.join(basedir, expname,args.model_path, f) for f in sorted(os.listdir(os.path.join(basedir, expname,args.model_path))) if 'tar' in f] 220 | print('Trainined model path:',os.path.join(basedir, expname,args.model_path)) 221 | print('Found ckpts', ckpts) 222 | if len(ckpts) > 0 and not args.no_reload: 223 | ckpt_path = ckpts[-1] 224 | print('Reloading from', ckpt_path) 225 | ckpt = torch.load(ckpt_path) 226 | 227 | start = ckpt['global_step'] 228 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 229 | 230 | # Load model 231 | model.load_state_dict(ckpt['network_fn_state_dict']) 232 | if model_fine is not None: 233 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 234 | 235 | ########################## 236 | 237 | render_kwargs_train = { 238 | 'network_query_fn' : network_query_fn, 239 | 'perturb' : args.perturb, 240 | 'N_importance' : args.N_importance, 241 | 'network_fine' : model_fine, 242 | 'N_samples' : args.N_samples, 243 | 'network_fn' : model, 244 | 'use_viewdirs' : args.use_viewdirs, 245 | 'white_bkgd' : args.white_bkgd, 246 | 'raw_noise_std' : args.raw_noise_std, 247 | } 248 | 249 | # NDC only good for LLFF-style forward facing data 250 | if args.dataset_type != 'llff' or args.no_ndc: 251 | print('Not ndc!') 252 | render_kwargs_train['ndc'] = False 253 | render_kwargs_train['lindisp'] = args.lindisp 254 | 255 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 256 | render_kwargs_test['perturb'] = False 257 | render_kwargs_test['raw_noise_std'] = 0. 258 | 259 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 260 | 261 | 262 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 263 | """Transforms model's predictions to semantically meaningful values. 264 | Args: 265 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 266 | z_vals: [num_rays, num_samples along ray]. Integration time. 267 | rays_d: [num_rays, 3]. Direction of each ray. 268 | Returns: 269 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 270 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 271 | acc_map: [num_rays]. Sum of weights along each ray. 272 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 273 | depth_map: [num_rays]. Estimated distance to object. 274 | """ 275 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 276 | 277 | dists = z_vals[...,1:] - z_vals[...,:-1] 278 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 279 | 280 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 281 | 282 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 283 | noise = 0. 284 | if raw_noise_std > 0.: 285 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 286 | 287 | # Overwrite randomly sampled data if pytest 288 | if pytest: 289 | np.random.seed(0) 290 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 291 | noise = torch.Tensor(noise) 292 | 293 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 294 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 295 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 296 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 297 | 298 | depth_map = torch.sum(weights * z_vals, -1) 299 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 300 | acc_map = torch.sum(weights, -1) 301 | 302 | if white_bkgd: 303 | rgb_map = rgb_map + (1.-acc_map[...,None]) 304 | 305 | return rgb_map, disp_map, acc_map, weights, depth_map 306 | 307 | 308 | def render_rays(ray_batch, 309 | network_fn, 310 | network_query_fn, 311 | N_samples, 312 | retraw=False, 313 | lindisp=False, 314 | perturb=0., 315 | N_importance=0, 316 | network_fine=None, 317 | white_bkgd=False, 318 | raw_noise_std=0., 319 | verbose=False, 320 | pytest=False): 321 | """Volumetric rendering. 322 | Args: 323 | ray_batch: array of shape [batch_size, ...]. All information necessary 324 | for sampling along a ray, including: ray origin, ray direction, min 325 | dist, max dist, and unit-magnitude viewing direction. 326 | network_fn: function. Model for predicting RGB and density at each point 327 | in space. 328 | network_query_fn: function used for passing queries to network_fn. 329 | N_samples: int. Number of different times to sample along each ray. 330 | retraw: bool. If True, include model's raw, unprocessed predictions. 331 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 332 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 333 | random points in time. 334 | N_importance: int. Number of additional times to sample along each ray. 335 | These samples are only passed to network_fine. 336 | network_fine: "fine" network with same spec as network_fn. 337 | white_bkgd: bool. If True, assume a white background. 338 | raw_noise_std: ... 339 | verbose: bool. If True, print more debugging info. 340 | Returns: 341 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 342 | disp_map: [num_rays]. Disparity map. 1 / depth. 343 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 344 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 345 | rgb0: See rgb_map. Output for coarse model. 346 | disp0: See disp_map. Output for coarse model. 347 | acc0: See acc_map. Output for coarse model. 348 | z_std: [num_rays]. Standard deviation of distances along ray for each 349 | sample. 350 | """ 351 | N_rays = ray_batch.shape[0] 352 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 353 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 354 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 355 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 356 | 357 | t_vals = torch.linspace(0., 1., steps=N_samples) 358 | if not lindisp: 359 | z_vals = near * (1.-t_vals) + far * (t_vals) 360 | else: 361 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 362 | 363 | z_vals = z_vals.expand([N_rays, N_samples]) 364 | 365 | if perturb > 0.: 366 | # get intervals between samples 367 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 368 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 369 | lower = torch.cat([z_vals[...,:1], mids], -1) 370 | # stratified samples in those intervals 371 | t_rand = torch.rand(z_vals.shape) 372 | 373 | # Pytest, overwrite u with numpy's fixed random numbers 374 | if pytest: 375 | np.random.seed(0) 376 | t_rand = np.random.rand(*list(z_vals.shape)) 377 | t_rand = torch.Tensor(t_rand) 378 | 379 | z_vals = lower + (upper - lower) * t_rand 380 | 381 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 382 | 383 | raw = network_query_fn(pts, viewdirs, network_fn) 384 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 385 | 386 | if N_importance > 0: 387 | 388 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 389 | 390 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 391 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 392 | z_samples = z_samples.detach() 393 | 394 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 395 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 396 | 397 | run_fn = network_fn if network_fine is None else network_fine 398 | raw = network_query_fn(pts, viewdirs, run_fn) 399 | 400 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 401 | 402 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} 403 | if retraw: 404 | ret['raw'] = raw 405 | if N_importance > 0: 406 | ret['rgb0'] = rgb_map_0 407 | ret['disp0'] = disp_map_0 408 | ret['acc0'] = acc_map_0 409 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 410 | 411 | for k in ret: 412 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 413 | print(f"! [Numerical Error] {k} contains nan or inf.") 414 | 415 | return ret 416 | 417 | 418 | def config_parser(): 419 | 420 | import configargparse 421 | parser = configargparse.ArgumentParser() 422 | parser.add_argument('--config', is_config_file=True, 423 | help='config file path') 424 | parser.add_argument("--expname", default='Ball',type=str, 425 | help='experiment name') 426 | parser.add_argument("--basedir", type=str, default='logs', 427 | help='where to store ckpts and logs') 428 | parser.add_argument("--datadir", type=str, default='data\\Ball', 429 | help='input data directory') 430 | 431 | # training options 432 | parser.add_argument("--netdepth", type=int, default=8, 433 | help='layers in network') 434 | parser.add_argument("--netwidth", type=int, default=256, 435 | help='channels per layer') 436 | parser.add_argument("--netdepth_fine", type=int, default=8, 437 | help='layers in fine network') 438 | parser.add_argument("--netwidth_fine", type=int, default=256, 439 | help='channels per layer in fine network') 440 | parser.add_argument("--N_rand", type=int, default=32*32, 441 | help='batch size (number of random rays per gradient step)') 442 | parser.add_argument("--lrate", type=float, default=5e-4, 443 | help='learning rate') 444 | parser.add_argument("--lrate_decay", type=int, default=150, 445 | help='exponential learning rate decay (in 1000 steps)') 446 | parser.add_argument("--chunk", type=int, default=1024*32, 447 | help='number of rays processed in parallel, decrease if running out of memory') 448 | parser.add_argument("--netchunk", type=int, default=1024*64, 449 | help='number of pts sent through network in parallel, decrease if running out of memory') 450 | parser.add_argument("--no_batching", default=False,action='store_true', 451 | help='only take random rays from 1 image at a time') 452 | parser.add_argument("--no_reload", action='store_true', 453 | help='do not reload weights from saved ckpt') 454 | parser.add_argument("--model_path", type=str, default='model_weights', 455 | help='path to trained model weights') 456 | # rendering options 457 | parser.add_argument("--N_samples", type=int, default=64, 458 | help='number of coarse samples per ray') 459 | parser.add_argument("--N_importance", type=int, default=64, 460 | help='number of additional fine samples per ray') 461 | parser.add_argument("--perturb", type=float, default=1., 462 | help='set to 0. for no jitter, 1. for jitter') 463 | parser.add_argument("--use_viewdirs", default=True,action='store_true', 464 | help='use full 5D input instead of 3D') 465 | parser.add_argument("--i_embed", type=int, default=0, 466 | help='set 0 for default positional encoding, -1 for none') 467 | parser.add_argument("--multires", type=int, default=10, 468 | help='log2 of max freq for positional encoding (3D location)') 469 | parser.add_argument("--multires_views", type=int, default=4, 470 | help='log2 of max freq for positional encoding (2D direction)') 471 | parser.add_argument("--raw_noise_std", type=float, default=1., 472 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 473 | 474 | parser.add_argument("--render_only", action='store_true', 475 | help='do not optimize, reload weights and render out render_poses path') 476 | parser.add_argument("--render_test", action='store_true', 477 | help='render the test set instead of render_poses path') 478 | parser.add_argument("--render_factor", type=int, default=0, 479 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 480 | 481 | # training options 482 | parser.add_argument("--precrop_iters", type=int, default=0, 483 | help='number of steps to train on central crops') 484 | parser.add_argument("--precrop_frac", type=float, 485 | default=.5, help='fraction of img taken for central crops') 486 | 487 | # dataset options 488 | parser.add_argument("--dataset_type", type=str, default='llff', 489 | help='options: llff / blender / deepvoxels') 490 | parser.add_argument("--testskip", type=int, default=None, 491 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 492 | 493 | ## deepvoxels flags 494 | parser.add_argument("--shape", type=str, default='greek', 495 | help='options : armchair / cube / greek / vase') 496 | 497 | ## blender flags 498 | parser.add_argument("--white_bkgd", action='store_true', 499 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 500 | parser.add_argument("--half_res", action='store_true', 501 | help='load blender synthetic data at 400x400 instead of 800x800') 502 | 503 | ## llff flags 504 | parser.add_argument("--factor", type=int, default=1, 505 | help='downsample factor for LLFF images') 506 | parser.add_argument("--no_ndc", default=True, action='store_true', 507 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 508 | parser.add_argument("--lindisp", default=False, action='store_true', 509 | help='sampling linearly in disparity rather than depth') 510 | parser.add_argument("--spherify", default=True,action='store_true', 511 | help='set for spherical 360 scenes') 512 | parser.add_argument("--llffhold", type=int, default=10, 513 | help='will take every 1/N images as LLFF test set, paper uses 8') 514 | 515 | # logging/saving options 516 | parser.add_argument("--i_print", type=int, default=100, 517 | help='frequency of console printout and metric loggin') 518 | parser.add_argument("--i_img", type=int, default=1000, 519 | help='frequency of tensorboard image logging') 520 | parser.add_argument("--i_weights", type=int, default=1000, 521 | help='frequency of weight ckpt saving') 522 | parser.add_argument("--i_testset", type=int, default=400000, 523 | help='frequency of testset saving') 524 | parser.add_argument("--i_video", type=int, default=150000, 525 | help='frequency of render_poses video saving') 526 | 527 | return parser 528 | 529 | 530 | def train(): 531 | 532 | parser = config_parser() 533 | args = parser.parse_args() 534 | 535 | # Load data 536 | K = None 537 | if args.dataset_type == 'llff': 538 | images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, 539 | recenter=True, bd_factor=.75, 540 | spherify=args.spherify) 541 | hwf = poses[0,:3,-1] 542 | poses = poses[:,:3,:4] 543 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 544 | if not isinstance(i_test, list): 545 | i_test = [i_test] 546 | 547 | if args.llffhold > 0: 548 | print('Auto LLFF holdout,', args.llffhold) 549 | i_test = np.arange(images.shape[0])[::args.llffhold] 550 | 551 | i_val = i_test 552 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 553 | (i not in i_test and i not in i_val)]) 554 | 555 | print('DEFINING BOUNDS') 556 | if args.no_ndc: 557 | near = np.ndarray.min(bds) * .9 558 | far = np.ndarray.max(bds) * 1. 559 | 560 | else: 561 | near = 0. 562 | far = 1. 563 | print('NEAR FAR', near, far) 564 | 565 | elif args.dataset_type == 'blender': 566 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) 567 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 568 | i_train, i_val, i_test = i_split 569 | 570 | near = 2. 571 | far = 6. 572 | 573 | if args.white_bkgd: 574 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 575 | else: 576 | images = images[...,:3] 577 | 578 | elif args.dataset_type == 'LINEMOD': 579 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) 580 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') 581 | print(f'[CHECK HERE] near: {near}, far: {far}.') 582 | i_train, i_val, i_test = i_split 583 | 584 | if args.white_bkgd: 585 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 586 | else: 587 | images = images[...,:3] 588 | 589 | elif args.dataset_type == 'deepvoxels': 590 | 591 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 592 | basedir=args.datadir, 593 | testskip=args.testskip) 594 | 595 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) 596 | i_train, i_val, i_test = i_split 597 | 598 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) 599 | near = hemi_R-1. 600 | far = hemi_R+1. 601 | 602 | else: 603 | print('Unknown dataset type', args.dataset_type, 'exiting') 604 | return 605 | 606 | # Cast intrinsics to right types 607 | H, W, focal = hwf 608 | H, W = int(H), int(W) 609 | hwf = [H, W, focal] 610 | 611 | if K is None: 612 | K = np.array([ 613 | [focal, 0, 0.5*W], 614 | [0, focal, 0.5*H], 615 | [0, 0, 1] 616 | ]) 617 | 618 | if args.render_test: 619 | render_poses = np.array(poses[i_test]) 620 | 621 | # Create log dir and copy the config file 622 | basedir = args.basedir 623 | expname = args.expname 624 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 625 | f = os.path.join(basedir, expname, 'args.txt') 626 | with open(f, 'w') as file: 627 | for arg in sorted(vars(args)): 628 | attr = getattr(args, arg) 629 | file.write('{} = {}\n'.format(arg, attr)) 630 | if args.config is not None: 631 | f = os.path.join(basedir, expname, 'config.txt') 632 | with open(f, 'w') as file: 633 | file.write(open(args.config, 'r').read()) 634 | 635 | # Create nerf model 636 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 637 | global_step = start 638 | 639 | bds_dict = { 640 | 'near' : near, 641 | 'far' : far, 642 | } 643 | render_kwargs_train.update(bds_dict) 644 | render_kwargs_test.update(bds_dict) 645 | 646 | # Move testing data to GPU 647 | render_poses = torch.Tensor(render_poses).to(device) 648 | 649 | # Short circuit if only rendering out from trained model 650 | if args.render_only: 651 | print('RENDER ONLY') 652 | with torch.no_grad(): 653 | if args.render_test: 654 | # render_test switches to test poses 655 | images = images[i_test] 656 | else: 657 | # Default is smoother render_poses path 658 | images = None 659 | 660 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 661 | os.makedirs(testsavedir, exist_ok=True) 662 | print('test poses shape', render_poses.shape) 663 | 664 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) 665 | print('Done rendering', testsavedir) 666 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 667 | 668 | return 669 | 670 | # Prepare raybatch tensor if batching random rays 671 | N_rand = args.N_rand 672 | use_batching = not args.no_batching 673 | if use_batching: 674 | # For random ray batching 675 | print('get rays') 676 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 677 | print('done, concats') 678 | 679 | rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] 680 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 681 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 682 | rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3] 683 | rays_rgb = rays_rgb.astype(np.float32) 684 | print('shuffle rays') 685 | np.random.shuffle(rays_rgb) 686 | print('done') 687 | i_batch = 0 688 | 689 | # Move training data to GPU 690 | if use_batching: 691 | images = torch.Tensor(images).to(device) 692 | poses = torch.Tensor(poses).to(device) 693 | if use_batching: 694 | rays_rgb = torch.Tensor(rays_rgb).to(device) 695 | 696 | 697 | 698 | 699 | 700 | 701 | N_iters = 150000 + 1 702 | print('Begin') 703 | print('TRAIN views are', i_train) 704 | print('TEST views are', i_test) 705 | print('VAL views are', i_val) 706 | 707 | # Summary writers 708 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) 709 | 710 | start = start + 1 711 | for i in trange(start, N_iters): 712 | time0 = time.time() 713 | 714 | # Sample random ray batch 715 | if use_batching: 716 | # Random over all images 717 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] 718 | batch = torch.transpose(batch, 0, 1) 719 | 720 | batch_rays, target_s = batch[:2], batch[2] 721 | 722 | i_batch += N_rand 723 | if i_batch >= rays_rgb.shape[0]: 724 | print("Shuffle data after an epoch!") 725 | rand_idx = torch.randperm(rays_rgb.shape[0]) 726 | rays_rgb = rays_rgb[rand_idx] 727 | i_batch = 0 728 | 729 | else: 730 | # Random from one image 731 | img_i = np.random.choice(i_train) 732 | target = images[img_i] 733 | target = torch.Tensor(target).to(device) 734 | pose = poses[img_i, :3,:4] 735 | 736 | if N_rand is not None: 737 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) 738 | 739 | if i < args.precrop_iters: 740 | dH = int(H//2 * args.precrop_frac) 741 | dW = int(W//2 * args.precrop_frac) 742 | coords = torch.stack( 743 | torch.meshgrid( 744 | torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 745 | torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) 746 | ), -1) 747 | if i == start: 748 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") 749 | else: 750 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) 751 | 752 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 753 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 754 | select_coords = coords[select_inds].long() # (N_rand, 2) 755 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 756 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 757 | batch_rays = torch.stack([rays_o, rays_d], 0) 758 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 759 | 760 | ##### Core optimization loop ##### 761 | rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 762 | verbose=i < 10, retraw=True, 763 | **render_kwargs_train) 764 | 765 | optimizer.zero_grad() 766 | img_loss = img2mse(rgb, target_s) 767 | 768 | loss = img_loss 769 | psnr = mse2psnr(img_loss) 770 | 771 | if 'rgb0' in extras: 772 | img_loss0 = img2mse(extras['rgb0'], target_s) 773 | loss = loss + img_loss0 774 | psnr0 = mse2psnr(img_loss0) 775 | 776 | loss.backward() 777 | optimizer.step() 778 | 779 | # NOTE: IMPORTANT! 780 | ### update learning rate ### 781 | decay_rate = 0.1 782 | decay_steps = args.lrate_decay * 1000 783 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 784 | for param_group in optimizer.param_groups: 785 | param_group['lr'] = new_lrate 786 | ################################ 787 | 788 | dt = time.time()-time0 789 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") 790 | ##### end ##### 791 | 792 | # Rest is logging 793 | if i%args.i_weights==0: 794 | model_path_dir = os.path.join(basedir, expname,args.model_path) 795 | os.makedirs(model_path_dir, exist_ok=True) 796 | path = os.path.join(basedir, expname,args.model_path, 'weights.tar'.format(i)) 797 | 798 | torch.save({ 799 | 'global_step': global_step, 800 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 801 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 802 | 'optimizer_state_dict': optimizer.state_dict(), 803 | }, path) 804 | print('Saved checkpoints at', path) 805 | 806 | if i%args.i_video==0 and i > 0: 807 | # Turn on testing mode 808 | with torch.no_grad(): 809 | rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) 810 | print('Done, saving', rgbs.shape, disps.shape) 811 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 812 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=15, quality=8) 813 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) 814 | 815 | if i%args.i_testset==0 and i > 0: 816 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 817 | os.makedirs(testsavedir, exist_ok=True) 818 | print('test poses shape', poses[i_test].shape) 819 | with torch.no_grad(): 820 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) 821 | print('Saved test set') 822 | 823 | if i%args.i_img==0: 824 | 825 | # Log a rendered validation view to Tensorboard 826 | img_i= i_test[5] #np.random.choice(i_val) 827 | target = images[img_i] 828 | pose = poses[img_i, :3,:4] 829 | with torch.no_grad(): 830 | rgb, disp, acc, extras = render(H, W,K, chunk=args.chunk, c2w=pose, 831 | **render_kwargs_test) 832 | testimgdir = os.path.join(basedir, expname, 'training_nerf') 833 | os.makedirs(testimgdir, exist_ok=True) 834 | imageio.imwrite(os.path.join(testimgdir, 'rendered_{:06d}.png'.format(i)), to8b(rgb.cpu().numpy())) 835 | imageio.imwrite(os.path.join(testimgdir, 'ref_{:06d}.png'.format(img_i)), to8b(images[img_i].cpu().numpy())) 836 | 837 | 838 | if i%args.i_print==0: 839 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 840 | 841 | 842 | global_step += 1 843 | 844 | 845 | if __name__=='__main__': 846 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 847 | 848 | train() 849 | --------------------------------------------------------------------------------