├── NeRF.py ├── PDRF_network.png ├── PDRF_plot.png ├── README.md ├── configs └── defocustanabata │ ├── tx_defocustanabata_full.txt │ ├── tx_defocustanabata_naive.txt │ ├── tx_defocustanabata_naive_c2f.txt │ ├── tx_defocustanabata_nerf.txt │ └── tx_defocustanabata_twostage_deblurnerf.txt ├── load_llff.py ├── lpips ├── __init__.py ├── lpips.py ├── pretrained_networks.py └── weights │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── metrics.py ├── optimizer.py ├── ray_utils.py ├── requirements.txt ├── results.png ├── run.sh ├── run_nerf.py ├── run_nerf_helpers.py └── utils.py /NeRF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from run_nerf_helpers import * 4 | import os 5 | import imageio 6 | import time, pickle 7 | from torch_efficient_distloss import flatten_eff_distloss 8 | 9 | import random 10 | import numpy as np 11 | seed = 10000 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.default_rng(seed=seed) 17 | 18 | def init_linear_weights(m): 19 | if isinstance(m, nn.Linear): 20 | if m.weight.shape[0] in [2, 3]: 21 | nn.init.xavier_normal_(m.weight, 0.1) 22 | else: 23 | nn.init.xavier_normal_(m.weight) 24 | # nn.init.xavier_normal_(m.weight) 25 | nn.init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.ConvTranspose2d): 27 | nn.init.xavier_normal_(m.weight) 28 | nn.init.constant_(m.bias, 0) 29 | 30 | class BlurModel(nn.Module): 31 | def __init__(self, num_img, poses, num_pt, kernel_hwindow, kernel_type, img_wh=None, random_hwindow=0.25, 32 | in_embed=3, random_mode='input', img_embed=32, spatial_embed=0, depth_embed=0, 33 | num_hidden=3, num_wide=64, feat_cnl=15, short_cut=False, pattern_init_radius=0.1, 34 | isglobal=False, optim_trans=False, optim_spatialvariant_trans=False): 35 | """ 36 | num_img: number of image, used for deciding the view embedding 37 | poses: the original poses, used for generating new rays, len(poses) == num_img 38 | num_pt: number of sparse point, we use 5 in the paper 39 | kernel_hwindow: the size of physically equivalent blur kernel, the sparse points are bounded inside the blur kernel. 40 | Can be a very big number 41 | 42 | random_hwindow: in training, we randomly perturb the sparse point to model a smooth manifold 43 | random_mode: 'input' or 'output', it controls whether the random perturb is added to the input of DSK or output of DSK 44 | // the above two parameters do not have big impact on the results 45 | 46 | in_embed: embedding for the canonical kernel location 47 | img_embed: the length of the view embedding 48 | spatial_embed: embedding for the pixel location of the blur kernel inside an image 49 | depth_embed: (deprecated) the embedding for the depth of current rays 50 | 51 | num_hidden, num_wide, short_cut: control the structure of the MLP 52 | pattern_init_radius: the little gain add to the deform location described in Sec. 4.4 53 | isglobal: control whether the canonical kernel should be shared by all the input views or not, does not have big impact on the results 54 | optim_trans: whether to optimize the ray origin described in Sec. 4.3 55 | optim_spatialvariant_trans: whether to optimize the ray origin for each view or each kernel point. 56 | """ 57 | super().__init__() 58 | self.num_pt = num_pt 59 | self.num_img = num_img 60 | self.short_cut = short_cut 61 | self.kernel_hwindow = kernel_hwindow 62 | self.random_hwindow = random_hwindow # about 1 pix 63 | self.random_mode = random_mode 64 | self.kernel_type = kernel_type 65 | self.isglobal = isglobal 66 | self.feat_cnl = feat_cnl 67 | pattern_num = 1 if isglobal else num_img 68 | assert self.random_mode in ['input', 'output'], f"BlurModel::random_mode {self.random_mode} unrecognized, " \ 69 | f"should be input/output" 70 | self.register_buffer("poses", poses) 71 | self.register_parameter("pattern_pos", 72 | nn.Parameter(torch.randn(pattern_num, num_pt, 2) 73 | .type(torch.float32) * pattern_init_radius, True)) 74 | self.optim_trans = optim_trans 75 | self.optim_sv_trans = optim_spatialvariant_trans 76 | 77 | if optim_trans: 78 | self.register_parameter("pattern_trans", 79 | nn.Parameter(torch.zeros(pattern_num, num_pt, 2) 80 | .type(torch.float32), True)) 81 | 82 | if in_embed > 0: 83 | self.in_embed_fn, self.in_embed_cnl = get_embedder(in_embed, input_dim=2) 84 | else: 85 | self.in_embed_fn, self.in_embed_cnl = None, 0 86 | 87 | self.img_embed_cnl = img_embed 88 | 89 | if spatial_embed > 0: 90 | self.spatial_embed_fn, self.spatial_embed_cnl = get_embedder(spatial_embed, input_dim=2) 91 | else: 92 | self.spatial_embed_fn, self.spatial_embed_cnl = None, 0 93 | 94 | if depth_embed > 0: 95 | self.require_depth = True 96 | self.depth_embed_fn, self.depth_embed_cnl = get_embedder(depth_embed, input_dim=1) 97 | else: 98 | self.require_depth = False 99 | self.depth_embed_fn, self.depth_embed_cnl = None, 0 100 | 101 | in_cnl = self.in_embed_cnl + self.img_embed_cnl + self.depth_embed_cnl + self.spatial_embed_cnl 102 | if self.kernel_type == 'PBE': 103 | in_cnl += self.feat_cnl 104 | out_cnl = 1 + 2 + 2 if self.optim_sv_trans else 1 + 2 # u, v, w or u, v, w, dx, dy 105 | hiddens = [nn.Linear(num_wide, num_wide) if i % 2 == 0 else nn.ReLU() 106 | for i in range((num_hidden - 1) * 2)] 107 | self.linears = nn.Sequential( 108 | nn.Linear(in_cnl, num_wide), nn.ReLU(), 109 | *hiddens, 110 | ) 111 | self.linears1 = nn.Sequential( 112 | nn.Linear((num_wide + in_cnl) if short_cut else num_wide, num_wide), nn.ReLU(), 113 | nn.Linear(num_wide, out_cnl) 114 | ) 115 | self.linears.apply(init_linear_weights) 116 | self.linears1.apply(init_linear_weights) 117 | if img_embed > 0: 118 | self.register_parameter("img_embed", 119 | nn.Parameter(torch.zeros(num_img, img_embed).type(torch.float32), True)) 120 | else: 121 | self.img_embed = None 122 | 123 | def forward(self, H, W, K, rays, rays_info,feats=None): 124 | """ 125 | inputs: all input has shape (ray_num, cnl) 126 | outputs: output shape (ray_num, ptnum, 3, 2) last two dim: [ray_o, ray_d] 127 | """ 128 | img_idx = rays_info['images_idx'].squeeze(-1) 129 | img_embed = self.img_embed[img_idx] if self.img_embed is not None else \ 130 | torch.tensor([]).reshape(len(img_idx), self.img_embed_cnl) 131 | 132 | pt_pos = self.pattern_pos.expand(len(img_idx), -1, -1) if self.isglobal \ 133 | else self.pattern_pos[img_idx] 134 | pt_pos = torch.tanh(pt_pos) * self.kernel_hwindow 135 | 136 | if self.random_hwindow > 0 and self.random_mode == "input": 137 | random_pos = torch.randn_like(pt_pos) * self.random_hwindow 138 | pt_pos = pt_pos + random_pos 139 | 140 | input_pos = pt_pos # the first point is the reference point 141 | if self.in_embed_fn is not None: 142 | pt_pos = pt_pos * (np.pi / self.kernel_hwindow) 143 | pt_pos = self.in_embed_fn(pt_pos) 144 | 145 | img_embed_expand = img_embed[:, None].expand(len(img_embed), self.num_pt, self.img_embed_cnl) 146 | 147 | if self.kernel_type == 'DSK': 148 | x = torch.cat([pt_pos, img_embed_expand], dim=-1) 149 | else: 150 | if feats == None: 151 | x = torch.cat([pt_pos, img_embed_expand,torch.zeros(len(img_embed), self.num_pt, self.feat_cnl)], dim=-1) 152 | else: 153 | x = torch.cat([pt_pos, img_embed_expand,feats.view(len(img_embed), self.num_pt,-1)], dim=-1) 154 | 155 | rays_x, rays_y = rays_info['rays_x'], rays_info['rays_y'] 156 | if self.spatial_embed_fn is not None: 157 | spatialx = rays_x / (W / 2 / np.pi) - np.pi 158 | spatialy = rays_y / (H / 2 / np.pi) - np.pi # scale 2pi to match the freq in the embedder 159 | spatial_save = torch.cat([spatialx, spatialy], dim=-1) 160 | spatial = self.spatial_embed_fn(spatial_save) 161 | spatial = spatial[:, None].expand(len(img_idx), self.num_pt, self.spatial_embed_cnl) 162 | x = torch.cat([x, spatial], dim=-1) 163 | 164 | if self.depth_embed_fn is not None: 165 | depth = rays_info['ray_depth'] 166 | depth = depth * np.pi 167 | depth = self.depth_embed_fn(depth) 168 | depth = depth[:, None].expand(len(img_idx), self.num_pt, self.depth_embed_cnl) 169 | x = torch.cat([x, depth], dim=-1) 170 | 171 | # forward 172 | x1 = self.linears(x) 173 | x1 = torch.cat([x, x1], dim=-1) if self.short_cut else x1 174 | x1 = self.linears1(x1) 175 | 176 | delta_trans = None 177 | if self.optim_sv_trans: 178 | delta_trans, delta_pos, weight = torch.split(x1, [2, 2, 1], dim=-1) 179 | else: 180 | delta_pos, weight = torch.split(x1, [2, 1], dim=-1) 181 | 182 | if self.optim_trans: 183 | delta_trans = self.pattern_trans.expand(len(img_idx), -1, -1) if self.isglobal \ 184 | else self.pattern_trans[img_idx] 185 | 186 | if delta_trans is None: 187 | delta_trans = torch.zeros_like(delta_pos) 188 | 189 | delta_trans = delta_trans * 0.01 190 | new_rays_xy = delta_pos + input_pos 191 | if self.kernel_type == 'PBE': 192 | new_rays_xy[:, 0, :] = 0 193 | delta_trans[:, 0, :] = 0 194 | align = None 195 | else: 196 | align = new_rays_xy[:, 0, :].abs().mean() 197 | align += (delta_trans[:, 0, :].abs().mean() * 10) 198 | weight = torch.softmax(weight[..., 0], dim=-1) 199 | 200 | if self.random_hwindow > 0 and self.random_mode == 'output': 201 | raise NotImplementedError(f"{self.random_mode} for self.random_mode is not implemented") 202 | 203 | poses = self.poses[img_idx] 204 | # get rays from offsetted pt position 205 | rays_x = (rays_x - K[0, 2] + new_rays_xy[..., 0]) / K[0, 0] 206 | rays_y = -(rays_y - K[1, 2] + new_rays_xy[..., 1]) / K[1, 1] 207 | dirs = torch.stack([rays_x - delta_trans[..., 0], 208 | rays_y - delta_trans[..., 1], 209 | -torch.ones_like(rays_x)], -1) 210 | 211 | # Rotate ray directions from camera frame to the world frame 212 | rays_d = torch.sum(dirs[..., None, :] * poses[..., None, :3, :3], 213 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 214 | 215 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 216 | translation = torch.stack([ 217 | delta_trans[..., 0], 218 | delta_trans[..., 1], 219 | torch.zeros_like(rays_x), 220 | torch.ones_like(rays_x) 221 | ], dim=-1) 222 | rays_o = torch.sum(translation[..., None, :] * poses[:, None], dim=-1) 223 | 224 | return torch.stack([rays_o, rays_d], dim=-1), weight, align 225 | 226 | 227 | class NeRFAll(nn.Module): 228 | def __init__(self, args, kernelsnet=None): 229 | super().__init__() 230 | self.args = args 231 | self.embed_fn, self.input_ch = get_embedder(args.multires, args.i_embed) 232 | 233 | self.kernel_type = args.kernel_type 234 | self.mode = args.mode 235 | self.input_ch_views = 0 236 | self.kernelsnet = kernelsnet 237 | self.embeddirs_fn = None 238 | if args.use_viewdirs: 239 | self.embeddirs_fn, self.input_ch_views = get_embedder(args.multires_views, args.i_embed) 240 | 241 | self.output_ch = 5 if args.N_importance > 0 else 4 242 | 243 | skips = [4] 244 | if self.mode == 'c2f': 245 | self.mlp_coarse = NeRFSmall_ray(aabb=args.bounding_box, 246 | num_layers=args.coarse_num_layers, 247 | hidden_dim=args.coarse_hidden_dim, 248 | geo_feat_dim=args.kernel_feat_cnl, 249 | num_layers_color=args.coarse_num_layers_color, 250 | hidden_dim_color=args.coarse_hidden_dim_color, 251 | input_ch=args.coarse_app_dim+self.input_ch, input_ch_views=self.input_ch_views, 252 | render_rmnearplane=args.render_rmnearplane,app_dim=args.coarse_app_dim, 253 | app_n_comp=args.coarse_app_n_comp, n_voxels=args.coarse_n_voxels) 254 | 255 | 256 | grad_vars_vol, grad_vars_net = self.mlp_coarse.get_optparam_groups() 257 | if self.kernelsnet != None: 258 | self.grad_vars = grad_vars_net + list(self.kernelsnet.parameters()) 259 | else: 260 | self.grad_vars = grad_vars_net 261 | self.grad_vars_vol = grad_vars_vol 262 | 263 | 264 | self.mlp_fine = NeRFSmall_voxel(aabb=args.bounding_box, 265 | num_layers=args.fine_num_layers, 266 | hidden_dim=args.fine_hidden_dim, 267 | geo_feat_dim=args.fine_geo_feat_dim, 268 | num_layers_color=args.fine_num_layers_color, 269 | hidden_dim_color=args.fine_hidden_dim_color, 270 | input_ch=args.coarse_app_dim+args.fine_app_dim+self.input_ch, input_ch_views=self.input_ch_views, 271 | render_rmnearplane=args.render_rmnearplane,app_dim=args.fine_app_dim, 272 | app_n_comp=args.fine_app_n_comp, n_voxels=args.fine_n_voxels) 273 | 274 | 275 | grad_vars_vol, grad_vars_net = self.mlp_fine.get_optparam_groups() 276 | self.grad_vars += grad_vars_net 277 | self.grad_vars_vol += grad_vars_vol 278 | elif self.mode == 'nerf': 279 | self.mlp_coarse = NeRF( 280 | D=args.netdepth, W=args.netwidth, 281 | input_ch=self.input_ch, output_ch=self.output_ch, skips=skips, 282 | input_ch_views=self.input_ch_views, use_viewdirs=args.use_viewdirs, 283 | rgb_activate=args.rgb_activate,sigma_activate=args.sigma_activate, 284 | render_rmnearplane=args.render_rmnearplane) 285 | 286 | self.mlp_fine = NeRF( 287 | D=args.netdepth_fine, W=args.netwidth_fine, 288 | input_ch=self.input_ch, output_ch=self.output_ch, skips=skips, 289 | input_ch_views=self.input_ch_views, use_viewdirs=args.use_viewdirs, 290 | rgb_activate=args.rgb_activate,sigma_activate=args.sigma_activate, 291 | render_rmnearplane=args.render_rmnearplane) 292 | else: 293 | raise NotImplementedError(f"{self.mode} for rendering network is not implemented") 294 | 295 | 296 | 297 | activate = {'relu': torch.relu, 'sigmoid': torch.sigmoid, 'exp': torch.exp, 'none': lambda x: x, 298 | 'sigmoid1': lambda x: 1.002 / (torch.exp(-x) + 1) - 0.001, 299 | 'softplus': lambda x: nn.Softplus()(x - 1)} 300 | self.rgb_activate = activate[args.rgb_activate] 301 | self.sigma_activate = activate[args.sigma_activate] 302 | self.tonemapping = ToneMapping(args.tone_mapping_type) 303 | print(self.mlp_coarse,self.mlp_fine,self.kernelsnet) 304 | 305 | 306 | def render_rays(self, 307 | ray_batch, 308 | N_samples, 309 | retraw=False, 310 | lindisp=False, 311 | perturb=0., 312 | N_importance=0, 313 | white_bkgd=False, 314 | raw_noise_std=0., 315 | pytest=False): 316 | """Volumetric rendering. 317 | Args: 318 | ray_batch: array of shape [batch_size, ...]. All information necessary 319 | for sampling along a ray, including: ray origin, ray direction, min 320 | dist, max dist, and unit-magnitude viewing direction. 321 | N_samples: int. Number of different times to sample along each ray. 322 | retraw: bool. If True, include model's raw, unprocessed predictions. 323 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 324 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 325 | random points in time. 326 | N_importance: int. Number of additional times to sample along each ray. 327 | These samples are only passed to network_fine. 328 | white_bkgd: bool. If True, assume a white background. 329 | raw_noise_std: ... 330 | verbose: bool. If True, print more debugging info. 331 | """ 332 | N_rays = ray_batch.shape[0] 333 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each 334 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None 335 | bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) 336 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1] 337 | 338 | t_vals = torch.linspace(0., 1., steps=N_samples).type_as(rays_o) 339 | if not lindisp: 340 | z_vals = near * (1. - t_vals) + far * (t_vals) 341 | else: 342 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)) 343 | 344 | z_vals = z_vals.expand([N_rays, N_samples]) 345 | 346 | if perturb > 0.: 347 | # get intervals between samples 348 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 349 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 350 | lower = torch.cat([z_vals[..., :1], mids], -1) 351 | # stratified samples in those intervals 352 | t_rand = torch.rand(z_vals.shape).type_as(rays_o) 353 | 354 | z_vals = lower + (upper - lower) * t_rand 355 | 356 | pts0 = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3] 357 | 358 | if self.mode == 'c2f': 359 | ft_coarse = self.mlp_coarse.sample(pts0) 360 | ft_fine = self.mlp_fine.sample(pts0) 361 | ft_comb0 = torch.cat([ft_coarse,ft_fine],-1) 362 | 363 | rgb_map_0, depth_map_0, acc_map_0, weights_coarse, _ = self.mlp_coarse(pts0, viewdirs, ft_coarse, self.embed_fn, self.embeddirs_fn, z_vals, rays_d, raw_noise_std, self.training) 364 | 365 | 366 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 367 | z_samples = sample_pdf(z_vals_mid, weights_coarse[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest) 368 | z_samples = z_samples.detach() 369 | 370 | z_vals, order = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 371 | pts1 = rays_o[..., None, :] + rays_d[..., None, :] * z_samples[..., :, 372 | None] # [N_rays, N_samples + N_importance, 3] 373 | 374 | ft_coarse = self.mlp_coarse.sample(pts1) 375 | pts = torch.cat([pts0, pts1],1)[torch.arange(pts1.shape[0]).unsqueeze(1),order] 376 | ft_fine = self.mlp_fine.sample(pts1) 377 | ft_comb1 = torch.cat([ft_coarse,ft_fine],-1) 378 | ft_comb = torch.cat([ft_comb0, ft_comb1],1)[torch.arange(pts1.shape[0]).unsqueeze(1),order] 379 | 380 | rgb_map, depth_map, acc_map, weights_fine = self.mlp_fine(pts, viewdirs, ft_comb, self.embed_fn, self.embeddirs_fn, z_vals, rays_d, raw_noise_std, self.training) 381 | else: 382 | rgb_map, depth_map, acc_map, weights, _ = self.mlp_coarse(pts0, viewdirs, self.embed_fn,self.embeddirs_fn, z_vals, rays_d, raw_noise_std, white_bkgd, self.training) 383 | 384 | rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map 385 | 386 | z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 387 | z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest) 388 | z_samples = z_samples.detach() 389 | 390 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 391 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, 392 | None] # [N_rays, N_samples + N_importance, 3] 393 | 394 | # mlp = self.mlp_coarse if self.mlp_fine is None else self.mlp_fine 395 | # raw = self.mlpforward(pts, viewdirs, mlp) 396 | 397 | rgb_map, depth_map, acc_map, weights, _ = self.mlp_fine(pts, viewdirs, self.embed_fn, self.embeddirs_fn, z_vals, rays_d, raw_noise_std, white_bkgd, self.training) 398 | 399 | 400 | ret = {'rgb_map': rgb_map, 'depth_map': depth_map, 'acc_map': acc_map} 401 | if N_importance > 0: 402 | ret['rgb0'] = rgb_map_0 403 | ret['depth0'] = depth_map_0 404 | ret['acc0'] = acc_map_0 405 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 406 | 407 | for k in ret: 408 | if torch.isnan(ret[k]).any(): 409 | print(f"! [Numerical Error] {k} contains nan.") 410 | if torch.isinf(ret[k]).any(): 411 | print(f"! [Numerical Error] {k} contains inf.") 412 | return ret 413 | 414 | def forward(self, H, W, K, chunk=1024 * 32, rays=None, rays_info=None, poses=None, **kwargs): 415 | """ 416 | render rays or render poses, rays and poses should atleast specify one 417 | calling model.train() to render rays, where rays, rays_info, should be specified 418 | calling model.eval() to render an image, where poses should be specified 419 | 420 | optional args: 421 | force_naive: when True, will only run the naive NeRF, even if the kernelsnet is specified 422 | 423 | """ 424 | # training 425 | if self.training: 426 | assert rays is not None, "Please specify rays when in the training mode" 427 | 428 | force_baseline = kwargs.pop("force_naive", True) 429 | # kernel mode, run multiple rays to get result of one ray 430 | if self.kernelsnet is not None and not force_baseline: 431 | if self.kernel_type == 'PBE': 432 | new_rays0, weight0, _ = self.kernelsnet(H, W, K, rays, rays_info) 433 | ray_num, pt_num = new_rays0.shape[:2] 434 | 435 | rgb0, features = self.coarse_render(H, W, K, chunk, new_rays0.reshape(-1, 3, 2), **kwargs) 436 | 437 | rgb0_pts = rgb0.reshape(ray_num, pt_num, 3) 438 | rgb0 = torch.sum(rgb0_pts * weight0[..., None], dim=1) 439 | rgb0 = self.tonemapping(rgb0) 440 | else: 441 | features = None 442 | 443 | new_rays, weight1, align_loss = self.kernelsnet(H, W, K, rays, rays_info,feats=features) 444 | ray_num, pt_num = new_rays.shape[:2] 445 | rgb, depth, acc, extras = self.render(H, W, K, chunk, new_rays.reshape(-1, 3, 2), **kwargs) 446 | rgb_pts = rgb.reshape(ray_num, pt_num, 3) 447 | rgb1_pts = extras['rgb0'].reshape(ray_num, pt_num, 3) 448 | rgb = torch.sum(rgb_pts * weight1[..., None], dim=1) 449 | rgb1 = torch.sum(rgb1_pts * weight1[..., None], dim=1) 450 | rgb = self.tonemapping(rgb) 451 | rgb1 = self.tonemapping(rgb1) 452 | if self.kernel_type == 'PBE': 453 | rgb1 = (rgb0 + rgb1)/2 454 | 455 | other_loss = {} 456 | if self.mode == 'c2f': 457 | other_loss["TV"] = (self.mlp_fine.TV_loss_app()+self.mlp_coarse.TV_loss_app()) * 5 458 | if align_loss is not None: 459 | other_loss["align"] = align_loss.reshape(1, 1) 460 | return rgb, rgb1, other_loss#, [weight]+trace 461 | else: 462 | rgb, depth, acc, extras = self.render(H, W, K, chunk, rays, **kwargs) 463 | other_loss = {} 464 | if self.mode == 'c2f': 465 | other_loss["TV"] = (self.mlp_fine.TV_loss_app()+self.mlp_coarse.TV_loss_app()) * 5 466 | return self.tonemapping(rgb), self.tonemapping(extras['rgb0']), other_loss 467 | 468 | # evaluation 469 | else: 470 | assert poses is not None, "Please specify poses when in the eval model" 471 | rgbs, depths = self.render_path(H, W, K, chunk, poses, **kwargs) 472 | return self.tonemapping(rgbs), depths 473 | 474 | def render(self, H, W, K, chunk, rays=None, c2w=None, ndc=True, 475 | near=0., far=1., 476 | use_viewdirs=False, c2w_staticcam=None, 477 | **kwargs): # the render function 478 | """Render rays 479 | Args: 480 | H: int. Height of image in pixels. 481 | W: int. Width of image in pixels. 482 | focal: float. Focal length of pinhole camera. 483 | chunk: int. Maximum number of rays to process simultaneously. Used to 484 | control maximum memory usage. Does not affect final results. 485 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 486 | each example in batch. 487 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 488 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 489 | near: float or array of shape [batch_size]. Nearest distance for a ray. 490 | far: float or array of shape [batch_size]. Farthest distance for a ray. 491 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 492 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 493 | camera while using other c2w argument for viewing directions. 494 | Returns: 495 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 496 | disp_map: [batch_size]. Disparity map. Inverse of depth. 497 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 498 | extras: dict with everything returned by render_rays(). 499 | """ 500 | rays_o, rays_d = rays[..., 0], rays[..., 1] 501 | 502 | if use_viewdirs: 503 | # provide ray directions as input 504 | viewdirs = rays_d 505 | if c2w_staticcam is not None: 506 | # special case to visualize effect of viewdirs 507 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 508 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 509 | viewdirs = torch.reshape(viewdirs, [-1, 3]).float() 510 | 511 | sh = rays_d.shape # [..., 3] 512 | if ndc: 513 | # for forward facing scenes 514 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 515 | 516 | # Create ray batch 517 | rays_o = torch.reshape(rays_o, [-1, 3]).float() 518 | rays_d = torch.reshape(rays_d, [-1, 3]).float() 519 | 520 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) 521 | rays = torch.cat([rays_o, rays_d, near, far], -1) 522 | if use_viewdirs: 523 | rays = torch.cat([rays, viewdirs], -1) 524 | 525 | # Batchfy and Render and reshape 526 | all_ret = {} 527 | for i in range(0, rays.shape[0], chunk): 528 | ret = self.render_rays(rays[i:i + chunk], **kwargs) 529 | for k in ret: 530 | if k not in all_ret: 531 | all_ret[k] = [] 532 | all_ret[k].append(ret[k]) 533 | all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} 534 | 535 | for k in all_ret: 536 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 537 | # print(k,k_sh,all_ret[k].shape) 538 | if all_ret[k].shape[0] == 5120: 539 | continue 540 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 541 | 542 | k_extract = ['rgb_map', 'depth_map', 'acc_map'] 543 | ret_list = [all_ret[k] for k in k_extract] 544 | ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} 545 | # ret_dict['dist'] = dist_loss 546 | return ret_list + [ret_dict] 547 | 548 | 549 | 550 | def coarse_render(self, H, W, K, chunk, rays=None, c2w=None, ndc=True, 551 | near=0., far=1., 552 | use_viewdirs=False, c2w_staticcam=None, 553 | **kwargs): # the render function 554 | """Render rays 555 | Args: 556 | H: int. Height of image in pixels. 557 | W: int. Width of image in pixels. 558 | focal: float. Focal length of pinhole camera. 559 | chunk: int. Maximum number of rays to process simultaneously. Used to 560 | control maximum memory usage. Does not affect final results. 561 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 562 | each example in batch. 563 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 564 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 565 | near: float or array of shape [batch_size]. Nearest distance for a ray. 566 | far: float or array of shape [batch_size]. Farthest distance for a ray. 567 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 568 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 569 | camera while using other c2w argument for viewing directions. 570 | Returns: 571 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 572 | disp_map: [batch_size]. Disparity map. Inverse of depth. 573 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 574 | extras: dict with everything returned by render_rays(). 575 | """ 576 | rays_o, rays_d = rays[..., 0], rays[..., 1] 577 | 578 | if use_viewdirs: 579 | # provide ray directions as input 580 | viewdirs = rays_d 581 | if c2w_staticcam is not None: 582 | # special case to visualize effect of viewdirs 583 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 584 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 585 | viewdirs = torch.reshape(viewdirs, [-1, 3]).float() 586 | 587 | sh = rays_d.shape # [..., 3] 588 | if ndc: 589 | # for forward facing scenes 590 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 591 | 592 | # Create ray batch 593 | rays_o = torch.reshape(rays_o, [-1, 3]).float() 594 | rays_d = torch.reshape(rays_d, [-1, 3]).float() 595 | 596 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) 597 | rays = torch.cat([rays_o, rays_d, near, far], -1) 598 | if use_viewdirs: 599 | rays = torch.cat([rays, viewdirs], -1) 600 | 601 | # Batchfy and Render and reshape 602 | all_ret = {} 603 | rgb,feat = self.coarse_render_rays(rays, **kwargs) 604 | return rgb,feat 605 | 606 | 607 | 608 | def coarse_render_rays(self, 609 | ray_batch, 610 | N_samples, 611 | retraw=False, 612 | lindisp=False, 613 | perturb=0., 614 | N_importance=0, 615 | white_bkgd=False, 616 | raw_noise_std=0., 617 | pytest=False): 618 | """Volumetric rendering. 619 | Args: 620 | ray_batch: array of shape [batch_size, ...]. All information necessary 621 | for sampling along a ray, including: ray origin, ray direction, min 622 | dist, max dist, and unit-magnitude viewing direction. 623 | N_samples: int. Number of different times to sample along each ray. 624 | retraw: bool. If True, include model's raw, unprocessed predictions. 625 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 626 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 627 | random points in time. 628 | N_importance: int. Number of additional times to sample along each ray. 629 | These samples are only passed to network_fine. 630 | white_bkgd: bool. If True, assume a white background. 631 | raw_noise_std: ... 632 | verbose: bool. If True, print more debugging info. 633 | """ 634 | N_rays = ray_batch.shape[0] 635 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each 636 | viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None 637 | bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) 638 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1] 639 | 640 | t_vals = torch.linspace(0., 1., steps=N_samples).type_as(rays_o) 641 | if not lindisp: 642 | z_vals = near * (1. - t_vals) + far * (t_vals) 643 | else: 644 | z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)) 645 | 646 | z_vals = z_vals.expand([N_rays, N_samples]) 647 | 648 | if perturb > 0.: 649 | # get intervals between samples 650 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 651 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 652 | lower = torch.cat([z_vals[..., :1], mids], -1) 653 | # stratified samples in those intervals 654 | t_rand = torch.rand(z_vals.shape).type_as(rays_o) 655 | 656 | # Pytest, overwrite u with numpy's fixed random numbers 657 | # pytest=True 658 | # if pytest: 659 | # np.random.seed(0) 660 | # t_rand = np.random.rand(*list(z_vals.shape)) 661 | # t_rand = torch.tensor(t_rand) 662 | 663 | z_vals = lower + (upper - lower) * t_rand 664 | 665 | pts0 = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3] 666 | 667 | if self.mode == 'c2f': 668 | ft_coarse = self.mlp_coarse.sample(pts0) 669 | rgb_map, _, _, _, feat = self.mlp_coarse(pts0, viewdirs, ft_coarse, self.embed_fn, self.embeddirs_fn, z_vals, rays_d, raw_noise_std, self.training) 670 | else: 671 | rgb_map, _, _, _, feat = self.mlp_coarse(pts0, viewdirs, self.embed_fn,self.embeddirs_fn, z_vals, rays_d, raw_noise_std, white_bkgd, self.training) 672 | return rgb_map, feat 673 | 674 | def render_path(self, H, W, K, chunk, render_poses, render_kwargs, render_factor=0, ): 675 | """ 676 | render image specified by the render_poses 677 | """ 678 | if render_factor != 0: 679 | # Render downsampled for speed 680 | H = H // render_factor 681 | W = W // render_factor 682 | 683 | rgbs = [] 684 | depths = [] 685 | 686 | t = time.time() 687 | for i, c2w in enumerate(render_poses): 688 | print(i, time.time() - t) 689 | t = time.time() 690 | rays = get_rays(H, W, K, c2w) 691 | rays = torch.stack(rays, dim=-1) 692 | rgb, depth, acc, extras = self.render(H, W, K, chunk=chunk, rays=rays, c2w=c2w[:3, :4], **render_kwargs) 693 | 694 | rgbs.append(rgb) 695 | depths.append(depth) 696 | # rgbs.append(extras['rgb0']) 697 | # depths.append(extras['depth0']) 698 | if i == 0: 699 | print(rgb.shape, depth.shape) 700 | 701 | rgbs = torch.stack(rgbs, 0) 702 | depths = torch.stack(depths, 0) 703 | 704 | return rgbs, depths 705 | -------------------------------------------------------------------------------- /PDRF_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/PDRF_network.png -------------------------------------------------------------------------------- /PDRF_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/PDRF_plot.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official implementation of the paper [PDRF: Progressively Deblurring Radiance Field for Fast and Robust Scene 2 | Reconstruction from Blurry Images](https://arxiv.org/abs/2208.08049) 3 | # PDRF 4 | ![](PDRF_network.png) 5 | ## Method Overview 6 | Progressively Deblurring Radiance Field (PDRF) is a novel approach to efficiently reconstruct high quality radiance fields from blurry images. Compared to previous methods like NeRF and DeblurNeRF, PDRF is both much faster and more performant by utilizing radiance field features to model blur. 7 | 8 | ![](results.png) 9 | 10 | 11 | 12 | ## Quick Start 13 | 14 | ### 1. Install environment 15 | 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ### 2. Download dataset 21 | We evaluate our results based on the dataset captured by [DeblurNeRF](https://arxiv.org/abs/2111.14292). You can download the data in [here](https://hkustconnect-my.sharepoint.com/:f:/g/personal/lmaag_connect_ust_hk/EqB3QrnNG5FMpGzENQq_hBMBSaCQiZXP7yGCVlBHIGuSVA?e=UaSQCC). 22 | 23 | 24 | ### 3. Setting parameters in configs/ 25 | 26 | This codebase supports several configurations for radiance field modeling. Two parameters worth noting are ```args.mode={c2f,nerf}``` and ```args.kernel_type={none,DSK,PBE}```. Specifically, ```args.mode=c2f``` leads to our coarse-to-fine rendering architectures (CRR+FVR), which leverages explicit representation (implemented based on [TensoRF](https://arxiv.org/abs/2203.09517)) and an improved importance sampling scheme; ```args.mode=PBE``` is our two-stage blur-modeling design. DeblurNeRF can be implemented with ```args.mode=nerf,args.kernel_type=DSK```; you can also try any other combinations. 27 | 28 | 29 | 30 | ### 4. Execute 31 | 32 | ``` 33 | python3 run_nerf.py --config configs/defocustanabata/tx_defocustanabata_full.txt 34 | ``` 35 | 36 | ### 5. Visualize 37 | To render a video on the learned scene: 38 | ``` 39 | python3 run_nerf.py --config configs/defocustanabata/tx_defocustanabata_full.txt --render_only 40 | ``` 41 | 42 | To render images on the learned scene: 43 | ``` 44 | python3 run_nerf.py --config configs/defocustanabata/tx_defocustanabata_full.txt --render_only --render_test 45 | ``` 46 | 47 | 48 | ## Citation 49 | If you find this useful, please consider citing our paper: 50 | ``` 51 | @misc{peng2023pdrf, 52 | title={PDRF: Progressively Deblurring Radiance Field for Fast and Robust Scene 53 | Reconstruction from Blurry Images}, 54 | author={Cheng, Peng and Rama, Chellappa}, 55 | year={2023}, 56 | booktitle = {The 37th AAAI Conference on Artificial Intelligence} 57 | } 58 | ``` 59 | 60 | ## Acknowledge 61 | This source code is derived from multiple sources, in particular: [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/), [Deblurnerf](https://github.com/limacv/Deblur-NeRF), [TensoRF](https://github.com/apchenstu/TensoRF), and [HashNeRF-pytorch](https://github.com/yashbhalgat/HashNeRF-pytorch). We thank the previous authors for their awesome and consistent implementations. 62 | -------------------------------------------------------------------------------- /configs/defocustanabata/tx_defocustanabata_full.txt: -------------------------------------------------------------------------------- 1 | num_gpu = 1 2 | expname = defocustanabata1_deblur_two_stage_10pts_update 3 | basedir = ./log/ 4 | datadir = ../data/synthetic_defocus_blur/defocustanabata 5 | tbdir = ./log/ 6 | dataset_type = llff 7 | 8 | seed = 10000 9 | factor = 1 10 | llffhold = 8 11 | 12 | N_rand = 1024 13 | N_samples = 64 14 | N_importance = 64 15 | N_iters = 30000 16 | lrate = 0.01 17 | lrate_decay = 10 18 | 19 | use_viewdirs = True 20 | raw_noise_std = 1e0 21 | rgb_activate = sigmoid 22 | 23 | mode = c2f 24 | coarse_num_layers = 2 25 | coarse_num_layers_color = 3 26 | coarse_hidden_dim = 64 27 | coarse_hidden_dim_color = 64 28 | coarse_app_dim = 32 29 | coarse_app_n_comp = [64,16,16] 30 | coarse_n_voxels = 16777248 31 | 32 | 33 | fine_num_layers = 2 34 | fine_num_layers_color = 3 35 | fine_hidden_dim = 256 36 | fine_hidden_dim_color = 256 37 | fine_geo_feat_dim = 128 38 | fine_app_dim = 32 39 | fine_app_n_comp = [64,16,16] 40 | fine_n_voxels = 134217984 41 | 42 | 43 | kernel_start_iter = 1200 44 | # kernel_prior_weight = 0.1 45 | # prior_start_iter = 15000 46 | kernel_align_weight = 0.1 47 | align_start_iter = 0 48 | align_end_iter = 180000 49 | 50 | kernel_type = PBE 51 | kernel_ptnum = 10 52 | kernel_random_hwindow = 0.15 53 | kernel_random_mode = input 54 | 55 | kernel_img_embed = 32 56 | kernel_rand_embed = 2 # the in_embed 57 | kernel_spatial_embed = 2 58 | kernel_depth_embed = 0 59 | 60 | kernel_num_hidden = 4 61 | kernel_num_wide = 64 62 | kernel_shortcut 63 | 64 | kernel_spatialvariant_trans 65 | tone_mapping_type = gamma 66 | 67 | render_radius_scale = 0.85 -------------------------------------------------------------------------------- /configs/defocustanabata/tx_defocustanabata_naive.txt: -------------------------------------------------------------------------------- 1 | num_gpu = 1 2 | expname = defocustanabata1_nerf 3 | basedir = ./log/ 4 | datadir = ../data/synthetic_defocus_blur/defocustanabata 5 | tbdir = ./log/ 6 | dataset_type = llff 7 | 8 | factor = 1 9 | llffhold = 8 10 | 11 | N_rand = 1024 12 | N_samples = 64 13 | N_importance = 64 14 | N_iters = 200000 15 | lrate = 5e-4 16 | lrate_decay = 250 17 | 18 | use_viewdirs = True 19 | raw_noise_std = 1e0 20 | rgb_activate = sigmoid 21 | 22 | mode = nerf 23 | kernel_type = none 24 | 25 | render_radius_scale = 0.85 26 | 27 | i_weights = 20000 28 | i_testset = 20000 29 | i_video = 20000 -------------------------------------------------------------------------------- /configs/defocustanabata/tx_defocustanabata_naive_c2f.txt: -------------------------------------------------------------------------------- 1 | num_gpu = 1 2 | expname = defocustanabata1_nerf 3 | basedir = ./log/ 4 | datadir = ../data/synthetic_defocus_blur/defocustanabata 5 | tbdir = ./log/ 6 | dataset_type = llff 7 | 8 | factor = 1 9 | llffhold = 8 10 | 11 | N_rand = 1024 12 | N_samples = 64 13 | N_importance = 64 14 | N_iters = 35000 15 | lrate = 5e-4 16 | lrate_decay = 250 17 | 18 | use_viewdirs = True 19 | raw_noise_std = 1e0 20 | rgb_activate = sigmoid 21 | 22 | mode = c2f 23 | kernel_type = none 24 | 25 | render_radius_scale = 0.85 26 | -------------------------------------------------------------------------------- /configs/defocustanabata/tx_defocustanabata_nerf.txt: -------------------------------------------------------------------------------- 1 | num_gpu = 1 2 | expname = defocustanabata1_deblur_5pts_DeblurNeRF_update 3 | basedir = ./log/ 4 | datadir = ../data/synthetic_defocus_blur/defocustanabata 5 | tbdir = ./log/ 6 | dataset_type = llff 7 | 8 | seed = 10000 9 | factor = 1 10 | llffhold = 8 11 | 12 | N_rand = 1024 13 | N_samples = 64 14 | N_importance = 64 15 | N_iters = 200000 16 | lrate = 5e-4 17 | lrate_decay = 250 18 | 19 | use_viewdirs = True 20 | raw_noise_std = 1e0 21 | rgb_activate = sigmoid 22 | 23 | mode = nerf 24 | kernel_start_iter = 1200 25 | # kernel_prior_weight = 0.1 26 | # prior_start_iter = 15000 27 | kernel_align_weight = 0.1 28 | align_start_iter = 0 29 | align_end_iter = 180000 30 | 31 | kernel_type = DSK 32 | kernel_ptnum = 5 33 | kernel_random_hwindow = 0.15 34 | kernel_random_mode = input 35 | 36 | kernel_img_embed = 32 37 | kernel_rand_embed = 2 # the in_embed 38 | kernel_spatial_embed = 2 39 | kernel_depth_embed = 0 40 | 41 | kernel_num_hidden = 4 42 | kernel_num_wide = 64 43 | kernel_shortcut 44 | 45 | kernel_spatialvariant_trans 46 | tone_mapping_type = gamma 47 | 48 | render_radius_scale = 0.85 49 | 50 | i_weights = 20000 51 | i_testset = 20000 52 | i_video = 40000 -------------------------------------------------------------------------------- /configs/defocustanabata/tx_defocustanabata_twostage_deblurnerf.txt: -------------------------------------------------------------------------------- 1 | num_gpu = 1 2 | expname = defocustanabata1_deblur_two_stage_5pts_DeblurNeRF_update 3 | basedir = ./log/ 4 | datadir = ../data/synthetic_defocus_blur/defocustanabata 5 | tbdir = ./log/ 6 | dataset_type = llff 7 | 8 | seed = 10000 9 | factor = 1 10 | llffhold = 8 11 | 12 | N_rand = 1024 13 | N_samples = 64 14 | N_importance = 64 15 | N_iters = 200000 16 | lrate = 5e-4 17 | lrate_decay = 250 18 | 19 | use_viewdirs = True 20 | raw_noise_std = 1e0 21 | rgb_activate = sigmoid 22 | 23 | mode = nerf 24 | kernel_start_iter = 1200 25 | # kernel_prior_weight = 0.1 26 | # prior_start_iter = 15000 27 | kernel_align_weight = 0.1 28 | align_start_iter = 0 29 | align_end_iter = 180000 30 | 31 | kernel_type = PBE 32 | kernel_ptnum = 5 33 | kernel_random_hwindow = 0.15 34 | kernel_random_mode = input 35 | 36 | kernel_img_embed = 32 37 | kernel_rand_embed = 2 # the in_embed 38 | kernel_spatial_embed = 2 39 | kernel_depth_embed = 0 40 | 41 | kernel_feat_cnl = 256 42 | kernel_num_hidden = 4 43 | kernel_num_wide = 64 44 | kernel_shortcut 45 | 46 | kernel_spatialvariant_trans 47 | tone_mapping_type = gamma 48 | 49 | render_radius_scale = 0.85 50 | 51 | i_weights = 20000 52 | i_testset = 20000 53 | i_video = 40000 -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | 4 | from utils import get_bbox3d_for_llff 5 | # import random 6 | # import torch 7 | # seed = 0 8 | # random.seed(seed) 9 | # np.random.seed(seed) 10 | # torch.manual_seed(seed) 11 | # torch.cuda.manual_seed_all(seed) 12 | ########## Slightly modified version of LLFF data loading code 13 | ########## see https://github.com/Fyusion/LLFF for original 14 | 15 | def _minify(basedir, factors=[], resolutions=[]): 16 | needtoload = False 17 | for r in factors: 18 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 19 | if not os.path.exists(imgdir): 20 | needtoload = True 21 | for r in resolutions: 22 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 23 | if not os.path.exists(imgdir): 24 | needtoload = True 25 | if not needtoload: 26 | return 27 | 28 | from shutil import copy 29 | from subprocess import check_output 30 | 31 | imgdir = os.path.join(basedir, 'images') 32 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 33 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 34 | imgdir_orig = imgdir 35 | 36 | wd = os.getcwd() 37 | 38 | for r in factors + resolutions: 39 | if isinstance(r, int): 40 | name = 'images_{}'.format(r) 41 | resizearg = '{}%'.format(100. / r) 42 | else: 43 | name = 'images_{}x{}'.format(r[1], r[0]) 44 | resizearg = '{}x{}'.format(r[1], r[0]) 45 | imgdir = os.path.join(basedir, name) 46 | if os.path.exists(imgdir): 47 | continue 48 | 49 | print('Minifying', r, basedir) 50 | 51 | os.makedirs(imgdir) 52 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 53 | 54 | ext = imgs[0].split('.')[-1] 55 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 56 | print(args) 57 | os.chdir(imgdir) 58 | check_output(args, shell=True) 59 | os.chdir(wd) 60 | 61 | if ext != 'png': 62 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 63 | print('Removed duplicates') 64 | print('Done') 65 | 66 | 67 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 68 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 69 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) 70 | bds = poses_arr[:, -2:].transpose([1, 0]) 71 | 72 | # img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 73 | # if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 74 | # sh = imageio.imread(img0).shape 75 | 76 | sfx = '' 77 | 78 | if factor is not None: 79 | sfx = '_{}'.format(factor) 80 | _minify(basedir, factors=[factor]) 81 | factor = factor 82 | else: 83 | factor = 1 84 | 85 | imgdir = os.path.join(basedir, 'images' + sfx) 86 | if not os.path.exists(imgdir): 87 | print(imgdir, 'does not exist, returning') 88 | return 89 | 90 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if 91 | f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 92 | if poses.shape[-1] != len(imgfiles): 93 | print('Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1])) 94 | return 95 | 96 | sh = imageio.imread(imgfiles[0]).shape 97 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 98 | poses[2, 4, :] = poses[2, 4, :] * 1. / factor 99 | 100 | if not load_imgs: 101 | return poses, bds 102 | 103 | def imread(f): 104 | if f.endswith('png'): 105 | return imageio.imread(f, ignoregamma=True) 106 | else: 107 | return imageio.imread(f) 108 | 109 | imgs = imgs = [imread(f)[..., :3] / 255. for f in imgfiles] 110 | imgs = np.stack(imgs, -1) 111 | 112 | print('Loaded image data', imgs.shape, poses[:, -1, 0]) 113 | return poses, bds, imgs 114 | 115 | 116 | def normalize(x): 117 | return x / np.linalg.norm(x) 118 | 119 | 120 | def viewmatrix(z, up, pos): 121 | vec2 = normalize(z) 122 | vec1_avg = up 123 | vec0 = normalize(np.cross(vec1_avg, vec2)) 124 | vec1 = normalize(np.cross(vec2, vec0)) 125 | m = np.stack([vec0, vec1, vec2, pos], 1) 126 | return m 127 | 128 | 129 | def ptstocam(pts, c2w): 130 | tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] 131 | return tt 132 | 133 | 134 | def poses_avg(poses): 135 | hwf = poses[0, :3, -1:] 136 | 137 | center = poses[:, :3, 3].mean(0) 138 | vec2 = normalize(poses[:, :3, 2].sum(0)) 139 | up = poses[:, :3, 1].sum(0) 140 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 141 | 142 | return c2w 143 | 144 | 145 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 146 | render_poses = [] 147 | rads = np.array(list(rads) + [1.]) 148 | hwf = c2w[:, 4:5] 149 | 150 | for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]: 151 | # view direction 152 | # c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 153 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 154 | # camera poses 155 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 156 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 157 | return render_poses 158 | 159 | 160 | def render_path_epi(c2w, up, rads, N): 161 | render_poses = [] 162 | hwf = c2w[:, 4:5] 163 | 164 | for theta in np.linspace(-1, 1, N + 1)[:-1]: 165 | # view direction 166 | c = np.dot(c2w[:3, :4], np.array([theta, 0, 0, 1.]) * rads) 167 | # camera poses 168 | z = normalize(np.dot(c2w[:3, :4], np.array([0, 0, 1, 0.]))) 169 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 170 | return render_poses 171 | 172 | 173 | def recenter_poses(poses): 174 | poses_ = poses + 0 175 | bottom = np.reshape([0, 0, 0, 1.], [1, 4]) 176 | c2w = poses_avg(poses) 177 | c2w = np.concatenate([c2w[:3, :4], bottom], -2) 178 | bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) 179 | poses = np.concatenate([poses[:, :3, :4], bottom], -2) 180 | 181 | poses = np.linalg.inv(c2w) @ poses 182 | poses_[:, :3, :4] = poses[:, :3, :4] 183 | poses = poses_ 184 | return poses 185 | 186 | 187 | ##################### 188 | 189 | 190 | def spherify_poses(poses, bds): 191 | p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1) 192 | 193 | rays_d = poses[:, :3, 2:3] 194 | rays_o = poses[:, :3, 3:4] 195 | 196 | def min_line_dist(rays_o, rays_d): 197 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) 198 | b_i = -A_i @ rays_o 199 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)) 200 | return pt_mindist 201 | 202 | pt_mindist = min_line_dist(rays_o, rays_d) 203 | 204 | center = pt_mindist 205 | up = (poses[:, :3, 3] - center).mean(0) 206 | 207 | vec0 = normalize(up) 208 | vec1 = normalize(np.cross([.1, .2, .3], vec0)) 209 | vec2 = normalize(np.cross(vec0, vec1)) 210 | pos = center 211 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 212 | 213 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) 214 | 215 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) 216 | 217 | sc = 1. / rad 218 | poses_reset[:, :3, 3] *= sc 219 | bds *= sc 220 | rad *= sc 221 | 222 | centroid = np.mean(poses_reset[:, :3, 3], 0) 223 | zh = centroid[2] 224 | radcircle = np.sqrt(rad ** 2 - zh ** 2) 225 | new_poses = [] 226 | 227 | for th in np.linspace(0., 2. * np.pi, 120): 228 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 229 | up = np.array([0, 0, -1.]) 230 | 231 | vec2 = normalize(camorigin) 232 | vec0 = normalize(np.cross(vec2, up)) 233 | vec1 = normalize(np.cross(vec2, vec0)) 234 | pos = camorigin 235 | p = np.stack([vec0, vec1, vec2, pos], 1) 236 | 237 | new_poses.append(p) 238 | 239 | new_poses = np.stack(new_poses, 0) 240 | 241 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1) 242 | poses_reset = np.concatenate( 243 | [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1) 244 | 245 | return poses_reset, new_poses, bds 246 | 247 | 248 | def load_llff_data(args, basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_epi=False): 249 | poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 250 | print('Loaded', basedir, bds.min(), bds.max()) 251 | 252 | # Correct rotation matrix ordering and move variable dim to axis 0 253 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 254 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 255 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 256 | images = imgs 257 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 258 | 259 | # Rescale if bd_factor is provided 260 | sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor) 261 | poses[:, :3, 3] *= sc 262 | bds *= sc 263 | 264 | if recenter: 265 | poses = recenter_poses(poses) 266 | 267 | # generate render_poses for video generation 268 | if spherify: 269 | poses, render_poses, bds = spherify_poses(poses, bds) 270 | 271 | else: 272 | c2w = poses_avg(poses) 273 | print('recentered', c2w.shape) 274 | print(c2w[:3, :4]) 275 | 276 | ## Get spiral 277 | # Get average pose 278 | up = normalize(poses[:, :3, 1].sum(0)) 279 | 280 | # Find a reasonable "focus depth" for this dataset 281 | close_depth, inf_depth = bds.min() * .9, bds.max() * 5. 282 | dt = .75 283 | mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) 284 | focal = mean_dz 285 | focal = focal * args.render_focuspoint_scale 286 | # Get radii for spiral path 287 | shrink_factor = .8 288 | zdelta = close_depth * .2 289 | tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 290 | rads = np.percentile(np.abs(tt), 90, 0) 291 | rads[0] *= args.render_radius_scale 292 | rads[1] *= args.render_radius_scale 293 | c2w_path = c2w 294 | N_views = 120 295 | N_rots = 2 296 | # Generate poses for spiral path 297 | # rads = [0.7, 0.2, 0.7] 298 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 299 | 300 | if path_epi: 301 | # zloc = np.percentile(tt, 10, 0)[2] 302 | rads[0] = rads[0] / 2 303 | render_poses = render_path_epi(c2w_path, up, rads[0], N_views) 304 | 305 | render_poses = np.array(render_poses).astype(np.float32)#[:10] 306 | 307 | c2w = poses_avg(poses) 308 | print('Data:') 309 | print(poses.shape, images.shape, bds.shape) 310 | 311 | dists = np.sum(np.square(c2w[:3, 3] - poses[:, :3, 3]), -1) 312 | i_test = np.argmin(dists) 313 | print('HOLDOUT view is', i_test) 314 | 315 | images = images.astype(np.float32) 316 | poses = poses.astype(np.float32) 317 | 318 | bounding_box = get_bbox3d_for_llff(poses[:,:3,:4], poses[0,:3,-1], near=0.0, far=1.0) 319 | 320 | return images, poses, bds, render_poses, i_test, bounding_box 321 | -------------------------------------------------------------------------------- /lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | # from torch.autograd import Variable 8 | 9 | from .lpips import * 10 | -------------------------------------------------------------------------------- /lpips/lpips.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from . import pretrained_networks as pn 9 | import torch.nn 10 | 11 | 12 | def normalize_tensor(in_feat, eps=1e-10): 13 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 14 | return in_feat / (norm_factor + eps) 15 | 16 | 17 | def l2(p0, p1, range=255.): 18 | return .5 * np.mean((p0 / range - p1 / range) ** 2) 19 | 20 | 21 | def psnr(p0, p1, peak=255.): 22 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) 23 | 24 | 25 | def dssim(p0, p1, range=255.): 26 | from skimage.measure import compare_ssim 27 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 28 | 29 | 30 | def rgb2lab(in_img, mean_cent=False): 31 | from skimage import color 32 | img_lab = color.rgb2lab(in_img) 33 | if mean_cent: 34 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 35 | return img_lab 36 | 37 | 38 | def tensor2np(tensor_obj): 39 | # change dimension of a tensor object into a numpy array 40 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 41 | 42 | 43 | def np2tensor(np_obj): 44 | # change dimenion of np array into tensor array 45 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 46 | 47 | 48 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 49 | # image tensor to lab tensor 50 | from skimage import color 51 | 52 | img = tensor2im(image_tensor) 53 | img_lab = color.rgb2lab(img) 54 | if mc_only: 55 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 56 | if to_norm and not mc_only: 57 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 58 | img_lab = img_lab / 100. 59 | 60 | return np2tensor(img_lab) 61 | 62 | 63 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 64 | from skimage import color 65 | import warnings 66 | warnings.filterwarnings("ignore") 67 | 68 | lab = tensor2np(lab_tensor) * 100. 69 | lab[:, :, 0] = lab[:, :, 0] + 50 70 | 71 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 72 | if return_inbnd: 73 | # convert back to lab, see if we match 74 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 75 | mask = 1. * np.isclose(lab_back, lab, atol=2.) 76 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 77 | return im2tensor(rgb_back), mask 78 | else: 79 | return im2tensor(rgb_back) 80 | 81 | 82 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 83 | image_numpy = image_tensor[0].cpu().float().numpy() 84 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 85 | return image_numpy.astype(imtype) 86 | 87 | 88 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 89 | return torch.tensor((image / factor - cent) 90 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 91 | 92 | 93 | def tensor2vec(vector_tensor): 94 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 95 | 96 | 97 | def voc_ap(rec, prec, use_07_metric=False): 98 | """ ap = voc_ap(rec, prec, [use_07_metric]) 99 | Compute VOC AP given precision and recall. 100 | If use_07_metric is true, uses the 101 | VOC 07 11 point method (default:False). 102 | """ 103 | if use_07_metric: 104 | # 11 point metric 105 | ap = 0. 106 | for t in np.arange(0., 1.1, 0.1): 107 | if np.sum(rec >= t) == 0: 108 | p = 0 109 | else: 110 | p = np.max(prec[rec >= t]) 111 | ap = ap + p / 11. 112 | else: 113 | # correct AP calculation 114 | # first append sentinel values at the end 115 | mrec = np.concatenate(([0.], rec, [1.])) 116 | mpre = np.concatenate(([0.], prec, [0.])) 117 | 118 | # compute the precision envelope 119 | for i in range(mpre.size - 1, 0, -1): 120 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 121 | 122 | # to calculate area under PR curve, look for points 123 | # where X axis (recall) changes value 124 | i = np.where(mrec[1:] != mrec[:-1])[0] 125 | 126 | # and sum (\Delta recall) * prec 127 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 128 | return ap 129 | 130 | def spatial_average(in_tens, keepdim=True): 131 | return in_tens.mean([2, 3], keepdim=keepdim) 132 | 133 | 134 | def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W 135 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 136 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 137 | 138 | 139 | # Learned perceptual metric 140 | class LPIPS(nn.Module): 141 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 142 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 143 | # lpips - [True] means with linear calibration on top of base network 144 | # pretrained - [True] means load linear weights 145 | 146 | super(LPIPS, self).__init__() 147 | if verbose: 148 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % 149 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 150 | 151 | self.pnet_type = net 152 | self.pnet_tune = pnet_tune 153 | self.pnet_rand = pnet_rand 154 | self.spatial = spatial 155 | self.lpips = lpips # false means baseline of just averaging all layers 156 | self.version = version 157 | self.scaling_layer = ScalingLayer() 158 | 159 | if self.pnet_type in ['vgg', 'vgg16']: 160 | net_type = pn.vgg16 161 | self.chns = [64, 128, 256, 512, 512] 162 | elif self.pnet_type == 'alex': 163 | net_type = pn.alexnet 164 | self.chns = [64, 192, 384, 256, 256] 165 | elif self.pnet_type == 'squeeze': 166 | net_type = pn.squeezenet 167 | self.chns = [64, 128, 256, 384, 384, 512, 512] 168 | self.L = len(self.chns) 169 | 170 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 171 | 172 | if lpips: 173 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 174 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 175 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 176 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 177 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 178 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 179 | if self.pnet_type == 'squeeze': # 7 layers for squeezenet 180 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 181 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 182 | self.lins += [self.lin5, self.lin6] 183 | self.lins = nn.ModuleList(self.lins) 184 | 185 | if pretrained: 186 | if model_path is None: 187 | import inspect 188 | import os 189 | model_path = os.path.abspath( 190 | os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) 191 | 192 | if verbose: 193 | print('Loading model from: %s' % model_path) 194 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 195 | 196 | if eval_mode: 197 | self.eval() 198 | 199 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 200 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 201 | in0 = 2 * in0 - 1 202 | in1 = 2 * in1 - 1 203 | 204 | # v0.0 - original release had a bug, where input was not scaled 205 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( 206 | in0, in1) 207 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 208 | feats0, feats1, diffs = {}, {}, {} 209 | 210 | for kk in range(self.L): 211 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 212 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 213 | 214 | if self.lpips: 215 | if self.spatial: 216 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 217 | else: 218 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 219 | else: 220 | if self.spatial: 221 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 222 | else: 223 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] 224 | 225 | val = res[0] 226 | for l in range(1, self.L): 227 | val += res[l] 228 | 229 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 230 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 231 | # for kk in range(self.L): 232 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 233 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 234 | # a = a/self.L 235 | # from IPython import embed 236 | # embed() 237 | # return 10*torch.log10(b/a) 238 | 239 | if retPerLayer: 240 | return val, res 241 | else: 242 | return val 243 | 244 | 245 | class ScalingLayer(nn.Module): 246 | def __init__(self): 247 | super(ScalingLayer, self).__init__() 248 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 249 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 250 | 251 | def forward(self, inp): 252 | return (inp - self.shift) / self.scale 253 | 254 | 255 | class NetLinLayer(nn.Module): 256 | ''' A single linear layer which does a 1x1 conv ''' 257 | 258 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 259 | super(NetLinLayer, self).__init__() 260 | 261 | layers = [nn.Dropout(), ] if use_dropout else [] 262 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 263 | self.model = nn.Sequential(*layers) 264 | 265 | def forward(self, x): 266 | return self.model(x) 267 | 268 | 269 | class Dist2LogitLayer(nn.Module): 270 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 271 | 272 | def __init__(self, chn_mid=32, use_sigmoid=True): 273 | super(Dist2LogitLayer, self).__init__() 274 | 275 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] 276 | layers += [nn.LeakyReLU(0.2, True), ] 277 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] 278 | layers += [nn.LeakyReLU(0.2, True), ] 279 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] 280 | if use_sigmoid: 281 | layers += [nn.Sigmoid(), ] 282 | self.model = nn.Sequential(*layers) 283 | 284 | def forward(self, d0, d1, eps=0.1): 285 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) 286 | 287 | 288 | class BCERankingLoss(nn.Module): 289 | def __init__(self, chn_mid=32): 290 | super(BCERankingLoss, self).__init__() 291 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 292 | # self.parameters = list(self.net.parameters()) 293 | self.loss = torch.nn.BCELoss() 294 | 295 | def forward(self, d0, d1, judge): 296 | per = (judge + 1.) / 2. 297 | self.logit = self.net.forward(d0, d1) 298 | return self.loss(self.logit, per) 299 | 300 | 301 | # L2, DSSIM metrics 302 | class FakeNet(nn.Module): 303 | def __init__(self, use_gpu=True, colorspace='Lab'): 304 | super(FakeNet, self).__init__() 305 | self.use_gpu = use_gpu 306 | self.colorspace = colorspace 307 | 308 | 309 | class L2(FakeNet): 310 | def forward(self, in0, in1, retPerLayer=None): 311 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 312 | 313 | if self.colorspace == 'RGB': 314 | (N, C, X, Y) = in0.size() 315 | value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), 316 | dim=3).view(N) 317 | return value 318 | elif self.colorspace == 'Lab': 319 | value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), 320 | tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 321 | 'float') 322 | ret_var = Variable(torch.Tensor((value,))) 323 | if self.use_gpu: 324 | ret_var = ret_var.cuda() 325 | return ret_var 326 | 327 | 328 | class DSSIM(FakeNet): 329 | 330 | def forward(self, in0, in1, retPerLayer=None): 331 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 332 | 333 | if self.colorspace == 'RGB': 334 | value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype( 335 | 'float') 336 | elif self.colorspace == 'Lab': 337 | value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), 338 | tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 339 | 'float') 340 | ret_var = Variable(torch.Tensor((value,))) 341 | if self.use_gpu: 342 | ret_var = ret_var.cuda() 343 | return ret_var 344 | 345 | 346 | def print_network(net): 347 | num_params = 0 348 | for param in net.parameters(): 349 | num_params += param.numel() 350 | print('Network', net) 351 | print('Total number of parameters: %d' % num_params) 352 | -------------------------------------------------------------------------------- /lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2, 5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) 52 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | 98 | class vgg16(torch.nn.Module): 99 | def __init__(self, requires_grad=False, pretrained=True): 100 | super(vgg16, self).__init__() 101 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.N_slices = 5 108 | for x in range(4): 109 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(4, 9): 111 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(9, 16): 113 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(16, 23): 115 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(23, 30): 117 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h = self.slice1(X) 124 | h_relu1_2 = h 125 | h = self.slice2(h) 126 | h_relu2_2 = h 127 | h = self.slice3(h) 128 | h_relu3_3 = h 129 | h = self.slice4(h) 130 | h_relu4_3 = h 131 | h = self.slice5(h) 132 | h_relu5_3 = h 133 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 134 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 135 | 136 | return out 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if (num == 18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif (num == 34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif (num == 50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif (num == 101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif (num == 152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from skimage import metrics 2 | import torch 3 | import torch.hub 4 | from lpips.lpips import LPIPS 5 | import os 6 | import numpy as np 7 | 8 | photometric = { 9 | "mse": None, 10 | "ssim": None, 11 | "psnr": None, 12 | "lpips": None 13 | } 14 | 15 | def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor, 16 | metric="mse", margin=0, mask=None): 17 | """ 18 | im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1) 19 | """ 20 | if metric not in photometric.keys(): 21 | raise RuntimeError(f"img_utils:: metric {metric} not recognized") 22 | if photometric[metric] is None: 23 | if metric == "mse": 24 | photometric[metric] = metrics.mean_squared_error 25 | elif metric == "ssim": 26 | photometric[metric] = metrics.structural_similarity 27 | elif metric == "psnr": 28 | photometric[metric] = metrics.peak_signal_noise_ratio 29 | elif metric == "lpips": 30 | photometric[metric] = LPIPS().cpu() 31 | 32 | if mask is not None: 33 | if mask.dim() == 3: 34 | mask = mask.unsqueeze(1) 35 | if mask.shape[1] == 1: 36 | mask = mask.expand(-1, 3, -1, -1) 37 | mask = mask.permute(0, 2, 3, 1).numpy() 38 | batchsz, hei, wid, _ = mask.shape 39 | if margin > 0: 40 | marginh = int(hei * margin) + 1 41 | marginw = int(wid * margin) + 1 42 | mask = mask[:, marginh:hei - marginh, marginw:wid - marginw] 43 | 44 | # convert from [0, 1] to [-1, 1] 45 | im1t = (im1t * 2 - 1).clamp(-1, 1) 46 | im2t = (im2t * 2 - 1).clamp(-1, 1) 47 | 48 | if im1t.dim() == 3: 49 | im1t = im1t.unsqueeze(0) 50 | im2t = im2t.unsqueeze(0) 51 | im1t = im1t.detach().cpu() 52 | im2t = im2t.detach().cpu() 53 | 54 | if im1t.shape[-1] == 3: 55 | im1t = im1t.permute(0, 3, 1, 2) 56 | im2t = im2t.permute(0, 3, 1, 2) 57 | 58 | im1 = im1t.permute(0, 2, 3, 1).numpy() 59 | im2 = im2t.permute(0, 2, 3, 1).numpy() 60 | batchsz, hei, wid, _ = im1.shape 61 | if margin > 0: 62 | marginh = int(hei * margin) + 1 63 | marginw = int(wid * margin) + 1 64 | im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw] 65 | im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw] 66 | values = [] 67 | 68 | for i in range(batchsz): 69 | if metric in ["mse", "psnr"]: 70 | if mask is not None: 71 | im1 = im1 * mask[i] 72 | im2 = im2 * mask[i] 73 | value = photometric[metric]( 74 | im1[i], im2[i] 75 | ) 76 | if mask is not None: 77 | hei, wid, _ = im1[i].shape 78 | pixelnum = mask[i, ..., 0].sum() 79 | value = value - 10 * np.log10(hei * wid / pixelnum) 80 | elif metric in ["ssim"]: 81 | value, ssimmap = photometric["ssim"]( 82 | im1[i], im2[i], multichannel=True, full=True 83 | ) 84 | if mask is not None: 85 | value = (ssimmap * mask[i]).sum() / mask[i].sum() 86 | elif metric in ["lpips"]: 87 | value = photometric[metric]( 88 | im1t[i:i + 1], im2t[i:i + 1] 89 | ) 90 | else: 91 | raise NotImplementedError 92 | values.append(value) 93 | 94 | return sum(values) / len(values) 95 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import os, sys 3 | import os.path as osp 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.optim import Optimizer 8 | from functools import reduce 9 | from torch.optim import AdamW 10 | 11 | 12 | # import random 13 | # import numpy as np 14 | # seed = 0 15 | # random.seed(seed) 16 | # np.random.seed(seed) 17 | # torch.manual_seed(seed) 18 | # torch.cuda.manual_seed_all(seed) 19 | class MultiOptimizer: 20 | def __init__(self, optimizers={}): 21 | self.optimizers = optimizers 22 | self.keys = list(optimizers.keys()) 23 | self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()]) 24 | 25 | def state_dict(self): 26 | state_dicts = [(key, self.optimizers[key].state_dict())\ 27 | for key in self.keys] 28 | return state_dicts 29 | 30 | def load_state_dict(self, state_dict): 31 | for key, val in state_dict: 32 | try: 33 | self.optimizers[key].load_state_dict(val) 34 | except: 35 | print("Unloaded %s" % key) 36 | 37 | def step(self, key=None, scaler=None): 38 | keys = [key] if key is not None else self.keys 39 | _ = [self._step(key, scaler) for key in keys] 40 | 41 | def _step(self, key, scaler=None): 42 | if scaler is not None: 43 | scaler.step(self.optimizers[key]) 44 | scaler.update() 45 | else: 46 | self.optimizers[key].step() 47 | 48 | def zero_grad(self, key=None): 49 | if key is not None: 50 | self.optimizers[key].zero_grad() 51 | else: 52 | _ = [self.optimizers[key].zero_grad() for key in self.keys] 53 | -------------------------------------------------------------------------------- /ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | # import random 4 | # import numpy as np 5 | # seed = 0 6 | # random.seed(seed) 7 | # np.random.seed(seed) 8 | # torch.manual_seed(seed) 9 | # torch.cuda.manual_seed_all(seed) 10 | 11 | def get_ray_directions(H, W, focal): 12 | """ 13 | Get ray directions for all pixels in camera coordinate. 14 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 15 | ray-tracing-generating-camera-rays/standard-coordinate-systems 16 | 17 | Inputs: 18 | H, W, focal: image height, width and focal length 19 | 20 | Outputs: 21 | directions: (H, W, 3), the direction of the rays in camera coordinate 22 | """ 23 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 24 | i, j = grid.unbind(-1) 25 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 26 | # see https://github.com/bmild/nerf/issues/24 27 | directions = \ 28 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3) 29 | 30 | dir_bounds = directions.view(-1, 3) 31 | # print("Directions ", directions[0,0,:], directions[H-1,0,:], directions[0,W-1,:], directions[H-1, W-1, :]) 32 | # print("Directions ", dir_bounds[0], dir_bounds[W-1], dir_bounds[H*W-W], dir_bounds[H*W-1]) 33 | 34 | return directions 35 | 36 | 37 | def get_rays(directions, c2w): 38 | """ 39 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 40 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 41 | ray-tracing-generating-camera-rays/standard-coordinate-systems 42 | 43 | Inputs: 44 | directions: (H, W, 3) precomputed ray directions in camera coordinate 45 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 46 | 47 | Outputs: 48 | rays_o: (H*W, 3), the origin of the rays in world coordinate 49 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 50 | """ 51 | # Rotate ray directions from camera coordinate to the world coordinate 52 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 53 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 54 | # The origin of all rays is the camera origin in world coordinate 55 | rays_o = c2w[:3, -1].expand(rays_d.shape) # (H, W, 3) 56 | 57 | rays_d = rays_d.view(-1, 3) 58 | rays_o = rays_o.view(-1, 3) 59 | 60 | return rays_o, rays_d 61 | 62 | 63 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 64 | """ 65 | Transform rays from world coordinate to NDC. 66 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. 67 | For detailed derivation, please see: 68 | http://www.songho.ca/opengl/gl_projectionmatrix.html 69 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf 70 | 71 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth). 72 | See https://github.com/bmild/nerf/issues/18 73 | 74 | Inputs: 75 | H, W, focal: image height, width and focal length 76 | near: (N_rays) or float, the depths of the near plane 77 | rays_o: (N_rays, 3), the origin of the rays in world coordinate 78 | rays_d: (N_rays, 3), the direction of the rays in world coordinate 79 | 80 | Outputs: 81 | rays_o: (N_rays, 3), the origin of the rays in NDC 82 | rays_d: (N_rays, 3), the direction of the rays in NDC 83 | """ 84 | # Shift ray origins to near plane 85 | t = -(near + rays_o[...,2]) / rays_d[...,2] 86 | rays_o = rays_o + t[...,None] * rays_d 87 | 88 | # Store some intermediate homogeneous results 89 | ox_oz = rays_o[...,0] / rays_o[...,2] 90 | oy_oz = rays_o[...,1] / rays_o[...,2] 91 | 92 | # Projection 93 | o0 = -1./(W/(2.*focal)) * ox_oz 94 | o1 = -1./(H/(2.*focal)) * oy_oz 95 | o2 = 1. + 2. * near / rays_o[...,2] 96 | 97 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz) 98 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz) 99 | d2 = 1 - o2 100 | 101 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) 102 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) 103 | 104 | return rays_o, rays_d 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-image 3 | torch>=1.8 4 | torchvision>=0.9.1 5 | imageio 6 | imageio-ffmpeg 7 | matplotlib 8 | configargparse 9 | tensorboardX>=2.0 10 | opencv-python 11 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/PDRF/9ae6d2b30d80fbbcca6cc5e96238322538f55019/results.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUBLAS_WORKSPACE_CONFIG=:4096:8 CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/defocustanabata/tx_defocustanabata_full.txt 2 | 3 | 4 | CUBLAS_WORKSPACE_CONFIG=:4096:8 CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/defocustanabata/tx_defocustanabata_full.txt --render_only --render_test -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import cv2 4 | import imageio 5 | from tensorboardX import SummaryWriter 6 | from NeRF import * 7 | from load_llff import load_llff_data 8 | from metrics import compute_img_metric 9 | from run_nerf_helpers import * 10 | import random, torch 11 | import numpy as np 12 | DEBUG = False 13 | 14 | 15 | 16 | def config_parser(): 17 | import configargparse 18 | parser = configargparse.ArgumentParser() 19 | parser.add_argument('--config', is_config_file=True, 20 | help='config file path') 21 | parser.add_argument("--expname", type=str, 22 | help='experiment name') 23 | parser.add_argument("--basedir", type=str, default='./logs/', required=True, 24 | help='where to store ckpts and logs') 25 | parser.add_argument("--datadir", type=str, required=True, 26 | help='input data directory') 27 | parser.add_argument("--datadownsample", type=float, default=-1, 28 | help='if downsample > 0, means downsample the image to scale=datadownsample') 29 | parser.add_argument("--tbdir", type=str, required=True, 30 | help="tensorboard log directory") 31 | parser.add_argument("--num_gpu", type=int, default=1, 32 | help=">1 will use DataParallel") 33 | parser.add_argument("--torch_hub_dir", type=str, default='', 34 | help=">1 will use DataParallel") 35 | # training options 36 | parser.add_argument("--seed", type=int, default=0, 37 | help='random seed') 38 | parser.add_argument("--mode", type=str, default='c2f', required=True, 39 | help='choose bewteen c2f (CRR+FVR) or nerf (2 MLPs) for rendering networks') 40 | parser.add_argument("--netdepth", type=int, default=8, 41 | help='layers in network') 42 | parser.add_argument("--netwidth", type=int, default=256, 43 | help='channels per layer') 44 | parser.add_argument("--netdepth_fine", type=int, default=8, 45 | help='layers in fine network') 46 | parser.add_argument("--netwidth_fine", type=int, default=256, 47 | help='channels per layer in fine network') 48 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 4, 49 | help='batch size (number of random rays per gradient step)') 50 | parser.add_argument("--lrate", type=float, default=5e-4, 51 | help='learning rate') 52 | parser.add_argument("--lrate_decay", type=int, default=250, 53 | help='exponential learning rate decay (in 1000 steps)') 54 | # generate N_rand # of rays, divide into chunk # of batch 55 | # then generate chunk * N_samples # of points, divide into netchunk # of batch 56 | parser.add_argument("--chunk", type=int, default=1024 * 32, 57 | help='number of rays processed in parallel, decrease if running out of memory') 58 | parser.add_argument("--netchunk", type=int, default=1024 * 64, 59 | help='number of pts sent through network in parallel, decrease if running out of memory') 60 | parser.add_argument("--no_reload", action='store_true', 61 | help='do not reload weights from saved ckpt') 62 | parser.add_argument("--ft_path", type=str, default=None, 63 | help='specific weights npy file to reload for coarse network') 64 | 65 | # CRR/FVR options: 66 | parser.add_argument("--coarse_num_layers", type=int, default=2, 67 | help='CRR layer for estimating sigma + feature') 68 | parser.add_argument("--coarse_num_layers_color", type=int, default=3, 69 | help='CRR layer for estimating color') 70 | parser.add_argument("--coarse_hidden_dim", type=int, default=64, 71 | help='coarse_hidden_dim') 72 | parser.add_argument("--coarse_hidden_dim_color", type=int, default=64, 73 | help='coarse_hidden_dim_color') 74 | parser.add_argument("--coarse_app_dim", type=int, default=32, 75 | help='coarse_app_dim') 76 | parser.add_argument("--coarse_app_n_comp", type=int, action="append") 77 | parser.add_argument("--coarse_n_voxels", type=int, default=16777248, 78 | help='coarse_n_voxels') 79 | 80 | parser.add_argument("--fine_num_layers", type=int, default=2, 81 | help='FVR layer for estimating sigma + feature') 82 | parser.add_argument("--fine_num_layers_color", type=int, default=3, 83 | help='FVR layer for estimating color') 84 | parser.add_argument("--fine_hidden_dim", type=int, default=256, 85 | help='fine_hidden_dim') 86 | parser.add_argument("--fine_hidden_dim_color", type=int, default=256, 87 | help='fine_hidden_dim_color') 88 | parser.add_argument("--fine_app_dim", type=int, default=32, 89 | help='fine_app_dim') 90 | parser.add_argument("--fine_geo_feat_dim", type=int, default=128, 91 | help='fine_geo_feat_dim') 92 | parser.add_argument("--fine_app_n_comp", type=int, action="append") 93 | parser.add_argument("--fine_n_voxels", type=int, default=134217984, 94 | help='fine_n_voxels') 95 | 96 | 97 | # rendering options 98 | parser.add_argument("--N_iters", type=int, default=50000, 99 | help='number of iteration') 100 | parser.add_argument("--N_samples", type=int, default=64, 101 | help='number of coarse samples per ray') 102 | parser.add_argument("--N_importance", type=int, default=0, 103 | help='number of additional fine samples per ray') 104 | parser.add_argument("--perturb", type=float, default=1., 105 | help='set to 0. for no jitter, 1. for jitter') 106 | parser.add_argument("--use_viewdirs", action='store_true', 107 | help='use full 5D input instead of 3D') 108 | parser.add_argument("--i_embed", type=int, default=0, 109 | help='set 0 for default positional encoding, -1 for none') 110 | parser.add_argument("--multires", type=int, default=10, 111 | help='log2 of max freq for positional encoding (3D location)') 112 | parser.add_argument("--multires_views", type=int, default=4, 113 | help='log2 of max freq for positional encoding (2D direction)') 114 | parser.add_argument("--raw_noise_std", type=float, default=0., 115 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 116 | 117 | parser.add_argument("--rgb_activate", type=str, default='sigmoid', 118 | help='activate function for rgb output, choose among "none", "sigmoid"') 119 | parser.add_argument("--sigma_activate", type=str, default='relu', 120 | help='activate function for sigma output, choose among "relu", "softplue"') 121 | 122 | # =============================== 123 | # Kernel optimizing 124 | # =============================== 125 | parser.add_argument("--kernel_type", type=str, default='kernel', 126 | help='choose among , , ') 127 | parser.add_argument("--kernel_isglobal", action='store_true', 128 | help='if specified, the canonical kernel position is global') 129 | parser.add_argument("--kernel_start_iter", type=int, default=0, 130 | help='start training kernel after # iteration') 131 | parser.add_argument("--kernel_ptnum", type=int, default=5, 132 | help='the number of sparse locations in the kernels ' 133 | 'that involves computing the final color of ray') 134 | parser.add_argument("--kernel_random_hwindow", type=float, default=0.25, 135 | help='randomly displace the predicted ray position') 136 | parser.add_argument("--kernel_img_embed", type=int, default=32, 137 | help='the dim of image laten code') 138 | parser.add_argument("--kernel_feat_cnl", type=int, default=15, 139 | help='the dim of radiance field latent code') 140 | parser.add_argument("--kernel_rand_dim", type=int, default=2, 141 | help='dimensions of input random number which uniformly sample from (0, 1)') 142 | parser.add_argument("--kernel_rand_embed", type=int, default=3, 143 | help='embed frequency of input kernel coordinate') 144 | parser.add_argument("--kernel_rand_mode", type=str, default='float', 145 | help=', <>>, ') 146 | parser.add_argument("--kernel_random_mode", type=str, default='input', 147 | help=', ') 148 | parser.add_argument("--kernel_spatial_embed", type=int, default=0, 149 | help='the dim of spatial coordinate embedding') 150 | parser.add_argument("--kernel_depth_embed", type=int, default=0, 151 | help='the dim of depth coordinate embedding') 152 | parser.add_argument("--kernel_hwindow", type=int, default=10, 153 | help='the max window of the kernel (sparse location will lie inside the window') 154 | parser.add_argument("--kernel_pattern_init_radius", type=float, default=0.1, 155 | help='the initialize radius of init pattern') 156 | parser.add_argument("--kernel_num_hidden", type=int, default=3, 157 | help='the number of hidden layer') 158 | parser.add_argument("--kernel_num_wide", type=int, default=64, 159 | help='the wide of hidden layer') 160 | parser.add_argument("--kernel_shortcut", action='store_true', 161 | help='if yes, add a short cut to the network') 162 | parser.add_argument("--align_start_iter", type=int, default=0, 163 | help='start iteration of the align loss') 164 | parser.add_argument("--align_end_iter", type=int, default=1e10, 165 | help='end iteration of the align loss') 166 | parser.add_argument("--kernel_align_weight", type=float, default=0, 167 | help='align term weight') 168 | parser.add_argument("--kernel_spatialvariant_trans", action='store_true', 169 | help='if true, optimize spatial variant 3D translation of each sampling point') 170 | parser.add_argument("--kernel_global_trans", action='store_true', 171 | help='if true, optimize global 3D translation of each sampling point') 172 | parser.add_argument("--tone_mapping_type", type=str, default='none', 173 | help='the tone mapping of linear to LDR color space, , , ') 174 | 175 | ####### render option, will not effect training ######## 176 | parser.add_argument("--render_only", action='store_true', 177 | help='do not optimize, reload weights and render out render_poses path') 178 | parser.add_argument("--render_test", action='store_true', 179 | help='render the test set instead of render_poses path') 180 | parser.add_argument("--render_multipoints", action='store_true', 181 | help='render sub image that reconstruct the blur image') 182 | parser.add_argument("--render_rmnearplane", type=int, default=0, 183 | help='when render, set the density of nearest plane to 0') 184 | parser.add_argument("--render_focuspoint_scale", type=float, default=1., 185 | help='scale the focal point when render') 186 | parser.add_argument("--render_radius_scale", type=float, default=1., 187 | help='scale the radius of the camera path') 188 | parser.add_argument("--render_factor", type=int, default=0, 189 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 190 | parser.add_argument("--render_epi", action='store_true', 191 | help='render the video with epi path') 192 | 193 | ## llff flags 194 | parser.add_argument("--factor", type=int, default=None, 195 | help='downsample factor for LLFF images') 196 | parser.add_argument("--no_ndc", action='store_true', 197 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 198 | parser.add_argument("--lindisp", action='store_true', 199 | help='sampling linearly in disparity rather than depth') 200 | parser.add_argument("--spherify", action='store_true', 201 | help='set for spherical 360 scenes') 202 | parser.add_argument("--llffhold", type=int, default=8, 203 | help='will take every 1/N images as LLFF test set, paper uses 8') 204 | 205 | # ######### Unused params from the original ########### 206 | parser.add_argument("--precrop_iters", type=int, default=0, 207 | help='number of steps to train on central crops') 208 | parser.add_argument("--precrop_frac", type=float, 209 | default=.5, help='fraction of img taken for central crops') 210 | # dataset options 211 | parser.add_argument("--dataset_type", type=str, default='llff', 212 | help='options: llff / blender / deepvoxels') 213 | parser.add_argument("--testskip", type=int, default=8, 214 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 215 | ## deepvoxels flags 216 | parser.add_argument("--shape", type=str, default='greek', 217 | help='options : armchair / cube / greek / vase') 218 | ## blender flags 219 | parser.add_argument("--white_bkgd", action='store_true', 220 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 221 | parser.add_argument("--half_res", action='store_true', 222 | help='load blender synthetic data at 400x400 instead of 800x800') 223 | 224 | 225 | ################# logging/saving options ################## 226 | parser.add_argument("--i_print", type=int, default=200, 227 | help='frequency of console printout and metric loggin') 228 | parser.add_argument("--i_tensorboard", type=int, default=200, 229 | help='frequency of tensorboard image logging') 230 | parser.add_argument("--i_weights", type=int, default=5000, 231 | help='frequency of weight ckpt saving') 232 | parser.add_argument("--i_testset", type=int, default=5000, 233 | help='frequency of testset saving') 234 | parser.add_argument("--i_video", type=int, default=25000, 235 | help='frequency of render_poses video saving') 236 | 237 | return parser 238 | 239 | 240 | def train(): 241 | parser = config_parser() 242 | args = parser.parse_args() 243 | args.i_embed=1 244 | 245 | if len(args.torch_hub_dir) > 0: 246 | print(f"Change torch hub cache to {args.torch_hub_dir}") 247 | torch.hub.set_dir(args.torch_hub_dir) 248 | 249 | # Load data 250 | print(args) 251 | print('RANDOM SEED',args.seed) 252 | random.seed(args.seed) 253 | np.random.seed(args.seed) 254 | np.random.default_rng(seed=args.seed) 255 | torch.manual_seed(args.seed) 256 | torch.cuda.manual_seed_all(args.seed) 257 | torch.use_deterministic_algorithms(True) 258 | K = None 259 | if args.dataset_type == 'llff': 260 | images, poses, bds, render_poses, i_test, bounding_box = load_llff_data(args, args.datadir, args.factor, 261 | recenter=True, bd_factor=.75, 262 | spherify=args.spherify, 263 | path_epi=args.render_epi) 264 | hwf = poses[0, :3, -1] 265 | poses = poses[:, :3, :4] 266 | args.bounding_box = bounding_box 267 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir, bounding_box) 268 | if not isinstance(i_test, list): 269 | i_test = [i_test] 270 | 271 | print('LLFF holdout,', args.llffhold) 272 | i_test = np.arange(images.shape[0])[::args.llffhold] 273 | 274 | i_val = i_test 275 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 276 | (i not in i_test and i not in i_val)]) 277 | 278 | print('DEFINING BOUNDS') 279 | if args.no_ndc: 280 | near = np.min(bds) * 0.9 281 | far = np.max(bds) * 1.0 282 | 283 | else: 284 | near = 0. 285 | far = 1. 286 | print('NEAR FAR', near, far) 287 | else: 288 | print('Unknown dataset type', args.dataset_type, 'exiting') 289 | return 290 | 291 | imagesf = images 292 | images = (images * 255).astype(np.uint8) 293 | images_idx = np.arange(0, len(images)) 294 | 295 | # Cast intrinsics to right types 296 | H, W, focal = hwf 297 | H, W = int(H), int(W) 298 | hwf = [H, W, focal] 299 | 300 | if K is None: 301 | K = np.array([ 302 | [focal, 0, 0.5 * W], 303 | [0, focal, 0.5 * H], 304 | [0, 0, 1] 305 | ]) 306 | 307 | if args.render_test: 308 | render_poses = np.array(poses) 309 | 310 | # Create log dir and copy the config file 311 | basedir = args.basedir 312 | tensorboardbase = args.tbdir 313 | 314 | args.lrate = 0.01 315 | args.lrate_decay = 10 316 | 317 | expname = args.expname 318 | test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt') 319 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 320 | os.makedirs(os.path.join(tensorboardbase, expname), exist_ok=True) 321 | 322 | tensorboard = SummaryWriter(os.path.join(tensorboardbase, expname)) 323 | 324 | f = os.path.join(basedir, expname, 'args.txt') 325 | with open(f, 'w') as file: 326 | for arg in sorted(vars(args)): 327 | attr = getattr(args, arg) 328 | file.write('{} = {}\n'.format(arg, attr)) 329 | if args.config is not None and not args.render_only: 330 | f = os.path.join(basedir, expname, 'config.txt') 331 | with open(f, 'w') as file: 332 | file.write(open(args.config, 'r').read()) 333 | 334 | with open(test_metric_file, 'a') as file: 335 | file.write(open(args.config, 'r').read()) 336 | file.write("\n============================\n" 337 | "||\n" 338 | "\\/\n") 339 | 340 | # The DSK module 341 | if args.kernel_type == 'PBE' or args.kernel_type == 'DSK': 342 | kernelnet = BlurModel(len(images), torch.tensor(poses[:, :3, :4]), 343 | args.kernel_ptnum, args.kernel_hwindow, args.kernel_type, 344 | img_wh = [W,H], 345 | random_hwindow=args.kernel_random_hwindow, in_embed=args.kernel_rand_embed, 346 | random_mode=args.kernel_random_mode, 347 | img_embed=args.kernel_img_embed, 348 | spatial_embed=args.kernel_spatial_embed, 349 | depth_embed=args.kernel_depth_embed, 350 | num_hidden=args.kernel_num_hidden, 351 | num_wide=args.kernel_num_wide, 352 | feat_cnl=args.kernel_feat_cnl, 353 | short_cut=args.kernel_shortcut, 354 | pattern_init_radius=args.kernel_pattern_init_radius, 355 | isglobal=args.kernel_isglobal, 356 | optim_trans=args.kernel_global_trans, 357 | optim_spatialvariant_trans=args.kernel_spatialvariant_trans) 358 | elif args.kernel_type == 'none': 359 | kernelnet = None 360 | else: 361 | raise RuntimeError(f"kernel_type {args.kernel_type} not recognized") 362 | 363 | # Create nerf model 364 | nerf = NeRFAll(args, kernelnet) 365 | if args.mode == 'c2f': 366 | optimizer = torch.optim.Adam(params=[{'params': nerf.grad_vars, 'lr': 0.001}, 367 | {'params': nerf.grad_vars_vol, 'lr': 0.02}], 368 | betas=(0.9, 0.999)) 369 | elif args.mode == 'nerf': 370 | optim_params = nerf.parameters() 371 | optimizer = torch.optim.Adam(params=optim_params, 372 | lr=args.lrate, 373 | betas=(0.9, 0.999)) 374 | else: 375 | raise NotImplementedError(f"{self.mode} for rendering network is not implemented") 376 | 377 | 378 | start = 0 379 | if args.ft_path is not None and args.ft_path != 'None': 380 | ckpts = [args.ft_path] 381 | else: 382 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 383 | '.tar' in f] 384 | print('Found ckpts', ckpts) 385 | if len(ckpts) > 0 and not args.no_reload: 386 | ckpt_path = ckpts[-1] 387 | print('Reloading from', ckpt_path) 388 | ckpt = torch.load(ckpt_path) 389 | 390 | start = ckpt['global_step'] 391 | # Load model 392 | smart_load_state_dict(nerf, ckpt) 393 | 394 | # figuring out the train/test configuration 395 | render_kwargs_train = { 396 | 'perturb': args.perturb, 397 | 'N_importance': args.N_importance, 398 | 'N_samples': args.N_samples, 399 | 'use_viewdirs': args.use_viewdirs, 400 | 'white_bkgd': args.white_bkgd, 401 | 'raw_noise_std': args.raw_noise_std, 402 | } 403 | # NDC only good for LLFF-style forward facing data 404 | if args.no_ndc: # args.dataset_type != 'llff' or 405 | print('Not ndc!') 406 | render_kwargs_train['ndc'] = False 407 | render_kwargs_train['lindisp'] = args.lindisp 408 | render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train} 409 | render_kwargs_test['perturb'] = False 410 | render_kwargs_test['raw_noise_std'] = 0. 411 | 412 | bds_dict = { 413 | 'near': near, 414 | 'far': far, 415 | } 416 | render_kwargs_train.update(bds_dict) 417 | render_kwargs_test.update(bds_dict) 418 | 419 | global_step = start 420 | 421 | # Move testing data to GPU 422 | render_poses = torch.tensor(render_poses[:, :3, :4]).cuda() 423 | nerf = nerf.cuda() 424 | # Short circuit if only rendering out from trained model 425 | if args.render_only: 426 | print('RENDER ONLY') 427 | with torch.no_grad(): 428 | testsavedir = os.path.join(basedir, expname, 429 | f"renderonly" 430 | f"_{'test' if args.render_test else 'path'}" 431 | f"_{start:06d}") 432 | os.makedirs(testsavedir, exist_ok=True) 433 | print('test poses shape', render_poses.shape) 434 | 435 | dummy_num = ((len(poses) - 1) // args.num_gpu + 1) * args.num_gpu - len(poses) 436 | dummy_poses = torch.eye(3, 4).unsqueeze(0).expand(dummy_num, 3, 4).type_as(render_poses) 437 | print(f"Append {dummy_num} # of poses to fill all the GPUs") 438 | nerf.eval() 439 | rgbshdr, disps = nerf( 440 | hwf[0], hwf[1], K, args.chunk, 441 | poses=torch.cat([render_poses, dummy_poses], dim=0), 442 | render_kwargs=render_kwargs_test, 443 | render_factor=args.render_factor, 444 | ) 445 | rgbshdr = rgbshdr[:len(rgbshdr) - dummy_num] 446 | disps = (1. - disps) 447 | disps = disps[:len(disps) - dummy_num].cpu().numpy() 448 | rgbs = rgbshdr 449 | rgbs = to8b(rgbs.cpu().numpy()) 450 | if args.render_test: 451 | for rgb_idx, rgb8 in enumerate(rgbs): 452 | curr_disp = to8b(disps[rgb_idx] / disps[rgb_idx].max()) 453 | imageio.imwrite(os.path.join(testsavedir, f'{rgb_idx:03d}.png'), rgb8) 454 | imageio.imwrite(os.path.join(testsavedir, f'{rgb_idx:03d}_disp.png'), cv2.applyColorMap(255-curr_disp, cv2.COLORMAP_TWILIGHT_SHIFTED)) 455 | else: 456 | prefix = 'epi_' if args.render_epi else '' 457 | imageio.mimwrite(os.path.join(testsavedir, f'{prefix}video.mp4'), rgbs, fps=30, quality=9) 458 | disps = to8b(disps / disps.max()) 459 | imageio.mimwrite(os.path.join(testsavedir, f'{prefix}video_disp.mp4'), disps, fps=30, quality=9) 460 | 461 | if args.render_test and args.render_multipoints: 462 | for pti in range(args.kernel_ptnum): 463 | nerf.eval() 464 | poses_num = len(poses) + dummy_num 465 | imgidx = torch.arange(poses_num, dtype=torch.long).to(render_poses.device).reshape(poses_num, 1) 466 | rgbs, weights = nerf( 467 | hwf[0], hwf[1], K, args.chunk, 468 | poses=torch.cat([render_poses, dummy_poses], dim=0), 469 | render_kwargs=render_kwargs_test, 470 | render_factor=args.render_factor, 471 | render_point=pti, 472 | images_indices=imgidx 473 | ) 474 | rgbs = rgbs[:len(rgbs) - dummy_num] 475 | weights = weights[:len(weights) - dummy_num] 476 | rgbs = to8b(rgbs.cpu().numpy()) 477 | weights = to8b(weights.cpu().numpy()) 478 | for rgb_idx, rgb8 in enumerate(rgbs): 479 | imageio.imwrite(os.path.join(testsavedir, f'{rgb_idx:03d}_pt{pti}.png'), rgb8) 480 | imageio.imwrite(os.path.join(testsavedir, f'w_{rgb_idx:03d}_pt{pti}.png'), weights[rgb_idx]) 481 | return 482 | 483 | # ============================================ 484 | # Prepare ray dataset if batching random rays 485 | # ============================================ 486 | N_rand = args.N_rand 487 | train_datas = {} 488 | 489 | # if downsample, downsample the images 490 | if args.datadownsample > 0: 491 | images_train = np.stack([cv2.resize(img_, None, None, 492 | 1 / args.datadownsample, 1 / args.datadownsample, 493 | cv2.INTER_AREA) for img_ in imagesf], axis=0) 494 | else: 495 | images_train = imagesf 496 | 497 | num_img, hei, wid, _ = images_train.shape 498 | print(f"train on image sequence of len = {num_img}, {wid}x{hei}") 499 | k_train = np.array([K[0, 0] * wid / W, 0, K[0, 2] * wid / W, 500 | 0, K[1, 1] * hei / H, K[1, 2] * hei / H, 501 | 0, 0, 1]).reshape(3, 3).astype(K.dtype) 502 | 503 | # For random ray batching 504 | print('get rays') 505 | rays = np.stack([get_rays_np(hei, wid, k_train, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] 506 | rays = np.transpose(rays, [0, 2, 3, 1, 4]) 507 | train_datas['rays'] = rays[i_train].reshape(-1, 2, 3) 508 | 509 | xs, ys = np.meshgrid(np.arange(wid, dtype=np.float32), np.arange(hei, dtype=np.float32), indexing='xy') 510 | xs = np.tile((xs[None, ...] + HALF_PIX) * W / wid, [num_img, 1, 1]) 511 | ys = np.tile((ys[None, ...] + HALF_PIX) * H / hei, [num_img, 1, 1]) 512 | train_datas['rays_x'], train_datas['rays_y'] = xs[i_train].reshape(-1, 1), ys[i_train].reshape(-1, 1) 513 | 514 | train_datas['rgbsf'] = images_train[i_train].reshape(-1, 3) 515 | 516 | images_idx_tile = images_idx.reshape((num_img, 1, 1)) 517 | images_idx_tile = np.tile(images_idx_tile, [1, hei, wid]) 518 | train_datas['images_idx'] = images_idx_tile[i_train].reshape(-1, 1).astype(np.int64) 519 | 520 | print('shuffle rays') 521 | shuffle_idx = np.random.permutation(len(train_datas['rays'])) 522 | print(shuffle_idx) 523 | train_datas = {k: v[shuffle_idx] for k, v in train_datas.items()} 524 | 525 | print('done') 526 | i_batch = 0 527 | 528 | # Move training data to GPU 529 | images = torch.tensor(images).cuda() 530 | imagesf = torch.tensor(imagesf).cuda() 531 | 532 | poses = torch.tensor(poses).cuda() 533 | train_datas = {k: torch.tensor(v).cuda() for k, v in train_datas.items()} 534 | 535 | N_iters = args.N_iters + 1 536 | print('Begin') 537 | print('TRAIN views are', i_train) 538 | print('TEST views are', i_test) 539 | print('VAL views are', i_val) 540 | 541 | 542 | start = start + 1 543 | for i in range(start, N_iters): 544 | time0 = time.time() 545 | 546 | # Sample random ray batch 547 | iter_data = {k: v[i_batch:i_batch + N_rand] for k, v in train_datas.items()} 548 | batch_rays = iter_data.pop('rays').permute(0, 2, 1) 549 | 550 | i_batch += N_rand 551 | if i_batch >= len(train_datas['rays']): 552 | print("Shuffle data after an epoch!") 553 | shuffle_idx = np.random.permutation(len(train_datas['rays'])) 554 | train_datas = {k: v[shuffle_idx] for k, v in train_datas.items()} 555 | i_batch = 0 556 | iter_data = {k: v[i_batch:i_batch + N_rand] for k, v in train_datas.items()} 557 | batch_rays = iter_data.pop('rays').permute(0, 2, 1) 558 | i_batch += N_rand 559 | 560 | ##### Core optimization loop ##### 561 | nerf.train() 562 | if i == args.kernel_start_iter: 563 | torch.cuda.empty_cache() 564 | rgb, rgb0, extra_loss = nerf(H, W, K, chunk=args.chunk, 565 | rays=batch_rays, rays_info=iter_data, 566 | retraw=True, force_naive=i < args.kernel_start_iter, 567 | **render_kwargs_train) 568 | 569 | # Compute Losses 570 | # ===================== 571 | target_rgb = iter_data['rgbsf'].squeeze(-2) 572 | img_loss = img2mse(rgb, target_rgb) 573 | loss = img_loss 574 | psnr = mse2psnr(img_loss) 575 | 576 | 577 | img_loss0 = img2mse(rgb0, target_rgb) 578 | psnr0 = mse2psnr(img_loss0) 579 | loss = loss + img_loss0 580 | 581 | extra_loss = {k: torch.mean(v) for k, v in extra_loss.items()} 582 | if "TV" in extra_loss: 583 | loss = loss + extra_loss["TV"] 584 | if "align" in extra_loss: 585 | if vars(args)["align_start_iter"] <= i <= vars(args)["align_end_iter"]: 586 | loss = loss + extra_loss["align"] * vars(args)["kernel_align_weight"] 587 | 588 | 589 | 590 | optimizer.zero_grad() 591 | loss.backward() 592 | optimizer.step() 593 | 594 | # NOTE: IMPORTANT! 595 | ### update learning rate ### 596 | decay_rate = 0.1 597 | decay_steps = args.lrate_decay * 1000 598 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 599 | for param_group in optimizer.param_groups: 600 | param_group['lr'] = new_lrate 601 | ################################ 602 | 603 | # dt = time.time() - time0 604 | # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") 605 | ##### end ##### 606 | 607 | # Rest is logging 608 | if i % args.i_weights == 0 and i > 0: 609 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 610 | torch.save({ 611 | 'global_step': global_step, 612 | 'network_state_dict': nerf.state_dict(), 613 | 'optimizer_state_dict': optimizer.state_dict(), 614 | }, path) 615 | print('Saved checkpoints at', path) 616 | 617 | if i % args.i_video == 0 and i > 0: 618 | # Turn on testing mode 619 | with torch.no_grad(): 620 | nerf.eval() 621 | rgbs, disps = nerf(H, W, K, args.chunk, poses=render_poses, render_kwargs=render_kwargs_test) 622 | print('Done, saving', rgbs.shape, disps.shape) 623 | moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) 624 | rgbs = (rgbs - rgbs.min()) / (rgbs.max() - rgbs.min()) 625 | rgbs = rgbs.cpu().numpy() 626 | disps = disps.cpu().numpy() 627 | 628 | imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) 629 | imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / disps.max()), fps=30, quality=8) 630 | 631 | 632 | if i % args.i_testset == 0 and i > 0: 633 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 634 | os.makedirs(testsavedir, exist_ok=True) 635 | print('test poses shape', poses.shape) 636 | dummy_num = ((len(poses) - 1) // args.num_gpu + 1) * args.num_gpu - len(poses) 637 | dummy_poses = torch.eye(3, 4).unsqueeze(0).expand(dummy_num, 3, 4).type_as(render_poses) 638 | print(f"Append {dummy_num} # of poses to fill all the GPUs") 639 | with torch.no_grad(): 640 | nerf.eval() 641 | rgbs, _ = nerf(H, W, K, args.chunk, poses=torch.cat([poses, dummy_poses], dim=0).cuda(), 642 | render_kwargs=render_kwargs_test) 643 | rgbs = rgbs[:len(rgbs) - dummy_num] 644 | rgbs_save = rgbs # (rgbs - rgbs.min()) / (rgbs.max() - rgbs.min()) 645 | # saving 646 | for rgb_idx, rgb in enumerate(rgbs_save): 647 | rgb8 = to8b(rgb.cpu().numpy()) 648 | filename = os.path.join(testsavedir, f'{rgb_idx:03d}.png') 649 | imageio.imwrite(filename, rgb8) 650 | 651 | # evaluation 652 | rgbs = rgbs[i_test] 653 | target_rgb_ldr = imagesf[i_test] 654 | 655 | test_mse = compute_img_metric(rgbs, target_rgb_ldr, 'mse') 656 | test_psnr = compute_img_metric(rgbs, target_rgb_ldr, 'psnr') 657 | test_ssim = compute_img_metric(rgbs, target_rgb_ldr, 'ssim') 658 | test_lpips = compute_img_metric(rgbs, target_rgb_ldr, 'lpips') 659 | if isinstance(test_lpips, torch.Tensor): 660 | test_lpips = test_lpips.item() 661 | 662 | tensorboard.add_scalar("Test MSE", test_mse, global_step) 663 | tensorboard.add_scalar("Test PSNR", test_psnr, global_step) 664 | tensorboard.add_scalar("Test SSIM", test_ssim, global_step) 665 | tensorboard.add_scalar("Test LPIPS", test_lpips, global_step) 666 | 667 | with open(test_metric_file, 'a') as outfile: 668 | outfile.write(f"iter{i}/globalstep{global_step}: MSE:{test_mse:.8f} PSNR:{test_psnr:.8f}" 669 | f" SSIM:{test_ssim:.8f} LPIPS:{test_lpips:.8f}\n") 670 | 671 | print('Saved test set') 672 | 673 | if i % args.i_tensorboard == 0: 674 | tensorboard.add_scalar("Loss", loss.item(), global_step) 675 | tensorboard.add_scalar("PSNR", psnr.item(), global_step) 676 | for k, v in extra_loss.items(): 677 | tensorboard.add_scalar(k, v.item(), global_step) 678 | 679 | if i % args.i_print == 0: 680 | print(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 681 | 682 | global_step += 1 683 | 684 | 685 | if __name__ == '__main__': 686 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 687 | train() 688 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | import random 7 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 8 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 9 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 10 | 11 | HALF_PIX = 0.5 12 | # seed = 0 13 | # random.seed(seed) 14 | # np.random.seed(seed) 15 | # torch.manual_seed(seed) 16 | # torch.cuda.manual_seed_all(seed) 17 | 18 | class ToneMapping(nn.Module): 19 | def __init__(self, map_type: str): 20 | super(ToneMapping, self).__init__() 21 | assert map_type in ['none', 'gamma', 'learn', 'ycbcr'] 22 | self.map_type = map_type 23 | if map_type == 'learn': 24 | self.linear = nn.Sequential( 25 | nn.Linear(1, 16), nn.ReLU(), 26 | nn.Linear(16, 16), nn.ReLU(), 27 | nn.Linear(16, 16), nn.ReLU(), 28 | nn.Linear(16, 1) 29 | ) 30 | 31 | def forward(self, x): 32 | if self.map_type == 'none': 33 | return x 34 | elif self.map_type == 'learn': 35 | ori_shape = x.shape 36 | x_in = x.reshape(-1, 1) 37 | res_x = self.linear(x_in) * 0.1 38 | x_out = torch.sigmoid(res_x + x_in) 39 | return x_out.reshape(ori_shape) 40 | elif self.map_type == 'gamma': 41 | return x ** (1. / 2.2) 42 | else: 43 | assert RuntimeError("map_type not recognized") 44 | 45 | # Positional encoding (section 5.1) 46 | class Embedder(nn.Module): 47 | def __init__(self, **kwargs): 48 | super().__init__() 49 | self.kwargs = kwargs 50 | d = self.kwargs['input_dims'] 51 | out_dim = 0 52 | if self.kwargs['include_input']: 53 | out_dim += d 54 | 55 | max_freq = self.kwargs['max_freq_log2'] 56 | N_freqs = self.kwargs['num_freqs'] 57 | 58 | if self.kwargs['log_sampling']: 59 | self.freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 60 | else: 61 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 62 | 63 | for freq in self.freq_bands: 64 | for p_fn in self.kwargs['periodic_fns']: 65 | out_dim += d 66 | 67 | self.out_dim = out_dim 68 | 69 | def forward(self, inputs): 70 | # print(f"input device: {inputs.device}, freq_bands device: {self.freq_bands.device}") 71 | self.freq_bands = self.freq_bands.type_as(inputs) 72 | outputs = [] 73 | if self.kwargs['include_input']: 74 | outputs.append(inputs) 75 | 76 | for freq in self.freq_bands: 77 | for p_fn in self.kwargs['periodic_fns']: 78 | outputs.append(p_fn(inputs * freq)) 79 | return torch.cat(outputs, -1) 80 | 81 | 82 | 83 | def get_embedder(multires, i=0, input_dim=3): 84 | if i == -1: 85 | return nn.Identity(), 3 86 | 87 | embed_kwargs = { 88 | 'include_input': True, 89 | 'input_dims': input_dim, 90 | 'max_freq_log2': multires - 1, 91 | 'num_freqs': multires, 92 | 'log_sampling': True, 93 | 'periodic_fns': [torch.sin, torch.cos], 94 | } 95 | 96 | embedder_obj = Embedder(**embed_kwargs) 97 | return embedder_obj, embedder_obj.out_dim 98 | 99 | 100 | class TVLoss(torch.nn.Module): 101 | def __init__(self,TVLoss_weight=1): 102 | super(TVLoss,self).__init__() 103 | self.TVLoss_weight = TVLoss_weight 104 | 105 | def forward(self,x): 106 | batch_size = x.size()[0] 107 | h_x = x.size()[2] 108 | w_x = x.size()[3] 109 | count_h = self._tensor_size(x[:,:,1:,:]) 110 | count_w = self._tensor_size(x[:,:,:,1:]) 111 | count_w = max(count_w, 1) 112 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 113 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 114 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 115 | 116 | def _tensor_size(self,t): 117 | return t.size()[1]*t.size()[2]*t.size()[3] 118 | 119 | # Model 120 | class NeRF(nn.Module): 121 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False,rgb_activate='sigmoid',sigma_activate='relu',render_rmnearplane = 0): 122 | """ 123 | """ 124 | super(NeRF, self).__init__() 125 | self.D = D 126 | self.W = W 127 | self.input_ch = input_ch 128 | self.input_ch_views = input_ch_views 129 | self.skips = skips 130 | self.use_viewdirs = use_viewdirs 131 | self.render_rmnearplane = render_rmnearplane 132 | self.pts_linears = nn.ModuleList( 133 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in 134 | range(D - 1)]) 135 | 136 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 137 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)]) 138 | 139 | ### Implementation according to the paper 140 | activate = {'relu': torch.relu, 'sigmoid': torch.sigmoid, 'exp': torch.exp, 'none': lambda x: x, 141 | 'sigmoid1': lambda x: 1.002 / (torch.exp(-x) + 1) - 0.001, 142 | 'softplus': lambda x: nn.Softplus()(x - 1)} 143 | self.rgb_activate = activate[rgb_activate] 144 | self.sigma_activate = activate[sigma_activate] 145 | 146 | 147 | if use_viewdirs: 148 | self.feature_linear = nn.Linear(W, W) 149 | self.alpha_linear = nn.Linear(W, 1) 150 | self.rgb_linear = nn.Linear(W // 2, 3) 151 | else: 152 | self.output_linear = nn.Linear(W, output_ch) 153 | 154 | def mlpforward(self, inputs, viewdirs, embed_fn, embeddirs_fn, netchunk=1024 * 64): 155 | """Prepares inputs and applies network 'fn'. 156 | """ 157 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 158 | embedded = embed_fn(inputs_flat) 159 | 160 | if viewdirs is not None: 161 | input_dirs = viewdirs[:, None].expand(inputs.shape) 162 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 163 | embedded_dirs = embeddirs_fn(input_dirs_flat) 164 | embedded = torch.cat([embedded, embedded_dirs], -1) 165 | 166 | # batchify execution 167 | if netchunk is None: 168 | outputs_flat, feature_flat = self.eval(embedded) 169 | else: 170 | outputs_flat, feature_flat = [], [] 171 | for i in range(0, embedded.shape[0], netchunk): 172 | output,feature = self.eval(embedded[i:i + netchunk]) 173 | outputs_flat.append(output) 174 | feature_flat.append(feature) 175 | 176 | outputs_flat, feature_flat = torch.cat(outputs_flat, 0), torch.cat(feature_flat, 0) 177 | 178 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 179 | feature = torch.reshape(feature_flat, list(inputs.shape[:-1]) + [feature_flat.shape[-1]]) 180 | return outputs, feature 181 | 182 | def raw2outputs(self, raw, feature, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 183 | """Transforms model's predictions to semantically meaningful values. 184 | Args: 185 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 186 | z_vals: [num_rays, num_samples along ray]. Integration time. 187 | rays_d: [num_rays, 3]. Direction of each ray. 188 | Returns: 189 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 190 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 191 | acc_map: [num_rays]. Sum of weights along each ray. 192 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 193 | depth_map: [num_rays]. Estimated distance to object. 194 | """ 195 | 196 | def raw2alpha(raw_, dists_, act_fn): 197 | alpha_ = - torch.exp(-act_fn(raw_) * dists_) + 1. 198 | return torch.cat([alpha_, torch.ones_like(alpha_[:, 0:1])], dim=-1) 199 | 200 | dists = z_vals[..., 1:] - z_vals[..., :-1] # [N_rays, N_samples - 1] 201 | 202 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 203 | 204 | rgb = self.rgb_activate(raw[..., :3]) 205 | noise = 0. 206 | if raw_noise_std > 0.: 207 | noise = torch.randn_like(raw[..., :-1, 3]) * raw_noise_std 208 | # Overwrite randomly sampled data if pytest 209 | if pytest: 210 | np.random.seed(0) 211 | noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std 212 | noise = torch.tensor(noise) 213 | 214 | density = self.sigma_activate(raw[..., :-1, 3] + noise) 215 | if not self.training and self.render_rmnearplane > 0: 216 | mask = z_vals[:, 1:] 217 | mask = mask > self.render_rmnearplane / 128 218 | mask = mask.type_as(density) 219 | density = mask * density 220 | 221 | alpha = - torch.exp(- density * dists) + 1. 222 | alpha = torch.cat([alpha, torch.ones_like(alpha[:, 0:1])], dim=-1) 223 | weights = alpha * \ 224 | torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), - alpha + (1. + 1e-10)], -1), -1)[:, :-1] 225 | 226 | feature_map = torch.sum(weights[..., None] * feature, -2) # [N_rays, 3] 227 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 228 | depth_map = torch.sum(weights * z_vals, -1) 229 | 230 | # disp_map = 1. / torch.clamp_min(depth_map, 1e-10) 231 | acc_map = torch.sum(weights, -1) 232 | 233 | if white_bkgd: 234 | rgb_map = rgb_map + (1. - acc_map[..., None]) 235 | 236 | return rgb_map, feature_map, density, acc_map, weights, depth_map 237 | 238 | def eval(self, x): 239 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 240 | h = input_pts 241 | for i, l in enumerate(self.pts_linears): 242 | h = self.pts_linears[i](h) 243 | h = F.relu(h) 244 | if i in self.skips: 245 | h = torch.cat([input_pts, h], -1) 246 | 247 | if self.use_viewdirs: 248 | alpha = self.alpha_linear(h) 249 | feature = self.feature_linear(h) 250 | h = torch.cat([feature, input_views], -1) 251 | 252 | for i, l in enumerate(self.views_linears): 253 | h = self.views_linears[i](h) 254 | h = F.relu(h) 255 | 256 | rgb = self.rgb_linear(h) 257 | outputs = torch.cat([rgb, alpha], -1) 258 | else: 259 | outputs = self.output_linear(h) 260 | 261 | return outputs, feature 262 | 263 | def forward(self, pts, viewdirs, pts_embed, dirs_embed, z_vals, rays_d, raw_noise_std, white_bkgd, is_train): 264 | 265 | raw, feature = self.mlpforward(pts, viewdirs, pts_embed, dirs_embed) 266 | rgb_map, feature_map, density_map, acc_map, weights, depth_map = self.raw2outputs(raw, feature, z_vals, rays_d, raw_noise_std, white_bkgd) 267 | 268 | return rgb_map, depth_map, acc_map, weights, feature_map 269 | 270 | 271 | class NeRFSmall_ray(nn.Module): 272 | def __init__(self, 273 | aabb, 274 | num_layers=3, 275 | hidden_dim=64, 276 | geo_feat_dim=15, 277 | num_layers_color=4, 278 | hidden_dim_color=64, 279 | input_ch=3, input_ch_views=3, 280 | render_rmnearplane=0,app_dim=32, 281 | app_n_comp=[64,16,16], n_voxels=16777248): 282 | super(NeRFSmall_ray, self).__init__() 283 | 284 | self.input_ch = input_ch 285 | self.input_ch_views = input_ch_views 286 | self.render_rmnearplane = render_rmnearplane 287 | # sigma network 288 | self.num_layers = num_layers 289 | self.hidden_dim = hidden_dim 290 | self.geo_feat_dim = geo_feat_dim 291 | 292 | sigma_net = [] 293 | for l in range(num_layers): 294 | if l == 0: 295 | in_dim = self.input_ch 296 | else: 297 | in_dim = hidden_dim 298 | 299 | if l == num_layers - 1: 300 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color 301 | else: 302 | out_dim = hidden_dim 303 | 304 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 305 | 306 | self.sigma_net = nn.ModuleList(sigma_net) 307 | 308 | # color network 309 | self.num_layers_color = num_layers_color 310 | self.hidden_dim_color = hidden_dim_color 311 | 312 | color_net = [] 313 | for l in range(num_layers_color): 314 | if l == 0: 315 | in_dim = self.input_ch_views + self.geo_feat_dim 316 | else: 317 | in_dim = hidden_dim 318 | 319 | if l == num_layers_color - 1: 320 | out_dim = 3 # 3 rgb 321 | else: 322 | out_dim = hidden_dim 323 | 324 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 325 | 326 | self.color_net = nn.ModuleList(color_net) 327 | # self.aabb = torch.FloatTensor([[-2.0815, -2.3389, -1.0001], [2.2236, 2.0548, 1.0001]]).cuda() 328 | self.app_dim = app_dim #app_dim 329 | self.app_n_comp = app_n_comp#[48,12,12] 330 | 331 | 332 | self.aabb = torch.stack(aabb).cuda() 333 | xyz_min, xyz_max = aabb 334 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / 3) 335 | gridSize = ((xyz_max - xyz_min) / voxel_size).long().tolist() 336 | self.aabbSize = self.aabb[1] - self.aabb[0] 337 | self.invaabbSize = 2.0/self.aabbSize 338 | self.gridSize= torch.LongTensor(gridSize)#.to(self.device) 339 | print('Coarse Ray GridSize', self.gridSize) 340 | 341 | # self.gridSize = [164*2, 167*2, 76*2] 342 | # self.aabbSize = self.aabb[1] - self.aabb[0] 343 | # self.invaabbSize = 2.0/self.aabbSize 344 | 345 | self.matMode = [[0,1], [0,2], [1,2]] 346 | self.vecMode = [2, 1, 0] 347 | self.comp_w = [1,1,1] 348 | self.reg = TVLoss() 349 | 350 | 351 | self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1) 352 | self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False)#.to(device) 353 | 354 | def init_one_svd(self, n_component, gridSize, scale): 355 | plane_coef, line_coef = [], [] 356 | for i in range(len(self.vecMode)): 357 | vec_id = self.vecMode[i] 358 | mat_id_0, mat_id_1 = self.matMode[i] 359 | plane_coef.append(torch.nn.Parameter( 360 | scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) # 361 | line_coef.append( 362 | torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) 363 | 364 | # return plane_coef, line_coef 365 | return torch.nn.ParameterList(plane_coef), torch.nn.ParameterList(line_coef) 366 | 367 | 368 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): 369 | grad_vars_vol = list(self.app_line)+ list(self.app_plane) 370 | grad_vars_net = list(self.basis_mat.parameters())+list(self.color_net.parameters())+list(self.sigma_net.parameters()) 371 | return grad_vars_vol, grad_vars_net 372 | 373 | 374 | 375 | def TV_loss_app(self): 376 | total = 0 377 | for idx in range(len(self.app_plane)): 378 | total = total + self.reg(self.app_plane[idx]) * 1e-2 + self.reg(self.app_line[idx]) * 1e-3 379 | return total 380 | 381 | 382 | 383 | def compute_appfeature(self, xyz_sampled): 384 | 385 | # plane + line basis 386 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).view(3, -1, 1, 2) 387 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 388 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).view(3, -1, 1, 2) 389 | 390 | plane_coef_point,line_coef_point = [],[] 391 | for idx_plane in range(len(self.app_plane)): 392 | plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]], 393 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 394 | line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]], 395 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 396 | plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) 397 | 398 | 399 | return self.basis_mat((plane_coef_point * line_coef_point).T) 400 | 401 | 402 | def raw2outputs(self, raw, z_vals, rays_d, raw_noise_std=0, is_train=False): 403 | """Transforms model's predictions to semantically meaningful values. 404 | Args: 405 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 406 | z_vals: [num_rays, num_samples along ray]. Integration time. 407 | rays_d: [num_rays, 3]. Direction of each ray. 408 | Returns: 409 | feature_map: [num_rays, 3]. Estimated feature sum of a ray. 410 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 411 | acc_map: [num_rays]. Sum of weights along each ray. 412 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 413 | depth_map: [num_rays]. Estimated distance to object. 414 | """ 415 | 416 | def raw2alpha(raw_, dists_, act_fn): 417 | alpha_ = - torch.exp(-act_fn(raw_) * dists_) + 1. 418 | return torch.cat([alpha_, torch.ones_like(alpha_[:, 0:1])], dim=-1) 419 | 420 | dists = z_vals[..., 1:] - z_vals[..., :-1] # [N_rays, N_samples - 1] 421 | # dists = torch.cat([dists, torch.tensor([1e10]).expand(dists[..., :1].shape)], -1) 422 | 423 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 424 | 425 | feature = torch.relu(raw[..., 1:]) 426 | noise = 0. 427 | if raw_noise_std > 0.: 428 | noise = torch.randn_like(raw[..., :-1, 0]) * raw_noise_std 429 | 430 | 431 | density = torch.relu(raw[..., :-1, 0] + noise) 432 | # print(density.shape, raw.shape) 433 | if not is_train and self.render_rmnearplane > 0: 434 | mask = z_vals[:, 1:] 435 | mask = mask > self.render_rmnearplane / 128 436 | mask = mask.type_as(density) 437 | density = mask * density 438 | 439 | # print(density.shape, dists.shape) 440 | alpha = - torch.exp(- density * dists) + 1. 441 | alpha = torch.cat([alpha, torch.ones_like(alpha[:, 0:1])], dim=-1) 442 | 443 | # alpha = raw2alpha(raw[..., :-1, 3] + noise, dists, act_fn=self.sigma_activate) # [N_rays, N_samples] 444 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 445 | weights = alpha * \ 446 | torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), - alpha + (1. + 1e-10)], -1), -1)[:, :-1] 447 | 448 | feature_map = torch.sum(weights[..., None] * feature, -2) # [N_rays, 3] 449 | depth_map = torch.sum(weights * z_vals, -1) 450 | 451 | # disp_map = 1. / torch.clamp_min(depth_map, 1e-10) 452 | acc_map = torch.sum(weights, -1) 453 | 454 | return feature_map, density, acc_map, weights, depth_map#, sparsity_loss 455 | 456 | 457 | def sample(self, pts): 458 | xyz_sampled = (pts.reshape(-1,3)-self.aabb[0]) * self.invaabbSize - 1 459 | return self.compute_appfeature(xyz_sampled).reshape(pts.shape[0],pts.shape[1],-1) 460 | 461 | 462 | def forward(self, pts, viewdirs, fts, pts_embed, dirs_embed, z_vals, rays_d, raw_noise_std, is_train): 463 | 464 | 465 | # time1 = time.time() 466 | input_locs = torch.reshape(pts, [-1, pts.shape[-1]]) 467 | input_locs = pts_embed(input_locs) 468 | 469 | # xyz_sampled = (pts.reshape(-1,3)-self.aabb[0]) * self.invaabbSize - 1 470 | # input_pts = self.compute_appfeature(xyz_sampled) 471 | # input_dirs = viewdirs[:, None].expand(pts.shape) 472 | input_dirs = torch.reshape(viewdirs, [-1, viewdirs.shape[-1]]) 473 | input_dirs = dirs_embed(input_dirs) 474 | 475 | # time2 = time.time() 476 | 477 | # sigma 478 | h = torch.cat([fts.view(pts.shape[0]*pts.shape[1],-1),input_locs],-1) 479 | # h = input_pts 480 | for l in range(self.num_layers): 481 | h = self.sigma_net[l](h) 482 | if l != self.num_layers - 1: 483 | h = F.relu(h, inplace=True) 484 | 485 | h = h.reshape(pts.shape[0], pts.shape[1], -1) 486 | 487 | # time3 = time.time() 488 | feature_map, density_map, acc_map, weights, depth_map = self.raw2outputs(h, z_vals, rays_d, raw_noise_std, is_train=is_train) 489 | 490 | # time4 = time.time() 491 | # color 492 | h = torch.cat([feature_map,input_dirs],-1) 493 | for l in range(self.num_layers_color): 494 | h = self.color_net[l](h) 495 | if l != self.num_layers_color - 1: 496 | h = F.relu(h, inplace=True) 497 | 498 | color = torch.sigmoid(h) 499 | return color, depth_map, acc_map, weights,feature_map 500 | 501 | 502 | class NeRFSmall_voxel(nn.Module): 503 | def __init__(self, 504 | aabb, 505 | num_layers=3, 506 | hidden_dim=64, 507 | geo_feat_dim=15, 508 | num_layers_color=4, 509 | hidden_dim_color=64, 510 | input_ch=3, input_ch_views=3, 511 | render_rmnearplane=0,app_dim=32, 512 | app_n_comp=[64,16,16], n_voxels=134217984): 513 | super(NeRFSmall_voxel, self).__init__() 514 | 515 | self.input_ch = input_ch 516 | self.input_ch_views = input_ch_views 517 | self.render_rmnearplane = render_rmnearplane 518 | # sigma network 519 | self.num_layers = num_layers 520 | self.hidden_dim = hidden_dim 521 | self.geo_feat_dim = geo_feat_dim 522 | 523 | sigma_net = [] 524 | for l in range(num_layers): 525 | if l == 0: 526 | in_dim = self.input_ch 527 | else: 528 | in_dim = hidden_dim 529 | 530 | if l == num_layers - 1: 531 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 self.geo_feat_dim features for color 532 | else: 533 | out_dim = hidden_dim 534 | 535 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 536 | 537 | self.sigma_net = nn.ModuleList(sigma_net) 538 | 539 | # color network 540 | self.num_layers_color = num_layers_color 541 | self.hidden_dim_color = hidden_dim_color 542 | 543 | color_net = [] 544 | for l in range(num_layers_color): 545 | if l == 0: 546 | in_dim = self.input_ch_views + self.geo_feat_dim 547 | else: 548 | in_dim = hidden_dim 549 | 550 | if l == num_layers_color - 1: 551 | out_dim = 3 # 3 rgb 552 | else: 553 | out_dim = hidden_dim 554 | 555 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 556 | 557 | self.color_net = nn.ModuleList(color_net) 558 | self.aabb = torch.stack(aabb).cuda() 559 | xyz_min, xyz_max = aabb 560 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / 3) 561 | gridSize = ((xyz_max - xyz_min) / voxel_size).long().tolist() 562 | self.aabbSize = self.aabb[1] - self.aabb[0] 563 | self.invaabbSize = 2.0/self.aabbSize 564 | self.gridSize= torch.LongTensor(gridSize)#.to(self.device) 565 | print('Fine Voxel GridSize', self.gridSize) 566 | 567 | 568 | self.app_dim = app_dim 569 | self.app_n_comp = app_n_comp#[48,12,12] 570 | # self.gridSize = [164*4, 167*4, 76*4] 571 | # self.aabbSize = self.aabb[1] - self.aabb[0] 572 | # self.invaabbSize = 2.0/self.aabbSize 573 | 574 | self.matMode = [[0,1], [0,2], [1,2]] 575 | self.vecMode = [2, 1, 0] 576 | self.comp_w = [1,1,1] 577 | self.reg = TVLoss() 578 | 579 | 580 | self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1) 581 | self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False)#.to(device) 582 | 583 | def init_one_svd(self, n_component, gridSize, scale): 584 | plane_coef, line_coef = [], [] 585 | for i in range(len(self.vecMode)): 586 | vec_id = self.vecMode[i] 587 | mat_id_0, mat_id_1 = self.matMode[i] 588 | plane_coef.append(torch.nn.Parameter( 589 | scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) # 590 | line_coef.append( 591 | torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) 592 | 593 | # return plane_coef, line_coef 594 | return torch.nn.ParameterList(plane_coef), torch.nn.ParameterList(line_coef) 595 | 596 | 597 | 598 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): 599 | grad_vars_vol = list(self.app_line)+ list(self.app_plane) 600 | grad_vars_net = list(self.basis_mat.parameters())+list(self.color_net.parameters())+list(self.sigma_net.parameters()) 601 | return grad_vars_vol, grad_vars_net 602 | 603 | 604 | 605 | 606 | def TV_loss_app(self): 607 | total = 0 608 | for idx in range(len(self.app_plane)): 609 | total = total + self.reg(self.app_plane[idx]) * 1e-2 + self.reg(self.app_line[idx]) * 1e-3 610 | return total 611 | 612 | 613 | 614 | def compute_appfeature(self, xyz_sampled): 615 | 616 | # plane + line basis 617 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).view(3, -1, 1, 2) 618 | coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 619 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).view(3, -1, 1, 2) 620 | 621 | plane_coef_point,line_coef_point = [],[] 622 | for idx_plane in range(len(self.app_plane)): 623 | plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]], 624 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 625 | line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]], 626 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 627 | plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) 628 | 629 | 630 | return self.basis_mat((plane_coef_point * line_coef_point).T) 631 | 632 | 633 | 634 | 635 | def raw2outputs(self, raw, z_vals, rays_d, raw_noise_std=0, is_train=False): 636 | """Transforms model's predictions to semantically meaningful values. 637 | Args: 638 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 639 | z_vals: [num_rays, num_samples along ray]. Integration time. 640 | rays_d: [num_rays, 3]. Direction of each ray. 641 | Returns: 642 | feature_map: [num_rays, 3]. Estimated feature sum of a ray. 643 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 644 | acc_map: [num_rays]. Sum of weights along each ray. 645 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 646 | depth_map: [num_rays]. Estimated distance to object. 647 | """ 648 | 649 | def raw2alpha(raw_, dists_, act_fn): 650 | alpha_ = - torch.exp(-act_fn(raw_) * dists_) + 1. 651 | return torch.cat([alpha_, torch.ones_like(alpha_[:, 0:1])], dim=-1) 652 | 653 | dists = z_vals[..., 1:] - z_vals[..., :-1] # [N_rays, N_samples - 1] 654 | # dists = torch.cat([dists, torch.tensor([1e10]).expand(dists[..., :1].shape)], -1) 655 | 656 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 657 | 658 | # feature = torch.relu(raw[..., 1:]) 659 | noise = 0. 660 | if raw_noise_std > 0.: 661 | noise = torch.randn_like(raw[..., :-1, 0]) * raw_noise_std 662 | 663 | 664 | density = torch.relu(raw[..., :-1, 0] + noise) 665 | # print(density.shape, raw.shape) 666 | if not is_train and self.render_rmnearplane > 0: 667 | mask = z_vals[:, 1:] 668 | mask = mask > self.render_rmnearplane / 128 669 | mask = mask.type_as(density) 670 | density = mask * density 671 | 672 | # print(density.shape, dists.shape) 673 | alpha = - torch.exp(- density * dists) + 1. 674 | alpha = torch.cat([alpha, torch.ones_like(alpha[:, 0:1])], dim=-1) 675 | 676 | # alpha = raw2alpha(raw[..., :-1, 3] + noise, dists, act_fn=self.sigma_activate) # [N_rays, N_samples] 677 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 678 | weights = alpha * \ 679 | torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), - alpha + (1. + 1e-10)], -1), -1)[:, :-1] 680 | 681 | rgb_map = torch.sum(weights[..., None] * raw[..., 1:], -2) # [N_rays, 3] 682 | depth_map = torch.sum(weights * z_vals, -1) 683 | 684 | # disp_map = 1. / torch.clamp_min(depth_map, 1e-10) 685 | acc_map = torch.sum(weights, -1) 686 | 687 | 688 | # mask = weights.sum(-1) > 0.5 689 | # entropy = Categorical(probs = weights+1e-5).entropy() 690 | # sparsity_loss = entropy * mask 691 | 692 | return rgb_map, density, acc_map, weights, depth_map#, sparsity_loss 693 | 694 | def sample(self, pts): 695 | xyz_sampled = (pts.reshape(-1,3)-self.aabb[0]) * self.invaabbSize - 1 696 | return self.compute_appfeature(xyz_sampled).reshape(pts.shape[0],pts.shape[1],-1) 697 | 698 | 699 | def forward(self, pts, viewdirs, fts, pts_embed, dirs_embed, z_vals, rays_d, raw_noise_std, is_train): 700 | 701 | 702 | # time1 = time.time() 703 | input_locs = torch.reshape(pts, [-1, pts.shape[-1]]) 704 | input_locs = pts_embed(input_locs) 705 | input_dirs = viewdirs[:, None].expand(pts.shape) 706 | input_dirs = torch.reshape(input_dirs, [-1, viewdirs.shape[-1]]) 707 | input_dirs = dirs_embed(input_dirs) 708 | 709 | # time2 = time.time() 710 | 711 | # sigma 712 | h = torch.cat([fts.reshape(pts.shape[0]*pts.shape[1],-1),input_locs],-1) 713 | # h = input_pts 714 | for l in range(self.num_layers): 715 | h = self.sigma_net[l](h) 716 | if l != self.num_layers - 1: 717 | h = F.relu(h, inplace=True) 718 | 719 | sigma = h[...,[0]].reshape(pts.shape[0], pts.shape[1], -1) 720 | # color = torch.zeros((*pts.shape[:2], 3), device=pts.device) 721 | h = torch.cat([h[...,1:],input_dirs],-1)#.reshape(pts.shape[0], pts.shape[1], -1) 722 | for l in range(self.num_layers_color): 723 | h = self.color_net[l](h) 724 | if l != self.num_layers_color - 1: 725 | h = F.relu(h, inplace=True) 726 | 727 | color = torch.sigmoid(h).reshape(pts.shape[0], pts.shape[1], -1) 728 | 729 | color, density_map, acc_map, weights, depth_map = self.raw2outputs(torch.cat([sigma,color],-1), z_vals, rays_d, raw_noise_std, is_train=is_train) 730 | 731 | # time5 = time.time() 732 | 733 | # print(f"Time| embed: {time2-time1:.5f}, sigma: {time3-time2:.5f} raw2output: {time4-time3:.5f}, color: {time5-time4:.5f}") 734 | return color, depth_map, acc_map, weights 735 | 736 | 737 | # Ray helpers 738 | def get_rays(H, W, K, c2w): 739 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), 740 | torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij' 741 | i = i.t() 742 | j = j.t() 743 | dirs = torch.stack([(i + (HALF_PIX - K[0][2])) / K[0][0], -(j + (HALF_PIX - K[1][2])) / K[1][1], -torch.ones_like(i)], -1) 744 | # Rotate ray directions from camera frame to the world frame 745 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 746 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 747 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 748 | rays_o = c2w[:3, -1].expand(rays_d.shape) 749 | return rays_o, rays_d 750 | 751 | 752 | def get_rays_np(H, W, K, c2w): 753 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 754 | dirs = np.stack([(i + (HALF_PIX - K[0][2])) / K[0][0], -(j + (HALF_PIX - K[1][2])) / K[1][1], -np.ones_like(i)], -1) 755 | # Rotate ray directions from camera frame to the world frame 756 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 757 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 758 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 759 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) 760 | return rays_o, rays_d 761 | 762 | 763 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 764 | """ 765 | See Paper supplementary for details 766 | """ 767 | # Shift ray origins to near plane 768 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 769 | rays_o = rays_o + t[..., None] * rays_d 770 | 771 | # Projection 772 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 773 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 774 | o2 = 1. + 2. * near / rays_o[..., 2] 775 | 776 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 777 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 778 | d2 = -2. * near / rays_o[..., 2] 779 | 780 | rays_o = torch.stack([o0, o1, o2], -1) 781 | rays_d = torch.stack([d0, d1, d2], -1) 782 | 783 | return rays_o, rays_d 784 | 785 | 786 | # Hierarchical sampling (section 5.2) 787 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 788 | # Get pdf 789 | weights = weights + 1e-5 # prevent nans 790 | pdf = weights / torch.sum(weights, -1, keepdim=True) 791 | cdf = torch.cumsum(pdf, -1) 792 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 793 | 794 | # Take uniform samples 795 | if det: 796 | u = torch.linspace(0., 1., steps=N_samples) 797 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 798 | else: 799 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 800 | 801 | # Pytest, overwrite u with numpy's fixed random numbers 802 | # if pytest: 803 | # np.random.seed(0) 804 | # new_shape = list(cdf.shape[:-1]) + [N_samples] 805 | # if det: 806 | # u = np.linspace(0., 1., N_samples) 807 | # u = np.broadcast_to(u, new_shape) 808 | # else: 809 | # u = np.random.rand(*new_shape) 810 | # u = torch.Tensor(u) 811 | 812 | # Invert CDF 813 | u = u.contiguous() 814 | inds = torch.searchsorted(cdf, u, right=True) 815 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 816 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 817 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 818 | 819 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 820 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 821 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 822 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 823 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 824 | 825 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 826 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 827 | t = (u - cdf_g[..., 0]) / denom 828 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 829 | 830 | return samples 831 | 832 | 833 | def smart_load_state_dict(model: nn.Module, state_dict: dict): 834 | if "network_fn_state_dict" in state_dict.keys(): 835 | state_dict_fn = {k.lstrip("module."): v for k, v in state_dict["network_fn_state_dict"].items()} 836 | state_dict_fn = {"mlp_coarse." + k: v for k, v in state_dict_fn.items()} 837 | 838 | state_dict_fine = {k.lstrip("module."): v for k, v in state_dict["network_fine_state_dict"].items()} 839 | state_dict_fine = {"mlp_fine." + k: v for k, v in state_dict_fine.items()} 840 | state_dict_fn.update(state_dict_fine) 841 | state_dict = state_dict_fn 842 | # elif "network_state_dict" in state_dict.keys(): 843 | # state_dict = {k[7:]: v for k, v in state_dict["network_state_dict"].items()} 844 | else: 845 | state_dict = state_dict 846 | 847 | # if isinstance(model, nn.DataParallel): 848 | # state_dict = {"module." + k: v for k, v in state_dict.items()} 849 | model.load_state_dict(state_dict["network_state_dict"]) 850 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pdb 4 | import torch 5 | 6 | from ray_utils import get_rays, get_ray_directions, get_ndc_rays 7 | # import random 8 | # import numpy as np 9 | # seed = 0 10 | # random.seed(seed) 11 | # np.random.seed(seed) 12 | # torch.manual_seed(seed) 13 | # torch.cuda.manual_seed_all(seed) 14 | 15 | BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]], 16 | device='cuda') 17 | 18 | 19 | def hash(coords, log2_hashmap_size): 20 | ''' 21 | coords: this function can process upto 7 dim coordinates 22 | log2T: logarithm of T w.r.t 2 23 | ''' 24 | primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] 25 | 26 | xor_result = torch.zeros_like(coords)[..., 0] 27 | for i in range(coords.shape[-1]): 28 | xor_result ^= coords[..., i]*primes[i] 29 | 30 | return torch.tensor((1< pt[i]): 52 | min_bound[i] = pt[i] 53 | if(max_bound[i] < pt[i]): 54 | max_bound[i] = pt[i] 55 | return 56 | 57 | for i in [0, W-1, H*W-W, H*W-1]: 58 | min_point = rays_o[i] + near*rays_d[i] 59 | max_point = rays_o[i] + far*rays_d[i] 60 | points += [min_point, max_point] 61 | find_min_max(min_point) 62 | find_min_max(max_point) 63 | 64 | return (torch.tensor(min_bound)-torch.tensor([1.0,1.0,1.0]), torch.tensor(max_bound)+torch.tensor([1.0,1.0,1.0])) 65 | 66 | 67 | def get_bbox3d_for_llff(poses, hwf, near=0.0, far=1.0): 68 | H, W, focal = hwf 69 | H, W = int(H), int(W) 70 | 71 | # ray directions in camera coordinates 72 | directions = get_ray_directions(H, W, focal) 73 | 74 | min_bound = [100, 100, 100] 75 | max_bound = [-100, -100, -100] 76 | 77 | points = [] 78 | poses = torch.FloatTensor(poses) 79 | for pose in poses: 80 | rays_o, rays_d = get_rays(directions, pose) 81 | rays_o, rays_d = get_ndc_rays(H, W, focal, 1.0, rays_o, rays_d) 82 | 83 | def find_min_max(pt): 84 | for i in range(3): 85 | if(min_bound[i] > pt[i]): 86 | min_bound[i] = pt[i] 87 | if(max_bound[i] < pt[i]): 88 | max_bound[i] = pt[i] 89 | return 90 | 91 | for i in [0, W-1, H*W-W, H*W-1]: 92 | min_point = rays_o[i] + near*rays_d[i] 93 | max_point = rays_o[i] + far*rays_d[i] 94 | points += [min_point, max_point] 95 | find_min_max(min_point) 96 | find_min_max(max_point) 97 | 98 | return (torch.tensor(min_bound)-torch.tensor([0.01,0.01,0.0001]), torch.tensor(max_bound)+torch.tensor([0.01,0.01,0.0001])) 99 | 100 | 101 | def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size): 102 | ''' 103 | xyz: 3D coordinates of samples. B x 3 104 | bounding_box: min and max x,y,z coordinates of object bbox 105 | resolution: number of voxels per axis 106 | ''' 107 | box_min, box_max = bounding_box 108 | 109 | if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min): 110 | # print("ALERT: some points are outside bounding box. Clipping them!") 111 | pdb.set_trace() 112 | xyz = torch.clamp(xyz, min=box_min, max=box_max) 113 | 114 | grid_size = (box_max-box_min)/resolution 115 | 116 | bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int() 117 | voxel_min_vertex = bottom_left_idx*grid_size + box_min 118 | voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size 119 | 120 | # hashed_voxel_indices = [] # B x 8 ... 000,001,010,011,100,101,110,111 121 | # for i in [0, 1]: 122 | # for j in [0, 1]: 123 | # for k in [0, 1]: 124 | # vertex_idx = bottom_left_idx + torch.tensor([i,j,k]) 125 | # # vertex = bottom_left + torch.tensor([i,j,k])*grid_size 126 | # hashed_voxel_indices.append(hash(vertex_idx, log2_hashmap_size)) 127 | 128 | voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS 129 | hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size) 130 | 131 | return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices 132 | 133 | 134 | 135 | if __name__=="__main__": 136 | with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f: 137 | camera_transforms = json.load(f) 138 | 139 | bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800) 140 | --------------------------------------------------------------------------------