├── .gitignore ├── LICENSE ├── README.md ├── confs └── dtu.conf ├── exp_runner.py ├── extensions └── chamfer_dist │ ├── __init__.py │ ├── chamfer.cu │ ├── chamfer_cuda.cpp │ ├── setup.py │ └── test.py ├── media ├── comparison.png └── pipeline.png ├── models ├── dataset.py ├── embedder.py ├── fields.py ├── renderer.py ├── udf_dataset.py ├── udf_embedder.py └── udf_fields.py ├── pretrained_model └── vismvsnet.pt ├── requirements.txt ├── tools ├── feat_utils.py ├── logger.py ├── surface_extraction.py └── utils.py └── udf_runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | ### VirtualEnv template 93 | # Virtualenv 94 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 95 | .Python 96 | [Bb]in 97 | [Ii]nclude 98 | [Ll]ib 99 | [Ll]ib64 100 | [Ll]ocal 101 | [Ss]cripts 102 | pyvenv.cfg 103 | .venv 104 | pip-selfcheck.json 105 | ### JetBrains template 106 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 107 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 108 | 109 | # User-specific stuff: 110 | .idea/workspace.xml 111 | .idea/tasks.xml 112 | .idea/dictionaries 113 | .idea/vcs.xml 114 | .idea/jsLibraryMappings.xml 115 | 116 | # Sensitive or high-churn files: 117 | .idea/dataSources.ids 118 | .idea/dataSources.xml 119 | .idea/dataSources.local.xml 120 | .idea/sqlDataSources.xml 121 | .idea/dynamic.xml 122 | .idea/uiDesigner.xml 123 | 124 | # Gradle: 125 | .idea/gradle.xml 126 | .idea/libraries 127 | 128 | # Mongo Explorer plugin: 129 | .idea/mongoSettings.xml 130 | 131 | .idea/ 132 | 133 | ## File-based project format: 134 | *.iws 135 | 136 | ## Plugin-specific files: 137 | 138 | # IntelliJ 139 | /out/ 140 | 141 | # mpeltonen/sbt-idea plugin 142 | .idea_modules/ 143 | 144 | # JIRA plugin 145 | atlassian-ide-plugin.xml 146 | 147 | # Crashlytics plugin (for Android Studio and IntelliJ) 148 | com_crashlytics_export_strings.xml 149 | crashlytics.properties 150 | crashlytics-build.properties 151 | fabric.properties 152 | 153 | data/ 154 | exp/ 155 | .data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Han Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuSurf 2 | Implementation of AAAI'24 paper *NeuSurf: On-Surface Priors for Neural Surface Reconstruction from Sparse Input Views* 3 | 4 | ### [Project Page](https://alvin528.github.io/NeuSurf/) | [Paper](https://arxiv.org/abs/2312.13977) | [Data](https://drive.google.com/drive/folders/18AZw4zi3fNQ-NKttNeVBp9cTja8NBnSA?usp=drive_link) | [Mesh Results](https://drive.google.com/drive/folders/1PVDJNa68OQm7Cisz2_CVNGQb4Zp40HSm?usp=sharing) 5 | 6 |
7 | 8 | ## Overview 9 | 10 |
11 | 12 | ## Installation 13 | 14 | Our code is implemented in Python 3.10, PyTorch 2.0.0 and CUDA 11.7. 15 | - Install Python dependencies 16 | ``` 17 | conda create -n neusurf python=3.10 18 | conda activate neusurf 19 | pip install torch==2.0.0 torchvision==0.15.1 20 | pip install -r requirements.txt 21 | ``` 22 | - Compile C++ extensions 23 | ``` 24 | cd extensions/chamfer_dist 25 | python setup.py install 26 | ``` 27 | 28 | ## Dataset 29 | 30 | Data structure: 31 | 32 | ``` 33 | data 34 | |-- DTU_pixelnerf 35 | |-- 36 | |-- cameras_sphere.npz 37 | |-- pcd 38 | |-- .ply 39 | |-- cam4feat 40 | |-- pair.txt 41 | |-- cam_00000000_flow3.txt 42 | |-- cam_00000001_flow3.txt 43 | ... 44 | |-- image 45 | |-- 000000.png 46 | |-- 000001.png 47 | ... 48 | |-- mask 49 | |-- 000.png 50 | |-- 001.png 51 | ... 52 | |-- DTU_sparseneus 53 | |-- blendedmvs_sparse 54 | ``` 55 | 56 | You can directly download the processed data [here](https://drive.google.com/drive/folders/18AZw4zi3fNQ-NKttNeVBp9cTja8NBnSA?usp=drive_link). 57 | 58 | ## Running 59 | 60 | - Training 61 | 62 | ``` 63 | CUDA_VISIBLE_DEVICES=0 64 | python exp_runner.py --mode train --conf ./confs/dtu.conf --case 65 | ``` 66 | 67 | - Extract mesh 68 | 69 | ``` 70 | CUDA_VISIBLE_DEVICES=0 71 | python exp_runner.py --mode validate_mesh --conf ./confs/dtu.conf --case --is_continue 72 | ``` 73 | 74 | 75 | 76 | ## Citation 77 | 78 | If you find our work useful in your research, please consider citing: 79 | 80 | ```bibtex 81 | @inproceedings{huang2024neusurf, 82 | title={NeuSurf: On-Surface Priors for Neural Surface Reconstruction from Sparse Input Views}, 83 | author={Huang, Han and Wu, Yulun and Zhou, Junsheng and Gao, Ge and Gu, Ming and Liu, Yu-Shen}, 84 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 85 | volume={38}, 86 | number={3}, 87 | pages={2312--2320}, 88 | year={2024} 89 | } 90 | ``` 91 | 92 | ## Acknowledgement 93 | 94 | This implementation is based on [CAP-UDF](https://github.com/junshengzhou/CAP-UDF/), [D-NeuS](https://github.com/fraunhoferhhi/D-NeuS) and [Vis-MVSNet](https://github.com/jzhangbs/Vis-MVSNet). Thanks for these great works. 95 | -------------------------------------------------------------------------------- /confs/dtu.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_dir = ./exp/DTU_pixelnerf/CASE_NAME/ 3 | base_exp_dir = ./exp/DTU_pixelnerf/CASE_NAME/womask_sphere 4 | recording = [ 5 | ./, 6 | ./models 7 | ] 8 | } 9 | 10 | udf_dataset { 11 | data_dir = ./data/DTU_pixelnerf/CASE_NAME/ 12 | } 13 | 14 | dataset { 15 | data_dir = ./data/DTU_pixelnerf/CASE_NAME/ 16 | render_cameras_name = cameras_sphere.npz 17 | object_cameras_name = cameras_sphere.npz 18 | feat_map_h = 384 19 | feat_map_w = 512 20 | } 21 | 22 | udf_train { 23 | learning_rate = 0.001 24 | step1_maxiter = 40000 25 | step2_maxiter = 60000 26 | warm_up_end = 1000 27 | eval_num_points = 1000000 28 | df_filter = 0.01 29 | far = -1 30 | outlier = 0.002 31 | extra_points_rate = 1 32 | low_range = 1.1 33 | 34 | batch_size = 5000 35 | batch_size_step2 = 20000 36 | 37 | save_freq = 5000 38 | val_freq = 2500 39 | val_mesh_freq = 2500 40 | report_freq = 5000 41 | 42 | igr_weight = 0.1 43 | mask_weight = 0.0 44 | load_ckpt = none 45 | } 46 | 47 | train { 48 | learning_rate = 5e-4 49 | learning_rate_alpha = 0.05 50 | end_iter = 300000 51 | 52 | batch_size = 512 53 | validate_resolution_level = 4 54 | warm_up_end = 5000 55 | anneal_end = 50000 56 | use_white_bkgd = False 57 | 58 | save_freq = 1000 59 | val_freq = 1000 60 | val_mesh_freq = 1000 61 | report_freq = 1000 62 | 63 | igr_weight = 0.1 64 | mask_weight = 0.0 65 | 66 | udf_thresh = 5e-2 67 | 68 | phase_delim = [0.16667, 0.5] 69 | local_weight = [0.0, 0.5, 0.05] 70 | pseudo_reg_weight = [0.01, 0.1, 0.01] 71 | depth_from_inside_only = [False, True, True] 72 | } 73 | 74 | udf_model { 75 | ckpt = 60000 76 | 77 | udf_network { 78 | d_out = 1 79 | d_in = 3 80 | d_hidden = 256 81 | n_layers = 8 82 | skip_in = [4] 83 | multires = 0 84 | bias = 0.5 85 | scale = 1.0 86 | geometric_init = True 87 | weight_norm = True 88 | } 89 | } 90 | 91 | model { 92 | nerf { 93 | D = 8, 94 | d_in = 4, 95 | d_in_view = 3, 96 | W = 256, 97 | multires = 10, 98 | multires_view = 4, 99 | output_ch = 4, 100 | skips=[4], 101 | use_viewdirs=True 102 | } 103 | 104 | sdf_network { 105 | d_out = 257 106 | d_in = 3 107 | d_hidden = 256 108 | n_layers = 8 109 | skip_in = [4] 110 | multires = 6 111 | bias = 0.5 112 | scale = 1.0 113 | geometric_init = True 114 | weight_norm = True 115 | } 116 | 117 | variance_network { 118 | init_val = 0.3 119 | } 120 | 121 | rendering_network { 122 | d_feature = 256 123 | mode = idr 124 | d_in = 9 125 | d_out = 3 126 | d_hidden = 256 127 | n_layers = 4 128 | weight_norm = True 129 | multires_view = 4 130 | squeeze_out = True 131 | } 132 | 133 | neus_renderer { 134 | n_samples = 64 135 | n_importance = 64 136 | n_outside = 32 137 | up_sample_steps = 4 138 | perturb = 1.0 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /exp_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | import numpy as np 5 | import cv2 as cv 6 | import trimesh 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.tensorboard import SummaryWriter 10 | from shutil import copyfile 11 | from tqdm import tqdm 12 | from pyhocon import ConfigFactory 13 | from models.dataset import Dataset 14 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF 15 | from models.renderer import NeuSRenderer 16 | 17 | from models.udf_fields import UDFNetwork 18 | from udf_runner import UDFRunner 19 | 20 | 21 | class Runner: 22 | def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False, ckpt=None): 23 | self.device = torch.device('cuda') 24 | 25 | # Configuration 26 | self.conf_path = conf_path 27 | f = open(self.conf_path) 28 | conf_text = f.read() 29 | conf_text = conf_text.replace('CASE_NAME', case) 30 | f.close() 31 | 32 | self.conf = ConfigFactory.parse_string(conf_text) 33 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case) 34 | self.base_exp_dir = self.conf['general.base_exp_dir'] 35 | os.makedirs(self.base_exp_dir, exist_ok=True) 36 | self.dataset = Dataset(self.conf['dataset']) 37 | self.iter_step = 0 38 | 39 | # Training parameters 40 | self.end_iter = self.conf.get_int('train.end_iter') 41 | self.save_freq = self.conf.get_int('train.save_freq') 42 | self.report_freq = self.conf.get_int('train.report_freq') 43 | self.val_freq = self.conf.get_int('train.val_freq') 44 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') 45 | self.batch_size = self.conf.get_int('train.batch_size') 46 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') 47 | self.learning_rate = self.conf.get_float('train.learning_rate') 48 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 49 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') 50 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 51 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) 52 | 53 | # Weights 54 | self.igr_weight = self.conf.get_float('train.igr_weight') 55 | self.mask_weight = self.conf.get_float('train.mask_weight') 56 | self.is_continue = is_continue 57 | self.mode = mode 58 | self.model_list = [] 59 | self.writer = None 60 | self.ckpt = ckpt 61 | 62 | self.udf_thresh = self.conf.get_float('train.udf_thresh') 63 | 64 | self.phase_delims = self.conf.get_list('train.phase_delim') 65 | self.pseudo_reg_weights = self.conf.get_list('train.pseudo_reg_weight') 66 | self.local_weights = self.conf.get_list('train.local_weight') 67 | self.depth_from_inside_only_s = self.conf.get_list('train.depth_from_inside_only') 68 | 69 | def get_param_in_phase(param_list, phase): 70 | if phase < self.phase_delims[0]: 71 | return param_list[0] 72 | elif phase < self.phase_delims[1]: 73 | return param_list[1] 74 | else: 75 | return param_list[2] 76 | self.get_param_in_phase = get_param_in_phase 77 | 78 | # Networks 79 | params_to_train = [] 80 | self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device) 81 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) 82 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) 83 | self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device) 84 | params_to_train += list(self.nerf_outside.parameters()) 85 | params_to_train += list(self.sdf_network.parameters()) 86 | params_to_train += list(self.deviation_network.parameters()) 87 | params_to_train += list(self.color_network.parameters()) 88 | 89 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 90 | 91 | self.renderer = NeuSRenderer(self.nerf_outside, 92 | self.sdf_network, 93 | self.deviation_network, 94 | self.color_network, 95 | self.dataset, 96 | **self.conf['model.neus_renderer']) 97 | 98 | pointcloud = trimesh.load('{}/pcd/{}.ply'.format(self.conf['dataset.data_dir'], case)).vertices 99 | pointcloud = np.asarray(pointcloud) 100 | self.shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])]) 101 | self.shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2] 102 | with torch.no_grad(): 103 | self.shape_center = torch.Tensor(self.shape_center) 104 | 105 | scale_pcd = False 106 | if scale_pcd: 107 | pointcloud = (pointcloud - self.dataset.scale_mats_np[0][:3, 3][None]) / self.dataset.scale_mats_np[0][0, 0] 108 | self.pointcloud = torch.tensor(pointcloud, requires_grad=False, dtype=torch.float32).to(self.device) 109 | 110 | self.udf_network = UDFNetwork(**self.conf['udf_model.udf_network']).to(self.device) 111 | udf_ckpt_path = '{}/udf/checkpoints/ckpt_{:0>6}.pth'.format(self.conf['general.base_dir'].replace('CASE_NAME', case), 112 | self.conf['udf_model.ckpt']) 113 | self.udf_network.load_state_dict(torch.load(udf_ckpt_path, map_location=self.device)['udf_network_fine']) 114 | self.udf_network.eval() 115 | for p in self.udf_network.parameters(): 116 | p.requires_grad = False 117 | logging.info('UDF network successfully loaded') 118 | 119 | # Load checkpoint 120 | latest_model_name = None 121 | if is_continue: 122 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) 123 | model_list = [] 124 | for model_name in model_list_raw: 125 | if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter: 126 | model_list.append(model_name) 127 | model_list.sort() 128 | if self.ckpt == 'latest': 129 | latest_model_name = model_list[-1] 130 | else: 131 | latest_model_name = 'ckpt_{:0>6}.pth'.format(self.ckpt) 132 | 133 | if latest_model_name is not None: 134 | logging.info('Find checkpoint: {}'.format(latest_model_name)) 135 | self.load_checkpoint(latest_model_name) 136 | 137 | # Backup codes and configs for debug 138 | if self.mode[:5] == 'train': 139 | self.file_backup() 140 | 141 | def init_params(self): 142 | self.iter_step = 0 143 | self.learning_rate = self.conf.get_float('train.learning_rate') 144 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 145 | 146 | params_to_train = [] 147 | params_to_train += list(self.nerf_outside.parameters()) 148 | params_to_train += list(self.sdf_network.parameters()) 149 | params_to_train += list(self.deviation_network.parameters()) 150 | params_to_train += list(self.color_network.parameters()) 151 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 152 | 153 | def train(self, prior_initialization=False): 154 | if not self.is_continue: 155 | self.init_params() 156 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) 157 | self.update_learning_rate() 158 | res_step = self.end_iter - self.iter_step 159 | if prior_initialization: 160 | res_step = 5000 - self.iter_step 161 | image_perm = self.get_image_perm() 162 | 163 | for iter_i in tqdm(range(res_step)): 164 | main_img_idx = image_perm[self.iter_step % len(image_perm)] 165 | data ,sample = self.dataset.gen_random_rays_at(main_img_idx, self.batch_size) 166 | 167 | rays_o, rays_d, true_rgb = data[:, :3], data[:, 3: 6], data[:, 6: 9] 168 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 169 | model_input = {} 170 | for attr in ['depth_cams','size', 'center', 'feat', 'feat_src','cam', 'src_cams', 'rays_d_norm']: 171 | model_input[attr] = sample[attr].cuda() 172 | for attr in ['H', 'W', 'src_idxs']: 173 | model_input[attr] = sample[attr] 174 | 175 | background_rgb = None 176 | if self.use_white_bkgd: 177 | background_rgb = torch.ones([1, 3]) 178 | 179 | mask = torch.ones_like(true_rgb[...,0:1]) 180 | mask_sum = mask.sum() + 1e-6 181 | train_phase = self.iter_step/self.end_iter 182 | 183 | random_pts = self.pointcloud[torch.randperm(self.pointcloud.shape[0])[:512]] 184 | if prior_initialization: 185 | random_pts = None 186 | render_out = self.renderer.render(rays_o, rays_d, near, far, main_img_idx, 187 | t=self.iter_step + 1, 188 | random_pcd=random_pts, 189 | background_rgb=background_rgb, 190 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 191 | model_input = model_input, 192 | depth_from_inside_only = self.get_param_in_phase(self.depth_from_inside_only_s, train_phase), 193 | ) 194 | 195 | color_fine = render_out['color_fine'] 196 | s_val = render_out['s_val'] 197 | cdf_fine = render_out['cdf_fine'] 198 | gradient_error = render_out['gradient_error'] 199 | weight_max = render_out['weight_max'] 200 | weight_sum = render_out['weight_sum'] 201 | pseudo_pts_reg_loss = render_out['pseudo_pts_loss'] 202 | local_loss = render_out['local_loss'] 203 | 204 | query_sdf = render_out['sdf'] 205 | query_pts = render_out['query_pts'] 206 | with torch.no_grad(): 207 | udf = self.shape_scale * self.udf_network.udf((query_pts - self.shape_center) / self.shape_scale) 208 | udf = udf.reshape(query_sdf.size()) 209 | 210 | udf[udf < self.udf_thresh] = 0.0 211 | udf_residual = torch.abs(torch.abs(query_sdf) - udf) 212 | udf_residual[(udf > self.udf_thresh)] = 0.0 213 | global_loss = F.l1_loss(udf_residual, torch.zeros_like(udf_residual)) 214 | 215 | random_pcd = self.pointcloud[torch.randperm(self.pointcloud.shape[0])[:30000]] 216 | sdf = self.renderer.sdf_network.sdf(random_pcd) 217 | pcd_loss = F.l1_loss(sdf, torch.zeros_like(sdf), 218 | reduction='sum') / random_pcd.shape[0] 219 | 220 | # Loss 221 | color_error = (color_fine - true_rgb) * mask 222 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum 223 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 224 | 225 | eikonal_loss = gradient_error 226 | 227 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask) 228 | 229 | pseudo_pts_reg_loss = self.get_param_in_phase(self.pseudo_reg_weights, train_phase) * pseudo_pts_reg_loss 230 | local_loss = self.get_param_in_phase(self.local_weights, train_phase) * local_loss 231 | 232 | loss = eikonal_loss * self.igr_weight +\ 233 | mask_loss * self.mask_weight +\ 234 | global_loss * 0.1 235 | 236 | if not prior_initialization: 237 | loss += color_fine_loss +\ 238 | pcd_loss +\ 239 | pseudo_pts_reg_loss +\ 240 | local_loss 241 | 242 | self.optimizer.zero_grad() 243 | loss.backward() 244 | self.optimizer.step() 245 | 246 | self.iter_step += 1 247 | 248 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 249 | self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step) 250 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step) 251 | self.writer.add_scalar('Loss/pseudo_pts_reg_loss', pseudo_pts_reg_loss, self.iter_step) 252 | self.writer.add_scalar('Loss/local_loss', local_loss, self.iter_step) 253 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step) 254 | self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step) 255 | self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step) 256 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 257 | 258 | if self.iter_step % self.report_freq == 0: 259 | print(self.base_exp_dir) 260 | print('iter:{:8>d} loss = {} lr={}'.format(self.iter_step, loss, self.optimizer.param_groups[0]['lr'])) 261 | 262 | if self.iter_step % self.save_freq == 0: 263 | self.save_checkpoint() 264 | 265 | if self.iter_step % self.val_freq == 0: 266 | self.validate_image() 267 | 268 | if self.iter_step % self.val_mesh_freq == 0: 269 | self.validate_mesh() 270 | 271 | self.update_learning_rate() 272 | 273 | if self.iter_step % len(image_perm) == 0: 274 | image_perm = self.get_image_perm() 275 | 276 | def get_image_perm(self): 277 | return torch.randperm(self.dataset.n_images) 278 | 279 | def get_cos_anneal_ratio(self): 280 | if self.anneal_end == 0.0: 281 | return 1.0 282 | else: 283 | return np.min([1.0, self.iter_step / self.anneal_end]) 284 | 285 | def update_learning_rate(self): 286 | if self.iter_step < self.warm_up_end: 287 | learning_factor = self.iter_step / self.warm_up_end 288 | else: 289 | alpha = self.learning_rate_alpha 290 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) 291 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha 292 | 293 | for g in self.optimizer.param_groups: 294 | g['lr'] = self.learning_rate * learning_factor 295 | 296 | def file_backup(self): 297 | dir_lis = self.conf['general.recording'] 298 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 299 | for dir_name in dir_lis: 300 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 301 | os.makedirs(cur_dir, exist_ok=True) 302 | files = os.listdir(dir_name) 303 | for f_name in files: 304 | if f_name[-3:] == '.py': 305 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 306 | 307 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 308 | 309 | def load_checkpoint(self, checkpoint_name): 310 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) 311 | self.nerf_outside.load_state_dict(checkpoint['nerf']) 312 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine']) 313 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) 314 | self.color_network.load_state_dict(checkpoint['color_network_fine']) 315 | self.optimizer.load_state_dict(checkpoint['optimizer']) 316 | self.iter_step = checkpoint['iter_step'] 317 | 318 | logging.info('End') 319 | 320 | def save_checkpoint(self): 321 | checkpoint = { 322 | 'nerf': self.nerf_outside.state_dict(), 323 | 'sdf_network_fine': self.sdf_network.state_dict(), 324 | 'variance_network_fine': self.deviation_network.state_dict(), 325 | 'color_network_fine': self.color_network.state_dict(), 326 | 'optimizer': self.optimizer.state_dict(), 327 | 'iter_step': self.iter_step, 328 | } 329 | 330 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 331 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 332 | 333 | def validate_image(self, idx=-1, resolution_level=-1): 334 | if idx < 0: 335 | idx = np.random.randint(self.dataset.n_images) 336 | 337 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) 338 | 339 | if resolution_level < 0: 340 | resolution_level = self.validate_resolution_level 341 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 342 | H, W, _ = rays_o.shape 343 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 344 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 345 | 346 | out_rgb_fine = [] 347 | out_normal_fine = [] 348 | 349 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 350 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 351 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 352 | 353 | render_out = self.renderer.render(rays_o_batch, 354 | rays_d_batch, 355 | near, 356 | far, 357 | None, 358 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 359 | background_rgb=background_rgb, t=self.iter_step + 1) 360 | 361 | def feasible(key): return (key in render_out) and (render_out[key] is not None) 362 | 363 | if feasible('color_fine'): 364 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 365 | if feasible('gradients') and feasible('weights'): 366 | n_samples = self.renderer.n_samples + self.renderer.n_importance 367 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] 368 | if feasible('inside_sphere'): 369 | normals = normals * render_out['inside_sphere'][..., None] 370 | normals = normals.sum(dim=1).detach().cpu().numpy() 371 | out_normal_fine.append(normals) 372 | del render_out 373 | 374 | img_fine = None 375 | if len(out_rgb_fine) > 0: 376 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 377 | 378 | normal_img = None 379 | if len(out_normal_fine) > 0: 380 | normal_img = np.concatenate(out_normal_fine, axis=0) 381 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) 382 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]) 383 | .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255) 384 | 385 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 386 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) 387 | 388 | for i in range(img_fine.shape[-1]): 389 | if len(out_rgb_fine) > 0: 390 | cv.imwrite(os.path.join(self.base_exp_dir, 391 | 'validations_fine', 392 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 393 | np.concatenate([img_fine[..., i], 394 | self.dataset.image_at(idx, resolution_level=resolution_level)])) 395 | if len(out_normal_fine) > 0: 396 | cv.imwrite(os.path.join(self.base_exp_dir, 397 | 'normals', 398 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 399 | normal_img[..., i]) 400 | 401 | def rendering_image(self, idx=-1, resolution_level=-1): 402 | if idx < 0: 403 | idx = np.random.randint(self.dataset.n_images) 404 | 405 | if resolution_level < 0: 406 | resolution_level = self.validate_resolution_level 407 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 408 | H, W, _ = rays_o.shape 409 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 410 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 411 | 412 | out_rgb_fine = [] 413 | 414 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 415 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 416 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 417 | 418 | render_out = self.renderer.render(rays_o_batch, 419 | rays_d_batch, 420 | near, 421 | far, 422 | None, 423 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 424 | background_rgb=background_rgb, t=self.iter_step + 1) 425 | 426 | def feasible(key): return (key in render_out) and (render_out[key] is not None) 427 | 428 | if feasible('color_fine'): 429 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 430 | 431 | img_fine = None 432 | if len(out_rgb_fine) > 0: 433 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 434 | 435 | 436 | os.makedirs(os.path.join(self.base_exp_dir, 'rendering_image'), exist_ok=True) 437 | 438 | 439 | for i in range(img_fine.shape[-1]): 440 | if len(out_rgb_fine) > 0: 441 | cv.imwrite(os.path.join(self.base_exp_dir, 442 | 'rendering_image', 443 | '{}.png'.format(idx)), 444 | img_fine[..., i] 445 | ) 446 | 447 | def output_rendering_image (self, resolution = -1 ): 448 | for i in range (self.dataset.n_images): 449 | self.rendering_image(idx= i , resolution_level=resolution) 450 | 451 | 452 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level): 453 | """ 454 | Interpolate view between two cameras. 455 | """ 456 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level) 457 | H, W, _ = rays_o.shape 458 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 459 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 460 | 461 | out_rgb_fine = [] 462 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 463 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 464 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 465 | 466 | render_out = self.renderer.render(rays_o_batch, 467 | rays_d_batch, 468 | near, 469 | far, 470 | None, 471 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 472 | background_rgb=background_rgb, t=self.iter_step + 1) 473 | 474 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 475 | 476 | del render_out 477 | 478 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8) 479 | return img_fine 480 | 481 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0): 482 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 483 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 484 | 485 | vertices, triangles =\ 486 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold) 487 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) 488 | 489 | if world_space: 490 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None] 491 | 492 | mesh = trimesh.Trimesh(vertices, triangles) 493 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}_{}.ply'.format(self.iter_step, resolution))) 494 | 495 | logging.info('End') 496 | 497 | def interpolate_view(self, img_idx_0, img_idx_1): 498 | images = [] 499 | n_frames = 60 500 | for i in range(n_frames): 501 | print(i) 502 | images.append(self.render_novel_image(img_idx_0, 503 | img_idx_1, 504 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, 505 | resolution_level=4)) 506 | for i in range(n_frames): 507 | images.append(images[n_frames - i - 1]) 508 | 509 | fourcc = cv.VideoWriter_fourcc(*'mp4v') 510 | video_dir = os.path.join(self.base_exp_dir, 'render') 511 | os.makedirs(video_dir, exist_ok=True) 512 | h, w, _ = images[0].shape 513 | writer = cv.VideoWriter(os.path.join(video_dir, 514 | '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), 515 | fourcc, 30, (w, h)) 516 | 517 | for image in images: 518 | writer.write(image) 519 | 520 | writer.release() 521 | 522 | 523 | if __name__ == '__main__': 524 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 525 | 526 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 527 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 528 | 529 | parser = argparse.ArgumentParser() 530 | parser.add_argument('--conf', type=str, default='./conf') 531 | parser.add_argument('--udf_dir', type=str, default='udf') 532 | parser.add_argument('--mode', type=str, default='train') 533 | parser.add_argument('--mcube_threshold', type=float, default=0.0) 534 | parser.add_argument('--volume_resolution', type=int, default=512) 535 | parser.add_argument('--is_continue', default=False, action="store_true") 536 | parser.add_argument('--gpu', type=int, default=0) 537 | parser.add_argument('--case', type=str, default='') 538 | parser.add_argument('--world_space', default=False, action="store_true") 539 | parser.add_argument('--ckpt', type=str, default='latest') 540 | 541 | args = parser.parse_args() 542 | 543 | torch.cuda.set_device(args.gpu) 544 | udf_runner = UDFRunner(args, args.conf) 545 | 546 | if args.mode == 'train': 547 | if not os.path.exists(f'{udf_runner.base_exp_dir}/checkpoints/ckpt_060000.pth'): 548 | udf_runner.train() 549 | 550 | runner = Runner(args.conf, args.mode, args.case, args.is_continue, args.ckpt) 551 | 552 | if args.mode == 'train': 553 | if not args.is_continue: 554 | base_exp_dir = runner.base_exp_dir 555 | os.makedirs(base_exp_dir, exist_ok=True) 556 | runner.base_exp_dir = os.path.join(base_exp_dir, 'prior_initialization') 557 | os.makedirs(runner.base_exp_dir, exist_ok=True) 558 | runner.train(prior_initialization=True) 559 | runner.base_exp_dir = base_exp_dir 560 | runner.train() 561 | runner.validate_mesh(world_space=True, resolution=args.volume_resolution, threshold=args.mcube_threshold) 562 | elif args.mode == 'validate_mesh': 563 | runner.validate_mesh(world_space=args.world_space, resolution=args.volume_resolution, threshold=args.mcube_threshold) 564 | elif args.mode == 'render_image': 565 | runner.output_rendering_image(resolution=1) 566 | elif args.mode.startswith('interpolate'): # Interpolate views given two image indices 567 | _, img_idx_0, img_idx_1 = args.mode.split('_') 568 | img_idx_0 = int(img_idx_0) 569 | img_idx_1 = int(img_idx_1) 570 | runner.interpolate_view(img_idx_0, img_idx_1) 571 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import chamfer 4 | 5 | 6 | class ChamferFunction(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, xyz1, xyz2): 9 | dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2) 10 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 11 | 12 | return dist1, dist2 13 | 14 | @staticmethod 15 | def backward(ctx, grad_dist1, grad_dist2): 16 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 17 | grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2) 18 | return grad_xyz1, grad_xyz2 19 | 20 | 21 | class ChamferDistanceL2(torch.nn.Module): 22 | f''' Chamder Distance L2 23 | ''' 24 | def __init__(self, ignore_zeros=False): 25 | super().__init__() 26 | self.ignore_zeros = ignore_zeros 27 | 28 | def forward(self, xyz1, xyz2): 29 | batch_size = xyz1.size(0) 30 | if batch_size == 1 and self.ignore_zeros: 31 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 32 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 33 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 34 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 35 | 36 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 37 | return torch.mean(dist1) + torch.mean(dist2) 38 | 39 | class ChamferDistanceL2_split(torch.nn.Module): 40 | f''' Chamder Distance L2 41 | ''' 42 | def __init__(self, ignore_zeros=False): 43 | super().__init__() 44 | self.ignore_zeros = ignore_zeros 45 | 46 | def forward(self, xyz1, xyz2): 47 | batch_size = xyz1.size(0) 48 | if batch_size == 1 and self.ignore_zeros: 49 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 50 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 51 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 52 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 53 | 54 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 55 | return torch.mean(dist1), torch.mean(dist2) 56 | 57 | class ChamferDistanceL1(torch.nn.Module): 58 | f''' Chamder Distance L1 59 | ''' 60 | def __init__(self, ignore_zeros=False): 61 | super().__init__() 62 | self.ignore_zeros = ignore_zeros 63 | 64 | def forward(self, xyz1, xyz2): 65 | batch_size = xyz1.size(0) 66 | if batch_size == 1 and self.ignore_zeros: 67 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 68 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 69 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 70 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 71 | 72 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 73 | # import pdb 74 | # pdb.set_trace() 75 | dist1 = torch.sqrt(dist1) 76 | dist2 = torch.sqrt(dist2) 77 | return (torch.mean(dist1) + torch.mean(dist2))/2 78 | 79 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | __global__ void chamfer_dist_kernel(int batch_size, 8 | int n, 9 | const float* xyz1, 10 | int m, 11 | const float* xyz2, 12 | float* dist, 13 | int* indexes) { 14 | const int batch = 512; 15 | __shared__ float buf[batch * 3]; 16 | for (int i = blockIdx.x; i < batch_size; i += gridDim.x) { 17 | for (int k2 = 0; k2 < m; k2 += batch) { 18 | int end_k = min(m, k2 + batch) - k2; 19 | for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) { 20 | buf[j] = xyz2[(i * m + k2) * 3 + j]; 21 | } 22 | __syncthreads(); 23 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; 24 | j += blockDim.x * gridDim.y) { 25 | float x1 = xyz1[(i * n + j) * 3 + 0]; 26 | float y1 = xyz1[(i * n + j) * 3 + 1]; 27 | float z1 = xyz1[(i * n + j) * 3 + 2]; 28 | float best_dist = 0; 29 | int best_dist_index = 0; 30 | int end_ka = end_k - (end_k & 3); 31 | if (end_ka == batch) { 32 | for (int k = 0; k < batch; k += 4) { 33 | { 34 | float x2 = buf[k * 3 + 0] - x1; 35 | float y2 = buf[k * 3 + 1] - y1; 36 | float z2 = buf[k * 3 + 2] - z1; 37 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 38 | 39 | if (k == 0 || dist < best_dist) { 40 | best_dist = dist; 41 | best_dist_index = k + k2; 42 | } 43 | } 44 | { 45 | float x2 = buf[k * 3 + 3] - x1; 46 | float y2 = buf[k * 3 + 4] - y1; 47 | float z2 = buf[k * 3 + 5] - z1; 48 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 49 | if (dist < best_dist) { 50 | best_dist = dist; 51 | best_dist_index = k + k2 + 1; 52 | } 53 | } 54 | { 55 | float x2 = buf[k * 3 + 6] - x1; 56 | float y2 = buf[k * 3 + 7] - y1; 57 | float z2 = buf[k * 3 + 8] - z1; 58 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 59 | if (dist < best_dist) { 60 | best_dist = dist; 61 | best_dist_index = k + k2 + 2; 62 | } 63 | } 64 | { 65 | float x2 = buf[k * 3 + 9] - x1; 66 | float y2 = buf[k * 3 + 10] - y1; 67 | float z2 = buf[k * 3 + 11] - z1; 68 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 69 | if (dist < best_dist) { 70 | best_dist = dist; 71 | best_dist_index = k + k2 + 3; 72 | } 73 | } 74 | } 75 | } else { 76 | for (int k = 0; k < end_ka; k += 4) { 77 | { 78 | float x2 = buf[k * 3 + 0] - x1; 79 | float y2 = buf[k * 3 + 1] - y1; 80 | float z2 = buf[k * 3 + 2] - z1; 81 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 82 | if (k == 0 || dist < best_dist) { 83 | best_dist = dist; 84 | best_dist_index = k + k2; 85 | } 86 | } 87 | { 88 | float x2 = buf[k * 3 + 3] - x1; 89 | float y2 = buf[k * 3 + 4] - y1; 90 | float z2 = buf[k * 3 + 5] - z1; 91 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 92 | if (dist < best_dist) { 93 | best_dist = dist; 94 | best_dist_index = k + k2 + 1; 95 | } 96 | } 97 | { 98 | float x2 = buf[k * 3 + 6] - x1; 99 | float y2 = buf[k * 3 + 7] - y1; 100 | float z2 = buf[k * 3 + 8] - z1; 101 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 102 | if (dist < best_dist) { 103 | best_dist = dist; 104 | best_dist_index = k + k2 + 2; 105 | } 106 | } 107 | { 108 | float x2 = buf[k * 3 + 9] - x1; 109 | float y2 = buf[k * 3 + 10] - y1; 110 | float z2 = buf[k * 3 + 11] - z1; 111 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 112 | if (dist < best_dist) { 113 | best_dist = dist; 114 | best_dist_index = k + k2 + 3; 115 | } 116 | } 117 | } 118 | } 119 | for (int k = end_ka; k < end_k; k++) { 120 | float x2 = buf[k * 3 + 0] - x1; 121 | float y2 = buf[k * 3 + 1] - y1; 122 | float z2 = buf[k * 3 + 2] - z1; 123 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 124 | if (k == 0 || dist < best_dist) { 125 | best_dist = dist; 126 | best_dist_index = k + k2; 127 | } 128 | } 129 | if (k2 == 0 || dist[(i * n + j)] > best_dist) { 130 | dist[(i * n + j)] = best_dist; 131 | indexes[(i * n + j)] = best_dist_index; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | std::vector chamfer_cuda_forward(torch::Tensor xyz1, 140 | torch::Tensor xyz2) { 141 | const int batch_size = xyz1.size(0); 142 | const int n = xyz1.size(1); // num_points point cloud A 143 | const int m = xyz2.size(1); // num_points point cloud B 144 | torch::Tensor dist1 = 145 | torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat)); 146 | torch::Tensor dist2 = 147 | torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat)); 148 | torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt)); 149 | torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt)); 150 | 151 | chamfer_dist_kernel<<>>( 152 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), 153 | dist1.data_ptr(), idx1.data_ptr()); 154 | chamfer_dist_kernel<<>>( 155 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(), 156 | dist2.data_ptr(), idx2.data_ptr()); 157 | 158 | cudaError_t err = cudaGetLastError(); 159 | if (err != cudaSuccess) { 160 | printf("Error in chamfer_cuda_forward: %s\n", cudaGetErrorString(err)); 161 | } 162 | return {dist1, dist2, idx1, idx2}; 163 | } 164 | 165 | __global__ void chamfer_dist_grad_kernel(int b, 166 | int n, 167 | const float* xyz1, 168 | int m, 169 | const float* xyz2, 170 | const float* grad_dist1, 171 | const int* idx1, 172 | float* grad_xyz1, 173 | float* grad_xyz2) { 174 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 175 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; 176 | j += blockDim.x * gridDim.y) { 177 | float x1 = xyz1[(i * n + j) * 3 + 0]; 178 | float y1 = xyz1[(i * n + j) * 3 + 1]; 179 | float z1 = xyz1[(i * n + j) * 3 + 2]; 180 | int j2 = idx1[i * n + j]; 181 | float x2 = xyz2[(i * m + j2) * 3 + 0]; 182 | float y2 = xyz2[(i * m + j2) * 3 + 1]; 183 | float z2 = xyz2[(i * m + j2) * 3 + 2]; 184 | float g = grad_dist1[i * n + j] * 2; 185 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); 186 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); 187 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); 188 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); 189 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); 190 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); 191 | } 192 | } 193 | } 194 | 195 | std::vector chamfer_cuda_backward(torch::Tensor xyz1, 196 | torch::Tensor xyz2, 197 | torch::Tensor idx1, 198 | torch::Tensor idx2, 199 | torch::Tensor grad_dist1, 200 | torch::Tensor grad_dist2) { 201 | const int batch_size = xyz1.size(0); 202 | const int n = xyz1.size(1); // num_points point cloud A 203 | const int m = xyz2.size(1); // num_points point cloud B 204 | torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat)); 205 | torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat)); 206 | 207 | chamfer_dist_grad_kernel<<>>( 208 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), 209 | grad_dist1.data_ptr(), idx1.data_ptr(), 210 | grad_xyz1.data_ptr(), grad_xyz2.data_ptr()); 211 | chamfer_dist_grad_kernel<<>>( 212 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(), 213 | grad_dist2.data_ptr(), idx2.data_ptr(), 214 | grad_xyz2.data_ptr(), grad_xyz1.data_ptr()); 215 | 216 | cudaError_t err = cudaGetLastError(); 217 | if (err != cudaSuccess) { 218 | printf("Error in chamfer_cuda_backward: %s\n", cudaGetErrorString(err)); 219 | } 220 | return {grad_xyz1, grad_xyz2}; 221 | } 222 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector chamfer_cuda_forward(torch::Tensor xyz1, 5 | torch::Tensor xyz2); 6 | 7 | std::vector chamfer_cuda_backward(torch::Tensor xyz1, 8 | torch::Tensor xyz2, 9 | torch::Tensor idx1, 10 | torch::Tensor idx2, 11 | torch::Tensor grad_dist1, 12 | torch::Tensor grad_dist2); 13 | 14 | std::vector chamfer_forward(torch::Tensor xyz1, 15 | torch::Tensor xyz2) { 16 | return chamfer_cuda_forward(xyz1, xyz2); 17 | } 18 | 19 | std::vector chamfer_backward(torch::Tensor xyz1, 20 | torch::Tensor xyz2, 21 | torch::Tensor idx1, 22 | torch::Tensor idx2, 23 | torch::Tensor grad_dist1, 24 | torch::Tensor grad_dist2) { 25 | return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2); 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)"); 30 | m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)"); 31 | } 32 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup(name='chamfer', 5 | version='2.0.0', 6 | ext_modules=[ 7 | CUDAExtension('chamfer', [ 8 | 'chamfer_cuda.cpp', 9 | 'chamfer.cu', 10 | ]), 11 | ], 12 | cmdclass={'build_ext': BuildExtension}) 13 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import unittest 5 | 6 | 7 | from torch.autograd import gradcheck 8 | 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 10 | from extensions.chamfer_dist import ChamferFunction 11 | 12 | 13 | class ChamferDistanceTestCase(unittest.TestCase): 14 | def test_chamfer_dist(self): 15 | x = torch.rand(4, 64, 3).double() 16 | y = torch.rand(4, 128, 3).double() 17 | x.requires_grad = True 18 | y.requires_grad = True 19 | print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])) 20 | 21 | 22 | 23 | if __name__ == '__main__': 24 | # unittest.main() 25 | import pdb 26 | x = torch.rand(32,128,3) 27 | y = torch.rand(32,128,3) 28 | pdb.set_trace() 29 | -------------------------------------------------------------------------------- /media/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/media/comparison.png -------------------------------------------------------------------------------- /media/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/media/pipeline.png -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from scipy.spatial.transform import Rotation as Rot 8 | from scipy.spatial.transform import Slerp 9 | from tools.feat_utils import load_pair, load_cam, scale_camera, FeatExt 10 | 11 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 12 | def load_K_Rt_from_P(filename, P=None): 13 | if P is None: 14 | lines = open(filename).read().splitlines() 15 | if len(lines) == 4: 16 | lines = lines[1:] 17 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 18 | P = np.asarray(lines).astype(np.float32).squeeze() 19 | 20 | out = cv.decomposeProjectionMatrix(P) 21 | K = out[0] 22 | R = out[1] 23 | t = out[2] 24 | 25 | K = K / K[2, 2] 26 | intrinsics = np.eye(4) 27 | intrinsics[:3, :3] = K 28 | 29 | pose = np.eye(4, dtype=np.float32) 30 | pose[:3, :3] = R.transpose() 31 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 32 | 33 | return intrinsics, pose 34 | 35 | 36 | class Dataset: 37 | def __init__(self, conf): 38 | super(Dataset, self).__init__() 39 | print('Load data: Begin') 40 | self.device = torch.device('cuda') 41 | self.conf = conf 42 | 43 | self.data_dir = conf.get_string('data_dir') 44 | self.render_cameras_name = conf.get_string('render_cameras_name') 45 | self.object_cameras_name = conf.get_string('object_cameras_name') 46 | 47 | self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True) 48 | self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1) 49 | 50 | camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name)) 51 | self.camera_dict = camera_dict 52 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png'))) 53 | self.n_images = len(self.images_lis) 54 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0 55 | # world_mat is a projection matrix from world to image 56 | self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 57 | 58 | self.scale_mats_np = [] 59 | 60 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. 61 | self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 62 | 63 | self.intrinsics_all = [] 64 | self.pose_all = [] 65 | 66 | for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np): 67 | P = world_mat @ scale_mat 68 | P = P[:3, :4] 69 | intrinsics, pose = load_K_Rt_from_P(None, P) 70 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 71 | self.pose_all.append(torch.from_numpy(pose).float()) 72 | 73 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] 74 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] 75 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4] 76 | self.focal = self.intrinsics_all[0][0, 0] 77 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4] 78 | self.H, self.W = self.images.shape[1], self.images.shape[2] 79 | self.image_pixels = self.H * self.W 80 | self.pair = load_pair(f'{self.data_dir}/cam4feat/pair.txt') 81 | self.num_src = 2 82 | self.depth_cams = torch.stack( 83 | [torch.from_numpy( 84 | load_cam(f'{self.data_dir}/cam4feat/cam_{self.pair["id_list"][i].zfill(8)}_flow3.txt', 256, 1)).to(torch.float32) 85 | for i in range(self.n_images)], dim=0) 86 | self.feat_img_scale = 2 87 | self.cams_hd = torch.stack( # upsample of 2 from depth_cams, not 1200 * 1600 88 | [scale_camera(self.depth_cams[i], self.feat_img_scale) for i in range(self.n_images)] # NOTE: hard code 89 | ) 90 | self.img_res = self.images.shape[-3:-1] 91 | # [n_images, 3, 768, 1024] 92 | self.rgb_2xd = torch.stack([ 93 | F.interpolate( 94 | self.images[idx].reshape(-1,3).permute(1, 0).view(1, 3, *self.img_res), # 1200 x 1600 95 | size=(self.conf.feat_map_h * self.feat_img_scale, self.conf.feat_map_w * self.feat_img_scale), 96 | mode='bilinear', align_corners=False)[0] 97 | for idx in range(self.n_images) 98 | ], dim=0) # v3hw 99 | mean = torch.tensor([0.485, 0.456, 0.406]).float().cpu() 100 | std = torch.tensor([0.229, 0.224, 0.225]).float().cpu() 101 | self.rgb_2xd = (self.rgb_2xd / 2 + 0.5 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) 102 | self.size = torch.from_numpy(self.scale_mats_np[0]).float()[0, 0] * 2 103 | self.center = torch.from_numpy(self.scale_mats_np[0]).float()[:3, 3] 104 | 105 | feat_ext = FeatExt().cuda() 106 | feat_ext.eval() 107 | for p in feat_ext.parameters(): 108 | p.requires_grad = False 109 | feats = [] 110 | for start_i in range(0, self.n_images): 111 | eval_batch = self.rgb_2xd[start_i:start_i + 1] 112 | feat2 = feat_ext(eval_batch.cuda())[2] # .detach().cpu() 113 | feats.append(feat2) 114 | self.feats = torch.cat(feats, dim=0) 115 | self.feats.requires_grad = False 116 | 117 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0]) 118 | object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0]) 119 | # Object scale mat: region of interest to **extract mesh** 120 | object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0'] 121 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None] 122 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None] 123 | self.object_bbox_min = object_bbox_min[:3, 0] 124 | self.object_bbox_max = object_bbox_max[:3, 0] 125 | 126 | print('Load data: End') 127 | 128 | def gen_rays_at(self, img_idx, resolution_level=1): 129 | """ 130 | Generate rays at world space from one camera. 131 | """ 132 | l = resolution_level 133 | tx = torch.linspace(0, self.W - 1, self.W // l) 134 | ty = torch.linspace(0, self.H - 1, self.H // l) 135 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 136 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 137 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 138 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 139 | rays_d = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3 140 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_d.shape) # W, H, 3 141 | return rays_o.transpose(0, 1), rays_d.transpose(0, 1) 142 | 143 | def gen_random_rays_at(self, img_idx, batch_size): 144 | """ 145 | Generate random rays at world space from one camera. 146 | """ 147 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) 148 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) 149 | color = self.images[img_idx.to(self.images.device)][(pixels_y.to(self.images.device), pixels_x.to(self.images.device))] # batch_size, 3 150 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3 151 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3 152 | rays_d_norm = torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) 153 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 154 | rays_d = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_d[:, :, None]).squeeze() # batch_size, 3 155 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_d.shape) # batch_size, 3 156 | 157 | id = self.pair['id_list'][img_idx] 158 | src_ids = self.pair[id]['pair'] 159 | src_idxs = [self.pair[src_id]['index'] for src_id in src_ids][:self.num_src] 160 | 161 | sample = {} 162 | sample['depth_cams'] = self.depth_cams[[img_idx]] 163 | sample['size'] = self.size 164 | sample['center'] = self.center 165 | sample["feat"] = self.feats[img_idx] 166 | sample["feat_src"] = self.feats[src_idxs] 167 | sample["cam"] = self.cams_hd[img_idx] 168 | sample["src_cams"] = self.cams_hd[src_idxs] 169 | sample['rays_d_norm'] = rays_d_norm 170 | sample['H'] = self.H 171 | sample['W'] = self.W 172 | sample['src_idxs'] = src_idxs 173 | 174 | return torch.cat([rays_o.cpu(), rays_d.cpu(), color], dim=-1).cuda() ,sample # batch_size, 9 175 | 176 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 177 | """ 178 | Interpolate pose between two cameras. 179 | """ 180 | l = resolution_level 181 | tx = torch.linspace(0, self.W - 1, self.W // l) 182 | ty = torch.linspace(0, self.H - 1, self.H // l) 183 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 184 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 185 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 186 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 187 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 188 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 189 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 190 | pose_0 = np.linalg.inv(pose_0) 191 | pose_1 = np.linalg.inv(pose_1) 192 | rot_0 = pose_0[:3, :3] 193 | rot_1 = pose_1[:3, :3] 194 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 195 | key_times = [0, 1] 196 | slerp = Slerp(key_times, rots) 197 | rot = slerp(ratio) 198 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 199 | pose = pose.astype(np.float32) 200 | pose[:3, :3] = rot.as_matrix() 201 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 202 | pose = np.linalg.inv(pose) 203 | rot = torch.from_numpy(pose[:3, :3]).cuda() 204 | trans = torch.from_numpy(pose[:3, 3]).cuda() 205 | rays_d = torch.matmul(rot[None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3 206 | rays_o = trans[None, None, :3].expand(rays_d.shape) # W, H, 3 207 | return rays_o.transpose(0, 1), rays_d.transpose(0, 1) 208 | 209 | def gen_rays_between_from_pts(self, idx_0, idx_1, ratio, pts, resolution_level=1): 210 | """ 211 | Interpolate pose between two cameras. 212 | """ 213 | l = resolution_level 214 | tx = torch.linspace(0, self.W - 1, self.W // l) 215 | ty = torch.linspace(0, self.H - 1, self.H // l) 216 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 217 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 218 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 219 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 220 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 221 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 222 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 223 | pose_0 = np.linalg.inv(pose_0) 224 | pose_1 = np.linalg.inv(pose_1) 225 | rot_0 = pose_0[:3, :3] 226 | rot_1 = pose_1[:3, :3] 227 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 228 | key_times = [0, 1] 229 | slerp = Slerp(key_times, rots) 230 | rot = slerp(ratio) 231 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 232 | pose = pose.astype(np.float32) 233 | pose[:3, :3] = rot.as_matrix() 234 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 235 | pose = np.linalg.inv(pose) 236 | rot = torch.from_numpy(pose[:3, :3]).cuda() 237 | trans = torch.from_numpy(pose[:3, 3]).cuda() 238 | # rays_d = torch.matmul(rot[None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3 239 | # import pdb; pdb.set_trace() 240 | rays_o = trans[None, None, :3].expand(pts.shape) # 1, N, 3 241 | rays_d = F.normalize(pts - rays_o, dim=-1) # 1, N, 3 242 | return rays_o.squeeze(0), rays_d.squeeze(0) 243 | 244 | def near_far_from_sphere(self, rays_o, rays_d): 245 | a = torch.sum(rays_d**2, dim=-1, keepdim=True) 246 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 247 | mid = 0.5 * (-b) / a 248 | near = mid - 1.0 249 | far = mid + 1.0 250 | return near, far 251 | 252 | def image_at(self, idx, resolution_level): 253 | img = cv.imread(self.images_lis[idx]) 254 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 255 | 256 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | # Borrowed from https://github.com/bmild/nerf. 7 | class Embedder: 8 | def __init__(self, **kwargs): 9 | self.kwargs = kwargs 10 | self.create_embedding_fn() 11 | 12 | def create_embedding_fn(self): 13 | embed_fns = [] 14 | d = self.kwargs['input_dims'] 15 | out_dim = 0 16 | if self.kwargs['include_input']: 17 | embed_fns.append(lambda x: x) 18 | out_dim += d 19 | 20 | max_freq = self.kwargs['max_freq_log2'] 21 | N_freqs = self.kwargs['num_freqs'] 22 | 23 | if self.kwargs['log_sampling']: 24 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 25 | else: 26 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 27 | 28 | for freq in freq_bands: 29 | for p_fn in self.kwargs['periodic_fns']: 30 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 31 | out_dim += d 32 | 33 | self.embed_fns = embed_fns 34 | self.out_dim = out_dim 35 | 36 | def embed(self, inputs): 37 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 38 | 39 | 40 | def get_embedder(multires, input_dims=3): 41 | embed_kwargs = { 42 | 'include_input': True, 43 | 'input_dims': input_dims, 44 | 'max_freq_log2': multires-1, 45 | 'num_freqs': multires, 46 | 'log_sampling': True, 47 | 'periodic_fns': [torch.sin, torch.cos], 48 | } 49 | 50 | embedder_obj = Embedder(**embed_kwargs) 51 | def embed(x, eo=embedder_obj): return eo.embed(x) 52 | return embed, embedder_obj.out_dim 53 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | d_in, 12 | d_out, 13 | d_hidden, 14 | n_layers, 15 | skip_in=(4,), 16 | multires=0, 17 | bias=0.5, 18 | scale=1, 19 | geometric_init=True, 20 | weight_norm=True, 21 | inside_outside=False): 22 | super(SDFNetwork, self).__init__() 23 | 24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 25 | 26 | self.embed_fn_fine = None 27 | 28 | if multires > 0: 29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 30 | self.embed_fn_fine = embed_fn 31 | dims[0] = input_ch 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | self.scale = scale 36 | 37 | for l in range(0, self.num_layers - 1): 38 | if l + 1 in self.skip_in: 39 | out_dim = dims[l + 1] - dims[0] 40 | else: 41 | out_dim = dims[l + 1] 42 | 43 | lin = nn.Linear(dims[l], out_dim) 44 | 45 | if geometric_init: 46 | if l == self.num_layers - 2: 47 | if not inside_outside: 48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 49 | torch.nn.init.constant_(lin.bias, -bias) 50 | else: 51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 52 | torch.nn.init.constant_(lin.bias, bias) 53 | elif multires > 0 and l == 0: 54 | torch.nn.init.constant_(lin.bias, 0.0) 55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 57 | elif multires > 0 and l in self.skip_in: 58 | torch.nn.init.constant_(lin.bias, 0.0) 59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 61 | else: 62 | torch.nn.init.constant_(lin.bias, 0.0) 63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 64 | 65 | if weight_norm: 66 | lin = nn.utils.weight_norm(lin) 67 | 68 | setattr(self, "lin" + str(l), lin) 69 | 70 | self.activation = nn.Softplus(beta=100) 71 | 72 | def forward(self, inputs): 73 | inputs = inputs * self.scale 74 | if self.embed_fn_fine is not None: 75 | inputs = self.embed_fn_fine(inputs) 76 | 77 | x = inputs 78 | for l in range(0, self.num_layers - 1): 79 | lin = getattr(self, "lin" + str(l)) 80 | 81 | if l in self.skip_in: 82 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 83 | 84 | x = lin(x) 85 | 86 | if l < self.num_layers - 2: 87 | x = self.activation(x) 88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 89 | 90 | def sdf(self, x): 91 | return self.forward(x)[:, :1] 92 | 93 | def sdf_hidden_appearance(self, x): 94 | return self.forward(x) 95 | 96 | def gradient(self, x): 97 | x.requires_grad_(True) 98 | y = self.sdf(x) 99 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 100 | gradients = torch.autograd.grad( 101 | outputs=y, 102 | inputs=x, 103 | grad_outputs=d_output, 104 | create_graph=True, 105 | retain_graph=True, 106 | only_inputs=True)[0] 107 | return gradients.unsqueeze(1) 108 | 109 | 110 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 111 | class RenderingNetwork(nn.Module): 112 | def __init__(self, 113 | d_feature, 114 | mode, 115 | d_in, 116 | d_out, 117 | d_hidden, 118 | n_layers, 119 | weight_norm=True, 120 | multires_view=0, 121 | squeeze_out=True): 122 | super().__init__() 123 | 124 | self.mode = mode 125 | self.squeeze_out = squeeze_out 126 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 127 | 128 | self.embedview_fn = None 129 | if multires_view > 0: 130 | embedview_fn, input_ch = get_embedder(multires_view) 131 | self.embedview_fn = embedview_fn 132 | dims[0] += (input_ch - 3) 133 | 134 | self.num_layers = len(dims) 135 | 136 | for l in range(0, self.num_layers - 1): 137 | out_dim = dims[l + 1] 138 | lin = nn.Linear(dims[l], out_dim) 139 | 140 | if weight_norm: 141 | lin = nn.utils.weight_norm(lin) 142 | 143 | setattr(self, "lin" + str(l), lin) 144 | 145 | self.relu = nn.ReLU() 146 | 147 | def forward(self, points, normals, view_dirs, feature_vectors): 148 | if self.embedview_fn is not None: 149 | view_dirs = self.embedview_fn(view_dirs) 150 | 151 | rendering_input = None 152 | 153 | if self.mode == 'idr': 154 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 155 | elif self.mode == 'no_view_dir': 156 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 157 | elif self.mode == 'no_normal': 158 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 159 | 160 | x = rendering_input 161 | 162 | for l in range(0, self.num_layers - 1): 163 | lin = getattr(self, "lin" + str(l)) 164 | 165 | x = lin(x) 166 | 167 | if l < self.num_layers - 2: 168 | x = self.relu(x) 169 | 170 | if self.squeeze_out: 171 | x = torch.sigmoid(x) 172 | return x 173 | 174 | 175 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch 176 | class NeRF(nn.Module): 177 | def __init__(self, 178 | D=8, 179 | W=256, 180 | d_in=3, 181 | d_in_view=3, 182 | multires=0, 183 | multires_view=0, 184 | output_ch=4, 185 | skips=[4], 186 | use_viewdirs=False): 187 | super(NeRF, self).__init__() 188 | self.D = D 189 | self.W = W 190 | self.d_in = d_in 191 | self.d_in_view = d_in_view 192 | self.input_ch = 3 193 | self.input_ch_view = 3 194 | self.embed_fn = None 195 | self.embed_fn_view = None 196 | 197 | if multires > 0: 198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 199 | self.embed_fn = embed_fn 200 | self.input_ch = input_ch 201 | 202 | if multires_view > 0: 203 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 204 | self.embed_fn_view = embed_fn_view 205 | self.input_ch_view = input_ch_view 206 | 207 | self.skips = skips 208 | self.use_viewdirs = use_viewdirs 209 | 210 | self.pts_linears = nn.ModuleList( 211 | [nn.Linear(self.input_ch, W)] + 212 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 213 | 214 | ### Implementation according to the official code release 215 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 216 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 217 | 218 | ### Implementation according to the paper 219 | # self.views_linears = nn.ModuleList( 220 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 221 | 222 | if use_viewdirs: 223 | self.feature_linear = nn.Linear(W, W) 224 | self.alpha_linear = nn.Linear(W, 1) 225 | self.rgb_linear = nn.Linear(W // 2, 3) 226 | else: 227 | self.output_linear = nn.Linear(W, output_ch) 228 | 229 | def forward(self, input_pts, input_views): 230 | if self.embed_fn is not None: 231 | input_pts = self.embed_fn(input_pts) 232 | if self.embed_fn_view is not None: 233 | input_views = self.embed_fn_view(input_views) 234 | 235 | h = input_pts 236 | for i, l in enumerate(self.pts_linears): 237 | h = self.pts_linears[i](h) 238 | h = F.relu(h) 239 | if i in self.skips: 240 | h = torch.cat([input_pts, h], -1) 241 | 242 | if self.use_viewdirs: 243 | alpha = self.alpha_linear(h) 244 | feature = self.feature_linear(h) 245 | h = torch.cat([feature, input_views], -1) 246 | 247 | for i, l in enumerate(self.views_linears): 248 | h = self.views_linears[i](h) 249 | h = F.relu(h) 250 | 251 | rgb = self.rgb_linear(h) 252 | return alpha, rgb 253 | else: 254 | assert False 255 | 256 | 257 | class SingleVarianceNetwork(nn.Module): 258 | def __init__(self, init_val): 259 | super(SingleVarianceNetwork, self).__init__() 260 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) 261 | 262 | def forward(self, x): 263 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 264 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import mcubes 5 | import tools.feat_utils as feat_utils 6 | 7 | 8 | # interpolate SDF zero-crossing points 9 | def find_surface_points(sdf, d_all, device='cuda'): 10 | # shape of sdf and d_all: only inside 11 | sdf_bool_1 = sdf[...,1:] * sdf[...,:-1] < 0 12 | # only find backward facing surface points, not forward facing 13 | sdf_bool_2 = sdf[...,1:] < sdf[...,:-1] 14 | sdf_bool = torch.logical_and(sdf_bool_1, sdf_bool_2) 15 | 16 | max, max_indices = torch.max(sdf_bool, dim=2) 17 | network_mask = max > 0 18 | d_surface = torch.zeros_like(network_mask, device=device).float() 19 | 20 | sdf_0 = torch.gather(sdf[network_mask], 1, max_indices[network_mask][..., None]).squeeze() 21 | sdf_1 = torch.gather(sdf[network_mask], 1, max_indices[network_mask][..., None]+1).squeeze() 22 | d_0 = torch.gather(d_all[network_mask], 1, max_indices[network_mask][..., None]).squeeze() 23 | d_1 = torch.gather(d_all[network_mask], 1, max_indices[network_mask][..., None]+1).squeeze() 24 | d_surface[network_mask] = (sdf_0 * d_1 - sdf_1 * d_0) / (sdf_0-sdf_1) 25 | 26 | return d_surface, network_mask 27 | 28 | def extract_fields(bound_min, bound_max, resolution, query_func): 29 | N = 64 30 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 31 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 32 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 33 | 34 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 35 | with torch.no_grad(): 36 | for xi, xs in enumerate(X): 37 | for yi, ys in enumerate(Y): 38 | for zi, zs in enumerate(Z): 39 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 40 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) 41 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 42 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 43 | return u 44 | 45 | 46 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): 47 | print('threshold: {}'.format(threshold)) 48 | u = extract_fields(bound_min, bound_max, resolution, query_func) 49 | vertices, triangles = mcubes.marching_cubes(u, threshold) 50 | b_max_np = bound_max.detach().cpu().numpy() 51 | b_min_np = bound_min.detach().cpu().numpy() 52 | 53 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 54 | return vertices, triangles 55 | 56 | 57 | def sample_pdf(bins, weights, n_samples, det=False): 58 | # This implementation is from NeRF 59 | # Get pdf 60 | weights = weights + 1e-5 # prevent nans 61 | pdf = weights / torch.sum(weights, -1, keepdim=True) 62 | cdf = torch.cumsum(pdf, -1) 63 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 64 | # Take uniform samples 65 | if det: 66 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) 67 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 68 | else: 69 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 70 | 71 | # Invert CDF 72 | u = u.contiguous() 73 | inds = torch.searchsorted(cdf, u, right=True) 74 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 75 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 76 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 77 | 78 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 79 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 80 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 81 | 82 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 83 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 84 | t = (u - cdf_g[..., 0]) / denom 85 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 86 | 87 | return samples 88 | 89 | 90 | class NeuSRenderer: 91 | def __init__(self, 92 | nerf, 93 | sdf_network, 94 | deviation_network, 95 | color_network, 96 | dataset, 97 | n_samples, 98 | n_importance, 99 | n_outside, 100 | up_sample_steps, 101 | perturb): 102 | self.nerf = nerf 103 | self.sdf_network = sdf_network 104 | self.deviation_network = deviation_network 105 | self.color_network = color_network 106 | self.dataset = dataset 107 | self.n_samples = n_samples 108 | self.n_importance = n_importance 109 | self.n_outside = n_outside 110 | self.up_sample_steps = up_sample_steps 111 | self.perturb = perturb 112 | 113 | self.feat_ext = feat_utils.FeatExt().cuda() 114 | self.feat_ext.eval() 115 | for p in self.feat_ext.parameters(): 116 | p.requires_grad = False 117 | 118 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None): 119 | """ 120 | Render background 121 | """ 122 | batch_size, n_samples = z_vals.shape 123 | 124 | # Section length 125 | dists = z_vals[..., 1:] - z_vals[..., :-1] 126 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 127 | mid_z_vals = z_vals + dists * 0.5 128 | 129 | 130 | # Section midpoints 131 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 132 | 133 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) 134 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 135 | 136 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 137 | 138 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) 139 | dirs = dirs.reshape(-1, 3) 140 | 141 | density, sampled_color = nerf(pts, dirs) 142 | sampled_color = torch.sigmoid(sampled_color) 143 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) 144 | alpha = alpha.reshape(batch_size, n_samples) 145 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 146 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3) 147 | color = (weights[:, :, None] * sampled_color).sum(dim=1) 148 | if background_rgb is not None: 149 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) 150 | 151 | return { 152 | 'color': color, 153 | 'sampled_color': sampled_color, 154 | 'alpha': alpha, 155 | 'weights': weights, 156 | 'mid_z_vals_out' : mid_z_vals 157 | } 158 | 159 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): 160 | """ 161 | Up sampling give a fixed inv_s 162 | """ 163 | batch_size, n_samples = z_vals.shape 164 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 165 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) 166 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) 167 | sdf = sdf.reshape(batch_size, n_samples) 168 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] 169 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] 170 | mid_sdf = (prev_sdf + next_sdf) * 0.5 171 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 172 | 173 | # ---------------------------------------------------------------------------------------------------------- 174 | # Use min value of [ cos, prev_cos ] 175 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more 176 | # robust when meeting situations like below: 177 | # 178 | # SDF 179 | # ^ 180 | # |\ -----x----... 181 | # | \ / 182 | # | x x 183 | # |---\----/-------------> 0 level 184 | # | \ / 185 | # | \/ 186 | # | 187 | # ---------------------------------------------------------------------------------------------------------- 188 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) 189 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) 190 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) 191 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere 192 | 193 | dist = (next_z_vals - prev_z_vals) 194 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 195 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5 196 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) 197 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s) 198 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 199 | weights = alpha * torch.cumprod( 200 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 201 | 202 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() 203 | return z_samples 204 | 205 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 206 | batch_size, n_samples = z_vals.shape 207 | _, n_importance = new_z_vals.shape 208 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] 209 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 210 | z_vals, index = torch.sort(z_vals, dim=-1) 211 | 212 | if not last: 213 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) 214 | sdf = torch.cat([sdf, new_sdf], dim=-1) 215 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) 216 | index = index.reshape(-1) 217 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) 218 | 219 | return z_vals, sdf 220 | 221 | def render_color(self, 222 | rays_o, 223 | rays_d, 224 | z_vals, 225 | sample_dist, 226 | sdf_network, 227 | deviation_network, 228 | color_network, 229 | background_alpha=None, 230 | background_sampled_color=None, 231 | background_rgb=None, 232 | cos_anneal_ratio=0.0, 233 | ): 234 | 235 | batch_size, n_samples = z_vals.shape 236 | 237 | # Section length 238 | dists = z_vals[..., 1:] - z_vals[..., :-1] 239 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 240 | mid_z_vals = z_vals + dists * 0.5 241 | 242 | # Section midpoints 243 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 244 | dirs = rays_d[:, None, :].expand(pts.shape) 245 | pts = pts.reshape(-1, 3) 246 | dirs = dirs.reshape(-1, 3) 247 | 248 | sdf_nn_output = sdf_network(pts) 249 | sdf = sdf_nn_output[:, :1] 250 | feature_vector = sdf_nn_output[:, 1:] 251 | 252 | gradients = sdf_network.gradient(pts).squeeze() 253 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) 254 | 255 | # inv_s in the code == s in the paper == 1 / standard deviation 256 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 257 | inv_s = inv_s.expand(batch_size * n_samples, 1) 258 | 259 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 260 | 261 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 262 | # the cos value "not dead" at the beginning training iterations, for better convergence. 263 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 264 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 265 | 266 | # Estimate signed distances at section points 267 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 268 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 269 | 270 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 271 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 272 | 273 | p = prev_cdf - next_cdf 274 | c = prev_cdf 275 | 276 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 277 | 278 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 279 | inside_sphere = (pts_norm < 1.0).float().detach() 280 | 281 | # Render with background 282 | if background_alpha is not None: 283 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) 284 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) 285 | sampled_color = sampled_color * inside_sphere[:, :, None] +\ 286 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] 287 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) 288 | 289 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 290 | weights_sum = weights.sum(dim=-1, keepdim=True) 291 | 292 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 293 | if background_rgb is not None: # Fixed background, usually black 294 | color = color + background_rgb * (1.0 - weights_sum) 295 | 296 | return color 297 | 298 | def render_core(self, 299 | rays_o, 300 | rays_d, 301 | z_vals, 302 | sample_dist, 303 | sdf_network, 304 | deviation_network, 305 | color_network, 306 | model_input = None, 307 | background_alpha=None, 308 | background_sampled_color=None, 309 | background_rgb=None, 310 | mid_z_vals_out=None, 311 | cos_anneal_ratio=0.0, 312 | depth_from_inside_only=None, 313 | ): 314 | 315 | batch_size, n_samples = z_vals.shape 316 | 317 | # Section length 318 | dists = z_vals[..., 1:] - z_vals[..., :-1] 319 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 320 | mid_z_vals = z_vals + dists * 0.5 321 | 322 | # Section midpoints 323 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 324 | dirs = rays_d[:, None, :].expand(pts.shape) 325 | pts = pts.reshape(-1, 3) 326 | dirs = dirs.reshape(-1, 3) 327 | 328 | query_pts = pts.clone() 329 | 330 | sdf_nn_output = sdf_network(pts) 331 | sdf = sdf_nn_output[:, :1] 332 | feature_vector = sdf_nn_output[:, 1:] 333 | 334 | gradients = sdf_network.gradient(pts).squeeze() 335 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) 336 | 337 | # inv_s in the code == s in the paper == 1 / standard deviation 338 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 339 | inv_s = inv_s.expand(batch_size * n_samples, 1) 340 | 341 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 342 | 343 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 344 | # the cos value "not dead" at the beginning training iterations, for better convergence. 345 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 346 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 347 | 348 | # Estimate signed distances at section points 349 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 350 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 351 | 352 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 353 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 354 | 355 | p = prev_cdf - next_cdf 356 | c = prev_cdf 357 | 358 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 359 | alpha_in = alpha 360 | 361 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 362 | inside_sphere = (pts_norm < 1.0).float().detach() 363 | relax_inside_sphere = (pts_norm < 1.2).float().detach() 364 | 365 | # Render with background 366 | if background_alpha is not None: 367 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) 368 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) 369 | sampled_color = sampled_color * inside_sphere[:, :, None] +\ 370 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] 371 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) 372 | 373 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 374 | weights_sum = weights.sum(dim=-1, keepdim=True) 375 | if depth_from_inside_only: 376 | weights_in = alpha_in * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha_in + 1e-7], -1), -1)[:, :-1] 377 | weights_in_sum = weights_in.sum(dim=-1, keepdim=True) 378 | 379 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 380 | if background_rgb is not None: # Fixed background, usually black 381 | color = color + background_rgb * (1.0 - weights_sum) 382 | 383 | # Eikonal loss 384 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, 385 | dim=-1) - 1.0) ** 2 386 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) 387 | 388 | if model_input is not None: 389 | if background_alpha is not None: 390 | z_final = mid_z_vals_out 391 | else: 392 | z_final = mid_z_vals 393 | if depth_from_inside_only: 394 | z_final = mid_z_vals 395 | 396 | if depth_from_inside_only: 397 | dist_map = torch.sum(weights_in / (weights_in.sum(-1, keepdim=True)+1e-10) * z_final, -1) 398 | else: 399 | dist_map = torch.sum(weights / (weights.sum(-1, keepdim=True)+1e-10) * z_final, -1) 400 | 401 | sdf_all = sdf.reshape(batch_size,n_samples).unsqueeze(0) 402 | d_all = mid_z_vals.unsqueeze(0) 403 | d_surface, network_mask = find_surface_points(sdf_all, d_all) 404 | d_surface = d_surface.squeeze(0) 405 | network_mask = network_mask.squeeze(0) 406 | 407 | object_mask = network_mask 408 | 409 | point_surface = rays_o + rays_d * d_surface[:,None] 410 | point_surface_wmask = point_surface[network_mask & object_mask] 411 | 412 | points_rendered = rays_o + rays_d * dist_map[:,None] 413 | sdf_rendered_points = sdf_network(points_rendered)[:, :1] 414 | sdf_rendered_points_wmask = sdf_rendered_points[object_mask] 415 | sdf_rendered_points_0 = torch.zeros_like(sdf_rendered_points_wmask) 416 | pseudo_pts_loss = F.l1_loss(sdf_rendered_points_wmask, sdf_rendered_points_0, reduction='mean') 417 | 418 | return { 419 | 'color': color, 420 | 'sdf': sdf, 421 | 'dists': dists, 422 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 423 | 's_val': 1.0 / inv_s, 424 | 'mid_z_vals': mid_z_vals, 425 | 'weights': weights, 426 | 'cdf': c.reshape(batch_size, n_samples), 427 | 'gradient_error': gradient_error, 428 | 'inside_sphere': inside_sphere, 429 | 'pseudo_pts_loss': pseudo_pts_loss, 430 | 'query_pts': query_pts, 431 | 'point_surface': point_surface_wmask, 432 | 'network_mask': network_mask, 433 | 'object_mask': object_mask, 434 | } 435 | else: 436 | return { 437 | 'color': color, 438 | 'sdf': sdf, 439 | 'dists': dists, 440 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 441 | 's_val': 1.0 / inv_s, 442 | 'mid_z_vals': mid_z_vals, 443 | 'weights': weights, 444 | 'cdf': c.reshape(batch_size, n_samples), 445 | 'gradient_error': gradient_error, 446 | 'inside_sphere': inside_sphere, 447 | 'pseudo_pts_loss': torch.tensor(0.0).float(), 448 | 'query_pts': query_pts, 449 | } 450 | 451 | def render(self, 452 | rays_o, 453 | rays_d, 454 | near, 455 | far, 456 | main_img_idx, 457 | t, 458 | random_pcd=None, 459 | perturb_overwrite=-1, 460 | background_rgb=None, 461 | cos_anneal_ratio=0.0, 462 | model_input=None, 463 | depth_from_inside_only=False, 464 | ): 465 | batch_size = len(rays_o) 466 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere 467 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 468 | z_vals = near + (far - near) * z_vals[None, :] 469 | 470 | z_vals_outside = None 471 | if self.n_outside > 0: 472 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 473 | 474 | n_samples = self.n_samples 475 | perturb = self.perturb 476 | 477 | if perturb_overwrite >= 0: 478 | perturb = perturb_overwrite 479 | if perturb > 0: 480 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 481 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 482 | 483 | if self.n_outside > 0: 484 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) 485 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 486 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 487 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 488 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand 489 | 490 | if self.n_outside > 0: 491 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 492 | 493 | background_alpha = None 494 | background_sampled_color = None 495 | mid_z_vals_out = None 496 | 497 | # Up sample 498 | if self.n_importance > 0: 499 | with torch.no_grad(): 500 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] 501 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) 502 | 503 | for i in range(self.up_sample_steps): 504 | new_z_vals = self.up_sample(rays_o, 505 | rays_d, 506 | z_vals, 507 | sdf, 508 | self.n_importance // self.up_sample_steps, 509 | 64 * 2**i) 510 | z_vals, sdf = self.cat_z_vals(rays_o, 511 | rays_d, 512 | z_vals, 513 | new_z_vals, 514 | sdf, 515 | last=(i + 1 == self.up_sample_steps)) 516 | 517 | n_samples = self.n_samples + self.n_importance 518 | 519 | # Background model 520 | if self.n_outside > 0: 521 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 522 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 523 | 524 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) 525 | 526 | background_sampled_color = ret_outside['sampled_color'] 527 | background_alpha = ret_outside['alpha'] 528 | mid_z_vals_out= ret_outside['mid_z_vals_out'] 529 | 530 | # Render core 531 | ret_fine = self.render_core(rays_o, 532 | rays_d, 533 | z_vals, 534 | sample_dist, 535 | self.sdf_network, 536 | self.deviation_network, 537 | self.color_network, 538 | model_input = model_input, 539 | background_rgb=background_rgb, 540 | background_alpha=background_alpha, 541 | background_sampled_color=background_sampled_color, 542 | mid_z_vals_out= mid_z_vals_out, 543 | cos_anneal_ratio=cos_anneal_ratio, 544 | depth_from_inside_only=depth_from_inside_only) 545 | 546 | color_fine = ret_fine['color'] 547 | weights = ret_fine['weights'] 548 | weights_sum = weights.sum(dim=-1, keepdim=True) 549 | gradients = ret_fine['gradients'] 550 | s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) 551 | 552 | local_loss = torch.tensor(0).float() 553 | if model_input is not None: 554 | output = { 555 | 'color_fine': color_fine, 556 | 's_val': s_val, 557 | 'cdf_fine': ret_fine['cdf'], 558 | 'weight_sum': weights_sum, 559 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], 560 | 'gradients': gradients, 561 | 'weights': weights, 562 | 'gradient_error': ret_fine['gradient_error'], 563 | 'inside_sphere': ret_fine['inside_sphere'], 564 | 'pseudo_pts_loss': ret_fine['pseudo_pts_loss'], 565 | 'sdf': ret_fine['sdf'], 566 | 'query_pts': ret_fine['query_pts'], 567 | } 568 | 569 | point_surface_wmask = ret_fine['point_surface'] 570 | network_mask = ret_fine['network_mask'] 571 | object_mask = ret_fine['object_mask'] 572 | 573 | size, center = model_input['size'].unsqueeze(0), model_input['center'].unsqueeze(0) 574 | size = size[:1] 575 | center = center[:1] 576 | 577 | cam = model_input['cam'] # 2, 4, 4 578 | src_cams = model_input['src_cams'] # m, 2, 4, 4 579 | feat_src = model_input['feat_src'] 580 | 581 | if (t % 100 == 0) and (random_pcd is not None): 582 | ''' unseen view rendering ''' 583 | random_pcd = random_pcd.view(1, -1, 3) 584 | random_pcd.requires_grad = False 585 | 586 | src_img_idx = model_input['src_idxs'][0] 587 | rays_o, rays_d = self.dataset.gen_rays_between_from_pts(main_img_idx, 588 | src_img_idx, 589 | 0.5, 590 | random_pcd, 591 | ) 592 | 593 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 594 | 595 | batch_size = len(rays_o) 596 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere 597 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 598 | z_vals = near + (far - near) * z_vals[None, :] 599 | 600 | z_vals_outside = None 601 | if self.n_outside > 0: 602 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 603 | 604 | n_samples = self.n_samples 605 | perturb = self.perturb 606 | 607 | if perturb_overwrite >= 0: 608 | perturb = perturb_overwrite 609 | if perturb > 0: 610 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 611 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 612 | 613 | if self.n_outside > 0: 614 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) 615 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 616 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 617 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 618 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand 619 | 620 | if self.n_outside > 0: 621 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 622 | 623 | background_alpha = None 624 | background_sampled_color = None 625 | mid_z_vals_out = None 626 | 627 | # Up sample 628 | if self.n_importance > 0: 629 | with torch.no_grad(): 630 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] 631 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) 632 | 633 | for i in range(self.up_sample_steps): 634 | new_z_vals = self.up_sample(rays_o, 635 | rays_d, 636 | z_vals, 637 | sdf, 638 | self.n_importance // self.up_sample_steps, 639 | 64 * 2**i) 640 | z_vals, sdf = self.cat_z_vals(rays_o, 641 | rays_d, 642 | z_vals, 643 | new_z_vals, 644 | sdf, 645 | last=(i + 1 == self.up_sample_steps)) 646 | 647 | n_samples = self.n_samples + self.n_importance 648 | 649 | # Background model 650 | if self.n_outside > 0: 651 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 652 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 653 | 654 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) 655 | 656 | background_sampled_color = ret_outside['sampled_color'] 657 | background_alpha = ret_outside['alpha'] 658 | mid_z_vals_out= ret_outside['mid_z_vals_out'] 659 | 660 | color = self.render_color(rays_o, 661 | rays_d, 662 | z_vals, 663 | sample_dist, 664 | self.sdf_network, 665 | self.deviation_network, 666 | self.color_network, 667 | background_rgb=background_rgb, 668 | background_alpha=background_alpha, 669 | background_sampled_color=background_sampled_color, 670 | cos_anneal_ratio=cos_anneal_ratio, 671 | ) 672 | 673 | us_pose = cam.clone() 674 | us_pose[0] = feat_utils.gen_camera_between(cam[0].cpu().numpy(), src_cams[0, 0].cpu().numpy(), 0.5) 675 | us_pose.requires_grad = False 676 | us_pose = us_pose.unsqueeze(0) # 2, 4, 4 677 | 678 | us_rgb = torch.zeros([1, 3, 768, 1024]).cuda() 679 | 680 | pts_world = random_pcd.view(1, -1, 1, 3, 1) 681 | pts_world = torch.cat([pts_world, torch.ones_like(pts_world[..., -1:, :])], dim=-2) 682 | pts_img = feat_utils.idx_cam2img(feat_utils.idx_world2cam(pts_world, us_pose), us_pose).view(1, -1, 3) # 1, N, 3 683 | us_uv = pts_img[..., :2] / pts_img[..., 2:3] 684 | us_uv = us_uv.round().long() 685 | 686 | color_mask = ((us_uv[..., 0] > -1) & (us_uv[..., 0] < 1024) & (us_uv[..., 1] > -1) & (us_uv[..., 1] < 768)).squeeze(0) 687 | 688 | us_uv = us_uv[0, color_mask] # M, 2 689 | color = color[color_mask] # M, 3 690 | 691 | _, cnts = torch.unique(us_uv, sorted=False, return_counts=True, dim=0) 692 | cnts = torch.cat((torch.tensor([0]).long().cuda(), cnts)) 693 | unique_index = torch.cumsum(cnts, dim=0) 694 | unique_index = unique_index[:-1] 695 | 696 | us_uv = us_uv[unique_index] 697 | color = color[unique_index].transpose(0, 1) 698 | 699 | us_rgb[0, :, us_uv[:, 1], us_uv[:, 0]] = color 700 | 701 | us_feat = self.feat_ext(us_rgb)[2] 702 | 703 | local_loss += feat_utils.get_local_loss(random_pcd.reshape(-1, 3), None, us_feat, 704 | us_pose, feat_src.unsqueeze(0), src_cams.unsqueeze(0), 705 | 2 * torch.ones_like(size).cuda(), torch.zeros_like(center).cuda(), 706 | color_mask.reshape(-1), color_mask.reshape(-1)) 707 | 708 | local_loss += feat_utils.get_local_loss(point_surface_wmask, None, model_input['feat'].unsqueeze(0), 709 | cam.unsqueeze(0), feat_src.unsqueeze(0), src_cams.unsqueeze(0), 710 | size, center, network_mask.reshape(-1), 711 | object_mask.reshape(-1)) 712 | 713 | output['local_loss'] = local_loss 714 | return output 715 | 716 | else: 717 | return { 718 | 'color_fine': color_fine, 719 | 's_val': s_val, 720 | 'cdf_fine': ret_fine['cdf'], 721 | 'weight_sum': weights_sum, 722 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], 723 | 'gradients': gradients, 724 | 'weights': weights, 725 | 'gradient_error': ret_fine['gradient_error'], 726 | 'inside_sphere': ret_fine['inside_sphere'], 727 | 'pseudo_pts_loss': ret_fine['pseudo_pts_loss'], 728 | 'local_loss': local_loss, 729 | 'sdf': ret_fine['sdf'], 730 | 'query_pts': ret_fine['query_pts'], 731 | } 732 | 733 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): 734 | return extract_geometry(bound_min, 735 | bound_max, 736 | resolution=resolution, 737 | threshold=threshold, 738 | query_func=lambda pts: -self.sdf_network.sdf(pts)) 739 | -------------------------------------------------------------------------------- /models/udf_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | from scipy.spatial import cKDTree 6 | import trimesh 7 | 8 | def search_nearest_point(point_batch, point_gt): 9 | num_point_batch, num_point_gt = point_batch.shape[0], point_gt.shape[0] 10 | point_batch = point_batch.unsqueeze(1).repeat(1, num_point_gt, 1) 11 | point_gt = point_gt.unsqueeze(0).repeat(num_point_batch, 1, 1) 12 | 13 | distances = torch.sqrt(torch.sum((point_batch-point_gt) ** 2, axis=-1) + 1e-12) 14 | dis_idx = torch.argmin(distances, axis=1).detach().cpu().numpy() 15 | 16 | return dis_idx 17 | 18 | def process_data(data_dir, dataname): 19 | if os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.ply'): 20 | pointcloud = trimesh.load(os.path.join(data_dir, 'pcd', dataname) + '.ply').vertices 21 | pointcloud = np.asarray(pointcloud) 22 | elif os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.xyz'): 23 | pointcloud = np.loadtxt(os.path.join(data_dir, 'pcd', dataname) + '.xyz') 24 | elif os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.npy'): 25 | pointcloud = np.load(os.path.join(data_dir, 'pcd', dataname) + '.npy') 26 | else: 27 | print('Only support .ply, .xyz or .npy data. Please adjust your data format.') 28 | exit() 29 | shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])]) 30 | shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2] 31 | pointcloud = pointcloud - shape_center 32 | pointcloud = pointcloud / shape_scale 33 | 34 | POINT_NUM = pointcloud.shape[0] // 60 35 | POINT_NUM_GT = pointcloud.shape[0] // 60 * 60 36 | QUERY_EACH = 1000000//POINT_NUM_GT 37 | 38 | point_idx = np.random.choice(pointcloud.shape[0], POINT_NUM_GT, replace = False) 39 | pointcloud = pointcloud[point_idx,:] 40 | ptree = cKDTree(pointcloud) 41 | sigmas = [] 42 | for p in np.array_split(pointcloud,100,axis=0): 43 | d = ptree.query(p,51) 44 | sigmas.append(d[0][:,-1]) 45 | 46 | sigmas = np.concatenate(sigmas) 47 | sample = [] 48 | sample_near = [] 49 | 50 | for i in range(QUERY_EACH): 51 | scale = 0.25 if 0.25 * np.sqrt(POINT_NUM_GT / 20000) < 0.25 else 0.25 * np.sqrt(POINT_NUM_GT / 20000) 52 | tt = pointcloud + scale*np.expand_dims(sigmas,-1) * np.random.normal(0.0, 1.0, size=pointcloud.shape) 53 | sample.append(tt) 54 | tt = tt.reshape(-1,POINT_NUM,3) 55 | 56 | sample_near_tmp = [] 57 | for j in range(tt.shape[0]): 58 | nearest_idx = search_nearest_point(torch.tensor(tt[j]).float().cuda(), torch.tensor(pointcloud).float().cuda()) 59 | nearest_points = pointcloud[nearest_idx] 60 | nearest_points = np.asarray(nearest_points).reshape(-1,3) 61 | sample_near_tmp.append(nearest_points) 62 | sample_near_tmp = np.asarray(sample_near_tmp) 63 | sample_near_tmp = sample_near_tmp.reshape(-1,3) 64 | sample_near.append(sample_near_tmp) 65 | 66 | sample = np.asarray(sample) 67 | sample_near = np.asarray(sample_near) 68 | 69 | os.makedirs(os.path.join(data_dir, 'query_data'), exist_ok=True) 70 | np.savez(os.path.join(data_dir, 'query_data', dataname)+'.npz', sample = sample, point = pointcloud, sample_near = sample_near) 71 | 72 | class Dataset: 73 | def __init__(self, conf, dataname): 74 | super(Dataset, self).__init__() 75 | print('Load data: Begin') 76 | self.device = torch.device('cuda') 77 | self.conf = conf 78 | 79 | self.data_dir = conf.get_string('data_dir').replace('CASE_NAME', dataname) 80 | print(self.data_dir) 81 | self.data_name = dataname + '.npz' 82 | 83 | if os.path.exists(os.path.join(self.data_dir, 'query_data', self.data_name)): 84 | print('Query data existing. Loading data...') 85 | else: 86 | print('Query data not found. Processing data...') 87 | process_data(self.data_dir, dataname) 88 | 89 | load_data = np.load(os.path.join(self.data_dir, 'query_data', self.data_name)) 90 | 91 | self.point = np.asarray(load_data['sample_near']).reshape(-1,3) 92 | self.sample = np.asarray(load_data['sample']).reshape(-1,3) 93 | self.point_gt = np.asarray(load_data['point']).reshape(-1,3) 94 | self.sample_points_num = self.sample.shape[0]-1 95 | 96 | self.object_bbox_min = np.array([np.min(self.point[:,0]), np.min(self.point[:,1]), np.min(self.point[:,2])]) -0.05 97 | self.object_bbox_max = np.array([np.max(self.point[:,0]), np.max(self.point[:,1]), np.max(self.point[:,2])]) +0.05 98 | print('bd:',self.object_bbox_min,self.object_bbox_max) 99 | 100 | self.point = torch.from_numpy(self.point).to(self.device).float() 101 | self.sample = torch.from_numpy(self.sample).to(self.device).float() 102 | self.point_gt = torch.from_numpy(self.point_gt).to(self.device).float() 103 | 104 | print('NP Load data: End') 105 | 106 | def get_train_data(self, batch_size): 107 | index_coarse = np.random.choice(10, 1) 108 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False) 109 | index = index_fine * 10 + index_coarse # for accelerating random choice operation 110 | points = self.point[index] 111 | sample = self.sample[index] 112 | return points, sample, self.point_gt 113 | 114 | def gen_new_data(self, tree): 115 | distance, index = tree.query(self.sample.detach().cpu().numpy(), 1) 116 | self.point_new = tree.data[index] 117 | self.point_new = torch.from_numpy(self.point_new).to(self.device).float() 118 | 119 | 120 | def get_train_data_step2(self, batch_size): 121 | index_coarse = np.random.choice(10, 1) 122 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False) 123 | index = index_fine * 10 + index_coarse # for accelerating random choice operation 124 | points = self.point_new[index] 125 | sample = self.sample[index] 126 | return points, sample, self.point_gt 127 | 128 | -------------------------------------------------------------------------------- /models/udf_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /models/udf_fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | class UDFNetwork(nn.Module): 8 | def __init__(self, 9 | d_in, 10 | d_out, 11 | d_hidden, 12 | n_layers, 13 | skip_in=(4,), 14 | multires=0, 15 | bias=0.5, 16 | scale=1, 17 | geometric_init=True, 18 | weight_norm=True, 19 | inside_outside=False): 20 | super(UDFNetwork, self).__init__() 21 | 22 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 23 | 24 | self.embed_fn_fine = None 25 | 26 | if multires > 0: 27 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 28 | self.embed_fn_fine = embed_fn 29 | dims[0] = input_ch 30 | 31 | self.num_layers = len(dims) 32 | self.skip_in = skip_in 33 | self.scale = scale 34 | 35 | for l in range(0, self.num_layers - 1): 36 | if l + 1 in self.skip_in: 37 | out_dim = dims[l + 1] - dims[0] 38 | else: 39 | out_dim = dims[l + 1] 40 | 41 | lin = nn.Linear(dims[l], out_dim) 42 | 43 | if geometric_init: 44 | if l == self.num_layers - 2: 45 | if not inside_outside: 46 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 47 | torch.nn.init.constant_(lin.bias, -bias) 48 | else: 49 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 50 | torch.nn.init.constant_(lin.bias, bias) 51 | elif multires > 0 and l == 0: 52 | torch.nn.init.constant_(lin.bias, 0.0) 53 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 54 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 55 | elif multires > 0 and l in self.skip_in: 56 | torch.nn.init.constant_(lin.bias, 0.0) 57 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 58 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 59 | else: 60 | torch.nn.init.constant_(lin.bias, 0.0) 61 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 62 | 63 | if weight_norm: 64 | lin = nn.utils.weight_norm(lin) 65 | 66 | setattr(self, "lin" + str(l), lin) 67 | 68 | #self.activation = nn.Softplus(beta=100) 69 | self.activation = nn.ReLU() 70 | 71 | self.act_last = nn.Sigmoid() 72 | 73 | def forward(self, inputs): 74 | inputs = inputs * self.scale 75 | if self.embed_fn_fine is not None: 76 | inputs = self.embed_fn_fine(inputs) 77 | 78 | x = inputs 79 | for l in range(0, self.num_layers - 1): 80 | lin = getattr(self, "lin" + str(l)) 81 | 82 | if l in self.skip_in: 83 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 84 | 85 | x = lin(x) 86 | 87 | if l < self.num_layers - 2: 88 | x = self.activation(x) 89 | 90 | # x = self.act_last(x) 91 | res = torch.abs(x) 92 | # res = 1 - torch.exp(-x) 93 | return res / self.scale 94 | 95 | def udf(self, x): 96 | return self.forward(x) 97 | 98 | def udf_hidden_appearance(self, x): 99 | return self.forward(x) 100 | 101 | def gradient(self, x): 102 | x.requires_grad_(True) 103 | y = self.udf(x) 104 | # y.requires_grad_(True) 105 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 106 | gradients = torch.autograd.grad( 107 | outputs=y, 108 | inputs=x, 109 | grad_outputs=d_output, 110 | create_graph=True, 111 | retain_graph=True, 112 | only_inputs=True)[0] 113 | return gradients.unsqueeze(1) 114 | 115 | -------------------------------------------------------------------------------- /pretrained_model/vismvsnet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/pretrained_model/vismvsnet.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.65.0 2 | pyhocon==0.3.57 3 | trimesh==3.22.5 4 | PyMCubes==0.1.4 5 | scipy==1.10.1 6 | point_cloud_utils==0.29.7 7 | icecream==2.1.3 8 | opencv-python==4.7.0.72 9 | tensorboard==2.12.1 -------------------------------------------------------------------------------- /tools/feat_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import List, Union, Tuple 6 | from collections import OrderedDict 7 | from scipy.spatial.transform import Rotation as Rot 8 | from scipy.spatial.transform import Slerp 9 | 10 | def scale_camera(cam: Union[np.ndarray, torch.Tensor], scale: Union[Tuple, float]=1): 11 | """ resize input in order to produce sampled depth map """ 12 | if type(scale) != tuple: 13 | scale = (scale, scale) 14 | if type(cam) == np.ndarray: 15 | new_cam = np.copy(cam) 16 | # focal: 17 | new_cam[1, 0, 0] = cam[1, 0, 0] * scale[0] 18 | new_cam[1, 1, 1] = cam[1, 1, 1] * scale[1] 19 | # principle point: 20 | new_cam[1, 0, 2] = cam[1, 0, 2] * scale[0] 21 | new_cam[1, 1, 2] = cam[1, 1, 2] * scale[1] 22 | elif type(cam) == torch.Tensor: 23 | new_cam = cam.clone() 24 | # focal: 25 | new_cam[..., 1, 0, 0] = cam[..., 1, 0, 0] * scale[0] 26 | new_cam[..., 1, 1, 1] = cam[..., 1, 1, 1] * scale[1] 27 | # principle point: 28 | new_cam[..., 1, 0, 2] = cam[..., 1, 0, 2] * scale[0] 29 | new_cam[..., 1, 1, 2] = cam[..., 1, 1, 2] * scale[1] 30 | else: 31 | raise TypeError 32 | return new_cam 33 | 34 | 35 | def bin_op_reduce(lst, func): 36 | result = lst[0] 37 | for i in range(1, len(lst)): 38 | result = func(result, lst[i]) 39 | return result 40 | 41 | 42 | def idx_world2cam(idx_world_homo, cam): 43 | """nhw41 -> nhw41""" 44 | idx_cam_homo = cam[:,0:1,...].unsqueeze(1) @ idx_world_homo # nhw41 45 | idx_cam_homo = idx_cam_homo / (idx_cam_homo[...,-1:,:]+1e-9) # nhw41 46 | return idx_cam_homo 47 | 48 | 49 | def idx_cam2img(idx_cam_homo, cam): 50 | """nhw41 -> nhw31""" 51 | idx_cam = idx_cam_homo[...,:3,:] / (idx_cam_homo[...,3:4,:]+1e-9) # nhw31 52 | idx_img_homo = cam[:,1:2,:3,:3].unsqueeze(1) @ idx_cam # nhw31 53 | idx_img_homo = idx_img_homo / (idx_img_homo[...,-1:,:]+1e-9) 54 | return idx_img_homo 55 | 56 | 57 | 58 | def normalize_for_grid_sample(input_, grid): 59 | size = torch.tensor(input_.size())[2:].flip(0).to(grid.dtype).to(grid.device).view(1,1,1,-1) # [[[w, h]]] 60 | grid_n = grid / size 61 | grid_n = (grid_n * 2 - 1).clamp(-1.1, 1.1) 62 | return grid_n 63 | 64 | 65 | def get_in_range(grid): 66 | """after normalization, keepdim=False""" 67 | masks = [] 68 | for dim in range(grid.size()[-1]): 69 | masks += [grid[..., dim]<=1, grid[..., dim]>=-1] 70 | in_range = bin_op_reduce(masks, torch.min).to(grid.dtype) 71 | return in_range 72 | 73 | 74 | def load_pair(file: str): 75 | with open(file) as f: 76 | lines = f.readlines() 77 | n_cam = int(lines[0]) 78 | pairs = {} 79 | img_ids = [] 80 | for i in range(1, 1+2*n_cam, 2): 81 | pair = [] 82 | score = [] 83 | img_id = lines[i].strip() 84 | pair_str = lines[i+1].strip().split(' ') 85 | n_pair = int(pair_str[0]) 86 | for j in range(1, 1+2*n_pair, 2): 87 | pair.append(pair_str[j]) 88 | score.append(float(pair_str[j+1])) 89 | img_ids.append(img_id) 90 | pairs[img_id] = {'id': img_id, 'index': i//2, 'pair': pair, 'score': score} 91 | pairs['id_list'] = img_ids 92 | return pairs 93 | 94 | 95 | def load_cam(file: str, max_d, interval_scale=1, override=False): 96 | """ read camera txt file """ 97 | cam = np.zeros((2, 4, 4)) 98 | with open(file) as f: 99 | words = f.read().split() 100 | # read extrinsic 101 | for i in range(0, 4): 102 | for j in range(0, 4): 103 | extrinsic_index = 4 * i + j + 1 104 | cam[0][i][j] = words[extrinsic_index] 105 | 106 | # read intrinsic 107 | for i in range(0, 3): 108 | for j in range(0, 3): 109 | intrinsic_index = 3 * i + j + 18 110 | cam[1][i][j] = words[intrinsic_index] 111 | 112 | if len(words) == 29: 113 | cam[1][3][0] = words[27] 114 | cam[1][3][1] = float(words[28]) * interval_scale 115 | cam[1][3][2] = max_d 116 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (cam[1][3][2] - 1) 117 | elif len(words) == 30: 118 | cam[1][3][0] = words[27] 119 | cam[1][3][1] = float(words[28]) * interval_scale 120 | cam[1][3][2] = words[29] 121 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (cam[1][3][2] - 1) 122 | elif len(words) == 31: 123 | if override: 124 | cam[1][3][0] = words[27] 125 | cam[1][3][1] = (float(words[30]) - float(words[27])) / (max_d - 1) 126 | cam[1][3][2] = max_d 127 | cam[1][3][3] = words[30] 128 | else: 129 | cam[1][3][0] = words[27] 130 | cam[1][3][1] = float(words[28]) * interval_scale 131 | cam[1][3][2] = words[29] 132 | cam[1][3][3] = words[30] 133 | else: 134 | cam[1][3][0] = 0 135 | cam[1][3][1] = 0 136 | cam[1][3][2] = 0 137 | cam[1][3][3] = 0 138 | 139 | return cam 140 | 141 | class ListModule(nn.Module): 142 | def __init__(self, modules: Union[List, OrderedDict]): 143 | super(ListModule, self).__init__() 144 | if isinstance(modules, OrderedDict): 145 | iterable = modules.items() 146 | elif isinstance(modules, list): 147 | iterable = enumerate(modules) 148 | else: 149 | raise TypeError('modules should be OrderedDict or List.') 150 | for name, module in iterable: 151 | if not isinstance(module, nn.Module): 152 | module = ListModule(module) 153 | if not isinstance(name, str): 154 | name = str(name) 155 | self.add_module(name, module) 156 | 157 | def __getitem__(self, idx): 158 | if idx < 0 or idx >= len(self._modules): 159 | raise IndexError('index {} is out of range'.format(idx)) 160 | it = iter(self._modules.values()) 161 | for i in range(idx): 162 | next(it) 163 | return next(it) 164 | 165 | def __iter__(self): 166 | return iter(self._modules.values()) 167 | 168 | def __len__(self): 169 | return len(self._modules) 170 | 171 | 172 | class BasicBlock(nn.Module): 173 | expansion = 1 174 | 175 | def __init__(self, inplanes, planes, stride=1, downsample=None, dim=2): 176 | super(BasicBlock, self).__init__() 177 | 178 | self.conv_fn = nn.Conv2d if dim == 2 else nn.Conv3d 179 | self.bn_fn = nn.BatchNorm2d if dim == 2 else nn.BatchNorm3d 180 | 181 | self.conv1 = self.conv3x3(inplanes, planes, stride) 182 | self.bn1 = self.bn_fn(planes) 183 | self.relu = nn.ReLU(inplace=True) 184 | self.conv2 = self.conv3x3(planes, planes) 185 | self.bn2 = self.bn_fn(planes) 186 | self.downsample = downsample 187 | self.stride = stride 188 | 189 | def conv1x1(self, in_planes, out_planes, stride=1): 190 | """1x1 convolution""" 191 | return self.conv_fn(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 192 | 193 | def conv3x3(self, in_planes, out_planes, stride=1): 194 | """3x3 convolution with padding""" 195 | return self.conv_fn(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 196 | 197 | def forward(self, x): 198 | residual = x 199 | 200 | out = self.conv1(x) 201 | out = self.bn1(out) 202 | out = self.relu(out) 203 | 204 | out = self.conv2(out) 205 | out = self.bn2(out) 206 | 207 | if self.downsample is not None: 208 | residual = self.downsample(x) 209 | 210 | out += residual 211 | out = self.relu(out) 212 | 213 | return out 214 | 215 | 216 | def _make_layer(inplanes, block, planes, blocks, stride=1, dim=2): 217 | downsample = None 218 | conv_fn = nn.Conv2d if dim==2 else nn.Conv3d 219 | bn_fn = nn.BatchNorm2d if dim==2 else nn.BatchNorm3d 220 | if stride != 1 or inplanes != planes * block.expansion: 221 | downsample = nn.Sequential( 222 | conv_fn(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 223 | bn_fn(planes * block.expansion) 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(inplanes, planes, stride, downsample, dim=dim)) 228 | inplanes = planes * block.expansion 229 | for _ in range(1, blocks): 230 | layers.append(block(inplanes, planes, dim=dim)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | 235 | class UNet(nn.Module): 236 | 237 | def __init__(self, inplanes: int, enc: int, dec: int, initial_scale: int, 238 | bottom_filters: List[int], filters: List[int], head_filters: List[int], 239 | prefix: str, dim: int=2): 240 | super(UNet, self).__init__() 241 | 242 | conv_fn = nn.Conv2d if dim==2 else nn.Conv3d 243 | deconv_fn = nn.ConvTranspose2d if dim==2 else nn.ConvTranspose3d 244 | current_scale = initial_scale 245 | idx = 0 246 | prev_f = inplanes 247 | 248 | self.bottom_blocks = OrderedDict() 249 | for f in bottom_filters: 250 | block = _make_layer(prev_f, BasicBlock, f, enc, 1 if idx==0 else 2, dim=dim) 251 | self.bottom_blocks[f'{prefix}{current_scale}_{idx}'] = block 252 | idx += 1 253 | current_scale *= 2 254 | prev_f = f 255 | self.bottom_blocks = ListModule(self.bottom_blocks) 256 | 257 | self.enc_blocks = OrderedDict() 258 | for f in filters: 259 | block = _make_layer(prev_f, BasicBlock, f, enc, 1 if idx == 0 else 2, dim=dim) 260 | self.enc_blocks[f'{prefix}{current_scale}_{idx}'] = block 261 | idx += 1 262 | current_scale *= 2 263 | prev_f = f 264 | self.enc_blocks = ListModule(self.enc_blocks) 265 | 266 | self.dec_blocks = OrderedDict() 267 | for f in filters[-2::-1]: 268 | block = [ 269 | deconv_fn(prev_f, f, 3, 2, 1, 1, bias=False), 270 | conv_fn(2*f, f, 3, 1, 1, bias=False), 271 | ] 272 | if dec > 0: 273 | block.append(_make_layer(f, BasicBlock, f, dec, 1, dim=dim)) 274 | self.dec_blocks[f'{prefix}{current_scale}_{idx}'] = block 275 | idx += 1 276 | current_scale //= 2 277 | prev_f = f 278 | self.dec_blocks = ListModule(self.dec_blocks) 279 | 280 | self.head_blocks = OrderedDict() 281 | for f in head_filters: 282 | block = [ 283 | deconv_fn(prev_f, f, 3, 2, 1, 1, bias=False) 284 | ] 285 | if dec > 0: 286 | block.append(_make_layer(f, BasicBlock, f, dec, 1, dim=dim)) 287 | block = nn.Sequential(*block) 288 | self.head_blocks[f'{prefix}{current_scale}_{idx}'] = block 289 | idx += 1 290 | current_scale //= 2 291 | prev_f = f 292 | self.head_blocks = ListModule(self.head_blocks) 293 | 294 | def forward(self, x, multi_scale=1): 295 | for b in self.bottom_blocks: 296 | x = b(x) 297 | enc_out = [] 298 | for b in self.enc_blocks: 299 | x = b(x) 300 | enc_out.append(x) 301 | dec_out = [x] 302 | for i, b in enumerate(self.dec_blocks): 303 | if len(b) == 3: deconv, post_concat, res = b 304 | elif len(b) == 2: deconv, post_concat = b 305 | x = deconv(x) 306 | x = torch.cat([x, enc_out[-2-i]], 1) 307 | x = post_concat(x) 308 | if len(b) == 3: x = res(x) 309 | dec_out.append(x) 310 | for b in self.head_blocks: 311 | x = b(x) 312 | dec_out.append(x) 313 | if multi_scale == 1: return x 314 | else: return dec_out[-multi_scale:] 315 | 316 | 317 | class FeatExt(nn.Module): 318 | 319 | def __init__(self): 320 | super(FeatExt, self).__init__() 321 | self.init_conv = nn.Sequential( 322 | nn.Conv2d(3, 16, 5, 2, 2, bias=False), 323 | nn.BatchNorm2d(16), 324 | nn.ReLU() 325 | ) 326 | self.unet = UNet(16, 2, 1, 2, [], [32, 64, 128], [], '2d', 2) 327 | self.final_conv_1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False) 328 | self.final_conv_2 = nn.Conv2d(64, 32, 3, 1, 1, bias=False) 329 | self.final_conv_3 = nn.Conv2d(32, 32, 3, 1, 1, bias=False) 330 | 331 | feat_ext_dict = {k[16:]:v for k,v in torch.load('pretrained_model/vismvsnet.pt')['state_dict'].items() if k.startswith('module.feat_ext')} 332 | self.load_state_dict(feat_ext_dict) 333 | 334 | def forward(self, x): 335 | out = self.init_conv(x) 336 | out1, out2, out3 = self.unet(out, multi_scale=3) 337 | return self.final_conv_1(out1), self.final_conv_2(out2), self.final_conv_3(out3) 338 | 339 | 340 | def gen_camera_between(pose_0, pose_1, ratio): 341 | rot_0 = pose_0[:3, :3] 342 | rot_1 = pose_1[:3, :3] 343 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 344 | key_times = [0, 1] 345 | slerp = Slerp(key_times, rots) 346 | rot = slerp(ratio) 347 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 348 | pose = pose.astype(np.float32) 349 | pose[:3, :3] = rot.as_matrix() 350 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 351 | pose = torch.from_numpy(pose).cuda() 352 | pose.requires_grad = False 353 | return pose 354 | 355 | 356 | def get_local_loss(diff_surf_pts, 357 | uncerts, 358 | feat, 359 | cam, 360 | feat_src, 361 | src_cams, 362 | size, 363 | center, 364 | network_object_mask, 365 | object_mask 366 | ): 367 | mask = network_object_mask & object_mask 368 | 369 | if (mask).sum() == 0: 370 | return torch.tensor(0.0).float().cuda() 371 | 372 | sample_mask = mask.view(feat.size()[0], -1) 373 | hit_nums = sample_mask.sum(-1) 374 | accu_nums = [0] + hit_nums.cumsum(0).tolist() 375 | slices = [slice(accu_nums[i], accu_nums[i + 1]) for i in range(len(accu_nums) - 1)] 376 | 377 | loss = [] 378 | for view_i, slice_ in enumerate(slices): 379 | if slice_.start < slice_.stop: 380 | 381 | # projection 382 | diff_surf_pts_slice = diff_surf_pts[slice_] 383 | pts_world = (diff_surf_pts_slice / 2 * size.view(1, 1) + center.view(1, 3)).view(1, -1, 1, 3, 1) 384 | pts_world = torch.cat([pts_world, torch.ones_like(pts_world[..., -1:, :])], dim=-2) 385 | cam_pack = torch.cat([cam[view_i:view_i + 1], src_cams[view_i]], dim=0) 386 | pts_img = idx_cam2img(idx_world2cam(pts_world, cam_pack), cam_pack) 387 | 388 | # gathering 389 | grid = pts_img[..., :2, 0] 390 | 391 | feat2_pack = torch.cat([feat[view_i:view_i + 1], feat_src[view_i]], dim=0) 392 | grid_n = normalize_for_grid_sample(feat2_pack, grid / 2) 393 | grid_in_range = get_in_range(grid_n) 394 | valid_mask = (grid_in_range[:1, ...] * grid_in_range[1:, ...]).unsqueeze(1) > 0.5 395 | gathered_feat = F.grid_sample(feat2_pack, grid_n, mode='bilinear', padding_mode='zeros', 396 | align_corners=False) 397 | 398 | # calculation 399 | gathered_norm = gathered_feat.norm(dim=1, keepdim=True) 400 | corr = (gathered_feat[:1] * gathered_feat[1:]).sum(dim=1, keepdim=True) \ 401 | / gathered_norm[:1].clamp(min=1e-9) / gathered_norm[1:].clamp(min=1e-9) 402 | corr_loss = (1 - corr).abs() 403 | if uncerts is None: 404 | diff_mask = corr_loss < 0.5 405 | sample_loss = (corr_loss * valid_mask * diff_mask).mean() 406 | else: 407 | uncert = uncerts[view_i].unsqueeze(1).unsqueeze(3) 408 | sample_loss = ((corr_loss * (-uncert).exp() + uncert) * valid_mask).mean() 409 | else: 410 | sample_loss = torch.zeros(1).float().cuda() 411 | loss.append(sample_loss) 412 | loss = sum(loss) / len(loss) 413 | return loss 414 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /tools/surface_extraction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mcubes 3 | import trimesh 4 | import torch 5 | 6 | from extensions.chamfer_dist import ChamferDistanceL2 7 | from tools.logger import print_log 8 | def as_mesh(scene_or_mesh): 9 | """ 10 | Convert a possible scene to a mesh. 11 | 12 | If conversion occurs, the returned mesh has only vertex and face data. 13 | Suggested by https://github.com/mikedh/trimesh/issues/507 14 | """ 15 | if isinstance(scene_or_mesh, trimesh.Scene): 16 | if len(scene_or_mesh.geometry) == 0: 17 | mesh = None # empty scene 18 | else: 19 | # we lose texture information here 20 | mesh = trimesh.util.concatenate( 21 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 22 | for g in scene_or_mesh.geometry.values())) 23 | else: 24 | assert(isinstance(scene_or_mesh, trimesh.Trimesh)) 25 | mesh = scene_or_mesh 26 | return mesh 27 | 28 | def surface_extraction(ndf, grad, out_path, iter_step, b_max, b_min, resolution): 29 | v_all = [] 30 | t_all = [] 31 | threshold = 0.005 # accelerate extraction 32 | v_num = 0 33 | for i in range(resolution-1): 34 | for j in range(resolution-1): 35 | for k in range(resolution-1): 36 | ndf_loc = ndf[i:i+2] 37 | ndf_loc = ndf_loc[:,j:j+2,:] 38 | ndf_loc = ndf_loc[:,:,k:k+2] 39 | if np.min(ndf_loc) > threshold: 40 | continue 41 | grad_loc = grad[i:i+2] 42 | grad_loc = grad_loc[:,j:j+2,:] 43 | grad_loc = grad_loc[:,:,k:k+2] 44 | 45 | res = np.ones((2,2,2)) 46 | for ii in range(2): 47 | for jj in range(2): 48 | for kk in range(2): 49 | if np.dot(grad_loc[0][0][0], grad_loc[ii][jj][kk]) < 0: 50 | res[ii][jj][kk] = -ndf_loc[ii][jj][kk] 51 | else: 52 | res[ii][jj][kk] = ndf_loc[ii][jj][kk] 53 | 54 | if res.min()<0: 55 | vertices, triangles = mcubes.marching_cubes( 56 | res, 0.0) 57 | # print(vertices) 58 | # vertices -= 1.5 59 | # vertices /= 128 60 | vertices[:,0] += i #/ resolution 61 | vertices[:,1] += j #/ resolution 62 | vertices[:,2] += k #/ resolution 63 | triangles += v_num 64 | # vertices = 65 | # vertices[:,1] /= 3 # TODO 66 | v_all.append(vertices) 67 | t_all.append(triangles) 68 | 69 | v_num += vertices.shape[0] 70 | # print(v_num) 71 | 72 | v_all = np.concatenate(v_all) 73 | t_all = np.concatenate(t_all) 74 | # Create mesh 75 | v_all = v_all / (resolution - 1.0) * (b_max - b_min)[None, :] + b_min[None, :] 76 | 77 | mesh = trimesh.Trimesh(v_all, t_all, process=False) 78 | 79 | return mesh 80 | 81 | 82 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from random import sample 4 | import time 5 | from tkinter import Variable 6 | from shutil import copyfile 7 | import numpy as np 8 | import trimesh 9 | 10 | from scipy.spatial import cKDTree 11 | 12 | 13 | def get_aver(distances, face): 14 | return (distances[face[0]] + distances[face[1]] + distances[face[2]]) / 3.0 15 | 16 | def remove_far(gt_pts, mesh, dis_trunc=0.1, is_use_prj=False): 17 | # gt_pts: trimesh 18 | # mesh: trimesh 19 | 20 | gt_kd_tree = cKDTree(gt_pts) 21 | distances, vertex_ids = gt_kd_tree.query(mesh.vertices, p=2, distance_upper_bound=dis_trunc) 22 | faces_remaining = [] 23 | faces = mesh.faces 24 | 25 | if is_use_prj: 26 | normals = gt_pts.vertex_normals 27 | closest_points = gt_pts.vertices[vertex_ids] 28 | closest_normals = normals[vertex_ids] 29 | direction_from_surface = mesh.vertices - closest_points 30 | distances = direction_from_surface * closest_normals 31 | distances = np.sum(distances, axis=1) 32 | 33 | for i in range(faces.shape[0]): 34 | if get_aver(distances, faces[i]) < dis_trunc: 35 | faces_remaining.append(faces[i]) 36 | mesh_cleaned = mesh.copy() 37 | mesh_cleaned.faces = faces_remaining 38 | mesh_cleaned.remove_unreferenced_vertices() 39 | 40 | return mesh_cleaned 41 | 42 | def remove_outlier(gt_pts, q_pts, dis_trunc=0.003, is_use_prj=False): 43 | # gt_pts: trimesh 44 | # mesh: trimesh 45 | 46 | gt_kd_tree = cKDTree(gt_pts) 47 | distances, q_ids = gt_kd_tree.query(q_pts, p=2, distance_upper_bound=dis_trunc) 48 | 49 | q_pts = q_pts[distancesd} cd_l1 = {} lr={}'.format(self.iter_step, loss_cd, self.optimizer.param_groups[0]['lr']), logger=logger) 144 | 145 | if self.iter_step == self.step1_maxiter or self.iter_step == self.step2_maxiter: 146 | self.save_checkpoint() 147 | 148 | if self.iter_step == self.step1_maxiter: 149 | gen_pointclouds = self.gen_extra_pointcloud(self.iter_step, self.conf.get_float('udf_train.low_range')) 150 | idx = pcu.downsample_point_cloud_poisson_disk(gen_pointclouds, num_samples=int(self.conf.get_float('udf_train.extra_points_rate')*point_gt.shape[0])) 151 | poisson_pointclouds = gen_pointclouds[idx] 152 | dense_pointclouds = np.concatenate((point_gt.detach().cpu().numpy(), poisson_pointclouds)) 153 | self.ptree = cKDTree(dense_pointclouds) 154 | self.dataset.gen_new_data(self.ptree) 155 | 156 | if self.iter_step == self.step2_maxiter: 157 | gen_pointclouds = self.gen_extra_pointcloud(self.iter_step, 1) 158 | 159 | # if self.iter_step == self.step1_maxiter or self.iter_step == self.step2_maxiter: 160 | # self.extract_mesh(resolution=args.mcube_resolution, threshold=0.0, point_gt=point_gt, iter_step=self.iter_step, logger=logger) 161 | 162 | 163 | def extract_mesh(self, resolution=64, threshold=0.0, point_gt=None, iter_step=0, logger=None): 164 | 165 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 166 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 167 | out_dir = os.path.join(self.base_exp_dir, 'mesh') 168 | os.makedirs(out_dir, exist_ok=True) 169 | 170 | mesh = extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold, \ 171 | out_dir=out_dir, iter_step=iter_step, dataname=self.dataname, logger=logger, \ 172 | query_func=lambda pts: self.udf_network.udf(pts), grad_func=lambda pts: self.udf_network.gradient(pts)) 173 | if self.conf.get_float('udf_train.far') > 0: 174 | mesh = remove_far(point_gt.detach().cpu().numpy(), mesh, self.conf.get_float('udf_train.far')) 175 | 176 | mesh.export(out_dir+'/'+str(iter_step)+'_mesh.obj') 177 | 178 | 179 | 180 | def gen_extra_pointcloud(self, iter_step, low_range): 181 | 182 | res = [] 183 | num_points = self.eval_num_points 184 | gen_nums = 0 185 | 186 | os.makedirs(os.path.join(self.base_exp_dir, 'pointcloud'), exist_ok=True) 187 | 188 | while gen_nums < num_points: 189 | 190 | points, samples, point_gt = self.dataset.get_train_data(5000) 191 | offsets = samples - points 192 | std = torch.std(offsets) 193 | 194 | extra_std = std * low_range 195 | rands = torch.normal(0.0, extra_std, size=points.shape) 196 | samples = points + torch.tensor(rands).cuda().float() 197 | 198 | samples.requires_grad = True 199 | gradients_sample = self.udf_network.gradient(samples).squeeze() # 5000x3 200 | udf_sample = self.udf_network.udf(samples) # 5000x1 201 | grad_norm = F.normalize(gradients_sample, dim=1) # 5000x3 202 | sample_moved = samples - grad_norm * udf_sample # 5000x3 203 | 204 | index = udf_sample < self.df_filter 205 | index = index.squeeze(1) 206 | sample_moved = sample_moved[index] 207 | 208 | gen_nums += sample_moved.shape[0] 209 | 210 | res.append(sample_moved.detach().cpu().numpy()) 211 | 212 | res = np.concatenate(res) 213 | res = res[:num_points] 214 | np.savetxt(os.path.join(self.base_exp_dir, 'pointcloud', 'point_cloud%d.xyz'%(iter_step)), res) 215 | 216 | res = remove_outlier(point_gt.detach().cpu().numpy(), res, dis_trunc=self.conf.get_float('udf_train.outlier')) 217 | return res 218 | 219 | def update_learning_rate(self, iter_step): 220 | 221 | warn_up = self.warm_up_end 222 | max_iter = self.step2_maxiter 223 | init_lr = self.learning_rate 224 | lr = (iter_step / warn_up) if iter_step < warn_up else 0.5 * (math.cos((iter_step - warn_up)/(max_iter - warn_up) * math.pi) + 1) 225 | lr = lr * init_lr 226 | 227 | for g in self.optimizer.param_groups: 228 | g['lr'] = lr 229 | 230 | def file_backup(self): 231 | dir_lis = self.conf['general.recording'] 232 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 233 | for dir_name in dir_lis: 234 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 235 | os.makedirs(cur_dir, exist_ok=True) 236 | files = os.listdir(dir_name) 237 | for f_name in files: 238 | if f_name[-3:] == '.py': 239 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 240 | 241 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 242 | 243 | def load_checkpoint(self, checkpoint_name): 244 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) 245 | print(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name)) 246 | self.udf_network.load_state_dict(checkpoint['udf_network_fine']) 247 | 248 | self.iter_step = checkpoint['iter_step'] 249 | 250 | def save_checkpoint(self): 251 | checkpoint = { 252 | 'udf_network_fine': self.udf_network.state_dict(), 253 | 'iter_step': self.iter_step, 254 | } 255 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 256 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 257 | 258 | 259 | if __name__ == '__main__': 260 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 263 | parser.add_argument('--mcube_resolution', type=int, default=256) 264 | parser.add_argument('--gpu', type=int, default=0) 265 | parser.add_argument('--udf_dir', type=str, default='udf') 266 | parser.add_argument('--case', type=str, default='') 267 | args = parser.parse_args() 268 | 269 | torch.cuda.set_device(args.gpu) 270 | runner = UDFRunner(args, args.conf) 271 | 272 | runner.train() 273 | --------------------------------------------------------------------------------