├── .gitignore ├── 1gpu.yaml ├── 8gpu.yaml ├── LICENSE ├── NeuS ├── confs │ └── wmask.conf ├── exp_runner.py ├── models │ ├── dataset_mvdiff.py │ ├── embedder.py │ ├── fields.py │ ├── fixed_poses │ │ ├── 000_back_RT.txt │ │ ├── 000_back_left_RT.txt │ │ ├── 000_back_right_RT.txt │ │ ├── 000_front_RT.txt │ │ ├── 000_front_left_RT.txt │ │ ├── 000_front_right_RT.txt │ │ ├── 000_left_RT.txt │ │ ├── 000_right_RT.txt │ │ └── 000_top_RT.txt │ ├── normal_utils.py │ ├── ops.py │ └── renderer.py └── run.sh ├── README.md ├── README_zh.md ├── assets ├── bug_fixed.png ├── coordinate.png └── fig_teaser.png ├── configs ├── mvdiffusion-joint-ortho-6views.yaml └── train │ ├── stage1-mix-6views-lvis.yaml │ └── stage2-joint-6views-lvis.yaml ├── data_lists ├── lvis_invalid_uids_nineviews.json └── lvis_uids_filter_by_vertex.json ├── docker ├── Dockerfile ├── README.md └── requirements.txt ├── example_images ├── 14_10_29_489_Tiger_1__1.png ├── box.png ├── bread.png ├── cat.png ├── cat_head.png ├── chili.png ├── duola.png ├── halloween.png ├── head.png ├── kettle.png ├── kunkun.png ├── milk.png ├── owl.png ├── poro.png ├── pumpkin.png ├── skull.png ├── stone.png ├── teapot.png └── tiger-head-3d-model-obj-stl.png ├── gradio_app_mv.py ├── gradio_app_recon.py ├── instant-nsr-pl ├── README.md ├── configs │ └── neuralangelo-ortho-wmask.yaml ├── datasets │ ├── __init__.py │ ├── blender.py │ ├── colmap.py │ ├── colmap_utils.py │ ├── dtu.py │ ├── fixed_poses │ │ ├── 000_back_RT.txt │ │ ├── 000_back_left_RT.txt │ │ ├── 000_back_right_RT.txt │ │ ├── 000_front_RT.txt │ │ ├── 000_front_left_RT.txt │ │ ├── 000_front_right_RT.txt │ │ ├── 000_left_RT.txt │ │ ├── 000_right_RT.txt │ │ └── 000_top_RT.txt │ ├── ortho.py │ └── utils.py ├── launch.py ├── models │ ├── __init__.py │ ├── base.py │ ├── geometry.py │ ├── nerf.py │ ├── network_utils.py │ ├── neus.py │ ├── ray_utils.py │ ├── texture.py │ └── utils.py ├── requirements.txt ├── run.sh ├── scripts │ └── imgs2poses.py ├── systems │ ├── __init__.py │ ├── base.py │ ├── criterions.py │ ├── nerf.py │ ├── neus.py │ ├── neus_ortho.py │ ├── neus_pinhole.py │ └── utils.py └── utils │ ├── __init__.py │ ├── callbacks.py │ ├── loggers.py │ ├── misc.py │ ├── mixins.py │ └── obj.py ├── mvdiffusion ├── data │ ├── fixed_poses │ │ └── nine_views │ │ │ ├── 000_back_RT.txt │ │ │ ├── 000_back_left_RT.txt │ │ │ ├── 000_back_right_RT.txt │ │ │ ├── 000_front_RT.txt │ │ │ ├── 000_front_left_RT.txt │ │ │ ├── 000_front_right_RT.txt │ │ │ ├── 000_left_RT.txt │ │ │ ├── 000_right_RT.txt │ │ │ └── 000_top_RT.txt │ ├── normal_utils.py │ ├── objaverse_dataset.py │ └── single_image_dataset.py ├── models │ ├── transformer_mv2d.py │ ├── unet_mv2d_blocks.py │ └── unet_mv2d_condition.py └── pipelines │ └── pipeline_mvdiffusion_image.py ├── render_codes ├── README.md ├── blenderProc_ortho.py ├── blenderProc_persp.py ├── distributed.py ├── render_batch_ortho.sh ├── render_batch_persp.sh ├── render_single_ortho.sh ├── render_single_persp.sh └── requirements.txt ├── requirements.txt ├── run_test.sh ├── run_train_stage1.sh ├── run_train_stage2.sh ├── test_mvdiffusion_seq.py ├── train_mvdiffusion_image.py ├── train_mvdiffusion_joint.py └── utils └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | ckpts 4 | sam_pt 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # tests and logs 14 | tests/fixtures/cached_*_text.txt 15 | logs/ 16 | lightning_logs/ 17 | lang_code_data/ 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # vscode 127 | .vs 128 | .vscode 129 | 130 | # Pycharm 131 | .idea 132 | 133 | # TF code 134 | tensorflow_code 135 | 136 | # Models 137 | proc_data 138 | 139 | # examples 140 | runs 141 | /runs_old 142 | /wandb 143 | /examples/runs 144 | /examples/**/*.args 145 | /examples/rag/sweep 146 | 147 | # data 148 | /data 149 | serialization_dir 150 | 151 | # emacs 152 | *.*~ 153 | debug.env 154 | 155 | # vim 156 | .*.swp 157 | 158 | #ctags 159 | tags 160 | 161 | # pre-commit 162 | .pre-commit* 163 | 164 | # .lock 165 | *.lock 166 | 167 | # DS_Store (MacOS) 168 | .DS_Store 169 | # RL pipelines may produce mp4 outputs 170 | *.mp4 171 | 172 | # dependencies 173 | /transformers 174 | 175 | # ruff 176 | .ruff_cache 177 | 178 | # ckpts 179 | *.ckpt 180 | 181 | outputs/* 182 | 183 | NeuS/exp/* 184 | NeuS/test_scenes/* 185 | NeuS/mesh2tex/* 186 | neus_configs 187 | vast/* 188 | render_results 189 | experiments/* 190 | ckpts/* 191 | neus/* 192 | instant-nsr-pl/exp/* -------------------------------------------------------------------------------- /1gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: 'NO' 3 | downcast_bf16: 'no' 4 | gpu_ids: '0' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: 'no' 8 | num_machines: 1 9 | num_processes: 1 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /8gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: 'no' 4 | gpu_ids: all 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: 'no' 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 xxlong0 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NeuS/confs/wmask.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp/neus/CASE_NAME/ 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./outputs/ 11 | object_name = CASE_NAME 12 | object_viewidx = 1 13 | imSize = [256, 256] 14 | load_color = True 15 | stage = coarse 16 | mtype = mlp 17 | normal_system: front 18 | num_views = 6 19 | } 20 | 21 | train { 22 | learning_rate = 5e-4 23 | learning_rate_alpha = 0.05 24 | end_iter = 1000 # longer time, better result. 1w will be ok for most cases 25 | 26 | batch_size = 512 27 | validate_resolution_level = 1 28 | warm_up_end = 500 29 | anneal_end = 0 30 | use_white_bkgd = True 31 | 32 | save_freq = 5000 33 | val_freq = 5000 34 | val_mesh_freq =5000 35 | report_freq = 100 36 | 37 | color_weight = 1.0 38 | igr_weight = 0.1 39 | mask_weight = 1.0 40 | normal_weight = 1.0 41 | sparse_weight = 0.1 42 | 43 | } 44 | 45 | model { 46 | nerf { 47 | D = 8, 48 | d_in = 4, 49 | d_in_view = 3, 50 | W = 256, 51 | multires = 10, 52 | multires_view = 4, 53 | output_ch = 4, 54 | skips=[4], 55 | use_viewdirs=True 56 | } 57 | 58 | sdf_network { 59 | d_out = 257 60 | d_in = 3 61 | d_hidden = 256 62 | n_layers = 8 63 | skip_in = [4] 64 | multires = 6 65 | bias = 0.5 66 | scale = 1.0 67 | geometric_init = True 68 | weight_norm = True 69 | } 70 | 71 | variance_network { 72 | init_val = 0.3 73 | } 74 | 75 | rendering_network { 76 | d_feature = 256 77 | mode = no_view_dir 78 | d_in = 6 79 | d_out = 3 80 | d_hidden = 256 81 | n_layers = 4 82 | weight_norm = True 83 | multires_view = 0 84 | squeeze_out = True 85 | } 86 | 87 | neus_renderer { 88 | n_samples = 64 89 | n_importance = 64 90 | n_outside = 0 91 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 92 | perturb = 1.0 93 | sdf_decay_param = 100 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /NeuS/models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /NeuS/models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | d_in, 12 | d_out, 13 | d_hidden, 14 | n_layers, 15 | skip_in=(4,), 16 | multires=0, 17 | bias=0.5, 18 | scale=1, 19 | geometric_init=True, 20 | weight_norm=True, 21 | inside_outside=False): 22 | super(SDFNetwork, self).__init__() 23 | 24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 25 | 26 | self.embed_fn_fine = None 27 | 28 | if multires > 0: 29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 30 | self.embed_fn_fine = embed_fn 31 | dims[0] = input_ch 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | self.scale = scale 36 | 37 | for l in range(0, self.num_layers - 1): 38 | if l + 1 in self.skip_in: 39 | out_dim = dims[l + 1] - dims[0] 40 | else: 41 | out_dim = dims[l + 1] 42 | 43 | lin = nn.Linear(dims[l], out_dim) 44 | 45 | if geometric_init: 46 | if l == self.num_layers - 2: 47 | if not inside_outside: 48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 49 | torch.nn.init.constant_(lin.bias, -bias) 50 | else: 51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 52 | torch.nn.init.constant_(lin.bias, bias) 53 | elif multires > 0 and l == 0: 54 | torch.nn.init.constant_(lin.bias, 0.0) 55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 57 | elif multires > 0 and l in self.skip_in: 58 | torch.nn.init.constant_(lin.bias, 0.0) 59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 61 | else: 62 | torch.nn.init.constant_(lin.bias, 0.0) 63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 64 | 65 | if weight_norm: 66 | lin = nn.utils.weight_norm(lin) 67 | 68 | setattr(self, "lin" + str(l), lin) 69 | 70 | self.activation = nn.Softplus(beta=100) 71 | 72 | def forward(self, inputs): 73 | inputs = inputs * self.scale 74 | if self.embed_fn_fine is not None: 75 | inputs = self.embed_fn_fine(inputs) 76 | 77 | x = inputs 78 | for l in range(0, self.num_layers - 1): 79 | lin = getattr(self, "lin" + str(l)) 80 | 81 | if l in self.skip_in: 82 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 83 | 84 | x = lin(x) 85 | 86 | if l < self.num_layers - 2: 87 | x = self.activation(x) 88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 89 | 90 | def sdf(self, x): 91 | return self.forward(x)[:, :1] 92 | 93 | def sdf_hidden_appearance(self, x): 94 | return self.forward(x) 95 | 96 | def gradient(self, x): 97 | x.requires_grad_(True) 98 | with torch.enable_grad(): 99 | y = self.sdf(x) 100 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 101 | gradients = torch.autograd.grad( 102 | outputs=y, 103 | inputs=x, 104 | grad_outputs=d_output, 105 | create_graph=True, 106 | retain_graph=True, 107 | only_inputs=True)[0] 108 | return gradients.unsqueeze(1) 109 | 110 | 111 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 112 | class RenderingNetwork(nn.Module): 113 | def __init__(self, 114 | d_feature, 115 | mode, 116 | d_in, 117 | d_out, 118 | d_hidden, 119 | n_layers, 120 | weight_norm=True, 121 | multires_view=0, 122 | squeeze_out=True): 123 | super().__init__() 124 | 125 | self.mode = mode 126 | self.squeeze_out = squeeze_out 127 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 128 | 129 | self.embedview_fn = None 130 | if multires_view > 0: 131 | embedview_fn, input_ch = get_embedder(multires_view) 132 | self.embedview_fn = embedview_fn 133 | dims[0] += (input_ch - 3) 134 | 135 | self.num_layers = len(dims) 136 | 137 | for l in range(0, self.num_layers - 1): 138 | out_dim = dims[l + 1] 139 | lin = nn.Linear(dims[l], out_dim) 140 | 141 | if weight_norm: 142 | lin = nn.utils.weight_norm(lin) 143 | 144 | setattr(self, "lin" + str(l), lin) 145 | 146 | self.relu = nn.ReLU() 147 | 148 | def forward(self, points, normals, view_dirs, feature_vectors): 149 | if self.embedview_fn is not None: 150 | view_dirs = self.embedview_fn(view_dirs) 151 | 152 | rendering_input = None 153 | 154 | if self.mode == 'idr': 155 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 156 | elif self.mode == 'no_view_dir': 157 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 158 | elif self.mode == 'no_normal': 159 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 160 | 161 | x = rendering_input 162 | 163 | for l in range(0, self.num_layers - 1): 164 | lin = getattr(self, "lin" + str(l)) 165 | 166 | x = lin(x) 167 | 168 | if l < self.num_layers - 2: 169 | x = self.relu(x) 170 | 171 | if self.squeeze_out: 172 | x = torch.sigmoid(x) 173 | return x 174 | 175 | 176 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch 177 | class NeRF(nn.Module): 178 | def __init__(self, 179 | D=8, 180 | W=256, 181 | d_in=3, 182 | d_in_view=3, 183 | multires=0, 184 | multires_view=0, 185 | output_ch=4, 186 | skips=[4], 187 | use_viewdirs=False): 188 | super(NeRF, self).__init__() 189 | self.D = D 190 | self.W = W 191 | self.d_in = d_in 192 | self.d_in_view = d_in_view 193 | self.input_ch = 3 194 | self.input_ch_view = 3 195 | self.embed_fn = None 196 | self.embed_fn_view = None 197 | 198 | if multires > 0: 199 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 200 | self.embed_fn = embed_fn 201 | self.input_ch = input_ch 202 | 203 | if multires_view > 0: 204 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 205 | self.embed_fn_view = embed_fn_view 206 | self.input_ch_view = input_ch_view 207 | 208 | self.skips = skips 209 | self.use_viewdirs = use_viewdirs 210 | 211 | self.pts_linears = nn.ModuleList( 212 | [nn.Linear(self.input_ch, W)] + 213 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 214 | 215 | ### Implementation according to the official code release 216 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 217 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 218 | 219 | ### Implementation according to the paper 220 | # self.views_linears = nn.ModuleList( 221 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 222 | 223 | if use_viewdirs: 224 | self.feature_linear = nn.Linear(W, W) 225 | self.alpha_linear = nn.Linear(W, 1) 226 | self.rgb_linear = nn.Linear(W // 2, 3) 227 | else: 228 | self.output_linear = nn.Linear(W, output_ch) 229 | 230 | def forward(self, input_pts, input_views): 231 | if self.embed_fn is not None: 232 | input_pts = self.embed_fn(input_pts) 233 | if self.embed_fn_view is not None: 234 | input_views = self.embed_fn_view(input_views) 235 | 236 | h = input_pts 237 | for i, l in enumerate(self.pts_linears): 238 | h = self.pts_linears[i](h) 239 | h = F.relu(h) 240 | if i in self.skips: 241 | h = torch.cat([input_pts, h], -1) 242 | 243 | if self.use_viewdirs: 244 | alpha = self.alpha_linear(h) 245 | feature = self.feature_linear(h) 246 | h = torch.cat([feature, input_views], -1) 247 | 248 | for i, l in enumerate(self.views_linears): 249 | h = self.views_linears[i](h) 250 | h = F.relu(h) 251 | 252 | rgb = self.rgb_linear(h) 253 | return alpha, rgb 254 | else: 255 | assert False 256 | 257 | 258 | class SingleVarianceNetwork(nn.Module): 259 | def __init__(self, init_val): 260 | super(SingleVarianceNetwork, self).__init__() 261 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) 262 | 263 | def forward(self, x): 264 | return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) 265 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_back_RT.txt: -------------------------------------------------------------------------------- 1 | -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08 2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08 3 | 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_back_left_RT.txt: -------------------------------------------------------------------------------- 1 | -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07 2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07 3 | 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_back_right_RT.txt: -------------------------------------------------------------------------------- 1 | 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07 2 | -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08 3 | 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_front_RT.txt: -------------------------------------------------------------------------------- 1 | 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00 2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08 3 | -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_front_left_RT.txt: -------------------------------------------------------------------------------- 1 | -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07 2 | 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07 3 | -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_front_right_RT.txt: -------------------------------------------------------------------------------- 1 | 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07 2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07 3 | -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_left_RT.txt: -------------------------------------------------------------------------------- 1 | -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00 2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08 3 | -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_right_RT.txt: -------------------------------------------------------------------------------- 1 | 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08 2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08 3 | 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/fixed_poses/000_top_RT.txt: -------------------------------------------------------------------------------- 1 | 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09 2 | -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08 3 | -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /NeuS/models/normal_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def camNormal2worldNormal(rot_c2w, camNormal): 4 | H,W,_ = camNormal.shape 5 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 6 | 7 | return normal_img 8 | 9 | def worldNormal2camNormal(rot_w2c, normal_map_world): 10 | H,W,_ = normal_map_world.shape 11 | # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 12 | 13 | # faster version 14 | # Reshape the normal map into a 2D array where each row represents a normal vector 15 | normal_map_flat = normal_map_world.reshape(-1, 3) 16 | 17 | # Transform the normal vectors using the transformation matrix 18 | normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T) 19 | 20 | # Reshape the transformed normal map back to its original shape 21 | normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape) 22 | 23 | return normal_map_camera 24 | 25 | def trans_normal(normal, RT_w2c, RT_w2c_target): 26 | 27 | # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) 28 | # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) 29 | 30 | relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) 31 | normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal) 32 | 33 | return normal_target_cam 34 | 35 | def img2normal(img): 36 | return (img/255.)*2-1 37 | 38 | def normal2img(normal): 39 | return np.uint8((normal*0.5+0.5)*255) 40 | 41 | def norm_normalize(normal, dim=-1): 42 | 43 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) 44 | 45 | return normal -------------------------------------------------------------------------------- /NeuS/models/ops.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 5 | """ 6 | depth: (H, W) 7 | """ 8 | 9 | x = np.nan_to_num(depth) # change nan to 0 10 | if minmax is None: 11 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) 12 | ma = np.max(x) 13 | else: 14 | mi, ma = minmax 15 | 16 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 17 | x = (255 * x).astype(np.uint8) 18 | x_ = cv2.applyColorMap(x, cmap) 19 | return x_, [mi, ma] -------------------------------------------------------------------------------- /NeuS/run.sh: -------------------------------------------------------------------------------- 1 | python exp_runner.py --mode train --conf ./confs/wmask.conf --case $2 --data_dir $1 2 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | **其他语言版本 [English](README.md)** 2 | 3 | # Wonder3D 4 | Single Image to 3D using Cross-Domain Diffusion 5 | ## [Paper](https://arxiv.org/abs/2310.15008) | [Project page](https://www.xxlong.site/Wonder3D/) | [Hugging Face Demo](https://huggingface.co/spaces/flamehaze1115/Wonder3D-demo) | [Colab from @camenduru](https://github.com/camenduru/Wonder3D-colab) 6 | 7 | ![](assets/fig_teaser.png) 8 | 9 | Wonder3D仅需2至3分钟即可从单视图图像中重建出高度详细的纹理网格。Wonder3D首先通过跨域扩散模型生成一致的多视图法线图与相应的彩色图像,然后利用一种新颖的法线融合方法实现快速且高质量的重建。 10 | 11 | ## Usage 使用 12 | ```bash 13 | 14 | import torch 15 | import requests 16 | from PIL import Image 17 | import numpy as np 18 | from torchvision.utils import make_grid, save_image 19 | from diffusers import DiffusionPipeline # only tested on diffusers[torch]==0.19.3, may have conflicts with newer versions of diffusers 20 | 21 | def load_wonder3d_pipeline(): 22 | 23 | pipeline = DiffusionPipeline.from_pretrained( 24 | 'flamehaze1115/wonder3d-v1.0', # or use local checkpoint './ckpts' 25 | custom_pipeline='flamehaze1115/wonder3d-pipeline', 26 | torch_dtype=torch.float16 27 | ) 28 | 29 | # enable xformers 30 | pipeline.unet.enable_xformers_memory_efficient_attention() 31 | 32 | if torch.cuda.is_available(): 33 | pipeline.to('cuda:0') 34 | return pipeline 35 | 36 | pipeline = load_wonder3d_pipeline() 37 | 38 | # Download an example image. 39 | cond = Image.open(requests.get("https://d.skis.ltd/nrp/sample-data/lysol.png", stream=True).raw) 40 | 41 | # The object should be located in the center and resized to 80% of image height. 42 | cond = Image.fromarray(np.array(cond)[:, :, :3]) 43 | 44 | # Run the pipeline! 45 | images = pipeline(cond, num_inference_steps=20, output_type='pt', guidance_scale=1.0).images 46 | 47 | result = make_grid(images, nrow=6, ncol=2, padding=0, value_range=(0, 1)) 48 | 49 | save_image(result, 'result.png') 50 | ``` 51 | 52 | ## Collaborations 合作 53 | 我们的总体使命是提高3D人工智能图形生成(3D AIGC)的速度、可负担性和质量,使所有人都能够轻松创建3D内容。尽管近年来取得了显著的进展,我们承认前方仍有很长的路要走。我们热切邀请您参与讨论并在任何方面探索潜在的合作机会。**如果您有兴趣与我们联系或合作,请随时通过电子邮件(xxlong@connect.hku.hk)联系我们**。 54 | 55 | ## More features 56 | 57 | The repo is still being under construction, thanks for your patience. 58 | - [x] Local gradio demo. 59 | - [x] Detailed tutorial. 60 | - [x] GUI demo for mesh reconstruction 61 | - [x] Windows support 62 | - [x] Docker support 63 | 64 | ## Schedule 65 | - [x] Inference code and pretrained models. 66 | - [x] Huggingface demo. 67 | - [ ] New model with higher resolution. 68 | 69 | 70 | ### Preparation for inference 测试准备 71 | 72 | #### Linux System Setup. 73 | ```angular2html 74 | conda create -n wonder3d 75 | conda activate wonder3d 76 | pip install -r requirements.txt 77 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 78 | ``` 79 | #### Windows System Setup. 80 | 81 | 请切换到`main-windows`分支以查看Windows设置的详细信息。 82 | 83 | #### Docker Setup 84 | 详见 [docker/README.MD](docker/README.md) 85 | 86 | ### Inference 87 | 1. 可选。如果您在连接到Hugging Face时遇到问题,请确保已下载以下模型。 88 | 下载[checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/Ej7fMT1PwXtKvsELTvDuzuMBebQXEkmf2IwhSjBWtKAJiA)并放入根文件夹中。 89 | 90 | 国内用户可下载: [阿里云盘](https://www.alipan.com/s/T4rLUNAVq6V) 91 | 92 | ```bash 93 | Wonder3D 94 | |-- ckpts 95 | |-- unet 96 | |-- scheduler 97 | |-- vae 98 | ... 99 | ``` 100 | 然后更改文件 ./configs/mvdiffusion-joint-ortho-6views.yaml, 设置 `pretrained_model_name_or_path="./ckpts"` 101 | 102 | 2. 下载模型 [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) . 放置在 ``sam_pt`` 文件夹. 103 | ``` 104 | Wonder3D 105 | |-- sam_pt 106 | |-- sam_vit_h_4b8939.pth 107 | ``` 108 | 3. 预测前景蒙版作为阿尔法通道。我们使用[Clipdrop](https://clipdrop.co/remove-background)来交互地分割前景对象。 109 | 您还可以使用`rembg`来去除背景。 110 | ```bash 111 | # !pip install rembg 112 | import rembg 113 | result = rembg.remove(result) 114 | result.show() 115 | ``` 116 | 4. 运行Wonder3D以生成多视角一致的法线图和彩色图像。然后,您可以在文件夹`./outputs`中检查结果(我们使用`rembg`去除结果的背景,但分割并不总是完美的。可以考虑使用[Clipdrop](https://clipdrop.co/remove-background)获取生成的法线图和彩色图像的蒙版,因为蒙版的质量将显著影响重建的网格质量)。 117 | ```bash 118 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ 119 | --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir={your_data_path} \ 120 | validation_dataset.filepaths=['your_img_file'] save_dir={your_save_path} 121 | ``` 122 | 123 | 示例: 124 | 125 | ```bash 126 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \ 127 | --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir=./example_images \ 128 | validation_dataset.filepaths=['owl.png'] save_dir=./outputs 129 | ``` 130 | 131 | #### 运行本地的Gradio演示。仅生成法线和颜色,无需进行重建。 132 | ```bash 133 | python gradio_app_mv.py # generate multi-view normals and colors 134 | ``` 135 | 136 | 5. Mesh Extraction 137 | 138 | #### Instant-NSR Mesh Extraction 139 | 140 | ```bash 141 | cd ./instant-nsr-pl 142 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{your_save_path}/cropsize-{crop_size}-cfg{guidance_scale:.1f}/ dataset.scene={scene} 143 | ``` 144 | 145 | 示例: 146 | 147 | ```bash 148 | cd ./instant-nsr-pl 149 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../outputs/cropsize-192-cfg1.0/ dataset.scene=owl 150 | ``` 151 | 152 | 我们生成的法线图和彩色图像是在正交视图中定义的,因此重建的网格也处于正交摄像机空间。如果您使用MeshLab查看网格,可以在“View”选项卡中单击“Toggle Orthographic Camera”切换到正交相机。 153 | 154 | #### 运行本地的Gradio演示。首先生成法线和颜色,然后进行重建。无需首先执行`gradio_app_mv.py`。 155 | ```bash 156 | python gradio_app_recon.py 157 | ``` 158 | 159 | #### NeuS-based Mesh Extraction 160 | 161 | 由于许多用户对于instant-nsr-pl的Windows设置提出了抱怨,我们提供了基于NeuS的重建,这可能消除了一些要求方面的问题。 162 | 163 | NeuS消耗较少的GPU内存,对平滑表面有利,无需参数调整。然而,NeuS需要更多时间,其纹理可能不够清晰。如果您对时间不太敏感,我们建议由于其稳健性而使用NeuS进行优化。 164 | 165 | ```bash 166 | cd ./NeuS 167 | bash run.sh output_folder_path scene_name 168 | ``` 169 | 170 | ## 常见问题 171 | **获取更好结果的提示:** 172 | 1. **图片朝向方向敏感:** Wonder3D对输入图像的面向方向敏感。通过实验证明,面向前方的图像通常会导致良好的重建结果。 173 | 2. **图像分辨率:** 受资源限制,当前实现仅支持有限的视图(6个视图)和低分辨率(256x256)。任何图像都将首先调整大小为256x256进行生成,因此在这样的降采样后仍然保持清晰而锐利特征的图像将导致良好的结果。 174 | 3. **处理遮挡:** 具有遮挡的图像会导致更差的重建,因为6个视图无法完全覆盖整个对象。具有较少遮挡的图像通常会产生更好的结果。 175 | 4. **增加instant-nsr-pl中的优化步骤:** 在instant-nsr-pl中增加优化步骤。在`instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml`中修改`trainer.max_steps: 3000`为更多步骤,例如`trainer.max_steps: 10000`。更长的优化步骤会导致更好的纹理。 176 | 177 | **生成视图信息:** 178 | - **仰角和方位角度:** 与Zero123、SyncDreamer和One2345等先前作品采用对象世界系统不同,我们的视图是在输入图像的相机系统中定义的。六个视图在输入图像的相机系统中的平面上,仰角为0度。因此,我们不需要为输入图像估算仰角。六个视图的方位角度分别为0、45、90、180、-90、-45。 179 | 180 | **生成视图的焦距:** 181 | - 我们假设输入图像是由正交相机捕获的,因此生成的视图也在正交空间中。这种设计使得我们的模型能够在虚构图像上保持强大的泛化能力,但有时可能在实际捕获的图像上受到焦距镜头畸变的影响。 182 | 183 | ## 致谢 184 | We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes. 185 | - [stable diffusion](https://github.com/CompVis/stable-diffusion) 186 | - [zero123](https://github.com/cvlab-columbia/zero123) 187 | - [NeuS](https://github.com/Totoro97/NeuS) 188 | - [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer) 189 | - [instant-nsr-pl](https://github.com/bennyguo/instant-nsr-pl) 190 | 191 | ## 协议 192 | Wonder3D采用[AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html)许可,因此任何包含Wonder3D代码或其中训练的模型(无论是预训练还是定制训练)的下游解决方案和产品(包括云服务)都应该开源以符合AGPL条件。如果您对Wonder3D的使用有任何疑问,请首先与我们联系。 193 | 194 | ## 引用 195 | 如果您在项目中发现这个项目对您有用,请引用以下工作。 :) 196 | ``` 197 | @article{long2023wonder3d, 198 | title={Wonder3D: Single Image to 3D using Cross-Domain Diffusion}, 199 | author={Long, Xiaoxiao and Guo, Yuan-Chen and Lin, Cheng and Liu, Yuan and Dou, Zhiyang and Liu, Lingjie and Ma, Yuexin and Zhang, Song-Hai and Habermann, Marc and Theobalt, Christian and others}, 200 | journal={arXiv preprint arXiv:2310.15008}, 201 | year={2023} 202 | } 203 | ``` 204 | -------------------------------------------------------------------------------- /assets/bug_fixed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/bug_fixed.png -------------------------------------------------------------------------------- /assets/coordinate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/coordinate.png -------------------------------------------------------------------------------- /assets/fig_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/fig_teaser.png -------------------------------------------------------------------------------- /configs/mvdiffusion-joint-ortho-6views.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: 'flamehaze1115/wonder3d-v1.0' # or './ckpts' 2 | revision: null 3 | validation_dataset: 4 | root_dir: "./example_images" # the folder path stores testing images 5 | num_views: 6 6 | bg_color: 'white' 7 | img_wh: [256, 256] 8 | num_validation_samples: 1000 9 | crop_size: 192 10 | filepaths: ['owl.png'] # the test image names. leave it empty, test all images in the folder 11 | 12 | save_dir: 'outputs/' 13 | 14 | pred_type: 'joint' 15 | seed: 42 16 | validation_batch_size: 1 17 | dataloader_num_workers: 64 18 | 19 | local_rank: -1 20 | 21 | pipe_kwargs: 22 | camera_embedding_type: 'e_de_da_sincos' 23 | num_views: 6 24 | 25 | validation_guidance_scales: [1.0, 3.0] 26 | pipe_validation_kwargs: 27 | eta: 1.0 28 | validation_grid_nrow: 6 29 | 30 | unet_from_pretrained_kwargs: 31 | camera_embedding_type: 'e_de_da_sincos' 32 | projection_class_embeddings_input_dim: 10 33 | num_views: 6 34 | sample_size: 32 35 | cd_attention_mid: true 36 | zero_init_conv_in: false 37 | zero_init_camera_projection: false 38 | 39 | num_views: 6 40 | camera_embedding_type: 'e_de_da_sincos' 41 | 42 | enable_xformers_memory_efficient_attention: true -------------------------------------------------------------------------------- /configs/train/stage1-mix-6views-lvis.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers' 2 | revision: null 3 | train_dataset: 4 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 5 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 6 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 7 | num_views: 6 8 | groups_num: 1 9 | bg_color: 'three_choices' 10 | img_wh: [256, 256] 11 | validation: false 12 | num_validation_samples: 32 13 | # read_normal: true 14 | # read_color: true 15 | mix_color_normal: true 16 | validation_dataset: 17 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 18 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 19 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 20 | num_views: 6 21 | groups_num: 1 22 | bg_color: 'white' 23 | img_wh: [256, 256] 24 | validation: true 25 | num_validation_samples: 32 26 | # read_normal: true 27 | # read_color: true 28 | mix_color_normal: true 29 | validation_train_dataset: 30 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 31 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 32 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 33 | num_views: 6 34 | groups_num: 1 35 | bg_color: 'white' 36 | img_wh: [256, 256] 37 | validation: false 38 | num_validation_samples: 32 39 | num_samples: 32 40 | # read_normal: true 41 | # read_color: true 42 | mix_color_normal: true 43 | 44 | pred_type: 'mix' 45 | 46 | output_dir: 'outputs/wonder3D-mix' 47 | seed: 42 48 | train_batch_size: 32 49 | validation_batch_size: 16 50 | validation_train_batch_size: 16 51 | max_train_steps: 30000 52 | gradient_accumulation_steps: 2 53 | gradient_checkpointing: true 54 | learning_rate: 1.e-4 55 | scale_lr: false 56 | lr_scheduler: "constant_with_warmup" 57 | lr_warmup_steps: 100 58 | snr_gamma: 5.0 59 | use_8bit_adam: false 60 | allow_tf32: true 61 | use_ema: true 62 | dataloader_num_workers: 64 63 | adam_beta1: 0.9 64 | adam_beta2: 0.999 65 | adam_weight_decay: 1.e-2 66 | adam_epsilon: 1.e-08 67 | max_grad_norm: 1.0 68 | prediction_type: null 69 | vis_dir: vis 70 | logging_dir: logs 71 | mixed_precision: "fp16" 72 | report_to: 'tensorboard' 73 | local_rank: -1 74 | checkpointing_steps: 5000 75 | checkpoints_total_limit: 20 76 | resume_from_checkpoint: latest 77 | enable_xformers_memory_efficient_attention: true 78 | validation_steps: 1250 79 | validation_sanity_check: true 80 | tracker_project_name: 'mvdiffusion-image-v1' 81 | 82 | trainable_modules: null 83 | use_classifier_free_guidance: true 84 | condition_drop_rate: 0.05 85 | drop_type: 'drop_as_a_whole' # modify 86 | camera_embedding_lr_mult: 10. 87 | scale_input_latents: true 88 | 89 | pipe_kwargs: 90 | camera_embedding_type: 'e_de_da_sincos' 91 | num_views: 6 92 | 93 | validation_guidance_scales: [1., 3.] 94 | pipe_validation_kwargs: 95 | eta: 1.0 96 | validation_grid_nrow: 12 97 | 98 | unet_from_pretrained_kwargs: 99 | camera_embedding_type: 'e_de_da_sincos' 100 | projection_class_embeddings_input_dim: 10 # modify 101 | num_views: 6 102 | sample_size: 32 103 | zero_init_conv_in: true 104 | zero_init_camera_projection: false 105 | cd_attention_last: false 106 | cd_attention_mid: false 107 | multiview_attention: true 108 | sparse_mv_attention: false 109 | mvcd_attention: false 110 | 111 | num_views: 6 112 | camera_embedding_type: 'e_de_da_sincos' 113 | -------------------------------------------------------------------------------- /configs/train/stage2-joint-6views-lvis.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers' 2 | # modify the unet path; use the stage 1 checkpoint 3 | pretrained_unet_path: 'outputs/wonder3D-mix/checkpoint-30000/' 4 | # pretrained_unet_path: null 5 | revision: null 6 | train_dataset: 7 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 8 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 9 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 10 | num_views: 6 11 | groups_num: 1 12 | bg_color: 'three_choices' 13 | img_wh: [256, 256] 14 | validation: false 15 | num_validation_samples: 32 16 | read_normal: true 17 | read_color: true 18 | validation_dataset: 19 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 20 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 21 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 22 | num_views: 6 23 | groups_num: 1 24 | bg_color: 'white' 25 | img_wh: [256, 256] 26 | validation: true 27 | num_validation_samples: 32 28 | read_normal: true 29 | read_color: true 30 | validation_train_dataset: 31 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path 32 | object_list: './data_lists/lvis_uids_filter_by_vertex.json' 33 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json' 34 | num_views: 6 35 | groups_num: 1 36 | bg_color: 'three_choices' 37 | img_wh: [256, 256] 38 | validation: false 39 | num_validation_samples: 32 40 | num_samples: 32 41 | read_normal: true 42 | read_color: true 43 | 44 | # output_dir: 'outputs/debug' 45 | output_dir: 'outputs/wonder3D-joint' 46 | seed: 42 47 | train_batch_size: 32 # original paper uses 32 48 | validation_batch_size: 16 49 | validation_train_batch_size: 16 50 | max_train_steps: 20000 51 | gradient_accumulation_steps: 2 52 | gradient_checkpointing: true 53 | learning_rate: 5.e-5 54 | scale_lr: false 55 | lr_scheduler: "constant_with_warmup" 56 | lr_warmup_steps: 100 57 | snr_gamma: 5.0 58 | use_8bit_adam: false 59 | allow_tf32: true 60 | use_ema: true 61 | dataloader_num_workers: 64 62 | adam_beta1: 0.9 63 | adam_beta2: 0.999 64 | adam_weight_decay: 1.e-2 65 | adam_epsilon: 1.e-08 66 | max_grad_norm: 1.0 67 | prediction_type: null 68 | vis_dir: vis 69 | logging_dir: logs 70 | mixed_precision: "fp16" 71 | report_to: 'tensorboard' 72 | local_rank: -1 73 | checkpointing_steps: 5000 74 | checkpoints_total_limit: null 75 | last_global_step: 5000 76 | 77 | resume_from_checkpoint: latest 78 | enable_xformers_memory_efficient_attention: true 79 | validation_steps: 1000 80 | validation_sanity_check: true 81 | tracker_project_name: 'mvdiffusion-image-v1' 82 | 83 | trainable_modules: ['joint_mid'] 84 | use_classifier_free_guidance: true 85 | condition_drop_rate: 0.05 86 | drop_type: 'drop_as_a_whole' # modify 87 | camera_embedding_lr_mult: 10. 88 | scale_input_latents: true 89 | 90 | pipe_kwargs: 91 | camera_embedding_type: 'e_de_da_sincos' 92 | num_views: 6 93 | 94 | validation_guidance_scales: [1., 3.] 95 | pipe_validation_kwargs: 96 | eta: 1.0 97 | validation_grid_nrow: 12 98 | 99 | unet_from_pretrained_kwargs: 100 | camera_embedding_type: 'e_de_da_sincos' 101 | projection_class_embeddings_input_dim: 10 # modify 102 | num_views: 6 103 | sample_size: 32 104 | zero_init_conv_in: false 105 | zero_init_camera_projection: false 106 | cd_attention_last: false 107 | cd_attention_mid: true 108 | multiview_attention: true 109 | sparse_mv_attention: false 110 | mvcd_attention: false 111 | 112 | num_views: 6 113 | camera_embedding_type: 'e_de_da_sincos' 114 | 115 | 116 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # get the development image from nvidia cuda 11.7 2 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 3 | 4 | LABEL name="Wonder3D" \ 5 | maintainer="Tiancheng " \ 6 | lastupdate="2024-01-05" 7 | 8 | # create workspace folder and set it as working directory 9 | RUN mkdir -p /workspace 10 | WORKDIR /workspace 11 | 12 | # Set the timezone 13 | ENV DEBIAN_FRONTEND=noninteractive 14 | RUN apt-get update && \ 15 | apt-get install -y tzdata && \ 16 | ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 17 | dpkg-reconfigure --frontend noninteractive tzdata 18 | 19 | # update package lists and install git, wget, vim, libgl1-mesa-glx, and libglib2.0-0 20 | RUN apt-get update && \ 21 | apt-get install -y git wget vim libgl1-mesa-glx libglib2.0-0 unzip 22 | 23 | # install conda 24 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 25 | chmod +x Miniconda3-latest-Linux-x86_64.sh && \ 26 | ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \ 27 | rm Miniconda3-latest-Linux-x86_64.sh 28 | 29 | # update PATH environment variable 30 | ENV PATH="/workspace/miniconda3/bin:${PATH}" 31 | 32 | # initialize conda 33 | RUN conda init bash 34 | 35 | # create and activate conda environment 36 | RUN conda create -n wonder3d python=3.8 && echo "source activate wonder3d" > ~/.bashrc 37 | ENV PATH /workspace/miniconda3/envs/wonder3d/bin:$PATH 38 | 39 | 40 | # clone the repository 41 | RUN git clone https://github.com/xxlong0/Wonder3D.git && \ 42 | cd /workspace/Wonder3D 43 | 44 | # change the working directory to the repository 45 | WORKDIR /workspace/Wonder3D 46 | 47 | # install pytorch 1.13.1 and torchvision 48 | RUN pip install -r docker/requirements.txt 49 | 50 | # install the specific version of nerfacc corresponding to torch 1.13.0 and cuda 11.7, otherwise the nerfacc will freeze during cuda setup 51 | RUN pip install nerfacc==0.3.3 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.0_cu117.html 52 | 53 | # install tiny cuda during docker setup will cause error, need to install it manually in the container 54 | # RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 55 | 56 | 57 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker setup 2 | 3 | This docker setup is tested on Ubunu20.04. 4 | 5 | make sure you are under directory yourworkspace/Wonder3D/ 6 | 7 | run 8 | 9 | `docker build --no-cache -t wonder3d/deploy:cuda11.7 -f docker/Dockerfile .` 10 | 11 | then run 12 | 13 | `docker run --gpus all -it wonder3d/deploy:cuda11.7 bash` 14 | 15 | 16 | ## Nvidia Container Toolkit setup 17 | 18 | You will have trouble enabling gpu for docker if you haven't installed **NVIDIA Container Toolkit** on you local machine before. You can skip this section if you have already installed it. Follow the instruction in this website to install it. 19 | 20 | https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html 21 | 22 | or you can run the following command to install it with apt: 23 | 24 | 1.Configure the production repository: 25 | 26 | ```bash 27 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ 28 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ 29 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ 30 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 31 | ``` 32 | 33 | 2.Update the packages list from the repository: 34 | 35 | `sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list` 36 | 37 | 3.Install the NVIDIA Container Toolkit packages: 38 | 39 | `sudo apt-get install -y nvidia-container-toolkit` 40 | 41 | Remember to restart the docker: 42 | 43 | `sudo systemctl restart docker` 44 | 45 | now you can run the following command: 46 | 47 | `docker run --gpus all -it wonder3d/deploy:cuda11.7 bash` 48 | 49 | 50 | ## Install Tiny Cudann 51 | 52 | After you start the container, run the following command to install tiny cudann. Somehow this pip installation can not be done during the docker build, so you have to do it manually after the docker is started. 53 | 54 | `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch` 55 | 56 | 57 | Now you should be good to go, good luck and have fun :) 58 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | 3 | # nerfacc==0.3.3, nefacc needs to be installed from the specific location 4 | # see installation part in this link: https://github.com/nerfstudio-project/nerfacc 5 | 6 | torch==1.13.1+cu117 7 | torchvision==0.14.1+cu117 8 | diffusers[torch]==0.19.3 9 | xformers==0.0.16 10 | transformers>=4.25.1 11 | bitsandbytes==0.35.4 12 | decord==0.6.0 13 | pytorch-lightning<2 14 | omegaconf==2.2.3 15 | trimesh==3.9.8 16 | pyhocon==0.3.57 17 | icecream==2.1.0 18 | PyMCubes==0.1.2 19 | accelerate 20 | modelcards 21 | einops 22 | ftfy 23 | piq 24 | matplotlib 25 | opencv-python 26 | imageio 27 | imageio-ffmpeg 28 | scipy 29 | pyransac3d 30 | torch_efficient_distloss 31 | tensorboard 32 | rembg 33 | segment_anything 34 | gradio==3.50.2 35 | triton 36 | rich 37 | -------------------------------------------------------------------------------- /example_images/14_10_29_489_Tiger_1__1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/14_10_29_489_Tiger_1__1.png -------------------------------------------------------------------------------- /example_images/box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/box.png -------------------------------------------------------------------------------- /example_images/bread.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/bread.png -------------------------------------------------------------------------------- /example_images/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/cat.png -------------------------------------------------------------------------------- /example_images/cat_head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/cat_head.png -------------------------------------------------------------------------------- /example_images/chili.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/chili.png -------------------------------------------------------------------------------- /example_images/duola.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/duola.png -------------------------------------------------------------------------------- /example_images/halloween.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/halloween.png -------------------------------------------------------------------------------- /example_images/head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/head.png -------------------------------------------------------------------------------- /example_images/kettle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/kettle.png -------------------------------------------------------------------------------- /example_images/kunkun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/kunkun.png -------------------------------------------------------------------------------- /example_images/milk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/milk.png -------------------------------------------------------------------------------- /example_images/owl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/owl.png -------------------------------------------------------------------------------- /example_images/poro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/poro.png -------------------------------------------------------------------------------- /example_images/pumpkin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/pumpkin.png -------------------------------------------------------------------------------- /example_images/skull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/skull.png -------------------------------------------------------------------------------- /example_images/stone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/stone.png -------------------------------------------------------------------------------- /example_images/teapot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/teapot.png -------------------------------------------------------------------------------- /example_images/tiger-head-3d-model-obj-stl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/tiger-head-3d-model-obj-stl.png -------------------------------------------------------------------------------- /instant-nsr-pl/README.md: -------------------------------------------------------------------------------- 1 | # Instant Neural Surface Reconstruction 2 | 3 | This repository contains a concise and extensible implementation of NeRF and NeuS for neural surface reconstruction based on Instant-NGP and the Pytorch-Lightning framework. **Training on a NeRF-Synthetic scene takes ~5min for NeRF and ~10min for NeuS on a single RTX3090.** 4 | 5 | ||NeRF in 5min|NeuS in 10 min| 6 | |---|---|---| 7 | |Rendering|![rendering-nerf](https://user-images.githubusercontent.com/19284678/199078178-b719676b-7e60-47f1-813b-c0b533f5480d.png)|![rendering-neus](https://user-images.githubusercontent.com/19284678/199078300-ebcf249d-b05e-431f-b035-da354705d8db.png)| 8 | |Mesh|![mesh-nerf](https://user-images.githubusercontent.com/19284678/199078661-b5cd569a-c22b-4220-9c11-d5fd13a52fb8.png)|![mesh-neus](https://user-images.githubusercontent.com/19284678/199078481-164e36a6-6d55-45cc-aaf3-795a114e4a38.png)| 9 | 10 | 11 | ## Features 12 | **This repository aims to provide a highly efficient while customizable boilerplate for research projects based on NeRF or NeuS.** 13 | 14 | - acceleration techniques from [Instant-NGP](https://github.com/NVlabs/instant-ngp): multiresolution hash encoding and fully fused networks by [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), occupancy grid pruning and rendering by [nerfacc](https://github.com/KAIR-BAIR/nerfacc) 15 | - out-of-the-box multi-GPU and mixed precision training by [PyTorch-Lightning](https://github.com/Lightning-AI/lightning) 16 | - hierarchical project layout that is designed to be easily customized and extended, flexible experiment configuration by [OmegaConf](https://github.com/omry/omegaconf) 17 | 18 | **Please subscribe to [#26](https://github.com/bennyguo/instant-nsr-pl/issues/26) for our latest findings on quality improvements!** 19 | 20 | ## News 21 | 22 | 🔥🔥🔥 Check out my new project on 3D content generation: https://github.com/threestudio-project/threestudio 🔥🔥🔥 23 | 24 | - 06/03/2023: Add an implementation of [Neuralangelo](https://research.nvidia.com/labs/dir/neuralangelo/). See [here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU) for details. 25 | - 03/31/2023: NeuS model now supports background modeling. You could try on the DTU dataset provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa) following [the instruction here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU). 26 | - 02/11/2023: NeRF model now supports unbounded 360 scenes with learned background. You could try on [MipNeRF 360 data](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) following [the COLMAP configuration](https://github.com/bennyguo/instant-nsr-pl#training-on-custom-colmap-data). 27 | 28 | ## Requirements 29 | **Note:** 30 | - To utilize multiresolution hash encoding or fully fused networks provided by tiny-cuda-nn, you should have least an RTX 2080Ti, see [https://github.com/NVlabs/tiny-cuda-nn#requirements](https://github.com/NVlabs/tiny-cuda-nn#requirements) for more details. 31 | - Multi-GPU training is currently not supported on Windows (see [#4](https://github.com/bennyguo/instant-nsr-pl/issues/4)). 32 | ### Environments 33 | - Install PyTorch>=1.10 [here](https://pytorch.org/get-started/locally/) based the package management tool you used and your cuda version (older PyTorch versions may work but have not been tested) 34 | - Install tiny-cuda-nn PyTorch extension: `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch` 35 | - `pip install -r requirements.txt` 36 | 37 | 38 | ## Run 39 | ### Training on NeRF-Synthetic 40 | Download the NeRF-Synthetic data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `load/`. The file structure should be like `load/nerf_synthetic/lego`. 41 | 42 | Run the launch script with `--train`, specifying the config file, the GPU(s) to be used (GPU 0 will be used by default), and the scene name: 43 | ```bash 44 | # train NeRF 45 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=example 46 | 47 | # train NeuS with mask 48 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example 49 | # train NeuS without mask 50 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example system.loss.lambda_mask=0.0 51 | ``` 52 | The code snapshots, checkpoints and experiment outputs are saved to `exp/[name]/[tag]@[timestamp]`, and tensorboard logs can be found at `runs/[name]/[tag]@[timestamp]`. You can change any configuration in the YAML file by specifying arguments without `--`, for example: 53 | ```bash 54 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=iter50k seed=0 trainer.max_steps=50000 55 | ``` 56 | ### Training on DTU 57 | Download preprocessed DTU data provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa). In the provided config files we assume using NeuS DTU data. If you are using IDR DTU data, please set `dataset.cameras_file=cameras.npz`. You may also need to adjust `dataset.root_dir` to point to your downloaded data location. 58 | ```bash 59 | # train NeuS on DTU without mask 60 | python launch.py --config configs/neus-dtu.yaml --gpu 0 --train 61 | # train NeuS on DTU with mask 62 | python launch.py --config configs/neus-dtu-wmask.yaml --gpu 0 --train 63 | # train NeuS on DTU with mask using tricks from Neuralangelo (experimental) 64 | python launch.py --config configs/neuralangelo-dtu-wmask.yaml --gpu 0 --train 65 | ``` 66 | Notes: 67 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. 68 | - The results of Neuralangelo can't reach those in the original paper. Some potential improvements: more iterations; larger `system.geometry.xyz_encoding_config.update_steps`; larger `system.geometry.xyz_encoding_config.n_features_per_level`; larger `system.geometry.xyz_encoding_config.log2_hashmap_size`; adopting curvature loss. 69 | 70 | ### Training on Custom COLMAP Data 71 | To get COLMAP data from custom images, you should have COLMAP installed (see [here](https://colmap.github.io/install.html) for installation instructions). Then put your images in the `images/` folder, and run `scripts/imgs2poses.py` specifying the path containing the `images/` folder. For example: 72 | ```bash 73 | python scripts/imgs2poses.py ./load/bmvs_dog # images are in ./load/bmvs_dog/images 74 | ``` 75 | Existing data following this file structure also works as long as images are store in `images/` and there is a `sparse/` folder for the COLMAP output, for example [the data provided by MipNeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip). An optional `masks/` folder could be provided for object mask supervision. To train on COLMAP data, please refer to the example config files `config/*-colmap.yaml`. Some notes: 76 | - Adapt the `root_dir` and `img_wh` (or `img_downscale`) option in the config file to your data; 77 | - The scene is normalized so that cameras have a minimum distance `1.0` to the center of the scene. Setting `model.radius=1.0` works in most cases. If not, try setting a smaller radius that wraps tightly to your foreground object. 78 | - There are three choices to determine the scene center: `dataset.center_est_method=camera` uses the center of all camera positions as the scene center; `dataset.center_est_method=lookat` assumes the cameras are looking at the same point and calculates an approximate look-at point as the scene center; `dataset.center_est_method=point` uses the center of all points (reconstructed by COLMAP) that are bounded by cameras as the scene center. Please choose an appropriate method according to your capture. 79 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing. 80 | 81 | ### Testing 82 | The training procedure are by default followed by testing, which computes metrics on test data, generates animations and exports the geometry as triangular meshes. If you want to do testing alone, just resume the pretrained model and replace `--train` with `--test`, for example: 83 | ```bash 84 | python launch.py --config path/to/your/exp/config/parsed.yaml --resume path/to/your/exp/ckpt/epoch=0-step=20000.ckpt --gpu 0 --test 85 | ``` 86 | 87 | 88 | ## Benchmarks 89 | All experiments are conducted on a single NVIDIA RTX3090. 90 | 91 | |PSNR|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| 92 | |---|---|---|---|---|---|---|---|---|---| 93 | |NeRF Paper|33.00|25.01|30.13|36.18|32.54|29.62|32.91|28.65|31.01| 94 | |NeRF Ours (20k)|34.80|26.04|33.89|37.42|35.33|29.46|35.22|31.17|32.92| 95 | |NeuS Ours (20k, with masks)|34.04|25.26|32.47|35.94|33.78|27.67|33.43|29.50|31.51| 96 | 97 | |Training Time (mm:ss)|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.| 98 | |---|---|---|---|---|---|---|---|---|---| 99 | |NeRF Ours (20k)|04:34|04:35|04:18|04:46|04:39|04:35|04:26|05:41|04:42| 100 | |NeuS Ours (20k, with masks)|11:25|10:34|09:51|12:11|11:37|11:46|09:59|16:25|11:44| 101 | 102 | 103 | ## TODO 104 | - [✅] Support more dataset formats, like COLMAP outputs and DTU 105 | - [✅] Support simple background model 106 | - [ ] Support GUI training and interaction 107 | - [ ] More illustrations about the framework 108 | 109 | ## Related Projects 110 | - [ngp_pl](https://github.com/kwea123/ngp_pl): Great Instant-NGP implementation in PyTorch-Lightning! Background model and GUI supported. 111 | - [Instant-NSR](https://github.com/zhaofuq/Instant-NSR): NeuS implementation using multiresolution hash encoding. 112 | 113 | ## Citation 114 | If you find this codebase useful, please consider citing: 115 | ``` 116 | @misc{instant-nsr-pl, 117 | Author = {Yuan-Chen Guo}, 118 | Year = {2022}, 119 | Note = {https://github.com/bennyguo/instant-nsr-pl}, 120 | Title = {Instant Neural Surface Reconstruction} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml: -------------------------------------------------------------------------------- 1 | name: ${basename:${dataset.scene}} 2 | tag: "" 3 | seed: 42 4 | 5 | dataset: 6 | name: ortho 7 | root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0 8 | cam_pose_dir: null 9 | scene: scene_name 10 | imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors 11 | camera_type: ortho 12 | apply_mask: true 13 | camera_params: null 14 | view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left'] 15 | # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 16 | 17 | model: 18 | name: neus 19 | radius: 1.0 20 | num_samples_per_ray: 1024 21 | train_num_rays: 256 22 | max_train_num_rays: 8192 23 | grid_prune: true 24 | grid_prune_occ_thre: 0.001 25 | dynamic_ray_sampling: true 26 | batch_image_sampling: true 27 | randomized: true 28 | ray_chunk: 2048 29 | cos_anneal_end: 20000 30 | learned_background: false 31 | background_color: black 32 | variance: 33 | init_val: 0.3 34 | modulate: false 35 | geometry: 36 | name: volume-sdf 37 | radius: ${model.radius} 38 | feature_dim: 13 39 | grad_type: finite_difference 40 | finite_difference_eps: progressive 41 | isosurface: 42 | method: mc 43 | resolution: 192 44 | chunk: 2097152 45 | threshold: 0. 46 | xyz_encoding_config: 47 | otype: ProgressiveBandHashGrid 48 | n_levels: 10 # 12 modify 49 | n_features_per_level: 2 50 | log2_hashmap_size: 19 51 | base_resolution: 32 52 | per_level_scale: 1.3195079107728942 53 | include_xyz: true 54 | start_level: 4 55 | start_step: 0 56 | update_steps: 1000 57 | mlp_network_config: 58 | otype: VanillaMLP 59 | activation: ReLU 60 | output_activation: none 61 | n_neurons: 64 62 | n_hidden_layers: 1 63 | sphere_init: true 64 | sphere_init_radius: 0.5 65 | weight_norm: true 66 | texture: 67 | name: volume-radiance 68 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input 69 | dir_encoding_config: 70 | otype: SphericalHarmonics 71 | degree: 4 72 | mlp_network_config: 73 | otype: VanillaMLP 74 | activation: ReLU 75 | output_activation: none 76 | n_neurons: 64 77 | n_hidden_layers: 2 78 | color_activation: sigmoid 79 | 80 | system: 81 | name: ortho-neus-system 82 | loss: 83 | lambda_rgb_mse: 0.5 84 | lambda_rgb_l1: 0. 85 | lambda_mask: 1.0 86 | lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects 87 | lambda_normal: 1.0 # cannot be too large 88 | lambda_3d_normal_smooth: 1.0 89 | # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup 90 | lambda_curvature: 0. 91 | lambda_sparsity: 0.5 92 | lambda_distortion: 0.0 93 | lambda_distortion_bg: 0.0 94 | lambda_opaque: 0.0 95 | sparsity_scale: 100.0 96 | geo_aware: true 97 | rgb_p_ratio: 0.8 98 | normal_p_ratio: 0.8 99 | mask_p_ratio: 0.9 100 | optimizer: 101 | name: AdamW 102 | args: 103 | lr: 0.01 104 | betas: [0.9, 0.99] 105 | eps: 1.e-15 106 | params: 107 | geometry: 108 | lr: 0.001 109 | texture: 110 | lr: 0.01 111 | variance: 112 | lr: 0.001 113 | constant_steps: 500 114 | scheduler: 115 | name: SequentialLR 116 | interval: step 117 | milestones: 118 | - ${system.constant_steps} 119 | schedulers: 120 | - name: ConstantLR 121 | args: 122 | factor: 1.0 123 | total_iters: ${system.constant_steps} 124 | - name: ExponentialLR 125 | args: 126 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} 127 | 128 | checkpoint: 129 | save_top_k: -1 130 | every_n_train_steps: ${trainer.max_steps} 131 | 132 | export: 133 | chunk_size: 2097152 134 | export_vertex_color: True 135 | ortho_scale: 1.35 #modify 136 | 137 | trainer: 138 | max_steps: 3000 139 | log_every_n_steps: 100 140 | num_sanity_val_steps: 0 141 | val_check_interval: 4000 142 | limit_train_batches: 1.0 143 | limit_val_batches: 2 144 | enable_progress_bar: true 145 | precision: 16 146 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | datasets = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | datasets[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config): 12 | dataset = datasets[name](config) 13 | return dataset 14 | 15 | 16 | from . import blender, colmap, dtu, ortho 17 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader, IterableDataset 9 | import torchvision.transforms.functional as TF 10 | 11 | import pytorch_lightning as pl 12 | 13 | import datasets 14 | from models.ray_utils import get_ray_directions 15 | from utils.misc import get_rank 16 | 17 | 18 | class BlenderDatasetBase(): 19 | def setup(self, config, split): 20 | self.config = config 21 | self.split = split 22 | self.rank = get_rank() 23 | 24 | self.has_mask = True 25 | self.apply_mask = True 26 | 27 | with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f: 28 | meta = json.load(f) 29 | 30 | if 'w' in meta and 'h' in meta: 31 | W, H = int(meta['w']), int(meta['h']) 32 | else: 33 | W, H = 800, 800 34 | 35 | if 'img_wh' in self.config: 36 | w, h = self.config.img_wh 37 | assert round(W / w * h) == H 38 | elif 'img_downscale' in self.config: 39 | w, h = W // self.config.img_downscale, H // self.config.img_downscale 40 | else: 41 | raise KeyError("Either img_wh or img_downscale should be specified.") 42 | 43 | self.w, self.h = w, h 44 | self.img_wh = (self.w, self.h) 45 | 46 | self.near, self.far = self.config.near_plane, self.config.far_plane 47 | 48 | self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) # scaled focal length 49 | 50 | # ray directions for all pixels, same for all images (same H, W, focal) 51 | self.directions = \ 52 | get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) # (h, w, 3) 53 | 54 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] 55 | 56 | for i, frame in enumerate(meta['frames']): 57 | c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4]) 58 | self.all_c2w.append(c2w) 59 | 60 | img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") 61 | img = Image.open(img_path) 62 | img = img.resize(self.img_wh, Image.BICUBIC) 63 | img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) 64 | 65 | self.all_fg_masks.append(img[..., -1]) # (h, w) 66 | self.all_images.append(img[...,:3]) 67 | 68 | self.all_c2w, self.all_images, self.all_fg_masks = \ 69 | torch.stack(self.all_c2w, dim=0).float().to(self.rank), \ 70 | torch.stack(self.all_images, dim=0).float().to(self.rank), \ 71 | torch.stack(self.all_fg_masks, dim=0).float().to(self.rank) 72 | 73 | 74 | class BlenderDataset(Dataset, BlenderDatasetBase): 75 | def __init__(self, config, split): 76 | self.setup(config, split) 77 | 78 | def __len__(self): 79 | return len(self.all_images) 80 | 81 | def __getitem__(self, index): 82 | return { 83 | 'index': index 84 | } 85 | 86 | 87 | class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): 88 | def __init__(self, config, split): 89 | self.setup(config, split) 90 | 91 | def __iter__(self): 92 | while True: 93 | yield {} 94 | 95 | 96 | @datasets.register('blender') 97 | class BlenderDataModule(pl.LightningDataModule): 98 | def __init__(self, config): 99 | super().__init__() 100 | self.config = config 101 | 102 | def setup(self, stage=None): 103 | if stage in [None, 'fit']: 104 | self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split) 105 | if stage in [None, 'fit', 'validate']: 106 | self.val_dataset = BlenderDataset(self.config, self.config.val_split) 107 | if stage in [None, 'test']: 108 | self.test_dataset = BlenderDataset(self.config, self.config.test_split) 109 | if stage in [None, 'predict']: 110 | self.predict_dataset = BlenderDataset(self.config, self.config.train_split) 111 | 112 | def prepare_data(self): 113 | pass 114 | 115 | def general_loader(self, dataset, batch_size): 116 | sampler = None 117 | return DataLoader( 118 | dataset, 119 | num_workers=os.cpu_count(), 120 | batch_size=batch_size, 121 | pin_memory=True, 122 | sampler=sampler 123 | ) 124 | 125 | def train_dataloader(self): 126 | return self.general_loader(self.train_dataset, batch_size=1) 127 | 128 | def val_dataloader(self): 129 | return self.general_loader(self.val_dataset, batch_size=1) 130 | 131 | def test_dataloader(self): 132 | return self.general_loader(self.test_dataset, batch_size=1) 133 | 134 | def predict_dataloader(self): 135 | return self.general_loader(self.predict_dataset, batch_size=1) 136 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/dtu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader, IterableDataset 11 | import torchvision.transforms.functional as TF 12 | 13 | import pytorch_lightning as pl 14 | 15 | import datasets 16 | from models.ray_utils import get_ray_directions 17 | from utils.misc import get_rank 18 | 19 | 20 | def load_K_Rt_from_P(P=None): 21 | out = cv2.decomposeProjectionMatrix(P) 22 | K = out[0] 23 | R = out[1] 24 | t = out[2] 25 | 26 | K = K / K[2, 2] 27 | intrinsics = np.eye(4) 28 | intrinsics[:3, :3] = K 29 | 30 | pose = np.eye(4, dtype=np.float32) 31 | pose[:3, :3] = R.transpose() 32 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 33 | 34 | return intrinsics, pose 35 | 36 | def create_spheric_poses(cameras, n_steps=120): 37 | center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) 38 | cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2) 39 | eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors 40 | rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1) 41 | up = rot_axis 42 | rot_dir = torch.cross(rot_axis, cam_center) 43 | max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max() 44 | 45 | all_c2w = [] 46 | for theta in torch.linspace(-max_angle, max_angle, n_steps): 47 | cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta) 48 | l = F.normalize(center - cam_pos, p=2, dim=0) 49 | s = F.normalize(l.cross(up), p=2, dim=0) 50 | u = F.normalize(s.cross(l), p=2, dim=0) 51 | c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) 52 | all_c2w.append(c2w) 53 | 54 | all_c2w = torch.stack(all_c2w, dim=0) 55 | 56 | return all_c2w 57 | 58 | class DTUDatasetBase(): 59 | def setup(self, config, split): 60 | self.config = config 61 | self.split = split 62 | self.rank = get_rank() 63 | 64 | cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file)) 65 | 66 | img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png')) 67 | H, W = img_sample.shape[0], img_sample.shape[1] 68 | 69 | if 'img_wh' in self.config: 70 | w, h = self.config.img_wh 71 | assert round(W / w * h) == H 72 | elif 'img_downscale' in self.config: 73 | w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) 74 | else: 75 | raise KeyError("Either img_wh or img_downscale should be specified.") 76 | 77 | self.w, self.h = w, h 78 | self.img_wh = (w, h) 79 | self.factor = w / W 80 | 81 | mask_dir = os.path.join(self.config.root_dir, 'mask') 82 | self.has_mask = True 83 | self.apply_mask = self.config.apply_mask 84 | 85 | self.directions = [] 86 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] 87 | 88 | n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1 89 | 90 | for i in range(n_images): 91 | world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}'] 92 | P = (world_mat @ scale_mat)[:3,:4] 93 | K, c2w = load_K_Rt_from_P(P) 94 | fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor 95 | directions = get_ray_directions(w, h, fx, fy, cx, cy) 96 | self.directions.append(directions) 97 | 98 | c2w = torch.from_numpy(c2w).float() 99 | 100 | # blender follows opengl camera coordinates (right up back) 101 | # NeuS DTU data coordinate system (right down front) is different from blender 102 | # https://github.com/Totoro97/NeuS/issues/9 103 | # for c2w, flip the sign of input camera coordinate yz 104 | c2w_ = c2w.clone() 105 | c2w_[:3,1:3] *= -1. # flip input sign 106 | self.all_c2w.append(c2w_[:3,:4]) 107 | 108 | if self.split in ['train', 'val']: 109 | img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png') 110 | img = Image.open(img_path) 111 | img = img.resize(self.img_wh, Image.BICUBIC) 112 | img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] 113 | 114 | mask_path = os.path.join(mask_dir, f'{i:03d}.png') 115 | mask = Image.open(mask_path).convert('L') # (H, W, 1) 116 | mask = mask.resize(self.img_wh, Image.BICUBIC) 117 | mask = TF.to_tensor(mask)[0] 118 | 119 | self.all_fg_masks.append(mask) # (h, w) 120 | self.all_images.append(img) 121 | 122 | self.all_c2w = torch.stack(self.all_c2w, dim=0) 123 | 124 | if self.split == 'test': 125 | self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) 126 | self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) 127 | self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) 128 | self.directions = self.directions[0] 129 | else: 130 | self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0) 131 | self.directions = torch.stack(self.directions, dim=0) 132 | 133 | self.directions = self.directions.float().to(self.rank) 134 | self.all_c2w, self.all_images, self.all_fg_masks = \ 135 | self.all_c2w.float().to(self.rank), \ 136 | self.all_images.float().to(self.rank), \ 137 | self.all_fg_masks.float().to(self.rank) 138 | 139 | 140 | class DTUDataset(Dataset, DTUDatasetBase): 141 | def __init__(self, config, split): 142 | self.setup(config, split) 143 | 144 | def __len__(self): 145 | return len(self.all_images) 146 | 147 | def __getitem__(self, index): 148 | return { 149 | 'index': index 150 | } 151 | 152 | 153 | class DTUIterableDataset(IterableDataset, DTUDatasetBase): 154 | def __init__(self, config, split): 155 | self.setup(config, split) 156 | 157 | def __iter__(self): 158 | while True: 159 | yield {} 160 | 161 | 162 | @datasets.register('dtu') 163 | class DTUDataModule(pl.LightningDataModule): 164 | def __init__(self, config): 165 | super().__init__() 166 | self.config = config 167 | 168 | def setup(self, stage=None): 169 | if stage in [None, 'fit']: 170 | self.train_dataset = DTUIterableDataset(self.config, 'train') 171 | if stage in [None, 'fit', 'validate']: 172 | self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train')) 173 | if stage in [None, 'test']: 174 | self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test')) 175 | if stage in [None, 'predict']: 176 | self.predict_dataset = DTUDataset(self.config, 'train') 177 | 178 | def prepare_data(self): 179 | pass 180 | 181 | def general_loader(self, dataset, batch_size): 182 | sampler = None 183 | return DataLoader( 184 | dataset, 185 | num_workers=os.cpu_count(), 186 | batch_size=batch_size, 187 | pin_memory=True, 188 | sampler=sampler 189 | ) 190 | 191 | def train_dataloader(self): 192 | return self.general_loader(self.train_dataset, batch_size=1) 193 | 194 | def val_dataloader(self): 195 | return self.general_loader(self.val_dataset, batch_size=1) 196 | 197 | def test_dataloader(self): 198 | return self.general_loader(self.test_dataset, batch_size=1) 199 | 200 | def predict_dataloader(self): 201 | return self.general_loader(self.predict_dataset, batch_size=1) 202 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_back_RT.txt: -------------------------------------------------------------------------------- 1 | -1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 2 | 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07 3 | 0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_back_left_RT.txt: -------------------------------------------------------------------------------- 1 | -7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 3 | -7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_back_right_RT.txt: -------------------------------------------------------------------------------- 1 | -7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 3 | 7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_front_RT.txt: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 2 | 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07 3 | 0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_front_left_RT.txt: -------------------------------------------------------------------------------- 1 | 7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 3 | -7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_front_right_RT.txt: -------------------------------------------------------------------------------- 1 | 7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 3 | 7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_left_RT.txt: -------------------------------------------------------------------------------- 1 | -2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16 2 | 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 3 | -1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_right_RT.txt: -------------------------------------------------------------------------------- 1 | -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16 2 | 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 3 | 1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/fixed_poses/000_top_RT.txt: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 2 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 3 | 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00 4 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/ortho.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader, IterableDataset 11 | import torchvision.transforms.functional as TF 12 | 13 | import pytorch_lightning as pl 14 | 15 | import datasets 16 | from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions 17 | from utils.misc import get_rank 18 | 19 | from glob import glob 20 | import PIL.Image 21 | 22 | 23 | def camNormal2worldNormal(rot_c2w, camNormal): 24 | H,W,_ = camNormal.shape 25 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 26 | 27 | return normal_img 28 | 29 | def worldNormal2camNormal(rot_w2c, worldNormal): 30 | H,W,_ = worldNormal.shape 31 | normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 32 | 33 | return normal_img 34 | 35 | def trans_normal(normal, RT_w2c, RT_w2c_target): 36 | 37 | normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) 38 | normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) 39 | 40 | return normal_target_cam 41 | 42 | def img2normal(img): 43 | return (img/255.)*2-1 44 | 45 | def normal2img(normal): 46 | return np.uint8((normal*0.5+0.5)*255) 47 | 48 | def norm_normalize(normal, dim=-1): 49 | 50 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) 51 | 52 | return normal 53 | 54 | def RT_opengl2opencv(RT): 55 | # Build the coordinate transform matrix from world to computer vision camera 56 | # R_world2cv = R_bcam2cv@R_world2bcam 57 | # T_world2cv = R_bcam2cv@T_world2bcam 58 | 59 | R = RT[:3, :3] 60 | t = RT[:3, 3] 61 | 62 | R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) 63 | 64 | R_world2cv = R_bcam2cv @ R 65 | t_world2cv = R_bcam2cv @ t 66 | 67 | RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) 68 | 69 | return RT 70 | 71 | def normal_opengl2opencv(normal): 72 | H,W,C = np.shape(normal) 73 | # normal_img = np.reshape(normal, (H*W,C)) 74 | R_bcam2cv = np.array([1, -1, -1], np.float32) 75 | normal_cv = normal * R_bcam2cv[None, None, :] 76 | 77 | print(np.shape(normal_cv)) 78 | 79 | return normal_cv 80 | 81 | def inv_RT(RT): 82 | RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0) 83 | RT_inv = np.linalg.inv(RT_h) 84 | 85 | return RT_inv[:3, :] 86 | 87 | 88 | def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None, 89 | normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None): 90 | 91 | all_images = [] 92 | all_normals = [] 93 | all_normals_world = [] 94 | all_masks = [] 95 | all_color_masks = [] 96 | all_poses = [] 97 | all_w2cs = [] 98 | directions = [] 99 | ray_origins = [] 100 | 101 | RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix 102 | RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv 103 | for idx, view in enumerate(view_types): 104 | print(os.path.join(root_dir,test_object)) 105 | normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view)) 106 | # Load key frame 107 | if load_color: # use bgr 108 | image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3] 109 | 110 | normal = np.array(PIL.Image.open(normal_filepath).resize(imSize)) 111 | mask = normal[:, :, 3] 112 | normal = normal[:, :, :3] 113 | 114 | color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3] 115 | invalid_color_mask = color_mask < 255*0.5 116 | threshold = np.ones_like(image[:, :, 0]) * 250 117 | invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold) 118 | invalid_color_mask_final = invalid_color_mask & invalid_white_mask 119 | color_mask = (1 - invalid_color_mask_final) > 0 120 | 121 | # if erode_mask: 122 | # kernel = np.ones((3, 3), np.uint8) 123 | # mask = cv2.erode(mask, kernel, iterations=1) 124 | 125 | RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix 126 | 127 | normal = img2normal(normal) 128 | 129 | normal[mask==0] = [0,0,0] 130 | mask = mask> (0.5*255) 131 | if load_color: 132 | all_images.append(image) 133 | 134 | all_masks.append(mask) 135 | all_color_masks.append(color_mask) 136 | RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv 137 | all_poses.append(inv_RT(RT_cv)) # cam2world 138 | all_w2cs.append(RT_cv) 139 | 140 | # whether to 141 | normal_cam_cv = normal_opengl2opencv(normal) 142 | 143 | if normal_system == 'front': 144 | print("the loaded normals are defined in the system of front view") 145 | normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv) 146 | elif normal_system == 'self': 147 | print("the loaded normals are in their independent camera systems") 148 | normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv) 149 | all_normals.append(normal_cam_cv) 150 | all_normals_world.append(normal_world) 151 | 152 | if camera_type == 'ortho': 153 | origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1]) 154 | elif camera_type == 'pinhole': 155 | dirs = get_ray_directions(W=imSize[0], H=imSize[1], 156 | fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3]) 157 | origins = dirs # occupy a position 158 | else: 159 | raise Exception("not support camera type") 160 | ray_origins.append(origins) 161 | directions.append(dirs) 162 | 163 | 164 | if not load_color: 165 | all_images = [normal2img(x) for x in all_normals_world] 166 | 167 | 168 | return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \ 169 | np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks) 170 | 171 | 172 | class OrthoDatasetBase(): 173 | def setup(self, config, split): 174 | self.config = config 175 | self.split = split 176 | self.rank = get_rank() 177 | 178 | self.data_dir = self.config.root_dir 179 | self.object_name = self.config.scene 180 | self.scene = self.config.scene 181 | self.imSize = self.config.imSize 182 | self.load_color = True 183 | self.img_wh = [self.imSize[0], self.imSize[1]] 184 | self.w = self.img_wh[0] 185 | self.h = self.img_wh[1] 186 | self.camera_type = self.config.camera_type 187 | self.camera_params = self.config.camera_params # [fx, fy, cx, cy] 188 | 189 | self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] 190 | 191 | self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1) 192 | self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w) 193 | 194 | if self.config.cam_pose_dir is None: 195 | self.cam_pose_dir = "./datasets/fixed_poses" 196 | else: 197 | self.cam_pose_dir = self.config.cam_pose_dir 198 | 199 | self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \ 200 | self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction( 201 | self.data_dir, self.object_name, self.imSize, self.view_types, 202 | self.load_color, self.cam_pose_dir, normal_system='front', 203 | camera_type=self.camera_type, cam_params=self.camera_params) 204 | 205 | self.has_mask = True 206 | self.apply_mask = self.config.apply_mask 207 | 208 | self.all_c2w = torch.from_numpy(self.pose_all_np) 209 | self.all_images = torch.from_numpy(self.images_np) / 255. 210 | self.all_fg_masks = torch.from_numpy(self.masks_np) 211 | self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np) 212 | self.all_normals_world = torch.from_numpy(self.normals_world_np) 213 | self.origins = torch.from_numpy(self.origins_np) 214 | self.directions = torch.from_numpy(self.directions_np) 215 | 216 | self.directions = self.directions.float().to(self.rank) 217 | self.origins = self.origins.float().to(self.rank) 218 | self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank) 219 | self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \ 220 | self.all_c2w.float().to(self.rank), \ 221 | self.all_images.float().to(self.rank), \ 222 | self.all_fg_masks.float().to(self.rank), \ 223 | self.all_normals_world.float().to(self.rank) 224 | 225 | 226 | class OrthoDataset(Dataset, OrthoDatasetBase): 227 | def __init__(self, config, split): 228 | self.setup(config, split) 229 | 230 | def __len__(self): 231 | return len(self.all_images) 232 | 233 | def __getitem__(self, index): 234 | return { 235 | 'index': index 236 | } 237 | 238 | 239 | class OrthoIterableDataset(IterableDataset, OrthoDatasetBase): 240 | def __init__(self, config, split): 241 | self.setup(config, split) 242 | 243 | def __iter__(self): 244 | while True: 245 | yield {} 246 | 247 | 248 | @datasets.register('ortho') 249 | class OrthoDataModule(pl.LightningDataModule): 250 | def __init__(self, config): 251 | super().__init__() 252 | self.config = config 253 | 254 | def setup(self, stage=None): 255 | if stage in [None, 'fit']: 256 | self.train_dataset = OrthoIterableDataset(self.config, 'train') 257 | if stage in [None, 'fit', 'validate']: 258 | self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train')) 259 | if stage in [None, 'test']: 260 | self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test')) 261 | if stage in [None, 'predict']: 262 | self.predict_dataset = OrthoDataset(self.config, 'train') 263 | 264 | def prepare_data(self): 265 | pass 266 | 267 | def general_loader(self, dataset, batch_size): 268 | sampler = None 269 | return DataLoader( 270 | dataset, 271 | num_workers=os.cpu_count(), 272 | batch_size=batch_size, 273 | pin_memory=True, 274 | sampler=sampler 275 | ) 276 | 277 | def train_dataloader(self): 278 | return self.general_loader(self.train_dataset, batch_size=1) 279 | 280 | def val_dataloader(self): 281 | return self.general_loader(self.val_dataset, batch_size=1) 282 | 283 | def test_dataloader(self): 284 | return self.general_loader(self.test_dataset, batch_size=1) 285 | 286 | def predict_dataloader(self): 287 | return self.general_loader(self.predict_dataset, batch_size=1) 288 | -------------------------------------------------------------------------------- /instant-nsr-pl/datasets/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/instant-nsr-pl/datasets/utils.py -------------------------------------------------------------------------------- /instant-nsr-pl/launch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import time 5 | import logging 6 | from datetime import datetime 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--config', required=True, help='path to config file') 12 | parser.add_argument('--gpu', default='0', help='GPU(s) to be used') 13 | parser.add_argument('--resume', default=None, help='path to the weights to be resumed') 14 | parser.add_argument( 15 | '--resume_weights_only', 16 | action='store_true', 17 | help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only' 18 | ) 19 | 20 | group = parser.add_mutually_exclusive_group(required=True) 21 | group.add_argument('--train', action='store_true') 22 | group.add_argument('--validate', action='store_true') 23 | group.add_argument('--test', action='store_true') 24 | group.add_argument('--predict', action='store_true') 25 | # group.add_argument('--export', action='store_true') # TODO: a separate export action 26 | 27 | parser.add_argument('--exp_dir', default='./exp') 28 | parser.add_argument('--runs_dir', default='./runs') 29 | parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG') 30 | 31 | args, extras = parser.parse_known_args() 32 | 33 | # set CUDA_VISIBLE_DEVICES then import pytorch-lightning 34 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 36 | n_gpus = len(args.gpu.split(',')) 37 | 38 | import datasets 39 | import systems 40 | import pytorch_lightning as pl 41 | from pytorch_lightning import Trainer 42 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 43 | from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger 44 | from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar 45 | from utils.misc import load_config 46 | 47 | # parse YAML config to OmegaConf 48 | config = load_config(args.config, cli_args=extras) 49 | config.cmd_args = vars(args) 50 | 51 | config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S')) 52 | config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name) 53 | config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save') 54 | config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt') 55 | config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code') 56 | config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config') 57 | 58 | logger = logging.getLogger('pytorch_lightning') 59 | if args.verbose: 60 | logger.setLevel(logging.DEBUG) 61 | 62 | if 'seed' not in config: 63 | config.seed = int(time.time() * 1000) % 1000 64 | pl.seed_everything(config.seed) 65 | 66 | dm = datasets.make(config.dataset.name, config.dataset) 67 | system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume) 68 | 69 | callbacks = [] 70 | if args.train: 71 | callbacks += [ 72 | ModelCheckpoint( 73 | dirpath=config.ckpt_dir, 74 | **config.checkpoint 75 | ), 76 | LearningRateMonitor(logging_interval='step'), 77 | CodeSnapshotCallback( 78 | config.code_dir, use_version=False 79 | ), 80 | ConfigSnapshotCallback( 81 | config, config.config_dir, use_version=False 82 | ), 83 | CustomProgressBar(refresh_rate=1), 84 | ] 85 | 86 | loggers = [] 87 | if args.train: 88 | loggers += [ 89 | TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name), 90 | CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs') 91 | ] 92 | 93 | if sys.platform == 'win32': 94 | # does not support multi-gpu on windows 95 | strategy = 'dp' 96 | assert n_gpus == 1 97 | else: 98 | strategy = 'ddp_find_unused_parameters_false' 99 | 100 | trainer = Trainer( 101 | devices=n_gpus, 102 | accelerator='gpu', 103 | callbacks=callbacks, 104 | logger=loggers, 105 | strategy=strategy, 106 | **config.trainer 107 | ) 108 | 109 | if args.train: 110 | if args.resume and not args.resume_weights_only: 111 | # FIXME: different behavior in pytorch-lighting>1.9 ? 112 | trainer.fit(system, datamodule=dm, ckpt_path=args.resume) 113 | else: 114 | trainer.fit(system, datamodule=dm) 115 | trainer.test(system, datamodule=dm) 116 | elif args.validate: 117 | trainer.validate(system, datamodule=dm, ckpt_path=args.resume) 118 | elif args.test: 119 | trainer.test(system, datamodule=dm, ckpt_path=args.resume) 120 | elif args.predict: 121 | trainer.predict(system, datamodule=dm, ckpt_path=args.resume) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/__init__.py: -------------------------------------------------------------------------------- 1 | models = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | models[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config): 12 | model = models[name](config) 13 | return model 14 | 15 | 16 | from . import nerf, neus, geometry, texture 17 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.misc import get_rank 5 | 6 | class BaseModel(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.rank = get_rank() 11 | self.setup() 12 | if self.config.get('weights', None): 13 | self.load_state_dict(torch.load(self.config.weights)) 14 | 15 | def setup(self): 16 | raise NotImplementedError 17 | 18 | def update_step(self, epoch, global_step): 19 | pass 20 | 21 | def train(self, mode=True): 22 | return super().train(mode=mode) 23 | 24 | def eval(self): 25 | return super().eval() 26 | 27 | def regularizations(self, out): 28 | return {} 29 | 30 | @torch.no_grad() 31 | def export(self, export_config): 32 | return {} 33 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/nerf.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import models 8 | from models.base import BaseModel 9 | from models.utils import chunk_batch 10 | from systems.utils import update_module_step 11 | from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays 12 | 13 | 14 | @models.register('nerf') 15 | class NeRFModel(BaseModel): 16 | def setup(self): 17 | self.geometry = models.make(self.config.geometry.name, self.config.geometry) 18 | self.texture = models.make(self.config.texture.name, self.config.texture) 19 | self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) 20 | 21 | if self.config.learned_background: 22 | self.occupancy_grid_res = 256 23 | self.near_plane, self.far_plane = 0.2, 1e4 24 | self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate 25 | self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size) 26 | self.contraction_type = ContractionType.UN_BOUNDED_SPHERE 27 | else: 28 | self.occupancy_grid_res = 128 29 | self.near_plane, self.far_plane = None, None 30 | self.cone_angle = 0.0 31 | self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray 32 | self.contraction_type = ContractionType.AABB 33 | 34 | self.geometry.contraction_type = self.contraction_type 35 | 36 | if self.config.grid_prune: 37 | self.occupancy_grid = OccupancyGrid( 38 | roi_aabb=self.scene_aabb, 39 | resolution=self.occupancy_grid_res, 40 | contraction_type=self.contraction_type 41 | ) 42 | self.randomized = self.config.randomized 43 | self.background_color = None 44 | 45 | def update_step(self, epoch, global_step): 46 | update_module_step(self.geometry, epoch, global_step) 47 | update_module_step(self.texture, epoch, global_step) 48 | 49 | def occ_eval_fn(x): 50 | density, _ = self.geometry(x) 51 | # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series 52 | return density[...,None] * self.render_step_size 53 | 54 | if self.training and self.config.grid_prune: 55 | self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn) 56 | 57 | def isosurface(self): 58 | mesh = self.geometry.isosurface() 59 | return mesh 60 | 61 | def forward_(self, rays): 62 | n_rays = rays.shape[0] 63 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 64 | 65 | def sigma_fn(t_starts, t_ends, ray_indices): 66 | ray_indices = ray_indices.long() 67 | t_origins = rays_o[ray_indices] 68 | t_dirs = rays_d[ray_indices] 69 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2. 70 | density, _ = self.geometry(positions) 71 | return density[...,None] 72 | 73 | def rgb_sigma_fn(t_starts, t_ends, ray_indices): 74 | ray_indices = ray_indices.long() 75 | t_origins = rays_o[ray_indices] 76 | t_dirs = rays_d[ray_indices] 77 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2. 78 | density, feature = self.geometry(positions) 79 | rgb = self.texture(feature, t_dirs) 80 | return rgb, density[...,None] 81 | 82 | with torch.no_grad(): 83 | ray_indices, t_starts, t_ends = ray_marching( 84 | rays_o, rays_d, 85 | scene_aabb=None if self.config.learned_background else self.scene_aabb, 86 | grid=self.occupancy_grid if self.config.grid_prune else None, 87 | sigma_fn=sigma_fn, 88 | near_plane=self.near_plane, far_plane=self.far_plane, 89 | render_step_size=self.render_step_size, 90 | stratified=self.randomized, 91 | cone_angle=self.cone_angle, 92 | alpha_thre=0.0 93 | ) 94 | 95 | ray_indices = ray_indices.long() 96 | t_origins = rays_o[ray_indices] 97 | t_dirs = rays_d[ray_indices] 98 | midpoints = (t_starts + t_ends) / 2. 99 | positions = t_origins + t_dirs * midpoints 100 | intervals = t_ends - t_starts 101 | 102 | density, feature = self.geometry(positions) 103 | rgb = self.texture(feature, t_dirs) 104 | 105 | weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) 106 | opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) 107 | depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) 108 | comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) 109 | comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) 110 | 111 | out = { 112 | 'comp_rgb': comp_rgb, 113 | 'opacity': opacity, 114 | 'depth': depth, 115 | 'rays_valid': opacity > 0, 116 | 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) 117 | } 118 | 119 | if self.training: 120 | out.update({ 121 | 'weights': weights.view(-1), 122 | 'points': midpoints.view(-1), 123 | 'intervals': intervals.view(-1), 124 | 'ray_indices': ray_indices.view(-1) 125 | }) 126 | 127 | return out 128 | 129 | def forward(self, rays): 130 | if self.training: 131 | out = self.forward_(rays) 132 | else: 133 | out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) 134 | return { 135 | **out, 136 | } 137 | 138 | def train(self, mode=True): 139 | self.randomized = mode and self.config.randomized 140 | return super().train(mode=mode) 141 | 142 | def eval(self): 143 | self.randomized = False 144 | return super().eval() 145 | 146 | def regularizations(self, out): 147 | losses = {} 148 | losses.update(self.geometry.regularizations(out)) 149 | losses.update(self.texture.regularizations(out)) 150 | return losses 151 | 152 | @torch.no_grad() 153 | def export(self, export_config): 154 | mesh = self.isosurface() 155 | if export_config.export_vertex_color: 156 | _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank)) 157 | viewdirs = torch.zeros(feature.shape[0], 3).to(feature) 158 | viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down) 159 | rgb = self.texture(feature, viewdirs).clamp(0,1) 160 | mesh['v_rgb'] = rgb.cpu() 161 | return mesh 162 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/network_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import tinycudann as tcnn 7 | 8 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info 9 | 10 | from utils.misc import config_to_primitive, get_rank 11 | from models.utils import get_activation 12 | from systems.utils import update_module_step 13 | 14 | class VanillaFrequency(nn.Module): 15 | def __init__(self, in_channels, config): 16 | super().__init__() 17 | self.N_freqs = config['n_frequencies'] 18 | self.in_channels, self.n_input_dims = in_channels, in_channels 19 | self.funcs = [torch.sin, torch.cos] 20 | self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs) 21 | self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) 22 | self.n_masking_step = config.get('n_masking_step', 0) 23 | self.update_step(None, None) # mask should be updated at the beginning each step 24 | 25 | def forward(self, x): 26 | out = [] 27 | for freq, mask in zip(self.freq_bands, self.mask): 28 | for func in self.funcs: 29 | out += [func(freq*x) * mask] 30 | return torch.cat(out, -1) 31 | 32 | def update_step(self, epoch, global_step): 33 | if self.n_masking_step <= 0 or global_step is None: 34 | self.mask = torch.ones(self.N_freqs, dtype=torch.float32) 35 | else: 36 | self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2. 37 | rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}') 38 | 39 | 40 | class ProgressiveBandHashGrid(nn.Module): 41 | def __init__(self, in_channels, config): 42 | super().__init__() 43 | self.n_input_dims = in_channels 44 | encoding_config = config.copy() 45 | encoding_config['otype'] = 'HashGrid' 46 | with torch.cuda.device(get_rank()): 47 | self.encoding = tcnn.Encoding(in_channels, encoding_config) 48 | self.n_output_dims = self.encoding.n_output_dims 49 | self.n_level = config['n_levels'] 50 | self.n_features_per_level = config['n_features_per_level'] 51 | self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps'] 52 | self.current_level = self.start_level 53 | self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank()) 54 | 55 | def forward(self, x): 56 | enc = self.encoding(x) 57 | enc = enc * self.mask 58 | return enc 59 | 60 | def update_step(self, epoch, global_step): 61 | current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level) 62 | if current_level > self.current_level: 63 | rank_zero_info(f'Update grid level to {current_level}') 64 | self.current_level = current_level 65 | self.mask[:self.current_level * self.n_features_per_level] = 1. 66 | 67 | 68 | class CompositeEncoding(nn.Module): 69 | def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.): 70 | super(CompositeEncoding, self).__init__() 71 | self.encoding = encoding 72 | self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset 73 | self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims 74 | 75 | def forward(self, x, *args): 76 | return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1) 77 | 78 | def update_step(self, epoch, global_step): 79 | update_module_step(self.encoding, epoch, global_step) 80 | 81 | 82 | def get_encoding(n_input_dims, config): 83 | # input suppose to be range [0, 1] 84 | if config.otype == 'VanillaFrequency': 85 | encoding = VanillaFrequency(n_input_dims, config_to_primitive(config)) 86 | elif config.otype == 'ProgressiveBandHashGrid': 87 | encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) 88 | else: 89 | with torch.cuda.device(get_rank()): 90 | encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config)) 91 | encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.) 92 | return encoding 93 | 94 | 95 | class VanillaMLP(nn.Module): 96 | def __init__(self, dim_in, dim_out, config): 97 | super().__init__() 98 | self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers'] 99 | self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False) 100 | self.sphere_init_radius = config.get('sphere_init_radius', 0.5) 101 | self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] 102 | for i in range(self.n_hidden_layers - 1): 103 | self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] 104 | self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] 105 | self.layers = nn.Sequential(*self.layers) 106 | self.output_activation = get_activation(config['output_activation']) 107 | 108 | @torch.cuda.amp.autocast(False) 109 | def forward(self, x): 110 | x = self.layers(x.float()) 111 | x = self.output_activation(x) 112 | return x 113 | 114 | def make_linear(self, dim_in, dim_out, is_first, is_last): 115 | layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality 116 | if self.sphere_init: 117 | if is_last: 118 | torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) 119 | torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) 120 | elif is_first: 121 | torch.nn.init.constant_(layer.bias, 0.0) 122 | torch.nn.init.constant_(layer.weight[:, 3:], 0.0) 123 | torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) 124 | else: 125 | torch.nn.init.constant_(layer.bias, 0.0) 126 | torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) 127 | else: 128 | torch.nn.init.constant_(layer.bias, 0.0) 129 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') 130 | 131 | if self.weight_norm: 132 | layer = nn.utils.weight_norm(layer) 133 | return layer 134 | 135 | def make_activation(self): 136 | if self.sphere_init: 137 | return nn.Softplus(beta=100) 138 | else: 139 | return nn.ReLU(inplace=True) 140 | 141 | 142 | def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network): 143 | rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.') 144 | """ 145 | from https://github.com/NVlabs/tiny-cuda-nn/issues/96 146 | It's the weight matrices of each layer laid out in row-major order and then concatenated. 147 | Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP). 148 | The padded input dimensions get a constant value of 1.0, 149 | whereas the padded output dimensions are simply ignored, 150 | so the weights pertaining to those can have any value. 151 | """ 152 | padto = 16 if config.otype == 'FullyFusedMLP' else 8 153 | n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto 154 | n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto 155 | data = list(network.parameters())[0].data 156 | assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2 157 | new_data = [] 158 | # first layer 159 | weight = torch.zeros((config.n_neurons, n_input_dims)).to(data) 160 | torch.nn.init.constant_(weight[:, 3:], 0.0) 161 | torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) 162 | new_data.append(weight.flatten()) 163 | # hidden layers 164 | for i in range(config.n_hidden_layers - 1): 165 | weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data) 166 | torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) 167 | new_data.append(weight.flatten()) 168 | # last layer 169 | weight = torch.zeros((n_output_dims, config.n_neurons)).to(data) 170 | torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001) 171 | new_data.append(weight.flatten()) 172 | new_data = torch.cat(new_data) 173 | data.copy_(new_data) 174 | 175 | 176 | def get_mlp(n_input_dims, n_output_dims, config): 177 | if config.otype == 'VanillaMLP': 178 | network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) 179 | else: 180 | with torch.cuda.device(get_rank()): 181 | network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config)) 182 | if config.get('sphere_init', False): 183 | sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network) 184 | return network 185 | 186 | 187 | class EncodingWithNetwork(nn.Module): 188 | def __init__(self, encoding, network): 189 | super().__init__() 190 | self.encoding, self.network = encoding, network 191 | 192 | def forward(self, x): 193 | return self.network(self.encoding(x)) 194 | 195 | def update_step(self, epoch, global_step): 196 | update_module_step(self.encoding, epoch, global_step) 197 | update_module_step(self.network, epoch, global_step) 198 | 199 | 200 | def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config): 201 | # input suppose to be range [0, 1] 202 | if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \ 203 | or network_config.otype in ['VanillaMLP']: 204 | encoding = get_encoding(n_input_dims, encoding_config) 205 | network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) 206 | encoding_with_network = EncodingWithNetwork(encoding, network) 207 | else: 208 | with torch.cuda.device(get_rank()): 209 | encoding_with_network = tcnn.NetworkWithInputEncoding( 210 | n_input_dims=n_input_dims, 211 | n_output_dims=n_output_dims, 212 | encoding_config=config_to_primitive(encoding_config), 213 | network_config=config_to_primitive(network_config) 214 | ) 215 | return encoding_with_network 216 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def cast_rays(ori, dir, z_vals): 6 | return ori[..., None, :] + z_vals[..., None] * dir[..., None, :] 7 | 8 | 9 | def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): 10 | pixel_center = 0.5 if use_pixel_centers else 0 11 | i, j = np.meshgrid( 12 | np.arange(W, dtype=np.float32) + pixel_center, 13 | np.arange(H, dtype=np.float32) + pixel_center, 14 | indexing='xy' 15 | ) 16 | i, j = torch.from_numpy(i), torch.from_numpy(j) 17 | 18 | # directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) 19 | # opencv system 20 | directions = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1) # (H, W, 3) 21 | 22 | return directions 23 | 24 | 25 | def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True): 26 | pixel_center = 0.5 if use_pixel_centers else 0 27 | i, j = np.meshgrid( 28 | np.arange(W, dtype=np.float32) + pixel_center, 29 | np.arange(H, dtype=np.float32) + pixel_center, 30 | indexing='xy' 31 | ) 32 | i, j = torch.from_numpy(i), torch.from_numpy(j) 33 | 34 | origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2, torch.zeros_like(i)], dim=-1) # W, H, 3 35 | directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 36 | 37 | return origins, directions 38 | 39 | 40 | def get_rays(directions, c2w, keepdim=False): 41 | # Rotate ray directions from camera coordinate to the world coordinate 42 | # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? 43 | assert directions.shape[-1] == 3 44 | 45 | if directions.ndim == 2: # (N_rays, 3) 46 | assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) 47 | rays_d = (directions[:,None,:] * c2w[:,:3,:3]).sum(-1) # (N_rays, 3) 48 | rays_o = c2w[:,:,3].expand(rays_d.shape) 49 | elif directions.ndim == 3: # (H, W, 3) 50 | if c2w.ndim == 2: # (4, 4) 51 | rays_d = (directions[:,:,None,:] * c2w[None,None,:3,:3]).sum(-1) # (H, W, 3) 52 | rays_o = c2w[None,None,:,3].expand(rays_d.shape) 53 | elif c2w.ndim == 3: # (B, 4, 4) 54 | rays_d = (directions[None,:,:,None,:] * c2w[:,None,None,:3,:3]).sum(-1) # (B, H, W, 3) 55 | rays_o = c2w[:,None,None,:,3].expand(rays_d.shape) 56 | 57 | if not keepdim: 58 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 59 | 60 | return rays_o, rays_d 61 | 62 | 63 | # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), rays_v[:, :, :, None].cuda()).squeeze() # W, H, 3 64 | 65 | # rays_o = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), q[:, :, :, None].cuda()).squeeze() # W, H, 3 66 | # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape).cuda() + rays_o # W, H, 3 67 | 68 | def get_ortho_rays(origins, directions, c2w, keepdim=False): 69 | # Rotate ray directions from camera coordinate to the world coordinate 70 | # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? 71 | assert directions.shape[-1] == 3 72 | assert origins.shape[-1] == 3 73 | 74 | if directions.ndim == 2: # (N_rays, 3) 75 | assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) 76 | rays_d = torch.matmul(c2w[:, :3, :3], directions[:, :, None]).squeeze() # (N_rays, 3) 77 | rays_o = torch.matmul(c2w[:, :3, :3], origins[:, :, None]).squeeze() # (N_rays, 3) 78 | rays_o = c2w[:,:3,3].expand(rays_d.shape) + rays_o 79 | elif directions.ndim == 3: # (H, W, 3) 80 | if c2w.ndim == 2: # (4, 4) 81 | rays_d = torch.matmul(c2w[None, None, :3, :3], directions[:, :, :, None]).squeeze() # (H, W, 3) 82 | rays_o = torch.matmul(c2w[None, None, :3, :3], origins[:, :, :, None]).squeeze() # (H, W, 3) 83 | rays_o = c2w[None, None,:3,3].expand(rays_d.shape) + rays_o 84 | elif c2w.ndim == 3: # (B, 4, 4) 85 | rays_d = torch.matmul(c2w[:,None, None, :3, :3], directions[None, :, :, :, None]).squeeze() # # (B, H, W, 3) 86 | rays_o = torch.matmul(c2w[:,None, None, :3, :3], origins[None, :, :, :, None]).squeeze() # # (B, H, W, 3) 87 | rays_o = c2w[:,None, None, :3,3].expand(rays_d.shape) + rays_o 88 | 89 | if not keepdim: 90 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 91 | 92 | return rays_o, rays_d 93 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/texture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import models 5 | from models.utils import get_activation 6 | from models.network_utils import get_encoding, get_mlp 7 | from systems.utils import update_module_step 8 | 9 | 10 | @models.register('volume-radiance') 11 | class VolumeRadiance(nn.Module): 12 | def __init__(self, config): 13 | super(VolumeRadiance, self).__init__() 14 | self.config = config 15 | self.with_viewdir = False #self.config.get('wo_viewdir', False) 16 | self.n_dir_dims = self.config.get('n_dir_dims', 3) 17 | self.n_output_dims = 3 18 | 19 | if self.with_viewdir: 20 | encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) 21 | self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims 22 | # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config) 23 | else: 24 | encoding = None 25 | self.n_input_dims = self.config.input_feature_dim 26 | 27 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) 28 | self.encoding = encoding 29 | self.network = network 30 | 31 | def forward(self, features, dirs, *args): 32 | 33 | # features = features.detach() 34 | if self.with_viewdir: 35 | dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) 36 | dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) 37 | network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) 38 | # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) 39 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() 40 | # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float() 41 | # color = color + color_base 42 | else: 43 | network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) 44 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() 45 | 46 | if 'color_activation' in self.config: 47 | color = get_activation(self.config.color_activation)(color) 48 | return color 49 | 50 | def update_step(self, epoch, global_step): 51 | update_module_step(self.encoding, epoch, global_step) 52 | 53 | def regularizations(self, out): 54 | return {} 55 | 56 | 57 | @models.register('volume-color') 58 | class VolumeColor(nn.Module): 59 | def __init__(self, config): 60 | super(VolumeColor, self).__init__() 61 | self.config = config 62 | self.n_output_dims = 3 63 | self.n_input_dims = self.config.input_feature_dim 64 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) 65 | self.network = network 66 | 67 | def forward(self, features, *args): 68 | network_inp = features.view(-1, features.shape[-1]) 69 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() 70 | if 'color_activation' in self.config: 71 | color = get_activation(self.config.color_activation)(color) 72 | return color 73 | 74 | def regularizations(self, out): 75 | return {} 76 | -------------------------------------------------------------------------------- /instant-nsr-pl/models/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from collections import defaultdict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | from torch.cuda.amp import custom_bwd, custom_fwd 9 | 10 | import tinycudann as tcnn 11 | 12 | 13 | def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): 14 | B = None 15 | for arg in args: 16 | if isinstance(arg, torch.Tensor): 17 | B = arg.shape[0] 18 | break 19 | out = defaultdict(list) 20 | out_type = None 21 | for i in range(0, B, chunk_size): 22 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) 23 | if out_chunk is None: 24 | continue 25 | out_type = type(out_chunk) 26 | if isinstance(out_chunk, torch.Tensor): 27 | out_chunk = {0: out_chunk} 28 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): 29 | chunk_length = len(out_chunk) 30 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} 31 | elif isinstance(out_chunk, dict): 32 | pass 33 | else: 34 | print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') 35 | exit(1) 36 | for k, v in out_chunk.items(): 37 | v = v if torch.is_grad_enabled() else v.detach() 38 | v = v.cpu() if move_to_cpu else v 39 | out[k].append(v) 40 | 41 | if out_type is None: 42 | return 43 | 44 | out = {k: torch.cat(v, dim=0) for k, v in out.items()} 45 | if out_type is torch.Tensor: 46 | return out[0] 47 | elif out_type in [tuple, list]: 48 | return out_type([out[i] for i in range(chunk_length)]) 49 | elif out_type is dict: 50 | return out 51 | 52 | 53 | class _TruncExp(Function): # pylint: disable=abstract-method 54 | # Implementation from torch-ngp: 55 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 56 | @staticmethod 57 | @custom_fwd(cast_inputs=torch.float32) 58 | def forward(ctx, x): # pylint: disable=arguments-differ 59 | ctx.save_for_backward(x) 60 | return torch.exp(x) 61 | 62 | @staticmethod 63 | @custom_bwd 64 | def backward(ctx, g): # pylint: disable=arguments-differ 65 | x = ctx.saved_tensors[0] 66 | return g * torch.exp(torch.clamp(x, max=15)) 67 | 68 | trunc_exp = _TruncExp.apply 69 | 70 | 71 | def get_activation(name): 72 | if name is None: 73 | return lambda x: x 74 | name = name.lower() 75 | if name == 'none': 76 | return lambda x: x 77 | elif name.startswith('scale'): 78 | scale_factor = float(name[5:]) 79 | return lambda x: x.clamp(0., scale_factor) / scale_factor 80 | elif name.startswith('clamp'): 81 | clamp_max = float(name[5:]) 82 | return lambda x: x.clamp(0., clamp_max) 83 | elif name.startswith('mul'): 84 | mul_factor = float(name[3:]) 85 | return lambda x: x * mul_factor 86 | elif name == 'lin2srgb': 87 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) 88 | elif name == 'trunc_exp': 89 | return trunc_exp 90 | elif name.startswith('+') or name.startswith('-'): 91 | return lambda x: x + float(name) 92 | elif name == 'sigmoid': 93 | return lambda x: torch.sigmoid(x) 94 | elif name == 'tanh': 95 | return lambda x: torch.tanh(x) 96 | else: 97 | return getattr(F, name) 98 | 99 | 100 | def dot(x, y): 101 | return torch.sum(x*y, -1, keepdim=True) 102 | 103 | 104 | def reflect(x, n): 105 | return 2 * dot(x, n) * n - x 106 | 107 | 108 | def scale_anything(dat, inp_scale, tgt_scale): 109 | if inp_scale is None: 110 | inp_scale = [dat.min(), dat.max()] 111 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 112 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 113 | return dat 114 | 115 | 116 | def cleanup(): 117 | gc.collect() 118 | torch.cuda.empty_cache() 119 | tcnn.free_temporary_memory() 120 | -------------------------------------------------------------------------------- /instant-nsr-pl/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning<2 2 | omegaconf==2.2.3 3 | nerfacc==0.3.3 4 | matplotlib 5 | opencv-python 6 | imageio 7 | imageio-ffmpeg 8 | scipy 9 | PyMCubes 10 | pyransac3d 11 | torch_efficient_distloss 12 | tensorboard 13 | -------------------------------------------------------------------------------- /instant-nsr-pl/run.sh: -------------------------------------------------------------------------------- 1 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=$1 dataset.scene=$2 -------------------------------------------------------------------------------- /instant-nsr-pl/scripts/imgs2poses.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This file is adapted from https://github.com/Fyusion/LLFF. 4 | """ 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import subprocess 10 | 11 | 12 | def run_colmap(basedir, match_type): 13 | logfile_name = os.path.join(basedir, 'colmap_output.txt') 14 | logfile = open(logfile_name, 'w') 15 | 16 | feature_extractor_args = [ 17 | 'colmap', 'feature_extractor', 18 | '--database_path', os.path.join(basedir, 'database.db'), 19 | '--image_path', os.path.join(basedir, 'images'), 20 | '--ImageReader.single_camera', '1' 21 | ] 22 | feat_output = ( subprocess.check_output(feature_extractor_args, universal_newlines=True) ) 23 | logfile.write(feat_output) 24 | print('Features extracted') 25 | 26 | exhaustive_matcher_args = [ 27 | 'colmap', match_type, 28 | '--database_path', os.path.join(basedir, 'database.db'), 29 | ] 30 | 31 | match_output = ( subprocess.check_output(exhaustive_matcher_args, universal_newlines=True) ) 32 | logfile.write(match_output) 33 | print('Features matched') 34 | 35 | p = os.path.join(basedir, 'sparse') 36 | if not os.path.exists(p): 37 | os.makedirs(p) 38 | 39 | mapper_args = [ 40 | 'colmap', 'mapper', 41 | '--database_path', os.path.join(basedir, 'database.db'), 42 | '--image_path', os.path.join(basedir, 'images'), 43 | '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6 44 | '--Mapper.num_threads', '16', 45 | '--Mapper.init_min_tri_angle', '4', 46 | '--Mapper.multiple_models', '0', 47 | '--Mapper.extract_colors', '0', 48 | ] 49 | 50 | map_output = ( subprocess.check_output(mapper_args, universal_newlines=True) ) 51 | logfile.write(map_output) 52 | logfile.close() 53 | print('Sparse map created') 54 | 55 | print( 'Finished running COLMAP, see {} for logs'.format(logfile_name) ) 56 | 57 | 58 | def gen_poses(basedir, match_type): 59 | files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']] 60 | if os.path.exists(os.path.join(basedir, 'sparse/0')): 61 | files_had = os.listdir(os.path.join(basedir, 'sparse/0')) 62 | else: 63 | files_had = [] 64 | if not all([f in files_had for f in files_needed]): 65 | print( 'Need to run COLMAP' ) 66 | run_colmap(basedir, match_type) 67 | else: 68 | print('Don\'t need to run COLMAP') 69 | 70 | return True 71 | 72 | 73 | if __name__=='__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--match_type', type=str, 76 | default='exhaustive_matcher', help='type of matcher used. Valid options: \ 77 | exhaustive_matcher sequential_matcher. Other matchers not supported at this time') 78 | parser.add_argument('scenedir', type=str, 79 | help='input scene directory') 80 | args = parser.parse_args() 81 | 82 | if args.match_type != 'exhaustive_matcher' and args.match_type != 'sequential_matcher': 83 | print('ERROR: matcher type ' + args.match_type + ' is not valid. Aborting') 84 | sys.exit() 85 | gen_poses(args.scenedir, args.match_type) 86 | -------------------------------------------------------------------------------- /instant-nsr-pl/systems/__init__.py: -------------------------------------------------------------------------------- 1 | systems = {} 2 | 3 | 4 | def register(name): 5 | def decorator(cls): 6 | systems[name] = cls 7 | return cls 8 | return decorator 9 | 10 | 11 | def make(name, config, load_from_checkpoint=None): 12 | if load_from_checkpoint is None: 13 | system = systems[name](config) 14 | else: 15 | system = systems[name].load_from_checkpoint(load_from_checkpoint, strict=False, config=config) 16 | return system 17 | 18 | 19 | from . import neus, neus_ortho, neus_pinhole 20 | -------------------------------------------------------------------------------- /instant-nsr-pl/systems/base.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | import models 4 | from systems.utils import parse_optimizer, parse_scheduler, update_module_step 5 | from utils.mixins import SaverMixin 6 | from utils.misc import config_to_primitive, get_rank 7 | 8 | 9 | class BaseSystem(pl.LightningModule, SaverMixin): 10 | """ 11 | Two ways to print to console: 12 | 1. self.print: correctly handle progress bar 13 | 2. rank_zero_info: use the logging module 14 | """ 15 | def __init__(self, config): 16 | super().__init__() 17 | self.config = config 18 | self.rank = get_rank() 19 | self.prepare() 20 | self.model = models.make(self.config.model.name, self.config.model) 21 | 22 | def prepare(self): 23 | pass 24 | 25 | def forward(self, batch): 26 | raise NotImplementedError 27 | 28 | def C(self, value): 29 | if isinstance(value, int) or isinstance(value, float): 30 | pass 31 | else: 32 | value = config_to_primitive(value) 33 | if not isinstance(value, list): 34 | raise TypeError('Scalar specification only supports list, got', type(value)) 35 | if len(value) == 3: 36 | value = [0] + value 37 | assert len(value) == 4 38 | start_step, start_value, end_value, end_step = value 39 | if isinstance(end_step, int): 40 | current_step = self.global_step 41 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) 42 | elif isinstance(end_step, float): 43 | current_step = self.current_epoch 44 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) 45 | return value 46 | 47 | def preprocess_data(self, batch, stage): 48 | pass 49 | 50 | """ 51 | Implementing on_after_batch_transfer of DataModule does the same. 52 | But on_after_batch_transfer does not support DP. 53 | """ 54 | def on_train_batch_start(self, batch, batch_idx, unused=0): 55 | self.dataset = self.trainer.datamodule.train_dataloader().dataset 56 | self.preprocess_data(batch, 'train') 57 | update_module_step(self.model, self.current_epoch, self.global_step) 58 | 59 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): 60 | self.dataset = self.trainer.datamodule.val_dataloader().dataset 61 | self.preprocess_data(batch, 'validation') 62 | update_module_step(self.model, self.current_epoch, self.global_step) 63 | 64 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx): 65 | self.dataset = self.trainer.datamodule.test_dataloader().dataset 66 | self.preprocess_data(batch, 'test') 67 | update_module_step(self.model, self.current_epoch, self.global_step) 68 | 69 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): 70 | self.dataset = self.trainer.datamodule.predict_dataloader().dataset 71 | self.preprocess_data(batch, 'predict') 72 | update_module_step(self.model, self.current_epoch, self.global_step) 73 | 74 | def training_step(self, batch, batch_idx): 75 | raise NotImplementedError 76 | 77 | """ 78 | # aggregate outputs from different devices (DP) 79 | def training_step_end(self, out): 80 | pass 81 | """ 82 | 83 | """ 84 | # aggregate outputs from different iterations 85 | def training_epoch_end(self, out): 86 | pass 87 | """ 88 | 89 | def validation_step(self, batch, batch_idx): 90 | raise NotImplementedError 91 | 92 | """ 93 | # aggregate outputs from different devices when using DP 94 | def validation_step_end(self, out): 95 | pass 96 | """ 97 | 98 | def validation_epoch_end(self, out): 99 | """ 100 | Gather metrics from all devices, compute mean. 101 | Purge repeated results using data index. 102 | """ 103 | raise NotImplementedError 104 | 105 | def test_step(self, batch, batch_idx): 106 | raise NotImplementedError 107 | 108 | def test_epoch_end(self, out): 109 | """ 110 | Gather metrics from all devices, compute mean. 111 | Purge repeated results using data index. 112 | """ 113 | raise NotImplementedError 114 | 115 | def export(self): 116 | raise NotImplementedError 117 | 118 | def configure_optimizers(self): 119 | optim = parse_optimizer(self.config.system.optimizer, self.model) 120 | ret = { 121 | 'optimizer': optim, 122 | } 123 | if 'scheduler' in self.config.system: 124 | ret.update({ 125 | 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), 126 | }) 127 | return ret 128 | 129 | -------------------------------------------------------------------------------- /instant-nsr-pl/systems/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class WeightedLoss(nn.Module): 7 | @property 8 | def func(self): 9 | raise NotImplementedError 10 | 11 | def forward(self, inputs, targets, weight=None, reduction='mean'): 12 | assert reduction in ['none', 'sum', 'mean', 'valid_mean'] 13 | loss = self.func(inputs, targets, reduction='none') 14 | if weight is not None: 15 | while weight.ndim < inputs.ndim: 16 | weight = weight[..., None] 17 | loss *= weight.float() 18 | if reduction == 'none': 19 | return loss 20 | elif reduction == 'sum': 21 | return loss.sum() 22 | elif reduction == 'mean': 23 | return loss.mean() 24 | elif reduction == 'valid_mean': 25 | return loss.sum() / weight.float().sum() 26 | 27 | 28 | class MSELoss(WeightedLoss): 29 | @property 30 | def func(self): 31 | return F.mse_loss 32 | 33 | 34 | class L1Loss(WeightedLoss): 35 | @property 36 | def func(self): 37 | return F.l1_loss 38 | 39 | 40 | class PSNR(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, inputs, targets, valid_mask=None, reduction='mean'): 45 | assert reduction in ['mean', 'none'] 46 | value = (inputs - targets)**2 47 | if valid_mask is not None: 48 | value = value[valid_mask] 49 | if reduction == 'mean': 50 | return -10 * torch.log10(torch.mean(value)) 51 | elif reduction == 'none': 52 | return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) 53 | 54 | 55 | class SSIM(): 56 | def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): 57 | self.kernel_size = kernel_size 58 | self.sigma = sigma 59 | self.gaussian = gaussian 60 | 61 | if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): 62 | raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") 63 | if any(y <= 0 for y in self.sigma): 64 | raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") 65 | 66 | data_scale = data_range[1] - data_range[0] 67 | self.c1 = (k1 * data_scale)**2 68 | self.c2 = (k2 * data_scale)**2 69 | self.pad_h = (self.kernel_size[0] - 1) // 2 70 | self.pad_w = (self.kernel_size[1] - 1) // 2 71 | self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) 72 | 73 | def _uniform(self, kernel_size): 74 | max, min = 2.5, -2.5 75 | ksize_half = (kernel_size - 1) * 0.5 76 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 77 | for i, j in enumerate(kernel): 78 | if min <= j <= max: 79 | kernel[i] = 1 / (max - min) 80 | else: 81 | kernel[i] = 0 82 | 83 | return kernel.unsqueeze(dim=0) # (1, kernel_size) 84 | 85 | def _gaussian(self, kernel_size, sigma): 86 | ksize_half = (kernel_size - 1) * 0.5 87 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 88 | gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) 89 | return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) 90 | 91 | def _gaussian_or_uniform_kernel(self, kernel_size, sigma): 92 | if self.gaussian: 93 | kernel_x = self._gaussian(kernel_size[0], sigma[0]) 94 | kernel_y = self._gaussian(kernel_size[1], sigma[1]) 95 | else: 96 | kernel_x = self._uniform(kernel_size[0]) 97 | kernel_y = self._uniform(kernel_size[1]) 98 | 99 | return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) 100 | 101 | def __call__(self, output, target, reduction='mean'): 102 | if output.dtype != target.dtype: 103 | raise TypeError( 104 | f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." 105 | ) 106 | 107 | if output.shape != target.shape: 108 | raise ValueError( 109 | f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." 110 | ) 111 | 112 | if len(output.shape) != 4 or len(target.shape) != 4: 113 | raise ValueError( 114 | f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." 115 | ) 116 | 117 | assert reduction in ['mean', 'sum', 'none'] 118 | 119 | channel = output.size(1) 120 | if len(self._kernel.shape) < 4: 121 | self._kernel = self._kernel.expand(channel, 1, -1, -1) 122 | 123 | output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") 124 | target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") 125 | 126 | input_list = torch.cat([output, target, output * output, target * target, output * target]) 127 | outputs = F.conv2d(input_list, self._kernel, groups=channel) 128 | 129 | output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] 130 | 131 | mu_pred_sq = output_list[0].pow(2) 132 | mu_target_sq = output_list[1].pow(2) 133 | mu_pred_target = output_list[0] * output_list[1] 134 | 135 | sigma_pred_sq = output_list[2] - mu_pred_sq 136 | sigma_target_sq = output_list[3] - mu_target_sq 137 | sigma_pred_target = output_list[4] - mu_pred_target 138 | 139 | a1 = 2 * mu_pred_target + self.c1 140 | a2 = 2 * sigma_pred_target + self.c2 141 | b1 = mu_pred_sq + mu_target_sq + self.c1 142 | b2 = sigma_pred_sq + sigma_target_sq + self.c2 143 | 144 | ssim_idx = (a1 * a2) / (b1 * b2) 145 | _ssim = torch.mean(ssim_idx, (1, 2, 3)) 146 | 147 | if reduction == 'none': 148 | return _ssim 149 | elif reduction == 'sum': 150 | return _ssim.sum() 151 | elif reduction == 'mean': 152 | return _ssim.mean() 153 | 154 | 155 | def binary_cross_entropy(input, target, reduction='mean'): 156 | """ 157 | F.binary_cross_entropy is not numerically stable in mixed-precision training. 158 | """ 159 | loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input)) 160 | 161 | if reduction == 'mean': 162 | return loss.mean() 163 | elif reduction == 'none': 164 | return loss 165 | -------------------------------------------------------------------------------- /instant-nsr-pl/systems/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_efficient_distloss import flatten_eff_distloss 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug 8 | 9 | import models 10 | from models.ray_utils import get_rays 11 | import systems 12 | from systems.base import BaseSystem 13 | from systems.criterions import PSNR 14 | 15 | 16 | @systems.register('nerf-system') 17 | class NeRFSystem(BaseSystem): 18 | """ 19 | Two ways to print to console: 20 | 1. self.print: correctly handle progress bar 21 | 2. rank_zero_info: use the logging module 22 | """ 23 | def prepare(self): 24 | self.criterions = { 25 | 'psnr': PSNR() 26 | } 27 | self.train_num_samples = self.config.model.train_num_rays * self.config.model.num_samples_per_ray 28 | self.train_num_rays = self.config.model.train_num_rays 29 | 30 | def forward(self, batch): 31 | return self.model(batch['rays']) 32 | 33 | def preprocess_data(self, batch, stage): 34 | if 'index' in batch: # validation / testing 35 | index = batch['index'] 36 | else: 37 | if self.config.model.batch_image_sampling: 38 | index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) 39 | else: 40 | index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) 41 | if stage in ['train']: 42 | c2w = self.dataset.all_c2w[index] 43 | x = torch.randint( 44 | 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device 45 | ) 46 | y = torch.randint( 47 | 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device 48 | ) 49 | if self.dataset.directions.ndim == 3: # (H, W, 3) 50 | directions = self.dataset.directions[y, x] 51 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 52 | directions = self.dataset.directions[index, y, x] 53 | rays_o, rays_d = get_rays(directions, c2w) 54 | rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 55 | fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) 56 | else: 57 | c2w = self.dataset.all_c2w[index][0] 58 | if self.dataset.directions.ndim == 3: # (H, W, 3) 59 | directions = self.dataset.directions 60 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3) 61 | directions = self.dataset.directions[index][0] 62 | rays_o, rays_d = get_rays(directions, c2w) 63 | rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) 64 | fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) 65 | 66 | rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) 67 | 68 | if stage in ['train']: 69 | if self.config.model.background_color == 'white': 70 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 71 | elif self.config.model.background_color == 'random': 72 | self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) 73 | else: 74 | raise NotImplementedError 75 | else: 76 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) 77 | 78 | if self.dataset.apply_mask: 79 | rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) 80 | 81 | batch.update({ 82 | 'rays': rays, 83 | 'rgb': rgb, 84 | 'fg_mask': fg_mask 85 | }) 86 | 87 | def training_step(self, batch, batch_idx): 88 | out = self(batch) 89 | 90 | loss = 0. 91 | 92 | # update train_num_rays 93 | if self.config.model.dynamic_ray_sampling: 94 | train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples'].sum().item())) 95 | self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) 96 | 97 | loss_rgb = F.smooth_l1_loss(out['comp_rgb'][out['rays_valid'][...,0]], batch['rgb'][out['rays_valid'][...,0]]) 98 | self.log('train/loss_rgb', loss_rgb) 99 | loss += loss_rgb * self.C(self.config.system.loss.lambda_rgb) 100 | 101 | # distortion loss proposed in MipNeRF360 102 | # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss, but still slows down training by ~30% 103 | if self.C(self.config.system.loss.lambda_distortion) > 0: 104 | loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) 105 | self.log('train/loss_distortion', loss_distortion) 106 | loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) 107 | 108 | losses_model_reg = self.model.regularizations(out) 109 | for name, value in losses_model_reg.items(): 110 | self.log(f'train/loss_{name}', value) 111 | loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) 112 | loss += loss_ 113 | 114 | for name, value in self.config.system.loss.items(): 115 | if name.startswith('lambda'): 116 | self.log(f'train_params/{name}', self.C(value)) 117 | 118 | self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) 119 | 120 | return { 121 | 'loss': loss 122 | } 123 | 124 | """ 125 | # aggregate outputs from different devices (DP) 126 | def training_step_end(self, out): 127 | pass 128 | """ 129 | 130 | """ 131 | # aggregate outputs from different iterations 132 | def training_epoch_end(self, out): 133 | pass 134 | """ 135 | 136 | def validation_step(self, batch, batch_idx): 137 | out = self(batch) 138 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) 139 | W, H = self.dataset.img_wh 140 | self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ 141 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 142 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 143 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 144 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} 145 | ]) 146 | return { 147 | 'psnr': psnr, 148 | 'index': batch['index'] 149 | } 150 | 151 | 152 | """ 153 | # aggregate outputs from different devices when using DP 154 | def validation_step_end(self, out): 155 | pass 156 | """ 157 | 158 | def validation_epoch_end(self, out): 159 | out = self.all_gather(out) 160 | if self.trainer.is_global_zero: 161 | out_set = {} 162 | for step_out in out: 163 | # DP 164 | if step_out['index'].ndim == 1: 165 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 166 | # DDP 167 | else: 168 | for oi, index in enumerate(step_out['index']): 169 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 170 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 171 | self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) 172 | 173 | def test_step(self, batch, batch_idx): 174 | out = self(batch) 175 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) 176 | W, H = self.dataset.img_wh 177 | self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ 178 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 179 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, 180 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, 181 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} 182 | ]) 183 | return { 184 | 'psnr': psnr, 185 | 'index': batch['index'] 186 | } 187 | 188 | def test_epoch_end(self, out): 189 | out = self.all_gather(out) 190 | if self.trainer.is_global_zero: 191 | out_set = {} 192 | for step_out in out: 193 | # DP 194 | if step_out['index'].ndim == 1: 195 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} 196 | # DDP 197 | else: 198 | for oi, index in enumerate(step_out['index']): 199 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} 200 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) 201 | self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) 202 | 203 | self.save_img_sequence( 204 | f"it{self.global_step}-test", 205 | f"it{self.global_step}-test", 206 | '(\d+)\.png', 207 | save_format='mp4', 208 | fps=30 209 | ) 210 | 211 | self.export() 212 | 213 | def export(self): 214 | mesh = self.model.export(self.config.export) 215 | self.save_mesh( 216 | f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", 217 | **mesh 218 | ) 219 | -------------------------------------------------------------------------------- /instant-nsr-pl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/instant-nsr-pl/utils/__init__.py -------------------------------------------------------------------------------- /instant-nsr-pl/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import shutil 4 | from utils.misc import dump_config, parse_version 5 | 6 | 7 | import pytorch_lightning 8 | if parse_version(pytorch_lightning.__version__) > parse_version('1.8'): 9 | from pytorch_lightning.callbacks import Callback 10 | else: 11 | from pytorch_lightning.callbacks.base import Callback 12 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 13 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 14 | 15 | 16 | class VersionedCallback(Callback): 17 | def __init__(self, save_root, version=None, use_version=True): 18 | self.save_root = save_root 19 | self._version = version 20 | self.use_version = use_version 21 | 22 | @property 23 | def version(self) -> int: 24 | """Get the experiment version. 25 | 26 | Returns: 27 | The experiment version if specified else the next version. 28 | """ 29 | if self._version is None: 30 | self._version = self._get_next_version() 31 | return self._version 32 | 33 | def _get_next_version(self): 34 | existing_versions = [] 35 | if os.path.isdir(self.save_root): 36 | for f in os.listdir(self.save_root): 37 | bn = os.path.basename(f) 38 | if bn.startswith("version_"): 39 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 40 | existing_versions.append(int(dir_ver)) 41 | if len(existing_versions) == 0: 42 | return 0 43 | return max(existing_versions) + 1 44 | 45 | @property 46 | def savedir(self): 47 | if not self.use_version: 48 | return self.save_root 49 | return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}") 50 | 51 | 52 | class CodeSnapshotCallback(VersionedCallback): 53 | def __init__(self, save_root, version=None, use_version=True): 54 | super().__init__(save_root, version, use_version) 55 | 56 | def get_file_list(self): 57 | return [ 58 | b.decode() for b in 59 | set(subprocess.check_output('git ls-files', shell=True).splitlines()) | 60 | set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines()) 61 | ] 62 | 63 | @rank_zero_only 64 | def save_code_snapshot(self): 65 | os.makedirs(self.savedir, exist_ok=True) 66 | for f in self.get_file_list(): 67 | if not os.path.exists(f) or os.path.isdir(f): 68 | continue 69 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 70 | shutil.copyfile(f, os.path.join(self.savedir, f)) 71 | 72 | def on_fit_start(self, trainer, pl_module): 73 | try: 74 | self.save_code_snapshot() 75 | except: 76 | rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.") 77 | 78 | 79 | class ConfigSnapshotCallback(VersionedCallback): 80 | def __init__(self, config, save_root, version=None, use_version=True): 81 | super().__init__(save_root, version, use_version) 82 | self.config = config 83 | 84 | @rank_zero_only 85 | def save_config_snapshot(self): 86 | os.makedirs(self.savedir, exist_ok=True) 87 | dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) 88 | shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml')) 89 | 90 | def on_fit_start(self, trainer, pl_module): 91 | self.save_config_snapshot() 92 | 93 | 94 | class CustomProgressBar(TQDMProgressBar): 95 | def get_metrics(self, *args, **kwargs): 96 | # don't show the version number 97 | items = super().get_metrics(*args, **kwargs) 98 | items.pop("v_num", None) 99 | return items 100 | -------------------------------------------------------------------------------- /instant-nsr-pl/utils/loggers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pprint 3 | import logging 4 | 5 | from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment 6 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 7 | 8 | 9 | class ConsoleLogger(LightningLoggerBase): 10 | def __init__(self, log_keys=[]): 11 | super().__init__() 12 | self.log_keys = [re.compile(k) for k in log_keys] 13 | self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat 14 | 15 | def match_log_keys(self, s): 16 | return True if not self.log_keys else any(r.search(s) for r in self.log_keys) 17 | 18 | @property 19 | def name(self): 20 | return 'console' 21 | 22 | @property 23 | def version(self): 24 | return '0' 25 | 26 | @property 27 | @rank_zero_experiment 28 | def experiment(self): 29 | return logging.getLogger('pytorch_lightning') 30 | 31 | @rank_zero_only 32 | def log_hyperparams(self, params): 33 | pass 34 | 35 | @rank_zero_only 36 | def log_metrics(self, metrics, step): 37 | metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} 38 | if not metrics_: 39 | return 40 | self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}") 41 | 42 | -------------------------------------------------------------------------------- /instant-nsr-pl/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | from packaging import version 4 | 5 | 6 | # ============ Register OmegaConf Recolvers ============= # 7 | OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) 8 | OmegaConf.register_new_resolver('add', lambda a, b: a + b) 9 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b) 10 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b) 11 | OmegaConf.register_new_resolver('div', lambda a, b: a / b) 12 | OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) 13 | OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) 14 | # ======================================================= # 15 | 16 | 17 | def prompt(question): 18 | inp = input(f"{question} (y/n)").lower().strip() 19 | if inp and inp == 'y': 20 | return True 21 | if inp and inp == 'n': 22 | return False 23 | return prompt(question) 24 | 25 | 26 | def load_config(*yaml_files, cli_args=[]): 27 | yaml_confs = [OmegaConf.load(f) for f in yaml_files] 28 | cli_conf = OmegaConf.from_cli(cli_args) 29 | conf = OmegaConf.merge(*yaml_confs, cli_conf) 30 | OmegaConf.resolve(conf) 31 | return conf 32 | 33 | 34 | def config_to_primitive(config, resolve=True): 35 | return OmegaConf.to_container(config, resolve=resolve) 36 | 37 | 38 | def dump_config(path, config): 39 | with open(path, 'w') as fp: 40 | OmegaConf.save(config=config, f=fp) 41 | 42 | def get_rank(): 43 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 44 | # therefore LOCAL_RANK needs to be checked first 45 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 46 | for key in rank_keys: 47 | rank = os.environ.get(key) 48 | if rank is not None: 49 | return int(rank) 50 | return 0 51 | 52 | 53 | def parse_version(ver): 54 | return version.parse(ver) 55 | -------------------------------------------------------------------------------- /instant-nsr-pl/utils/mixins.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import numpy as np 5 | import cv2 6 | import imageio 7 | from matplotlib import cm 8 | from matplotlib.colors import LinearSegmentedColormap 9 | import json 10 | 11 | import torch 12 | 13 | from utils.obj import write_obj 14 | 15 | 16 | class SaverMixin(): 17 | @property 18 | def save_dir(self): 19 | return self.config.save_dir 20 | 21 | def convert_data(self, data): 22 | if isinstance(data, np.ndarray): 23 | return data 24 | elif isinstance(data, torch.Tensor): 25 | return data.cpu().numpy() 26 | elif isinstance(data, list): 27 | return [self.convert_data(d) for d in data] 28 | elif isinstance(data, dict): 29 | return {k: self.convert_data(v) for k, v in data.items()} 30 | else: 31 | raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data)) 32 | 33 | def get_save_path(self, filename): 34 | save_path = os.path.join(self.save_dir, filename) 35 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 36 | return save_path 37 | 38 | DEFAULT_RGB_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1)} 39 | DEFAULT_UV_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1), 'cmap': 'checkerboard'} 40 | DEFAULT_GRAYSCALE_KWARGS = {'data_range': None, 'cmap': 'jet'} 41 | 42 | def get_rgb_image_(self, img, data_format, data_range): 43 | img = self.convert_data(img) 44 | assert data_format in ['CHW', 'HWC'] 45 | if data_format == 'CHW': 46 | img = img.transpose(1, 2, 0) 47 | img = img.clip(min=data_range[0], max=data_range[1]) 48 | img = ((img - data_range[0]) / (data_range[1] - data_range[0]) * 255.).astype(np.uint8) 49 | imgs = [img[...,start:start+3] for start in range(0, img.shape[-1], 3)] 50 | imgs = [img_ if img_.shape[-1] == 3 else np.concatenate([img_, np.zeros((img_.shape[0], img_.shape[1], 3 - img_.shape[2]), dtype=img_.dtype)], axis=-1) for img_ in imgs] 51 | img = np.concatenate(imgs, axis=1) 52 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 53 | return img 54 | 55 | def save_rgb_image(self, filename, img, data_format=DEFAULT_RGB_KWARGS['data_format'], data_range=DEFAULT_RGB_KWARGS['data_range']): 56 | img = self.get_rgb_image_(img, data_format, data_range) 57 | cv2.imwrite(self.get_save_path(filename), img) 58 | 59 | def get_uv_image_(self, img, data_format, data_range, cmap): 60 | img = self.convert_data(img) 61 | assert data_format in ['CHW', 'HWC'] 62 | if data_format == 'CHW': 63 | img = img.transpose(1, 2, 0) 64 | img = img.clip(min=data_range[0], max=data_range[1]) 65 | img = (img - data_range[0]) / (data_range[1] - data_range[0]) 66 | assert cmap in ['checkerboard', 'color'] 67 | if cmap == 'checkerboard': 68 | n_grid = 64 69 | mask = (img * n_grid).astype(int) 70 | mask = (mask[...,0] + mask[...,1]) % 2 == 0 71 | img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 72 | img[mask] = np.array([255, 0, 255], dtype=np.uint8) 73 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 74 | elif cmap == 'color': 75 | img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) 76 | img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) 77 | img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) 78 | img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) 79 | img = img_ 80 | return img 81 | 82 | def save_uv_image(self, filename, img, data_format=DEFAULT_UV_KWARGS['data_format'], data_range=DEFAULT_UV_KWARGS['data_range'], cmap=DEFAULT_UV_KWARGS['cmap']): 83 | img = self.get_uv_image_(img, data_format, data_range, cmap) 84 | cv2.imwrite(self.get_save_path(filename), img) 85 | 86 | def get_grayscale_image_(self, img, data_range, cmap): 87 | img = self.convert_data(img) 88 | img = np.nan_to_num(img) 89 | if data_range is None: 90 | img = (img - img.min()) / (img.max() - img.min()) 91 | else: 92 | img = img.clip(data_range[0], data_range[1]) 93 | img = (img - data_range[0]) / (data_range[1] - data_range[0]) 94 | assert cmap in [None, 'jet', 'magma'] 95 | if cmap == None: 96 | img = (img * 255.).astype(np.uint8) 97 | img = np.repeat(img[...,None], 3, axis=2) 98 | elif cmap == 'jet': 99 | img = (img * 255.).astype(np.uint8) 100 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 101 | elif cmap == 'magma': 102 | img = 1. - img 103 | base = cm.get_cmap('magma') 104 | num_bins = 256 105 | colormap = LinearSegmentedColormap.from_list( 106 | f"{base.name}{num_bins}", 107 | base(np.linspace(0, 1, num_bins)), 108 | num_bins 109 | )(np.linspace(0, 1, num_bins))[:,:3] 110 | a = np.floor(img * 255.) 111 | b = (a + 1).clip(max=255.) 112 | f = img * 255. - a 113 | a = a.astype(np.uint16).clip(0, 255) 114 | b = b.astype(np.uint16).clip(0, 255) 115 | img = colormap[a] + (colormap[b] - colormap[a]) * f[...,None] 116 | img = (img * 255.).astype(np.uint8) 117 | return img 118 | 119 | def save_grayscale_image(self, filename, img, data_range=DEFAULT_GRAYSCALE_KWARGS['data_range'], cmap=DEFAULT_GRAYSCALE_KWARGS['cmap']): 120 | img = self.get_grayscale_image_(img, data_range, cmap) 121 | cv2.imwrite(self.get_save_path(filename), img) 122 | 123 | def get_image_grid_(self, imgs): 124 | if isinstance(imgs[0], list): 125 | return np.concatenate([self.get_image_grid_(row) for row in imgs], axis=0) 126 | cols = [] 127 | for col in imgs: 128 | assert col['type'] in ['rgb', 'uv', 'grayscale'] 129 | if col['type'] == 'rgb': 130 | rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() 131 | rgb_kwargs.update(col['kwargs']) 132 | cols.append(self.get_rgb_image_(col['img'], **rgb_kwargs)) 133 | elif col['type'] == 'uv': 134 | uv_kwargs = self.DEFAULT_UV_KWARGS.copy() 135 | uv_kwargs.update(col['kwargs']) 136 | cols.append(self.get_uv_image_(col['img'], **uv_kwargs)) 137 | elif col['type'] == 'grayscale': 138 | grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() 139 | grayscale_kwargs.update(col['kwargs']) 140 | cols.append(self.get_grayscale_image_(col['img'], **grayscale_kwargs)) 141 | return np.concatenate(cols, axis=1) 142 | 143 | def save_image_grid(self, filename, imgs): 144 | img = self.get_image_grid_(imgs) 145 | cv2.imwrite(self.get_save_path(filename), img) 146 | 147 | def save_image(self, filename, img): 148 | img = self.convert_data(img) 149 | assert img.dtype == np.uint8 150 | if img.shape[-1] == 3: 151 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 152 | elif img.shape[-1] == 4: 153 | img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) 154 | cv2.imwrite(self.get_save_path(filename), img) 155 | 156 | def save_cubemap(self, filename, img, data_range=(0, 1)): 157 | img = self.convert_data(img) 158 | assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] 159 | 160 | imgs_full = [] 161 | for start in range(0, img.shape[-1], 3): 162 | img_ = img[...,start:start+3] 163 | img_ = np.stack([self.get_rgb_image_(img_[i], 'HWC', data_range) for i in range(img_.shape[0])], axis=0) 164 | size = img_.shape[1] 165 | placeholder = np.zeros((size, size, 3), dtype=np.float32) 166 | img_full = np.concatenate([ 167 | np.concatenate([placeholder, img_[2], placeholder, placeholder], axis=1), 168 | np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), 169 | np.concatenate([placeholder, img_[3], placeholder, placeholder], axis=1) 170 | ], axis=0) 171 | img_full = cv2.cvtColor(img_full, cv2.COLOR_RGB2BGR) 172 | imgs_full.append(img_full) 173 | 174 | imgs_full = np.concatenate(imgs_full, axis=1) 175 | cv2.imwrite(self.get_save_path(filename), imgs_full) 176 | 177 | def save_data(self, filename, data): 178 | data = self.convert_data(data) 179 | if isinstance(data, dict): 180 | if not filename.endswith('.npz'): 181 | filename += '.npz' 182 | np.savez(self.get_save_path(filename), **data) 183 | else: 184 | if not filename.endswith('.npy'): 185 | filename += '.npy' 186 | np.save(self.get_save_path(filename), data) 187 | 188 | def save_state_dict(self, filename, data): 189 | torch.save(data, self.get_save_path(filename)) 190 | 191 | def save_img_sequence(self, filename, img_dir, matcher, save_format='gif', fps=30): 192 | assert save_format in ['gif', 'mp4'] 193 | if not filename.endswith(save_format): 194 | filename += f".{save_format}" 195 | matcher = re.compile(matcher) 196 | img_dir = os.path.join(self.save_dir, img_dir) 197 | imgs = [] 198 | for f in os.listdir(img_dir): 199 | if matcher.search(f): 200 | imgs.append(f) 201 | imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) 202 | imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] 203 | 204 | if save_format == 'gif': 205 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 206 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps, palettesize=256) 207 | elif save_format == 'mp4': 208 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 209 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps) 210 | 211 | def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None, v_rgb=None, ortho_scale=1): 212 | v_pos, t_pos_idx = self.convert_data(v_pos), self.convert_data(t_pos_idx) 213 | if v_rgb is not None: 214 | v_rgb = self.convert_data(v_rgb) 215 | 216 | if ortho_scale is not None: 217 | print("ortho scale is: ", ortho_scale) 218 | v_pos = v_pos * ortho_scale * 0.5 219 | 220 | # change to front-facing 221 | v_pos_copy = np.zeros_like(v_pos) 222 | v_pos_copy[:, 0] = v_pos[:, 0] 223 | v_pos_copy[:, 1] = v_pos[:, 2] 224 | v_pos_copy[:, 2] = v_pos[:, 1] 225 | 226 | import trimesh 227 | mesh = trimesh.Trimesh( 228 | vertices=v_pos_copy, 229 | faces=t_pos_idx, 230 | vertex_colors=v_rgb 231 | ) 232 | trimesh.repair.fix_inversion(mesh) 233 | mesh.export(self.get_save_path(filename)) 234 | # mesh.export(self.get_save_path(filename.replace(".obj", "-meshlab.obj"))) 235 | 236 | # v_pos_copy[:, 0] = v_pos[:, 1] * -1 237 | # v_pos_copy[:, 1] = v_pos[:, 0] 238 | # v_pos_copy[:, 2] = v_pos[:, 2] 239 | 240 | # mesh = trimesh.Trimesh( 241 | # vertices=v_pos_copy, 242 | # faces=t_pos_idx, 243 | # vertex_colors=v_rgb 244 | # ) 245 | # mesh.export(self.get_save_path(filename.replace(".obj", "-blender.obj"))) 246 | 247 | 248 | # v_pos_copy[:, 0] = v_pos[:, 0] 249 | # v_pos_copy[:, 1] = v_pos[:, 1] * -1 250 | # v_pos_copy[:, 2] = v_pos[:, 2] * -1 251 | 252 | # mesh = trimesh.Trimesh( 253 | # vertices=v_pos_copy, 254 | # faces=t_pos_idx, 255 | # vertex_colors=v_rgb 256 | # ) 257 | # mesh.export(self.get_save_path(filename.replace(".obj", "-opengl.obj"))) 258 | 259 | def save_file(self, filename, src_path): 260 | shutil.copyfile(src_path, self.get_save_path(filename)) 261 | 262 | def save_json(self, filename, payload): 263 | with open(self.get_save_path(filename), 'w') as f: 264 | f.write(json.dumps(payload)) 265 | -------------------------------------------------------------------------------- /instant-nsr-pl/utils/obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def load_obj(filename): 5 | # Read entire file 6 | with open(filename, 'r') as f: 7 | lines = f.readlines() 8 | 9 | # load vertices 10 | vertices, texcoords = [], [] 11 | for line in lines: 12 | if len(line.split()) == 0: 13 | continue 14 | 15 | prefix = line.split()[0].lower() 16 | if prefix == 'v': 17 | vertices.append([float(v) for v in line.split()[1:]]) 18 | elif prefix == 'vt': 19 | val = [float(v) for v in line.split()[1:]] 20 | texcoords.append([val[0], 1.0 - val[1]]) 21 | 22 | uv = len(texcoords) > 0 23 | faces, tfaces = [], [] 24 | for line in lines: 25 | if len(line.split()) == 0: 26 | continue 27 | prefix = line.split()[0].lower() 28 | if prefix == 'usemtl': # Track used materials 29 | pass 30 | elif prefix == 'f': # Parse face 31 | vs = line.split()[1:] 32 | nv = len(vs) 33 | vv = vs[0].split('/') 34 | v0 = int(vv[0]) - 1 35 | if uv: 36 | t0 = int(vv[1]) - 1 if vv[1] != "" else -1 37 | for i in range(nv - 2): # Triangulate polygons 38 | vv1 = vs[i + 1].split('/') 39 | v1 = int(vv1[0]) - 1 40 | vv2 = vs[i + 2].split('/') 41 | v2 = int(vv2[0]) - 1 42 | faces.append([v0, v1, v2]) 43 | if uv: 44 | t1 = int(vv1[1]) - 1 if vv1[1] != "" else -1 45 | t2 = int(vv2[1]) - 1 if vv2[1] != "" else -1 46 | tfaces.append([t0, t1, t2]) 47 | vertices = np.array(vertices, dtype=np.float32) 48 | faces = np.array(faces, dtype=np.int64) 49 | if uv: 50 | assert len(tfaces) == len(faces) 51 | texcoords = np.array(texcoords, dtype=np.float32) 52 | tfaces = np.array(tfaces, dtype=np.int64) 53 | else: 54 | texcoords, tfaces = None, None 55 | 56 | return vertices, faces, texcoords, tfaces 57 | 58 | 59 | def write_obj(filename, v_pos, t_pos_idx, v_tex, t_tex_idx): 60 | with open(filename, "w") as f: 61 | for v in v_pos: 62 | f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) 63 | 64 | if v_tex is not None: 65 | assert(len(t_pos_idx) == len(t_tex_idx)) 66 | for v in v_tex: 67 | f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) 68 | 69 | # Write faces 70 | for i in range(len(t_pos_idx)): 71 | f.write("f ") 72 | for j in range(3): 73 | f.write(' %s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1))) 74 | f.write("\n") 75 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt: -------------------------------------------------------------------------------- 1 | -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08 2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08 3 | 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt: -------------------------------------------------------------------------------- 1 | -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07 2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07 3 | 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt: -------------------------------------------------------------------------------- 1 | 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07 2 | -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08 3 | 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt: -------------------------------------------------------------------------------- 1 | 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00 2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08 3 | -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt: -------------------------------------------------------------------------------- 1 | -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07 2 | 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07 3 | -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt: -------------------------------------------------------------------------------- 1 | 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07 2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07 3 | -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt: -------------------------------------------------------------------------------- 1 | -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00 2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08 3 | -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt: -------------------------------------------------------------------------------- 1 | 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08 2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08 3 | 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt: -------------------------------------------------------------------------------- 1 | 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09 2 | -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08 3 | -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00 4 | -------------------------------------------------------------------------------- /mvdiffusion/data/normal_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def camNormal2worldNormal(rot_c2w, camNormal): 4 | H,W,_ = camNormal.shape 5 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 6 | 7 | return normal_img 8 | 9 | def worldNormal2camNormal(rot_w2c, normal_map_world): 10 | H,W,_ = normal_map_world.shape 11 | # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) 12 | 13 | # faster version 14 | # Reshape the normal map into a 2D array where each row represents a normal vector 15 | normal_map_flat = normal_map_world.reshape(-1, 3) 16 | 17 | # Transform the normal vectors using the transformation matrix 18 | normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T) 19 | 20 | # Reshape the transformed normal map back to its original shape 21 | normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape) 22 | 23 | return normal_map_camera 24 | 25 | def trans_normal(normal, RT_w2c, RT_w2c_target): 26 | 27 | # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) 28 | # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) 29 | 30 | relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3])) 31 | normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal) 32 | 33 | return normal_target_cam 34 | 35 | def img2normal(img): 36 | return (img/255.)*2-1 37 | 38 | def normal2img(normal): 39 | return np.uint8((normal*0.5+0.5)*255) 40 | 41 | def norm_normalize(normal, dim=-1): 42 | 43 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) 44 | 45 | return normal -------------------------------------------------------------------------------- /mvdiffusion/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | from omegaconf import DictConfig, ListConfig 4 | import torch 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import json 8 | from PIL import Image 9 | from torchvision import transforms 10 | from einops import rearrange 11 | from typing import Literal, Tuple, Optional, Any 12 | import cv2 13 | import random 14 | 15 | import json 16 | import os, sys 17 | import math 18 | 19 | from glob import glob 20 | 21 | import PIL.Image 22 | from .normal_utils import trans_normal, normal2img, img2normal 23 | import pdb 24 | 25 | 26 | import cv2 27 | import numpy as np 28 | 29 | def add_margin(pil_img, color=0, size=256): 30 | width, height = pil_img.size 31 | result = Image.new(pil_img.mode, (size, size), color) 32 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) 33 | return result 34 | 35 | def scale_and_place_object(image, scale_factor): 36 | assert np.shape(image)[-1]==4 # RGBA 37 | 38 | # Extract the alpha channel (transparency) and the object (RGB channels) 39 | alpha_channel = image[:, :, 3] 40 | 41 | # Find the bounding box coordinates of the object 42 | coords = cv2.findNonZero(alpha_channel) 43 | x, y, width, height = cv2.boundingRect(coords) 44 | 45 | # Calculate the scale factor for resizing 46 | original_height, original_width = image.shape[:2] 47 | 48 | if width > height: 49 | size = width 50 | original_size = original_width 51 | else: 52 | size = height 53 | original_size = original_height 54 | 55 | scale_factor = min(scale_factor, size / (original_size+0.0)) 56 | 57 | new_size = scale_factor * original_size 58 | scale_factor = new_size / size 59 | 60 | # Calculate the new size based on the scale factor 61 | new_width = int(width * scale_factor) 62 | new_height = int(height * scale_factor) 63 | 64 | center_x = original_width // 2 65 | center_y = original_height // 2 66 | 67 | paste_x = center_x - (new_width // 2) 68 | paste_y = center_y - (new_height // 2) 69 | 70 | # Resize the object (RGB channels) to the new size 71 | rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height)) 72 | 73 | # Create a new RGBA image with the resized image 74 | new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8) 75 | 76 | new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object 77 | 78 | return new_image 79 | 80 | class SingleImageDataset(Dataset): 81 | def __init__(self, 82 | root_dir: str, 83 | num_views: int, 84 | img_wh: Tuple[int, int], 85 | bg_color: str, 86 | crop_size: int = 224, 87 | single_image: Optional[PIL.Image.Image] = None, 88 | num_validation_samples: Optional[int] = None, 89 | filepaths: Optional[list] = None, 90 | cond_type: Optional[str] = None 91 | ) -> None: 92 | """Create a dataset from a folder of images. 93 | If you pass in a root directory it will be searched for images 94 | ending in ext (ext can be a list) 95 | """ 96 | self.root_dir = root_dir 97 | self.num_views = num_views 98 | self.img_wh = img_wh 99 | self.crop_size = crop_size 100 | self.bg_color = bg_color 101 | self.cond_type = cond_type 102 | 103 | if self.num_views == 4: 104 | self.view_types = ['front', 'right', 'back', 'left'] 105 | elif self.num_views == 5: 106 | self.view_types = ['front', 'front_right', 'right', 'back', 'left'] 107 | elif self.num_views == 6: 108 | self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] 109 | 110 | self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views" 111 | 112 | self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix 113 | 114 | if single_image is None: 115 | if filepaths is None: 116 | # Get a list of all files in the directory 117 | file_list = os.listdir(self.root_dir) 118 | else: 119 | file_list = filepaths 120 | 121 | # Filter the files that end with .png or .jpg 122 | self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))] 123 | else: 124 | self.file_list = None 125 | 126 | # load all images 127 | self.all_images = [] 128 | self.all_alphas = [] 129 | bg_color = self.get_bg_color() 130 | 131 | if single_image is not None: 132 | image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image) 133 | self.all_images.append(image) 134 | self.all_alphas.append(alpha) 135 | else: 136 | for file in self.file_list: 137 | print(os.path.join(self.root_dir, file)) 138 | image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt') 139 | self.all_images.append(image) 140 | self.all_alphas.append(alpha) 141 | 142 | self.all_images = self.all_images[:num_validation_samples] 143 | self.all_alphas = self.all_alphas[:num_validation_samples] 144 | 145 | 146 | def __len__(self): 147 | return len(self.all_images) 148 | 149 | def load_fixed_poses(self): 150 | poses = {} 151 | for face in self.view_types: 152 | RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face))) 153 | poses[face] = RT 154 | 155 | return poses 156 | 157 | def cartesian_to_spherical(self, xyz): 158 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 159 | xy = xyz[:,0]**2 + xyz[:,1]**2 160 | z = np.sqrt(xy + xyz[:,2]**2) 161 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down 162 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up 163 | azimuth = np.arctan2(xyz[:,1], xyz[:,0]) 164 | return np.array([theta, azimuth, z]) 165 | 166 | def get_T(self, target_RT, cond_RT): 167 | R, T = target_RT[:3, :3], target_RT[:, -1] 168 | T_target = -R.T @ T # change to cam2world 169 | 170 | R, T = cond_RT[:3, :3], cond_RT[:, -1] 171 | T_cond = -R.T @ T 172 | 173 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) 174 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) 175 | 176 | d_theta = theta_target - theta_cond 177 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 178 | d_z = z_target - z_cond 179 | 180 | # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) 181 | return d_theta, d_azimuth 182 | 183 | def get_bg_color(self): 184 | if self.bg_color == 'white': 185 | bg_color = np.array([1., 1., 1.], dtype=np.float32) 186 | elif self.bg_color == 'black': 187 | bg_color = np.array([0., 0., 0.], dtype=np.float32) 188 | elif self.bg_color == 'gray': 189 | bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) 190 | elif self.bg_color == 'random': 191 | bg_color = np.random.rand(3) 192 | elif isinstance(self.bg_color, float): 193 | bg_color = np.array([self.bg_color] * 3, dtype=np.float32) 194 | else: 195 | raise NotImplementedError 196 | return bg_color 197 | 198 | 199 | def load_image(self, img_path, bg_color, return_type='np', Imagefile=None): 200 | # pil always returns uint8 201 | if Imagefile is None: 202 | image_input = Image.open(img_path) 203 | else: 204 | image_input = Imagefile 205 | image_size = self.img_wh[0] 206 | 207 | if self.crop_size!=-1: 208 | alpha_np = np.asarray(image_input)[:, :, 3] 209 | coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] 210 | min_x, min_y = np.min(coords, 0) 211 | max_x, max_y = np.max(coords, 0) 212 | ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) 213 | h, w = ref_img_.height, ref_img_.width 214 | scale = self.crop_size / max(h, w) 215 | h_, w_ = int(scale * h), int(scale * w) 216 | ref_img_ = ref_img_.resize((w_, h_)) 217 | image_input = add_margin(ref_img_, size=image_size) 218 | else: 219 | image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) 220 | image_input = image_input.resize((image_size, image_size)) 221 | 222 | # img = scale_and_place_object(img, self.scale_ratio) 223 | img = np.array(image_input) 224 | img = img.astype(np.float32) / 255. # [0, 1] 225 | assert img.shape[-1] == 4 # RGBA 226 | 227 | alpha = img[...,3:4] 228 | img = img[...,:3] * alpha + bg_color * (1 - alpha) 229 | 230 | if return_type == "np": 231 | pass 232 | elif return_type == "pt": 233 | img = torch.from_numpy(img) 234 | alpha = torch.from_numpy(alpha) 235 | else: 236 | raise NotImplementedError 237 | 238 | return img, alpha 239 | 240 | 241 | def __len__(self): 242 | return len(self.all_images) 243 | 244 | def __getitem__(self, index): 245 | 246 | image = self.all_images[index%len(self.all_images)] 247 | alpha = self.all_alphas[index%len(self.all_images)] 248 | if self.file_list is not None: 249 | filename = self.file_list[index%len(self.all_images)].replace(".png", "") 250 | else: 251 | filename = 'null' 252 | 253 | cond_w2c = self.fix_cam_poses['front'] 254 | 255 | tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types] 256 | 257 | elevations = [] 258 | azimuths = [] 259 | 260 | img_tensors_in = [ 261 | image.permute(2, 0, 1) 262 | ] * self.num_views 263 | 264 | alpha_tensors_in = [ 265 | alpha.permute(2, 0, 1) 266 | ] * self.num_views 267 | 268 | for view, tgt_w2c in zip(self.view_types, tgt_w2cs): 269 | # evelations, azimuths 270 | elevation, azimuth = self.get_T(tgt_w2c, cond_w2c) 271 | elevations.append(elevation) 272 | azimuths.append(azimuth) 273 | 274 | img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) 275 | alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W) 276 | 277 | elevations = torch.as_tensor(elevations).float().squeeze(1) 278 | azimuths = torch.as_tensor(azimuths).float().squeeze(1) 279 | elevations_cond = torch.as_tensor([0] * self.num_views).float() 280 | 281 | normal_class = torch.tensor([1, 0]).float() 282 | normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2) 283 | color_class = torch.tensor([0, 1]).float() 284 | color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2) 285 | 286 | camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3) 287 | 288 | out = { 289 | 'elevations_cond': elevations_cond, 290 | 'elevations_cond_deg': torch.rad2deg(elevations_cond), 291 | 'elevations': elevations, 292 | 'azimuths': azimuths, 293 | 'elevations_deg': torch.rad2deg(elevations), 294 | 'azimuths_deg': torch.rad2deg(azimuths), 295 | 'imgs_in': img_tensors_in, 296 | 'alphas': alpha_tensors_in, 297 | 'camera_embeddings': camera_embeddings, 298 | 'normal_task_embeddings': normal_task_embeddings, 299 | 'color_task_embeddings': color_task_embeddings, 300 | 'filename': filename, 301 | } 302 | 303 | return out 304 | 305 | 306 | -------------------------------------------------------------------------------- /render_codes/README.md: -------------------------------------------------------------------------------- 1 | # Prepare the rendering data 2 | 3 | ## Environment 4 | The rendering codes are mainly based on [BlenderProc](https://github.com/DLR-RM/BlenderProc). Thanks for the great tool. 5 | BlenderProc uses blender Cycle engine to render the images by default, which may meet long-time hanging problem in some specific GPUs (like A800, tested already) 6 | 7 | ` 8 | cd ./render_codes 9 | pip install -r requirements.txt 10 | ` 11 | 12 | ## How to use the code 13 | Here we provide two rendering scripts `blenderProc_ortho.py` and `blenderProc_persp.py`, which use **orthogonal** camera and **perspective** camera to render the objects respectively. 14 | 15 | ### Use `blenderProc_ortho.py` to render images of a single object 16 | ` 17 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender 18 | blenderProc_ortho.py 19 | --object_path /mnt/pfs/data/objaverse_lvis_glbs/c7/c70e8817b5a945aca8bb37e02ddbc6f9.glb --view 0 20 | --output_folder ./out_renderings/ 21 | --object_uid c70e8817b5a945aca8bb37e02ddbc6f9 22 | --ortho_scale 1.35 23 | --resolution 512 24 | --random_pose 25 | ` 26 | 27 | Here `--view` denotes a tag for the rendering images, since you may render an object multiple times, `--ortho_scale` decides the scaling of rendered object in the image, `--random_pose` will randomly rotate the object before rendering. 28 | 29 | 30 | ### Use `blenderProc_persp.py` to render images of a single object 31 | 32 | ` 33 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender 34 | blenderProc_persp.py 35 | --object_path ${the object path} --view 0 36 | --output_folder ${your save path} 37 | --object_uid ${object_uid} --radius 2.0 38 | --random_pose 39 | ` 40 | 41 | Here `--radius` denotes the distance of between the camera and the object origin. 42 | 43 | ### Render objects in distributed mode 44 | see `render_batch_ortho.sh` and `render_batch_persp.sh` for commands. 45 | 46 | -------------------------------------------------------------------------------- /render_codes/distributed.py: -------------------------------------------------------------------------------- 1 | # multiprocessing render 2 | import json 3 | import multiprocessing 4 | import subprocess 5 | from dataclasses import dataclass 6 | from typing import Optional 7 | import os 8 | 9 | import boto3 10 | 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description='distributed rendering') 14 | 15 | parser.add_argument('--workers_per_gpu', type=int, 16 | help='number of workers per gpu.') 17 | parser.add_argument('--input_models_path', type=str, 18 | help='Path to a json file containing a list of 3D object files.') 19 | parser.add_argument('--upload_to_s3', type=bool, default=False, 20 | help='Whether to upload the rendered images to S3.') 21 | parser.add_argument('--log_to_wandb', type=bool, default=False, 22 | help='Whether to log the progress to wandb.') 23 | parser.add_argument('--num_gpus', type=int, default=-1, 24 | help='number of gpus to use. -1 means all available gpus.') 25 | parser.add_argument('--gpu_list',nargs='+', type=int, 26 | help='the avalaible gpus') 27 | 28 | parser.add_argument('--mode', type=str, default='render', 29 | choices=['render_ortho', 'render_persp'], 30 | help='use orthogonal camera or perspective camera') 31 | 32 | parser.add_argument('--start_i', type=int, default=0, 33 | help='the index of first object to be rendered.') 34 | 35 | parser.add_argument('--end_i', type=int, default=-1, 36 | help='the index of the last object to be rendered.') 37 | 38 | parser.add_argument('--objaverse_root', type=str, default='/ghome/l5/xxlong/.objaverse/hf-objaverse-v1', 39 | help='Path to a json file containing a list of 3D object files.') 40 | 41 | parser.add_argument('--save_folder', type=str, default=None, 42 | help='Path to a json file containing a list of 3D object files.') 43 | 44 | parser.add_argument('--blender_install_path', type=str, default=None, 45 | help='blender path.') 46 | 47 | parser.add_argument('--view_idx', type=int, default=2, 48 | help='the number of render views.') 49 | 50 | parser.add_argument('--ortho_scale', type=float, default=1.25, 51 | help='ortho rendering usage; how large the object is') 52 | 53 | parser.add_argument('--random_pose', action='store_true', 54 | help='whether randomly rotate the poses to be rendered') 55 | 56 | args = parser.parse_args() 57 | 58 | 59 | view_idx = args.view_idx 60 | 61 | VIEWS = ["front", "back", "right", "left", "front_right", "front_left", "back_right", "back_left"] 62 | 63 | def check_task_finish(render_dir, view_index): 64 | files_type = ['rgb', 'normals'] 65 | flag = True 66 | view_index = "%03d" % view_index 67 | if os.path.exists(render_dir): 68 | for t in files_type: 69 | for face in VIEWS: 70 | fpath = os.path.join(render_dir, f'{t}_{view_index}_{face}.webp') 71 | # print(fpath) 72 | if not os.path.exists(fpath): 73 | flag = False 74 | else: 75 | flag = False 76 | 77 | return flag 78 | 79 | def worker( 80 | queue: multiprocessing.JoinableQueue, 81 | count: multiprocessing.Value, 82 | gpu: int, 83 | s3: Optional[boto3.client], 84 | ) -> None: 85 | while True: 86 | item = queue.get() 87 | if item is None: 88 | break 89 | 90 | view_path = os.path.join(args.save_folder, item.split('/')[-1][:2], item.split('/')[-1][:-4]) 91 | print(view_path) 92 | if 'render' in args.mode: 93 | if check_task_finish(view_path, view_idx): 94 | queue.task_done() 95 | print('========', item, 'rendered', '========') 96 | 97 | continue 98 | else: 99 | os.makedirs(view_path, exist_ok = True) 100 | 101 | # Perform some operation on the item 102 | print(item, gpu) 103 | 104 | if args.mode == 'render_ortho': 105 | command = ( 106 | f" CUDA_VISIBLE_DEVICES={gpu} " 107 | f" blenderproc run --blender-install-path {args.blender_install_path} blenderProc_ortho.py" 108 | f" --object_path {item} --view {view_idx}" 109 | f" --output_folder {args.save_folder}" 110 | f" --ortho_scale {args.ortho_scale} " 111 | ) 112 | if args.random_pose: 113 | print("random pose to render") 114 | command += f" --random_pose" 115 | elif args.mode == 'render_persp': 116 | command = ( 117 | f" CUDA_VISIBLE_DEVICES={gpu} " 118 | f" blenderproc run --blender-install-path {args.blender_install_path} blenderProc_persp.py" 119 | f" --object_path {item} --view {view_idx}" 120 | f" --output_folder {args.save_folder}" 121 | ) 122 | if args.random_pose: 123 | print("random pose to render") 124 | command += f" --random_pose" 125 | 126 | print(command) 127 | subprocess.run(command, shell=True) 128 | 129 | with count.get_lock(): 130 | count.value += 1 131 | 132 | queue.task_done() 133 | 134 | 135 | if __name__ == "__main__": 136 | # args = tyro.cli(Args) 137 | 138 | s3 = boto3.client("s3") if args.upload_to_s3 else None 139 | queue = multiprocessing.JoinableQueue() 140 | count = multiprocessing.Value("i", 0) 141 | 142 | # Start worker processes on each of the GPUs 143 | for gpu_i in range(args.num_gpus): 144 | for worker_i in range(args.workers_per_gpu): 145 | worker_i = gpu_i * args.workers_per_gpu + worker_i 146 | process = multiprocessing.Process( 147 | target=worker, args=(queue, count, args.gpu_list[gpu_i], s3) 148 | ) 149 | process.daemon = True 150 | process.start() 151 | 152 | # Add items to the queue 153 | if args.input_models_path is not None: 154 | with open(args.input_models_path, "r") as f: 155 | model_paths = json.load(f) 156 | 157 | args.end_i = len(model_paths) if args.end_i > len(model_paths) else args.end_i 158 | 159 | for item in model_paths[args.start_i:args.end_i]: 160 | 161 | if os.path.exists(os.path.join(args.objaverse_root, os.path.basename(item))): 162 | obj_path = os.path.join(args.objaverse_root, os.path.basename(item)) 163 | elif os.path.exists(os.path.join(args.objaverse_root, item)): 164 | obj_path = os.path.join(args.objaverse_root, item) 165 | else: 166 | obj_path = os.path.join(args.objaverse_root, item[:2], item+".glb") 167 | queue.put(obj_path) 168 | 169 | # Wait for all tasks to be completed 170 | queue.join() 171 | 172 | # Add sentinels to the queue to stop the worker processes 173 | for i in range(args.num_gpus * args.workers_per_gpu): 174 | queue.put(None) 175 | -------------------------------------------------------------------------------- /render_codes/render_batch_ortho.sh: -------------------------------------------------------------------------------- 1 | python distributed.py \ 2 | --num_gpus 8 --gpu_list 0 1 2 3 4 5 6 7 --mode render_ortho \ 3 | --workers_per_gpu 10 --view_idx $1 \ 4 | --start_i $2 --end_i $3 --ortho_scale 1.35 \ 5 | --input_models_path ../data_lists/lvis_uids_filter_by_vertex.json \ 6 | --objaverse_root /data/objaverse \ 7 | --save_folder data/obj_lvis_13views \ 8 | --blender_install_path /workplace/blender \ 9 | --random_pose 10 | -------------------------------------------------------------------------------- /render_codes/render_batch_persp.sh: -------------------------------------------------------------------------------- 1 | python distributed.py \ 2 | --num_gpus 8 --gpu_list 0 1 2 3 4 5 6 7 --mode render_persp \ 3 | --workers_per_gpu 10 --view_idx $1 \ 4 | --start_i $2 --end_i $3 \ 5 | --input_models_path ../data_lists/lvis/lvis_uids_filter_by_vertex.json \ 6 | --save_folder /data/nineviews-pinhole \ 7 | --objaverse_root /objaverse \ 8 | --blender_install_path /workplace/blender \ 9 | --random_pose 10 | -------------------------------------------------------------------------------- /render_codes/render_single_ortho.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender \ 3 | blenderProc_nineviews_ortho.py \ 4 | --object_path /mnt/pfs/data/objaverse_lvis_glbs/c7/c70e8817b5a945aca8bb37e02ddbc6f9.glb --view 0 \ 5 | --output_folder ./out_renderings/ \ 6 | --object_uid c70e8817b5a945aca8bb37e02ddbc6f9 \ 7 | --ortho_scale 1.35 \ 8 | --resolution 512 \ 9 | # --reset_object_euler -------------------------------------------------------------------------------- /render_codes/render_single_persp.sh: -------------------------------------------------------------------------------- 1 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender 2 | blenderProc_persp.py 3 | --object_path ${the object path} --view 0 4 | --output_folder ${your save path} 5 | --object_uid ${object_uid} --radius 2.0 6 | --random_pose -------------------------------------------------------------------------------- /render_codes/requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | blenderproc==2.5.0 3 | boto3==1.26.105 4 | docstring-parser==0.14.1 5 | h5py==3.8.0 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | markdown==3.4.3 9 | matplotlib==3.7.1 10 | multiprocess==0.70.13 11 | objaverse==0.0.7 12 | omegaconf==2.1.1 13 | opencv-python==4.5.5.64 14 | opencv-python-headless==4.7.0.72 15 | pillow==9.4.0 16 | progressbar==2.5 17 | scikit-image==0.20.0 18 | termcolor==2.2.0 19 | tqdm==4.65.0 20 | tyro==0.3.38 21 | vtk==9.2.6 22 | wandb==0.14.0 23 | zipp==3.15.0 24 | natsort 25 | mathutils==3.3.0 26 | argparse 27 | 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==1.13.1 3 | torchvision==0.14.1 4 | diffusers[torch]==0.19.3 5 | xformers==0.0.16 6 | transformers>=4.25.1 7 | bitsandbytes==0.35.4 8 | decord==0.6.0 9 | pytorch-lightning<2 10 | omegaconf==2.2.3 11 | nerfacc==0.3.3 12 | trimesh==3.9.8 13 | pyhocon==0.3.57 14 | icecream==2.1.0 15 | PyMCubes==0.1.2 16 | accelerate 17 | modelcards 18 | einops 19 | ftfy 20 | piq 21 | matplotlib 22 | opencv-python 23 | imageio 24 | imageio-ffmpeg 25 | scipy 26 | pyransac3d 27 | torch_efficient_distloss 28 | tensorboard 29 | rembg 30 | segment_anything 31 | gradio==3.50.2 32 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py --config configs/mvdiffusion-joint-ortho-6views.yaml -------------------------------------------------------------------------------- /run_train_stage1.sh: -------------------------------------------------------------------------------- 1 | 2 | # stage 1 3 | accelerate launch --config_file 1gpu.yaml train_mvdiffusion_image.py --config configs/train/stage1-mix-6views-lvis.yaml -------------------------------------------------------------------------------- /run_train_stage2.sh: -------------------------------------------------------------------------------- 1 | # stage 2 2 | accelerate launch --config_file 1gpu.yaml train_mvdiffusion_joint.py --config configs/train/stage2-joint-6views-lvis.yaml -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | from packaging import version 4 | 5 | 6 | # ============ Register OmegaConf Recolvers ============= # 7 | OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) 8 | OmegaConf.register_new_resolver('add', lambda a, b: a + b) 9 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b) 10 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b) 11 | OmegaConf.register_new_resolver('div', lambda a, b: a / b) 12 | OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) 13 | OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) 14 | # ======================================================= # 15 | 16 | 17 | def prompt(question): 18 | inp = input(f"{question} (y/n)").lower().strip() 19 | if inp and inp == 'y': 20 | return True 21 | if inp and inp == 'n': 22 | return False 23 | return prompt(question) 24 | 25 | 26 | def load_config(*yaml_files, cli_args=[]): 27 | yaml_confs = [OmegaConf.load(f) for f in yaml_files] 28 | cli_conf = OmegaConf.from_cli(cli_args) 29 | conf = OmegaConf.merge(*yaml_confs, cli_conf) 30 | OmegaConf.resolve(conf) 31 | return conf 32 | 33 | 34 | def config_to_primitive(config, resolve=True): 35 | return OmegaConf.to_container(config, resolve=resolve) 36 | 37 | 38 | def dump_config(path, config): 39 | with open(path, 'w') as fp: 40 | OmegaConf.save(config=config, f=fp) 41 | 42 | def get_rank(): 43 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 44 | # therefore LOCAL_RANK needs to be checked first 45 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 46 | for key in rank_keys: 47 | rank = os.environ.get(key) 48 | if rank is not None: 49 | return int(rank) 50 | return 0 51 | 52 | 53 | def parse_version(ver): 54 | return version.parse(ver) 55 | --------------------------------------------------------------------------------