├── README.md ├── confs └── womask.conf ├── exp_runner.py ├── models ├── __pycache__ │ ├── dataset.cpython-38.pyc │ ├── embedder.cpython-38.pyc │ ├── eval.cpython-38.pyc │ ├── fields.cpython-38.pyc │ ├── renderer.cpython-38.pyc │ └── utils.cpython-38.pyc ├── dataset.py ├── embedder.py ├── eval.py ├── fields.py ├── renderer.py └── utils.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Ref-NeuS: Ambiguity-Reduced Neural Implicit Surface Learning for Multi-View Reconstruction with Reflection (ICCV, Oral, Best Paper Nomination. Top 0.2%) 4 | 5 | ## [Project Page](https://g3956.github.io/) | [Paper](https://arxiv.org/pdf/2303.10840.pdf) 6 | 7 | This is the official repo for the implementation of [Ref-NeuS: Ambiguity-Reduced Neural Implicit Surface Learning for Multi-View Reconstruction with Reflection](https://arxiv.org/pdf/2303.10840.pdf), Wenhang Ge, Tao Hu, Haoyu Zhao, Shu Liu, Ying-Cong Chen. 8 | 9 | ## Setup 10 | 11 | Installation 12 | 13 | This code is built with pytorch 1.11.0. See ```requirements.txt``` for the python packages. 14 | 15 | You can create an anaconda environment called refneus with the required dependencies by running: 16 | 17 | ``` 18 | conda create -n refneus python=3.7 19 | conda activate refneus 20 | conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=11.3 -c pytorch 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Data 25 | 26 | Download data [ShinlyBlender](https://storage.googleapis.com/gresearch/refraw360/ref.zip). 27 | 28 | Download the GT dense point cloud for evaluation from [Google Drive](https://drive.google.com/file/d/1HGTD3uQUr8WrzRYZBagrg75_rQJmAK6S/view?usp=sharing). 29 | 30 | Make sure the data is organized as follows (we show an object helmet here): 31 |
32 | +-- ShinyBlender 33 | | +-- helmet 34 | | +-- test 35 | | +-- train 36 | | +-- dense_pcd.ply 37 | | +-- points_of_interest.ply 38 | | +-- test_info.json 39 | | +-- transforms_test.json 40 | | +-- transforms_train.json 41 | +-- toaster 42 |43 | 44 | ## Evaluation with pretrained model 45 | 46 | Download the pretrained models [Pretrained Models for reconstruction evaluation](https://drive.google.com/file/d/17A0x04nyRc9QLd31R57tWz1tcn159vr2/view?usp=sharing), 47 | [Pretrained Models for PSNR evaluation](https://drive.google.com/file/d/1wqFJBv3hAHbBTM49yQZ_Gctm2CV_QVrr/view?usp=sharing). 48 | 49 | Run the evaluation script with 50 | 51 | ```python exp_runner.py --mode validate_mesh --conf ./confs/womask.conf --ckpt_path ckpt_path``` 52 | 53 | ```ckpt_path``` is the path to the pretrained model. 54 | 55 | Make sure the ```data_dir``` in configuration file ```./confs/womask.conf``` points to the same object as pretrained model. 56 | 57 | The output mesh will be in ```base_exp_dir/meshes```. You can specify the path ```base_exp_dir``` in the configuration file. 58 | 59 | The evaluaton metrics will be written in ```base_exp_dir/result.txt```. 60 | 61 | The error visulization are in ```base_exp_dir/vis_d2s.ply```. Points with large errors are marked in red. 62 | 63 | We can also download our final meshes results [here](https://drive.google.com/file/d/1r1G4Lu3U2017PHgIImx7WXm_ERSfKaHv/view?usp=sharing). 64 | 65 | We also provide a function to make a video for surface normals and novel view synthesis. Run the evaluation script with 66 | 67 | ```python exp_runner.py --mode visualize_video --conf ./confs/womask.conf --ckpt_path ckpt_path``` 68 | 69 | The output videos will be in ```base_exp_dir/normals.mp4``` and ```base_exp_dir/video.mp4```. 70 | 71 | ## Train a model from scratch 72 | 73 | Run the evaluation script with 74 | 75 | ```python exp_runner.py --mode train --conf ./confs/womask.conf ``` 76 | 77 | ## Citation 78 | 79 | 80 | If you find our work useful in your research, please consider citing: 81 | 82 | ``` 83 | @article{ge2023ref, 84 | title={Ref-NeuS: Ambiguity-Reduced Neural Implicit Surface Learning for Multi-View Reconstruction with Reflection}, 85 | author={Ge, Wenhang and Hu, Tao and Zhao, Haoyu and Liu, Shu and Chen, Ying-Cong}, 86 | journal={arXiv preprint arXiv:2303.10840}, 87 | year={2023} 88 | } 89 | ``` 90 | 91 | 92 | ## Acknowledgments 93 | 94 | Our code is partially based on [NeuS](https://github.com/Totoro97/NeuS) project and some code snippets are borrowed from [NeuralWarp](https://github.com/fdarmon/NeuralWarp). Thanks for these great projects. 95 | 96 | 97 | -------------------------------------------------------------------------------- /confs/womask.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp/helmet/Ref_NeuS, 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./data/ShinyBlender/helmet, 11 | render_cameras_name = cameras.npz, 12 | object_cameras_name = cameras.npz 13 | } 14 | 15 | train { 16 | learning_rate = 5e-4, 17 | learning_rate_alpha = 0.05, 18 | end_iter = 200001, 19 | 20 | batch_size = 512, 21 | validate_resolution_level = 4, 22 | warm_up_end = 5000, 23 | anneal_end = 50000, 24 | use_white_bkgd = False, 25 | 26 | save_freq = 10000, 27 | val_freq = 2500, 28 | val_mesh_freq = 500, 29 | report_freq = 1000, 30 | 31 | igr_weight = 0.1, 32 | mask_weight = 0.0 33 | } 34 | 35 | model { 36 | nerf { 37 | D = 8, 38 | d_in = 4, 39 | d_in_view = 3, 40 | W = 256, 41 | multires = 10, 42 | multires_view = 4, 43 | output_ch = 4, 44 | skips=[4], 45 | use_viewdirs=True 46 | } 47 | 48 | sdf_network { 49 | d_out = 257, 50 | d_in = 3, 51 | d_hidden = 256, 52 | n_layers = 8, 53 | skip_in = [4], 54 | multires = 6, 55 | bias = 0.5, 56 | scale = 1.0, 57 | geometric_init = True, 58 | weight_norm = True 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | rendering_network { 66 | d_feature = 256, 67 | mode = idr, 68 | d_in = 9, 69 | d_out = 3, 70 | d_hidden = 256, 71 | n_layers = 4, 72 | weight_norm = True, 73 | multires_view = 4, 74 | squeeze_out = True 75 | } 76 | 77 | neus_renderer { 78 | n_samples = 64, 79 | n_importance = 64, 80 | n_outside = 32, 81 | up_sample_steps = 4, 82 | perturb = 1.0 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /exp_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | import numpy as np 6 | import cv2 as cv 7 | import trimesh 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from shutil import copyfile 12 | from tqdm import tqdm 13 | from pyhocon import ConfigFactory 14 | from models.dataset import Dataset 15 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF 16 | from models.renderer import NeuSRenderer 17 | from models.eval import evaluation_shinyblender 18 | import open3d as o3d 19 | import json 20 | import torchvision 21 | import torch.nn as nn 22 | import math 23 | from models.utils import generate_spherical_cam_to_world 24 | import imageio 25 | 26 | class Runner: 27 | def __init__(self, conf_path, mode='train', is_continue=False): 28 | self.device = torch.device('cuda') 29 | 30 | # Configuration 31 | self.conf_path = conf_path 32 | f = open(self.conf_path) 33 | conf_text = f.read() 34 | f.close() 35 | 36 | self.conf = ConfigFactory.parse_string(conf_text) 37 | self.base_exp_dir = self.conf['general.base_exp_dir'] 38 | os.makedirs(self.base_exp_dir, exist_ok=True) 39 | self.dataset = Dataset(self.conf['dataset']) 40 | self.iter_step = 0 41 | 42 | # Training parameters 43 | self.end_iter = self.conf.get_int('train.end_iter') 44 | self.save_freq = self.conf.get_int('train.save_freq') 45 | self.report_freq = self.conf.get_int('train.report_freq') 46 | self.val_freq = self.conf.get_int('train.val_freq') 47 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') 48 | self.batch_size = self.conf.get_int('train.batch_size') 49 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') 50 | self.learning_rate = self.conf.get_float('train.learning_rate') 51 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 52 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') 53 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 54 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) 55 | 56 | # Weights 57 | self.igr_weight = self.conf.get_float('train.igr_weight') 58 | self.mask_weight = self.conf.get_float('train.mask_weight') 59 | self.is_continue = is_continue 60 | self.mode = mode 61 | self.model_list = [] 62 | self.writer = None 63 | 64 | # Networks 65 | params_to_train = [] 66 | self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device) 67 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) 68 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) 69 | self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device) 70 | params_to_train += list(self.nerf_outside.parameters()) 71 | params_to_train += list(self.sdf_network.parameters()) 72 | params_to_train += list(self.deviation_network.parameters()) 73 | params_to_train += list(self.color_network.parameters()) 74 | 75 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 76 | 77 | self.renderer = NeuSRenderer(self.nerf_outside, 78 | self.sdf_network, 79 | self.deviation_network, 80 | self.color_network, 81 | **self.conf['model.neus_renderer']) 82 | 83 | # Intermediate Mesh 84 | self.scene = None 85 | 86 | # Load checkpoint 87 | latest_model_name = None 88 | if is_continue: 89 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) 90 | model_list = [] 91 | for model_name in model_list_raw: 92 | if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter: 93 | model_list.append(model_name) 94 | model_list.sort() 95 | latest_model_name = model_list[-1] 96 | 97 | if latest_model_name is not None: 98 | logging.info('Find checkpoint: {}'.format(latest_model_name)) 99 | self.load_checkpoint(latest_model_name) 100 | 101 | # Backup codes and configs for debug 102 | if self.mode[:5] == 'train': 103 | self.file_backup() 104 | 105 | def train(self): 106 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) 107 | self.update_learning_rate() 108 | res_step = self.end_iter - self.iter_step 109 | image_perm = self.get_image_perm() 110 | 111 | for iter_i in tqdm(range(res_step)): 112 | img_idx = image_perm[self.iter_step % len(image_perm)] 113 | data, uv = self.dataset.gen_random_rays_at(img_idx, self.batch_size) 114 | 115 | rays_o, rays_d, true_rgb, mask = data[:, :3], data[:, 3: 6], data[:, 6: 9], data[:, 9: 10] 116 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 117 | 118 | background_rgb = None 119 | if self.use_white_bkgd: 120 | background_rgb = torch.ones([1, 3]) 121 | 122 | if self.mask_weight > 0.0: 123 | mask = (mask > 0.5).float() 124 | else: 125 | mask = torch.ones_like(mask) 126 | 127 | mask_sum = mask.sum() + 1e-5 128 | 129 | if self.iter_step % self.val_mesh_freq == 0: 130 | self.scene = self.validate_mesh(resolution=128) 131 | 132 | if self.iter_step % self.val_freq == 0: 133 | self.validate_image() 134 | 135 | render_out = self.renderer.render(rays_o, rays_d, near, far, img_idx, uv, self.dataset, self.scene, 136 | background_rgb=background_rgb, 137 | cos_anneal_ratio=self.get_cos_anneal_ratio()) 138 | 139 | color_fine = render_out['color_fine'] 140 | s_val = render_out['s_val'] 141 | cdf_fine = render_out['cdf_fine'] 142 | gradient_error = render_out['gradient_error'] 143 | weight_max = render_out['weight_max'] 144 | weight_sum = render_out['weight_sum'] 145 | RS = render_out['RS'].unsqueeze(-1) 146 | 147 | # Loss 148 | color_error = (color_fine - true_rgb) * mask 149 | color_fine_loss = (F.l1_loss(color_error, torch.zeros_like(color_error), reduction='none') / RS).sum() / mask_sum 150 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 151 | 152 | eikonal_loss = gradient_error 153 | 154 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask) 155 | 156 | loss = color_fine_loss +\ 157 | eikonal_loss * self.igr_weight +\ 158 | mask_loss * self.mask_weight 159 | 160 | self.optimizer.zero_grad() 161 | loss.backward() 162 | self.optimizer.step() 163 | 164 | self.iter_step += 1 165 | 166 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 167 | self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step) 168 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step) 169 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step) 170 | self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step) 171 | self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step) 172 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 173 | 174 | if self.iter_step % self.report_freq == 0: 175 | print('iter:{:8>d} loss = {} color_loss={} eikonal_loss={} psnr={} lr={}'.format( 176 | self.iter_step, loss, color_fine_loss, eikonal_loss, 177 | psnr, self.optimizer.param_groups[0]['lr'])) 178 | 179 | if self.iter_step % self.save_freq == 0: 180 | self.save_checkpoint() 181 | 182 | self.update_learning_rate() 183 | 184 | if self.iter_step % len(image_perm) == 0: 185 | image_perm = self.get_image_perm() 186 | 187 | def get_image_perm(self): 188 | return torch.randperm(self.dataset.n_images) 189 | 190 | def get_cos_anneal_ratio(self): 191 | if self.anneal_end == 0.0: 192 | return 1.0 193 | else: 194 | return np.min([1.0, self.iter_step / self.anneal_end]) 195 | 196 | def update_learning_rate(self): 197 | if self.iter_step < self.warm_up_end: 198 | learning_factor = self.iter_step / self.warm_up_end 199 | else: 200 | alpha = self.learning_rate_alpha 201 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) 202 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha 203 | 204 | for g in self.optimizer.param_groups: 205 | g['lr'] = self.learning_rate * learning_factor 206 | 207 | def file_backup(self): 208 | dir_lis = self.conf['general.recording'] 209 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 210 | for dir_name in dir_lis: 211 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 212 | os.makedirs(cur_dir, exist_ok=True) 213 | files = os.listdir(dir_name) 214 | for f_name in files: 215 | if f_name[-3:] == '.py': 216 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 217 | 218 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 219 | 220 | def load_checkpoint(self, checkpoint_name): 221 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) 222 | self.nerf_outside.load_state_dict(checkpoint['nerf']) 223 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine']) 224 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) 225 | self.color_network.load_state_dict(checkpoint['color_network_fine']) 226 | self.optimizer.load_state_dict(checkpoint['optimizer']) 227 | self.iter_step = checkpoint['iter_step'] 228 | 229 | logging.info('End') 230 | 231 | def load_ckpt_validation(self, ckpt_path): 232 | checkpoint = torch.load(ckpt_path, map_location=self.device) 233 | self.nerf_outside.load_state_dict(checkpoint['nerf']) 234 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine']) 235 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) 236 | self.color_network.load_state_dict(checkpoint['color_network_fine']) 237 | self.optimizer.load_state_dict(checkpoint['optimizer']) 238 | self.iter_step = checkpoint['iter_step'] 239 | 240 | logging.info('End') 241 | 242 | 243 | def save_checkpoint(self): 244 | checkpoint = { 245 | 'nerf': self.nerf_outside.state_dict(), 246 | 'sdf_network_fine': self.sdf_network.state_dict(), 247 | 'variance_network_fine': self.deviation_network.state_dict(), 248 | 'color_network_fine': self.color_network.state_dict(), 249 | 'optimizer': self.optimizer.state_dict(), 250 | 'iter_step': self.iter_step, 251 | } 252 | 253 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 254 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 255 | 256 | def validate_image(self, idx=-1, resolution_level=-1, only_normals=False, pose=None): 257 | if idx < 0: 258 | idx = np.random.randint(self.dataset.n_images) 259 | 260 | print('Validate: camera: {}'.format(idx)) 261 | 262 | if resolution_level < 0: 263 | resolution_level = self.validate_resolution_level 264 | if pose is None: 265 | rays_o, rays_d, uv = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 266 | else: 267 | rays_o, rays_d, uv = self.dataset.gen_rays_visu(idx, pose, resolution_level=resolution_level) 268 | H, W, _ = rays_o.shape 269 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 270 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 271 | uv = uv.reshape(-1, 2).split(self.batch_size) 272 | 273 | out_rgb_fine = [] 274 | out_normal_fine = [] 275 | RS_fine = [] 276 | 277 | for rays_o_batch, rays_d_batch, uv_batch in zip(rays_o, rays_d, uv): 278 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 279 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 280 | 281 | render_out = self.renderer.render(rays_o_batch, 282 | rays_d_batch, 283 | near, 284 | far, 285 | idx, 286 | uv_batch, 287 | self.dataset, 288 | self.scene, 289 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 290 | background_rgb=background_rgb) 291 | 292 | def feasible(key): return (key in render_out) and (render_out[key] is not None) 293 | 294 | if feasible('color_fine'): 295 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 296 | if feasible('gradients') and feasible('weights'): 297 | normals = render_out['normal_map'] 298 | out_normal_fine.append(normals) 299 | if feasible('RS'): 300 | RS = render_out['RS'][:, None].expand(render_out['RS'].shape[0], 3).cpu().numpy() 301 | RS_fine.append(1. / RS) 302 | del render_out 303 | 304 | normal_img = None 305 | if len(out_normal_fine) > 0: 306 | normal_img = torch.from_numpy(np.concatenate(out_normal_fine, axis=0)) / 2. + 0.5 307 | normal_img = normal_img.permute(1, 0).reshape([3, H, W]) 308 | 309 | os.makedirs(os.path.join(self.base_exp_dir, 'normals_all'), exist_ok=True) 310 | os.makedirs(os.path.join(self.base_exp_dir, 'test_images_all'), exist_ok=True) 311 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) 312 | os.makedirs(os.path.join(self.base_exp_dir, 'RS_all'), exist_ok=True) 313 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 314 | 315 | torchvision.utils.save_image(normal_img.clone(), os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_0_{}.png'.format(self.iter_step, idx)), nrow=8) 316 | img_fine = None 317 | if len(out_rgb_fine) > 0: 318 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 319 | if len(RS_fine) > 0: 320 | RS_fine = (np.concatenate(RS_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 321 | 322 | if only_normals: 323 | torchvision.utils.save_image(normal_img.clone(), os.path.join(self.base_exp_dir, 'normals_all', '{:0>8d}_0_{}.png'.format(self.iter_step, idx)), nrow=8) 324 | cv.imwrite(os.path.join(self.base_exp_dir, 'test_images_all', '{:0>8d}_0_{}.png'.format(self.iter_step, idx)), img_fine[..., 0]) 325 | normal_img[0,:,:], normal_img[2,:,:] = normal_img[2,:,:].clone(), normal_img[0,:,:].clone() 326 | cv.imwrite(os.path.join(self.base_exp_dir, 'RS_all', '{:0>8d}_0_{}.png'.format(self.iter_step, idx)), RS_fine[..., 0]) 327 | return normal_img.permute(1,2,0) * 256., torch.from_numpy(img_fine[..., 0] / 256.) 328 | 329 | 330 | 331 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 332 | 333 | 334 | 335 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 336 | 337 | for i in range(img_fine.shape[-1]): 338 | if len(out_rgb_fine) > 0: 339 | cv.imwrite(os.path.join(self.base_exp_dir, 340 | 'validations_fine', 341 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 342 | np.concatenate([img_fine[..., i], RS_fine[..., i], 343 | self.dataset.image_at(idx, resolution_level=resolution_level)])) 344 | if pose is not None: 345 | img_fine = torch.from_numpy(img_fine[..., 0]) 346 | img_fine[:,:,0], img_fine[:,:,2] = img_fine[:,:,2].clone(), img_fine[:,:,0].clone() 347 | return img_fine, normal_img.permute(1,2,0) 348 | 349 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level): 350 | """ 351 | Interpolate view between two cameras. 352 | """ 353 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level) 354 | H, W, _ = rays_o.shape 355 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 356 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 357 | 358 | out_rgb_fine = [] 359 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 360 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 361 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 362 | 363 | render_out = self.renderer.render(rays_o_batch, 364 | rays_d_batch, 365 | near, 366 | far, 367 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 368 | background_rgb=background_rgb) 369 | 370 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 371 | 372 | del render_out 373 | 374 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8) 375 | return img_fine 376 | 377 | def validate_mesh(self, resolution=64, threshold=0.0, ckpt_path=None): 378 | if ckpt_path is not None: 379 | self.load_ckpt_validation(ckpt_path) 380 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 381 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 382 | 383 | vertices, triangles =\ 384 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold) 385 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) 386 | 387 | mesh = trimesh.Trimesh(vertices, triangles) 388 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', 'inter_mesh.ply')) 389 | 390 | # For visibility identification 391 | mesh_ = o3d.io.read_triangle_mesh(os.path.join(self.base_exp_dir, 'meshes', 'inter_mesh.ply')) 392 | mesh_ = o3d.t.geometry.TriangleMesh.from_legacy(mesh_) 393 | scene = o3d.t.geometry.RaycastingScene() 394 | cube_id = scene.add_triangles(mesh_) 395 | 396 | if self.iter_step == self.end_iter - 1: 397 | resolution = 512 398 | 399 | vertices, triangles =\ 400 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold) 401 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) 402 | 403 | mesh = trimesh.Trimesh(vertices, triangles) 404 | mesh.apply_transform(self.dataset.scale_mat) #transform to orignial space for evaluation 405 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}_eval.ply'.format(self.iter_step))) 406 | 407 | #self.validate_all_normals(ckpt_path) 408 | 409 | logging.info('End') 410 | return scene 411 | 412 | def interpolate_view(self, img_idx_0, img_idx_1): 413 | images = [] 414 | n_frames = 60 415 | for i in range(n_frames): 416 | print(i) 417 | images.append(self.render_novel_image(img_idx_0, 418 | img_idx_1, 419 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, 420 | resolution_level=4)) 421 | for i in range(n_frames): 422 | images.append(images[n_frames - i - 1]) 423 | 424 | fourcc = cv.VideoWriter_fourcc(*'mp4v') 425 | video_dir = os.path.join(self.base_exp_dir, 'render') 426 | os.makedirs(video_dir, exist_ok=True) 427 | h, w, _ = images[0].shape 428 | writer = cv.VideoWriter(os.path.join(video_dir, 429 | '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), 430 | fourcc, 30, (w, h)) 431 | 432 | for image in images: 433 | writer.write(image) 434 | 435 | writer.release() 436 | 437 | def validate_all_normals(self, ckpt_path): 438 | if ckpt_path is not None: 439 | self.load_ckpt_validation(ckpt_path) 440 | #self.scene = self.validate_mesh(self.result, resolution=128) 441 | total_MAE = 0 442 | total_PNSR = 0 443 | idxs = [i for i in range(self.dataset.n_images)] 444 | f = open(os.path.join(self.base_exp_dir, 'result_normal.txt'), 'a') 445 | for idx in idxs: 446 | normal_maps, color_fine = self.validate_image(idx, resolution_level=1, only_normals=True) 447 | try: 448 | GT_normal = torch.from_numpy(self.dataset.normal_np[idx]) 449 | GT_color = torch.from_numpy(self.dataset.images_np[idx]) 450 | PSNR = 20.0 * torch.log10(1.0 / ((color_fine - GT_color)**2).mean().sqrt()) 451 | total_PNSR += PSNR 452 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 453 | cos_loss = cos(normal_maps.view(-1, 3), GT_normal.view(-1, 3)) 454 | cos_loss = torch.clamp(cos_loss, (-1.0 + 1e-10), (1.0 - 1e-10)) 455 | loss_rad = torch.acos(cos_loss) 456 | loss_deg = loss_rad * (180.0 / math.pi) 457 | total_MAE += loss_deg.mean() 458 | f.write(str(idx) + '_MAE:') 459 | f.write(str(loss_deg.mean().data.item()) + ' ') 460 | f.write(str(idx) + '_psnr:') 461 | f.write(str(PSNR.data.item())) 462 | f.write('\n') 463 | f.flush() 464 | except: 465 | continue 466 | MAE = total_MAE / self.dataset.n_images 467 | PSNR = total_PNSR / self.dataset.n_images 468 | f.write('\n') 469 | f.write('MAE_final:') 470 | f.write(str(MAE.data.item()) + ' ') 471 | f.write('PSNR_final:') 472 | f.write(str(PSNR.data.item())) 473 | f.close() 474 | 475 | def visualize(self, ckpt_path): 476 | self.load_ckpt_validation(ckpt_path) 477 | rgb_frames = [] 478 | normal_frames = [] 479 | n_poses = 200 480 | pose = generate_spherical_cam_to_world(radius=3.5, n_poses=n_poses) 481 | pose = torch.from_numpy(pose).cuda() 482 | pose = torch.matmul(pose, torch.diag(torch.tensor([1., -1., -1., 1.]))) 483 | for i in range(n_poses): 484 | print('processing:' ,i) 485 | img, normal = self.validate_image(i, resolution_level=1, pose=pose) 486 | rgb_frames.append(img) 487 | normal_frames.append(normal) 488 | imageio.mimwrite(os.path.join(self.base_exp_dir, "video.mp4"), rgb_frames, fps=30, quality=8) 489 | imageio.mimwrite(os.path.join(self.base_exp_dir, "normals.mp4"), normal_frames, fps=30, quality=8) 490 | 491 | if __name__ == '__main__': 492 | print('Hello Wooden') 493 | 494 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 495 | 496 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 497 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 498 | 499 | parser = argparse.ArgumentParser() 500 | parser.add_argument('--conf', type=str, default='./confs/base.conf') 501 | parser.add_argument('--mode', type=str, default='train') 502 | parser.add_argument('--mcube_threshold', type=float, default=0.0) 503 | parser.add_argument('--is_continue', default=False, action="store_true") 504 | parser.add_argument('--gpu', type=int, default=0) 505 | parser.add_argument('--ckpt_path', type=str, default='') 506 | 507 | args = parser.parse_args() 508 | 509 | torch.cuda.set_device(args.gpu) 510 | runner = Runner(args.conf, args.mode, args.is_continue) 511 | 512 | if args.mode == 'train': 513 | runner.train() 514 | elif args.mode == 'validate_mesh': 515 | runner.validate_mesh(runner.result, resolution=512, threshold=args.mcube_threshold, ckpt_path=args.ckpt_path) 516 | elif args.mode == 'visualize_video': 517 | runner.visualize(ckpt_path=args.ckpt_path) 518 | elif args.mode == 'validate_normal': 519 | runner.validate_all_normals(ckpt_path=args.ckpt_path) 520 | elif args.mode.startswith('interpolate'): # Interpolate views given two image indices 521 | _, img_idx_0, img_idx_1 = args.mode.split('_') 522 | img_idx_0 = int(img_idx_0) 523 | img_idx_1 = int(img_idx_1) 524 | runner.interpolate_view(img_idx_0, img_idx_1) 525 | -------------------------------------------------------------------------------- /models/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/embedder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/embedder.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/fields.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/fields.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/renderer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/renderer.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/Ref-NeuS/ee38ec896444296d2f78a35863ca8d9f8d3b97ca/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | import json 7 | import trimesh 8 | import cv2 as cv 9 | from glob import glob 10 | from copy import deepcopy 11 | from scipy.spatial.transform import Rotation as Rot 12 | from scipy.spatial.transform import Slerp 13 | 14 | 15 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 16 | def load_K_Rt_from_P(filename, P=None): 17 | if P is None: 18 | lines = open(filename).read().splitlines() 19 | if len(lines) == 4: 20 | lines = lines[1:] 21 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 22 | P = np.asarray(lines).astype(np.float32).squeeze() 23 | 24 | out = cv.decomposeProjectionMatrix(P) 25 | K = out[0] 26 | R = out[1] 27 | t = out[2] 28 | 29 | K = K / K[2, 2] 30 | intrinsics = np.eye(4) 31 | intrinsics[:3, :3] = K 32 | 33 | pose = np.eye(4, dtype=np.float32) 34 | pose[:3, :3] = R.transpose() 35 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 36 | 37 | return intrinsics, pose 38 | 39 | 40 | class Dataset: 41 | def __init__(self, conf): 42 | super(Dataset, self).__init__() 43 | print('Load data: Begin') 44 | self.device = torch.device('cuda') 45 | self.conf = conf 46 | 47 | self.data_dir = conf.get_string('data_dir') 48 | self.render_cameras_name = conf.get_string('render_cameras_name') 49 | self.object_cameras_name = conf.get_string('object_cameras_name') 50 | 51 | self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True) 52 | self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1) 53 | 54 | with open(os.path.join(self.data_dir, 'transforms_test.json'), 'r') as fp: 55 | data_info = json.load(fp) 56 | 57 | self.images_lis = [] 58 | self.normal_lis = [] 59 | 60 | pose_all = [] 61 | 62 | for frame in data_info['frames']: 63 | img_path = os.path.join(self.data_dir, frame['file_path'][2:] + '.png') 64 | normal_path = os.path.join(self.data_dir, frame['file_path'][2:] + '_normal' + '.png') 65 | pose_all.append(torch.from_numpy(np.array(frame['transform_matrix'], dtype=np.float32))) 66 | self.images_lis.append(img_path) 67 | self.normal_lis.append(normal_path) 68 | 69 | pose_all = torch.stack(pose_all).cuda() 70 | 71 | # Scale_mat: transform the object to unit sphere for training 72 | pcd = trimesh.load(os.path.join(self.data_dir, 'points_of_interest.ply')) 73 | vertices = pcd.vertices 74 | bbox_max = np.max(vertices, axis=0) 75 | bbox_min = np.min(vertices, axis=0) 76 | center = (bbox_max + bbox_min) * 0.5 77 | radius = np.linalg.norm(vertices - center, ord=2, axis=-1).max() 78 | scale_mat = np.diag([radius, radius, radius, 1.0]).astype(np.float32) 79 | scale_mat[:3, 3] = center 80 | 81 | # Scale_mat: transform the reconstructed mesh in unit sphere to original space with scale 150 for evaluation 82 | self.scale_mat = deepcopy(scale_mat) 83 | self.scale_mat[0, 0] *= 150 84 | self.scale_mat[1, 1] *= 150 85 | self.scale_mat[2, 2] *= 150 86 | self.scale_mat[:3, 3] *= 150 87 | 88 | for i in range(pose_all.shape[0]): 89 | pose_all[i, :, 3:] = torch.from_numpy(np.linalg.inv(scale_mat)).cuda() @ pose_all[i, :, 3:] 90 | 91 | # from opencv to opengl 92 | self.pose_all = torch.matmul(pose_all, torch.diag(torch.tensor([1., -1., -1., 1.]))) 93 | 94 | self.n_images = len(self.images_lis) 95 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0 96 | self.normal_np = np.stack([cv.imread(im_name) for im_name in self.normal_lis]) 97 | self.H, self.W, _ = self.images_np[0].shape 98 | 99 | # intrinsic 100 | camera_angle_x = float(data_info['camera_angle_x']) 101 | self.focal = .5 * self.W / np.tan(.5 * camera_angle_x) 102 | self.intrinsics_all = [] 103 | intrinsics = torch.Tensor([ 104 | [self.focal, 0, self.W / 2, 0], 105 | [0, self.focal, self.H / 2, 0], 106 | [0, 0, 1, 0], 107 | [0, 0, 0, 1]]).float() 108 | for i in range(self.images_np.shape[0]): 109 | self.intrinsics_all.append(intrinsics) 110 | 111 | self.masks_np = np.ones_like(self.images_np) * 255. / 256. 112 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cuda() # [n_images, H, W, 3] 113 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cuda() # [n_images, H, W, 3] 114 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] 115 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4] 116 | self.focal = self.intrinsics_all[0][0, 0] 117 | self.inv_pose_all = torch.inverse(self.pose_all) 118 | self.H, self.W = self.images.shape[1], self.images.shape[2] 119 | self.image_pixels = self.H * self.W 120 | self.all_rays_o = self.pose_all[:,:3,3] 121 | 122 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0]) 123 | object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0]) 124 | self.object_bbox_min = object_bbox_min[:3] 125 | self.object_bbox_max = object_bbox_max[:3] 126 | print('Load data: End') 127 | 128 | def gen_rays_at(self, img_idx, resolution_level=1): 129 | """ 130 | Generate rays at world space from one camera. 131 | """ 132 | l = resolution_level 133 | tx = torch.linspace(0, self.W - 1, self.W // l) 134 | ty = torch.linspace(0, self.H - 1, self.H // l) 135 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 136 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 137 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 138 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 139 | rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 140 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3 141 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1), torch.stack([pixels_y, pixels_x], dim=-1).transpose(0, 1) 142 | 143 | def gen_rays_visu(self, img_idx, pose, resolution_level=1): 144 | """ 145 | Generate rays at world space from one camera. 146 | """ 147 | 148 | l = resolution_level 149 | tx = torch.linspace(0, self.W - 1, self.W // l) 150 | ty = torch.linspace(0, self.H - 1, self.H // l) 151 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 152 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 153 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 154 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 155 | rays_v = torch.matmul(pose[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 156 | rays_o = pose[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3 157 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1), torch.stack([pixels_y, pixels_x], dim=-1).transpose(0, 1) 158 | 159 | def gen_random_rays_at(self, img_idx, batch_size): 160 | """ 161 | Generate random rays at world space from one camera. 162 | """ 163 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) 164 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) 165 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3 166 | mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3 167 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3 168 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3 169 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 170 | rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3 171 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3 172 | return torch.cat([rays_o, rays_v, color, mask[:, :1]], dim=-1).cuda(), torch.stack([pixels_y, pixels_x], dim=-1) # batch_size, 10 173 | 174 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 175 | """ 176 | Interpolate pose between two cameras. 177 | """ 178 | l = resolution_level 179 | tx = torch.linspace(0, self.W - 1, self.W // l) 180 | ty = torch.linspace(0, self.H - 1, self.H // l) 181 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 182 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 183 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 184 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 185 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 186 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 187 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 188 | pose_0 = np.linalg.inv(pose_0) 189 | pose_1 = np.linalg.inv(pose_1) 190 | rot_0 = pose_0[:3, :3] 191 | rot_1 = pose_1[:3, :3] 192 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 193 | key_times = [0, 1] 194 | slerp = Slerp(key_times, rots) 195 | rot = slerp(ratio) 196 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 197 | pose = pose.astype(np.float32) 198 | pose[:3, :3] = rot.as_matrix() 199 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 200 | pose = np.linalg.inv(pose) 201 | rot = torch.from_numpy(pose[:3, :3]).cuda() 202 | trans = torch.from_numpy(pose[:3, 3]).cuda() 203 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 204 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 205 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 206 | 207 | def near_far_from_sphere(self, rays_o, rays_d): 208 | a = torch.sum(rays_d**2, dim=-1, keepdim=True) 209 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 210 | mid = 0.5 * (-b) / a 211 | near = mid - 1.0 212 | far = mid + 1.0 213 | return near, far 214 | 215 | def image_at(self, idx, resolution_level): 216 | img = cv.imread(self.images_lis[idx]) 217 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 218 | 219 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /models/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import sklearn.neighbors as skln 4 | from tqdm import tqdm 5 | from scipy.io import loadmat 6 | import multiprocessing as mp 7 | import argparse 8 | import os 9 | import trimesh 10 | import json 11 | 12 | def sample_single_tri(input_): 13 | n1, n2, v1, v2, tri_vert = input_ 14 | c = np.mgrid[:n1+1, :n2+1] 15 | c += 0.5 16 | c[0] /= max(n1, 1e-7) 17 | c[1] /= max(n2, 1e-7) 18 | c = np.transpose(c, (1,2,0)) 19 | k = c[c.sum(axis=-1) < 1] # m2 20 | q = v1 * k[:,:1] + v2 * k[:,1:] + tri_vert 21 | return q 22 | 23 | def write_vis_pcd(file, points, colors): 24 | pcd = o3d.geometry.PointCloud() 25 | pcd.points = o3d.utility.Vector3dVector(points) 26 | pcd.colors = o3d.utility.Vector3dVector(colors) 27 | o3d.io.write_point_cloud(file, pcd) 28 | 29 | def evaluation(data, scan, dataset_dir, vis_out_dir, mode='mesh', downsample_density=0.2, patch_size=60, max_dist=20, visualize_threshold=10): 30 | mp.freeze_support() 31 | 32 | thresh = downsample_density 33 | if mode == 'mesh': 34 | pbar = tqdm(total=9) 35 | pbar.set_description('read data mesh') 36 | data_mesh = data 37 | 38 | vertices = np.asarray(data_mesh.vertices) 39 | triangles = np.asarray(data_mesh.triangles) 40 | tri_vert = vertices[triangles] 41 | 42 | pbar.update(1) 43 | pbar.set_description('sample pcd from mesh') 44 | v1 = tri_vert[:,1] - tri_vert[:,0] 45 | v2 = tri_vert[:,2] - tri_vert[:,0] 46 | l1 = np.linalg.norm(v1, axis=-1, keepdims=True) 47 | l2 = np.linalg.norm(v2, axis=-1, keepdims=True) 48 | area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) 49 | non_zero_area = (area2 > 0)[:,0] 50 | l1, l2, area2, v1, v2, tri_vert = [ 51 | arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert] 52 | ] 53 | thr = thresh * np.sqrt(l1 * l2 / area2) 54 | n1 = np.floor(l1 / thr) 55 | n2 = np.floor(l2 / thr) 56 | 57 | with mp.Pool() as mp_pool: 58 | new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024) 59 | 60 | new_pts = np.concatenate(new_pts, axis=0) 61 | data_pcd = np.concatenate([vertices, new_pts], axis=0) 62 | 63 | elif mode == 'pcd': 64 | pbar = tqdm(total=8) 65 | pbar.set_description('read data pcd') 66 | data_pcd_o3d = o3d.io.read_point_cloud(data) 67 | data_pcd = np.asarray(data_pcd_o3d.points) 68 | 69 | pbar.update(1) 70 | pbar.set_description('random shuffle pcd index') 71 | shuffle_rng = np.random.default_rng() 72 | shuffle_rng.shuffle(data_pcd, axis=0) 73 | 74 | pbar.update(1) 75 | pbar.set_description('downsample pcd') 76 | nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1) 77 | nn_engine.fit(data_pcd) 78 | rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) 79 | mask = np.ones(data_pcd.shape[0], dtype=np.bool_) 80 | for curr, idxs in enumerate(rnn_idxs): 81 | if mask[curr]: 82 | mask[idxs] = 0 83 | mask[curr] = 1 84 | data_down = data_pcd[mask] 85 | 86 | 87 | pbar.update(1) 88 | pbar.set_description('masking data pcd') 89 | obs_mask_file = loadmat(f'{dataset_dir}/ObsMask/ObsMask{scan}_10.mat') 90 | ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']] 91 | BB = BB.astype(np.float32) 92 | 93 | patch = patch_size 94 | inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3 95 | data_in = data_down[inbound] 96 | 97 | 98 | data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32) 99 | grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) ==3 100 | data_grid_in = data_grid[grid_inbound] 101 | in_obs = ObsMask[data_grid_in[:,0], data_grid_in[:,1], data_grid_in[:,2]].astype(np.bool_) 102 | data_in_obs = data_in[grid_inbound][in_obs] 103 | 104 | 105 | pbar.update(1) 106 | pbar.set_description('read STL pcd') 107 | stl_pcd = o3d.io.read_point_cloud(f'{dataset_dir}/Points/stl/stl{scan:03}_total.ply') 108 | stl = np.asarray(stl_pcd.points) 109 | 110 | 111 | pbar.update(1) 112 | pbar.set_description('compute data2stl') 113 | nn_engine.fit(stl) 114 | dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True) 115 | max_dist = max_dist 116 | mean_d2s = dist_d2s[dist_d2s < max_dist].mean() 117 | 118 | pbar.update(1) 119 | pbar.set_description('compute stl2data') 120 | ground_plane = loadmat(f'{dataset_dir}/ObsMask/Plane{scan}.mat')['P'] 121 | 122 | stl_hom = np.concatenate([stl, np.ones_like(stl[:,:1])], -1) 123 | above = (ground_plane.reshape((1,4)) * stl_hom).sum(-1) > 0 124 | stl_above = stl[above] 125 | 126 | nn_engine.fit(data_in) 127 | dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) 128 | mean_s2d = dist_s2d[dist_s2d < max_dist].mean() 129 | 130 | pbar.update(1) 131 | pbar.set_description('visualize error') 132 | vis_dist = visualize_threshold 133 | R = np.array([[1,0,0]], dtype=np.float64) 134 | G = np.array([[0,1,0]], dtype=np.float64) 135 | B = np.array([[0,0,1]], dtype=np.float64) 136 | W = np.array([[1,1,1]], dtype=np.float64) 137 | data_color = np.tile(B, (data_down.shape[0], 1)) 138 | data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist 139 | data_color[ np.where(inbound)[0][grid_inbound][in_obs] ] = R * data_alpha + W * (1-data_alpha) 140 | data_color[ np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:,0] >= max_dist] ] = G 141 | write_vis_pcd(f'{vis_out_dir}/vis_{scan:03}_d2s.ply', data_down, data_color) 142 | stl_color = np.tile(B, (stl.shape[0], 1)) 143 | stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist 144 | stl_color[ np.where(above)[0] ] = R * stl_alpha + W * (1-stl_alpha) 145 | stl_color[ np.where(above)[0][dist_s2d[:,0] >= max_dist] ] = G 146 | write_vis_pcd(f'{vis_out_dir}/vis_{scan:03}_s2d.ply', stl, stl_color) 147 | 148 | pbar.update(1) 149 | pbar.set_description('done') 150 | pbar.close() 151 | over_all = (mean_d2s + mean_s2d) / 2 152 | 153 | print(mean_d2s, mean_s2d, over_all) 154 | 155 | return mean_d2s, mean_s2d, over_all 156 | 157 | def evaluation_shinyblender(data, dataset_dir, vis_out_dir, downsample_density=0.3, patch_size=60, max_dist_d=100, 158 | max_dist_t=10, visualize_threshold=10, points_for_plane=None, nonvalid_bbox=None): 159 | mp.freeze_support() 160 | 161 | thresh = downsample_density 162 | 163 | pbar = tqdm(total=9) 164 | pbar.set_description('read data mesh') 165 | data_mesh = data 166 | 167 | vertices = np.asarray(data_mesh.vertices) 168 | triangles = np.asarray(data_mesh.triangles) 169 | tri_vert = vertices[triangles] 170 | 171 | pbar.update(1) 172 | pbar.set_description('sample pcd from mesh') 173 | 174 | v1 = tri_vert[:,1] - tri_vert[:,0] 175 | v2 = tri_vert[:,2] - tri_vert[:,0] 176 | l1 = np.linalg.norm(v1, axis=-1, keepdims=True) 177 | l2 = np.linalg.norm(v2, axis=-1, keepdims=True) 178 | area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) 179 | non_zero_area = (area2 > 0)[:,0] 180 | l1, l2, area2, v1, v2, tri_vert = [ 181 | arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]] 182 | thr = thresh * np.sqrt(l1 * l2 / area2) 183 | n1 = np.floor(l1 / thr) 184 | n2 = np.floor(l2 / thr) 185 | 186 | with mp.Pool() as mp_pool: 187 | new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024) 188 | 189 | new_pts = np.concatenate(new_pts, axis=0) 190 | data_pcd = np.concatenate([vertices, new_pts], axis=0) 191 | 192 | pbar.update(1) 193 | pbar.set_description('random shuffle pcd index') 194 | shuffle_rng = np.random.default_rng() 195 | shuffle_rng.shuffle(data_pcd, axis=0) 196 | 197 | pbar.update(1) 198 | pbar.set_description('downsample pcd') 199 | nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1) 200 | nn_engine.fit(data_pcd) 201 | rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) 202 | mask = np.ones(data_pcd.shape[0], dtype=np.bool_) 203 | for curr, idxs in enumerate(rnn_idxs): 204 | if mask[curr]: 205 | mask[idxs] = 0 206 | mask[curr] = 1 207 | data_down = data_pcd[mask] 208 | 209 | pbar.update(1) 210 | pbar.set_description('read STL pcd') 211 | stl_pcd = o3d.io.read_point_cloud(dataset_dir) 212 | stl = np.asarray(stl_pcd.points) 213 | BB = np.array([stl.min(0), stl.max(0)]) 214 | 215 | # compute lowest surface 216 | p1 = np.array(points_for_plane[0]) 217 | p2 = np.array(points_for_plane[1]) 218 | p3 = np.array(points_for_plane[2]) 219 | v1 = p1 - p2 220 | v2 = p3 - p2 221 | 222 | normal = np.cross(v1, v2) 223 | # make sure the normal toward positive z 224 | if normal[-1] < 0 : 225 | normal = np.cross(v2, v1) 226 | D = np.dot(normal, p1) 227 | 228 | pbar.update(1) 229 | pbar.set_description('masking data pcd') 230 | 231 | BB = BB.astype(np.float32) 232 | 233 | patch = patch_size 234 | inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3 235 | data_in = data_down[inbound] 236 | 237 | above = (data_in @ normal - D) > 0 238 | data_in_above = data_in[above] 239 | 240 | above_stl = (stl @ normal - D) > 0 241 | stl_above = stl[above_stl] 242 | 243 | if nonvalid_bbox is not None: 244 | aa = nonvalid_bbox[0] 245 | bb = nonvalid_bbox[1] 246 | 247 | mask_bbox = ((data_in_above >= bb) & (data_in_above <= aa)).sum(axis=-1) ==3 248 | mask_val = ~mask_bbox 249 | else: 250 | mask_val = np.ones_like(data_in_above) 251 | mask_val = mask_val.astype(bool)[:, 0] 252 | data_in_above = data_in_above[mask_val] 253 | 254 | pbar.update(1) 255 | pbar.set_description('compute data2stl') 256 | nn_engine.fit(stl) 257 | dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_above, n_neighbors=1, return_distance=True) 258 | mean_d2s = dist_d2s[dist_d2s < max_dist_d].mean() 259 | 260 | pbar.update(1) 261 | pbar.set_description('compute stl2data') 262 | nn_engine.fit(data_in) 263 | dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) 264 | mean_s2d = dist_s2d[dist_s2d < max_dist_t].mean() 265 | 266 | pbar.update(1) 267 | pbar.set_description('visualize error') 268 | vis_dist = visualize_threshold 269 | R = np.array([[1,0,0]], dtype=np.float64) 270 | G = np.array([[0,1,0]], dtype=np.float64) 271 | B = np.array([[0,0,1]], dtype=np.float64) 272 | W = np.array([[1,1,1]], dtype=np.float64) 273 | data_color = np.tile(B, (data_down.shape[0], 1)) 274 | data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist 275 | 276 | data_color[ np.where(inbound)[0][above][mask_val] ] = R * data_alpha + W * (1-data_alpha) 277 | data_color[ np.where(inbound)[0][above][mask_val] [dist_d2s[:,0] >= max_dist_d] ] = G 278 | write_vis_pcd(f'{vis_out_dir}/vis_d2s.ply', data_down, data_color) 279 | 280 | stl_color = np.tile(B, (stl.shape[0], 1)) 281 | stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist 282 | stl_color[ np.where(above_stl)[0] ] = R * stl_alpha + W * (1-stl_alpha) 283 | stl_color[ np.where(above_stl)[0][dist_s2d[:,0] >= max_dist_t] ] = G 284 | write_vis_pcd(f'{vis_out_dir}/vis_s2d.ply', stl, stl_color) 285 | 286 | pbar.update(1) 287 | pbar.set_description('done') 288 | pbar.close() 289 | over_all = (mean_d2s + mean_s2d) / 2 290 | 291 | print(mean_d2s, mean_s2d, over_all) 292 | 293 | return mean_d2s, mean_s2d, over_all 294 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | d_in, 12 | d_out, 13 | d_hidden, 14 | n_layers, 15 | skip_in=(4,), 16 | multires=0, 17 | bias=0.5, 18 | scale=1, 19 | geometric_init=True, 20 | weight_norm=True, 21 | inside_outside=False): 22 | super(SDFNetwork, self).__init__() 23 | 24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 25 | 26 | self.embed_fn_fine = None 27 | 28 | if multires > 0: 29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 30 | self.embed_fn_fine = embed_fn 31 | dims[0] = input_ch 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | self.scale = scale 36 | 37 | for l in range(0, self.num_layers - 1): 38 | if l + 1 in self.skip_in: 39 | out_dim = dims[l + 1] - dims[0] 40 | else: 41 | out_dim = dims[l + 1] 42 | 43 | lin = nn.Linear(dims[l], out_dim) 44 | 45 | if geometric_init: 46 | if l == self.num_layers - 2: 47 | if not inside_outside: 48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 49 | torch.nn.init.constant_(lin.bias, -bias) 50 | else: 51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 52 | torch.nn.init.constant_(lin.bias, bias) 53 | elif multires > 0 and l == 0: 54 | torch.nn.init.constant_(lin.bias, 0.0) 55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 57 | elif multires > 0 and l in self.skip_in: 58 | torch.nn.init.constant_(lin.bias, 0.0) 59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 61 | else: 62 | torch.nn.init.constant_(lin.bias, 0.0) 63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 64 | 65 | if weight_norm: 66 | lin = nn.utils.weight_norm(lin) 67 | 68 | setattr(self, "lin" + str(l), lin) 69 | 70 | self.activation = nn.Softplus(beta=100) 71 | 72 | def forward(self, inputs): 73 | inputs = inputs * self.scale 74 | if self.embed_fn_fine is not None: 75 | inputs = self.embed_fn_fine(inputs) 76 | 77 | x = inputs 78 | for l in range(0, self.num_layers - 1): 79 | lin = getattr(self, "lin" + str(l)) 80 | 81 | if l in self.skip_in: 82 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 83 | 84 | x = lin(x) 85 | 86 | if l < self.num_layers - 2: 87 | x = self.activation(x) 88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 89 | 90 | def sdf(self, x): 91 | return self.forward(x)[:, :1] 92 | 93 | def sdf_hidden_appearance(self, x): 94 | return self.forward(x) 95 | 96 | def gradient(self, x): 97 | x.requires_grad_(True) 98 | y = self.sdf(x) 99 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 100 | gradients = torch.autograd.grad( 101 | outputs=y, 102 | inputs=x, 103 | grad_outputs=d_output, 104 | create_graph=True, 105 | retain_graph=True, 106 | only_inputs=True)[0] 107 | return gradients.unsqueeze(1) 108 | 109 | 110 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 111 | class RenderingNetwork(nn.Module): 112 | def __init__(self, 113 | d_feature, 114 | mode, 115 | d_in, 116 | d_out, 117 | d_hidden, 118 | n_layers, 119 | weight_norm=True, 120 | multires_view=0, 121 | squeeze_out=True): 122 | super().__init__() 123 | 124 | self.mode = mode 125 | self.squeeze_out = squeeze_out 126 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 127 | 128 | self.embedview_fn = None 129 | if multires_view > 0: 130 | embedview_fn, input_ch = get_embedder(multires_view) 131 | self.embedview_fn = embedview_fn 132 | dims[0] += (input_ch - 3) 133 | 134 | self.num_layers = len(dims) 135 | 136 | for l in range(0, self.num_layers - 1): 137 | out_dim = dims[l + 1] 138 | lin = nn.Linear(dims[l], out_dim) 139 | 140 | if weight_norm: 141 | lin = nn.utils.weight_norm(lin) 142 | 143 | setattr(self, "lin" + str(l), lin) 144 | 145 | self.relu = nn.ReLU() 146 | 147 | def forward(self, points, normals, view_dirs, feature_vectors): 148 | if self.embedview_fn is not None: 149 | view_dirs = self.embedview_fn(view_dirs) 150 | 151 | rendering_input = None 152 | 153 | if self.mode == 'idr': 154 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 155 | elif self.mode == 'no_view_dir': 156 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 157 | elif self.mode == 'no_normal': 158 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 159 | 160 | x = rendering_input 161 | 162 | for l in range(0, self.num_layers - 1): 163 | lin = getattr(self, "lin" + str(l)) 164 | 165 | x = lin(x) 166 | 167 | if l < self.num_layers - 2: 168 | x = self.relu(x) 169 | 170 | if self.squeeze_out: 171 | x = torch.sigmoid(x) 172 | return x 173 | 174 | 175 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch 176 | class NeRF(nn.Module): 177 | def __init__(self, 178 | D=8, 179 | W=256, 180 | d_in=3, 181 | d_in_view=3, 182 | multires=0, 183 | multires_view=0, 184 | output_ch=4, 185 | skips=[4], 186 | use_viewdirs=False): 187 | super(NeRF, self).__init__() 188 | self.D = D 189 | self.W = W 190 | self.d_in = d_in 191 | self.d_in_view = d_in_view 192 | self.input_ch = 3 193 | self.input_ch_view = 3 194 | self.embed_fn = None 195 | self.embed_fn_view = None 196 | 197 | if multires > 0: 198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 199 | self.embed_fn = embed_fn 200 | self.input_ch = input_ch 201 | 202 | if multires_view > 0: 203 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 204 | self.embed_fn_view = embed_fn_view 205 | self.input_ch_view = input_ch_view 206 | 207 | self.skips = skips 208 | self.use_viewdirs = use_viewdirs 209 | 210 | self.pts_linears = nn.ModuleList( 211 | [nn.Linear(self.input_ch, W)] + 212 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 213 | 214 | ### Implementation according to the official code release 215 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 216 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 217 | 218 | ### Implementation according to the paper 219 | # self.views_linears = nn.ModuleList( 220 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 221 | 222 | if use_viewdirs: 223 | self.feature_linear = nn.Linear(W, W) 224 | self.alpha_linear = nn.Linear(W, 1) 225 | self.rgb_linear = nn.Linear(W // 2, 3) 226 | else: 227 | self.output_linear = nn.Linear(W, output_ch) 228 | 229 | def forward(self, input_pts, input_views): 230 | if self.embed_fn is not None: 231 | input_pts = self.embed_fn(input_pts) 232 | if self.embed_fn_view is not None: 233 | input_views = self.embed_fn_view(input_views) 234 | 235 | h = input_pts 236 | for i, l in enumerate(self.pts_linears): 237 | h = self.pts_linears[i](h) 238 | h = F.relu(h) 239 | if i in self.skips: 240 | h = torch.cat([input_pts, h], -1) 241 | 242 | if self.use_viewdirs: 243 | alpha = self.alpha_linear(h) 244 | feature = self.feature_linear(h) 245 | h = torch.cat([feature, input_views], -1) 246 | 247 | for i, l in enumerate(self.views_linears): 248 | h = self.views_linears[i](h) 249 | h = F.relu(h) 250 | 251 | rgb = self.rgb_linear(h) 252 | return alpha, rgb 253 | else: 254 | assert False 255 | 256 | 257 | class SingleVarianceNetwork(nn.Module): 258 | def __init__(self, init_val): 259 | super(SingleVarianceNetwork, self).__init__() 260 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) 261 | 262 | def forward(self, x): 263 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 264 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | import mcubes 7 | import open3d as o3d 8 | from models.utils import project, normalize 9 | 10 | 11 | def extract_fields(bound_min, bound_max, resolution, query_func): 12 | N = 64 13 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 14 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 15 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 16 | 17 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 18 | with torch.no_grad(): 19 | for xi, xs in enumerate(X): 20 | for yi, ys in enumerate(Y): 21 | for zi, zs in enumerate(Z): 22 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 23 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) 24 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 25 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 26 | return u 27 | 28 | 29 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): 30 | print('threshold: {}'.format(threshold)) 31 | u = extract_fields(bound_min, bound_max, resolution, query_func) 32 | vertices, triangles = mcubes.marching_cubes(u, threshold) 33 | b_max_np = bound_max.detach().cpu().numpy() 34 | b_min_np = bound_min.detach().cpu().numpy() 35 | 36 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 37 | return vertices, triangles 38 | 39 | 40 | def sample_pdf(bins, weights, n_samples, det=False): 41 | # This implementation is from NeRF 42 | # Get pdf 43 | weights = weights + 1e-5 # prevent nans 44 | pdf = weights / torch.sum(weights, -1, keepdim=True) 45 | cdf = torch.cumsum(pdf, -1) 46 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 47 | # Take uniform samples 48 | if det: 49 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) 50 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 51 | else: 52 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 53 | 54 | # Invert CDF 55 | u = u.contiguous() 56 | inds = torch.searchsorted(cdf, u, right=True) 57 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 58 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 59 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 60 | 61 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 62 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 63 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 64 | 65 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 66 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 67 | t = (u - cdf_g[..., 0]) / denom 68 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 69 | 70 | return samples 71 | 72 | 73 | class NeuSRenderer: 74 | def __init__(self, 75 | nerf, 76 | sdf_network, 77 | deviation_network, 78 | color_network, 79 | n_samples, 80 | n_importance, 81 | n_outside, 82 | up_sample_steps, 83 | perturb): 84 | self.nerf = nerf 85 | self.sdf_network = sdf_network 86 | self.deviation_network = deviation_network 87 | self.color_network = color_network 88 | self.n_samples = n_samples 89 | self.n_importance = n_importance 90 | self.n_outside = n_outside 91 | self.up_sample_steps = up_sample_steps 92 | self.perturb = perturb 93 | 94 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None): 95 | """ 96 | Render background 97 | """ 98 | batch_size, n_samples = z_vals.shape 99 | 100 | # Section length 101 | dists = z_vals[..., 1:] - z_vals[..., :-1] 102 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 103 | mid_z_vals = z_vals + dists * 0.5 104 | 105 | # Section midpoints 106 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 107 | 108 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) 109 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 110 | 111 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 112 | 113 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) 114 | dirs = dirs.reshape(-1, 3) 115 | 116 | density, sampled_color = nerf(pts, dirs) 117 | sampled_color = torch.sigmoid(sampled_color) 118 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) 119 | alpha = alpha.reshape(batch_size, n_samples) 120 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 121 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3) 122 | color = (weights[:, :, None] * sampled_color).sum(dim=1) 123 | if background_rgb is not None: 124 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) 125 | 126 | return { 127 | 'color': color, 128 | 'sampled_color': sampled_color, 129 | 'alpha': alpha, 130 | 'weights': weights, 131 | } 132 | 133 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): 134 | """ 135 | Up sampling give a fixed inv_s 136 | """ 137 | batch_size, n_samples = z_vals.shape 138 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 139 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) 140 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) 141 | sdf = sdf.reshape(batch_size, n_samples) 142 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] 143 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] 144 | mid_sdf = (prev_sdf + next_sdf) * 0.5 145 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 146 | 147 | # ---------------------------------------------------------------------------------------------------------- 148 | # Use min value of [ cos, prev_cos ] 149 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more 150 | # robust when meeting situations like below: 151 | # 152 | # SDF 153 | # ^ 154 | # |\ -----x----... 155 | # | \ / 156 | # | x x 157 | # |---\----/-------------> 0 level 158 | # | \ / 159 | # | \/ 160 | # | 161 | # ---------------------------------------------------------------------------------------------------------- 162 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) 163 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) 164 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) 165 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere 166 | 167 | dist = (next_z_vals - prev_z_vals) 168 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 169 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5 170 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) 171 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s) 172 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 173 | weights = alpha * torch.cumprod( 174 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 175 | 176 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() 177 | return z_samples 178 | 179 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 180 | batch_size, n_samples = z_vals.shape 181 | _, n_importance = new_z_vals.shape 182 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] 183 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 184 | z_vals, index = torch.sort(z_vals, dim=-1) 185 | 186 | if not last: 187 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) 188 | sdf = torch.cat([sdf, new_sdf], dim=-1) 189 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) 190 | index = index.reshape(-1) 191 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) 192 | 193 | return z_vals, sdf 194 | 195 | def render_core(self, 196 | rays_o, 197 | rays_d, 198 | z_vals, 199 | sample_dist, 200 | ref_idx, 201 | uv, 202 | dataset, 203 | inter_mesh, 204 | sdf_network, 205 | deviation_network, 206 | color_network, 207 | background_alpha=None, 208 | background_sampled_color=None, 209 | background_rgb=None, 210 | cos_anneal_ratio=0.0): 211 | batch_size, n_samples = z_vals.shape 212 | 213 | # Section length 214 | dists = z_vals[..., 1:] - z_vals[..., :-1] 215 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 216 | mid_z_vals = z_vals + dists * 0.5 217 | 218 | # Parameters for projection 219 | inv_c2w_all = dataset.inv_pose_all.cuda() 220 | intrinsics_all = dataset.intrinsics_all.cuda() 221 | scr_ind = [i for i in range(inv_c2w_all.shape[0])] 222 | scr_ind.remove(ref_idx) 223 | 224 | inv_src_pose = inv_c2w_all[scr_ind] 225 | src_intr = intrinsics_all[scr_ind][:, :3, :3] 226 | 227 | # Section midpoints 228 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 229 | dirs = rays_d[:, None, :].expand(pts.shape) 230 | 231 | pts = pts.reshape(-1, 3) 232 | dirs = dirs.reshape(-1, 3) 233 | 234 | sdf_nn_output = sdf_network(pts) 235 | sdf = sdf_nn_output[:, :1] 236 | feature_vector = sdf_nn_output[:, 1:] 237 | 238 | gradients = sdf_network.gradient(pts).squeeze() 239 | normals = F.normalize(gradients, dim=-1) 240 | 241 | refdirs = 2.0 * torch.sum(normals * -dirs, axis=-1, keepdims=True) * normals + dirs 242 | sampled_color = color_network(pts, gradients, refdirs, feature_vector).reshape(batch_size, n_samples, 3) 243 | 244 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 245 | inv_s = inv_s.expand(batch_size * n_samples, 1) 246 | 247 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 248 | 249 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 250 | # the cos value "not dead" at the beginning training iterations, for better convergence. 251 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 252 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 253 | 254 | # Estimate signed distances at section points 255 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 256 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 257 | 258 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 259 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 260 | 261 | p = prev_cdf - next_cdf 262 | c = prev_cdf 263 | 264 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 265 | 266 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 267 | inside_sphere = (pts_norm < 1.0).float().detach() 268 | relax_inside_sphere = (pts_norm < 1.2).float().detach() 269 | 270 | # Render with background 271 | if background_alpha is not None: 272 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) 273 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) 274 | sampled_color = sampled_color * inside_sphere[:, :, None] +\ 275 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] 276 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) 277 | 278 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 279 | weights_sum = weights.sum(dim=-1, keepdim=True) 280 | 281 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 282 | if background_rgb is not None: # Fixed background, usually black 283 | color = color + background_rgb * (1.0 - weights_sum) 284 | 285 | # Eikonal loss 286 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, 287 | dim=-1) - 1.0) ** 2 288 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) 289 | 290 | # Normal map 291 | normals_map = F.normalize(gradients.reshape(batch_size, n_samples, 3), dim=-1) 292 | normals_map = (normals_map * weights[:, :128, None]).sum(dim=-2).detach().cpu().numpy() 293 | 294 | # Reflection Score 295 | RS = 10. * torch.ones(batch_size, dtype=torch.float32).cuda() 296 | if inter_mesh is not None: 297 | # __________________________________________________________________________________________________________ 298 | # _______________________________ Localize surface points with predicted sdf _______________________________ 299 | # __________________________________________________________________________________________________________ 300 | sdf_d = sdf.reshape(batch_size, n_samples) 301 | prev_sdf, next_sdf = sdf_d[:, :-1], sdf_d[:, 1:] 302 | sign = prev_sdf * next_sdf 303 | 304 | surf_exit_inds = torch.unique(torch.where(sign < 0)[0]) 305 | 306 | sign = torch.where(sign <= 0, torch.ones_like(sign), torch.zeros_like(sign)) 307 | idx = reversed(torch.Tensor(range(1, n_samples)).cuda()) 308 | tmp = torch.einsum("ab,b->ab", (sign, idx)) 309 | prev_idx = torch.argmax(tmp, 1, keepdim=True) 310 | next_idx = prev_idx + 1 311 | 312 | prev_inside_sphere = torch.gather(inside_sphere, 1, prev_idx) 313 | next_inside_sphere = torch.gather(inside_sphere, 1, next_idx) 314 | mid_inside_sphere = (0.5 * (prev_inside_sphere + next_inside_sphere) > 0.5).float() 315 | 316 | sdf1 = torch.gather(sdf_d, 1, prev_idx) 317 | sdf2 = torch.gather(sdf_d, 1, next_idx) 318 | z_vals1 = torch.gather(mid_z_vals, 1, prev_idx) 319 | z_vals2 = torch.gather(mid_z_vals, 1, next_idx) 320 | 321 | z_vals_sdf0 = (sdf1 * z_vals2 - sdf2 * z_vals1) / (sdf1 - sdf2 + 1e-10) 322 | z_vals_sdf0 = torch.where(z_vals_sdf0 < 0, torch.zeros_like(z_vals_sdf0), z_vals_sdf0) 323 | max_z_val = torch.max(z_vals) 324 | z_vals_sdf0 = torch.where(z_vals_sdf0 > max_z_val, torch.zeros_like(z_vals_sdf0), z_vals_sdf0) 325 | points_for_warp = (rays_o[:, None, :] + rays_d[:, None, :] * z_vals_sdf0[..., :, None]).detach() 326 | 327 | # __________________________________________________________________________________________________________ 328 | # _________________________________________ Occlusion Detection ____________________________________________ 329 | # __________________________________________________________________________________________________________ 330 | ref_point_dir = torch.cat((rays_o, rays_d), dim=-1).cpu().numpy() 331 | ref_point_dir = o3d.core.Tensor(ref_point_dir, dtype=o3d.core.Dtype.Float32) 332 | 333 | ans_ref = inter_mesh.cast_rays(ref_point_dir) 334 | t_hit_ref = torch.from_numpy(ans_ref['t_hit'].numpy()).cuda().squeeze(0) 335 | 336 | # inf means the ray dose not hit the surface 337 | val_ray_inds = torch.where(~torch.isinf(t_hit_ref))[0] 338 | tmp = list(set(val_ray_inds.cpu().numpy()) & set(surf_exit_inds.cpu().numpy())) # double check: sdf network and mesh 339 | val_ray_inside_inds = torch.tensor(tmp).cuda() 340 | 341 | if val_ray_inside_inds.shape[0] != 0: 342 | 343 | # get source rays_o 344 | rays_o_src = dataset.all_rays_o.cuda()[scr_ind] 345 | # get all rays_d for all validate point 346 | rays_d_scr = points_for_warp - rays_o_src 347 | rays_d_scr = F.normalize(rays_d_scr, dim=-1) 348 | 349 | rays_o_src = rays_o_src.expand(rays_d_scr.size()) 350 | 351 | points_for_warp = points_for_warp[val_ray_inside_inds] 352 | 353 | #cal source 354 | val_rays_o_scr = rays_o_src[val_ray_inside_inds] 355 | val_rays_d_scr = rays_d_scr[val_ray_inside_inds] 356 | 357 | all_point_dir = torch.cat((val_rays_o_scr, val_rays_d_scr),dim=-1).cpu().numpy() 358 | all_point_dir = o3d.core.Tensor(all_point_dir, dtype=o3d.core.Dtype.Float32) 359 | ans_source = inter_mesh.cast_rays(all_point_dir) 360 | 361 | t_hit_src = torch.from_numpy(ans_source['t_hit'].numpy()).cuda() 362 | t_hit_src[torch.where(torch.isinf(t_hit_src))] = -10. 363 | 364 | # distance from surface points to source rays_o 365 | dist = ((points_for_warp.repeat(1, len(scr_ind), 1) - val_rays_o_scr) / val_rays_d_scr)[..., 0] 366 | # we slightly relax the occlusion judegment. If the surfaces are optimized inward, all source views are occluded. 367 | dist_ref = ((points_for_warp.squeeze(1) - rays_o[val_ray_inside_inds]) / rays_o[val_ray_inside_inds])[..., 0].detach() 368 | diff_ref = (dist_ref - t_hit_ref[val_ray_inside_inds]).detach() 369 | 370 | diff_ref[torch.where(torch.isinf(diff_ref))] = 0. 371 | diff_ref[torch.where(diff_ref < 0 )] = 0. 372 | 373 | val_inds = torch.where(( (dist - 1.5 * diff_ref[:,None].repeat(1, rays_d_scr.shape[1]) - 0.05) <= t_hit_src)) 374 | all_val_inds = torch.zeros(val_rays_d_scr.shape[:2], dtype=torch.int64).cuda() 375 | all_val_inds[val_inds] = 1 376 | 377 | with torch.no_grad(): 378 | grid_px, in_front = project(points_for_warp.view(-1, 3), inv_src_pose[:, :3].cuda(), src_intr[:, :3, :3].cuda()) 379 | grid_px[..., 0], grid_px[..., 1] = grid_px[..., 1].clone(), grid_px[..., 0].clone() 380 | 381 | grid = normalize(grid_px.squeeze(0), dataset.H, dataset.W, clamp=10) 382 | warping_mask_full = (in_front.squeeze(0) & (grid < 1).all(dim=-1) & (grid > -1).all(dim=-1)) 383 | 384 | sampled_rgb_vals = F.grid_sample(dataset.images[scr_ind].squeeze(0).permute(0, 3, 2, 1), grid.unsqueeze(1), align_corners=True).squeeze(2).transpose(1, 2) 385 | sampled_rgb_vals[~warping_mask_full, :] = 0 #[num_scr, num_val_rays, 3] 386 | all_rgbs_warp = sampled_rgb_vals.transpose(0, 1) #[num_val_rays, num_scr, 3] 387 | 388 | bk_ind = torch.all(all_rgbs_warp == 0, dim=2) 389 | all_val_inds_fina = all_val_inds * warping_mask_full.transpose(0,1) * ~bk_ind 390 | 391 | num_val = torch.sum(all_val_inds_fina, dim=-1) 392 | bk_num = torch.sum(bk_ind * all_val_inds * warping_mask_full.transpose(0, 1), dim=-1) 393 | _val_ind = torch.where((num_val>=5) & (bk_num <= 10)) 394 | 395 | # here we use L1 distance, which achieves similar results 396 | RS_temp = 10. * torch.ones((val_ray_inside_inds.shape[0])).cuda() 397 | uv_val = uv[val_ray_inside_inds] 398 | 399 | anchor_rgb = dataset.images[ref_idx][(uv_val[:,0].long(), uv_val[:,1].long())].view(-1, 3).cuda() 400 | diff_color = torch.zeros_like(all_rgbs_warp).cuda() 401 | all_warp_color = torch.zeros_like(all_rgbs_warp).cuda() 402 | 403 | diff_color[all_val_inds_fina.bool()] = torch.abs(all_rgbs_warp - anchor_rgb[:, None, :].expand(all_rgbs_warp.size()))[all_val_inds_fina.bool()] 404 | all_warp_color[all_val_inds_fina.bool()] = torch.abs(all_rgbs_warp - anchor_rgb[:, None, :].expand(all_rgbs_warp.size()))[all_val_inds_fina.bool()] 405 | val_mean = torch.sum(all_warp_color, dim=-2) / num_val.unsqueeze(-1) 406 | 407 | RS_temp[_val_ind] = (val_mean.sum(-1)[_val_ind] * 10.).clamp(min=1., max=5.) 408 | 409 | RS[val_ray_inside_inds] = RS_temp 410 | 411 | return { 412 | 'color': color, 413 | 'sdf': sdf, 414 | 'dists': dists, 415 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 416 | 's_val': 1.0 / inv_s, 417 | 'mid_z_vals': mid_z_vals, 418 | 'weights': weights, 419 | 'cdf': c.reshape(batch_size, n_samples), 420 | 'gradient_error': gradient_error, 421 | 'inside_sphere': inside_sphere, 422 | 'RS': RS, 423 | 'normal_map': normals_map, 424 | } 425 | 426 | def render(self, rays_o, rays_d, near, far, img_idx, uv, dataset, inter_mesh, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0): 427 | batch_size = len(rays_o) 428 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere 429 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 430 | z_vals = near + (far - near) * z_vals[None, :] 431 | 432 | z_vals_outside = None 433 | if self.n_outside > 0: 434 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 435 | 436 | n_samples = self.n_samples 437 | perturb = self.perturb 438 | 439 | if perturb_overwrite >= 0: 440 | perturb = perturb_overwrite 441 | if perturb > 0: 442 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 443 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 444 | 445 | if self.n_outside > 0: 446 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) 447 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 448 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 449 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 450 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand 451 | 452 | if self.n_outside > 0: 453 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 454 | 455 | background_alpha = None 456 | background_sampled_color = None 457 | 458 | # Up sample 459 | if self.n_importance > 0: 460 | with torch.no_grad(): 461 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] 462 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) 463 | 464 | for i in range(self.up_sample_steps): 465 | new_z_vals = self.up_sample(rays_o, 466 | rays_d, 467 | z_vals, 468 | sdf, 469 | self.n_importance // self.up_sample_steps, 470 | 64 * 2**i) 471 | z_vals, sdf = self.cat_z_vals(rays_o, 472 | rays_d, 473 | z_vals, 474 | new_z_vals, 475 | sdf, 476 | last=(i + 1 == self.up_sample_steps)) 477 | 478 | n_samples = self.n_samples + self.n_importance 479 | 480 | # Background model 481 | if self.n_outside > 0: 482 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 483 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 484 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) 485 | 486 | background_sampled_color = ret_outside['sampled_color'] 487 | background_alpha = ret_outside['alpha'] 488 | 489 | # Render core 490 | ret_fine = self.render_core(rays_o, 491 | rays_d, 492 | z_vals, 493 | sample_dist, 494 | img_idx, 495 | uv, 496 | dataset, 497 | inter_mesh, 498 | self.sdf_network, 499 | self.deviation_network, 500 | self.color_network, 501 | background_rgb=background_rgb, 502 | background_alpha=background_alpha, 503 | background_sampled_color=background_sampled_color, 504 | cos_anneal_ratio=cos_anneal_ratio) 505 | 506 | color_fine = ret_fine['color'] 507 | weights = ret_fine['weights'] 508 | weights_sum = weights.sum(dim=-1, keepdim=True) 509 | gradients = ret_fine['gradients'] 510 | s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) 511 | 512 | return { 513 | 'color_fine': color_fine, 514 | 's_val': s_val, 515 | 'cdf_fine': ret_fine['cdf'], 516 | 'weight_sum': weights_sum, 517 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], 518 | 'gradients': gradients, 519 | 'weights': weights, 520 | 'gradient_error': ret_fine['gradient_error'], 521 | 'inside_sphere': ret_fine['inside_sphere'], 522 | 'RS': ret_fine['RS'], 523 | 'normal_map': ret_fine['normal_map'], 524 | } 525 | 526 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): 527 | return extract_geometry(bound_min, 528 | bound_max, 529 | resolution=resolution, 530 | threshold=threshold, 531 | query_func=lambda pts: -self.sdf_network.sdf(pts)) 532 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # NeuralWarp All rights reseved to Thales LAS and ENPC. 2 | # 3 | # This code is freely available for academic use only and Provided “as is” without any warranty. 4 | # 5 | # Modification are allowed for academic research provided that the following conditions are met : 6 | # * Redistributions of source code or any format must retain the above copyright notice and this list of conditions. 7 | # * Neither the name of Thales LAS and ENPC nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def add_hom(pts): 14 | try: 15 | dev = pts.device 16 | ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1) 17 | return torch.cat((pts, ones), dim=-1) 18 | 19 | except AttributeError: 20 | ones = np.ones((pts.shape[0], 1)) 21 | return np.concatenate((pts, ones), axis=1) 22 | 23 | 24 | def quat_to_rot(q): 25 | a, b, c, d = q[:, 0], q[:, 1], q[:, 2], q[:, 3] 26 | a2, b2, c2, d2 = a ** 2, b ** 2, c ** 2, d ** 2 27 | if isinstance(q, torch.Tensor): 28 | R = torch.empty((q.shape[0], 3, 3)) 29 | else: 30 | R = np.empty((q.shape[0], 3, 3)) 31 | R[:, 0, 0] = a2 + b2 - c2 - d2 32 | R[:, 0, 1] = 2 * b * c - 2 * a * d 33 | R[:, 0, 2] = 2 * a * c + 2 * b * d 34 | R[:, 1, 0] = 2 * a * d + 2 * b * c 35 | R[:, 1, 1] = a2 - b2 + c2 - d2 36 | R[:, 1, 2] = 2 * c * d - 2 * a * b 37 | R[:, 2, 0] = 2 * b * d - 2 * a * c 38 | R[:, 2, 1] = 2 * a * b + 2 * c * d 39 | R[:, 2, 2] = a2 - b2 - c2 + d2 40 | 41 | return R 42 | 43 | 44 | def rot_to_quat(M): 45 | q = np.empty((M.shape[0], 4,)) 46 | t = np.trace(M, axis1=1, axis2=2) 47 | 48 | cond1 = t > 0 49 | cond2 = ~cond1 & (M[:, 0, 0] > M[:, 1, 1]) & (M[:, 0, 0] > M[:, 2, 2]) 50 | cond3 = ~cond1 & ~cond2 & (M[:, 1, 1] > M[:, 2, 2]) 51 | cond4 = ~cond1 & ~cond2 & ~cond3 52 | 53 | S = 2 * np.sqrt(1.0 + t[cond1]) 54 | q[cond1, 0] = 0.25 * S 55 | q[cond1, 1] = (M[cond1, 2, 1] - M[cond1, 1, 2]) / S 56 | q[cond1, 2] = (M[cond1, 0, 2] - M[cond1, 2, 0]) / S 57 | q[cond1, 3] = (M[cond1, 1, 0] - M[cond1, 0, 1]) / S 58 | 59 | S = np.sqrt(1.0 + M[cond2, 0, 0] - M[cond2, 1, 1] - M[cond2, 2,2]) * 2 60 | q[cond2, 0] = (M[cond2, 2, 1] - M[cond2, 1, 2]) / S 61 | q[cond2, 1] = 0.25 * S 62 | q[cond2, 2] = (M[cond2, 0, 1] + M[cond2, 1, 0]) / S 63 | q[cond2, 3] = (M[cond2, 0, 2] + M[cond2, 2, 0]) / S 64 | 65 | S = np.sqrt(1.0 + M[cond3, 1, 1] - M[cond3, 0, 0] - M[cond3, 2, 2]) * 2 66 | q[cond3, 0] = (M[cond3, 0, 2] - M[cond3, 2, 0]) / S 67 | q[cond3, 1] = (M[cond3, 0, 1] + M[cond3, 1, 0]) / S 68 | q[cond3, 2] = 0.25 * S 69 | q[cond3, 3] = (M[cond3, 1, 2] + M[cond3, 2, 1]) / S 70 | 71 | S = np.sqrt(1.0 + M[cond4, 2, 2] - M[cond4, 0, 0] - M[cond4, 1, 1]) * 2 72 | q[cond4, 0] = (M[cond4, 1, 0] - M[cond4, 0, 1]) / S 73 | q[cond4, 1] = (M[cond4, 0, 2] + M[cond4, 2, 0]) / S 74 | q[cond4, 2] = (M[cond4, 1, 2] + M[cond4, 2, 1]) / S 75 | q[cond4, 3] = 0.25 * S 76 | 77 | return q / np.linalg.norm(q, axis=1, keepdims=True) 78 | 79 | 80 | def normalize(flow, h, w, clamp=None): 81 | # either h and w are simple float or N torch.tensor where N batch size 82 | try: 83 | h.device 84 | 85 | except AttributeError: 86 | h = torch.tensor(h, device=flow.device).float().unsqueeze(0) 87 | w = torch.tensor(w, device=flow.device).float().unsqueeze(0) 88 | 89 | if len(flow.shape) == 4: 90 | w = w.unsqueeze(1).unsqueeze(2) 91 | h = h.unsqueeze(1).unsqueeze(2) 92 | elif len(flow.shape) == 3: 93 | w = w.unsqueeze(1) 94 | h = h.unsqueeze(1) 95 | elif len(flow.shape) == 5: 96 | w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2) 97 | h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2) 98 | 99 | res = torch.empty_like(flow) 100 | if res.shape[-1] == 3: 101 | res[..., 2] = 1 102 | 103 | # for grid_sample with align_corners=True 104 | # https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33 105 | res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1 106 | res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1 107 | 108 | if clamp: 109 | return torch.clamp(res, -clamp, clamp) 110 | else: 111 | return res 112 | 113 | 114 | def unnormalize(flow, h, w): 115 | try: 116 | h.device 117 | except AttributeError: 118 | h = torch.tensor(h, device=flow.device).float().unsqueeze(0) 119 | w = torch.tensor(w, device=flow.device).float().unsqueeze(0) 120 | 121 | if len(flow.shape) == 4: 122 | w = w.unsqueeze(1).unsqueeze(2) 123 | h = h.unsqueeze(1).unsqueeze(2) 124 | elif len(flow.shape) == 3: 125 | w = w.unsqueeze(1) 126 | h = h.unsqueeze(1) 127 | 128 | res = torch.empty_like(flow) 129 | 130 | if res.shape[-1] == 3: 131 | res[..., 2] = 1 132 | 133 | # idem: https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33 134 | res[..., 0] = ((flow[..., 0] + 1) / 2) * (w - 1) 135 | res[..., 1] = ((flow[..., 1] + 1) / 2) * (h - 1) 136 | 137 | return res 138 | 139 | 140 | def project(points, pose, intr): 141 | xyz = (intr.unsqueeze(1) @ pose.unsqueeze(1) @ add_hom(points).unsqueeze(-1))[..., :3, 0] 142 | in_front = xyz[..., 2] > 0 143 | grid = xyz[..., :2] / torch.clamp(xyz[..., 2:], 1e-8) 144 | return grid, in_front 145 | 146 | def patch_homography(H, uv): 147 | N, Npx = uv.shape[:2] 148 | Nsrc = H.shape[0] 149 | H = H.view(Nsrc, N, -1, 3, 3) 150 | hom_uv = add_hom(uv) 151 | 152 | # einsum is 30 times faster 153 | # tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3) 154 | tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3) 155 | 156 | grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8) 157 | mask = tmp[..., 2] > 0 158 | return grid, mask 159 | 160 | def generate_spherical_cam_to_world(radius, n_poses=120): 161 | """ 162 | Generate a 360 degree spherical path for rendering 163 | ref: https://github.com/kwea123/nerf_pl/blob/master/datasets/llff.py 164 | ref: https://github.com/yenchenlin/nerf-pytorch/blob/master/load_blender.py 165 | Create circular poses around z axis. 166 | Inputs: 167 | radius: the (negative) height and the radius of the circle. 168 | Outputs: 169 | spheric_cams: (n_poses, 3, 4) the cam to world transformation matrix of a circular path 170 | """ 171 | 172 | def spheric_pose(theta, phi, radius): 173 | trans_t = lambda t: np.array([ 174 | [1, 0, 0, 0], 175 | [0, 1, 0, 0], 176 | [0, 0, 1, t], 177 | [0, 0, 0, 1], 178 | ], dtype=np.float32) 179 | 180 | rotation_phi = lambda phi: np.array([ 181 | [1, 0, 0, 0], 182 | [0, np.cos(phi), -np.sin(phi), 0], 183 | [0, np.sin(phi), np.cos(phi), 0], 184 | [0, 0, 0, 1], 185 | ], dtype=np.float32) 186 | 187 | rotation_theta = lambda th: np.array([ 188 | [np.cos(th), 0, -np.sin(th), 0], 189 | [0, 1, 0, 0], 190 | [np.sin(th), 0, np.cos(th), 0], 191 | [0, 0, 0, 1], 192 | ], dtype=np.float32) 193 | cam_to_world = trans_t(radius) 194 | cam_to_world = rotation_phi(phi / 180. * np.pi) @ cam_to_world 195 | cam_to_world = rotation_theta(theta /180.* np.pi) @ cam_to_world 196 | cam_to_world = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], 197 | dtype=np.float32) @ cam_to_world 198 | return cam_to_world 199 | 200 | spheric_cams = [] 201 | # for th in np.linspace(0, 2 * np.pi, n_poses + 1)[:-1]: 202 | # spheric_cams += [spheric_pose(th, -30, radius)] 203 | for th in np.linspace(-180,180,n_poses+1)[:-1]: 204 | spheric_cams += [spheric_pose(th, -30, radius)] 205 | return np.stack(spheric_cams, 0) 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trimesh==3.16.4 2 | numpy==1.21.2 3 | pyhocon==0.3.59 4 | opencv_python==4.6.0.66 5 | tqdm==4.62.3 6 | torch==1.11.0 7 | scipy==1.7.3 8 | open3d==0.15.2 9 | imageio==2.22.4 10 | pymcubes==0.1.2 11 | tensorboard==2.11.0 12 | chardet 13 | --------------------------------------------------------------------------------