├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── blender.py ├── depth_utils.py └── ray_utils.py ├── losses.py ├── metrics.py ├── models ├── __init__.py ├── nerf.py ├── rendering.py └── sh.py ├── opt.py ├── requirements.txt ├── train.py ├── train.sh └── utils ├── __init__.py ├── optimizers.py ├── visualization.py └── warmup_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | scripts/ 3 | logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## The official code for "[EfficientNeRF: Efficient Neural Radiance Fields](https://arxiv.org/abs/2206.00878)" in CVPR2022. 2 | 3 | ### Environment (Tested) 4 | - Ubuntu 18.04 5 | - Python 3.7 6 | - CUDA 11.x 7 | - Pytorch 1.9.1 8 | - Pytorch-Lightning 1.6.4 9 | 10 | ### Install via Anaconda 11 | ``` 12 | $ conda create -n EfficientNeRF python=3.8 13 | $ conda activate EfficientNeRF 14 | $ pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 15 | $ pip install -r requirements.txt 16 | ``` 17 | 18 | ### Training 19 | ``` 20 | $ DATA_DIR=/path/to/lego 21 | $ python train.py \ 22 | --dataset_name blender \ 23 | --root_dir $DATA_DIR \ 24 | --N_samples 128 \ 25 | --N_importance 5 --img_wh 800 800 \ 26 | --num_epochs 16 --batch_size 4096 \ 27 | --optimizer radam --lr 2e-3 \ 28 | --lr_scheduler poly \ 29 | --coord_scope 3.0 \ 30 | --warmup_step 5000\ 31 | --sigma_init 30.0 \ 32 | --weight_threashold 1e-5 \ 33 | --exp_name lego_coarse128_fine5_V384 34 | ``` 35 | 36 | ### Visualization 37 | ``` 38 | $ tensorboard --logdir=./logs 39 | ``` 40 | 41 | ### Question 42 | - Q1. Different hyperparameters from the original paper 43 | * A1. There are many combinations between these hyperparameters. You are free to balance the training speed and accuracy by modify them. 44 | - Q2. When will NeRF-Tree released? 45 | * A2. Hard to say a specific date. The data structure NeRF-Tree is closed to Octree. 46 | 47 | ### Progress 48 | More scenes and applications will be suported soon. Stay tune! 49 | 50 | ### Acknowledgement 51 | Our initial code was borrowed from 52 | - [nerf-pl:https://github.com/kwea123/nerf_pl](https://github.com/kwea123/nerf_pl) 53 | 54 | ### Citation 55 | If you find our code or paper helps, please cite our paper: 56 | ``` 57 | @InProceedings{Hu_2022_CVPR, 58 | author = {Hu, Tao and Liu, Shu and Chen, Yilun and Shen, Tiancheng and Jia, Jiaya}, 59 | title = {EfficientNeRF Efficient Neural Radiance Fields}, 60 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 61 | month = {June}, 62 | year = {2022}, 63 | pages = {12902-12911} 64 | } 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .blender import BlenderDataset 2 | 3 | dataset_dict = {'blender': BlenderDataset} -------------------------------------------------------------------------------- /datasets/blender.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | trans_t = lambda t : torch.Tensor([ 12 | [1,0,0,0], 13 | [0,1,0,0], 14 | [0,0,1,t], 15 | [0,0,0,1]]).float() 16 | 17 | rot_phi = lambda phi : torch.Tensor([ 18 | [1,0,0,0], 19 | [0,np.cos(phi),-np.sin(phi),0], 20 | [0,np.sin(phi), np.cos(phi),0], 21 | [0,0,0,1]]).float() 22 | 23 | rot_theta = lambda th : torch.Tensor([ 24 | [np.cos(th),0,-np.sin(th),0], 25 | [0,1,0,0], 26 | [np.sin(th),0, np.cos(th),0], 27 | [0,0,0,1]]).float() 28 | 29 | 30 | def pose_spherical(theta, phi, radius): 31 | c2w = trans_t(radius) 32 | c2w = rot_phi(phi/180.*np.pi) @ c2w 33 | c2w = rot_theta(theta/180.*np.pi) @ c2w 34 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 35 | return c2w 36 | 37 | 38 | class BlenderDataset(Dataset): 39 | def __init__(self, root_dir, split='train', img_wh=(800, 800)): 40 | self.root_dir = root_dir 41 | self.split = split 42 | assert img_wh[0] == img_wh[1], 'image width must equal image height!' 43 | self.img_wh = img_wh 44 | self.define_transforms() 45 | 46 | self.read_meta() 47 | self.white_back = True 48 | 49 | def read_meta(self): 50 | with open(os.path.join(self.root_dir, 51 | f"transforms_{self.split}.json"), 'r') as f: 52 | self.meta = json.load(f) 53 | 54 | w, h = self.img_wh 55 | self.focal = 0.5*800/np.tan(0.5*self.meta['camera_angle_x']) # original focal length 56 | # when W=800 57 | 58 | self.focal *= self.img_wh[0]/800 # modify focal length to match size self.img_wh 59 | 60 | # bounds, common for all scenes 61 | self.near = 2.0 62 | self.far = 6.0 63 | self.bounds = np.array([self.near, self.far]) 64 | 65 | # ray directions for all pixels, same for all images (same H, W, focal) 66 | self.directions = \ 67 | get_ray_directions(h, w, self.focal) # (h, w, 3) 68 | 69 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 70 | 71 | if self.split == 'train': # create buffer of all rays and rgb data 72 | self.image_paths = [] 73 | self.poses = [] 74 | self.all_rays = [] 75 | self.all_rgbs = [] 76 | for frame in self.meta['frames']: 77 | pose = np.array(frame['transform_matrix'])[:3, :4] 78 | self.poses += [pose] 79 | c2w = torch.FloatTensor(pose) 80 | 81 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 82 | self.image_paths += [image_path] 83 | img = Image.open(image_path) 84 | if self.img_wh[0] != img.size[0]: 85 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 86 | img = self.transform(img) # (4, h, w) 87 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 88 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB 89 | self.all_rgbs += [img] 90 | 91 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 92 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 93 | self.all_rays += [torch.cat([rays_o, rays_d], 94 | 1)] # (h*w, 8) 95 | 96 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 97 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 98 | elif self.split == 'test': 99 | self.meta['frames'] = self.meta['frames'][::10] # we select 1/10 for fast testing during training 100 | else: 101 | self.pose_vis = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,1000+1)[:-1]], 0) 102 | 103 | 104 | def define_transforms(self): 105 | self.transform = T.ToTensor() 106 | 107 | def __len__(self): 108 | if self.split == 'train': 109 | return len(self.all_rays) 110 | elif self.split == 'test': 111 | return len(self.meta['frames']) 112 | elif self.split == 'val': 113 | return self.pose_vis.shape[0] 114 | 115 | def __getitem__(self, idx): 116 | if self.split == 'train': # use data in the buffers 117 | sample = {'rays': self.all_rays[idx], 118 | 'rgbs': self.all_rgbs[idx]} 119 | 120 | elif self.split == 'test': # create data for each image separately 121 | frame = self.meta['frames'][idx] 122 | c2w = torch.FloatTensor(frame['transform_matrix'])[:3, :4] 123 | 124 | img = Image.open(os.path.join(self.root_dir, f"{frame['file_path']}.png")) 125 | if self.img_wh[0] != img.size[0]: 126 | img = img.resize(self.img_wh, Image.Resampling.LANCZOS) 127 | img = self.transform(img) # (4, H, W) 128 | valid_mask = (img[-1]>0).flatten() # (H*W) valid color area 129 | img = img.view(4, -1).permute(1, 0) # (H*W, 4) RGBA 130 | img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB 131 | 132 | rays_o, rays_d = get_rays(self.directions, c2w) 133 | 134 | rays = torch.cat([rays_o, rays_d], 1) # (H*W, 8) 135 | 136 | sample = {'rays': rays, 137 | 'rgbs': img, 138 | 'c2w': c2w, 139 | 'valid_mask': valid_mask} 140 | elif self.split == 'val': 141 | c2w = torch.FloatTensor(self.pose_vis[idx])[:3, :4] 142 | rays_o, rays_d = get_rays(self.directions, c2w) 143 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 144 | rays = torch.cat([rays_o, rays_d], 1) # (H*W, 8) 145 | 146 | sample = {'rays': rays} 147 | 148 | return sample -------------------------------------------------------------------------------- /datasets/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | 5 | def read_pfm(filename): 6 | file = open(filename, 'rb') 7 | color = None 8 | width = None 9 | height = None 10 | scale = None 11 | endian = None 12 | 13 | header = file.readline().decode('utf-8').rstrip() 14 | if header == 'PF': 15 | color = True 16 | elif header == 'Pf': 17 | color = False 18 | else: 19 | raise Exception('Not a PFM file.') 20 | 21 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 22 | if dim_match: 23 | width, height = map(int, dim_match.groups()) 24 | else: 25 | raise Exception('Malformed PFM header.') 26 | 27 | scale = float(file.readline().rstrip()) 28 | if scale < 0: # little-endian 29 | endian = '<' 30 | scale = -scale 31 | else: 32 | endian = '>' # big-endian 33 | 34 | data = np.fromfile(file, endian + 'f') 35 | shape = (height, width, 3) if color else (height, width) 36 | 37 | data = np.reshape(data, shape) 38 | data = np.flipud(data) 39 | file.close() 40 | return data, scale 41 | 42 | 43 | def save_pfm(filename, image, scale=1): 44 | file = open(filename, "wb") 45 | color = None 46 | 47 | image = np.flipud(image) 48 | 49 | if image.dtype.name != 'float32': 50 | raise Exception('Image dtype must be float32.') 51 | 52 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 53 | color = True 54 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 55 | color = False 56 | else: 57 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 58 | 59 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 60 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 61 | 62 | endian = image.dtype.byteorder 63 | 64 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 65 | scale = -scale 66 | 67 | file.write(('%f\n' % scale).encode('utf-8')) 68 | 69 | image.tofile(file) 70 | file.close() -------------------------------------------------------------------------------- /datasets/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | 4 | 5 | def get_ray_directions(H, W, focal): 6 | """ 7 | Get ray directions for all pixels in camera coordinate. 8 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 9 | ray-tracing-generating-camera-rays/standard-coordinate-systems 10 | 11 | Inputs: 12 | H, W, focal: image height, width and focal length 13 | 14 | Outputs: 15 | directions: (H, W, 3), the direction of the rays in camera coordinate 16 | """ 17 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 18 | i, j = grid.unbind(-1) 19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 20 | # see https://github.com/bmild/nerf/issues/24 21 | directions = \ 22 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3) 23 | 24 | return directions 25 | 26 | 27 | def get_rays(directions, c2w): 28 | """ 29 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 30 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 31 | ray-tracing-generating-camera-rays/standard-coordinate-systems 32 | Inputs: 33 | directions: (H, W, 3) precomputed ray directions in camera coordinate 34 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 35 | Outputs: 36 | rays_o: (H*W, 3), the origin of the rays in world coordinate 37 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 38 | """ 39 | # Rotate ray directions from camera coordinate to the world coordinate 40 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 41 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 42 | # The origin of all rays is the camera origin in world coordinate 43 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 44 | 45 | rays_d = rays_d.view(-1, 3) 46 | rays_o = rays_o.view(-1, 3) 47 | 48 | return rays_o, rays_d 49 | 50 | 51 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 52 | # Shift ray origins to near plane 53 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 54 | rays_o = rays_o + t[..., None] * rays_d 55 | 56 | # Projection 57 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 58 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 59 | o2 = 1. + 2. * near / rays_o[..., 2] 60 | 61 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 62 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 63 | d2 = -2. * near / rays_o[..., 2] 64 | 65 | rays_o = torch.stack([o0, o1, o2], -1) 66 | rays_d = torch.stack([d0, d1, d2], -1) 67 | 68 | return rays_o, rays_d -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class MSELoss(nn.Module): 5 | def __init__(self): 6 | super(MSELoss, self).__init__() 7 | self.loss = nn.MSELoss(reduction='mean') 8 | 9 | def forward(self, inputs, targets): 10 | loss = 0.0 11 | if 'rgb_coarse' in inputs: 12 | loss += self.loss(inputs['rgb_coarse'], targets) 13 | if 'rgb_fine' in inputs: 14 | loss += self.loss(inputs['rgb_fine'], targets) 15 | 16 | return loss 17 | 18 | 19 | loss_dict = {'mse': MSELoss} -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt)**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 13 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 14 | 15 | def ssim(image_pred, image_gt, reduction='mean'): 16 | """ 17 | image_pred and image_gt: (1, 3, H, W) 18 | """ 19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1] 20 | return 1-2*dssim_ # in [-1, 1] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/EfficientNeRF/fc89c0cb54a123e3e2a047676eb2bd0604883d38/models/__init__.py -------------------------------------------------------------------------------- /models/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.sh import eval_sh 4 | 5 | class Embedding(nn.Module): 6 | def __init__(self, in_channels, N_freqs, logscale=True): 7 | """ 8 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 9 | in_channels: number of input channels (3 for both xyz and direction) 10 | """ 11 | super(Embedding, self).__init__() 12 | self.N_freqs = N_freqs 13 | self.in_channels = in_channels 14 | self.funcs = [torch.sin, torch.cos] 15 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 16 | 17 | if logscale: 18 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 19 | else: 20 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 21 | 22 | def forward(self, x): 23 | """ 24 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 25 | Different from the paper, "x" is also in the output 26 | See https://github.com/bmild/nerf/issues/12 27 | 28 | Inputs: 29 | x: (B, self.in_channels) 30 | 31 | Outputs: 32 | out: (B, self.out_channels) 33 | """ 34 | out = [x] 35 | for freq in self.freq_bands: 36 | for func in self.funcs: 37 | out += [func(freq*x)] 38 | 39 | return torch.cat(out, -1) 40 | 41 | 42 | class NeRF(nn.Module): 43 | def __init__(self, 44 | D=8, W=256, 45 | in_channels_xyz=63, in_channels_dir=27, 46 | skips=[4], deg=2): 47 | """ 48 | D: number of layers for density (sigma) encoder 49 | W: number of hidden units in each layer 50 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default) 51 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) 52 | skips: add skip connection in the Dth layer 53 | """ 54 | super(NeRF, self).__init__() 55 | self.D = D 56 | self.W = W 57 | self.in_channels_xyz = in_channels_xyz 58 | self.in_channels_dir = in_channels_dir 59 | self.skips = skips 60 | self.deg = deg 61 | 62 | # xyz encoding layers 63 | for i in range(D): 64 | if i == 0: 65 | layer = nn.Linear(in_channels_xyz, W) 66 | elif i in skips: 67 | layer = nn.Linear(W+in_channels_xyz, W) 68 | else: 69 | layer = nn.Linear(W, W) 70 | layer = nn.Sequential(layer, nn.ReLU(True)) 71 | setattr(self, f"xyz_encoding_{i+1}", layer) 72 | # self.xyz_encoding_final = nn.Linear(W, W) 73 | 74 | # # direction encoding layers 75 | # self.dir_encoding = nn.Sequential( 76 | # nn.Linear(W+in_channels_dir, W), 77 | # nn.ReLU(True)) 78 | 79 | # output layers 80 | self.sigma = nn.Sequential(nn.Linear(W, W), 81 | nn.ReLU(True), 82 | nn.Linear(W, 1)) 83 | # self.sh = nn.Linear(W, 3 * (self.deg + 1)**2) 84 | self.sh = nn.Sequential(nn.Linear(W, W), 85 | nn.ReLU(True), 86 | nn.Linear(W, 3 * (self.deg + 1)**2)) 87 | 88 | def forward(self, x, dirs=None, sigma_sh_only=False): 89 | """ 90 | Encodes input (xyz+dir) to sh+sigma (not ready to render yet). 91 | For rendering this ray, please see rendering.py 92 | 93 | Inputs: 94 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 95 | the embedded vector of position and direction 96 | sigma_only: whether to infer sigma only. If True, 97 | x is of shape (B, self.in_channels_xyz) 98 | 99 | Outputs: 100 | if sigma_ony: 101 | sigma: (B, 1) sigma 102 | else: 103 | out: (B, 4), sh and sigma 104 | """ 105 | input_xyz = x 106 | 107 | xyz_ = input_xyz 108 | for i in range(self.D): 109 | if i in self.skips: 110 | xyz_ = torch.cat([input_xyz, xyz_], -1) 111 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 112 | 113 | sigma = self.sigma(xyz_) 114 | sh = self.sh(xyz_) 115 | 116 | if sigma_sh_only: 117 | out = torch.cat([sigma, sh], -1) 118 | return out 119 | 120 | rgb = eval_sh(deg=self.deg, sh=sh.reshape(-1, 3, (self.deg + 1)**2), dirs=dirs) # sh: [..., C, (deg + 1) ** 2] 121 | rgb = torch.sigmoid(rgb) 122 | 123 | # if extract_time: 124 | out = torch.cat([sigma, rgb, sh], -1) 125 | return out -------------------------------------------------------------------------------- /models/rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from torchsearchsorted import searchsorted 3 | 4 | __all__ = ['render_rays'] 5 | 6 | 7 | def render_rays(models, 8 | embeddings, 9 | rays, 10 | sigma_voxels, 11 | N_samples=64, 12 | use_disp=False, 13 | perturb=0, 14 | noise_std=1, 15 | N_importance=0, 16 | chunk=1024*32, 17 | white_back=False, 18 | test_time=False 19 | ): 20 | """ 21 | Render rays by computing the output of @model applied on @rays 22 | 23 | Inputs: 24 | models: list of NeRF models (coarse and fine) defined in nerf.py 25 | embeddings: list of embedding models of origin and direction defined in nerf.py 26 | rays: (N_rays, 3+3+2), ray origins, directions and near, far depth bounds 27 | N_samples: number of coarse samples per ray 28 | use_disp: whether to sample in disparity space (inverse depth) 29 | perturb: factor to perturb the sampling position on the ray (for coarse model only) 30 | noise_std: factor to perturb the model's prediction of sigma 31 | N_importance: number of fine samples per ray 32 | chunk: the chunk size in batched inference 33 | white_back: whether the background is white (dataset dependent) 34 | test_time: whether it is test (inference only) or not. If True, it will not do inference 35 | on coarse rgb to save time 36 | 37 | Outputs: 38 | result: dictionary containing final rgb and depth maps for coarse and fine models 39 | """ 40 | 41 | def sigma2weights(z_vals, sigmas, dirs): 42 | # Convert these values using volume rendering (Section 4) 43 | deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1) 44 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity 45 | deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 46 | 47 | # Multiply each distance by the norm of its corresponding direction ray 48 | # to convert to real world distance (accounts for non-unit directions). 49 | deltas = deltas * torch.norm(dirs.unsqueeze(1), dim=-1) 50 | 51 | noise = torch.randn(sigmas.shape, device=sigmas.device) * 0.0 52 | 53 | # compute alpha by the formula (3) 54 | alphas = 1-torch.exp(-deltas*torch.nn.Softplus()(sigmas+noise)) # (N_rays, N_samples_) 55 | alphas_shifted = \ 56 | torch.cat([torch.ones_like(alphas[:, :1]), 1-alphas+1e-10], -1) # [1, a1, a2, ...] 57 | weights = \ 58 | alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_) 59 | return weights 60 | 61 | 62 | def inference(model, embedding_xyz, xyz_, dirs, dir_embedded, z_vals, idx_render, weights_only=False): 63 | """ 64 | Helper function that performs model inference. 65 | 66 | Inputs: 67 | model: NeRF model (coarse or fine) 68 | embedding_xyz: embedding module for xyz 69 | xyz_: (N_rays, N_samples_, 3) sampled positions 70 | N_samples_ is the number of sampled points in each ray; 71 | = N_samples for coarse model 72 | = N_samples+N_importance for fine model 73 | dirs: (N_rays, 3) ray directions 74 | dir_embedded: (N_rays, embed_dir_channels) embedded directions 75 | z_vals: (N_rays, N_samples_) depths of the sampled positions 76 | weights_only: do inference on sigma only or not 77 | 78 | Outputs: 79 | if weights_only: 80 | weights: (N_rays, N_samples_): weights of each sample 81 | else: 82 | rgb_final: (N_rays, 3) the final rgb image 83 | depth_final: (N_rays) depth map 84 | weights: (N_rays, N_samples_): weights of each sample 85 | """ 86 | N_samples_ = xyz_.shape[1] 87 | # Embed directions 88 | xyz_ = xyz_[idx_render[:, 0], idx_render[:, 1]].view(-1, 3) # (N_rays*N_samples_, 3) 89 | if not weights_only: 90 | dir_embedded = dir_embedded.unsqueeze(1).expand(-1, N_samples_, -1) 91 | dir_embedded = dir_embedded[idx_render[:, 0], idx_render[:, 1]] 92 | view_dir = dirs.unsqueeze(1).expand(-1, N_samples_, -1) 93 | view_dir = view_dir[idx_render[:, 0], idx_render[:, 1]] 94 | # Perform model inference to get rgb and raw sigma 95 | B = xyz_.shape[0] 96 | out_chunks = [] 97 | for i in range(0, B, chunk): 98 | # Embed positions by chunk 99 | xyz_embedded = embedding_xyz(xyz_[i:i+chunk]) 100 | if not weights_only: 101 | xyzdir_embedded = torch.cat([xyz_embedded, 102 | dir_embedded[i:i+chunk]], 1) 103 | else: 104 | xyzdir_embedded = xyz_embedded 105 | out_chunks += [model(xyzdir_embedded, view_dir[i:i+chunk], sigma_only=weights_only)] 106 | 107 | out = torch.cat(out_chunks, 0) 108 | if weights_only: 109 | out_sigma = torch.full((N_rays, N_samples_, 1), -20.0, device=rays.device) 110 | out_sigma[idx_render[:, 0], idx_render[:, 1]] = out 111 | out = out_sigma 112 | sigmas = out.view(N_rays, N_samples_) 113 | else: 114 | out_rgb = torch.full((N_rays, N_samples_, 3), 1.0, device=rays.device) 115 | out_sigma = torch.full((N_rays, N_samples_, 1), -20.0, device=rays.device) 116 | out_defaults = torch.cat([out_rgb, out_sigma], dim=2) 117 | out_defaults[idx_render[:, 0], idx_render[:, 1]] = out 118 | out = out_defaults 119 | 120 | rgbsigma = out.view(N_rays, N_samples_, 4) 121 | rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3) 122 | sigmas = rgbsigma[..., 3] # (N_rays, N_samples_) 123 | 124 | # Convert these values using volume rendering (Section 4) 125 | # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1) 126 | # delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity 127 | # deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 128 | 129 | # # Multiply each distance by the norm of its corresponding direction ray 130 | # # to convert to real world distance (accounts for non-unit directions). 131 | # deltas = deltas * torch.norm(dirs.unsqueeze(1), dim=-1) 132 | 133 | # noise = torch.randn(sigmas.shape, device=sigmas.device) * noise_std 134 | 135 | # # compute alpha by the formula (3) 136 | # alphas = 1-torch.exp(-deltas*torch.nn.Softplus()(sigmas+noise)) # (N_rays, N_samples_) 137 | # alphas_shifted = \ 138 | # torch.cat([torch.ones_like(alphas[:, :1]), 1-alphas+1e-10], -1) # [1, a1, a2, ...] 139 | # weights = \ 140 | # alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_) 141 | 142 | weights = sigma2weights(z_vals, sigmas, dirs) 143 | weights_sum = weights.sum(1) # (N_rays), the accumulated opacity along the rays 144 | # equals "1 - (1-a1)(1-a2)...(1-an)" mathematically 145 | if weights_only: 146 | return weights 147 | 148 | # compute final weighted outputs 149 | rgb_final = torch.sum(weights.unsqueeze(-1)*rgbs, -2) # (N_rays, 3) 150 | depth_final = torch.sum(weights*z_vals, -1) # (N_rays) 151 | 152 | if white_back: 153 | rgb_final = rgb_final + 1-weights_sum.unsqueeze(-1) 154 | 155 | return rgb_final, depth_final, weights, rgbs, sigmas 156 | 157 | 158 | # Extract models from lists 159 | model_coarse = models[0] 160 | embedding_xyz = embeddings[0] 161 | embedding_dir = embeddings[1] 162 | 163 | is_training = model_coarse.training 164 | 165 | # Decompose the inputs 166 | N_rays = rays.shape[0] 167 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 168 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 169 | 170 | # Embed direction 171 | dir_embedded = embedding_dir(rays_d) # (N_rays, embed_dir_channels) 172 | 173 | # Sample depth points 174 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 175 | if not use_disp: # use linear sampling in depth space 176 | z_vals = near * (1-z_steps) + far * z_steps 177 | else: # use linear sampling in disparity space 178 | z_vals = 1/(1/near * (1-z_steps) + 1/far * z_steps) 179 | 180 | z_vals = z_vals.expand(N_rays, N_samples) 181 | 182 | if is_training: 183 | z_vals = z_vals + torch.empty_like(z_vals).normal_(0.0, 0.002) * (far - near) 184 | 185 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 186 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 187 | 188 | # range of voxel 189 | scope = 1.1 * (far[0, 0] - near[0, 0]) 190 | xyz_coarse_sampled_ = xyz_coarse_sampled.reshape(-1, 3) 191 | 192 | VOXEL_SIZE = sigma_voxels.shape[0] 193 | idx_voxels = ((xyz_coarse_sampled_ / scope / 2 + 0.5) * VOXEL_SIZE).round() 194 | idx_voxels = torch.clamp(idx_voxels, 0, VOXEL_SIZE-1).long() 195 | N_rays = rays.shape[0] 196 | sigmas = sigma_voxels[idx_voxels[:, 0], idx_voxels[:, 1], idx_voxels[:, 2]].reshape(N_rays, N_samples) 197 | weights = sigma2weights(z_vals, sigmas, rays_d) 198 | 199 | if is_training: 200 | idx_render_coarse = torch.nonzero(sigmas >= -20) 201 | 202 | rgb_coarse, depth_coarse, weights_coarse, colors_coarse, sigmas_coarse = \ 203 | inference(model_coarse, embedding_xyz, xyz_coarse_sampled, rays_d, 204 | dir_embedded, z_vals, idx_render_coarse, weights_only=False) 205 | result = {'rgb_coarse': rgb_coarse, 206 | 'depth_coarse': depth_coarse, 207 | 'opacity_coarse': weights_coarse.sum(1), 208 | # 'z_val_coarse': z_vals, 209 | # 'sigma_coarse': sigmas_coarse, 210 | # 'weight_coarse': weights_coarse 211 | } 212 | sigma_voxels[idx_voxels[:, 0], idx_voxels[:, 1], idx_voxels[:, 2]] = \ 213 | 0.9 * sigma_voxels[idx_voxels[:, 0], idx_voxels[:, 1], idx_voxels[:, 2]] + \ 214 | 0.1 * sigmas_coarse.reshape(-1).detach() 215 | else: 216 | weights_coarse = weights 217 | result = { 218 | # 'weight_coarse': weights_coarse 219 | } 220 | 221 | if N_importance > 0: # sample points for fine model 222 | idx_render = torch.nonzero(weights_coarse >= min(1e-3, weights_coarse.max().item())).long() # (M, 2) 223 | 224 | scale = N_importance 225 | z_0 = torch.cat([z_vals, z_vals[:, -1:]], dim=-1) 226 | for i in range(1, scale): 227 | z_vals_mid = i / scale * z_0[:, 1:] + (1 - i / scale) * z_0[:, :-1] 228 | z_vals = torch.sort(torch.cat([z_vals, z_vals_mid], dim=-1), dim=-1)[0] 229 | 230 | idxs = [idx_render.clone() for _ in range(scale)] 231 | for i in range(scale): 232 | idxs[i][:, 1] = idxs[i][:, 1] * scale + i - scale // 2 233 | idx_render_fine = torch.cat(idxs, dim=0) 234 | idx_render_fine[:, 1] = torch.clamp(idx_render_fine[:, 1], 0, int(N_samples * scale)) 235 | 236 | if idx_render_fine.shape[0] >= N_rays * 64: 237 | indices = torch.randperm(idx_render_fine.shape[0])[:N_rays * 64] 238 | idx_render_fine = idx_render_fine[indices] 239 | 240 | xyz_fine_sampled = rays_o.unsqueeze(1) + \ 241 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) 242 | # (N_rays, N_samples+N_importance, 3) 243 | 244 | model_fine = models[1] 245 | rgb_fine, depth_fine, weights_fine, colors_fine, sigmas_fine = \ 246 | inference(model_fine, embedding_xyz, xyz_fine_sampled, rays_d, 247 | dir_embedded, z_vals, idx_render_fine, weights_only=False) 248 | 249 | result['rgb_fine'] = rgb_fine 250 | result['depth_fine'] = depth_fine 251 | result['opacity_fine'] = weights_fine.sum(1) 252 | if is_training: 253 | result['mean_samples_coarse'] = torch.FloatTensor([idx_render_coarse.shape[0] / N_rays]) 254 | result['mean_samples_fine'] = torch.FloatTensor([idx_render_fine.shape[0] / N_rays]) 255 | # result['z_val_fine'] = z_vals 256 | # result['sigma_fine'] = sigmas_fine 257 | # result['weight_fine'] = weights_fine 258 | 259 | return result -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | C0 = 0.28209479177387814 25 | C1 = 0.4886025119029199 26 | C2 = [ 27 | 1.0925484305920792, 28 | -1.0925484305920792, 29 | 0.31539156525252005, 30 | -1.0925484305920792, 31 | 0.5462742152960396 32 | ] 33 | C3 = [ 34 | -0.5900435899266435, 35 | 2.890611442640554, 36 | -0.4570457994644658, 37 | 0.3731763325901154, 38 | -0.4570457994644658, 39 | 1.445305721320277, 40 | -0.5900435899266435 41 | ] 42 | C4 = [ 43 | 2.5033429417967046, 44 | -1.7701307697799304, 45 | 0.9461746957575601, 46 | -0.6690465435572892, 47 | 0.10578554691520431, 48 | -0.6690465435572892, 49 | 0.47308734787878004, 50 | -1.7701307697799304, 51 | 0.6258357354491761, 52 | ] 53 | 54 | def eval_sh(deg, sh, dirs): 55 | """ 56 | Evaluate spherical harmonics at unit directions 57 | using hardcoded SH polynomials. 58 | Works with torch/np/jnp. 59 | ... Can be 0 or more batch dimensions. 60 | 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | 66 | Returns: 67 | [..., C] 68 | """ 69 | assert deg <= 4 and deg >= 0 70 | assert (deg + 1) ** 2 == sh.shape[-1] 71 | C = sh.shape[-2] 72 | 73 | result = C0 * sh[..., 0] 74 | if deg > 0: 75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 76 | result = (result - 77 | C1 * y * sh[..., 1] + 78 | C1 * z * sh[..., 2] - 79 | C1 * x * sh[..., 3]) 80 | if deg > 1: 81 | xx, yy, zz = x * x, y * y, z * z 82 | xy, yz, xz = x * y, y * z, x * z 83 | result = (result + 84 | C2[0] * xy * sh[..., 4] + 85 | C2[1] * yz * sh[..., 5] + 86 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 87 | C2[3] * xz * sh[..., 7] + 88 | C2[4] * (xx - yy) * sh[..., 8]) 89 | 90 | if deg > 2: 91 | result = (result + 92 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 93 | C3[1] * xy * z * sh[..., 10] + 94 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 95 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 96 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 97 | C3[5] * z * (xx - yy) * sh[..., 14] + 98 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 99 | if deg > 3: 100 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 101 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 102 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 103 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 104 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 105 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 106 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 107 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 108 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 109 | return result -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_opts(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--root_dir', type=str, 7 | default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego', 8 | help='root directory of dataset') 9 | parser.add_argument('--dataset_name', type=str, default='blender', 10 | help='which dataset to train/val') 11 | parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800], 12 | help='resolution (img_w, img_h) of the image') 13 | parser.add_argument('--spheric_poses', default=False, action="store_true", 14 | help='whether images are taken in spheric poses ()') 15 | 16 | parser.add_argument('--N_samples', type=int, default=64, 17 | help='number of coarse samples') 18 | parser.add_argument('--N_importance', type=int, default=128, 19 | help='number of additional fine samples') 20 | parser.add_argument('--use_disp', default=False, action="store_true", 21 | help='use disparity depth sampling') 22 | parser.add_argument('--perturb', type=float, default=1.0, 23 | help='factor to perturb depth sampling points') 24 | parser.add_argument('--noise_std', type=float, default=1.0, 25 | help='std dev of noise added to regularize sigma') 26 | 27 | parser.add_argument('--loss_type', type=str, default='mse', 28 | choices=['mse'], 29 | help='loss to use') 30 | 31 | parser.add_argument('--batch_size', type=int, default=1024, 32 | help='batch size') 33 | parser.add_argument('--chunk', type=int, default=32*1024, 34 | help='chunk size to split the input to avoid OOM') 35 | parser.add_argument('--num_epochs', type=int, default=16, 36 | help='number of training epochs') 37 | parser.add_argument('--num_gpus', type=int, default=1, 38 | help='number of gpus') 39 | 40 | parser.add_argument('--ckpt_path', type=str, default=None, 41 | help='pretrained checkpoint path to load') 42 | parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'], 43 | help='the prefixes to ignore in the checkpoint state dict') 44 | 45 | parser.add_argument('--optimizer', type=str, default='radam', 46 | help='optimizer type', 47 | choices=['sgd', 'adam', 'radam', 'ranger', 'adamw']) 48 | parser.add_argument('--lr', type=float, default=5e-4, 49 | help='learning rate') 50 | parser.add_argument('--momentum', type=float, default=0.9, 51 | help='learning rate momentum') 52 | parser.add_argument('--weight_decay', type=float, default=5e-4, 53 | help='weight decay') 54 | parser.add_argument('--lr_scheduler', type=str, default='steplr', 55 | help='scheduler type', 56 | choices=['steplr', 'cosine', 'poly']) 57 | #### params for warmup, only applied when optimizer == 'sgd' or 'adam' 58 | parser.add_argument('--warmup_multiplier', type=float, default=1.0, 59 | help='lr is multiplied by this factor after --warmup_epochs') 60 | parser.add_argument('--warmup_epochs', type=int, default=0, 61 | help='Gradually warm-up(increasing) learning rate in optimizer') 62 | ########################### 63 | #### params for steplr #### 64 | parser.add_argument('--decay_step', nargs='+', type=int, default=[20], 65 | help='scheduler decay step') 66 | parser.add_argument('--decay_gamma', type=float, default=0.1, 67 | help='learning rate decay amount') 68 | ########################### 69 | #### params for poly #### 70 | parser.add_argument('--poly_exp', type=float, default=0.9, 71 | help='exponent for polynomial learning rate decay') 72 | ########################### 73 | 74 | parser.add_argument('--exp_name', type=str, default='exp', 75 | help='experiment name') 76 | 77 | parser.add_argument('--coord_scope', type=float, 78 | help='the scope of world coordnates') 79 | 80 | parser.add_argument('--sigma_init', type=float, default=30.0, 81 | help='the init sigma') 82 | 83 | parser.add_argument('--sigma_default', type=float, default=-20.0, 84 | help='the default sigma') 85 | 86 | parser.add_argument('--weight_threashold', type=float, default=1e-4, 87 | help='the weight threashold') 88 | 89 | parser.add_argument('--uniform_ratio', type=float, default=0.01, 90 | help='the percentage of uniform sampling') 91 | 92 | parser.add_argument('--beta', type=float, default=0.1, 93 | help='update rate') 94 | 95 | parser.add_argument('--warmup_step', type=int, default=0, 96 | help='the warmup step') 97 | 98 | parser.add_argument('--weight_sparse', type=float, default=0.0, 99 | help='weight of sparse loss') 100 | 101 | parser.add_argument('--weight_tv', type=float, default=0.0, 102 | help='weight of tv loss') 103 | return parser.parse_args() 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia==0.6.0 2 | opencv-python==4.2.0.34 3 | matplotlib 4 | pytorch-lightning==1.6.4 5 | setuptools==59.5.0 6 | test_tube 7 | pillow==8.0.1 8 | imageio==2.19.3 9 | imageio-ffmpeg 10 | tensorboard -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | from torch._C import device, dtype 4 | from opt import get_opts 5 | import torch 6 | import torch.nn.functional as F 7 | from collections import defaultdict 8 | from torchvision import transforms 9 | 10 | from torch.utils.data import DataLoader 11 | from datasets import dataset_dict 12 | 13 | # models 14 | from models.nerf import Embedding, NeRF 15 | from utils import * 16 | 17 | # losses 18 | from losses import loss_dict 19 | 20 | # metrics 21 | from metrics import * 22 | 23 | # pytorch-lightning 24 | from pytorch_lightning.callbacks import ModelCheckpoint 25 | from pytorch_lightning import LightningModule, Trainer 26 | from pytorch_lightning.loggers import TensorBoardLogger 27 | # from spherical_harmonic import eval_sh_torch 28 | from models.sh import eval_sh 29 | 30 | import matplotlib 31 | matplotlib.use('Agg') 32 | import matplotlib.pyplot as plt 33 | import time 34 | import imageio.v2 as imageio 35 | import glob 36 | 37 | import logging 38 | logging.getLogger("lightning").setLevel(logging.ERROR) 39 | 40 | class NerfTree_Pytorch(object): # This is only based on Pytorch implementation 41 | def __init__(self, xyz_min, xyz_max, grid_coarse, grid_fine, deg, sigma_init, sigma_default, device): 42 | ''' 43 | xyz_min: list (3,) or (1, 3) 44 | scope: float 45 | ''' 46 | super().__init__() 47 | self.sigma_init = sigma_init 48 | self.sigma_default = sigma_default 49 | 50 | self.sigma_voxels_coarse = torch.full((grid_coarse,grid_coarse,grid_coarse), self.sigma_init, device=device) 51 | self.index_voxels_coarse = torch.full((grid_coarse,grid_coarse,grid_coarse), 0, dtype=torch.long, device=device) 52 | self.voxels_fine = None 53 | 54 | self.xyz_min = xyz_min[0] 55 | self.xyz_max = xyz_max[0] 56 | self.xyz_scope = self.xyz_max - self.xyz_min 57 | self.grid_coarse = grid_coarse 58 | self.grid_fine = grid_fine 59 | self.res_coarse = grid_coarse 60 | self.res_fine = grid_coarse * grid_fine 61 | self.dim_sh = 3 * (deg + 1)**2 62 | self.device = device 63 | 64 | def calc_index_coarse(self, xyz): 65 | ijk_coarse = ((xyz - self.xyz_min) / self.xyz_scope * self.grid_coarse).long().clamp(min=0, max=self.grid_coarse-1) 66 | # return index_coarse[:, 0] * (self.grid_coarse**2) + index_coarse[:, 1] * self.grid_coarse + index_coarse[:, 2] 67 | return ijk_coarse 68 | 69 | def update_coarse(self, xyz, sigma, beta): 70 | ''' 71 | xyz: (N, 3) 72 | sigma: (N,) 73 | ''' 74 | ijk_coarse = self.calc_index_coarse(xyz) 75 | 76 | self.sigma_voxels_coarse[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] \ 77 | = (1 - beta) * self.sigma_voxels_coarse[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] + \ 78 | beta * sigma 79 | 80 | def create_voxels_fine(self): 81 | ijk_coarse = torch.logical_and(self.sigma_voxels_coarse > 0, self.sigma_voxels_coarse != self.sigma_init).nonzero().squeeze(1) # (N, 3) 82 | num_valid = ijk_coarse.shape[0] + 1 83 | 84 | index = torch.arange(1, num_valid, dtype=torch.long, device=ijk_coarse.device) 85 | self.index_voxels_coarse[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] = index 86 | 87 | self.voxels_fine = torch.zeros(num_valid, self.grid_fine, self.grid_fine, self.grid_fine, self.dim_sh+1, device=self.device) 88 | self.voxels_fine[..., 0] = self.sigma_default 89 | self.voxels_fine[..., 1:] = 0.0 90 | 91 | def calc_index_fine(self, xyz): 92 | # xyz_norm = (xyz - self.xyz_min) / self.xyz_scope 93 | # xyz_coarse = (xyz_norm * self.grid_coarse).long() * self.grid_fine 94 | # xyz_fine = (xyz_norm * self.res_fine).long() 95 | # index_fine = ((xyz_fine - xyz_coarse)).clamp(0, self.grid_fine-1) 96 | 97 | xyz_norm = (xyz - self.xyz_min) / self.xyz_scope 98 | xyz_fine = (xyz_norm * self.res_fine).long() 99 | index_fine = xyz_fine % self.grid_fine 100 | return index_fine 101 | 102 | def update_fine(self, xyz, sigma, sh): 103 | ''' 104 | xyz: (N, 3) 105 | sigma: (N, 1) 106 | sh: (N, F) 107 | ''' 108 | # calc ijk_coarse 109 | index_coarse = self.query_coarse(xyz, 'index') 110 | nonzero_index_coarse = torch.nonzero(index_coarse).squeeze(1) 111 | index_coarse = index_coarse[nonzero_index_coarse] 112 | 113 | # calc index_fine 114 | ijk_fine = self.calc_index_fine(xyz[nonzero_index_coarse]) 115 | 116 | # feat 117 | feat = torch.cat([sigma, sh], dim=-1) 118 | 119 | self.voxels_fine[index_coarse, ijk_fine[:, 0], ijk_fine[:, 1], ijk_fine[:, 2]] = feat[nonzero_index_coarse] 120 | 121 | def query_coarse(self, xyz, type='sigma'): 122 | ''' 123 | xyz: (N, 3) 124 | ''' 125 | ijk_coarse = self.calc_index_coarse(xyz) 126 | 127 | if type == 'sigma': 128 | out = self.sigma_voxels_coarse[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] 129 | else: 130 | out = self.index_voxels_coarse[ijk_coarse[:, 0], ijk_coarse[:, 1], ijk_coarse[:, 2]] 131 | return out 132 | 133 | def query_fine(self, xyz): 134 | ''' 135 | x: (N, 3) 136 | ''' 137 | # calc index_coarse 138 | index_coarse = self.query_coarse(xyz, 'index') 139 | 140 | # calc index_fine 141 | ijk_fine = self.calc_index_fine(xyz) 142 | 143 | return self.voxels_fine[index_coarse, ijk_fine[:, 0], ijk_fine[:, 1], ijk_fine[:, 2]] 144 | 145 | 146 | class EfficientNeRFSystem(LightningModule): 147 | def __init__(self, hparams): 148 | super(EfficientNeRFSystem, self).__init__() 149 | self.save_hyperparameters(hparams) 150 | 151 | self.loss = loss_dict[hparams.loss_type]() 152 | 153 | self.embedding_xyz = Embedding(3, 10) # 10 is the default number 154 | self.embedding_dir = Embedding(3, 4) # 4 is the default number 155 | self.embeddings = [self.embedding_xyz, self.embedding_dir] 156 | 157 | self.deg = 2 158 | self.dim_sh = 3 * (self.deg + 1)**2 159 | 160 | self.nerf_coarse = NeRF(D=4, W=128, 161 | in_channels_xyz=63, in_channels_dir=27, 162 | skips=[2], deg=self.deg) 163 | self.models = [self.nerf_coarse] 164 | if hparams.N_importance > 0: 165 | self.nerf_fine = NeRF(D=8, W=256, 166 | in_channels_xyz=63, in_channels_dir=27, 167 | skips=[4], deg=self.deg) 168 | self.models += [self.nerf_fine] 169 | self.sigma_init = hparams.sigma_init 170 | self.sigma_default = hparams.sigma_default 171 | 172 | # sparse voxels 173 | coord_scope = hparams.coord_scope 174 | self.nerf_tree = NerfTree_Pytorch(xyz_min=[-coord_scope, -coord_scope, -coord_scope], 175 | xyz_max=[coord_scope, coord_scope, coord_scope], 176 | grid_coarse=384, 177 | grid_fine=3, 178 | deg=self.deg, 179 | sigma_init=self.sigma_init, 180 | sigma_default=self.sigma_default, 181 | device='cuda') 182 | os.makedirs(f'logs/{self.hparams.exp_name}/ckpts', exist_ok=True) 183 | self.nerftree_path = os.path.join(f'logs/{self.hparams.exp_name}/ckpts', 'nerftree.pt') 184 | if self.hparams.ckpt_path != None and os.path.exists(self.nerftree_path): 185 | voxels_dict = torch.load(self.nerftree_path) 186 | self.nerf_tree.sigma_voxels_coarse = voxels_dict['sigma_voxels_coarse'] 187 | 188 | self.xyz_min = self.nerf_tree.xyz_min 189 | self.xyz_max = self.nerf_tree.xyz_max 190 | self.xyz_scope = self.nerf_tree.xyz_scope 191 | self.grid_coarse = self.nerf_tree.grid_coarse 192 | self.grid_fine = self.nerf_tree.grid_fine 193 | self.res_coarse = self.nerf_tree.res_coarse 194 | self.res_fine = self.nerf_tree.res_fine 195 | 196 | def decode_batch(self, batch): 197 | rays = batch['rays'] # (B, 8) 198 | rgbs = batch['rgbs'] # (B, 3) 199 | return rays, rgbs 200 | 201 | def sigma2weights(self, deltas, sigmas): 202 | # compute alpha by the formula (3) 203 | # if self.training: 204 | noise = torch.randn(sigmas.shape, device=sigmas.device) 205 | sigmas = sigmas + noise 206 | 207 | # alphas = 1-torch.exp(-deltas*torch.nn.ReLU()(sigmas)) # (N_rays, N_samples_) 208 | alphas = 1-torch.exp(-deltas*torch.nn.Softplus()(sigmas)) # (N_rays, N_samples_) 209 | alphas_shifted = torch.cat([torch.ones_like(alphas[:, :1]), 1-alphas+1e-10], -1) # [1, a1, a2, ...] 210 | weights = alphas * torch.cumprod(alphas_shifted, -1)[:, :-1] # (N_rays, N_samples_) 211 | return weights, alphas 212 | 213 | def render_rays(self, 214 | models, 215 | embeddings, 216 | rays, 217 | N_samples=64, 218 | use_disp=False, 219 | noise_std=0.0, 220 | N_importance=0, 221 | chunk=1024*32, 222 | white_back=False 223 | ): 224 | 225 | def inference(model, embedding_xyz, xyz_, dir_, dir_embedded, z_vals, idx_render): 226 | N_samples_ = xyz_.shape[1] 227 | # Embed directions 228 | xyz_ = xyz_[idx_render[:, 0], idx_render[:, 1]].view(-1, 3) # (N_rays*N_samples_, 3) 229 | view_dir = dir_.unsqueeze(1).expand(-1, N_samples_, -1) 230 | view_dir = view_dir[idx_render[:, 0], idx_render[:, 1]] 231 | # Perform model inference to get rgb and raw sigma 232 | B = xyz_.shape[0] 233 | out_chunks = [] 234 | for i in range(0, B, chunk): 235 | out_chunks += [model(embedding_xyz(xyz_[i:i+chunk]), view_dir[i:i+chunk])] 236 | out = torch.cat(out_chunks, 0) 237 | 238 | out_rgb = torch.full((N_rays, N_samples_, 3), 1.0, device=device) 239 | out_sigma = torch.full((N_rays, N_samples_, 1), self.sigma_default, device=device) 240 | out_sh = torch.full((N_rays, N_samples_, self.dim_sh), 0.0, device=device) 241 | out_defaults = torch.cat([out_sigma, out_rgb, out_sh], dim=2) 242 | out_defaults[idx_render[:, 0], idx_render[:, 1]] = out 243 | out = out_defaults 244 | 245 | sigmas, rgbs, shs = torch.split(out, (1, 3, self.dim_sh), dim=-1) 246 | del out 247 | sigmas = sigmas.squeeze(-1) 248 | 249 | # Convert these values using volume rendering (Section 4) 250 | deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples_-1) 251 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity 252 | deltas = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 253 | 254 | weights, alphas = self.sigma2weights(deltas, sigmas) 255 | 256 | weights_sum = weights.sum(1) # (N_rays), the accumulated opacity along the rays 257 | # equals "1 - (1-a1)(1-a2)...(1-an)" mathematically 258 | 259 | # compute final weighted outputs 260 | rgb_final = torch.sum(weights.unsqueeze(-1)*rgbs, -2) # (N_rays, 3) 261 | depth_final = torch.sum(weights*z_vals, -1) # (N_rays) 262 | 263 | if white_back: 264 | rgb_final = rgb_final + 1-weights_sum.unsqueeze(-1) 265 | 266 | return rgb_final, depth_final, weights, sigmas, shs 267 | 268 | # Extract models from lists 269 | model_coarse = models[0] 270 | embedding_xyz = embeddings[0] 271 | device = rays.device 272 | is_training = model_coarse.training 273 | result = {} 274 | 275 | # Decompose the inputs 276 | N_rays = rays.shape[0] 277 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 278 | 279 | # Embed direction 280 | dir_embedded = None 281 | 282 | N_samples_coarse = self.N_samples_coarse 283 | z_vals_coarse = self.z_vals_coarse.clone().expand(N_rays, -1) 284 | if is_training: 285 | delta_z_vals = torch.empty(N_rays, 1, device=device).uniform_(0.0, self.distance/N_samples_coarse) 286 | z_vals_coarse = z_vals_coarse + delta_z_vals 287 | 288 | xyz_sampled_coarse = rays_o.unsqueeze(1) + \ 289 | rays_d.unsqueeze(1) * z_vals_coarse.unsqueeze(2) # (N_rays, N_samples_coarse, 3) 290 | 291 | xyz_coarse = xyz_sampled_coarse.reshape(-1, 3) 292 | 293 | # valid sampling 294 | sigmas = self.nerf_tree.query_coarse(xyz_coarse, type='sigma').reshape(N_rays, N_samples_coarse) 295 | 296 | # update density voxel during coarse training 297 | if is_training and self.nerf_tree.voxels_fine == None: 298 | with torch.no_grad(): 299 | # introduce uniform sampling, not necessary 300 | sigmas[torch.rand_like(sigmas[:, 0]) < self.hparams.uniform_ratio] = self.sigma_init 301 | 302 | if self.hparams.warmup_step > 0 and self.trainer.global_step <= self.hparams.warmup_step: 303 | # during warmup, treat all points as valid samples 304 | idx_render_coarse = torch.nonzero(sigmas >= -1e10).detach() 305 | else: 306 | # or else, treat points whose density > 0 as valid samples 307 | idx_render_coarse = torch.nonzero(sigmas > 0.0).detach() 308 | 309 | rgb_coarse, depth_coarse, weights_coarse, sigmas_coarse, _ = \ 310 | inference(model_coarse, embedding_xyz, xyz_sampled_coarse, rays_d, 311 | dir_embedded, z_vals_coarse, idx_render_coarse) 312 | result['rgb_coarse'] = rgb_coarse 313 | result['z_vals_coarse'] = self.z_vals_coarse 314 | result['depth_coarse'] = depth_coarse 315 | result['sigma_coarse'] = sigmas_coarse 316 | result['weight_coarse'] = weights_coarse 317 | result['opacity_coarse'] = weights_coarse.sum(1) 318 | result['num_samples_coarse'] = torch.FloatTensor([idx_render_coarse.shape[0] / N_rays]) 319 | 320 | # update 321 | xyz_coarse_ = xyz_sampled_coarse[idx_render_coarse[:, 0], idx_render_coarse[:, 1]] 322 | sigmas_coarse_ = sigmas_coarse.detach()[idx_render_coarse[:, 0], idx_render_coarse[:, 1]] 323 | self.nerf_tree.update_coarse(xyz_coarse_, sigmas_coarse_, self.hparams.beta) 324 | 325 | # deltas_coarse = self.deltas_coarse 326 | with torch.no_grad(): 327 | deltas_coarse = z_vals_coarse[:, 1:] - z_vals_coarse[:, :-1] # (N_rays, N_samples_-1) 328 | delta_inf = 1e10 * torch.ones_like(deltas_coarse[:, :1]) # (N_rays, 1) the last delta is infinity 329 | deltas_coarse = torch.cat([deltas_coarse, delta_inf], -1) # (N_rays, N_samples_) 330 | weights_coarse, _ = self.sigma2weights(deltas_coarse, sigmas) 331 | weights_coarse = weights_coarse.detach() 332 | 333 | # pivotal sampling 334 | idx_render = torch.nonzero(weights_coarse >= min(self.hparams.weight_threashold, weights_coarse.max().item())) 335 | scale = N_importance 336 | z_vals_fine = self.z_vals_fine.clone() 337 | if is_training: 338 | z_vals_fine = z_vals_fine + delta_z_vals 339 | 340 | idx_render = idx_render.unsqueeze(1).expand(-1, scale, -1) # (B, scale, 2) 341 | idx_render_fine = idx_render.clone() 342 | idx_render_fine[..., 1] = idx_render[..., 1] * scale + (torch.arange(scale, device=device)).reshape(1, scale) 343 | idx_render_fine = idx_render_fine.reshape(-1, 2) 344 | 345 | if idx_render_fine.shape[0] > N_rays * 64: 346 | indices = torch.randperm(idx_render_fine.shape[0])[:N_rays * 64] 347 | idx_render_fine = idx_render_fine[indices] 348 | 349 | xyz_sampled_fine = rays_o.unsqueeze(1) + \ 350 | rays_d.unsqueeze(1) * z_vals_fine.unsqueeze(2) # (N_rays, N_samples*scale, 3) 351 | 352 | # if self.nerf_tree.voxels_fine != None: 353 | # xyz_norm = (xyz_sampled_fine - self.xyz_min) / self.xyz_scope 354 | # xyz_norm = (xyz_norm * self.res_fine).long().float() / float(self.res_fine) 355 | # xyz_sampled_fine = xyz_norm * self.xyz_scope + self.xyz_min 356 | 357 | model_fine = models[1] 358 | rgb_fine, depth_fine, _, sigmas_fine, shs_fine = \ 359 | inference(model_fine, embedding_xyz, xyz_sampled_fine, rays_d, 360 | dir_embedded, z_vals_fine, idx_render_fine) 361 | 362 | if is_training and self.nerf_tree.voxels_fine != None: 363 | with torch.no_grad(): 364 | xyz_fine_ = xyz_sampled_fine[idx_render_fine[:, 0], idx_render_fine[:, 1]] 365 | sigmas_fine_ = sigmas_fine.detach()[idx_render_fine[:, 0], idx_render_fine[:, 1]].unsqueeze(-1) 366 | shs_fine_ = shs_fine.detach()[idx_render_fine[:, 0], idx_render_fine[:, 1]] 367 | self.nerf_tree.update_fine(xyz_fine_, sigmas_fine_, shs_fine_) 368 | 369 | result['rgb_fine'] = rgb_fine 370 | result['depth_fine'] = depth_fine 371 | result['num_samples_fine'] = torch.FloatTensor([idx_render_fine.shape[0] / N_rays]) 372 | 373 | return result 374 | 375 | def forward(self, rays): 376 | """Do batched inference on rays using chunk.""" 377 | B = rays.shape[0] 378 | results = defaultdict(list) 379 | # if self.nerf_tree.voxels_fine == None or self.models[0].training: 380 | # chunk = self.hparams.chunk 381 | # else: 382 | # chunk = B // 8 383 | chunk = self.hparams.chunk 384 | for i in range(0, B, chunk): 385 | rendered_ray_chunks = \ 386 | self.render_rays(self.models, 387 | self.embeddings, 388 | rays[i:i+chunk], 389 | self.hparams.N_samples, 390 | self.hparams.use_disp, 391 | self.hparams.noise_std, 392 | self.hparams.N_importance, 393 | chunk, # chunk size is effective in val mode 394 | self.train_dataset.white_back 395 | ) 396 | 397 | for k, v in rendered_ray_chunks.items(): 398 | results[k] += [v] 399 | 400 | for k, v in results.items(): 401 | results[k] = torch.cat(v, 0) 402 | return results 403 | 404 | def optimizer_step(self, epoch=None, 405 | batch_idx=None, 406 | optimizer=None, 407 | optimizer_idx=None, 408 | optimizer_closure=None, 409 | on_tpu=None, 410 | using_native_amp=None, 411 | using_lbfgs=None): 412 | if self.hparams.warmup_step > 0 and self.trainer.global_step < self.hparams.warmup_step: 413 | lr_scale = min(1., float(self.trainer.global_step + 1) / float(self.hparams.warmup_step)) 414 | for pg in optimizer.param_groups: 415 | pg['lr'] = lr_scale * self.hparams.lr 416 | optimizer.step(closure=optimizer_closure) 417 | optimizer.zero_grad() 418 | 419 | def prepare_data(self): 420 | dataset = dataset_dict[self.hparams.dataset_name] 421 | kwargs = {'root_dir': self.hparams.root_dir, 422 | 'img_wh': tuple(self.hparams.img_wh)} 423 | self.train_dataset = dataset(split='train', **kwargs) 424 | if self.hparams.dataset_name == 'blender': 425 | self.val_dataset = dataset(split='test', **kwargs) 426 | else: 427 | self.val_dataset = dataset(split='val', **kwargs) 428 | 429 | self.near = self.train_dataset.near 430 | self.far = self.train_dataset.far 431 | self.distance = self.far - self.near 432 | near = torch.full((1,), self.near, dtype=torch.float32, device='cuda') 433 | far = torch.full((1,), self.far, dtype=torch.float32, device='cuda') 434 | 435 | # z_vals_coarse 436 | self.N_samples_coarse = self.hparams.N_samples 437 | z_vals_coarse = torch.linspace(0, 1, self.N_samples_coarse, device='cuda') # (N_samples_coarse) 438 | if not self.hparams.use_disp: # use linear sampling in depth space 439 | z_vals_coarse = near * (1-z_vals_coarse) + far * z_vals_coarse 440 | else: # use linear sampling in disparity space 441 | z_vals_coarse = 1/(1/near * (1-z_vals_coarse) + 1/far * z_vals_coarse) # (N_rays, N_samples_coarse) 442 | self.z_vals_coarse = z_vals_coarse.unsqueeze(0) 443 | 444 | # z_vals_fine 445 | self.N_samples_fine = self.hparams.N_samples * self.hparams.N_importance 446 | z_vals_fine = torch.linspace(0, 1, self.N_samples_fine, device='cuda') # (N_samples_coarse) 447 | if not self.hparams.use_disp: # use linear sampling in depth space 448 | z_vals_fine = near * (1-z_vals_fine) + far * z_vals_fine 449 | else: # use linear sampling in disparity space 450 | z_vals_fine = 1/(1/near * (1-z_vals_fine) + 1/far * z_vals_fine) # (N_rays, N_samples_coarse) 451 | self.z_vals_fine = z_vals_fine.unsqueeze(0) 452 | 453 | # delta 454 | deltas = self.z_vals_coarse[:, 1:] - self.z_vals_coarse[:, :-1] # (N_rays, N_samples_-1) 455 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity 456 | self.deltas_coarse = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 457 | 458 | deltas = self.z_vals_fine[:, 1:] - self.z_vals_fine[:, :-1] # (N_rays, N_samples_-1) 459 | delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # (N_rays, 1) the last delta is infinity 460 | self.deltas_fine = torch.cat([deltas, delta_inf], -1) # (N_rays, N_samples_) 461 | 462 | 463 | def configure_optimizers(self): 464 | self.optimizer = get_optimizer(self.hparams, self.models) 465 | scheduler = get_scheduler(self.hparams, self.optimizer) 466 | 467 | return [self.optimizer], [scheduler] 468 | 469 | def train_dataloader(self): 470 | return DataLoader(self.train_dataset, 471 | shuffle=True, 472 | num_workers=8, 473 | batch_size=self.hparams.batch_size, 474 | pin_memory=True) 475 | 476 | def val_dataloader(self): 477 | return DataLoader(self.val_dataset, 478 | shuffle=False, 479 | num_workers=4, 480 | batch_size=1, # validate one image (H*W rays) at a time 481 | pin_memory=True) 482 | 483 | def training_step(self, batch, batch_idx): 484 | self.log('train/lr', get_learning_rate(self.optimizer), on_step=True, prog_bar=True) 485 | rays, rgbs = self.decode_batch(batch) 486 | extract_time = self.current_epoch >= (self.hparams.num_epochs - 1) 487 | 488 | if extract_time and self.nerf_tree.voxels_fine == None: 489 | self.nerf_tree.create_voxels_fine() 490 | 491 | results = self(rays) 492 | 493 | loss_total = loss_rgb = self.loss(results, rgbs) 494 | self.log('train/loss_rgb', loss_rgb, on_step=True) 495 | 496 | # if self.hparams.weight_tv > 0.0: 497 | # alphas_coarse = results['alpha_coarse'] 498 | # loss_tv = self.hparams.weight_tv * (alphas_coarse[:, 1:] - alphas_coarse[:, :-1]).pow(2).mean() 499 | # self.log('train/loss_tv', loss_tv, on_step=True) 500 | # loss_total += loss_tv 501 | 502 | self.log('train/loss_total', loss_total, on_step=True) 503 | 504 | if 'num_samples_coarse' in results: 505 | self.log('train/num_samples_coarse', results['num_samples_coarse'].mean(), on_step=True) 506 | 507 | if 'num_samples_fine' in results: 508 | self.log('train/num_samples_fine', results['num_samples_fine'].mean(), on_step=True) 509 | 510 | typ = 'fine' if 'rgb_fine' in results else 'coarse' 511 | 512 | if batch_idx % 1000 == 0 and self.nerf_tree.voxels_fine == None: 513 | fig = plt.figure() 514 | depths = results['z_vals_coarse'][0].detach().cpu().numpy() 515 | sigmas = torch.nn.ReLU()(results['sigma_coarse'][0]).detach().cpu().numpy() 516 | weights = results['weight_coarse'][0].detach().cpu().numpy() 517 | near = self.near - (self.far - self.near) * 0.1 518 | far = self.far + (self.far - self.near) * 0.1 519 | fig, ax = plt.subplots(1, 2, figsize=(12, 5), dpi=120) 520 | ax[0].scatter(x=depths, y=sigmas) 521 | ax[0].set_xlabel('Depth', fontsize=16) 522 | ax[0].set_ylabel('Density', fontsize=16) 523 | ax[0].set_title('Density Distribution of a Ray', fontsize=16) 524 | ax[0].set_xlim([near, far]) 525 | 526 | ax[1].scatter(x=depths, y=weights) 527 | ax[1].set_xlabel('Depth', fontsize=16) 528 | ax[1].set_ylabel('Weight', fontsize=16) 529 | ax[1].set_title('Weight Distribution of a Ray', fontsize=16) 530 | ax[1].set_xlim([near, far]) 531 | 532 | self.logger.experiment.add_figure('train/distribution', 533 | fig, self.global_step) 534 | plt.close() 535 | 536 | feats = {} 537 | with torch.no_grad(): 538 | psnr_fine = psnr(results[f'rgb_{typ}'], rgbs) 539 | self.log('train/psnr_fine', psnr_fine, on_step=True, prog_bar=True) 540 | 541 | if 'rgb_coarse' in results: 542 | psnr_coarse = psnr(results['rgb_coarse'], rgbs) 543 | self.log('train/psnr_coarse', psnr_coarse, on_step=True) 544 | 545 | if batch_idx % 1000 == 0: 546 | torch.cuda.empty_cache() 547 | return loss_total 548 | 549 | def validation_step(self, batch, batch_idx): 550 | rays, rgbs = self.decode_batch(batch) 551 | rays = rays.squeeze() # (H*W, 3) 552 | rgbs = rgbs.squeeze() # (H*W, 3) 553 | 554 | results = self(rays) 555 | log = {} 556 | log['val_loss'] = self.loss(results, rgbs) 557 | typ = 'fine' if 'rgb_fine' in results else 'coarse' 558 | 559 | W, H = self.hparams.img_wh 560 | img = results[f'rgb_{typ}'].view(H, W, 3).cpu() 561 | img = img.permute(2, 0, 1) # (3, H, W) 562 | img_path = os.path.join(f'logs/{hparams.exp_name}/video', "%06d.png" % batch_idx) 563 | os.makedirs(os.path.dirname(img_path), exist_ok=True) 564 | transforms.ToPILImage()(img).convert("RGB").save(img_path) 565 | 566 | idx_selected = 0 567 | if batch_idx == idx_selected: 568 | W, H = self.hparams.img_wh 569 | img = results[f'rgb_{typ}'].view(H, W, 3).cpu() 570 | img = img.permute(2, 0, 1) # (3, H, W) 571 | img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) 572 | stack = torch.stack([img_gt, img]) # (3, 3, H, W) 573 | self.logger.experiment.add_images('val/gt_pred', 574 | stack, self.global_step) 575 | 576 | img_path = os.path.join(f'logs/{hparams.exp_name}', f'epoch_{self.current_epoch}.png') 577 | transforms.ToPILImage()(img).convert("RGB").save(img_path) 578 | 579 | log['val_psnr'] = psnr(results[f'rgb_{typ}'], rgbs) 580 | torch.cuda.empty_cache() 581 | return log 582 | 583 | def validation_epoch_end(self, outputs): 584 | log = {} 585 | mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 586 | mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean() 587 | num_voxels_coarse = torch.logical_and(self.nerf_tree.sigma_voxels_coarse > 0, self.nerf_tree.sigma_voxels_coarse != self.sigma_init).nonzero().shape[0] 588 | self.log('val/loss', mean_loss, on_epoch=True) 589 | self.log('val/psnr', mean_psnr, on_epoch=True, prog_bar=True) 590 | self.log('val/num_voxels_coarse', num_voxels_coarse, on_epoch=True) 591 | 592 | # save sparse voxels 593 | sigma_voxels_coarse_clean = self.nerf_tree.sigma_voxels_coarse.clone() 594 | sigma_voxels_coarse_clean[sigma_voxels_coarse_clean == self.sigma_init] = self.sigma_default 595 | voxels_dict = { 596 | 'sigma_voxels_coarse': sigma_voxels_coarse_clean, 597 | 'index_voxels_coarse': self.nerf_tree.index_voxels_coarse, 598 | 'voxels_fine': self.nerf_tree.voxels_fine 599 | } 600 | torch.save(voxels_dict, self.nerftree_path) 601 | 602 | img_paths = glob.glob(f'logs/{hparams.exp_name}/video/*.png') 603 | writer = imageio.get_writer(f'logs/{hparams.exp_name}/video/video_{self.current_epoch}.mp4', fps=40) 604 | for im in img_paths: 605 | writer.append_data(imageio.imread(im)) 606 | writer.close() 607 | 608 | 609 | if __name__ == '__main__': 610 | hparams = get_opts() 611 | system = EfficientNeRFSystem(hparams) 612 | checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(f'logs/{hparams.exp_name}/ckpts', 613 | '{epoch:d}'), 614 | monitor='val/psnr', 615 | mode='max', 616 | save_top_k=5,) 617 | 618 | logger = TensorBoardLogger( 619 | save_dir="logs", 620 | name=hparams.exp_name, 621 | ) 622 | 623 | trainer = Trainer(max_epochs=hparams.num_epochs, 624 | checkpoint_callback=checkpoint_callback, 625 | logger=logger, 626 | gpus=hparams.num_gpus, 627 | strategy='ddp' if hparams.num_gpus>1 else None, 628 | benchmark=True) 629 | 630 | trainer.fit(system, ckpt_path=hparams.ckpt_path) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate EfficientNeRF 4 | 5 | DATA_DIR=/research/dept6/taohu/Data/NeRF_Data/nerf_synthetic/lego 6 | 7 | python train.py \ 8 | --dataset_name blender \ 9 | --root_dir $DATA_DIR \ 10 | --N_samples 128 \ 11 | --N_importance 5 --img_wh 800 800 \ 12 | --num_epochs 16 --batch_size 4096 \ 13 | --lr 2e-3 \ 14 | --lr_scheduler poly \ 15 | --coord_scope 3.0 \ 16 | --warmup_step 5000\ 17 | --sigma_init 30.0 \ 18 | --weight_threashold 1e-5 \ 19 | --exp_name lego_coarse128_fine5_V384 20 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | from torch.optim import SGD, Adam 3 | from .optimizers import * 4 | # scheduler 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, LambdaLR 6 | from .warmup_scheduler import GradualWarmupScheduler 7 | 8 | from .visualization import * 9 | 10 | def get_optimizer(hparams, models): 11 | eps = 1e-8 12 | parameters = [] 13 | for model in models: 14 | parameters += list(model.parameters()) 15 | if hparams.optimizer == 'sgd': 16 | optimizer = SGD(parameters, lr=hparams.lr, 17 | momentum=hparams.momentum, weight_decay=hparams.weight_decay) 18 | elif hparams.optimizer == 'adam': 19 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps, 20 | weight_decay=hparams.weight_decay) 21 | elif hparams.optimizer == 'radam': 22 | optimizer = RAdam(parameters, lr=hparams.lr, eps=eps, 23 | weight_decay=hparams.weight_decay) 24 | elif hparams.optimizer == 'ranger': 25 | optimizer = Ranger(parameters, lr=hparams.lr, eps=eps, 26 | weight_decay=hparams.weight_decay) 27 | elif hparams.optimizer == 'adamw': 28 | optimizer = AdamW(parameters, lr=hparams.lr, eps=eps, 29 | weight_decay=hparams.weight_decay) 30 | else: 31 | raise ValueError('optimizer not recognized!') 32 | 33 | return optimizer 34 | 35 | def get_scheduler(hparams, optimizer): 36 | eps = 1e-8 37 | if hparams.lr_scheduler == 'steplr': 38 | scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, 39 | gamma=hparams.decay_gamma) 40 | elif hparams.lr_scheduler == 'cosine': 41 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) 42 | elif hparams.lr_scheduler == 'poly': 43 | scheduler = LambdaLR(optimizer, 44 | lambda epoch: 0.01**(epoch/hparams.num_epochs)) 45 | else: 46 | raise ValueError('scheduler not recognized!') 47 | 48 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: 49 | scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, 50 | total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) 51 | 52 | return scheduler 53 | 54 | def get_learning_rate(optimizer): 55 | for param_group in optimizer.param_groups: 56 | return param_group['lr'] 57 | 58 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 59 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 60 | checkpoint_ = {} 61 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 62 | checkpoint = checkpoint['state_dict'] 63 | for k, v in checkpoint.items(): 64 | if not k.startswith(model_name): 65 | continue 66 | k = k[len(model_name)+1:] 67 | for prefix in prefixes_to_ignore: 68 | if k.startswith(prefix): 69 | print('ignore', k) 70 | break 71 | else: 72 | checkpoint_[k] = v 73 | return checkpoint_ 74 | 75 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 76 | model_dict = model.state_dict() 77 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 78 | model_dict.update(checkpoint_) 79 | model.load_state_dict(model_dict) 80 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | import itertools as it 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 <= betas[0] < 1.0: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= betas[1] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 17 | 18 | self.degenerated_to_sgd = degenerated_to_sgd 19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 20 | for param in params: 21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 22 | param['buffer'] = [[None, None, None] for _ in range(10)] 23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 24 | super(RAdam, self).__init__(params, defaults) 25 | 26 | def __setstate__(self, state): 27 | super(RAdam, self).__setstate__(state) 28 | 29 | def step(self, closure=None): 30 | 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | 35 | for group in self.param_groups: 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | grad = p.grad.data.float() 41 | if grad.is_sparse: 42 | raise RuntimeError('RAdam does not support sparse gradients') 43 | 44 | p_data_fp32 = p.data.float() 45 | 46 | state = self.state[p] 47 | 48 | if len(state) == 0: 49 | state['step'] = 0 50 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 51 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 52 | else: 53 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 54 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 55 | 56 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 57 | beta1, beta2 = group['betas'] 58 | 59 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 60 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 61 | 62 | state['step'] += 1 63 | buffered = group['buffer'][int(state['step'] % 10)] 64 | if state['step'] == buffered[0]: 65 | N_sma, step_size = buffered[1], buffered[2] 66 | else: 67 | buffered[0] = state['step'] 68 | beta2_t = beta2 ** state['step'] 69 | N_sma_max = 2 / (1 - beta2) - 1 70 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 71 | buffered[1] = N_sma 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 76 | elif self.degenerated_to_sgd: 77 | step_size = 1.0 / (1 - beta1 ** state['step']) 78 | else: 79 | step_size = -1 80 | buffered[2] = step_size 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | if group['weight_decay'] != 0: 85 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 86 | denom = exp_avg_sq.sqrt().add_(group['eps']) 87 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 88 | p.data.copy_(p_data_fp32) 89 | elif step_size > 0: 90 | if group['weight_decay'] != 0: 91 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 92 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 93 | p.data.copy_(p_data_fp32) 94 | 95 | return loss 96 | 97 | class PlainRAdam(Optimizer): 98 | 99 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 100 | if not 0.0 <= lr: 101 | raise ValueError("Invalid learning rate: {}".format(lr)) 102 | if not 0.0 <= eps: 103 | raise ValueError("Invalid epsilon value: {}".format(eps)) 104 | if not 0.0 <= betas[0] < 1.0: 105 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 106 | if not 0.0 <= betas[1] < 1.0: 107 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 108 | 109 | self.degenerated_to_sgd = degenerated_to_sgd 110 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 111 | 112 | super(PlainRAdam, self).__init__(params, defaults) 113 | 114 | def __setstate__(self, state): 115 | super(PlainRAdam, self).__setstate__(state) 116 | 117 | def step(self, closure=None): 118 | 119 | loss = None 120 | if closure is not None: 121 | loss = closure() 122 | 123 | for group in self.param_groups: 124 | 125 | for p in group['params']: 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data.float() 129 | if grad.is_sparse: 130 | raise RuntimeError('RAdam does not support sparse gradients') 131 | 132 | p_data_fp32 = p.data.float() 133 | 134 | state = self.state[p] 135 | 136 | if len(state) == 0: 137 | state['step'] = 0 138 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 139 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 140 | else: 141 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 142 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 143 | 144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 145 | beta1, beta2 = group['betas'] 146 | 147 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 148 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 149 | 150 | state['step'] += 1 151 | beta2_t = beta2 ** state['step'] 152 | N_sma_max = 2 / (1 - beta2) - 1 153 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 154 | 155 | 156 | # more conservative since it's an approximated value 157 | if N_sma >= 5: 158 | if group['weight_decay'] != 0: 159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 160 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 161 | denom = exp_avg_sq.sqrt().add_(group['eps']) 162 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 163 | p.data.copy_(p_data_fp32) 164 | elif self.degenerated_to_sgd: 165 | if group['weight_decay'] != 0: 166 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 167 | step_size = group['lr'] / (1 - beta1 ** state['step']) 168 | p_data_fp32.add_(-step_size, exp_avg) 169 | p.data.copy_(p_data_fp32) 170 | 171 | return loss 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss 245 | 246 | 247 | #Ranger deep learning optimizer - RAdam + Lookahead combined. 248 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 249 | 250 | #Ranger has now been used to capture 12 records on the FastAI leaderboard. 251 | 252 | #This version = 9.3.19 253 | 254 | #Credits: 255 | #RAdam --> https://github.com/LiyuanLucasLiu/RAdam 256 | #Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 257 | #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 258 | 259 | #summary of changes: 260 | #full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 261 | #supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 262 | #changes 8/31/19 - fix references to *self*.N_sma_threshold; 263 | #changed eps to 1e-5 as better default than 1e-8. 264 | 265 | 266 | class Ranger(Optimizer): 267 | 268 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0): 269 | #parameter checks 270 | if not 0.0 <= alpha <= 1.0: 271 | raise ValueError(f'Invalid slow update rate: {alpha}') 272 | if not 1 <= k: 273 | raise ValueError(f'Invalid lookahead steps: {k}') 274 | if not lr > 0: 275 | raise ValueError(f'Invalid Learning Rate: {lr}') 276 | if not eps > 0: 277 | raise ValueError(f'Invalid eps: {eps}') 278 | 279 | #parameter comments: 280 | # beta1 (momentum) of .95 seems to work better than .90... 281 | #N_sma_threshold of 5 seems better in testing than 4. 282 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 283 | 284 | #prep defaults and init torch.optim base 285 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 286 | super().__init__(params,defaults) 287 | 288 | #adjustable threshold 289 | self.N_sma_threshhold = N_sma_threshhold 290 | 291 | #now we can get to work... 292 | #removed as we now use step from RAdam...no need for duplicate step counting 293 | #for group in self.param_groups: 294 | # group["step_counter"] = 0 295 | #print("group step counter init") 296 | 297 | #look ahead params 298 | self.alpha = alpha 299 | self.k = k 300 | 301 | #radam buffer for state 302 | self.radam_buffer = [[None,None,None] for ind in range(10)] 303 | 304 | #self.first_run_check=0 305 | 306 | #lookahead weights 307 | #9/2/19 - lookahead param tensors have been moved to state storage. 308 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. 309 | 310 | #self.slow_weights = [[p.clone().detach() for p in group['params']] 311 | # for group in self.param_groups] 312 | 313 | #don't use grad for lookahead weights 314 | #for w in it.chain(*self.slow_weights): 315 | # w.requires_grad = False 316 | 317 | def __setstate__(self, state): 318 | print("set state called") 319 | super(Ranger, self).__setstate__(state) 320 | 321 | 322 | def step(self, closure=None): 323 | loss = None 324 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 325 | #Uncomment if you need to use the actual closure... 326 | 327 | #if closure is not None: 328 | #loss = closure() 329 | 330 | #Evaluate averages and grad, update param tensors 331 | for group in self.param_groups: 332 | 333 | for p in group['params']: 334 | if p.grad is None: 335 | continue 336 | grad = p.grad.data.float() 337 | if grad.is_sparse: 338 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 339 | 340 | p_data_fp32 = p.data.float() 341 | 342 | state = self.state[p] #get state dict for this param 343 | 344 | if len(state) == 0: #if first time to run...init dictionary with our desired entries 345 | #if self.first_run_check==0: 346 | #self.first_run_check=1 347 | #print("Initializing slow buffer...should not see this at load from saved model!") 348 | state['step'] = 0 349 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 350 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 351 | 352 | #look ahead weight storage now in state dict 353 | state['slow_buffer'] = torch.empty_like(p.data) 354 | state['slow_buffer'].copy_(p.data) 355 | 356 | else: 357 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 358 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 359 | 360 | #begin computations 361 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 362 | beta1, beta2 = group['betas'] 363 | 364 | #compute variance mov avg 365 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 366 | #compute mean moving avg 367 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 368 | 369 | state['step'] += 1 370 | 371 | 372 | buffered = self.radam_buffer[int(state['step'] % 10)] 373 | if state['step'] == buffered[0]: 374 | N_sma, step_size = buffered[1], buffered[2] 375 | else: 376 | buffered[0] = state['step'] 377 | beta2_t = beta2 ** state['step'] 378 | N_sma_max = 2 / (1 - beta2) - 1 379 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 380 | buffered[1] = N_sma 381 | if N_sma > self.N_sma_threshhold: 382 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 383 | else: 384 | step_size = 1.0 / (1 - beta1 ** state['step']) 385 | buffered[2] = step_size 386 | 387 | if group['weight_decay'] != 0: 388 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 389 | 390 | if N_sma > self.N_sma_threshhold: 391 | denom = exp_avg_sq.sqrt().add_(group['eps']) 392 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 393 | else: 394 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 395 | 396 | p.data.copy_(p_data_fp32) 397 | 398 | #integrated look ahead... 399 | #we do it at the param level instead of group level 400 | if state['step'] % group['k'] == 0: 401 | slow_p = state['slow_buffer'] #get access to slow param tensor 402 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha 403 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor 404 | 405 | return loss -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | 6 | def visualize_depth(depth, cmap=cv2.COLORMAP_JET): 7 | """ 8 | depth: (H, W) 9 | """ 10 | x = depth.cpu().numpy() 11 | x = np.nan_to_num(x) # change nan to 0 12 | mi = np.min(x) # get minimum depth 13 | ma = np.max(x) 14 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 15 | x = (255*x).astype(np.uint8) 16 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 17 | x_ = T.ToTensor()(x_) # (3, H, W) 18 | return x_ -------------------------------------------------------------------------------- /utils/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | multiplier: target learning rate = base lr * multiplier 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 15 | self.multiplier = multiplier 16 | if self.multiplier < 1.: 17 | raise ValueError('multiplier should be greater thant or equal to 1.') 18 | self.total_epoch = total_epoch 19 | self.after_scheduler = after_scheduler 20 | self.finished = False 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.last_epoch > self.total_epoch: 25 | if self.after_scheduler: 26 | if not self.finished: 27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 28 | self.finished = True 29 | return self.after_scheduler.get_lr() 30 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | 32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 33 | 34 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 35 | if epoch is None: 36 | epoch = self.last_epoch + 1 37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 38 | if self.last_epoch <= self.total_epoch: 39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 41 | param_group['lr'] = lr 42 | else: 43 | if epoch is None: 44 | self.after_scheduler.step(metrics, None) 45 | else: 46 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 47 | 48 | def step(self, epoch=None, metrics=None): 49 | if type(self.after_scheduler) != ReduceLROnPlateau: 50 | if self.finished and self.after_scheduler: 51 | if epoch is None: 52 | self.after_scheduler.step(None) 53 | else: 54 | self.after_scheduler.step(epoch - self.total_epoch) 55 | else: 56 | return super(GradualWarmupScheduler, self).step(epoch) 57 | else: 58 | self.step_ReduceLROnPlateau(metrics, epoch) --------------------------------------------------------------------------------