├── .gitignore
├── LICENSE
├── README.md
├── confs
└── dtu.conf
├── exp_runner.py
├── extensions
└── chamfer_dist
│ ├── __init__.py
│ ├── chamfer.cu
│ ├── chamfer_cuda.cpp
│ ├── setup.py
│ └── test.py
├── media
├── comparison.png
└── pipeline.png
├── models
├── dataset.py
├── embedder.py
├── fields.py
├── renderer.py
├── udf_dataset.py
├── udf_embedder.py
└── udf_fields.py
├── pretrained_model
└── vismvsnet.pt
├── requirements.txt
├── tools
├── feat_utils.py
├── logger.py
├── surface_extraction.py
└── utils.py
└── udf_runner.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *,cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # IPython Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # dotenv
81 | .env
82 |
83 | # virtualenv
84 | venv/
85 | ENV/
86 |
87 | # Spyder project settings
88 | .spyderproject
89 |
90 | # Rope project settings
91 | .ropeproject
92 | ### VirtualEnv template
93 | # Virtualenv
94 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
95 | .Python
96 | [Bb]in
97 | [Ii]nclude
98 | [Ll]ib
99 | [Ll]ib64
100 | [Ll]ocal
101 | [Ss]cripts
102 | pyvenv.cfg
103 | .venv
104 | pip-selfcheck.json
105 | ### JetBrains template
106 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
107 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
108 |
109 | # User-specific stuff:
110 | .idea/workspace.xml
111 | .idea/tasks.xml
112 | .idea/dictionaries
113 | .idea/vcs.xml
114 | .idea/jsLibraryMappings.xml
115 |
116 | # Sensitive or high-churn files:
117 | .idea/dataSources.ids
118 | .idea/dataSources.xml
119 | .idea/dataSources.local.xml
120 | .idea/sqlDataSources.xml
121 | .idea/dynamic.xml
122 | .idea/uiDesigner.xml
123 |
124 | # Gradle:
125 | .idea/gradle.xml
126 | .idea/libraries
127 |
128 | # Mongo Explorer plugin:
129 | .idea/mongoSettings.xml
130 |
131 | .idea/
132 |
133 | ## File-based project format:
134 | *.iws
135 |
136 | ## Plugin-specific files:
137 |
138 | # IntelliJ
139 | /out/
140 |
141 | # mpeltonen/sbt-idea plugin
142 | .idea_modules/
143 |
144 | # JIRA plugin
145 | atlassian-ide-plugin.xml
146 |
147 | # Crashlytics plugin (for Android Studio and IntelliJ)
148 | com_crashlytics_export_strings.xml
149 | crashlytics.properties
150 | crashlytics-build.properties
151 | fabric.properties
152 |
153 | data/
154 | exp/
155 | .data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Han Huang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeuSurf
2 | Implementation of AAAI'24 paper *NeuSurf: On-Surface Priors for Neural Surface Reconstruction from Sparse Input Views*
3 |
4 | ### [Project Page](https://alvin528.github.io/NeuSurf/) | [Paper](https://arxiv.org/abs/2312.13977) | [Data](https://drive.google.com/drive/folders/18AZw4zi3fNQ-NKttNeVBp9cTja8NBnSA?usp=drive_link) | [Mesh Results](https://drive.google.com/drive/folders/1PVDJNa68OQm7Cisz2_CVNGQb4Zp40HSm?usp=sharing)
5 |
6 |
7 |
8 | ## Overview
9 |
10 |
11 |
12 | ## Installation
13 |
14 | Our code is implemented in Python 3.10, PyTorch 2.0.0 and CUDA 11.7.
15 | - Install Python dependencies
16 | ```
17 | conda create -n neusurf python=3.10
18 | conda activate neusurf
19 | pip install torch==2.0.0 torchvision==0.15.1
20 | pip install -r requirements.txt
21 | ```
22 | - Compile C++ extensions
23 | ```
24 | cd extensions/chamfer_dist
25 | python setup.py install
26 | ```
27 |
28 | ## Dataset
29 |
30 | Data structure:
31 |
32 | ```
33 | data
34 | |-- DTU_pixelnerf
35 | |--
36 | |-- cameras_sphere.npz
37 | |-- pcd
38 | |-- .ply
39 | |-- cam4feat
40 | |-- pair.txt
41 | |-- cam_00000000_flow3.txt
42 | |-- cam_00000001_flow3.txt
43 | ...
44 | |-- image
45 | |-- 000000.png
46 | |-- 000001.png
47 | ...
48 | |-- mask
49 | |-- 000.png
50 | |-- 001.png
51 | ...
52 | |-- DTU_sparseneus
53 | |-- blendedmvs_sparse
54 | ```
55 |
56 | You can directly download the processed data [here](https://drive.google.com/drive/folders/18AZw4zi3fNQ-NKttNeVBp9cTja8NBnSA?usp=drive_link).
57 |
58 | ## Running
59 |
60 | - Training
61 |
62 | ```
63 | CUDA_VISIBLE_DEVICES=0
64 | python exp_runner.py --mode train --conf ./confs/dtu.conf --case
65 | ```
66 |
67 | - Extract mesh
68 |
69 | ```
70 | CUDA_VISIBLE_DEVICES=0
71 | python exp_runner.py --mode validate_mesh --conf ./confs/dtu.conf --case --is_continue
72 | ```
73 |
74 |
75 |
76 | ## Citation
77 |
78 | If you find our work useful in your research, please consider citing:
79 |
80 | ```bibtex
81 | @inproceedings{huang2024neusurf,
82 | title={NeuSurf: On-Surface Priors for Neural Surface Reconstruction from Sparse Input Views},
83 | author={Huang, Han and Wu, Yulun and Zhou, Junsheng and Gao, Ge and Gu, Ming and Liu, Yu-Shen},
84 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
85 | volume={38},
86 | number={3},
87 | pages={2312--2320},
88 | year={2024}
89 | }
90 | ```
91 |
92 | ## Acknowledgement
93 |
94 | This implementation is based on [CAP-UDF](https://github.com/junshengzhou/CAP-UDF/), [D-NeuS](https://github.com/fraunhoferhhi/D-NeuS) and [Vis-MVSNet](https://github.com/jzhangbs/Vis-MVSNet). Thanks for these great works.
95 |
--------------------------------------------------------------------------------
/confs/dtu.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_dir = ./exp/DTU_pixelnerf/CASE_NAME/
3 | base_exp_dir = ./exp/DTU_pixelnerf/CASE_NAME/womask_sphere
4 | recording = [
5 | ./,
6 | ./models
7 | ]
8 | }
9 |
10 | udf_dataset {
11 | data_dir = ./data/DTU_pixelnerf/CASE_NAME/
12 | }
13 |
14 | dataset {
15 | data_dir = ./data/DTU_pixelnerf/CASE_NAME/
16 | render_cameras_name = cameras_sphere.npz
17 | object_cameras_name = cameras_sphere.npz
18 | feat_map_h = 384
19 | feat_map_w = 512
20 | }
21 |
22 | udf_train {
23 | learning_rate = 0.001
24 | step1_maxiter = 40000
25 | step2_maxiter = 60000
26 | warm_up_end = 1000
27 | eval_num_points = 1000000
28 | df_filter = 0.01
29 | far = -1
30 | outlier = 0.002
31 | extra_points_rate = 1
32 | low_range = 1.1
33 |
34 | batch_size = 5000
35 | batch_size_step2 = 20000
36 |
37 | save_freq = 5000
38 | val_freq = 2500
39 | val_mesh_freq = 2500
40 | report_freq = 5000
41 |
42 | igr_weight = 0.1
43 | mask_weight = 0.0
44 | load_ckpt = none
45 | }
46 |
47 | train {
48 | learning_rate = 5e-4
49 | learning_rate_alpha = 0.05
50 | end_iter = 300000
51 |
52 | batch_size = 512
53 | validate_resolution_level = 4
54 | warm_up_end = 5000
55 | anneal_end = 50000
56 | use_white_bkgd = False
57 |
58 | save_freq = 1000
59 | val_freq = 1000
60 | val_mesh_freq = 1000
61 | report_freq = 1000
62 |
63 | igr_weight = 0.1
64 | mask_weight = 0.0
65 |
66 | udf_thresh = 5e-2
67 |
68 | phase_delim = [0.16667, 0.5]
69 | local_weight = [0.0, 0.5, 0.05]
70 | pseudo_reg_weight = [0.01, 0.1, 0.01]
71 | depth_from_inside_only = [False, True, True]
72 | }
73 |
74 | udf_model {
75 | ckpt = 60000
76 |
77 | udf_network {
78 | d_out = 1
79 | d_in = 3
80 | d_hidden = 256
81 | n_layers = 8
82 | skip_in = [4]
83 | multires = 0
84 | bias = 0.5
85 | scale = 1.0
86 | geometric_init = True
87 | weight_norm = True
88 | }
89 | }
90 |
91 | model {
92 | nerf {
93 | D = 8,
94 | d_in = 4,
95 | d_in_view = 3,
96 | W = 256,
97 | multires = 10,
98 | multires_view = 4,
99 | output_ch = 4,
100 | skips=[4],
101 | use_viewdirs=True
102 | }
103 |
104 | sdf_network {
105 | d_out = 257
106 | d_in = 3
107 | d_hidden = 256
108 | n_layers = 8
109 | skip_in = [4]
110 | multires = 6
111 | bias = 0.5
112 | scale = 1.0
113 | geometric_init = True
114 | weight_norm = True
115 | }
116 |
117 | variance_network {
118 | init_val = 0.3
119 | }
120 |
121 | rendering_network {
122 | d_feature = 256
123 | mode = idr
124 | d_in = 9
125 | d_out = 3
126 | d_hidden = 256
127 | n_layers = 4
128 | weight_norm = True
129 | multires_view = 4
130 | squeeze_out = True
131 | }
132 |
133 | neus_renderer {
134 | n_samples = 64
135 | n_importance = 64
136 | n_outside = 32
137 | up_sample_steps = 4
138 | perturb = 1.0
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/exp_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import argparse
4 | import numpy as np
5 | import cv2 as cv
6 | import trimesh
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.tensorboard import SummaryWriter
10 | from shutil import copyfile
11 | from tqdm import tqdm
12 | from pyhocon import ConfigFactory
13 | from models.dataset import Dataset
14 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF
15 | from models.renderer import NeuSRenderer
16 |
17 | from models.udf_fields import UDFNetwork
18 | from udf_runner import UDFRunner
19 |
20 |
21 | class Runner:
22 | def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False, ckpt=None):
23 | self.device = torch.device('cuda')
24 |
25 | # Configuration
26 | self.conf_path = conf_path
27 | f = open(self.conf_path)
28 | conf_text = f.read()
29 | conf_text = conf_text.replace('CASE_NAME', case)
30 | f.close()
31 |
32 | self.conf = ConfigFactory.parse_string(conf_text)
33 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case)
34 | self.base_exp_dir = self.conf['general.base_exp_dir']
35 | os.makedirs(self.base_exp_dir, exist_ok=True)
36 | self.dataset = Dataset(self.conf['dataset'])
37 | self.iter_step = 0
38 |
39 | # Training parameters
40 | self.end_iter = self.conf.get_int('train.end_iter')
41 | self.save_freq = self.conf.get_int('train.save_freq')
42 | self.report_freq = self.conf.get_int('train.report_freq')
43 | self.val_freq = self.conf.get_int('train.val_freq')
44 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
45 | self.batch_size = self.conf.get_int('train.batch_size')
46 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
47 | self.learning_rate = self.conf.get_float('train.learning_rate')
48 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha')
49 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd')
50 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0)
51 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
52 |
53 | # Weights
54 | self.igr_weight = self.conf.get_float('train.igr_weight')
55 | self.mask_weight = self.conf.get_float('train.mask_weight')
56 | self.is_continue = is_continue
57 | self.mode = mode
58 | self.model_list = []
59 | self.writer = None
60 | self.ckpt = ckpt
61 |
62 | self.udf_thresh = self.conf.get_float('train.udf_thresh')
63 |
64 | self.phase_delims = self.conf.get_list('train.phase_delim')
65 | self.pseudo_reg_weights = self.conf.get_list('train.pseudo_reg_weight')
66 | self.local_weights = self.conf.get_list('train.local_weight')
67 | self.depth_from_inside_only_s = self.conf.get_list('train.depth_from_inside_only')
68 |
69 | def get_param_in_phase(param_list, phase):
70 | if phase < self.phase_delims[0]:
71 | return param_list[0]
72 | elif phase < self.phase_delims[1]:
73 | return param_list[1]
74 | else:
75 | return param_list[2]
76 | self.get_param_in_phase = get_param_in_phase
77 |
78 | # Networks
79 | params_to_train = []
80 | self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device)
81 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device)
82 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
83 | self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device)
84 | params_to_train += list(self.nerf_outside.parameters())
85 | params_to_train += list(self.sdf_network.parameters())
86 | params_to_train += list(self.deviation_network.parameters())
87 | params_to_train += list(self.color_network.parameters())
88 |
89 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
90 |
91 | self.renderer = NeuSRenderer(self.nerf_outside,
92 | self.sdf_network,
93 | self.deviation_network,
94 | self.color_network,
95 | self.dataset,
96 | **self.conf['model.neus_renderer'])
97 |
98 | pointcloud = trimesh.load('{}/pcd/{}.ply'.format(self.conf['dataset.data_dir'], case)).vertices
99 | pointcloud = np.asarray(pointcloud)
100 | self.shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])])
101 | self.shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2]
102 | with torch.no_grad():
103 | self.shape_center = torch.Tensor(self.shape_center)
104 |
105 | scale_pcd = False
106 | if scale_pcd:
107 | pointcloud = (pointcloud - self.dataset.scale_mats_np[0][:3, 3][None]) / self.dataset.scale_mats_np[0][0, 0]
108 | self.pointcloud = torch.tensor(pointcloud, requires_grad=False, dtype=torch.float32).to(self.device)
109 |
110 | self.udf_network = UDFNetwork(**self.conf['udf_model.udf_network']).to(self.device)
111 | udf_ckpt_path = '{}/udf/checkpoints/ckpt_{:0>6}.pth'.format(self.conf['general.base_dir'].replace('CASE_NAME', case),
112 | self.conf['udf_model.ckpt'])
113 | self.udf_network.load_state_dict(torch.load(udf_ckpt_path, map_location=self.device)['udf_network_fine'])
114 | self.udf_network.eval()
115 | for p in self.udf_network.parameters():
116 | p.requires_grad = False
117 | logging.info('UDF network successfully loaded')
118 |
119 | # Load checkpoint
120 | latest_model_name = None
121 | if is_continue:
122 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints'))
123 | model_list = []
124 | for model_name in model_list_raw:
125 | if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter:
126 | model_list.append(model_name)
127 | model_list.sort()
128 | if self.ckpt == 'latest':
129 | latest_model_name = model_list[-1]
130 | else:
131 | latest_model_name = 'ckpt_{:0>6}.pth'.format(self.ckpt)
132 |
133 | if latest_model_name is not None:
134 | logging.info('Find checkpoint: {}'.format(latest_model_name))
135 | self.load_checkpoint(latest_model_name)
136 |
137 | # Backup codes and configs for debug
138 | if self.mode[:5] == 'train':
139 | self.file_backup()
140 |
141 | def init_params(self):
142 | self.iter_step = 0
143 | self.learning_rate = self.conf.get_float('train.learning_rate')
144 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha')
145 |
146 | params_to_train = []
147 | params_to_train += list(self.nerf_outside.parameters())
148 | params_to_train += list(self.sdf_network.parameters())
149 | params_to_train += list(self.deviation_network.parameters())
150 | params_to_train += list(self.color_network.parameters())
151 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
152 |
153 | def train(self, prior_initialization=False):
154 | if not self.is_continue:
155 | self.init_params()
156 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs'))
157 | self.update_learning_rate()
158 | res_step = self.end_iter - self.iter_step
159 | if prior_initialization:
160 | res_step = 5000 - self.iter_step
161 | image_perm = self.get_image_perm()
162 |
163 | for iter_i in tqdm(range(res_step)):
164 | main_img_idx = image_perm[self.iter_step % len(image_perm)]
165 | data ,sample = self.dataset.gen_random_rays_at(main_img_idx, self.batch_size)
166 |
167 | rays_o, rays_d, true_rgb = data[:, :3], data[:, 3: 6], data[:, 6: 9]
168 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
169 | model_input = {}
170 | for attr in ['depth_cams','size', 'center', 'feat', 'feat_src','cam', 'src_cams', 'rays_d_norm']:
171 | model_input[attr] = sample[attr].cuda()
172 | for attr in ['H', 'W', 'src_idxs']:
173 | model_input[attr] = sample[attr]
174 |
175 | background_rgb = None
176 | if self.use_white_bkgd:
177 | background_rgb = torch.ones([1, 3])
178 |
179 | mask = torch.ones_like(true_rgb[...,0:1])
180 | mask_sum = mask.sum() + 1e-6
181 | train_phase = self.iter_step/self.end_iter
182 |
183 | random_pts = self.pointcloud[torch.randperm(self.pointcloud.shape[0])[:512]]
184 | if prior_initialization:
185 | random_pts = None
186 | render_out = self.renderer.render(rays_o, rays_d, near, far, main_img_idx,
187 | t=self.iter_step + 1,
188 | random_pcd=random_pts,
189 | background_rgb=background_rgb,
190 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
191 | model_input = model_input,
192 | depth_from_inside_only = self.get_param_in_phase(self.depth_from_inside_only_s, train_phase),
193 | )
194 |
195 | color_fine = render_out['color_fine']
196 | s_val = render_out['s_val']
197 | cdf_fine = render_out['cdf_fine']
198 | gradient_error = render_out['gradient_error']
199 | weight_max = render_out['weight_max']
200 | weight_sum = render_out['weight_sum']
201 | pseudo_pts_reg_loss = render_out['pseudo_pts_loss']
202 | local_loss = render_out['local_loss']
203 |
204 | query_sdf = render_out['sdf']
205 | query_pts = render_out['query_pts']
206 | with torch.no_grad():
207 | udf = self.shape_scale * self.udf_network.udf((query_pts - self.shape_center) / self.shape_scale)
208 | udf = udf.reshape(query_sdf.size())
209 |
210 | udf[udf < self.udf_thresh] = 0.0
211 | udf_residual = torch.abs(torch.abs(query_sdf) - udf)
212 | udf_residual[(udf > self.udf_thresh)] = 0.0
213 | global_loss = F.l1_loss(udf_residual, torch.zeros_like(udf_residual))
214 |
215 | random_pcd = self.pointcloud[torch.randperm(self.pointcloud.shape[0])[:30000]]
216 | sdf = self.renderer.sdf_network.sdf(random_pcd)
217 | pcd_loss = F.l1_loss(sdf, torch.zeros_like(sdf),
218 | reduction='sum') / random_pcd.shape[0]
219 |
220 | # Loss
221 | color_error = (color_fine - true_rgb) * mask
222 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum
223 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt())
224 |
225 | eikonal_loss = gradient_error
226 |
227 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask)
228 |
229 | pseudo_pts_reg_loss = self.get_param_in_phase(self.pseudo_reg_weights, train_phase) * pseudo_pts_reg_loss
230 | local_loss = self.get_param_in_phase(self.local_weights, train_phase) * local_loss
231 |
232 | loss = eikonal_loss * self.igr_weight +\
233 | mask_loss * self.mask_weight +\
234 | global_loss * 0.1
235 |
236 | if not prior_initialization:
237 | loss += color_fine_loss +\
238 | pcd_loss +\
239 | pseudo_pts_reg_loss +\
240 | local_loss
241 |
242 | self.optimizer.zero_grad()
243 | loss.backward()
244 | self.optimizer.step()
245 |
246 | self.iter_step += 1
247 |
248 | self.writer.add_scalar('Loss/loss', loss, self.iter_step)
249 | self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step)
250 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step)
251 | self.writer.add_scalar('Loss/pseudo_pts_reg_loss', pseudo_pts_reg_loss, self.iter_step)
252 | self.writer.add_scalar('Loss/local_loss', local_loss, self.iter_step)
253 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step)
254 | self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step)
255 | self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step)
256 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step)
257 |
258 | if self.iter_step % self.report_freq == 0:
259 | print(self.base_exp_dir)
260 | print('iter:{:8>d} loss = {} lr={}'.format(self.iter_step, loss, self.optimizer.param_groups[0]['lr']))
261 |
262 | if self.iter_step % self.save_freq == 0:
263 | self.save_checkpoint()
264 |
265 | if self.iter_step % self.val_freq == 0:
266 | self.validate_image()
267 |
268 | if self.iter_step % self.val_mesh_freq == 0:
269 | self.validate_mesh()
270 |
271 | self.update_learning_rate()
272 |
273 | if self.iter_step % len(image_perm) == 0:
274 | image_perm = self.get_image_perm()
275 |
276 | def get_image_perm(self):
277 | return torch.randperm(self.dataset.n_images)
278 |
279 | def get_cos_anneal_ratio(self):
280 | if self.anneal_end == 0.0:
281 | return 1.0
282 | else:
283 | return np.min([1.0, self.iter_step / self.anneal_end])
284 |
285 | def update_learning_rate(self):
286 | if self.iter_step < self.warm_up_end:
287 | learning_factor = self.iter_step / self.warm_up_end
288 | else:
289 | alpha = self.learning_rate_alpha
290 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)
291 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha
292 |
293 | for g in self.optimizer.param_groups:
294 | g['lr'] = self.learning_rate * learning_factor
295 |
296 | def file_backup(self):
297 | dir_lis = self.conf['general.recording']
298 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
299 | for dir_name in dir_lis:
300 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
301 | os.makedirs(cur_dir, exist_ok=True)
302 | files = os.listdir(dir_name)
303 | for f_name in files:
304 | if f_name[-3:] == '.py':
305 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
306 |
307 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
308 |
309 | def load_checkpoint(self, checkpoint_name):
310 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
311 | self.nerf_outside.load_state_dict(checkpoint['nerf'])
312 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
313 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
314 | self.color_network.load_state_dict(checkpoint['color_network_fine'])
315 | self.optimizer.load_state_dict(checkpoint['optimizer'])
316 | self.iter_step = checkpoint['iter_step']
317 |
318 | logging.info('End')
319 |
320 | def save_checkpoint(self):
321 | checkpoint = {
322 | 'nerf': self.nerf_outside.state_dict(),
323 | 'sdf_network_fine': self.sdf_network.state_dict(),
324 | 'variance_network_fine': self.deviation_network.state_dict(),
325 | 'color_network_fine': self.color_network.state_dict(),
326 | 'optimizer': self.optimizer.state_dict(),
327 | 'iter_step': self.iter_step,
328 | }
329 |
330 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
331 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
332 |
333 | def validate_image(self, idx=-1, resolution_level=-1):
334 | if idx < 0:
335 | idx = np.random.randint(self.dataset.n_images)
336 |
337 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
338 |
339 | if resolution_level < 0:
340 | resolution_level = self.validate_resolution_level
341 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
342 | H, W, _ = rays_o.shape
343 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
344 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
345 |
346 | out_rgb_fine = []
347 | out_normal_fine = []
348 |
349 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
350 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
351 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
352 |
353 | render_out = self.renderer.render(rays_o_batch,
354 | rays_d_batch,
355 | near,
356 | far,
357 | None,
358 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
359 | background_rgb=background_rgb, t=self.iter_step + 1)
360 |
361 | def feasible(key): return (key in render_out) and (render_out[key] is not None)
362 |
363 | if feasible('color_fine'):
364 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
365 | if feasible('gradients') and feasible('weights'):
366 | n_samples = self.renderer.n_samples + self.renderer.n_importance
367 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
368 | if feasible('inside_sphere'):
369 | normals = normals * render_out['inside_sphere'][..., None]
370 | normals = normals.sum(dim=1).detach().cpu().numpy()
371 | out_normal_fine.append(normals)
372 | del render_out
373 |
374 | img_fine = None
375 | if len(out_rgb_fine) > 0:
376 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
377 |
378 | normal_img = None
379 | if len(out_normal_fine) > 0:
380 | normal_img = np.concatenate(out_normal_fine, axis=0)
381 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
382 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None])
383 | .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)
384 |
385 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True)
386 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True)
387 |
388 | for i in range(img_fine.shape[-1]):
389 | if len(out_rgb_fine) > 0:
390 | cv.imwrite(os.path.join(self.base_exp_dir,
391 | 'validations_fine',
392 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
393 | np.concatenate([img_fine[..., i],
394 | self.dataset.image_at(idx, resolution_level=resolution_level)]))
395 | if len(out_normal_fine) > 0:
396 | cv.imwrite(os.path.join(self.base_exp_dir,
397 | 'normals',
398 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
399 | normal_img[..., i])
400 |
401 | def rendering_image(self, idx=-1, resolution_level=-1):
402 | if idx < 0:
403 | idx = np.random.randint(self.dataset.n_images)
404 |
405 | if resolution_level < 0:
406 | resolution_level = self.validate_resolution_level
407 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
408 | H, W, _ = rays_o.shape
409 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
410 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
411 |
412 | out_rgb_fine = []
413 |
414 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
415 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
416 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
417 |
418 | render_out = self.renderer.render(rays_o_batch,
419 | rays_d_batch,
420 | near,
421 | far,
422 | None,
423 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
424 | background_rgb=background_rgb, t=self.iter_step + 1)
425 |
426 | def feasible(key): return (key in render_out) and (render_out[key] is not None)
427 |
428 | if feasible('color_fine'):
429 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
430 |
431 | img_fine = None
432 | if len(out_rgb_fine) > 0:
433 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
434 |
435 |
436 | os.makedirs(os.path.join(self.base_exp_dir, 'rendering_image'), exist_ok=True)
437 |
438 |
439 | for i in range(img_fine.shape[-1]):
440 | if len(out_rgb_fine) > 0:
441 | cv.imwrite(os.path.join(self.base_exp_dir,
442 | 'rendering_image',
443 | '{}.png'.format(idx)),
444 | img_fine[..., i]
445 | )
446 |
447 | def output_rendering_image (self, resolution = -1 ):
448 | for i in range (self.dataset.n_images):
449 | self.rendering_image(idx= i , resolution_level=resolution)
450 |
451 |
452 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
453 | """
454 | Interpolate view between two cameras.
455 | """
456 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level)
457 | H, W, _ = rays_o.shape
458 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
459 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
460 |
461 | out_rgb_fine = []
462 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
463 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
464 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
465 |
466 | render_out = self.renderer.render(rays_o_batch,
467 | rays_d_batch,
468 | near,
469 | far,
470 | None,
471 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
472 | background_rgb=background_rgb, t=self.iter_step + 1)
473 |
474 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
475 |
476 | del render_out
477 |
478 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8)
479 | return img_fine
480 |
481 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
482 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
483 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)
484 |
485 | vertices, triangles =\
486 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
487 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)
488 |
489 | if world_space:
490 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
491 |
492 | mesh = trimesh.Trimesh(vertices, triangles)
493 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}_{}.ply'.format(self.iter_step, resolution)))
494 |
495 | logging.info('End')
496 |
497 | def interpolate_view(self, img_idx_0, img_idx_1):
498 | images = []
499 | n_frames = 60
500 | for i in range(n_frames):
501 | print(i)
502 | images.append(self.render_novel_image(img_idx_0,
503 | img_idx_1,
504 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5,
505 | resolution_level=4))
506 | for i in range(n_frames):
507 | images.append(images[n_frames - i - 1])
508 |
509 | fourcc = cv.VideoWriter_fourcc(*'mp4v')
510 | video_dir = os.path.join(self.base_exp_dir, 'render')
511 | os.makedirs(video_dir, exist_ok=True)
512 | h, w, _ = images[0].shape
513 | writer = cv.VideoWriter(os.path.join(video_dir,
514 | '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)),
515 | fourcc, 30, (w, h))
516 |
517 | for image in images:
518 | writer.write(image)
519 |
520 | writer.release()
521 |
522 |
523 | if __name__ == '__main__':
524 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
525 |
526 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
527 | logging.basicConfig(level=logging.DEBUG, format=FORMAT)
528 |
529 | parser = argparse.ArgumentParser()
530 | parser.add_argument('--conf', type=str, default='./conf')
531 | parser.add_argument('--udf_dir', type=str, default='udf')
532 | parser.add_argument('--mode', type=str, default='train')
533 | parser.add_argument('--mcube_threshold', type=float, default=0.0)
534 | parser.add_argument('--volume_resolution', type=int, default=512)
535 | parser.add_argument('--is_continue', default=False, action="store_true")
536 | parser.add_argument('--gpu', type=int, default=0)
537 | parser.add_argument('--case', type=str, default='')
538 | parser.add_argument('--world_space', default=False, action="store_true")
539 | parser.add_argument('--ckpt', type=str, default='latest')
540 |
541 | args = parser.parse_args()
542 |
543 | torch.cuda.set_device(args.gpu)
544 | udf_runner = UDFRunner(args, args.conf)
545 |
546 | if args.mode == 'train':
547 | if not os.path.exists(f'{udf_runner.base_exp_dir}/checkpoints/ckpt_060000.pth'):
548 | udf_runner.train()
549 |
550 | runner = Runner(args.conf, args.mode, args.case, args.is_continue, args.ckpt)
551 |
552 | if args.mode == 'train':
553 | if not args.is_continue:
554 | base_exp_dir = runner.base_exp_dir
555 | os.makedirs(base_exp_dir, exist_ok=True)
556 | runner.base_exp_dir = os.path.join(base_exp_dir, 'prior_initialization')
557 | os.makedirs(runner.base_exp_dir, exist_ok=True)
558 | runner.train(prior_initialization=True)
559 | runner.base_exp_dir = base_exp_dir
560 | runner.train()
561 | runner.validate_mesh(world_space=True, resolution=args.volume_resolution, threshold=args.mcube_threshold)
562 | elif args.mode == 'validate_mesh':
563 | runner.validate_mesh(world_space=args.world_space, resolution=args.volume_resolution, threshold=args.mcube_threshold)
564 | elif args.mode == 'render_image':
565 | runner.output_rendering_image(resolution=1)
566 | elif args.mode.startswith('interpolate'): # Interpolate views given two image indices
567 | _, img_idx_0, img_idx_1 = args.mode.split('_')
568 | img_idx_0 = int(img_idx_0)
569 | img_idx_1 = int(img_idx_1)
570 | runner.interpolate_view(img_idx_0, img_idx_1)
571 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import chamfer
4 |
5 |
6 | class ChamferFunction(torch.autograd.Function):
7 | @staticmethod
8 | def forward(ctx, xyz1, xyz2):
9 | dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)
10 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
11 |
12 | return dist1, dist2
13 |
14 | @staticmethod
15 | def backward(ctx, grad_dist1, grad_dist2):
16 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
17 | grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
18 | return grad_xyz1, grad_xyz2
19 |
20 |
21 | class ChamferDistanceL2(torch.nn.Module):
22 | f''' Chamder Distance L2
23 | '''
24 | def __init__(self, ignore_zeros=False):
25 | super().__init__()
26 | self.ignore_zeros = ignore_zeros
27 |
28 | def forward(self, xyz1, xyz2):
29 | batch_size = xyz1.size(0)
30 | if batch_size == 1 and self.ignore_zeros:
31 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
32 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
33 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
34 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
35 |
36 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
37 | return torch.mean(dist1) + torch.mean(dist2)
38 |
39 | class ChamferDistanceL2_split(torch.nn.Module):
40 | f''' Chamder Distance L2
41 | '''
42 | def __init__(self, ignore_zeros=False):
43 | super().__init__()
44 | self.ignore_zeros = ignore_zeros
45 |
46 | def forward(self, xyz1, xyz2):
47 | batch_size = xyz1.size(0)
48 | if batch_size == 1 and self.ignore_zeros:
49 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
50 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
51 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
52 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
53 |
54 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
55 | return torch.mean(dist1), torch.mean(dist2)
56 |
57 | class ChamferDistanceL1(torch.nn.Module):
58 | f''' Chamder Distance L1
59 | '''
60 | def __init__(self, ignore_zeros=False):
61 | super().__init__()
62 | self.ignore_zeros = ignore_zeros
63 |
64 | def forward(self, xyz1, xyz2):
65 | batch_size = xyz1.size(0)
66 | if batch_size == 1 and self.ignore_zeros:
67 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
68 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
69 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
70 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
71 |
72 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
73 | # import pdb
74 | # pdb.set_trace()
75 | dist1 = torch.sqrt(dist1)
76 | dist2 = torch.sqrt(dist2)
77 | return (torch.mean(dist1) + torch.mean(dist2))/2
78 |
79 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include
6 |
7 | __global__ void chamfer_dist_kernel(int batch_size,
8 | int n,
9 | const float* xyz1,
10 | int m,
11 | const float* xyz2,
12 | float* dist,
13 | int* indexes) {
14 | const int batch = 512;
15 | __shared__ float buf[batch * 3];
16 | for (int i = blockIdx.x; i < batch_size; i += gridDim.x) {
17 | for (int k2 = 0; k2 < m; k2 += batch) {
18 | int end_k = min(m, k2 + batch) - k2;
19 | for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {
20 | buf[j] = xyz2[(i * m + k2) * 3 + j];
21 | }
22 | __syncthreads();
23 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
24 | j += blockDim.x * gridDim.y) {
25 | float x1 = xyz1[(i * n + j) * 3 + 0];
26 | float y1 = xyz1[(i * n + j) * 3 + 1];
27 | float z1 = xyz1[(i * n + j) * 3 + 2];
28 | float best_dist = 0;
29 | int best_dist_index = 0;
30 | int end_ka = end_k - (end_k & 3);
31 | if (end_ka == batch) {
32 | for (int k = 0; k < batch; k += 4) {
33 | {
34 | float x2 = buf[k * 3 + 0] - x1;
35 | float y2 = buf[k * 3 + 1] - y1;
36 | float z2 = buf[k * 3 + 2] - z1;
37 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
38 |
39 | if (k == 0 || dist < best_dist) {
40 | best_dist = dist;
41 | best_dist_index = k + k2;
42 | }
43 | }
44 | {
45 | float x2 = buf[k * 3 + 3] - x1;
46 | float y2 = buf[k * 3 + 4] - y1;
47 | float z2 = buf[k * 3 + 5] - z1;
48 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
49 | if (dist < best_dist) {
50 | best_dist = dist;
51 | best_dist_index = k + k2 + 1;
52 | }
53 | }
54 | {
55 | float x2 = buf[k * 3 + 6] - x1;
56 | float y2 = buf[k * 3 + 7] - y1;
57 | float z2 = buf[k * 3 + 8] - z1;
58 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
59 | if (dist < best_dist) {
60 | best_dist = dist;
61 | best_dist_index = k + k2 + 2;
62 | }
63 | }
64 | {
65 | float x2 = buf[k * 3 + 9] - x1;
66 | float y2 = buf[k * 3 + 10] - y1;
67 | float z2 = buf[k * 3 + 11] - z1;
68 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
69 | if (dist < best_dist) {
70 | best_dist = dist;
71 | best_dist_index = k + k2 + 3;
72 | }
73 | }
74 | }
75 | } else {
76 | for (int k = 0; k < end_ka; k += 4) {
77 | {
78 | float x2 = buf[k * 3 + 0] - x1;
79 | float y2 = buf[k * 3 + 1] - y1;
80 | float z2 = buf[k * 3 + 2] - z1;
81 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
82 | if (k == 0 || dist < best_dist) {
83 | best_dist = dist;
84 | best_dist_index = k + k2;
85 | }
86 | }
87 | {
88 | float x2 = buf[k * 3 + 3] - x1;
89 | float y2 = buf[k * 3 + 4] - y1;
90 | float z2 = buf[k * 3 + 5] - z1;
91 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
92 | if (dist < best_dist) {
93 | best_dist = dist;
94 | best_dist_index = k + k2 + 1;
95 | }
96 | }
97 | {
98 | float x2 = buf[k * 3 + 6] - x1;
99 | float y2 = buf[k * 3 + 7] - y1;
100 | float z2 = buf[k * 3 + 8] - z1;
101 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
102 | if (dist < best_dist) {
103 | best_dist = dist;
104 | best_dist_index = k + k2 + 2;
105 | }
106 | }
107 | {
108 | float x2 = buf[k * 3 + 9] - x1;
109 | float y2 = buf[k * 3 + 10] - y1;
110 | float z2 = buf[k * 3 + 11] - z1;
111 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
112 | if (dist < best_dist) {
113 | best_dist = dist;
114 | best_dist_index = k + k2 + 3;
115 | }
116 | }
117 | }
118 | }
119 | for (int k = end_ka; k < end_k; k++) {
120 | float x2 = buf[k * 3 + 0] - x1;
121 | float y2 = buf[k * 3 + 1] - y1;
122 | float z2 = buf[k * 3 + 2] - z1;
123 | float dist = x2 * x2 + y2 * y2 + z2 * z2;
124 | if (k == 0 || dist < best_dist) {
125 | best_dist = dist;
126 | best_dist_index = k + k2;
127 | }
128 | }
129 | if (k2 == 0 || dist[(i * n + j)] > best_dist) {
130 | dist[(i * n + j)] = best_dist;
131 | indexes[(i * n + j)] = best_dist_index;
132 | }
133 | }
134 | __syncthreads();
135 | }
136 | }
137 | }
138 |
139 | std::vector chamfer_cuda_forward(torch::Tensor xyz1,
140 | torch::Tensor xyz2) {
141 | const int batch_size = xyz1.size(0);
142 | const int n = xyz1.size(1); // num_points point cloud A
143 | const int m = xyz2.size(1); // num_points point cloud B
144 | torch::Tensor dist1 =
145 | torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));
146 | torch::Tensor dist2 =
147 | torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));
148 | torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
149 | torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
150 |
151 | chamfer_dist_kernel<<>>(
152 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
153 | dist1.data_ptr(), idx1.data_ptr());
154 | chamfer_dist_kernel<<>>(
155 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
156 | dist2.data_ptr(), idx2.data_ptr());
157 |
158 | cudaError_t err = cudaGetLastError();
159 | if (err != cudaSuccess) {
160 | printf("Error in chamfer_cuda_forward: %s\n", cudaGetErrorString(err));
161 | }
162 | return {dist1, dist2, idx1, idx2};
163 | }
164 |
165 | __global__ void chamfer_dist_grad_kernel(int b,
166 | int n,
167 | const float* xyz1,
168 | int m,
169 | const float* xyz2,
170 | const float* grad_dist1,
171 | const int* idx1,
172 | float* grad_xyz1,
173 | float* grad_xyz2) {
174 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
175 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
176 | j += blockDim.x * gridDim.y) {
177 | float x1 = xyz1[(i * n + j) * 3 + 0];
178 | float y1 = xyz1[(i * n + j) * 3 + 1];
179 | float z1 = xyz1[(i * n + j) * 3 + 2];
180 | int j2 = idx1[i * n + j];
181 | float x2 = xyz2[(i * m + j2) * 3 + 0];
182 | float y2 = xyz2[(i * m + j2) * 3 + 1];
183 | float z2 = xyz2[(i * m + j2) * 3 + 2];
184 | float g = grad_dist1[i * n + j] * 2;
185 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
186 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
187 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
188 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
189 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
190 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
191 | }
192 | }
193 | }
194 |
195 | std::vector chamfer_cuda_backward(torch::Tensor xyz1,
196 | torch::Tensor xyz2,
197 | torch::Tensor idx1,
198 | torch::Tensor idx2,
199 | torch::Tensor grad_dist1,
200 | torch::Tensor grad_dist2) {
201 | const int batch_size = xyz1.size(0);
202 | const int n = xyz1.size(1); // num_points point cloud A
203 | const int m = xyz2.size(1); // num_points point cloud B
204 | torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
205 | torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
206 |
207 | chamfer_dist_grad_kernel<<>>(
208 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
209 | grad_dist1.data_ptr(), idx1.data_ptr(),
210 | grad_xyz1.data_ptr(), grad_xyz2.data_ptr());
211 | chamfer_dist_grad_kernel<<>>(
212 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
213 | grad_dist2.data_ptr(), idx2.data_ptr(),
214 | grad_xyz2.data_ptr(), grad_xyz1.data_ptr());
215 |
216 | cudaError_t err = cudaGetLastError();
217 | if (err != cudaSuccess) {
218 | printf("Error in chamfer_cuda_backward: %s\n", cudaGetErrorString(err));
219 | }
220 | return {grad_xyz1, grad_xyz2};
221 | }
222 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/chamfer_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | std::vector chamfer_cuda_forward(torch::Tensor xyz1,
5 | torch::Tensor xyz2);
6 |
7 | std::vector chamfer_cuda_backward(torch::Tensor xyz1,
8 | torch::Tensor xyz2,
9 | torch::Tensor idx1,
10 | torch::Tensor idx2,
11 | torch::Tensor grad_dist1,
12 | torch::Tensor grad_dist2);
13 |
14 | std::vector chamfer_forward(torch::Tensor xyz1,
15 | torch::Tensor xyz2) {
16 | return chamfer_cuda_forward(xyz1, xyz2);
17 | }
18 |
19 | std::vector chamfer_backward(torch::Tensor xyz1,
20 | torch::Tensor xyz2,
21 | torch::Tensor idx1,
22 | torch::Tensor idx2,
23 | torch::Tensor grad_dist1,
24 | torch::Tensor grad_dist2) {
25 | return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
26 | }
27 |
28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29 | m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
30 | m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
31 | }
32 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(name='chamfer',
5 | version='2.0.0',
6 | ext_modules=[
7 | CUDAExtension('chamfer', [
8 | 'chamfer_cuda.cpp',
9 | 'chamfer.cu',
10 | ]),
11 | ],
12 | cmdclass={'build_ext': BuildExtension})
13 |
--------------------------------------------------------------------------------
/extensions/chamfer_dist/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import unittest
5 |
6 |
7 | from torch.autograd import gradcheck
8 |
9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
10 | from extensions.chamfer_dist import ChamferFunction
11 |
12 |
13 | class ChamferDistanceTestCase(unittest.TestCase):
14 | def test_chamfer_dist(self):
15 | x = torch.rand(4, 64, 3).double()
16 | y = torch.rand(4, 128, 3).double()
17 | x.requires_grad = True
18 | y.requires_grad = True
19 | print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))
20 |
21 |
22 |
23 | if __name__ == '__main__':
24 | # unittest.main()
25 | import pdb
26 | x = torch.rand(32,128,3)
27 | y = torch.rand(32,128,3)
28 | pdb.set_trace()
29 |
--------------------------------------------------------------------------------
/media/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/media/comparison.png
--------------------------------------------------------------------------------
/media/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/media/pipeline.png
--------------------------------------------------------------------------------
/models/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from glob import glob
7 | from scipy.spatial.transform import Rotation as Rot
8 | from scipy.spatial.transform import Slerp
9 | from tools.feat_utils import load_pair, load_cam, scale_camera, FeatExt
10 |
11 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
12 | def load_K_Rt_from_P(filename, P=None):
13 | if P is None:
14 | lines = open(filename).read().splitlines()
15 | if len(lines) == 4:
16 | lines = lines[1:]
17 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
18 | P = np.asarray(lines).astype(np.float32).squeeze()
19 |
20 | out = cv.decomposeProjectionMatrix(P)
21 | K = out[0]
22 | R = out[1]
23 | t = out[2]
24 |
25 | K = K / K[2, 2]
26 | intrinsics = np.eye(4)
27 | intrinsics[:3, :3] = K
28 |
29 | pose = np.eye(4, dtype=np.float32)
30 | pose[:3, :3] = R.transpose()
31 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
32 |
33 | return intrinsics, pose
34 |
35 |
36 | class Dataset:
37 | def __init__(self, conf):
38 | super(Dataset, self).__init__()
39 | print('Load data: Begin')
40 | self.device = torch.device('cuda')
41 | self.conf = conf
42 |
43 | self.data_dir = conf.get_string('data_dir')
44 | self.render_cameras_name = conf.get_string('render_cameras_name')
45 | self.object_cameras_name = conf.get_string('object_cameras_name')
46 |
47 | self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
48 | self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
49 |
50 | camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
51 | self.camera_dict = camera_dict
52 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
53 | self.n_images = len(self.images_lis)
54 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
55 | # world_mat is a projection matrix from world to image
56 | self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
57 |
58 | self.scale_mats_np = []
59 |
60 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
61 | self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
62 |
63 | self.intrinsics_all = []
64 | self.pose_all = []
65 |
66 | for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np):
67 | P = world_mat @ scale_mat
68 | P = P[:3, :4]
69 | intrinsics, pose = load_K_Rt_from_P(None, P)
70 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
71 | self.pose_all.append(torch.from_numpy(pose).float())
72 |
73 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3]
74 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4]
75 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
76 | self.focal = self.intrinsics_all[0][0, 0]
77 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
78 | self.H, self.W = self.images.shape[1], self.images.shape[2]
79 | self.image_pixels = self.H * self.W
80 | self.pair = load_pair(f'{self.data_dir}/cam4feat/pair.txt')
81 | self.num_src = 2
82 | self.depth_cams = torch.stack(
83 | [torch.from_numpy(
84 | load_cam(f'{self.data_dir}/cam4feat/cam_{self.pair["id_list"][i].zfill(8)}_flow3.txt', 256, 1)).to(torch.float32)
85 | for i in range(self.n_images)], dim=0)
86 | self.feat_img_scale = 2
87 | self.cams_hd = torch.stack( # upsample of 2 from depth_cams, not 1200 * 1600
88 | [scale_camera(self.depth_cams[i], self.feat_img_scale) for i in range(self.n_images)] # NOTE: hard code
89 | )
90 | self.img_res = self.images.shape[-3:-1]
91 | # [n_images, 3, 768, 1024]
92 | self.rgb_2xd = torch.stack([
93 | F.interpolate(
94 | self.images[idx].reshape(-1,3).permute(1, 0).view(1, 3, *self.img_res), # 1200 x 1600
95 | size=(self.conf.feat_map_h * self.feat_img_scale, self.conf.feat_map_w * self.feat_img_scale),
96 | mode='bilinear', align_corners=False)[0]
97 | for idx in range(self.n_images)
98 | ], dim=0) # v3hw
99 | mean = torch.tensor([0.485, 0.456, 0.406]).float().cpu()
100 | std = torch.tensor([0.229, 0.224, 0.225]).float().cpu()
101 | self.rgb_2xd = (self.rgb_2xd / 2 + 0.5 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
102 | self.size = torch.from_numpy(self.scale_mats_np[0]).float()[0, 0] * 2
103 | self.center = torch.from_numpy(self.scale_mats_np[0]).float()[:3, 3]
104 |
105 | feat_ext = FeatExt().cuda()
106 | feat_ext.eval()
107 | for p in feat_ext.parameters():
108 | p.requires_grad = False
109 | feats = []
110 | for start_i in range(0, self.n_images):
111 | eval_batch = self.rgb_2xd[start_i:start_i + 1]
112 | feat2 = feat_ext(eval_batch.cuda())[2] # .detach().cpu()
113 | feats.append(feat2)
114 | self.feats = torch.cat(feats, dim=0)
115 | self.feats.requires_grad = False
116 |
117 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
118 | object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
119 | # Object scale mat: region of interest to **extract mesh**
120 | object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
121 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
122 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
123 | self.object_bbox_min = object_bbox_min[:3, 0]
124 | self.object_bbox_max = object_bbox_max[:3, 0]
125 |
126 | print('Load data: End')
127 |
128 | def gen_rays_at(self, img_idx, resolution_level=1):
129 | """
130 | Generate rays at world space from one camera.
131 | """
132 | l = resolution_level
133 | tx = torch.linspace(0, self.W - 1, self.W // l)
134 | ty = torch.linspace(0, self.H - 1, self.H // l)
135 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
136 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
137 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
138 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
139 | rays_d = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3
140 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_d.shape) # W, H, 3
141 | return rays_o.transpose(0, 1), rays_d.transpose(0, 1)
142 |
143 | def gen_random_rays_at(self, img_idx, batch_size):
144 | """
145 | Generate random rays at world space from one camera.
146 | """
147 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
148 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
149 | color = self.images[img_idx.to(self.images.device)][(pixels_y.to(self.images.device), pixels_x.to(self.images.device))] # batch_size, 3
150 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
151 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
152 | rays_d_norm = torch.linalg.norm(p, ord=2, dim=-1, keepdim=True)
153 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
154 | rays_d = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_d[:, :, None]).squeeze() # batch_size, 3
155 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_d.shape) # batch_size, 3
156 |
157 | id = self.pair['id_list'][img_idx]
158 | src_ids = self.pair[id]['pair']
159 | src_idxs = [self.pair[src_id]['index'] for src_id in src_ids][:self.num_src]
160 |
161 | sample = {}
162 | sample['depth_cams'] = self.depth_cams[[img_idx]]
163 | sample['size'] = self.size
164 | sample['center'] = self.center
165 | sample["feat"] = self.feats[img_idx]
166 | sample["feat_src"] = self.feats[src_idxs]
167 | sample["cam"] = self.cams_hd[img_idx]
168 | sample["src_cams"] = self.cams_hd[src_idxs]
169 | sample['rays_d_norm'] = rays_d_norm
170 | sample['H'] = self.H
171 | sample['W'] = self.W
172 | sample['src_idxs'] = src_idxs
173 |
174 | return torch.cat([rays_o.cpu(), rays_d.cpu(), color], dim=-1).cuda() ,sample # batch_size, 9
175 |
176 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
177 | """
178 | Interpolate pose between two cameras.
179 | """
180 | l = resolution_level
181 | tx = torch.linspace(0, self.W - 1, self.W // l)
182 | ty = torch.linspace(0, self.H - 1, self.H // l)
183 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
184 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
185 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
186 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
187 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
188 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
189 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
190 | pose_0 = np.linalg.inv(pose_0)
191 | pose_1 = np.linalg.inv(pose_1)
192 | rot_0 = pose_0[:3, :3]
193 | rot_1 = pose_1[:3, :3]
194 | rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
195 | key_times = [0, 1]
196 | slerp = Slerp(key_times, rots)
197 | rot = slerp(ratio)
198 | pose = np.diag([1.0, 1.0, 1.0, 1.0])
199 | pose = pose.astype(np.float32)
200 | pose[:3, :3] = rot.as_matrix()
201 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
202 | pose = np.linalg.inv(pose)
203 | rot = torch.from_numpy(pose[:3, :3]).cuda()
204 | trans = torch.from_numpy(pose[:3, 3]).cuda()
205 | rays_d = torch.matmul(rot[None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3
206 | rays_o = trans[None, None, :3].expand(rays_d.shape) # W, H, 3
207 | return rays_o.transpose(0, 1), rays_d.transpose(0, 1)
208 |
209 | def gen_rays_between_from_pts(self, idx_0, idx_1, ratio, pts, resolution_level=1):
210 | """
211 | Interpolate pose between two cameras.
212 | """
213 | l = resolution_level
214 | tx = torch.linspace(0, self.W - 1, self.W // l)
215 | ty = torch.linspace(0, self.H - 1, self.H // l)
216 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
217 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
218 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
219 | rays_d = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
220 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
221 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
222 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
223 | pose_0 = np.linalg.inv(pose_0)
224 | pose_1 = np.linalg.inv(pose_1)
225 | rot_0 = pose_0[:3, :3]
226 | rot_1 = pose_1[:3, :3]
227 | rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
228 | key_times = [0, 1]
229 | slerp = Slerp(key_times, rots)
230 | rot = slerp(ratio)
231 | pose = np.diag([1.0, 1.0, 1.0, 1.0])
232 | pose = pose.astype(np.float32)
233 | pose[:3, :3] = rot.as_matrix()
234 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
235 | pose = np.linalg.inv(pose)
236 | rot = torch.from_numpy(pose[:3, :3]).cuda()
237 | trans = torch.from_numpy(pose[:3, 3]).cuda()
238 | # rays_d = torch.matmul(rot[None, None, :3, :3], rays_d[:, :, :, None]).squeeze() # W, H, 3
239 | # import pdb; pdb.set_trace()
240 | rays_o = trans[None, None, :3].expand(pts.shape) # 1, N, 3
241 | rays_d = F.normalize(pts - rays_o, dim=-1) # 1, N, 3
242 | return rays_o.squeeze(0), rays_d.squeeze(0)
243 |
244 | def near_far_from_sphere(self, rays_o, rays_d):
245 | a = torch.sum(rays_d**2, dim=-1, keepdim=True)
246 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
247 | mid = 0.5 * (-b) / a
248 | near = mid - 1.0
249 | far = mid + 1.0
250 | return near, far
251 |
252 | def image_at(self, idx, resolution_level):
253 | img = cv.imread(self.images_lis[idx])
254 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
255 |
256 |
--------------------------------------------------------------------------------
/models/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | # Borrowed from https://github.com/bmild/nerf.
7 | class Embedder:
8 | def __init__(self, **kwargs):
9 | self.kwargs = kwargs
10 | self.create_embedding_fn()
11 |
12 | def create_embedding_fn(self):
13 | embed_fns = []
14 | d = self.kwargs['input_dims']
15 | out_dim = 0
16 | if self.kwargs['include_input']:
17 | embed_fns.append(lambda x: x)
18 | out_dim += d
19 |
20 | max_freq = self.kwargs['max_freq_log2']
21 | N_freqs = self.kwargs['num_freqs']
22 |
23 | if self.kwargs['log_sampling']:
24 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
25 | else:
26 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
27 |
28 | for freq in freq_bands:
29 | for p_fn in self.kwargs['periodic_fns']:
30 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
31 | out_dim += d
32 |
33 | self.embed_fns = embed_fns
34 | self.out_dim = out_dim
35 |
36 | def embed(self, inputs):
37 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
38 |
39 |
40 | def get_embedder(multires, input_dims=3):
41 | embed_kwargs = {
42 | 'include_input': True,
43 | 'input_dims': input_dims,
44 | 'max_freq_log2': multires-1,
45 | 'num_freqs': multires,
46 | 'log_sampling': True,
47 | 'periodic_fns': [torch.sin, torch.cos],
48 | }
49 |
50 | embedder_obj = Embedder(**embed_kwargs)
51 | def embed(x, eo=embedder_obj): return eo.embed(x)
52 | return embed, embedder_obj.out_dim
53 |
--------------------------------------------------------------------------------
/models/fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from models.embedder import get_embedder
6 |
7 |
8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
9 | class SDFNetwork(nn.Module):
10 | def __init__(self,
11 | d_in,
12 | d_out,
13 | d_hidden,
14 | n_layers,
15 | skip_in=(4,),
16 | multires=0,
17 | bias=0.5,
18 | scale=1,
19 | geometric_init=True,
20 | weight_norm=True,
21 | inside_outside=False):
22 | super(SDFNetwork, self).__init__()
23 |
24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
25 |
26 | self.embed_fn_fine = None
27 |
28 | if multires > 0:
29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
30 | self.embed_fn_fine = embed_fn
31 | dims[0] = input_ch
32 |
33 | self.num_layers = len(dims)
34 | self.skip_in = skip_in
35 | self.scale = scale
36 |
37 | for l in range(0, self.num_layers - 1):
38 | if l + 1 in self.skip_in:
39 | out_dim = dims[l + 1] - dims[0]
40 | else:
41 | out_dim = dims[l + 1]
42 |
43 | lin = nn.Linear(dims[l], out_dim)
44 |
45 | if geometric_init:
46 | if l == self.num_layers - 2:
47 | if not inside_outside:
48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
49 | torch.nn.init.constant_(lin.bias, -bias)
50 | else:
51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
52 | torch.nn.init.constant_(lin.bias, bias)
53 | elif multires > 0 and l == 0:
54 | torch.nn.init.constant_(lin.bias, 0.0)
55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
57 | elif multires > 0 and l in self.skip_in:
58 | torch.nn.init.constant_(lin.bias, 0.0)
59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
61 | else:
62 | torch.nn.init.constant_(lin.bias, 0.0)
63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
64 |
65 | if weight_norm:
66 | lin = nn.utils.weight_norm(lin)
67 |
68 | setattr(self, "lin" + str(l), lin)
69 |
70 | self.activation = nn.Softplus(beta=100)
71 |
72 | def forward(self, inputs):
73 | inputs = inputs * self.scale
74 | if self.embed_fn_fine is not None:
75 | inputs = self.embed_fn_fine(inputs)
76 |
77 | x = inputs
78 | for l in range(0, self.num_layers - 1):
79 | lin = getattr(self, "lin" + str(l))
80 |
81 | if l in self.skip_in:
82 | x = torch.cat([x, inputs], 1) / np.sqrt(2)
83 |
84 | x = lin(x)
85 |
86 | if l < self.num_layers - 2:
87 | x = self.activation(x)
88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
89 |
90 | def sdf(self, x):
91 | return self.forward(x)[:, :1]
92 |
93 | def sdf_hidden_appearance(self, x):
94 | return self.forward(x)
95 |
96 | def gradient(self, x):
97 | x.requires_grad_(True)
98 | y = self.sdf(x)
99 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
100 | gradients = torch.autograd.grad(
101 | outputs=y,
102 | inputs=x,
103 | grad_outputs=d_output,
104 | create_graph=True,
105 | retain_graph=True,
106 | only_inputs=True)[0]
107 | return gradients.unsqueeze(1)
108 |
109 |
110 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
111 | class RenderingNetwork(nn.Module):
112 | def __init__(self,
113 | d_feature,
114 | mode,
115 | d_in,
116 | d_out,
117 | d_hidden,
118 | n_layers,
119 | weight_norm=True,
120 | multires_view=0,
121 | squeeze_out=True):
122 | super().__init__()
123 |
124 | self.mode = mode
125 | self.squeeze_out = squeeze_out
126 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
127 |
128 | self.embedview_fn = None
129 | if multires_view > 0:
130 | embedview_fn, input_ch = get_embedder(multires_view)
131 | self.embedview_fn = embedview_fn
132 | dims[0] += (input_ch - 3)
133 |
134 | self.num_layers = len(dims)
135 |
136 | for l in range(0, self.num_layers - 1):
137 | out_dim = dims[l + 1]
138 | lin = nn.Linear(dims[l], out_dim)
139 |
140 | if weight_norm:
141 | lin = nn.utils.weight_norm(lin)
142 |
143 | setattr(self, "lin" + str(l), lin)
144 |
145 | self.relu = nn.ReLU()
146 |
147 | def forward(self, points, normals, view_dirs, feature_vectors):
148 | if self.embedview_fn is not None:
149 | view_dirs = self.embedview_fn(view_dirs)
150 |
151 | rendering_input = None
152 |
153 | if self.mode == 'idr':
154 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
155 | elif self.mode == 'no_view_dir':
156 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
157 | elif self.mode == 'no_normal':
158 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
159 |
160 | x = rendering_input
161 |
162 | for l in range(0, self.num_layers - 1):
163 | lin = getattr(self, "lin" + str(l))
164 |
165 | x = lin(x)
166 |
167 | if l < self.num_layers - 2:
168 | x = self.relu(x)
169 |
170 | if self.squeeze_out:
171 | x = torch.sigmoid(x)
172 | return x
173 |
174 |
175 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch
176 | class NeRF(nn.Module):
177 | def __init__(self,
178 | D=8,
179 | W=256,
180 | d_in=3,
181 | d_in_view=3,
182 | multires=0,
183 | multires_view=0,
184 | output_ch=4,
185 | skips=[4],
186 | use_viewdirs=False):
187 | super(NeRF, self).__init__()
188 | self.D = D
189 | self.W = W
190 | self.d_in = d_in
191 | self.d_in_view = d_in_view
192 | self.input_ch = 3
193 | self.input_ch_view = 3
194 | self.embed_fn = None
195 | self.embed_fn_view = None
196 |
197 | if multires > 0:
198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
199 | self.embed_fn = embed_fn
200 | self.input_ch = input_ch
201 |
202 | if multires_view > 0:
203 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view)
204 | self.embed_fn_view = embed_fn_view
205 | self.input_ch_view = input_ch_view
206 |
207 | self.skips = skips
208 | self.use_viewdirs = use_viewdirs
209 |
210 | self.pts_linears = nn.ModuleList(
211 | [nn.Linear(self.input_ch, W)] +
212 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)])
213 |
214 | ### Implementation according to the official code release
215 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
216 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
217 |
218 | ### Implementation according to the paper
219 | # self.views_linears = nn.ModuleList(
220 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
221 |
222 | if use_viewdirs:
223 | self.feature_linear = nn.Linear(W, W)
224 | self.alpha_linear = nn.Linear(W, 1)
225 | self.rgb_linear = nn.Linear(W // 2, 3)
226 | else:
227 | self.output_linear = nn.Linear(W, output_ch)
228 |
229 | def forward(self, input_pts, input_views):
230 | if self.embed_fn is not None:
231 | input_pts = self.embed_fn(input_pts)
232 | if self.embed_fn_view is not None:
233 | input_views = self.embed_fn_view(input_views)
234 |
235 | h = input_pts
236 | for i, l in enumerate(self.pts_linears):
237 | h = self.pts_linears[i](h)
238 | h = F.relu(h)
239 | if i in self.skips:
240 | h = torch.cat([input_pts, h], -1)
241 |
242 | if self.use_viewdirs:
243 | alpha = self.alpha_linear(h)
244 | feature = self.feature_linear(h)
245 | h = torch.cat([feature, input_views], -1)
246 |
247 | for i, l in enumerate(self.views_linears):
248 | h = self.views_linears[i](h)
249 | h = F.relu(h)
250 |
251 | rgb = self.rgb_linear(h)
252 | return alpha, rgb
253 | else:
254 | assert False
255 |
256 |
257 | class SingleVarianceNetwork(nn.Module):
258 | def __init__(self, init_val):
259 | super(SingleVarianceNetwork, self).__init__()
260 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
261 |
262 | def forward(self, x):
263 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0)
264 |
--------------------------------------------------------------------------------
/models/renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import mcubes
5 | import tools.feat_utils as feat_utils
6 |
7 |
8 | # interpolate SDF zero-crossing points
9 | def find_surface_points(sdf, d_all, device='cuda'):
10 | # shape of sdf and d_all: only inside
11 | sdf_bool_1 = sdf[...,1:] * sdf[...,:-1] < 0
12 | # only find backward facing surface points, not forward facing
13 | sdf_bool_2 = sdf[...,1:] < sdf[...,:-1]
14 | sdf_bool = torch.logical_and(sdf_bool_1, sdf_bool_2)
15 |
16 | max, max_indices = torch.max(sdf_bool, dim=2)
17 | network_mask = max > 0
18 | d_surface = torch.zeros_like(network_mask, device=device).float()
19 |
20 | sdf_0 = torch.gather(sdf[network_mask], 1, max_indices[network_mask][..., None]).squeeze()
21 | sdf_1 = torch.gather(sdf[network_mask], 1, max_indices[network_mask][..., None]+1).squeeze()
22 | d_0 = torch.gather(d_all[network_mask], 1, max_indices[network_mask][..., None]).squeeze()
23 | d_1 = torch.gather(d_all[network_mask], 1, max_indices[network_mask][..., None]+1).squeeze()
24 | d_surface[network_mask] = (sdf_0 * d_1 - sdf_1 * d_0) / (sdf_0-sdf_1)
25 |
26 | return d_surface, network_mask
27 |
28 | def extract_fields(bound_min, bound_max, resolution, query_func):
29 | N = 64
30 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
31 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
32 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
33 |
34 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
35 | with torch.no_grad():
36 | for xi, xs in enumerate(X):
37 | for yi, ys in enumerate(Y):
38 | for zi, zs in enumerate(Z):
39 | xx, yy, zz = torch.meshgrid(xs, ys, zs)
40 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
41 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
42 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
43 | return u
44 |
45 |
46 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
47 | print('threshold: {}'.format(threshold))
48 | u = extract_fields(bound_min, bound_max, resolution, query_func)
49 | vertices, triangles = mcubes.marching_cubes(u, threshold)
50 | b_max_np = bound_max.detach().cpu().numpy()
51 | b_min_np = bound_min.detach().cpu().numpy()
52 |
53 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
54 | return vertices, triangles
55 |
56 |
57 | def sample_pdf(bins, weights, n_samples, det=False):
58 | # This implementation is from NeRF
59 | # Get pdf
60 | weights = weights + 1e-5 # prevent nans
61 | pdf = weights / torch.sum(weights, -1, keepdim=True)
62 | cdf = torch.cumsum(pdf, -1)
63 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
64 | # Take uniform samples
65 | if det:
66 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
67 | u = u.expand(list(cdf.shape[:-1]) + [n_samples])
68 | else:
69 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
70 |
71 | # Invert CDF
72 | u = u.contiguous()
73 | inds = torch.searchsorted(cdf, u, right=True)
74 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
75 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
76 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
77 |
78 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
79 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
80 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
81 |
82 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
83 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
84 | t = (u - cdf_g[..., 0]) / denom
85 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
86 |
87 | return samples
88 |
89 |
90 | class NeuSRenderer:
91 | def __init__(self,
92 | nerf,
93 | sdf_network,
94 | deviation_network,
95 | color_network,
96 | dataset,
97 | n_samples,
98 | n_importance,
99 | n_outside,
100 | up_sample_steps,
101 | perturb):
102 | self.nerf = nerf
103 | self.sdf_network = sdf_network
104 | self.deviation_network = deviation_network
105 | self.color_network = color_network
106 | self.dataset = dataset
107 | self.n_samples = n_samples
108 | self.n_importance = n_importance
109 | self.n_outside = n_outside
110 | self.up_sample_steps = up_sample_steps
111 | self.perturb = perturb
112 |
113 | self.feat_ext = feat_utils.FeatExt().cuda()
114 | self.feat_ext.eval()
115 | for p in self.feat_ext.parameters():
116 | p.requires_grad = False
117 |
118 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None):
119 | """
120 | Render background
121 | """
122 | batch_size, n_samples = z_vals.shape
123 |
124 | # Section length
125 | dists = z_vals[..., 1:] - z_vals[..., :-1]
126 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
127 | mid_z_vals = z_vals + dists * 0.5
128 |
129 |
130 | # Section midpoints
131 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3
132 |
133 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
134 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4
135 |
136 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
137 |
138 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
139 | dirs = dirs.reshape(-1, 3)
140 |
141 | density, sampled_color = nerf(pts, dirs)
142 | sampled_color = torch.sigmoid(sampled_color)
143 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
144 | alpha = alpha.reshape(batch_size, n_samples)
145 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
146 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
147 | color = (weights[:, :, None] * sampled_color).sum(dim=1)
148 | if background_rgb is not None:
149 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
150 |
151 | return {
152 | 'color': color,
153 | 'sampled_color': sampled_color,
154 | 'alpha': alpha,
155 | 'weights': weights,
156 | 'mid_z_vals_out' : mid_z_vals
157 | }
158 |
159 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
160 | """
161 | Up sampling give a fixed inv_s
162 | """
163 | batch_size, n_samples = z_vals.shape
164 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
165 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
166 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
167 | sdf = sdf.reshape(batch_size, n_samples)
168 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
169 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
170 | mid_sdf = (prev_sdf + next_sdf) * 0.5
171 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
172 |
173 | # ----------------------------------------------------------------------------------------------------------
174 | # Use min value of [ cos, prev_cos ]
175 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
176 | # robust when meeting situations like below:
177 | #
178 | # SDF
179 | # ^
180 | # |\ -----x----...
181 | # | \ /
182 | # | x x
183 | # |---\----/-------------> 0 level
184 | # | \ /
185 | # | \/
186 | # |
187 | # ----------------------------------------------------------------------------------------------------------
188 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
189 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
190 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
191 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
192 |
193 | dist = (next_z_vals - prev_z_vals)
194 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
195 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5
196 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
197 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
198 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
199 | weights = alpha * torch.cumprod(
200 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
201 |
202 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
203 | return z_samples
204 |
205 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
206 | batch_size, n_samples = z_vals.shape
207 | _, n_importance = new_z_vals.shape
208 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
209 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
210 | z_vals, index = torch.sort(z_vals, dim=-1)
211 |
212 | if not last:
213 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
214 | sdf = torch.cat([sdf, new_sdf], dim=-1)
215 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
216 | index = index.reshape(-1)
217 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
218 |
219 | return z_vals, sdf
220 |
221 | def render_color(self,
222 | rays_o,
223 | rays_d,
224 | z_vals,
225 | sample_dist,
226 | sdf_network,
227 | deviation_network,
228 | color_network,
229 | background_alpha=None,
230 | background_sampled_color=None,
231 | background_rgb=None,
232 | cos_anneal_ratio=0.0,
233 | ):
234 |
235 | batch_size, n_samples = z_vals.shape
236 |
237 | # Section length
238 | dists = z_vals[..., 1:] - z_vals[..., :-1]
239 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
240 | mid_z_vals = z_vals + dists * 0.5
241 |
242 | # Section midpoints
243 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
244 | dirs = rays_d[:, None, :].expand(pts.shape)
245 | pts = pts.reshape(-1, 3)
246 | dirs = dirs.reshape(-1, 3)
247 |
248 | sdf_nn_output = sdf_network(pts)
249 | sdf = sdf_nn_output[:, :1]
250 | feature_vector = sdf_nn_output[:, 1:]
251 |
252 | gradients = sdf_network.gradient(pts).squeeze()
253 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
254 |
255 | # inv_s in the code == s in the paper == 1 / standard deviation
256 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
257 | inv_s = inv_s.expand(batch_size * n_samples, 1)
258 |
259 | true_cos = (dirs * gradients).sum(-1, keepdim=True)
260 |
261 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
262 | # the cos value "not dead" at the beginning training iterations, for better convergence.
263 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
264 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
265 |
266 | # Estimate signed distances at section points
267 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
268 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
269 |
270 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
271 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
272 |
273 | p = prev_cdf - next_cdf
274 | c = prev_cdf
275 |
276 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
277 |
278 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
279 | inside_sphere = (pts_norm < 1.0).float().detach()
280 |
281 | # Render with background
282 | if background_alpha is not None:
283 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
284 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
285 | sampled_color = sampled_color * inside_sphere[:, :, None] +\
286 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
287 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
288 |
289 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
290 | weights_sum = weights.sum(dim=-1, keepdim=True)
291 |
292 | color = (sampled_color * weights[:, :, None]).sum(dim=1)
293 | if background_rgb is not None: # Fixed background, usually black
294 | color = color + background_rgb * (1.0 - weights_sum)
295 |
296 | return color
297 |
298 | def render_core(self,
299 | rays_o,
300 | rays_d,
301 | z_vals,
302 | sample_dist,
303 | sdf_network,
304 | deviation_network,
305 | color_network,
306 | model_input = None,
307 | background_alpha=None,
308 | background_sampled_color=None,
309 | background_rgb=None,
310 | mid_z_vals_out=None,
311 | cos_anneal_ratio=0.0,
312 | depth_from_inside_only=None,
313 | ):
314 |
315 | batch_size, n_samples = z_vals.shape
316 |
317 | # Section length
318 | dists = z_vals[..., 1:] - z_vals[..., :-1]
319 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
320 | mid_z_vals = z_vals + dists * 0.5
321 |
322 | # Section midpoints
323 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
324 | dirs = rays_d[:, None, :].expand(pts.shape)
325 | pts = pts.reshape(-1, 3)
326 | dirs = dirs.reshape(-1, 3)
327 |
328 | query_pts = pts.clone()
329 |
330 | sdf_nn_output = sdf_network(pts)
331 | sdf = sdf_nn_output[:, :1]
332 | feature_vector = sdf_nn_output[:, 1:]
333 |
334 | gradients = sdf_network.gradient(pts).squeeze()
335 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
336 |
337 | # inv_s in the code == s in the paper == 1 / standard deviation
338 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
339 | inv_s = inv_s.expand(batch_size * n_samples, 1)
340 |
341 | true_cos = (dirs * gradients).sum(-1, keepdim=True)
342 |
343 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
344 | # the cos value "not dead" at the beginning training iterations, for better convergence.
345 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
346 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
347 |
348 | # Estimate signed distances at section points
349 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
350 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
351 |
352 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
353 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
354 |
355 | p = prev_cdf - next_cdf
356 | c = prev_cdf
357 |
358 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
359 | alpha_in = alpha
360 |
361 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
362 | inside_sphere = (pts_norm < 1.0).float().detach()
363 | relax_inside_sphere = (pts_norm < 1.2).float().detach()
364 |
365 | # Render with background
366 | if background_alpha is not None:
367 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
368 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
369 | sampled_color = sampled_color * inside_sphere[:, :, None] +\
370 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
371 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
372 |
373 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
374 | weights_sum = weights.sum(dim=-1, keepdim=True)
375 | if depth_from_inside_only:
376 | weights_in = alpha_in * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha_in + 1e-7], -1), -1)[:, :-1]
377 | weights_in_sum = weights_in.sum(dim=-1, keepdim=True)
378 |
379 | color = (sampled_color * weights[:, :, None]).sum(dim=1)
380 | if background_rgb is not None: # Fixed background, usually black
381 | color = color + background_rgb * (1.0 - weights_sum)
382 |
383 | # Eikonal loss
384 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
385 | dim=-1) - 1.0) ** 2
386 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
387 |
388 | if model_input is not None:
389 | if background_alpha is not None:
390 | z_final = mid_z_vals_out
391 | else:
392 | z_final = mid_z_vals
393 | if depth_from_inside_only:
394 | z_final = mid_z_vals
395 |
396 | if depth_from_inside_only:
397 | dist_map = torch.sum(weights_in / (weights_in.sum(-1, keepdim=True)+1e-10) * z_final, -1)
398 | else:
399 | dist_map = torch.sum(weights / (weights.sum(-1, keepdim=True)+1e-10) * z_final, -1)
400 |
401 | sdf_all = sdf.reshape(batch_size,n_samples).unsqueeze(0)
402 | d_all = mid_z_vals.unsqueeze(0)
403 | d_surface, network_mask = find_surface_points(sdf_all, d_all)
404 | d_surface = d_surface.squeeze(0)
405 | network_mask = network_mask.squeeze(0)
406 |
407 | object_mask = network_mask
408 |
409 | point_surface = rays_o + rays_d * d_surface[:,None]
410 | point_surface_wmask = point_surface[network_mask & object_mask]
411 |
412 | points_rendered = rays_o + rays_d * dist_map[:,None]
413 | sdf_rendered_points = sdf_network(points_rendered)[:, :1]
414 | sdf_rendered_points_wmask = sdf_rendered_points[object_mask]
415 | sdf_rendered_points_0 = torch.zeros_like(sdf_rendered_points_wmask)
416 | pseudo_pts_loss = F.l1_loss(sdf_rendered_points_wmask, sdf_rendered_points_0, reduction='mean')
417 |
418 | return {
419 | 'color': color,
420 | 'sdf': sdf,
421 | 'dists': dists,
422 | 'gradients': gradients.reshape(batch_size, n_samples, 3),
423 | 's_val': 1.0 / inv_s,
424 | 'mid_z_vals': mid_z_vals,
425 | 'weights': weights,
426 | 'cdf': c.reshape(batch_size, n_samples),
427 | 'gradient_error': gradient_error,
428 | 'inside_sphere': inside_sphere,
429 | 'pseudo_pts_loss': pseudo_pts_loss,
430 | 'query_pts': query_pts,
431 | 'point_surface': point_surface_wmask,
432 | 'network_mask': network_mask,
433 | 'object_mask': object_mask,
434 | }
435 | else:
436 | return {
437 | 'color': color,
438 | 'sdf': sdf,
439 | 'dists': dists,
440 | 'gradients': gradients.reshape(batch_size, n_samples, 3),
441 | 's_val': 1.0 / inv_s,
442 | 'mid_z_vals': mid_z_vals,
443 | 'weights': weights,
444 | 'cdf': c.reshape(batch_size, n_samples),
445 | 'gradient_error': gradient_error,
446 | 'inside_sphere': inside_sphere,
447 | 'pseudo_pts_loss': torch.tensor(0.0).float(),
448 | 'query_pts': query_pts,
449 | }
450 |
451 | def render(self,
452 | rays_o,
453 | rays_d,
454 | near,
455 | far,
456 | main_img_idx,
457 | t,
458 | random_pcd=None,
459 | perturb_overwrite=-1,
460 | background_rgb=None,
461 | cos_anneal_ratio=0.0,
462 | model_input=None,
463 | depth_from_inside_only=False,
464 | ):
465 | batch_size = len(rays_o)
466 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
467 | z_vals = torch.linspace(0.0, 1.0, self.n_samples)
468 | z_vals = near + (far - near) * z_vals[None, :]
469 |
470 | z_vals_outside = None
471 | if self.n_outside > 0:
472 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
473 |
474 | n_samples = self.n_samples
475 | perturb = self.perturb
476 |
477 | if perturb_overwrite >= 0:
478 | perturb = perturb_overwrite
479 | if perturb > 0:
480 | t_rand = (torch.rand([batch_size, 1]) - 0.5)
481 | z_vals = z_vals + t_rand * 2.0 / self.n_samples
482 |
483 | if self.n_outside > 0:
484 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
485 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
486 | lower = torch.cat([z_vals_outside[..., :1], mids], -1)
487 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
488 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
489 |
490 | if self.n_outside > 0:
491 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
492 |
493 | background_alpha = None
494 | background_sampled_color = None
495 | mid_z_vals_out = None
496 |
497 | # Up sample
498 | if self.n_importance > 0:
499 | with torch.no_grad():
500 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
501 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
502 |
503 | for i in range(self.up_sample_steps):
504 | new_z_vals = self.up_sample(rays_o,
505 | rays_d,
506 | z_vals,
507 | sdf,
508 | self.n_importance // self.up_sample_steps,
509 | 64 * 2**i)
510 | z_vals, sdf = self.cat_z_vals(rays_o,
511 | rays_d,
512 | z_vals,
513 | new_z_vals,
514 | sdf,
515 | last=(i + 1 == self.up_sample_steps))
516 |
517 | n_samples = self.n_samples + self.n_importance
518 |
519 | # Background model
520 | if self.n_outside > 0:
521 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
522 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
523 |
524 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
525 |
526 | background_sampled_color = ret_outside['sampled_color']
527 | background_alpha = ret_outside['alpha']
528 | mid_z_vals_out= ret_outside['mid_z_vals_out']
529 |
530 | # Render core
531 | ret_fine = self.render_core(rays_o,
532 | rays_d,
533 | z_vals,
534 | sample_dist,
535 | self.sdf_network,
536 | self.deviation_network,
537 | self.color_network,
538 | model_input = model_input,
539 | background_rgb=background_rgb,
540 | background_alpha=background_alpha,
541 | background_sampled_color=background_sampled_color,
542 | mid_z_vals_out= mid_z_vals_out,
543 | cos_anneal_ratio=cos_anneal_ratio,
544 | depth_from_inside_only=depth_from_inside_only)
545 |
546 | color_fine = ret_fine['color']
547 | weights = ret_fine['weights']
548 | weights_sum = weights.sum(dim=-1, keepdim=True)
549 | gradients = ret_fine['gradients']
550 | s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
551 |
552 | local_loss = torch.tensor(0).float()
553 | if model_input is not None:
554 | output = {
555 | 'color_fine': color_fine,
556 | 's_val': s_val,
557 | 'cdf_fine': ret_fine['cdf'],
558 | 'weight_sum': weights_sum,
559 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
560 | 'gradients': gradients,
561 | 'weights': weights,
562 | 'gradient_error': ret_fine['gradient_error'],
563 | 'inside_sphere': ret_fine['inside_sphere'],
564 | 'pseudo_pts_loss': ret_fine['pseudo_pts_loss'],
565 | 'sdf': ret_fine['sdf'],
566 | 'query_pts': ret_fine['query_pts'],
567 | }
568 |
569 | point_surface_wmask = ret_fine['point_surface']
570 | network_mask = ret_fine['network_mask']
571 | object_mask = ret_fine['object_mask']
572 |
573 | size, center = model_input['size'].unsqueeze(0), model_input['center'].unsqueeze(0)
574 | size = size[:1]
575 | center = center[:1]
576 |
577 | cam = model_input['cam'] # 2, 4, 4
578 | src_cams = model_input['src_cams'] # m, 2, 4, 4
579 | feat_src = model_input['feat_src']
580 |
581 | if (t % 100 == 0) and (random_pcd is not None):
582 | ''' unseen view rendering '''
583 | random_pcd = random_pcd.view(1, -1, 3)
584 | random_pcd.requires_grad = False
585 |
586 | src_img_idx = model_input['src_idxs'][0]
587 | rays_o, rays_d = self.dataset.gen_rays_between_from_pts(main_img_idx,
588 | src_img_idx,
589 | 0.5,
590 | random_pcd,
591 | )
592 |
593 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
594 |
595 | batch_size = len(rays_o)
596 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
597 | z_vals = torch.linspace(0.0, 1.0, self.n_samples)
598 | z_vals = near + (far - near) * z_vals[None, :]
599 |
600 | z_vals_outside = None
601 | if self.n_outside > 0:
602 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
603 |
604 | n_samples = self.n_samples
605 | perturb = self.perturb
606 |
607 | if perturb_overwrite >= 0:
608 | perturb = perturb_overwrite
609 | if perturb > 0:
610 | t_rand = (torch.rand([batch_size, 1]) - 0.5)
611 | z_vals = z_vals + t_rand * 2.0 / self.n_samples
612 |
613 | if self.n_outside > 0:
614 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
615 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
616 | lower = torch.cat([z_vals_outside[..., :1], mids], -1)
617 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
618 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
619 |
620 | if self.n_outside > 0:
621 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
622 |
623 | background_alpha = None
624 | background_sampled_color = None
625 | mid_z_vals_out = None
626 |
627 | # Up sample
628 | if self.n_importance > 0:
629 | with torch.no_grad():
630 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
631 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
632 |
633 | for i in range(self.up_sample_steps):
634 | new_z_vals = self.up_sample(rays_o,
635 | rays_d,
636 | z_vals,
637 | sdf,
638 | self.n_importance // self.up_sample_steps,
639 | 64 * 2**i)
640 | z_vals, sdf = self.cat_z_vals(rays_o,
641 | rays_d,
642 | z_vals,
643 | new_z_vals,
644 | sdf,
645 | last=(i + 1 == self.up_sample_steps))
646 |
647 | n_samples = self.n_samples + self.n_importance
648 |
649 | # Background model
650 | if self.n_outside > 0:
651 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
652 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
653 |
654 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
655 |
656 | background_sampled_color = ret_outside['sampled_color']
657 | background_alpha = ret_outside['alpha']
658 | mid_z_vals_out= ret_outside['mid_z_vals_out']
659 |
660 | color = self.render_color(rays_o,
661 | rays_d,
662 | z_vals,
663 | sample_dist,
664 | self.sdf_network,
665 | self.deviation_network,
666 | self.color_network,
667 | background_rgb=background_rgb,
668 | background_alpha=background_alpha,
669 | background_sampled_color=background_sampled_color,
670 | cos_anneal_ratio=cos_anneal_ratio,
671 | )
672 |
673 | us_pose = cam.clone()
674 | us_pose[0] = feat_utils.gen_camera_between(cam[0].cpu().numpy(), src_cams[0, 0].cpu().numpy(), 0.5)
675 | us_pose.requires_grad = False
676 | us_pose = us_pose.unsqueeze(0) # 2, 4, 4
677 |
678 | us_rgb = torch.zeros([1, 3, 768, 1024]).cuda()
679 |
680 | pts_world = random_pcd.view(1, -1, 1, 3, 1)
681 | pts_world = torch.cat([pts_world, torch.ones_like(pts_world[..., -1:, :])], dim=-2)
682 | pts_img = feat_utils.idx_cam2img(feat_utils.idx_world2cam(pts_world, us_pose), us_pose).view(1, -1, 3) # 1, N, 3
683 | us_uv = pts_img[..., :2] / pts_img[..., 2:3]
684 | us_uv = us_uv.round().long()
685 |
686 | color_mask = ((us_uv[..., 0] > -1) & (us_uv[..., 0] < 1024) & (us_uv[..., 1] > -1) & (us_uv[..., 1] < 768)).squeeze(0)
687 |
688 | us_uv = us_uv[0, color_mask] # M, 2
689 | color = color[color_mask] # M, 3
690 |
691 | _, cnts = torch.unique(us_uv, sorted=False, return_counts=True, dim=0)
692 | cnts = torch.cat((torch.tensor([0]).long().cuda(), cnts))
693 | unique_index = torch.cumsum(cnts, dim=0)
694 | unique_index = unique_index[:-1]
695 |
696 | us_uv = us_uv[unique_index]
697 | color = color[unique_index].transpose(0, 1)
698 |
699 | us_rgb[0, :, us_uv[:, 1], us_uv[:, 0]] = color
700 |
701 | us_feat = self.feat_ext(us_rgb)[2]
702 |
703 | local_loss += feat_utils.get_local_loss(random_pcd.reshape(-1, 3), None, us_feat,
704 | us_pose, feat_src.unsqueeze(0), src_cams.unsqueeze(0),
705 | 2 * torch.ones_like(size).cuda(), torch.zeros_like(center).cuda(),
706 | color_mask.reshape(-1), color_mask.reshape(-1))
707 |
708 | local_loss += feat_utils.get_local_loss(point_surface_wmask, None, model_input['feat'].unsqueeze(0),
709 | cam.unsqueeze(0), feat_src.unsqueeze(0), src_cams.unsqueeze(0),
710 | size, center, network_mask.reshape(-1),
711 | object_mask.reshape(-1))
712 |
713 | output['local_loss'] = local_loss
714 | return output
715 |
716 | else:
717 | return {
718 | 'color_fine': color_fine,
719 | 's_val': s_val,
720 | 'cdf_fine': ret_fine['cdf'],
721 | 'weight_sum': weights_sum,
722 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
723 | 'gradients': gradients,
724 | 'weights': weights,
725 | 'gradient_error': ret_fine['gradient_error'],
726 | 'inside_sphere': ret_fine['inside_sphere'],
727 | 'pseudo_pts_loss': ret_fine['pseudo_pts_loss'],
728 | 'local_loss': local_loss,
729 | 'sdf': ret_fine['sdf'],
730 | 'query_pts': ret_fine['query_pts'],
731 | }
732 |
733 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
734 | return extract_geometry(bound_min,
735 | bound_max,
736 | resolution=resolution,
737 | threshold=threshold,
738 | query_func=lambda pts: -self.sdf_network.sdf(pts))
739 |
--------------------------------------------------------------------------------
/models/udf_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import os
5 | from scipy.spatial import cKDTree
6 | import trimesh
7 |
8 | def search_nearest_point(point_batch, point_gt):
9 | num_point_batch, num_point_gt = point_batch.shape[0], point_gt.shape[0]
10 | point_batch = point_batch.unsqueeze(1).repeat(1, num_point_gt, 1)
11 | point_gt = point_gt.unsqueeze(0).repeat(num_point_batch, 1, 1)
12 |
13 | distances = torch.sqrt(torch.sum((point_batch-point_gt) ** 2, axis=-1) + 1e-12)
14 | dis_idx = torch.argmin(distances, axis=1).detach().cpu().numpy()
15 |
16 | return dis_idx
17 |
18 | def process_data(data_dir, dataname):
19 | if os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.ply'):
20 | pointcloud = trimesh.load(os.path.join(data_dir, 'pcd', dataname) + '.ply').vertices
21 | pointcloud = np.asarray(pointcloud)
22 | elif os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.xyz'):
23 | pointcloud = np.loadtxt(os.path.join(data_dir, 'pcd', dataname) + '.xyz')
24 | elif os.path.exists(os.path.join(data_dir, 'pcd', dataname) + '.npy'):
25 | pointcloud = np.load(os.path.join(data_dir, 'pcd', dataname) + '.npy')
26 | else:
27 | print('Only support .ply, .xyz or .npy data. Please adjust your data format.')
28 | exit()
29 | shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])])
30 | shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2]
31 | pointcloud = pointcloud - shape_center
32 | pointcloud = pointcloud / shape_scale
33 |
34 | POINT_NUM = pointcloud.shape[0] // 60
35 | POINT_NUM_GT = pointcloud.shape[0] // 60 * 60
36 | QUERY_EACH = 1000000//POINT_NUM_GT
37 |
38 | point_idx = np.random.choice(pointcloud.shape[0], POINT_NUM_GT, replace = False)
39 | pointcloud = pointcloud[point_idx,:]
40 | ptree = cKDTree(pointcloud)
41 | sigmas = []
42 | for p in np.array_split(pointcloud,100,axis=0):
43 | d = ptree.query(p,51)
44 | sigmas.append(d[0][:,-1])
45 |
46 | sigmas = np.concatenate(sigmas)
47 | sample = []
48 | sample_near = []
49 |
50 | for i in range(QUERY_EACH):
51 | scale = 0.25 if 0.25 * np.sqrt(POINT_NUM_GT / 20000) < 0.25 else 0.25 * np.sqrt(POINT_NUM_GT / 20000)
52 | tt = pointcloud + scale*np.expand_dims(sigmas,-1) * np.random.normal(0.0, 1.0, size=pointcloud.shape)
53 | sample.append(tt)
54 | tt = tt.reshape(-1,POINT_NUM,3)
55 |
56 | sample_near_tmp = []
57 | for j in range(tt.shape[0]):
58 | nearest_idx = search_nearest_point(torch.tensor(tt[j]).float().cuda(), torch.tensor(pointcloud).float().cuda())
59 | nearest_points = pointcloud[nearest_idx]
60 | nearest_points = np.asarray(nearest_points).reshape(-1,3)
61 | sample_near_tmp.append(nearest_points)
62 | sample_near_tmp = np.asarray(sample_near_tmp)
63 | sample_near_tmp = sample_near_tmp.reshape(-1,3)
64 | sample_near.append(sample_near_tmp)
65 |
66 | sample = np.asarray(sample)
67 | sample_near = np.asarray(sample_near)
68 |
69 | os.makedirs(os.path.join(data_dir, 'query_data'), exist_ok=True)
70 | np.savez(os.path.join(data_dir, 'query_data', dataname)+'.npz', sample = sample, point = pointcloud, sample_near = sample_near)
71 |
72 | class Dataset:
73 | def __init__(self, conf, dataname):
74 | super(Dataset, self).__init__()
75 | print('Load data: Begin')
76 | self.device = torch.device('cuda')
77 | self.conf = conf
78 |
79 | self.data_dir = conf.get_string('data_dir').replace('CASE_NAME', dataname)
80 | print(self.data_dir)
81 | self.data_name = dataname + '.npz'
82 |
83 | if os.path.exists(os.path.join(self.data_dir, 'query_data', self.data_name)):
84 | print('Query data existing. Loading data...')
85 | else:
86 | print('Query data not found. Processing data...')
87 | process_data(self.data_dir, dataname)
88 |
89 | load_data = np.load(os.path.join(self.data_dir, 'query_data', self.data_name))
90 |
91 | self.point = np.asarray(load_data['sample_near']).reshape(-1,3)
92 | self.sample = np.asarray(load_data['sample']).reshape(-1,3)
93 | self.point_gt = np.asarray(load_data['point']).reshape(-1,3)
94 | self.sample_points_num = self.sample.shape[0]-1
95 |
96 | self.object_bbox_min = np.array([np.min(self.point[:,0]), np.min(self.point[:,1]), np.min(self.point[:,2])]) -0.05
97 | self.object_bbox_max = np.array([np.max(self.point[:,0]), np.max(self.point[:,1]), np.max(self.point[:,2])]) +0.05
98 | print('bd:',self.object_bbox_min,self.object_bbox_max)
99 |
100 | self.point = torch.from_numpy(self.point).to(self.device).float()
101 | self.sample = torch.from_numpy(self.sample).to(self.device).float()
102 | self.point_gt = torch.from_numpy(self.point_gt).to(self.device).float()
103 |
104 | print('NP Load data: End')
105 |
106 | def get_train_data(self, batch_size):
107 | index_coarse = np.random.choice(10, 1)
108 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False)
109 | index = index_fine * 10 + index_coarse # for accelerating random choice operation
110 | points = self.point[index]
111 | sample = self.sample[index]
112 | return points, sample, self.point_gt
113 |
114 | def gen_new_data(self, tree):
115 | distance, index = tree.query(self.sample.detach().cpu().numpy(), 1)
116 | self.point_new = tree.data[index]
117 | self.point_new = torch.from_numpy(self.point_new).to(self.device).float()
118 |
119 |
120 | def get_train_data_step2(self, batch_size):
121 | index_coarse = np.random.choice(10, 1)
122 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False)
123 | index = index_fine * 10 + index_coarse # for accelerating random choice operation
124 | points = self.point_new[index]
125 | sample = self.sample[index]
126 | return points, sample, self.point_gt
127 |
128 |
--------------------------------------------------------------------------------
/models/udf_embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
6 | class Embedder:
7 | def __init__(self, **kwargs):
8 | self.kwargs = kwargs
9 | self.create_embedding_fn()
10 |
11 | def create_embedding_fn(self):
12 | embed_fns = []
13 | d = self.kwargs['input_dims']
14 | out_dim = 0
15 | if self.kwargs['include_input']:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs['max_freq_log2']
20 | N_freqs = self.kwargs['num_freqs']
21 |
22 | if self.kwargs['log_sampling']:
23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs['periodic_fns']:
29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 |
39 | def get_embedder(multires, input_dims=3):
40 | embed_kwargs = {
41 | 'include_input': True,
42 | 'input_dims': input_dims,
43 | 'max_freq_log2': multires-1,
44 | 'num_freqs': multires,
45 | 'log_sampling': True,
46 | 'periodic_fns': [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 | def embed(x, eo=embedder_obj): return eo.embed(x)
51 | return embed, embedder_obj.out_dim
52 |
--------------------------------------------------------------------------------
/models/udf_fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from models.embedder import get_embedder
6 |
7 | class UDFNetwork(nn.Module):
8 | def __init__(self,
9 | d_in,
10 | d_out,
11 | d_hidden,
12 | n_layers,
13 | skip_in=(4,),
14 | multires=0,
15 | bias=0.5,
16 | scale=1,
17 | geometric_init=True,
18 | weight_norm=True,
19 | inside_outside=False):
20 | super(UDFNetwork, self).__init__()
21 |
22 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
23 |
24 | self.embed_fn_fine = None
25 |
26 | if multires > 0:
27 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
28 | self.embed_fn_fine = embed_fn
29 | dims[0] = input_ch
30 |
31 | self.num_layers = len(dims)
32 | self.skip_in = skip_in
33 | self.scale = scale
34 |
35 | for l in range(0, self.num_layers - 1):
36 | if l + 1 in self.skip_in:
37 | out_dim = dims[l + 1] - dims[0]
38 | else:
39 | out_dim = dims[l + 1]
40 |
41 | lin = nn.Linear(dims[l], out_dim)
42 |
43 | if geometric_init:
44 | if l == self.num_layers - 2:
45 | if not inside_outside:
46 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
47 | torch.nn.init.constant_(lin.bias, -bias)
48 | else:
49 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
50 | torch.nn.init.constant_(lin.bias, bias)
51 | elif multires > 0 and l == 0:
52 | torch.nn.init.constant_(lin.bias, 0.0)
53 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
54 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
55 | elif multires > 0 and l in self.skip_in:
56 | torch.nn.init.constant_(lin.bias, 0.0)
57 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
58 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
59 | else:
60 | torch.nn.init.constant_(lin.bias, 0.0)
61 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
62 |
63 | if weight_norm:
64 | lin = nn.utils.weight_norm(lin)
65 |
66 | setattr(self, "lin" + str(l), lin)
67 |
68 | #self.activation = nn.Softplus(beta=100)
69 | self.activation = nn.ReLU()
70 |
71 | self.act_last = nn.Sigmoid()
72 |
73 | def forward(self, inputs):
74 | inputs = inputs * self.scale
75 | if self.embed_fn_fine is not None:
76 | inputs = self.embed_fn_fine(inputs)
77 |
78 | x = inputs
79 | for l in range(0, self.num_layers - 1):
80 | lin = getattr(self, "lin" + str(l))
81 |
82 | if l in self.skip_in:
83 | x = torch.cat([x, inputs], 1) / np.sqrt(2)
84 |
85 | x = lin(x)
86 |
87 | if l < self.num_layers - 2:
88 | x = self.activation(x)
89 |
90 | # x = self.act_last(x)
91 | res = torch.abs(x)
92 | # res = 1 - torch.exp(-x)
93 | return res / self.scale
94 |
95 | def udf(self, x):
96 | return self.forward(x)
97 |
98 | def udf_hidden_appearance(self, x):
99 | return self.forward(x)
100 |
101 | def gradient(self, x):
102 | x.requires_grad_(True)
103 | y = self.udf(x)
104 | # y.requires_grad_(True)
105 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
106 | gradients = torch.autograd.grad(
107 | outputs=y,
108 | inputs=x,
109 | grad_outputs=d_output,
110 | create_graph=True,
111 | retain_graph=True,
112 | only_inputs=True)[0]
113 | return gradients.unsqueeze(1)
114 |
115 |
--------------------------------------------------------------------------------
/pretrained_model/vismvsnet.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yulunwu0108/NeuSurf/9c5b3bc8e78e3dc31bcd2ee0af3c967bdf907944/pretrained_model/vismvsnet.pt
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.65.0
2 | pyhocon==0.3.57
3 | trimesh==3.22.5
4 | PyMCubes==0.1.4
5 | scipy==1.10.1
6 | point_cloud_utils==0.29.7
7 | icecream==2.1.3
8 | opencv-python==4.7.0.72
9 | tensorboard==2.12.1
--------------------------------------------------------------------------------
/tools/feat_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from typing import List, Union, Tuple
6 | from collections import OrderedDict
7 | from scipy.spatial.transform import Rotation as Rot
8 | from scipy.spatial.transform import Slerp
9 |
10 | def scale_camera(cam: Union[np.ndarray, torch.Tensor], scale: Union[Tuple, float]=1):
11 | """ resize input in order to produce sampled depth map """
12 | if type(scale) != tuple:
13 | scale = (scale, scale)
14 | if type(cam) == np.ndarray:
15 | new_cam = np.copy(cam)
16 | # focal:
17 | new_cam[1, 0, 0] = cam[1, 0, 0] * scale[0]
18 | new_cam[1, 1, 1] = cam[1, 1, 1] * scale[1]
19 | # principle point:
20 | new_cam[1, 0, 2] = cam[1, 0, 2] * scale[0]
21 | new_cam[1, 1, 2] = cam[1, 1, 2] * scale[1]
22 | elif type(cam) == torch.Tensor:
23 | new_cam = cam.clone()
24 | # focal:
25 | new_cam[..., 1, 0, 0] = cam[..., 1, 0, 0] * scale[0]
26 | new_cam[..., 1, 1, 1] = cam[..., 1, 1, 1] * scale[1]
27 | # principle point:
28 | new_cam[..., 1, 0, 2] = cam[..., 1, 0, 2] * scale[0]
29 | new_cam[..., 1, 1, 2] = cam[..., 1, 1, 2] * scale[1]
30 | else:
31 | raise TypeError
32 | return new_cam
33 |
34 |
35 | def bin_op_reduce(lst, func):
36 | result = lst[0]
37 | for i in range(1, len(lst)):
38 | result = func(result, lst[i])
39 | return result
40 |
41 |
42 | def idx_world2cam(idx_world_homo, cam):
43 | """nhw41 -> nhw41"""
44 | idx_cam_homo = cam[:,0:1,...].unsqueeze(1) @ idx_world_homo # nhw41
45 | idx_cam_homo = idx_cam_homo / (idx_cam_homo[...,-1:,:]+1e-9) # nhw41
46 | return idx_cam_homo
47 |
48 |
49 | def idx_cam2img(idx_cam_homo, cam):
50 | """nhw41 -> nhw31"""
51 | idx_cam = idx_cam_homo[...,:3,:] / (idx_cam_homo[...,3:4,:]+1e-9) # nhw31
52 | idx_img_homo = cam[:,1:2,:3,:3].unsqueeze(1) @ idx_cam # nhw31
53 | idx_img_homo = idx_img_homo / (idx_img_homo[...,-1:,:]+1e-9)
54 | return idx_img_homo
55 |
56 |
57 |
58 | def normalize_for_grid_sample(input_, grid):
59 | size = torch.tensor(input_.size())[2:].flip(0).to(grid.dtype).to(grid.device).view(1,1,1,-1) # [[[w, h]]]
60 | grid_n = grid / size
61 | grid_n = (grid_n * 2 - 1).clamp(-1.1, 1.1)
62 | return grid_n
63 |
64 |
65 | def get_in_range(grid):
66 | """after normalization, keepdim=False"""
67 | masks = []
68 | for dim in range(grid.size()[-1]):
69 | masks += [grid[..., dim]<=1, grid[..., dim]>=-1]
70 | in_range = bin_op_reduce(masks, torch.min).to(grid.dtype)
71 | return in_range
72 |
73 |
74 | def load_pair(file: str):
75 | with open(file) as f:
76 | lines = f.readlines()
77 | n_cam = int(lines[0])
78 | pairs = {}
79 | img_ids = []
80 | for i in range(1, 1+2*n_cam, 2):
81 | pair = []
82 | score = []
83 | img_id = lines[i].strip()
84 | pair_str = lines[i+1].strip().split(' ')
85 | n_pair = int(pair_str[0])
86 | for j in range(1, 1+2*n_pair, 2):
87 | pair.append(pair_str[j])
88 | score.append(float(pair_str[j+1]))
89 | img_ids.append(img_id)
90 | pairs[img_id] = {'id': img_id, 'index': i//2, 'pair': pair, 'score': score}
91 | pairs['id_list'] = img_ids
92 | return pairs
93 |
94 |
95 | def load_cam(file: str, max_d, interval_scale=1, override=False):
96 | """ read camera txt file """
97 | cam = np.zeros((2, 4, 4))
98 | with open(file) as f:
99 | words = f.read().split()
100 | # read extrinsic
101 | for i in range(0, 4):
102 | for j in range(0, 4):
103 | extrinsic_index = 4 * i + j + 1
104 | cam[0][i][j] = words[extrinsic_index]
105 |
106 | # read intrinsic
107 | for i in range(0, 3):
108 | for j in range(0, 3):
109 | intrinsic_index = 3 * i + j + 18
110 | cam[1][i][j] = words[intrinsic_index]
111 |
112 | if len(words) == 29:
113 | cam[1][3][0] = words[27]
114 | cam[1][3][1] = float(words[28]) * interval_scale
115 | cam[1][3][2] = max_d
116 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (cam[1][3][2] - 1)
117 | elif len(words) == 30:
118 | cam[1][3][0] = words[27]
119 | cam[1][3][1] = float(words[28]) * interval_scale
120 | cam[1][3][2] = words[29]
121 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (cam[1][3][2] - 1)
122 | elif len(words) == 31:
123 | if override:
124 | cam[1][3][0] = words[27]
125 | cam[1][3][1] = (float(words[30]) - float(words[27])) / (max_d - 1)
126 | cam[1][3][2] = max_d
127 | cam[1][3][3] = words[30]
128 | else:
129 | cam[1][3][0] = words[27]
130 | cam[1][3][1] = float(words[28]) * interval_scale
131 | cam[1][3][2] = words[29]
132 | cam[1][3][3] = words[30]
133 | else:
134 | cam[1][3][0] = 0
135 | cam[1][3][1] = 0
136 | cam[1][3][2] = 0
137 | cam[1][3][3] = 0
138 |
139 | return cam
140 |
141 | class ListModule(nn.Module):
142 | def __init__(self, modules: Union[List, OrderedDict]):
143 | super(ListModule, self).__init__()
144 | if isinstance(modules, OrderedDict):
145 | iterable = modules.items()
146 | elif isinstance(modules, list):
147 | iterable = enumerate(modules)
148 | else:
149 | raise TypeError('modules should be OrderedDict or List.')
150 | for name, module in iterable:
151 | if not isinstance(module, nn.Module):
152 | module = ListModule(module)
153 | if not isinstance(name, str):
154 | name = str(name)
155 | self.add_module(name, module)
156 |
157 | def __getitem__(self, idx):
158 | if idx < 0 or idx >= len(self._modules):
159 | raise IndexError('index {} is out of range'.format(idx))
160 | it = iter(self._modules.values())
161 | for i in range(idx):
162 | next(it)
163 | return next(it)
164 |
165 | def __iter__(self):
166 | return iter(self._modules.values())
167 |
168 | def __len__(self):
169 | return len(self._modules)
170 |
171 |
172 | class BasicBlock(nn.Module):
173 | expansion = 1
174 |
175 | def __init__(self, inplanes, planes, stride=1, downsample=None, dim=2):
176 | super(BasicBlock, self).__init__()
177 |
178 | self.conv_fn = nn.Conv2d if dim == 2 else nn.Conv3d
179 | self.bn_fn = nn.BatchNorm2d if dim == 2 else nn.BatchNorm3d
180 |
181 | self.conv1 = self.conv3x3(inplanes, planes, stride)
182 | self.bn1 = self.bn_fn(planes)
183 | self.relu = nn.ReLU(inplace=True)
184 | self.conv2 = self.conv3x3(planes, planes)
185 | self.bn2 = self.bn_fn(planes)
186 | self.downsample = downsample
187 | self.stride = stride
188 |
189 | def conv1x1(self, in_planes, out_planes, stride=1):
190 | """1x1 convolution"""
191 | return self.conv_fn(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
192 |
193 | def conv3x3(self, in_planes, out_planes, stride=1):
194 | """3x3 convolution with padding"""
195 | return self.conv_fn(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
196 |
197 | def forward(self, x):
198 | residual = x
199 |
200 | out = self.conv1(x)
201 | out = self.bn1(out)
202 | out = self.relu(out)
203 |
204 | out = self.conv2(out)
205 | out = self.bn2(out)
206 |
207 | if self.downsample is not None:
208 | residual = self.downsample(x)
209 |
210 | out += residual
211 | out = self.relu(out)
212 |
213 | return out
214 |
215 |
216 | def _make_layer(inplanes, block, planes, blocks, stride=1, dim=2):
217 | downsample = None
218 | conv_fn = nn.Conv2d if dim==2 else nn.Conv3d
219 | bn_fn = nn.BatchNorm2d if dim==2 else nn.BatchNorm3d
220 | if stride != 1 or inplanes != planes * block.expansion:
221 | downsample = nn.Sequential(
222 | conv_fn(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
223 | bn_fn(planes * block.expansion)
224 | )
225 |
226 | layers = []
227 | layers.append(block(inplanes, planes, stride, downsample, dim=dim))
228 | inplanes = planes * block.expansion
229 | for _ in range(1, blocks):
230 | layers.append(block(inplanes, planes, dim=dim))
231 |
232 | return nn.Sequential(*layers)
233 |
234 |
235 | class UNet(nn.Module):
236 |
237 | def __init__(self, inplanes: int, enc: int, dec: int, initial_scale: int,
238 | bottom_filters: List[int], filters: List[int], head_filters: List[int],
239 | prefix: str, dim: int=2):
240 | super(UNet, self).__init__()
241 |
242 | conv_fn = nn.Conv2d if dim==2 else nn.Conv3d
243 | deconv_fn = nn.ConvTranspose2d if dim==2 else nn.ConvTranspose3d
244 | current_scale = initial_scale
245 | idx = 0
246 | prev_f = inplanes
247 |
248 | self.bottom_blocks = OrderedDict()
249 | for f in bottom_filters:
250 | block = _make_layer(prev_f, BasicBlock, f, enc, 1 if idx==0 else 2, dim=dim)
251 | self.bottom_blocks[f'{prefix}{current_scale}_{idx}'] = block
252 | idx += 1
253 | current_scale *= 2
254 | prev_f = f
255 | self.bottom_blocks = ListModule(self.bottom_blocks)
256 |
257 | self.enc_blocks = OrderedDict()
258 | for f in filters:
259 | block = _make_layer(prev_f, BasicBlock, f, enc, 1 if idx == 0 else 2, dim=dim)
260 | self.enc_blocks[f'{prefix}{current_scale}_{idx}'] = block
261 | idx += 1
262 | current_scale *= 2
263 | prev_f = f
264 | self.enc_blocks = ListModule(self.enc_blocks)
265 |
266 | self.dec_blocks = OrderedDict()
267 | for f in filters[-2::-1]:
268 | block = [
269 | deconv_fn(prev_f, f, 3, 2, 1, 1, bias=False),
270 | conv_fn(2*f, f, 3, 1, 1, bias=False),
271 | ]
272 | if dec > 0:
273 | block.append(_make_layer(f, BasicBlock, f, dec, 1, dim=dim))
274 | self.dec_blocks[f'{prefix}{current_scale}_{idx}'] = block
275 | idx += 1
276 | current_scale //= 2
277 | prev_f = f
278 | self.dec_blocks = ListModule(self.dec_blocks)
279 |
280 | self.head_blocks = OrderedDict()
281 | for f in head_filters:
282 | block = [
283 | deconv_fn(prev_f, f, 3, 2, 1, 1, bias=False)
284 | ]
285 | if dec > 0:
286 | block.append(_make_layer(f, BasicBlock, f, dec, 1, dim=dim))
287 | block = nn.Sequential(*block)
288 | self.head_blocks[f'{prefix}{current_scale}_{idx}'] = block
289 | idx += 1
290 | current_scale //= 2
291 | prev_f = f
292 | self.head_blocks = ListModule(self.head_blocks)
293 |
294 | def forward(self, x, multi_scale=1):
295 | for b in self.bottom_blocks:
296 | x = b(x)
297 | enc_out = []
298 | for b in self.enc_blocks:
299 | x = b(x)
300 | enc_out.append(x)
301 | dec_out = [x]
302 | for i, b in enumerate(self.dec_blocks):
303 | if len(b) == 3: deconv, post_concat, res = b
304 | elif len(b) == 2: deconv, post_concat = b
305 | x = deconv(x)
306 | x = torch.cat([x, enc_out[-2-i]], 1)
307 | x = post_concat(x)
308 | if len(b) == 3: x = res(x)
309 | dec_out.append(x)
310 | for b in self.head_blocks:
311 | x = b(x)
312 | dec_out.append(x)
313 | if multi_scale == 1: return x
314 | else: return dec_out[-multi_scale:]
315 |
316 |
317 | class FeatExt(nn.Module):
318 |
319 | def __init__(self):
320 | super(FeatExt, self).__init__()
321 | self.init_conv = nn.Sequential(
322 | nn.Conv2d(3, 16, 5, 2, 2, bias=False),
323 | nn.BatchNorm2d(16),
324 | nn.ReLU()
325 | )
326 | self.unet = UNet(16, 2, 1, 2, [], [32, 64, 128], [], '2d', 2)
327 | self.final_conv_1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
328 | self.final_conv_2 = nn.Conv2d(64, 32, 3, 1, 1, bias=False)
329 | self.final_conv_3 = nn.Conv2d(32, 32, 3, 1, 1, bias=False)
330 |
331 | feat_ext_dict = {k[16:]:v for k,v in torch.load('pretrained_model/vismvsnet.pt')['state_dict'].items() if k.startswith('module.feat_ext')}
332 | self.load_state_dict(feat_ext_dict)
333 |
334 | def forward(self, x):
335 | out = self.init_conv(x)
336 | out1, out2, out3 = self.unet(out, multi_scale=3)
337 | return self.final_conv_1(out1), self.final_conv_2(out2), self.final_conv_3(out3)
338 |
339 |
340 | def gen_camera_between(pose_0, pose_1, ratio):
341 | rot_0 = pose_0[:3, :3]
342 | rot_1 = pose_1[:3, :3]
343 | rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
344 | key_times = [0, 1]
345 | slerp = Slerp(key_times, rots)
346 | rot = slerp(ratio)
347 | pose = np.diag([1.0, 1.0, 1.0, 1.0])
348 | pose = pose.astype(np.float32)
349 | pose[:3, :3] = rot.as_matrix()
350 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
351 | pose = torch.from_numpy(pose).cuda()
352 | pose.requires_grad = False
353 | return pose
354 |
355 |
356 | def get_local_loss(diff_surf_pts,
357 | uncerts,
358 | feat,
359 | cam,
360 | feat_src,
361 | src_cams,
362 | size,
363 | center,
364 | network_object_mask,
365 | object_mask
366 | ):
367 | mask = network_object_mask & object_mask
368 |
369 | if (mask).sum() == 0:
370 | return torch.tensor(0.0).float().cuda()
371 |
372 | sample_mask = mask.view(feat.size()[0], -1)
373 | hit_nums = sample_mask.sum(-1)
374 | accu_nums = [0] + hit_nums.cumsum(0).tolist()
375 | slices = [slice(accu_nums[i], accu_nums[i + 1]) for i in range(len(accu_nums) - 1)]
376 |
377 | loss = []
378 | for view_i, slice_ in enumerate(slices):
379 | if slice_.start < slice_.stop:
380 |
381 | # projection
382 | diff_surf_pts_slice = diff_surf_pts[slice_]
383 | pts_world = (diff_surf_pts_slice / 2 * size.view(1, 1) + center.view(1, 3)).view(1, -1, 1, 3, 1)
384 | pts_world = torch.cat([pts_world, torch.ones_like(pts_world[..., -1:, :])], dim=-2)
385 | cam_pack = torch.cat([cam[view_i:view_i + 1], src_cams[view_i]], dim=0)
386 | pts_img = idx_cam2img(idx_world2cam(pts_world, cam_pack), cam_pack)
387 |
388 | # gathering
389 | grid = pts_img[..., :2, 0]
390 |
391 | feat2_pack = torch.cat([feat[view_i:view_i + 1], feat_src[view_i]], dim=0)
392 | grid_n = normalize_for_grid_sample(feat2_pack, grid / 2)
393 | grid_in_range = get_in_range(grid_n)
394 | valid_mask = (grid_in_range[:1, ...] * grid_in_range[1:, ...]).unsqueeze(1) > 0.5
395 | gathered_feat = F.grid_sample(feat2_pack, grid_n, mode='bilinear', padding_mode='zeros',
396 | align_corners=False)
397 |
398 | # calculation
399 | gathered_norm = gathered_feat.norm(dim=1, keepdim=True)
400 | corr = (gathered_feat[:1] * gathered_feat[1:]).sum(dim=1, keepdim=True) \
401 | / gathered_norm[:1].clamp(min=1e-9) / gathered_norm[1:].clamp(min=1e-9)
402 | corr_loss = (1 - corr).abs()
403 | if uncerts is None:
404 | diff_mask = corr_loss < 0.5
405 | sample_loss = (corr_loss * valid_mask * diff_mask).mean()
406 | else:
407 | uncert = uncerts[view_i].unsqueeze(1).unsqueeze(3)
408 | sample_loss = ((corr_loss * (-uncert).exp() + uncert) * valid_mask).mean()
409 | else:
410 | sample_loss = torch.zeros(1).float().cuda()
411 | loss.append(sample_loss)
412 | loss = sum(loss) / len(loss)
413 | return loss
414 |
--------------------------------------------------------------------------------
/tools/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/tools/surface_extraction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mcubes
3 | import trimesh
4 | import torch
5 |
6 | from extensions.chamfer_dist import ChamferDistanceL2
7 | from tools.logger import print_log
8 | def as_mesh(scene_or_mesh):
9 | """
10 | Convert a possible scene to a mesh.
11 |
12 | If conversion occurs, the returned mesh has only vertex and face data.
13 | Suggested by https://github.com/mikedh/trimesh/issues/507
14 | """
15 | if isinstance(scene_or_mesh, trimesh.Scene):
16 | if len(scene_or_mesh.geometry) == 0:
17 | mesh = None # empty scene
18 | else:
19 | # we lose texture information here
20 | mesh = trimesh.util.concatenate(
21 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
22 | for g in scene_or_mesh.geometry.values()))
23 | else:
24 | assert(isinstance(scene_or_mesh, trimesh.Trimesh))
25 | mesh = scene_or_mesh
26 | return mesh
27 |
28 | def surface_extraction(ndf, grad, out_path, iter_step, b_max, b_min, resolution):
29 | v_all = []
30 | t_all = []
31 | threshold = 0.005 # accelerate extraction
32 | v_num = 0
33 | for i in range(resolution-1):
34 | for j in range(resolution-1):
35 | for k in range(resolution-1):
36 | ndf_loc = ndf[i:i+2]
37 | ndf_loc = ndf_loc[:,j:j+2,:]
38 | ndf_loc = ndf_loc[:,:,k:k+2]
39 | if np.min(ndf_loc) > threshold:
40 | continue
41 | grad_loc = grad[i:i+2]
42 | grad_loc = grad_loc[:,j:j+2,:]
43 | grad_loc = grad_loc[:,:,k:k+2]
44 |
45 | res = np.ones((2,2,2))
46 | for ii in range(2):
47 | for jj in range(2):
48 | for kk in range(2):
49 | if np.dot(grad_loc[0][0][0], grad_loc[ii][jj][kk]) < 0:
50 | res[ii][jj][kk] = -ndf_loc[ii][jj][kk]
51 | else:
52 | res[ii][jj][kk] = ndf_loc[ii][jj][kk]
53 |
54 | if res.min()<0:
55 | vertices, triangles = mcubes.marching_cubes(
56 | res, 0.0)
57 | # print(vertices)
58 | # vertices -= 1.5
59 | # vertices /= 128
60 | vertices[:,0] += i #/ resolution
61 | vertices[:,1] += j #/ resolution
62 | vertices[:,2] += k #/ resolution
63 | triangles += v_num
64 | # vertices =
65 | # vertices[:,1] /= 3 # TODO
66 | v_all.append(vertices)
67 | t_all.append(triangles)
68 |
69 | v_num += vertices.shape[0]
70 | # print(v_num)
71 |
72 | v_all = np.concatenate(v_all)
73 | t_all = np.concatenate(t_all)
74 | # Create mesh
75 | v_all = v_all / (resolution - 1.0) * (b_max - b_min)[None, :] + b_min[None, :]
76 |
77 | mesh = trimesh.Trimesh(v_all, t_all, process=False)
78 |
79 | return mesh
80 |
81 |
82 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from random import sample
4 | import time
5 | from tkinter import Variable
6 | from shutil import copyfile
7 | import numpy as np
8 | import trimesh
9 |
10 | from scipy.spatial import cKDTree
11 |
12 |
13 | def get_aver(distances, face):
14 | return (distances[face[0]] + distances[face[1]] + distances[face[2]]) / 3.0
15 |
16 | def remove_far(gt_pts, mesh, dis_trunc=0.1, is_use_prj=False):
17 | # gt_pts: trimesh
18 | # mesh: trimesh
19 |
20 | gt_kd_tree = cKDTree(gt_pts)
21 | distances, vertex_ids = gt_kd_tree.query(mesh.vertices, p=2, distance_upper_bound=dis_trunc)
22 | faces_remaining = []
23 | faces = mesh.faces
24 |
25 | if is_use_prj:
26 | normals = gt_pts.vertex_normals
27 | closest_points = gt_pts.vertices[vertex_ids]
28 | closest_normals = normals[vertex_ids]
29 | direction_from_surface = mesh.vertices - closest_points
30 | distances = direction_from_surface * closest_normals
31 | distances = np.sum(distances, axis=1)
32 |
33 | for i in range(faces.shape[0]):
34 | if get_aver(distances, faces[i]) < dis_trunc:
35 | faces_remaining.append(faces[i])
36 | mesh_cleaned = mesh.copy()
37 | mesh_cleaned.faces = faces_remaining
38 | mesh_cleaned.remove_unreferenced_vertices()
39 |
40 | return mesh_cleaned
41 |
42 | def remove_outlier(gt_pts, q_pts, dis_trunc=0.003, is_use_prj=False):
43 | # gt_pts: trimesh
44 | # mesh: trimesh
45 |
46 | gt_kd_tree = cKDTree(gt_pts)
47 | distances, q_ids = gt_kd_tree.query(q_pts, p=2, distance_upper_bound=dis_trunc)
48 |
49 | q_pts = q_pts[distancesd} cd_l1 = {} lr={}'.format(self.iter_step, loss_cd, self.optimizer.param_groups[0]['lr']), logger=logger)
144 |
145 | if self.iter_step == self.step1_maxiter or self.iter_step == self.step2_maxiter:
146 | self.save_checkpoint()
147 |
148 | if self.iter_step == self.step1_maxiter:
149 | gen_pointclouds = self.gen_extra_pointcloud(self.iter_step, self.conf.get_float('udf_train.low_range'))
150 | idx = pcu.downsample_point_cloud_poisson_disk(gen_pointclouds, num_samples=int(self.conf.get_float('udf_train.extra_points_rate')*point_gt.shape[0]))
151 | poisson_pointclouds = gen_pointclouds[idx]
152 | dense_pointclouds = np.concatenate((point_gt.detach().cpu().numpy(), poisson_pointclouds))
153 | self.ptree = cKDTree(dense_pointclouds)
154 | self.dataset.gen_new_data(self.ptree)
155 |
156 | if self.iter_step == self.step2_maxiter:
157 | gen_pointclouds = self.gen_extra_pointcloud(self.iter_step, 1)
158 |
159 | # if self.iter_step == self.step1_maxiter or self.iter_step == self.step2_maxiter:
160 | # self.extract_mesh(resolution=args.mcube_resolution, threshold=0.0, point_gt=point_gt, iter_step=self.iter_step, logger=logger)
161 |
162 |
163 | def extract_mesh(self, resolution=64, threshold=0.0, point_gt=None, iter_step=0, logger=None):
164 |
165 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
166 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)
167 | out_dir = os.path.join(self.base_exp_dir, 'mesh')
168 | os.makedirs(out_dir, exist_ok=True)
169 |
170 | mesh = extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold, \
171 | out_dir=out_dir, iter_step=iter_step, dataname=self.dataname, logger=logger, \
172 | query_func=lambda pts: self.udf_network.udf(pts), grad_func=lambda pts: self.udf_network.gradient(pts))
173 | if self.conf.get_float('udf_train.far') > 0:
174 | mesh = remove_far(point_gt.detach().cpu().numpy(), mesh, self.conf.get_float('udf_train.far'))
175 |
176 | mesh.export(out_dir+'/'+str(iter_step)+'_mesh.obj')
177 |
178 |
179 |
180 | def gen_extra_pointcloud(self, iter_step, low_range):
181 |
182 | res = []
183 | num_points = self.eval_num_points
184 | gen_nums = 0
185 |
186 | os.makedirs(os.path.join(self.base_exp_dir, 'pointcloud'), exist_ok=True)
187 |
188 | while gen_nums < num_points:
189 |
190 | points, samples, point_gt = self.dataset.get_train_data(5000)
191 | offsets = samples - points
192 | std = torch.std(offsets)
193 |
194 | extra_std = std * low_range
195 | rands = torch.normal(0.0, extra_std, size=points.shape)
196 | samples = points + torch.tensor(rands).cuda().float()
197 |
198 | samples.requires_grad = True
199 | gradients_sample = self.udf_network.gradient(samples).squeeze() # 5000x3
200 | udf_sample = self.udf_network.udf(samples) # 5000x1
201 | grad_norm = F.normalize(gradients_sample, dim=1) # 5000x3
202 | sample_moved = samples - grad_norm * udf_sample # 5000x3
203 |
204 | index = udf_sample < self.df_filter
205 | index = index.squeeze(1)
206 | sample_moved = sample_moved[index]
207 |
208 | gen_nums += sample_moved.shape[0]
209 |
210 | res.append(sample_moved.detach().cpu().numpy())
211 |
212 | res = np.concatenate(res)
213 | res = res[:num_points]
214 | np.savetxt(os.path.join(self.base_exp_dir, 'pointcloud', 'point_cloud%d.xyz'%(iter_step)), res)
215 |
216 | res = remove_outlier(point_gt.detach().cpu().numpy(), res, dis_trunc=self.conf.get_float('udf_train.outlier'))
217 | return res
218 |
219 | def update_learning_rate(self, iter_step):
220 |
221 | warn_up = self.warm_up_end
222 | max_iter = self.step2_maxiter
223 | init_lr = self.learning_rate
224 | lr = (iter_step / warn_up) if iter_step < warn_up else 0.5 * (math.cos((iter_step - warn_up)/(max_iter - warn_up) * math.pi) + 1)
225 | lr = lr * init_lr
226 |
227 | for g in self.optimizer.param_groups:
228 | g['lr'] = lr
229 |
230 | def file_backup(self):
231 | dir_lis = self.conf['general.recording']
232 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
233 | for dir_name in dir_lis:
234 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
235 | os.makedirs(cur_dir, exist_ok=True)
236 | files = os.listdir(dir_name)
237 | for f_name in files:
238 | if f_name[-3:] == '.py':
239 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
240 |
241 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
242 |
243 | def load_checkpoint(self, checkpoint_name):
244 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
245 | print(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name))
246 | self.udf_network.load_state_dict(checkpoint['udf_network_fine'])
247 |
248 | self.iter_step = checkpoint['iter_step']
249 |
250 | def save_checkpoint(self):
251 | checkpoint = {
252 | 'udf_network_fine': self.udf_network.state_dict(),
253 | 'iter_step': self.iter_step,
254 | }
255 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
256 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
257 |
258 |
259 | if __name__ == '__main__':
260 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
261 | parser = argparse.ArgumentParser()
262 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf')
263 | parser.add_argument('--mcube_resolution', type=int, default=256)
264 | parser.add_argument('--gpu', type=int, default=0)
265 | parser.add_argument('--udf_dir', type=str, default='udf')
266 | parser.add_argument('--case', type=str, default='')
267 | args = parser.parse_args()
268 |
269 | torch.cuda.set_device(args.gpu)
270 | runner = UDFRunner(args, args.conf)
271 |
272 | runner.train()
273 |
--------------------------------------------------------------------------------