├── .gitignore
├── 1gpu.yaml
├── 8gpu.yaml
├── LICENSE
├── NeuS
├── confs
│ └── wmask.conf
├── exp_runner.py
├── models
│ ├── dataset_mvdiff.py
│ ├── embedder.py
│ ├── fields.py
│ ├── fixed_poses
│ │ ├── 000_back_RT.txt
│ │ ├── 000_back_left_RT.txt
│ │ ├── 000_back_right_RT.txt
│ │ ├── 000_front_RT.txt
│ │ ├── 000_front_left_RT.txt
│ │ ├── 000_front_right_RT.txt
│ │ ├── 000_left_RT.txt
│ │ ├── 000_right_RT.txt
│ │ └── 000_top_RT.txt
│ ├── normal_utils.py
│ ├── ops.py
│ └── renderer.py
└── run.sh
├── README.md
├── README_zh.md
├── assets
├── bug_fixed.png
├── coordinate.png
└── fig_teaser.png
├── configs
├── mvdiffusion-joint-ortho-6views.yaml
└── train
│ ├── stage1-mix-6views-lvis.yaml
│ └── stage2-joint-6views-lvis.yaml
├── data_lists
├── lvis_invalid_uids_nineviews.json
└── lvis_uids_filter_by_vertex.json
├── docker
├── Dockerfile
├── README.md
└── requirements.txt
├── example_images
├── 14_10_29_489_Tiger_1__1.png
├── box.png
├── bread.png
├── cat.png
├── cat_head.png
├── chili.png
├── duola.png
├── halloween.png
├── head.png
├── kettle.png
├── kunkun.png
├── milk.png
├── owl.png
├── poro.png
├── pumpkin.png
├── skull.png
├── stone.png
├── teapot.png
└── tiger-head-3d-model-obj-stl.png
├── gradio_app_mv.py
├── gradio_app_recon.py
├── instant-nsr-pl
├── README.md
├── configs
│ └── neuralangelo-ortho-wmask.yaml
├── datasets
│ ├── __init__.py
│ ├── blender.py
│ ├── colmap.py
│ ├── colmap_utils.py
│ ├── dtu.py
│ ├── fixed_poses
│ │ ├── 000_back_RT.txt
│ │ ├── 000_back_left_RT.txt
│ │ ├── 000_back_right_RT.txt
│ │ ├── 000_front_RT.txt
│ │ ├── 000_front_left_RT.txt
│ │ ├── 000_front_right_RT.txt
│ │ ├── 000_left_RT.txt
│ │ ├── 000_right_RT.txt
│ │ └── 000_top_RT.txt
│ ├── ortho.py
│ └── utils.py
├── launch.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── geometry.py
│ ├── nerf.py
│ ├── network_utils.py
│ ├── neus.py
│ ├── ray_utils.py
│ ├── texture.py
│ └── utils.py
├── requirements.txt
├── run.sh
├── scripts
│ └── imgs2poses.py
├── systems
│ ├── __init__.py
│ ├── base.py
│ ├── criterions.py
│ ├── nerf.py
│ ├── neus.py
│ ├── neus_ortho.py
│ ├── neus_pinhole.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── callbacks.py
│ ├── loggers.py
│ ├── misc.py
│ ├── mixins.py
│ └── obj.py
├── mvdiffusion
├── data
│ ├── fixed_poses
│ │ └── nine_views
│ │ │ ├── 000_back_RT.txt
│ │ │ ├── 000_back_left_RT.txt
│ │ │ ├── 000_back_right_RT.txt
│ │ │ ├── 000_front_RT.txt
│ │ │ ├── 000_front_left_RT.txt
│ │ │ ├── 000_front_right_RT.txt
│ │ │ ├── 000_left_RT.txt
│ │ │ ├── 000_right_RT.txt
│ │ │ └── 000_top_RT.txt
│ ├── normal_utils.py
│ ├── objaverse_dataset.py
│ └── single_image_dataset.py
├── models
│ ├── transformer_mv2d.py
│ ├── unet_mv2d_blocks.py
│ └── unet_mv2d_condition.py
└── pipelines
│ └── pipeline_mvdiffusion_image.py
├── render_codes
├── README.md
├── blenderProc_ortho.py
├── blenderProc_persp.py
├── distributed.py
├── render_batch_ortho.sh
├── render_batch_persp.sh
├── render_single_ortho.sh
├── render_single_persp.sh
└── requirements.txt
├── requirements.txt
├── run_test.sh
├── run_train_stage1.sh
├── run_train_stage2.sh
├── test_mvdiffusion_seq.py
├── train_mvdiffusion_image.py
├── train_mvdiffusion_joint.py
└── utils
└── misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | ckpts
4 | sam_pt
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # tests and logs
14 | tests/fixtures/cached_*_text.txt
15 | logs/
16 | lightning_logs/
17 | lang_code_data/
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 | .dmypy.json
121 | dmypy.json
122 |
123 | # Pyre type checker
124 | .pyre/
125 |
126 | # vscode
127 | .vs
128 | .vscode
129 |
130 | # Pycharm
131 | .idea
132 |
133 | # TF code
134 | tensorflow_code
135 |
136 | # Models
137 | proc_data
138 |
139 | # examples
140 | runs
141 | /runs_old
142 | /wandb
143 | /examples/runs
144 | /examples/**/*.args
145 | /examples/rag/sweep
146 |
147 | # data
148 | /data
149 | serialization_dir
150 |
151 | # emacs
152 | *.*~
153 | debug.env
154 |
155 | # vim
156 | .*.swp
157 |
158 | #ctags
159 | tags
160 |
161 | # pre-commit
162 | .pre-commit*
163 |
164 | # .lock
165 | *.lock
166 |
167 | # DS_Store (MacOS)
168 | .DS_Store
169 | # RL pipelines may produce mp4 outputs
170 | *.mp4
171 |
172 | # dependencies
173 | /transformers
174 |
175 | # ruff
176 | .ruff_cache
177 |
178 | # ckpts
179 | *.ckpt
180 |
181 | outputs/*
182 |
183 | NeuS/exp/*
184 | NeuS/test_scenes/*
185 | NeuS/mesh2tex/*
186 | neus_configs
187 | vast/*
188 | render_results
189 | experiments/*
190 | ckpts/*
191 | neus/*
192 | instant-nsr-pl/exp/*
--------------------------------------------------------------------------------
/1gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: 'NO'
3 | downcast_bf16: 'no'
4 | gpu_ids: '0'
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: 'no'
8 | num_machines: 1
9 | num_processes: 1
10 | rdzv_backend: static
11 | same_network: true
12 | tpu_env: []
13 | tpu_use_cluster: false
14 | tpu_use_sudo: false
15 | use_cpu: false
16 |
--------------------------------------------------------------------------------
/8gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | downcast_bf16: 'no'
4 | gpu_ids: all
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: 'no'
8 | num_machines: 1
9 | num_processes: 8
10 | rdzv_backend: static
11 | same_network: true
12 | tpu_env: []
13 | tpu_use_cluster: false
14 | tpu_use_sudo: false
15 | use_cpu: false
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 xxlong0
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NeuS/confs/wmask.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = ./exp/neus/CASE_NAME/
3 | recording = [
4 | ./,
5 | ./models
6 | ]
7 | }
8 |
9 | dataset {
10 | data_dir = ./outputs/
11 | object_name = CASE_NAME
12 | object_viewidx = 1
13 | imSize = [256, 256]
14 | load_color = True
15 | stage = coarse
16 | mtype = mlp
17 | normal_system: front
18 | num_views = 6
19 | }
20 |
21 | train {
22 | learning_rate = 5e-4
23 | learning_rate_alpha = 0.05
24 | end_iter = 1000 # longer time, better result. 1w will be ok for most cases
25 |
26 | batch_size = 512
27 | validate_resolution_level = 1
28 | warm_up_end = 500
29 | anneal_end = 0
30 | use_white_bkgd = True
31 |
32 | save_freq = 5000
33 | val_freq = 5000
34 | val_mesh_freq =5000
35 | report_freq = 100
36 |
37 | color_weight = 1.0
38 | igr_weight = 0.1
39 | mask_weight = 1.0
40 | normal_weight = 1.0
41 | sparse_weight = 0.1
42 |
43 | }
44 |
45 | model {
46 | nerf {
47 | D = 8,
48 | d_in = 4,
49 | d_in_view = 3,
50 | W = 256,
51 | multires = 10,
52 | multires_view = 4,
53 | output_ch = 4,
54 | skips=[4],
55 | use_viewdirs=True
56 | }
57 |
58 | sdf_network {
59 | d_out = 257
60 | d_in = 3
61 | d_hidden = 256
62 | n_layers = 8
63 | skip_in = [4]
64 | multires = 6
65 | bias = 0.5
66 | scale = 1.0
67 | geometric_init = True
68 | weight_norm = True
69 | }
70 |
71 | variance_network {
72 | init_val = 0.3
73 | }
74 |
75 | rendering_network {
76 | d_feature = 256
77 | mode = no_view_dir
78 | d_in = 6
79 | d_out = 3
80 | d_hidden = 256
81 | n_layers = 4
82 | weight_norm = True
83 | multires_view = 0
84 | squeeze_out = True
85 | }
86 |
87 | neus_renderer {
88 | n_samples = 64
89 | n_importance = 64
90 | n_outside = 0
91 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling
92 | perturb = 1.0
93 | sdf_decay_param = 100
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/NeuS/models/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
6 | class Embedder:
7 | def __init__(self, **kwargs):
8 | self.kwargs = kwargs
9 | self.create_embedding_fn()
10 |
11 | def create_embedding_fn(self):
12 | embed_fns = []
13 | d = self.kwargs['input_dims']
14 | out_dim = 0
15 | if self.kwargs['include_input']:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs['max_freq_log2']
20 | N_freqs = self.kwargs['num_freqs']
21 |
22 | if self.kwargs['log_sampling']:
23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs['periodic_fns']:
29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 |
39 | def get_embedder(multires, input_dims=3):
40 | embed_kwargs = {
41 | 'include_input': True,
42 | 'input_dims': input_dims,
43 | 'max_freq_log2': multires-1,
44 | 'num_freqs': multires,
45 | 'log_sampling': True,
46 | 'periodic_fns': [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 | def embed(x, eo=embedder_obj): return eo.embed(x)
51 | return embed, embedder_obj.out_dim
52 |
--------------------------------------------------------------------------------
/NeuS/models/fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from models.embedder import get_embedder
6 |
7 |
8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
9 | class SDFNetwork(nn.Module):
10 | def __init__(self,
11 | d_in,
12 | d_out,
13 | d_hidden,
14 | n_layers,
15 | skip_in=(4,),
16 | multires=0,
17 | bias=0.5,
18 | scale=1,
19 | geometric_init=True,
20 | weight_norm=True,
21 | inside_outside=False):
22 | super(SDFNetwork, self).__init__()
23 |
24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
25 |
26 | self.embed_fn_fine = None
27 |
28 | if multires > 0:
29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
30 | self.embed_fn_fine = embed_fn
31 | dims[0] = input_ch
32 |
33 | self.num_layers = len(dims)
34 | self.skip_in = skip_in
35 | self.scale = scale
36 |
37 | for l in range(0, self.num_layers - 1):
38 | if l + 1 in self.skip_in:
39 | out_dim = dims[l + 1] - dims[0]
40 | else:
41 | out_dim = dims[l + 1]
42 |
43 | lin = nn.Linear(dims[l], out_dim)
44 |
45 | if geometric_init:
46 | if l == self.num_layers - 2:
47 | if not inside_outside:
48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
49 | torch.nn.init.constant_(lin.bias, -bias)
50 | else:
51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
52 | torch.nn.init.constant_(lin.bias, bias)
53 | elif multires > 0 and l == 0:
54 | torch.nn.init.constant_(lin.bias, 0.0)
55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
57 | elif multires > 0 and l in self.skip_in:
58 | torch.nn.init.constant_(lin.bias, 0.0)
59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
61 | else:
62 | torch.nn.init.constant_(lin.bias, 0.0)
63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
64 |
65 | if weight_norm:
66 | lin = nn.utils.weight_norm(lin)
67 |
68 | setattr(self, "lin" + str(l), lin)
69 |
70 | self.activation = nn.Softplus(beta=100)
71 |
72 | def forward(self, inputs):
73 | inputs = inputs * self.scale
74 | if self.embed_fn_fine is not None:
75 | inputs = self.embed_fn_fine(inputs)
76 |
77 | x = inputs
78 | for l in range(0, self.num_layers - 1):
79 | lin = getattr(self, "lin" + str(l))
80 |
81 | if l in self.skip_in:
82 | x = torch.cat([x, inputs], 1) / np.sqrt(2)
83 |
84 | x = lin(x)
85 |
86 | if l < self.num_layers - 2:
87 | x = self.activation(x)
88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
89 |
90 | def sdf(self, x):
91 | return self.forward(x)[:, :1]
92 |
93 | def sdf_hidden_appearance(self, x):
94 | return self.forward(x)
95 |
96 | def gradient(self, x):
97 | x.requires_grad_(True)
98 | with torch.enable_grad():
99 | y = self.sdf(x)
100 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
101 | gradients = torch.autograd.grad(
102 | outputs=y,
103 | inputs=x,
104 | grad_outputs=d_output,
105 | create_graph=True,
106 | retain_graph=True,
107 | only_inputs=True)[0]
108 | return gradients.unsqueeze(1)
109 |
110 |
111 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
112 | class RenderingNetwork(nn.Module):
113 | def __init__(self,
114 | d_feature,
115 | mode,
116 | d_in,
117 | d_out,
118 | d_hidden,
119 | n_layers,
120 | weight_norm=True,
121 | multires_view=0,
122 | squeeze_out=True):
123 | super().__init__()
124 |
125 | self.mode = mode
126 | self.squeeze_out = squeeze_out
127 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
128 |
129 | self.embedview_fn = None
130 | if multires_view > 0:
131 | embedview_fn, input_ch = get_embedder(multires_view)
132 | self.embedview_fn = embedview_fn
133 | dims[0] += (input_ch - 3)
134 |
135 | self.num_layers = len(dims)
136 |
137 | for l in range(0, self.num_layers - 1):
138 | out_dim = dims[l + 1]
139 | lin = nn.Linear(dims[l], out_dim)
140 |
141 | if weight_norm:
142 | lin = nn.utils.weight_norm(lin)
143 |
144 | setattr(self, "lin" + str(l), lin)
145 |
146 | self.relu = nn.ReLU()
147 |
148 | def forward(self, points, normals, view_dirs, feature_vectors):
149 | if self.embedview_fn is not None:
150 | view_dirs = self.embedview_fn(view_dirs)
151 |
152 | rendering_input = None
153 |
154 | if self.mode == 'idr':
155 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
156 | elif self.mode == 'no_view_dir':
157 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
158 | elif self.mode == 'no_normal':
159 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
160 |
161 | x = rendering_input
162 |
163 | for l in range(0, self.num_layers - 1):
164 | lin = getattr(self, "lin" + str(l))
165 |
166 | x = lin(x)
167 |
168 | if l < self.num_layers - 2:
169 | x = self.relu(x)
170 |
171 | if self.squeeze_out:
172 | x = torch.sigmoid(x)
173 | return x
174 |
175 |
176 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch
177 | class NeRF(nn.Module):
178 | def __init__(self,
179 | D=8,
180 | W=256,
181 | d_in=3,
182 | d_in_view=3,
183 | multires=0,
184 | multires_view=0,
185 | output_ch=4,
186 | skips=[4],
187 | use_viewdirs=False):
188 | super(NeRF, self).__init__()
189 | self.D = D
190 | self.W = W
191 | self.d_in = d_in
192 | self.d_in_view = d_in_view
193 | self.input_ch = 3
194 | self.input_ch_view = 3
195 | self.embed_fn = None
196 | self.embed_fn_view = None
197 |
198 | if multires > 0:
199 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
200 | self.embed_fn = embed_fn
201 | self.input_ch = input_ch
202 |
203 | if multires_view > 0:
204 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view)
205 | self.embed_fn_view = embed_fn_view
206 | self.input_ch_view = input_ch_view
207 |
208 | self.skips = skips
209 | self.use_viewdirs = use_viewdirs
210 |
211 | self.pts_linears = nn.ModuleList(
212 | [nn.Linear(self.input_ch, W)] +
213 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)])
214 |
215 | ### Implementation according to the official code release
216 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
217 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
218 |
219 | ### Implementation according to the paper
220 | # self.views_linears = nn.ModuleList(
221 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
222 |
223 | if use_viewdirs:
224 | self.feature_linear = nn.Linear(W, W)
225 | self.alpha_linear = nn.Linear(W, 1)
226 | self.rgb_linear = nn.Linear(W // 2, 3)
227 | else:
228 | self.output_linear = nn.Linear(W, output_ch)
229 |
230 | def forward(self, input_pts, input_views):
231 | if self.embed_fn is not None:
232 | input_pts = self.embed_fn(input_pts)
233 | if self.embed_fn_view is not None:
234 | input_views = self.embed_fn_view(input_views)
235 |
236 | h = input_pts
237 | for i, l in enumerate(self.pts_linears):
238 | h = self.pts_linears[i](h)
239 | h = F.relu(h)
240 | if i in self.skips:
241 | h = torch.cat([input_pts, h], -1)
242 |
243 | if self.use_viewdirs:
244 | alpha = self.alpha_linear(h)
245 | feature = self.feature_linear(h)
246 | h = torch.cat([feature, input_views], -1)
247 |
248 | for i, l in enumerate(self.views_linears):
249 | h = self.views_linears[i](h)
250 | h = F.relu(h)
251 |
252 | rgb = self.rgb_linear(h)
253 | return alpha, rgb
254 | else:
255 | assert False
256 |
257 |
258 | class SingleVarianceNetwork(nn.Module):
259 | def __init__(self, init_val):
260 | super(SingleVarianceNetwork, self).__init__()
261 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
262 |
263 | def forward(self, x):
264 | return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0)
265 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_back_RT.txt:
--------------------------------------------------------------------------------
1 | -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08
2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08
3 | 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_back_left_RT.txt:
--------------------------------------------------------------------------------
1 | -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07
2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07
3 | 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_back_right_RT.txt:
--------------------------------------------------------------------------------
1 | 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07
2 | -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08
3 | 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_front_RT.txt:
--------------------------------------------------------------------------------
1 | 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00
2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08
3 | -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_front_left_RT.txt:
--------------------------------------------------------------------------------
1 | -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07
2 | 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07
3 | -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_front_right_RT.txt:
--------------------------------------------------------------------------------
1 | 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07
2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07
3 | -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_left_RT.txt:
--------------------------------------------------------------------------------
1 | -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00
2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08
3 | -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_right_RT.txt:
--------------------------------------------------------------------------------
1 | 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08
2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08
3 | 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/fixed_poses/000_top_RT.txt:
--------------------------------------------------------------------------------
1 | 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09
2 | -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08
3 | -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/NeuS/models/normal_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def camNormal2worldNormal(rot_c2w, camNormal):
4 | H,W,_ = camNormal.shape
5 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
6 |
7 | return normal_img
8 |
9 | def worldNormal2camNormal(rot_w2c, normal_map_world):
10 | H,W,_ = normal_map_world.shape
11 | # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
12 |
13 | # faster version
14 | # Reshape the normal map into a 2D array where each row represents a normal vector
15 | normal_map_flat = normal_map_world.reshape(-1, 3)
16 |
17 | # Transform the normal vectors using the transformation matrix
18 | normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
19 |
20 | # Reshape the transformed normal map back to its original shape
21 | normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
22 |
23 | return normal_map_camera
24 |
25 | def trans_normal(normal, RT_w2c, RT_w2c_target):
26 |
27 | # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
28 | # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
29 |
30 | relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
31 | normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
32 |
33 | return normal_target_cam
34 |
35 | def img2normal(img):
36 | return (img/255.)*2-1
37 |
38 | def normal2img(normal):
39 | return np.uint8((normal*0.5+0.5)*255)
40 |
41 | def norm_normalize(normal, dim=-1):
42 |
43 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
44 |
45 | return normal
--------------------------------------------------------------------------------
/NeuS/models/ops.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
5 | """
6 | depth: (H, W)
7 | """
8 |
9 | x = np.nan_to_num(depth) # change nan to 0
10 | if minmax is None:
11 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
12 | ma = np.max(x)
13 | else:
14 | mi, ma = minmax
15 |
16 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
17 | x = (255 * x).astype(np.uint8)
18 | x_ = cv2.applyColorMap(x, cmap)
19 | return x_, [mi, ma]
--------------------------------------------------------------------------------
/NeuS/run.sh:
--------------------------------------------------------------------------------
1 | python exp_runner.py --mode train --conf ./confs/wmask.conf --case $2 --data_dir $1
2 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | **其他语言版本 [English](README.md)**
2 |
3 | # Wonder3D
4 | Single Image to 3D using Cross-Domain Diffusion
5 | ## [Paper](https://arxiv.org/abs/2310.15008) | [Project page](https://www.xxlong.site/Wonder3D/) | [Hugging Face Demo](https://huggingface.co/spaces/flamehaze1115/Wonder3D-demo) | [Colab from @camenduru](https://github.com/camenduru/Wonder3D-colab)
6 |
7 | 
8 |
9 | Wonder3D仅需2至3分钟即可从单视图图像中重建出高度详细的纹理网格。Wonder3D首先通过跨域扩散模型生成一致的多视图法线图与相应的彩色图像,然后利用一种新颖的法线融合方法实现快速且高质量的重建。
10 |
11 | ## Usage 使用
12 | ```bash
13 |
14 | import torch
15 | import requests
16 | from PIL import Image
17 | import numpy as np
18 | from torchvision.utils import make_grid, save_image
19 | from diffusers import DiffusionPipeline # only tested on diffusers[torch]==0.19.3, may have conflicts with newer versions of diffusers
20 |
21 | def load_wonder3d_pipeline():
22 |
23 | pipeline = DiffusionPipeline.from_pretrained(
24 | 'flamehaze1115/wonder3d-v1.0', # or use local checkpoint './ckpts'
25 | custom_pipeline='flamehaze1115/wonder3d-pipeline',
26 | torch_dtype=torch.float16
27 | )
28 |
29 | # enable xformers
30 | pipeline.unet.enable_xformers_memory_efficient_attention()
31 |
32 | if torch.cuda.is_available():
33 | pipeline.to('cuda:0')
34 | return pipeline
35 |
36 | pipeline = load_wonder3d_pipeline()
37 |
38 | # Download an example image.
39 | cond = Image.open(requests.get("https://d.skis.ltd/nrp/sample-data/lysol.png", stream=True).raw)
40 |
41 | # The object should be located in the center and resized to 80% of image height.
42 | cond = Image.fromarray(np.array(cond)[:, :, :3])
43 |
44 | # Run the pipeline!
45 | images = pipeline(cond, num_inference_steps=20, output_type='pt', guidance_scale=1.0).images
46 |
47 | result = make_grid(images, nrow=6, ncol=2, padding=0, value_range=(0, 1))
48 |
49 | save_image(result, 'result.png')
50 | ```
51 |
52 | ## Collaborations 合作
53 | 我们的总体使命是提高3D人工智能图形生成(3D AIGC)的速度、可负担性和质量,使所有人都能够轻松创建3D内容。尽管近年来取得了显著的进展,我们承认前方仍有很长的路要走。我们热切邀请您参与讨论并在任何方面探索潜在的合作机会。**如果您有兴趣与我们联系或合作,请随时通过电子邮件(xxlong@connect.hku.hk)联系我们**。
54 |
55 | ## More features
56 |
57 | The repo is still being under construction, thanks for your patience.
58 | - [x] Local gradio demo.
59 | - [x] Detailed tutorial.
60 | - [x] GUI demo for mesh reconstruction
61 | - [x] Windows support
62 | - [x] Docker support
63 |
64 | ## Schedule
65 | - [x] Inference code and pretrained models.
66 | - [x] Huggingface demo.
67 | - [ ] New model with higher resolution.
68 |
69 |
70 | ### Preparation for inference 测试准备
71 |
72 | #### Linux System Setup.
73 | ```angular2html
74 | conda create -n wonder3d
75 | conda activate wonder3d
76 | pip install -r requirements.txt
77 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
78 | ```
79 | #### Windows System Setup.
80 |
81 | 请切换到`main-windows`分支以查看Windows设置的详细信息。
82 |
83 | #### Docker Setup
84 | 详见 [docker/README.MD](docker/README.md)
85 |
86 | ### Inference
87 | 1. 可选。如果您在连接到Hugging Face时遇到问题,请确保已下载以下模型。
88 | 下载[checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/Ej7fMT1PwXtKvsELTvDuzuMBebQXEkmf2IwhSjBWtKAJiA)并放入根文件夹中。
89 |
90 | 国内用户可下载: [阿里云盘](https://www.alipan.com/s/T4rLUNAVq6V)
91 |
92 | ```bash
93 | Wonder3D
94 | |-- ckpts
95 | |-- unet
96 | |-- scheduler
97 | |-- vae
98 | ...
99 | ```
100 | 然后更改文件 ./configs/mvdiffusion-joint-ortho-6views.yaml, 设置 `pretrained_model_name_or_path="./ckpts"`
101 |
102 | 2. 下载模型 [SAM](https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth) . 放置在 ``sam_pt`` 文件夹.
103 | ```
104 | Wonder3D
105 | |-- sam_pt
106 | |-- sam_vit_h_4b8939.pth
107 | ```
108 | 3. 预测前景蒙版作为阿尔法通道。我们使用[Clipdrop](https://clipdrop.co/remove-background)来交互地分割前景对象。
109 | 您还可以使用`rembg`来去除背景。
110 | ```bash
111 | # !pip install rembg
112 | import rembg
113 | result = rembg.remove(result)
114 | result.show()
115 | ```
116 | 4. 运行Wonder3D以生成多视角一致的法线图和彩色图像。然后,您可以在文件夹`./outputs`中检查结果(我们使用`rembg`去除结果的背景,但分割并不总是完美的。可以考虑使用[Clipdrop](https://clipdrop.co/remove-background)获取生成的法线图和彩色图像的蒙版,因为蒙版的质量将显著影响重建的网格质量)。
117 | ```bash
118 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \
119 | --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir={your_data_path} \
120 | validation_dataset.filepaths=['your_img_file'] save_dir={your_save_path}
121 | ```
122 |
123 | 示例:
124 |
125 | ```bash
126 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \
127 | --config configs/mvdiffusion-joint-ortho-6views.yaml validation_dataset.root_dir=./example_images \
128 | validation_dataset.filepaths=['owl.png'] save_dir=./outputs
129 | ```
130 |
131 | #### 运行本地的Gradio演示。仅生成法线和颜色,无需进行重建。
132 | ```bash
133 | python gradio_app_mv.py # generate multi-view normals and colors
134 | ```
135 |
136 | 5. Mesh Extraction
137 |
138 | #### Instant-NSR Mesh Extraction
139 |
140 | ```bash
141 | cd ./instant-nsr-pl
142 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../{your_save_path}/cropsize-{crop_size}-cfg{guidance_scale:.1f}/ dataset.scene={scene}
143 | ```
144 |
145 | 示例:
146 |
147 | ```bash
148 | cd ./instant-nsr-pl
149 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=../outputs/cropsize-192-cfg1.0/ dataset.scene=owl
150 | ```
151 |
152 | 我们生成的法线图和彩色图像是在正交视图中定义的,因此重建的网格也处于正交摄像机空间。如果您使用MeshLab查看网格,可以在“View”选项卡中单击“Toggle Orthographic Camera”切换到正交相机。
153 |
154 | #### 运行本地的Gradio演示。首先生成法线和颜色,然后进行重建。无需首先执行`gradio_app_mv.py`。
155 | ```bash
156 | python gradio_app_recon.py
157 | ```
158 |
159 | #### NeuS-based Mesh Extraction
160 |
161 | 由于许多用户对于instant-nsr-pl的Windows设置提出了抱怨,我们提供了基于NeuS的重建,这可能消除了一些要求方面的问题。
162 |
163 | NeuS消耗较少的GPU内存,对平滑表面有利,无需参数调整。然而,NeuS需要更多时间,其纹理可能不够清晰。如果您对时间不太敏感,我们建议由于其稳健性而使用NeuS进行优化。
164 |
165 | ```bash
166 | cd ./NeuS
167 | bash run.sh output_folder_path scene_name
168 | ```
169 |
170 | ## 常见问题
171 | **获取更好结果的提示:**
172 | 1. **图片朝向方向敏感:** Wonder3D对输入图像的面向方向敏感。通过实验证明,面向前方的图像通常会导致良好的重建结果。
173 | 2. **图像分辨率:** 受资源限制,当前实现仅支持有限的视图(6个视图)和低分辨率(256x256)。任何图像都将首先调整大小为256x256进行生成,因此在这样的降采样后仍然保持清晰而锐利特征的图像将导致良好的结果。
174 | 3. **处理遮挡:** 具有遮挡的图像会导致更差的重建,因为6个视图无法完全覆盖整个对象。具有较少遮挡的图像通常会产生更好的结果。
175 | 4. **增加instant-nsr-pl中的优化步骤:** 在instant-nsr-pl中增加优化步骤。在`instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml`中修改`trainer.max_steps: 3000`为更多步骤,例如`trainer.max_steps: 10000`。更长的优化步骤会导致更好的纹理。
176 |
177 | **生成视图信息:**
178 | - **仰角和方位角度:** 与Zero123、SyncDreamer和One2345等先前作品采用对象世界系统不同,我们的视图是在输入图像的相机系统中定义的。六个视图在输入图像的相机系统中的平面上,仰角为0度。因此,我们不需要为输入图像估算仰角。六个视图的方位角度分别为0、45、90、180、-90、-45。
179 |
180 | **生成视图的焦距:**
181 | - 我们假设输入图像是由正交相机捕获的,因此生成的视图也在正交空间中。这种设计使得我们的模型能够在虚构图像上保持强大的泛化能力,但有时可能在实际捕获的图像上受到焦距镜头畸变的影响。
182 |
183 | ## 致谢
184 | We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes.
185 | - [stable diffusion](https://github.com/CompVis/stable-diffusion)
186 | - [zero123](https://github.com/cvlab-columbia/zero123)
187 | - [NeuS](https://github.com/Totoro97/NeuS)
188 | - [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer)
189 | - [instant-nsr-pl](https://github.com/bennyguo/instant-nsr-pl)
190 |
191 | ## 协议
192 | Wonder3D采用[AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0.en.html)许可,因此任何包含Wonder3D代码或其中训练的模型(无论是预训练还是定制训练)的下游解决方案和产品(包括云服务)都应该开源以符合AGPL条件。如果您对Wonder3D的使用有任何疑问,请首先与我们联系。
193 |
194 | ## 引用
195 | 如果您在项目中发现这个项目对您有用,请引用以下工作。 :)
196 | ```
197 | @article{long2023wonder3d,
198 | title={Wonder3D: Single Image to 3D using Cross-Domain Diffusion},
199 | author={Long, Xiaoxiao and Guo, Yuan-Chen and Lin, Cheng and Liu, Yuan and Dou, Zhiyang and Liu, Lingjie and Ma, Yuexin and Zhang, Song-Hai and Habermann, Marc and Theobalt, Christian and others},
200 | journal={arXiv preprint arXiv:2310.15008},
201 | year={2023}
202 | }
203 | ```
204 |
--------------------------------------------------------------------------------
/assets/bug_fixed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/bug_fixed.png
--------------------------------------------------------------------------------
/assets/coordinate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/coordinate.png
--------------------------------------------------------------------------------
/assets/fig_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/assets/fig_teaser.png
--------------------------------------------------------------------------------
/configs/mvdiffusion-joint-ortho-6views.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_name_or_path: 'flamehaze1115/wonder3d-v1.0' # or './ckpts'
2 | revision: null
3 | validation_dataset:
4 | root_dir: "./example_images" # the folder path stores testing images
5 | num_views: 6
6 | bg_color: 'white'
7 | img_wh: [256, 256]
8 | num_validation_samples: 1000
9 | crop_size: 192
10 | filepaths: ['owl.png'] # the test image names. leave it empty, test all images in the folder
11 |
12 | save_dir: 'outputs/'
13 |
14 | pred_type: 'joint'
15 | seed: 42
16 | validation_batch_size: 1
17 | dataloader_num_workers: 64
18 |
19 | local_rank: -1
20 |
21 | pipe_kwargs:
22 | camera_embedding_type: 'e_de_da_sincos'
23 | num_views: 6
24 |
25 | validation_guidance_scales: [1.0, 3.0]
26 | pipe_validation_kwargs:
27 | eta: 1.0
28 | validation_grid_nrow: 6
29 |
30 | unet_from_pretrained_kwargs:
31 | camera_embedding_type: 'e_de_da_sincos'
32 | projection_class_embeddings_input_dim: 10
33 | num_views: 6
34 | sample_size: 32
35 | cd_attention_mid: true
36 | zero_init_conv_in: false
37 | zero_init_camera_projection: false
38 |
39 | num_views: 6
40 | camera_embedding_type: 'e_de_da_sincos'
41 |
42 | enable_xformers_memory_efficient_attention: true
--------------------------------------------------------------------------------
/configs/train/stage1-mix-6views-lvis.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'
2 | revision: null
3 | train_dataset:
4 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
5 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
6 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
7 | num_views: 6
8 | groups_num: 1
9 | bg_color: 'three_choices'
10 | img_wh: [256, 256]
11 | validation: false
12 | num_validation_samples: 32
13 | # read_normal: true
14 | # read_color: true
15 | mix_color_normal: true
16 | validation_dataset:
17 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
18 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
19 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
20 | num_views: 6
21 | groups_num: 1
22 | bg_color: 'white'
23 | img_wh: [256, 256]
24 | validation: true
25 | num_validation_samples: 32
26 | # read_normal: true
27 | # read_color: true
28 | mix_color_normal: true
29 | validation_train_dataset:
30 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
31 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
32 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
33 | num_views: 6
34 | groups_num: 1
35 | bg_color: 'white'
36 | img_wh: [256, 256]
37 | validation: false
38 | num_validation_samples: 32
39 | num_samples: 32
40 | # read_normal: true
41 | # read_color: true
42 | mix_color_normal: true
43 |
44 | pred_type: 'mix'
45 |
46 | output_dir: 'outputs/wonder3D-mix'
47 | seed: 42
48 | train_batch_size: 32
49 | validation_batch_size: 16
50 | validation_train_batch_size: 16
51 | max_train_steps: 30000
52 | gradient_accumulation_steps: 2
53 | gradient_checkpointing: true
54 | learning_rate: 1.e-4
55 | scale_lr: false
56 | lr_scheduler: "constant_with_warmup"
57 | lr_warmup_steps: 100
58 | snr_gamma: 5.0
59 | use_8bit_adam: false
60 | allow_tf32: true
61 | use_ema: true
62 | dataloader_num_workers: 64
63 | adam_beta1: 0.9
64 | adam_beta2: 0.999
65 | adam_weight_decay: 1.e-2
66 | adam_epsilon: 1.e-08
67 | max_grad_norm: 1.0
68 | prediction_type: null
69 | vis_dir: vis
70 | logging_dir: logs
71 | mixed_precision: "fp16"
72 | report_to: 'tensorboard'
73 | local_rank: -1
74 | checkpointing_steps: 5000
75 | checkpoints_total_limit: 20
76 | resume_from_checkpoint: latest
77 | enable_xformers_memory_efficient_attention: true
78 | validation_steps: 1250
79 | validation_sanity_check: true
80 | tracker_project_name: 'mvdiffusion-image-v1'
81 |
82 | trainable_modules: null
83 | use_classifier_free_guidance: true
84 | condition_drop_rate: 0.05
85 | drop_type: 'drop_as_a_whole' # modify
86 | camera_embedding_lr_mult: 10.
87 | scale_input_latents: true
88 |
89 | pipe_kwargs:
90 | camera_embedding_type: 'e_de_da_sincos'
91 | num_views: 6
92 |
93 | validation_guidance_scales: [1., 3.]
94 | pipe_validation_kwargs:
95 | eta: 1.0
96 | validation_grid_nrow: 12
97 |
98 | unet_from_pretrained_kwargs:
99 | camera_embedding_type: 'e_de_da_sincos'
100 | projection_class_embeddings_input_dim: 10 # modify
101 | num_views: 6
102 | sample_size: 32
103 | zero_init_conv_in: true
104 | zero_init_camera_projection: false
105 | cd_attention_last: false
106 | cd_attention_mid: false
107 | multiview_attention: true
108 | sparse_mv_attention: false
109 | mvcd_attention: false
110 |
111 | num_views: 6
112 | camera_embedding_type: 'e_de_da_sincos'
113 |
--------------------------------------------------------------------------------
/configs/train/stage2-joint-6views-lvis.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'
2 | # modify the unet path; use the stage 1 checkpoint
3 | pretrained_unet_path: 'outputs/wonder3D-mix/checkpoint-30000/'
4 | # pretrained_unet_path: null
5 | revision: null
6 | train_dataset:
7 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
8 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
9 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
10 | num_views: 6
11 | groups_num: 1
12 | bg_color: 'three_choices'
13 | img_wh: [256, 256]
14 | validation: false
15 | num_validation_samples: 32
16 | read_normal: true
17 | read_color: true
18 | validation_dataset:
19 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
20 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
21 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
22 | num_views: 6
23 | groups_num: 1
24 | bg_color: 'white'
25 | img_wh: [256, 256]
26 | validation: true
27 | num_validation_samples: 32
28 | read_normal: true
29 | read_color: true
30 | validation_train_dataset:
31 | root_dir: '/mnt/pfs/data/objaverse_renderings_ortho_9views/' # change to your path
32 | object_list: './data_lists/lvis_uids_filter_by_vertex.json'
33 | invalid_list: './data_lists/lvis_invalid_uids_nineviews.json'
34 | num_views: 6
35 | groups_num: 1
36 | bg_color: 'three_choices'
37 | img_wh: [256, 256]
38 | validation: false
39 | num_validation_samples: 32
40 | num_samples: 32
41 | read_normal: true
42 | read_color: true
43 |
44 | # output_dir: 'outputs/debug'
45 | output_dir: 'outputs/wonder3D-joint'
46 | seed: 42
47 | train_batch_size: 32 # original paper uses 32
48 | validation_batch_size: 16
49 | validation_train_batch_size: 16
50 | max_train_steps: 20000
51 | gradient_accumulation_steps: 2
52 | gradient_checkpointing: true
53 | learning_rate: 5.e-5
54 | scale_lr: false
55 | lr_scheduler: "constant_with_warmup"
56 | lr_warmup_steps: 100
57 | snr_gamma: 5.0
58 | use_8bit_adam: false
59 | allow_tf32: true
60 | use_ema: true
61 | dataloader_num_workers: 64
62 | adam_beta1: 0.9
63 | adam_beta2: 0.999
64 | adam_weight_decay: 1.e-2
65 | adam_epsilon: 1.e-08
66 | max_grad_norm: 1.0
67 | prediction_type: null
68 | vis_dir: vis
69 | logging_dir: logs
70 | mixed_precision: "fp16"
71 | report_to: 'tensorboard'
72 | local_rank: -1
73 | checkpointing_steps: 5000
74 | checkpoints_total_limit: null
75 | last_global_step: 5000
76 |
77 | resume_from_checkpoint: latest
78 | enable_xformers_memory_efficient_attention: true
79 | validation_steps: 1000
80 | validation_sanity_check: true
81 | tracker_project_name: 'mvdiffusion-image-v1'
82 |
83 | trainable_modules: ['joint_mid']
84 | use_classifier_free_guidance: true
85 | condition_drop_rate: 0.05
86 | drop_type: 'drop_as_a_whole' # modify
87 | camera_embedding_lr_mult: 10.
88 | scale_input_latents: true
89 |
90 | pipe_kwargs:
91 | camera_embedding_type: 'e_de_da_sincos'
92 | num_views: 6
93 |
94 | validation_guidance_scales: [1., 3.]
95 | pipe_validation_kwargs:
96 | eta: 1.0
97 | validation_grid_nrow: 12
98 |
99 | unet_from_pretrained_kwargs:
100 | camera_embedding_type: 'e_de_da_sincos'
101 | projection_class_embeddings_input_dim: 10 # modify
102 | num_views: 6
103 | sample_size: 32
104 | zero_init_conv_in: false
105 | zero_init_camera_projection: false
106 | cd_attention_last: false
107 | cd_attention_mid: true
108 | multiview_attention: true
109 | sparse_mv_attention: false
110 | mvcd_attention: false
111 |
112 | num_views: 6
113 | camera_embedding_type: 'e_de_da_sincos'
114 |
115 |
116 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # get the development image from nvidia cuda 11.7
2 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
3 |
4 | LABEL name="Wonder3D" \
5 | maintainer="Tiancheng " \
6 | lastupdate="2024-01-05"
7 |
8 | # create workspace folder and set it as working directory
9 | RUN mkdir -p /workspace
10 | WORKDIR /workspace
11 |
12 | # Set the timezone
13 | ENV DEBIAN_FRONTEND=noninteractive
14 | RUN apt-get update && \
15 | apt-get install -y tzdata && \
16 | ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
17 | dpkg-reconfigure --frontend noninteractive tzdata
18 |
19 | # update package lists and install git, wget, vim, libgl1-mesa-glx, and libglib2.0-0
20 | RUN apt-get update && \
21 | apt-get install -y git wget vim libgl1-mesa-glx libglib2.0-0 unzip
22 |
23 | # install conda
24 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
25 | chmod +x Miniconda3-latest-Linux-x86_64.sh && \
26 | ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \
27 | rm Miniconda3-latest-Linux-x86_64.sh
28 |
29 | # update PATH environment variable
30 | ENV PATH="/workspace/miniconda3/bin:${PATH}"
31 |
32 | # initialize conda
33 | RUN conda init bash
34 |
35 | # create and activate conda environment
36 | RUN conda create -n wonder3d python=3.8 && echo "source activate wonder3d" > ~/.bashrc
37 | ENV PATH /workspace/miniconda3/envs/wonder3d/bin:$PATH
38 |
39 |
40 | # clone the repository
41 | RUN git clone https://github.com/xxlong0/Wonder3D.git && \
42 | cd /workspace/Wonder3D
43 |
44 | # change the working directory to the repository
45 | WORKDIR /workspace/Wonder3D
46 |
47 | # install pytorch 1.13.1 and torchvision
48 | RUN pip install -r docker/requirements.txt
49 |
50 | # install the specific version of nerfacc corresponding to torch 1.13.0 and cuda 11.7, otherwise the nerfacc will freeze during cuda setup
51 | RUN pip install nerfacc==0.3.3 -f https://nerfacc-bucket.s3.us-west-2.amazonaws.com/whl/torch-1.13.0_cu117.html
52 |
53 | # install tiny cuda during docker setup will cause error, need to install it manually in the container
54 | # RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
55 |
56 |
57 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Docker setup
2 |
3 | This docker setup is tested on Ubunu20.04.
4 |
5 | make sure you are under directory yourworkspace/Wonder3D/
6 |
7 | run
8 |
9 | `docker build --no-cache -t wonder3d/deploy:cuda11.7 -f docker/Dockerfile .`
10 |
11 | then run
12 |
13 | `docker run --gpus all -it wonder3d/deploy:cuda11.7 bash`
14 |
15 |
16 | ## Nvidia Container Toolkit setup
17 |
18 | You will have trouble enabling gpu for docker if you haven't installed **NVIDIA Container Toolkit** on you local machine before. You can skip this section if you have already installed it. Follow the instruction in this website to install it.
19 |
20 | https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
21 |
22 | or you can run the following command to install it with apt:
23 |
24 | 1.Configure the production repository:
25 |
26 | ```bash
27 | curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \
28 | && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \
29 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \
30 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
31 | ```
32 |
33 | 2.Update the packages list from the repository:
34 |
35 | `sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list`
36 |
37 | 3.Install the NVIDIA Container Toolkit packages:
38 |
39 | `sudo apt-get install -y nvidia-container-toolkit`
40 |
41 | Remember to restart the docker:
42 |
43 | `sudo systemctl restart docker`
44 |
45 | now you can run the following command:
46 |
47 | `docker run --gpus all -it wonder3d/deploy:cuda11.7 bash`
48 |
49 |
50 | ## Install Tiny Cudann
51 |
52 | After you start the container, run the following command to install tiny cudann. Somehow this pip installation can not be done during the docker build, so you have to do it manually after the docker is started.
53 |
54 | `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`
55 |
56 |
57 | Now you should be good to go, good luck and have fun :)
58 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 |
3 | # nerfacc==0.3.3, nefacc needs to be installed from the specific location
4 | # see installation part in this link: https://github.com/nerfstudio-project/nerfacc
5 |
6 | torch==1.13.1+cu117
7 | torchvision==0.14.1+cu117
8 | diffusers[torch]==0.19.3
9 | xformers==0.0.16
10 | transformers>=4.25.1
11 | bitsandbytes==0.35.4
12 | decord==0.6.0
13 | pytorch-lightning<2
14 | omegaconf==2.2.3
15 | trimesh==3.9.8
16 | pyhocon==0.3.57
17 | icecream==2.1.0
18 | PyMCubes==0.1.2
19 | accelerate
20 | modelcards
21 | einops
22 | ftfy
23 | piq
24 | matplotlib
25 | opencv-python
26 | imageio
27 | imageio-ffmpeg
28 | scipy
29 | pyransac3d
30 | torch_efficient_distloss
31 | tensorboard
32 | rembg
33 | segment_anything
34 | gradio==3.50.2
35 | triton
36 | rich
37 |
--------------------------------------------------------------------------------
/example_images/14_10_29_489_Tiger_1__1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/14_10_29_489_Tiger_1__1.png
--------------------------------------------------------------------------------
/example_images/box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/box.png
--------------------------------------------------------------------------------
/example_images/bread.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/bread.png
--------------------------------------------------------------------------------
/example_images/cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/cat.png
--------------------------------------------------------------------------------
/example_images/cat_head.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/cat_head.png
--------------------------------------------------------------------------------
/example_images/chili.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/chili.png
--------------------------------------------------------------------------------
/example_images/duola.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/duola.png
--------------------------------------------------------------------------------
/example_images/halloween.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/halloween.png
--------------------------------------------------------------------------------
/example_images/head.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/head.png
--------------------------------------------------------------------------------
/example_images/kettle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/kettle.png
--------------------------------------------------------------------------------
/example_images/kunkun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/kunkun.png
--------------------------------------------------------------------------------
/example_images/milk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/milk.png
--------------------------------------------------------------------------------
/example_images/owl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/owl.png
--------------------------------------------------------------------------------
/example_images/poro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/poro.png
--------------------------------------------------------------------------------
/example_images/pumpkin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/pumpkin.png
--------------------------------------------------------------------------------
/example_images/skull.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/skull.png
--------------------------------------------------------------------------------
/example_images/stone.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/stone.png
--------------------------------------------------------------------------------
/example_images/teapot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/teapot.png
--------------------------------------------------------------------------------
/example_images/tiger-head-3d-model-obj-stl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/example_images/tiger-head-3d-model-obj-stl.png
--------------------------------------------------------------------------------
/instant-nsr-pl/README.md:
--------------------------------------------------------------------------------
1 | # Instant Neural Surface Reconstruction
2 |
3 | This repository contains a concise and extensible implementation of NeRF and NeuS for neural surface reconstruction based on Instant-NGP and the Pytorch-Lightning framework. **Training on a NeRF-Synthetic scene takes ~5min for NeRF and ~10min for NeuS on a single RTX3090.**
4 |
5 | ||NeRF in 5min|NeuS in 10 min|
6 | |---|---|---|
7 | |Rendering|||
8 | |Mesh|||
9 |
10 |
11 | ## Features
12 | **This repository aims to provide a highly efficient while customizable boilerplate for research projects based on NeRF or NeuS.**
13 |
14 | - acceleration techniques from [Instant-NGP](https://github.com/NVlabs/instant-ngp): multiresolution hash encoding and fully fused networks by [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), occupancy grid pruning and rendering by [nerfacc](https://github.com/KAIR-BAIR/nerfacc)
15 | - out-of-the-box multi-GPU and mixed precision training by [PyTorch-Lightning](https://github.com/Lightning-AI/lightning)
16 | - hierarchical project layout that is designed to be easily customized and extended, flexible experiment configuration by [OmegaConf](https://github.com/omry/omegaconf)
17 |
18 | **Please subscribe to [#26](https://github.com/bennyguo/instant-nsr-pl/issues/26) for our latest findings on quality improvements!**
19 |
20 | ## News
21 |
22 | 🔥🔥🔥 Check out my new project on 3D content generation: https://github.com/threestudio-project/threestudio 🔥🔥🔥
23 |
24 | - 06/03/2023: Add an implementation of [Neuralangelo](https://research.nvidia.com/labs/dir/neuralangelo/). See [here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU) for details.
25 | - 03/31/2023: NeuS model now supports background modeling. You could try on the DTU dataset provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa) following [the instruction here](https://github.com/bennyguo/instant-nsr-pl#training-on-DTU).
26 | - 02/11/2023: NeRF model now supports unbounded 360 scenes with learned background. You could try on [MipNeRF 360 data](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) following [the COLMAP configuration](https://github.com/bennyguo/instant-nsr-pl#training-on-custom-colmap-data).
27 |
28 | ## Requirements
29 | **Note:**
30 | - To utilize multiresolution hash encoding or fully fused networks provided by tiny-cuda-nn, you should have least an RTX 2080Ti, see [https://github.com/NVlabs/tiny-cuda-nn#requirements](https://github.com/NVlabs/tiny-cuda-nn#requirements) for more details.
31 | - Multi-GPU training is currently not supported on Windows (see [#4](https://github.com/bennyguo/instant-nsr-pl/issues/4)).
32 | ### Environments
33 | - Install PyTorch>=1.10 [here](https://pytorch.org/get-started/locally/) based the package management tool you used and your cuda version (older PyTorch versions may work but have not been tested)
34 | - Install tiny-cuda-nn PyTorch extension: `pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch`
35 | - `pip install -r requirements.txt`
36 |
37 |
38 | ## Run
39 | ### Training on NeRF-Synthetic
40 | Download the NeRF-Synthetic data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and put it under `load/`. The file structure should be like `load/nerf_synthetic/lego`.
41 |
42 | Run the launch script with `--train`, specifying the config file, the GPU(s) to be used (GPU 0 will be used by default), and the scene name:
43 | ```bash
44 | # train NeRF
45 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=example
46 |
47 | # train NeuS with mask
48 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example
49 | # train NeuS without mask
50 | python launch.py --config configs/neus-blender.yaml --gpu 0 --train dataset.scene=lego tag=example system.loss.lambda_mask=0.0
51 | ```
52 | The code snapshots, checkpoints and experiment outputs are saved to `exp/[name]/[tag]@[timestamp]`, and tensorboard logs can be found at `runs/[name]/[tag]@[timestamp]`. You can change any configuration in the YAML file by specifying arguments without `--`, for example:
53 | ```bash
54 | python launch.py --config configs/nerf-blender.yaml --gpu 0 --train dataset.scene=lego tag=iter50k seed=0 trainer.max_steps=50000
55 | ```
56 | ### Training on DTU
57 | Download preprocessed DTU data provided by [NeuS](https://drive.google.com/drive/folders/1Nlzejs4mfPuJYORLbDEUDWlc9IZIbU0C?usp=sharing) or [IDR](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa). In the provided config files we assume using NeuS DTU data. If you are using IDR DTU data, please set `dataset.cameras_file=cameras.npz`. You may also need to adjust `dataset.root_dir` to point to your downloaded data location.
58 | ```bash
59 | # train NeuS on DTU without mask
60 | python launch.py --config configs/neus-dtu.yaml --gpu 0 --train
61 | # train NeuS on DTU with mask
62 | python launch.py --config configs/neus-dtu-wmask.yaml --gpu 0 --train
63 | # train NeuS on DTU with mask using tricks from Neuralangelo (experimental)
64 | python launch.py --config configs/neuralangelo-dtu-wmask.yaml --gpu 0 --train
65 | ```
66 | Notes:
67 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing.
68 | - The results of Neuralangelo can't reach those in the original paper. Some potential improvements: more iterations; larger `system.geometry.xyz_encoding_config.update_steps`; larger `system.geometry.xyz_encoding_config.n_features_per_level`; larger `system.geometry.xyz_encoding_config.log2_hashmap_size`; adopting curvature loss.
69 |
70 | ### Training on Custom COLMAP Data
71 | To get COLMAP data from custom images, you should have COLMAP installed (see [here](https://colmap.github.io/install.html) for installation instructions). Then put your images in the `images/` folder, and run `scripts/imgs2poses.py` specifying the path containing the `images/` folder. For example:
72 | ```bash
73 | python scripts/imgs2poses.py ./load/bmvs_dog # images are in ./load/bmvs_dog/images
74 | ```
75 | Existing data following this file structure also works as long as images are store in `images/` and there is a `sparse/` folder for the COLMAP output, for example [the data provided by MipNeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip). An optional `masks/` folder could be provided for object mask supervision. To train on COLMAP data, please refer to the example config files `config/*-colmap.yaml`. Some notes:
76 | - Adapt the `root_dir` and `img_wh` (or `img_downscale`) option in the config file to your data;
77 | - The scene is normalized so that cameras have a minimum distance `1.0` to the center of the scene. Setting `model.radius=1.0` works in most cases. If not, try setting a smaller radius that wraps tightly to your foreground object.
78 | - There are three choices to determine the scene center: `dataset.center_est_method=camera` uses the center of all camera positions as the scene center; `dataset.center_est_method=lookat` assumes the cameras are looking at the same point and calculates an approximate look-at point as the scene center; `dataset.center_est_method=point` uses the center of all points (reconstructed by COLMAP) that are bounded by cameras as the scene center. Please choose an appropriate method according to your capture.
79 | - PSNR in the testing stage is meaningless, as we simply compare to pure white images in testing.
80 |
81 | ### Testing
82 | The training procedure are by default followed by testing, which computes metrics on test data, generates animations and exports the geometry as triangular meshes. If you want to do testing alone, just resume the pretrained model and replace `--train` with `--test`, for example:
83 | ```bash
84 | python launch.py --config path/to/your/exp/config/parsed.yaml --resume path/to/your/exp/ckpt/epoch=0-step=20000.ckpt --gpu 0 --test
85 | ```
86 |
87 |
88 | ## Benchmarks
89 | All experiments are conducted on a single NVIDIA RTX3090.
90 |
91 | |PSNR|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.|
92 | |---|---|---|---|---|---|---|---|---|---|
93 | |NeRF Paper|33.00|25.01|30.13|36.18|32.54|29.62|32.91|28.65|31.01|
94 | |NeRF Ours (20k)|34.80|26.04|33.89|37.42|35.33|29.46|35.22|31.17|32.92|
95 | |NeuS Ours (20k, with masks)|34.04|25.26|32.47|35.94|33.78|27.67|33.43|29.50|31.51|
96 |
97 | |Training Time (mm:ss)|Chair|Drums|Ficus|Hotdog|Lego|Materials|Mic|Ship|Avg.|
98 | |---|---|---|---|---|---|---|---|---|---|
99 | |NeRF Ours (20k)|04:34|04:35|04:18|04:46|04:39|04:35|04:26|05:41|04:42|
100 | |NeuS Ours (20k, with masks)|11:25|10:34|09:51|12:11|11:37|11:46|09:59|16:25|11:44|
101 |
102 |
103 | ## TODO
104 | - [✅] Support more dataset formats, like COLMAP outputs and DTU
105 | - [✅] Support simple background model
106 | - [ ] Support GUI training and interaction
107 | - [ ] More illustrations about the framework
108 |
109 | ## Related Projects
110 | - [ngp_pl](https://github.com/kwea123/ngp_pl): Great Instant-NGP implementation in PyTorch-Lightning! Background model and GUI supported.
111 | - [Instant-NSR](https://github.com/zhaofuq/Instant-NSR): NeuS implementation using multiresolution hash encoding.
112 |
113 | ## Citation
114 | If you find this codebase useful, please consider citing:
115 | ```
116 | @misc{instant-nsr-pl,
117 | Author = {Yuan-Chen Guo},
118 | Year = {2022},
119 | Note = {https://github.com/bennyguo/instant-nsr-pl},
120 | Title = {Instant Neural Surface Reconstruction}
121 | }
122 | ```
123 |
--------------------------------------------------------------------------------
/instant-nsr-pl/configs/neuralangelo-ortho-wmask.yaml:
--------------------------------------------------------------------------------
1 | name: ${basename:${dataset.scene}}
2 | tag: ""
3 | seed: 42
4 |
5 | dataset:
6 | name: ortho
7 | root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0
8 | cam_pose_dir: null
9 | scene: scene_name
10 | imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors
11 | camera_type: ortho
12 | apply_mask: true
13 | camera_params: null
14 | view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left']
15 | # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
16 |
17 | model:
18 | name: neus
19 | radius: 1.0
20 | num_samples_per_ray: 1024
21 | train_num_rays: 256
22 | max_train_num_rays: 8192
23 | grid_prune: true
24 | grid_prune_occ_thre: 0.001
25 | dynamic_ray_sampling: true
26 | batch_image_sampling: true
27 | randomized: true
28 | ray_chunk: 2048
29 | cos_anneal_end: 20000
30 | learned_background: false
31 | background_color: black
32 | variance:
33 | init_val: 0.3
34 | modulate: false
35 | geometry:
36 | name: volume-sdf
37 | radius: ${model.radius}
38 | feature_dim: 13
39 | grad_type: finite_difference
40 | finite_difference_eps: progressive
41 | isosurface:
42 | method: mc
43 | resolution: 192
44 | chunk: 2097152
45 | threshold: 0.
46 | xyz_encoding_config:
47 | otype: ProgressiveBandHashGrid
48 | n_levels: 10 # 12 modify
49 | n_features_per_level: 2
50 | log2_hashmap_size: 19
51 | base_resolution: 32
52 | per_level_scale: 1.3195079107728942
53 | include_xyz: true
54 | start_level: 4
55 | start_step: 0
56 | update_steps: 1000
57 | mlp_network_config:
58 | otype: VanillaMLP
59 | activation: ReLU
60 | output_activation: none
61 | n_neurons: 64
62 | n_hidden_layers: 1
63 | sphere_init: true
64 | sphere_init_radius: 0.5
65 | weight_norm: true
66 | texture:
67 | name: volume-radiance
68 | input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
69 | dir_encoding_config:
70 | otype: SphericalHarmonics
71 | degree: 4
72 | mlp_network_config:
73 | otype: VanillaMLP
74 | activation: ReLU
75 | output_activation: none
76 | n_neurons: 64
77 | n_hidden_layers: 2
78 | color_activation: sigmoid
79 |
80 | system:
81 | name: ortho-neus-system
82 | loss:
83 | lambda_rgb_mse: 0.5
84 | lambda_rgb_l1: 0.
85 | lambda_mask: 1.0
86 | lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects
87 | lambda_normal: 1.0 # cannot be too large
88 | lambda_3d_normal_smooth: 1.0
89 | # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
90 | lambda_curvature: 0.
91 | lambda_sparsity: 0.5
92 | lambda_distortion: 0.0
93 | lambda_distortion_bg: 0.0
94 | lambda_opaque: 0.0
95 | sparsity_scale: 100.0
96 | geo_aware: true
97 | rgb_p_ratio: 0.8
98 | normal_p_ratio: 0.8
99 | mask_p_ratio: 0.9
100 | optimizer:
101 | name: AdamW
102 | args:
103 | lr: 0.01
104 | betas: [0.9, 0.99]
105 | eps: 1.e-15
106 | params:
107 | geometry:
108 | lr: 0.001
109 | texture:
110 | lr: 0.01
111 | variance:
112 | lr: 0.001
113 | constant_steps: 500
114 | scheduler:
115 | name: SequentialLR
116 | interval: step
117 | milestones:
118 | - ${system.constant_steps}
119 | schedulers:
120 | - name: ConstantLR
121 | args:
122 | factor: 1.0
123 | total_iters: ${system.constant_steps}
124 | - name: ExponentialLR
125 | args:
126 | gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
127 |
128 | checkpoint:
129 | save_top_k: -1
130 | every_n_train_steps: ${trainer.max_steps}
131 |
132 | export:
133 | chunk_size: 2097152
134 | export_vertex_color: True
135 | ortho_scale: 1.35 #modify
136 |
137 | trainer:
138 | max_steps: 3000
139 | log_every_n_steps: 100
140 | num_sanity_val_steps: 0
141 | val_check_interval: 4000
142 | limit_train_batches: 1.0
143 | limit_val_batches: 2
144 | enable_progress_bar: true
145 | precision: 16
146 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | datasets = {}
2 |
3 |
4 | def register(name):
5 | def decorator(cls):
6 | datasets[name] = cls
7 | return cls
8 | return decorator
9 |
10 |
11 | def make(name, config):
12 | dataset = datasets[name](config)
13 | return dataset
14 |
15 |
16 | from . import blender, colmap, dtu, ortho
17 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torch
8 | from torch.utils.data import Dataset, DataLoader, IterableDataset
9 | import torchvision.transforms.functional as TF
10 |
11 | import pytorch_lightning as pl
12 |
13 | import datasets
14 | from models.ray_utils import get_ray_directions
15 | from utils.misc import get_rank
16 |
17 |
18 | class BlenderDatasetBase():
19 | def setup(self, config, split):
20 | self.config = config
21 | self.split = split
22 | self.rank = get_rank()
23 |
24 | self.has_mask = True
25 | self.apply_mask = True
26 |
27 | with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f:
28 | meta = json.load(f)
29 |
30 | if 'w' in meta and 'h' in meta:
31 | W, H = int(meta['w']), int(meta['h'])
32 | else:
33 | W, H = 800, 800
34 |
35 | if 'img_wh' in self.config:
36 | w, h = self.config.img_wh
37 | assert round(W / w * h) == H
38 | elif 'img_downscale' in self.config:
39 | w, h = W // self.config.img_downscale, H // self.config.img_downscale
40 | else:
41 | raise KeyError("Either img_wh or img_downscale should be specified.")
42 |
43 | self.w, self.h = w, h
44 | self.img_wh = (self.w, self.h)
45 |
46 | self.near, self.far = self.config.near_plane, self.config.far_plane
47 |
48 | self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) # scaled focal length
49 |
50 | # ray directions for all pixels, same for all images (same H, W, focal)
51 | self.directions = \
52 | get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) # (h, w, 3)
53 |
54 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
55 |
56 | for i, frame in enumerate(meta['frames']):
57 | c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4])
58 | self.all_c2w.append(c2w)
59 |
60 | img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png")
61 | img = Image.open(img_path)
62 | img = img.resize(self.img_wh, Image.BICUBIC)
63 | img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
64 |
65 | self.all_fg_masks.append(img[..., -1]) # (h, w)
66 | self.all_images.append(img[...,:3])
67 |
68 | self.all_c2w, self.all_images, self.all_fg_masks = \
69 | torch.stack(self.all_c2w, dim=0).float().to(self.rank), \
70 | torch.stack(self.all_images, dim=0).float().to(self.rank), \
71 | torch.stack(self.all_fg_masks, dim=0).float().to(self.rank)
72 |
73 |
74 | class BlenderDataset(Dataset, BlenderDatasetBase):
75 | def __init__(self, config, split):
76 | self.setup(config, split)
77 |
78 | def __len__(self):
79 | return len(self.all_images)
80 |
81 | def __getitem__(self, index):
82 | return {
83 | 'index': index
84 | }
85 |
86 |
87 | class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
88 | def __init__(self, config, split):
89 | self.setup(config, split)
90 |
91 | def __iter__(self):
92 | while True:
93 | yield {}
94 |
95 |
96 | @datasets.register('blender')
97 | class BlenderDataModule(pl.LightningDataModule):
98 | def __init__(self, config):
99 | super().__init__()
100 | self.config = config
101 |
102 | def setup(self, stage=None):
103 | if stage in [None, 'fit']:
104 | self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split)
105 | if stage in [None, 'fit', 'validate']:
106 | self.val_dataset = BlenderDataset(self.config, self.config.val_split)
107 | if stage in [None, 'test']:
108 | self.test_dataset = BlenderDataset(self.config, self.config.test_split)
109 | if stage in [None, 'predict']:
110 | self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
111 |
112 | def prepare_data(self):
113 | pass
114 |
115 | def general_loader(self, dataset, batch_size):
116 | sampler = None
117 | return DataLoader(
118 | dataset,
119 | num_workers=os.cpu_count(),
120 | batch_size=batch_size,
121 | pin_memory=True,
122 | sampler=sampler
123 | )
124 |
125 | def train_dataloader(self):
126 | return self.general_loader(self.train_dataset, batch_size=1)
127 |
128 | def val_dataloader(self):
129 | return self.general_loader(self.val_dataset, batch_size=1)
130 |
131 | def test_dataloader(self):
132 | return self.general_loader(self.test_dataset, batch_size=1)
133 |
134 | def predict_dataloader(self):
135 | return self.general_loader(self.predict_dataset, batch_size=1)
136 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/dtu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import numpy as np
5 | from PIL import Image
6 | import cv2
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import Dataset, DataLoader, IterableDataset
11 | import torchvision.transforms.functional as TF
12 |
13 | import pytorch_lightning as pl
14 |
15 | import datasets
16 | from models.ray_utils import get_ray_directions
17 | from utils.misc import get_rank
18 |
19 |
20 | def load_K_Rt_from_P(P=None):
21 | out = cv2.decomposeProjectionMatrix(P)
22 | K = out[0]
23 | R = out[1]
24 | t = out[2]
25 |
26 | K = K / K[2, 2]
27 | intrinsics = np.eye(4)
28 | intrinsics[:3, :3] = K
29 |
30 | pose = np.eye(4, dtype=np.float32)
31 | pose[:3, :3] = R.transpose()
32 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
33 |
34 | return intrinsics, pose
35 |
36 | def create_spheric_poses(cameras, n_steps=120):
37 | center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
38 | cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2)
39 | eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors
40 | rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1)
41 | up = rot_axis
42 | rot_dir = torch.cross(rot_axis, cam_center)
43 | max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max()
44 |
45 | all_c2w = []
46 | for theta in torch.linspace(-max_angle, max_angle, n_steps):
47 | cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta)
48 | l = F.normalize(center - cam_pos, p=2, dim=0)
49 | s = F.normalize(l.cross(up), p=2, dim=0)
50 | u = F.normalize(s.cross(l), p=2, dim=0)
51 | c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
52 | all_c2w.append(c2w)
53 |
54 | all_c2w = torch.stack(all_c2w, dim=0)
55 |
56 | return all_c2w
57 |
58 | class DTUDatasetBase():
59 | def setup(self, config, split):
60 | self.config = config
61 | self.split = split
62 | self.rank = get_rank()
63 |
64 | cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file))
65 |
66 | img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png'))
67 | H, W = img_sample.shape[0], img_sample.shape[1]
68 |
69 | if 'img_wh' in self.config:
70 | w, h = self.config.img_wh
71 | assert round(W / w * h) == H
72 | elif 'img_downscale' in self.config:
73 | w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
74 | else:
75 | raise KeyError("Either img_wh or img_downscale should be specified.")
76 |
77 | self.w, self.h = w, h
78 | self.img_wh = (w, h)
79 | self.factor = w / W
80 |
81 | mask_dir = os.path.join(self.config.root_dir, 'mask')
82 | self.has_mask = True
83 | self.apply_mask = self.config.apply_mask
84 |
85 | self.directions = []
86 | self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
87 |
88 | n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1
89 |
90 | for i in range(n_images):
91 | world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}']
92 | P = (world_mat @ scale_mat)[:3,:4]
93 | K, c2w = load_K_Rt_from_P(P)
94 | fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor
95 | directions = get_ray_directions(w, h, fx, fy, cx, cy)
96 | self.directions.append(directions)
97 |
98 | c2w = torch.from_numpy(c2w).float()
99 |
100 | # blender follows opengl camera coordinates (right up back)
101 | # NeuS DTU data coordinate system (right down front) is different from blender
102 | # https://github.com/Totoro97/NeuS/issues/9
103 | # for c2w, flip the sign of input camera coordinate yz
104 | c2w_ = c2w.clone()
105 | c2w_[:3,1:3] *= -1. # flip input sign
106 | self.all_c2w.append(c2w_[:3,:4])
107 |
108 | if self.split in ['train', 'val']:
109 | img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png')
110 | img = Image.open(img_path)
111 | img = img.resize(self.img_wh, Image.BICUBIC)
112 | img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
113 |
114 | mask_path = os.path.join(mask_dir, f'{i:03d}.png')
115 | mask = Image.open(mask_path).convert('L') # (H, W, 1)
116 | mask = mask.resize(self.img_wh, Image.BICUBIC)
117 | mask = TF.to_tensor(mask)[0]
118 |
119 | self.all_fg_masks.append(mask) # (h, w)
120 | self.all_images.append(img)
121 |
122 | self.all_c2w = torch.stack(self.all_c2w, dim=0)
123 |
124 | if self.split == 'test':
125 | self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
126 | self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
127 | self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
128 | self.directions = self.directions[0]
129 | else:
130 | self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0)
131 | self.directions = torch.stack(self.directions, dim=0)
132 |
133 | self.directions = self.directions.float().to(self.rank)
134 | self.all_c2w, self.all_images, self.all_fg_masks = \
135 | self.all_c2w.float().to(self.rank), \
136 | self.all_images.float().to(self.rank), \
137 | self.all_fg_masks.float().to(self.rank)
138 |
139 |
140 | class DTUDataset(Dataset, DTUDatasetBase):
141 | def __init__(self, config, split):
142 | self.setup(config, split)
143 |
144 | def __len__(self):
145 | return len(self.all_images)
146 |
147 | def __getitem__(self, index):
148 | return {
149 | 'index': index
150 | }
151 |
152 |
153 | class DTUIterableDataset(IterableDataset, DTUDatasetBase):
154 | def __init__(self, config, split):
155 | self.setup(config, split)
156 |
157 | def __iter__(self):
158 | while True:
159 | yield {}
160 |
161 |
162 | @datasets.register('dtu')
163 | class DTUDataModule(pl.LightningDataModule):
164 | def __init__(self, config):
165 | super().__init__()
166 | self.config = config
167 |
168 | def setup(self, stage=None):
169 | if stage in [None, 'fit']:
170 | self.train_dataset = DTUIterableDataset(self.config, 'train')
171 | if stage in [None, 'fit', 'validate']:
172 | self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train'))
173 | if stage in [None, 'test']:
174 | self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test'))
175 | if stage in [None, 'predict']:
176 | self.predict_dataset = DTUDataset(self.config, 'train')
177 |
178 | def prepare_data(self):
179 | pass
180 |
181 | def general_loader(self, dataset, batch_size):
182 | sampler = None
183 | return DataLoader(
184 | dataset,
185 | num_workers=os.cpu_count(),
186 | batch_size=batch_size,
187 | pin_memory=True,
188 | sampler=sampler
189 | )
190 |
191 | def train_dataloader(self):
192 | return self.general_loader(self.train_dataset, batch_size=1)
193 |
194 | def val_dataloader(self):
195 | return self.general_loader(self.val_dataset, batch_size=1)
196 |
197 | def test_dataloader(self):
198 | return self.general_loader(self.test_dataset, batch_size=1)
199 |
200 | def predict_dataloader(self):
201 | return self.general_loader(self.predict_dataset, batch_size=1)
202 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_back_RT.txt:
--------------------------------------------------------------------------------
1 | -1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2 | 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07
3 | 0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_back_left_RT.txt:
--------------------------------------------------------------------------------
1 | -7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
3 | -7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_back_right_RT.txt:
--------------------------------------------------------------------------------
1 | -7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
3 | 7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_front_RT.txt:
--------------------------------------------------------------------------------
1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2 | 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07
3 | 0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_front_left_RT.txt:
--------------------------------------------------------------------------------
1 | 7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
3 | -7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_front_right_RT.txt:
--------------------------------------------------------------------------------
1 | 7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
2 | 0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
3 | 7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_left_RT.txt:
--------------------------------------------------------------------------------
1 | -2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16
2 | 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3 | -1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_right_RT.txt:
--------------------------------------------------------------------------------
1 | -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16
2 | 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3 | 1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/fixed_poses/000_top_RT.txt:
--------------------------------------------------------------------------------
1 | 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2 | 0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
3 | 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00
4 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/ortho.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import numpy as np
5 | from PIL import Image
6 | import cv2
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import Dataset, DataLoader, IterableDataset
11 | import torchvision.transforms.functional as TF
12 |
13 | import pytorch_lightning as pl
14 |
15 | import datasets
16 | from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
17 | from utils.misc import get_rank
18 |
19 | from glob import glob
20 | import PIL.Image
21 |
22 |
23 | def camNormal2worldNormal(rot_c2w, camNormal):
24 | H,W,_ = camNormal.shape
25 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
26 |
27 | return normal_img
28 |
29 | def worldNormal2camNormal(rot_w2c, worldNormal):
30 | H,W,_ = worldNormal.shape
31 | normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
32 |
33 | return normal_img
34 |
35 | def trans_normal(normal, RT_w2c, RT_w2c_target):
36 |
37 | normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
38 | normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
39 |
40 | return normal_target_cam
41 |
42 | def img2normal(img):
43 | return (img/255.)*2-1
44 |
45 | def normal2img(normal):
46 | return np.uint8((normal*0.5+0.5)*255)
47 |
48 | def norm_normalize(normal, dim=-1):
49 |
50 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
51 |
52 | return normal
53 |
54 | def RT_opengl2opencv(RT):
55 | # Build the coordinate transform matrix from world to computer vision camera
56 | # R_world2cv = R_bcam2cv@R_world2bcam
57 | # T_world2cv = R_bcam2cv@T_world2bcam
58 |
59 | R = RT[:3, :3]
60 | t = RT[:3, 3]
61 |
62 | R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
63 |
64 | R_world2cv = R_bcam2cv @ R
65 | t_world2cv = R_bcam2cv @ t
66 |
67 | RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
68 |
69 | return RT
70 |
71 | def normal_opengl2opencv(normal):
72 | H,W,C = np.shape(normal)
73 | # normal_img = np.reshape(normal, (H*W,C))
74 | R_bcam2cv = np.array([1, -1, -1], np.float32)
75 | normal_cv = normal * R_bcam2cv[None, None, :]
76 |
77 | print(np.shape(normal_cv))
78 |
79 | return normal_cv
80 |
81 | def inv_RT(RT):
82 | RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
83 | RT_inv = np.linalg.inv(RT_h)
84 |
85 | return RT_inv[:3, :]
86 |
87 |
88 | def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
89 | normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
90 |
91 | all_images = []
92 | all_normals = []
93 | all_normals_world = []
94 | all_masks = []
95 | all_color_masks = []
96 | all_poses = []
97 | all_w2cs = []
98 | directions = []
99 | ray_origins = []
100 |
101 | RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix
102 | RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv
103 | for idx, view in enumerate(view_types):
104 | print(os.path.join(root_dir,test_object))
105 | normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
106 | # Load key frame
107 | if load_color: # use bgr
108 | image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
109 |
110 | normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
111 | mask = normal[:, :, 3]
112 | normal = normal[:, :, :3]
113 |
114 | color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
115 | invalid_color_mask = color_mask < 255*0.5
116 | threshold = np.ones_like(image[:, :, 0]) * 250
117 | invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
118 | invalid_color_mask_final = invalid_color_mask & invalid_white_mask
119 | color_mask = (1 - invalid_color_mask_final) > 0
120 |
121 | # if erode_mask:
122 | # kernel = np.ones((3, 3), np.uint8)
123 | # mask = cv2.erode(mask, kernel, iterations=1)
124 |
125 | RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix
126 |
127 | normal = img2normal(normal)
128 |
129 | normal[mask==0] = [0,0,0]
130 | mask = mask> (0.5*255)
131 | if load_color:
132 | all_images.append(image)
133 |
134 | all_masks.append(mask)
135 | all_color_masks.append(color_mask)
136 | RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv
137 | all_poses.append(inv_RT(RT_cv)) # cam2world
138 | all_w2cs.append(RT_cv)
139 |
140 | # whether to
141 | normal_cam_cv = normal_opengl2opencv(normal)
142 |
143 | if normal_system == 'front':
144 | print("the loaded normals are defined in the system of front view")
145 | normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
146 | elif normal_system == 'self':
147 | print("the loaded normals are in their independent camera systems")
148 | normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
149 | all_normals.append(normal_cam_cv)
150 | all_normals_world.append(normal_world)
151 |
152 | if camera_type == 'ortho':
153 | origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
154 | elif camera_type == 'pinhole':
155 | dirs = get_ray_directions(W=imSize[0], H=imSize[1],
156 | fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
157 | origins = dirs # occupy a position
158 | else:
159 | raise Exception("not support camera type")
160 | ray_origins.append(origins)
161 | directions.append(dirs)
162 |
163 |
164 | if not load_color:
165 | all_images = [normal2img(x) for x in all_normals_world]
166 |
167 |
168 | return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
169 | np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
170 |
171 |
172 | class OrthoDatasetBase():
173 | def setup(self, config, split):
174 | self.config = config
175 | self.split = split
176 | self.rank = get_rank()
177 |
178 | self.data_dir = self.config.root_dir
179 | self.object_name = self.config.scene
180 | self.scene = self.config.scene
181 | self.imSize = self.config.imSize
182 | self.load_color = True
183 | self.img_wh = [self.imSize[0], self.imSize[1]]
184 | self.w = self.img_wh[0]
185 | self.h = self.img_wh[1]
186 | self.camera_type = self.config.camera_type
187 | self.camera_params = self.config.camera_params # [fx, fy, cx, cy]
188 |
189 | self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
190 |
191 | self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
192 | self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
193 |
194 | if self.config.cam_pose_dir is None:
195 | self.cam_pose_dir = "./datasets/fixed_poses"
196 | else:
197 | self.cam_pose_dir = self.config.cam_pose_dir
198 |
199 | self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
200 | self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
201 | self.data_dir, self.object_name, self.imSize, self.view_types,
202 | self.load_color, self.cam_pose_dir, normal_system='front',
203 | camera_type=self.camera_type, cam_params=self.camera_params)
204 |
205 | self.has_mask = True
206 | self.apply_mask = self.config.apply_mask
207 |
208 | self.all_c2w = torch.from_numpy(self.pose_all_np)
209 | self.all_images = torch.from_numpy(self.images_np) / 255.
210 | self.all_fg_masks = torch.from_numpy(self.masks_np)
211 | self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
212 | self.all_normals_world = torch.from_numpy(self.normals_world_np)
213 | self.origins = torch.from_numpy(self.origins_np)
214 | self.directions = torch.from_numpy(self.directions_np)
215 |
216 | self.directions = self.directions.float().to(self.rank)
217 | self.origins = self.origins.float().to(self.rank)
218 | self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
219 | self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
220 | self.all_c2w.float().to(self.rank), \
221 | self.all_images.float().to(self.rank), \
222 | self.all_fg_masks.float().to(self.rank), \
223 | self.all_normals_world.float().to(self.rank)
224 |
225 |
226 | class OrthoDataset(Dataset, OrthoDatasetBase):
227 | def __init__(self, config, split):
228 | self.setup(config, split)
229 |
230 | def __len__(self):
231 | return len(self.all_images)
232 |
233 | def __getitem__(self, index):
234 | return {
235 | 'index': index
236 | }
237 |
238 |
239 | class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
240 | def __init__(self, config, split):
241 | self.setup(config, split)
242 |
243 | def __iter__(self):
244 | while True:
245 | yield {}
246 |
247 |
248 | @datasets.register('ortho')
249 | class OrthoDataModule(pl.LightningDataModule):
250 | def __init__(self, config):
251 | super().__init__()
252 | self.config = config
253 |
254 | def setup(self, stage=None):
255 | if stage in [None, 'fit']:
256 | self.train_dataset = OrthoIterableDataset(self.config, 'train')
257 | if stage in [None, 'fit', 'validate']:
258 | self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
259 | if stage in [None, 'test']:
260 | self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
261 | if stage in [None, 'predict']:
262 | self.predict_dataset = OrthoDataset(self.config, 'train')
263 |
264 | def prepare_data(self):
265 | pass
266 |
267 | def general_loader(self, dataset, batch_size):
268 | sampler = None
269 | return DataLoader(
270 | dataset,
271 | num_workers=os.cpu_count(),
272 | batch_size=batch_size,
273 | pin_memory=True,
274 | sampler=sampler
275 | )
276 |
277 | def train_dataloader(self):
278 | return self.general_loader(self.train_dataset, batch_size=1)
279 |
280 | def val_dataloader(self):
281 | return self.general_loader(self.val_dataset, batch_size=1)
282 |
283 | def test_dataloader(self):
284 | return self.general_loader(self.test_dataset, batch_size=1)
285 |
286 | def predict_dataloader(self):
287 | return self.general_loader(self.predict_dataset, batch_size=1)
288 |
--------------------------------------------------------------------------------
/instant-nsr-pl/datasets/utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/instant-nsr-pl/datasets/utils.py
--------------------------------------------------------------------------------
/instant-nsr-pl/launch.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os
4 | import time
5 | import logging
6 | from datetime import datetime
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--config', required=True, help='path to config file')
12 | parser.add_argument('--gpu', default='0', help='GPU(s) to be used')
13 | parser.add_argument('--resume', default=None, help='path to the weights to be resumed')
14 | parser.add_argument(
15 | '--resume_weights_only',
16 | action='store_true',
17 | help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only'
18 | )
19 |
20 | group = parser.add_mutually_exclusive_group(required=True)
21 | group.add_argument('--train', action='store_true')
22 | group.add_argument('--validate', action='store_true')
23 | group.add_argument('--test', action='store_true')
24 | group.add_argument('--predict', action='store_true')
25 | # group.add_argument('--export', action='store_true') # TODO: a separate export action
26 |
27 | parser.add_argument('--exp_dir', default='./exp')
28 | parser.add_argument('--runs_dir', default='./runs')
29 | parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG')
30 |
31 | args, extras = parser.parse_known_args()
32 |
33 | # set CUDA_VISIBLE_DEVICES then import pytorch-lightning
34 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
36 | n_gpus = len(args.gpu.split(','))
37 |
38 | import datasets
39 | import systems
40 | import pytorch_lightning as pl
41 | from pytorch_lightning import Trainer
42 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
43 | from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
44 | from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar
45 | from utils.misc import load_config
46 |
47 | # parse YAML config to OmegaConf
48 | config = load_config(args.config, cli_args=extras)
49 | config.cmd_args = vars(args)
50 |
51 | config.trial_name = config.get('trial_name') or (config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S'))
52 | config.exp_dir = config.get('exp_dir') or os.path.join(args.exp_dir, config.name)
53 | config.save_dir = config.get('save_dir') or os.path.join(config.exp_dir, config.trial_name, 'save')
54 | config.ckpt_dir = config.get('ckpt_dir') or os.path.join(config.exp_dir, config.trial_name, 'ckpt')
55 | config.code_dir = config.get('code_dir') or os.path.join(config.exp_dir, config.trial_name, 'code')
56 | config.config_dir = config.get('config_dir') or os.path.join(config.exp_dir, config.trial_name, 'config')
57 |
58 | logger = logging.getLogger('pytorch_lightning')
59 | if args.verbose:
60 | logger.setLevel(logging.DEBUG)
61 |
62 | if 'seed' not in config:
63 | config.seed = int(time.time() * 1000) % 1000
64 | pl.seed_everything(config.seed)
65 |
66 | dm = datasets.make(config.dataset.name, config.dataset)
67 | system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume)
68 |
69 | callbacks = []
70 | if args.train:
71 | callbacks += [
72 | ModelCheckpoint(
73 | dirpath=config.ckpt_dir,
74 | **config.checkpoint
75 | ),
76 | LearningRateMonitor(logging_interval='step'),
77 | CodeSnapshotCallback(
78 | config.code_dir, use_version=False
79 | ),
80 | ConfigSnapshotCallback(
81 | config, config.config_dir, use_version=False
82 | ),
83 | CustomProgressBar(refresh_rate=1),
84 | ]
85 |
86 | loggers = []
87 | if args.train:
88 | loggers += [
89 | TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name),
90 | CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs')
91 | ]
92 |
93 | if sys.platform == 'win32':
94 | # does not support multi-gpu on windows
95 | strategy = 'dp'
96 | assert n_gpus == 1
97 | else:
98 | strategy = 'ddp_find_unused_parameters_false'
99 |
100 | trainer = Trainer(
101 | devices=n_gpus,
102 | accelerator='gpu',
103 | callbacks=callbacks,
104 | logger=loggers,
105 | strategy=strategy,
106 | **config.trainer
107 | )
108 |
109 | if args.train:
110 | if args.resume and not args.resume_weights_only:
111 | # FIXME: different behavior in pytorch-lighting>1.9 ?
112 | trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
113 | else:
114 | trainer.fit(system, datamodule=dm)
115 | trainer.test(system, datamodule=dm)
116 | elif args.validate:
117 | trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
118 | elif args.test:
119 | trainer.test(system, datamodule=dm, ckpt_path=args.resume)
120 | elif args.predict:
121 | trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
122 |
123 |
124 | if __name__ == '__main__':
125 | main()
126 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/__init__.py:
--------------------------------------------------------------------------------
1 | models = {}
2 |
3 |
4 | def register(name):
5 | def decorator(cls):
6 | models[name] = cls
7 | return cls
8 | return decorator
9 |
10 |
11 | def make(name, config):
12 | model = models[name](config)
13 | return model
14 |
15 |
16 | from . import nerf, neus, geometry, texture
17 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.misc import get_rank
5 |
6 | class BaseModel(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 | self.rank = get_rank()
11 | self.setup()
12 | if self.config.get('weights', None):
13 | self.load_state_dict(torch.load(self.config.weights))
14 |
15 | def setup(self):
16 | raise NotImplementedError
17 |
18 | def update_step(self, epoch, global_step):
19 | pass
20 |
21 | def train(self, mode=True):
22 | return super().train(mode=mode)
23 |
24 | def eval(self):
25 | return super().eval()
26 |
27 | def regularizations(self, out):
28 | return {}
29 |
30 | @torch.no_grad()
31 | def export(self, export_config):
32 | return {}
33 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/nerf.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import models
8 | from models.base import BaseModel
9 | from models.utils import chunk_batch
10 | from systems.utils import update_module_step
11 | from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays
12 |
13 |
14 | @models.register('nerf')
15 | class NeRFModel(BaseModel):
16 | def setup(self):
17 | self.geometry = models.make(self.config.geometry.name, self.config.geometry)
18 | self.texture = models.make(self.config.texture.name, self.config.texture)
19 | self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32))
20 |
21 | if self.config.learned_background:
22 | self.occupancy_grid_res = 256
23 | self.near_plane, self.far_plane = 0.2, 1e4
24 | self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate
25 | self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size)
26 | self.contraction_type = ContractionType.UN_BOUNDED_SPHERE
27 | else:
28 | self.occupancy_grid_res = 128
29 | self.near_plane, self.far_plane = None, None
30 | self.cone_angle = 0.0
31 | self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray
32 | self.contraction_type = ContractionType.AABB
33 |
34 | self.geometry.contraction_type = self.contraction_type
35 |
36 | if self.config.grid_prune:
37 | self.occupancy_grid = OccupancyGrid(
38 | roi_aabb=self.scene_aabb,
39 | resolution=self.occupancy_grid_res,
40 | contraction_type=self.contraction_type
41 | )
42 | self.randomized = self.config.randomized
43 | self.background_color = None
44 |
45 | def update_step(self, epoch, global_step):
46 | update_module_step(self.geometry, epoch, global_step)
47 | update_module_step(self.texture, epoch, global_step)
48 |
49 | def occ_eval_fn(x):
50 | density, _ = self.geometry(x)
51 | # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series
52 | return density[...,None] * self.render_step_size
53 |
54 | if self.training and self.config.grid_prune:
55 | self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn)
56 |
57 | def isosurface(self):
58 | mesh = self.geometry.isosurface()
59 | return mesh
60 |
61 | def forward_(self, rays):
62 | n_rays = rays.shape[0]
63 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
64 |
65 | def sigma_fn(t_starts, t_ends, ray_indices):
66 | ray_indices = ray_indices.long()
67 | t_origins = rays_o[ray_indices]
68 | t_dirs = rays_d[ray_indices]
69 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
70 | density, _ = self.geometry(positions)
71 | return density[...,None]
72 |
73 | def rgb_sigma_fn(t_starts, t_ends, ray_indices):
74 | ray_indices = ray_indices.long()
75 | t_origins = rays_o[ray_indices]
76 | t_dirs = rays_d[ray_indices]
77 | positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
78 | density, feature = self.geometry(positions)
79 | rgb = self.texture(feature, t_dirs)
80 | return rgb, density[...,None]
81 |
82 | with torch.no_grad():
83 | ray_indices, t_starts, t_ends = ray_marching(
84 | rays_o, rays_d,
85 | scene_aabb=None if self.config.learned_background else self.scene_aabb,
86 | grid=self.occupancy_grid if self.config.grid_prune else None,
87 | sigma_fn=sigma_fn,
88 | near_plane=self.near_plane, far_plane=self.far_plane,
89 | render_step_size=self.render_step_size,
90 | stratified=self.randomized,
91 | cone_angle=self.cone_angle,
92 | alpha_thre=0.0
93 | )
94 |
95 | ray_indices = ray_indices.long()
96 | t_origins = rays_o[ray_indices]
97 | t_dirs = rays_d[ray_indices]
98 | midpoints = (t_starts + t_ends) / 2.
99 | positions = t_origins + t_dirs * midpoints
100 | intervals = t_ends - t_starts
101 |
102 | density, feature = self.geometry(positions)
103 | rgb = self.texture(feature, t_dirs)
104 |
105 | weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays)
106 | opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays)
107 | depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays)
108 | comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays)
109 | comp_rgb = comp_rgb + self.background_color * (1.0 - opacity)
110 |
111 | out = {
112 | 'comp_rgb': comp_rgb,
113 | 'opacity': opacity,
114 | 'depth': depth,
115 | 'rays_valid': opacity > 0,
116 | 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device)
117 | }
118 |
119 | if self.training:
120 | out.update({
121 | 'weights': weights.view(-1),
122 | 'points': midpoints.view(-1),
123 | 'intervals': intervals.view(-1),
124 | 'ray_indices': ray_indices.view(-1)
125 | })
126 |
127 | return out
128 |
129 | def forward(self, rays):
130 | if self.training:
131 | out = self.forward_(rays)
132 | else:
133 | out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays)
134 | return {
135 | **out,
136 | }
137 |
138 | def train(self, mode=True):
139 | self.randomized = mode and self.config.randomized
140 | return super().train(mode=mode)
141 |
142 | def eval(self):
143 | self.randomized = False
144 | return super().eval()
145 |
146 | def regularizations(self, out):
147 | losses = {}
148 | losses.update(self.geometry.regularizations(out))
149 | losses.update(self.texture.regularizations(out))
150 | return losses
151 |
152 | @torch.no_grad()
153 | def export(self, export_config):
154 | mesh = self.isosurface()
155 | if export_config.export_vertex_color:
156 | _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank))
157 | viewdirs = torch.zeros(feature.shape[0], 3).to(feature)
158 | viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down)
159 | rgb = self.texture(feature, viewdirs).clamp(0,1)
160 | mesh['v_rgb'] = rgb.cpu()
161 | return mesh
162 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/network_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | import tinycudann as tcnn
7 |
8 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info
9 |
10 | from utils.misc import config_to_primitive, get_rank
11 | from models.utils import get_activation
12 | from systems.utils import update_module_step
13 |
14 | class VanillaFrequency(nn.Module):
15 | def __init__(self, in_channels, config):
16 | super().__init__()
17 | self.N_freqs = config['n_frequencies']
18 | self.in_channels, self.n_input_dims = in_channels, in_channels
19 | self.funcs = [torch.sin, torch.cos]
20 | self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs)
21 | self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs)
22 | self.n_masking_step = config.get('n_masking_step', 0)
23 | self.update_step(None, None) # mask should be updated at the beginning each step
24 |
25 | def forward(self, x):
26 | out = []
27 | for freq, mask in zip(self.freq_bands, self.mask):
28 | for func in self.funcs:
29 | out += [func(freq*x) * mask]
30 | return torch.cat(out, -1)
31 |
32 | def update_step(self, epoch, global_step):
33 | if self.n_masking_step <= 0 or global_step is None:
34 | self.mask = torch.ones(self.N_freqs, dtype=torch.float32)
35 | else:
36 | self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2.
37 | rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}')
38 |
39 |
40 | class ProgressiveBandHashGrid(nn.Module):
41 | def __init__(self, in_channels, config):
42 | super().__init__()
43 | self.n_input_dims = in_channels
44 | encoding_config = config.copy()
45 | encoding_config['otype'] = 'HashGrid'
46 | with torch.cuda.device(get_rank()):
47 | self.encoding = tcnn.Encoding(in_channels, encoding_config)
48 | self.n_output_dims = self.encoding.n_output_dims
49 | self.n_level = config['n_levels']
50 | self.n_features_per_level = config['n_features_per_level']
51 | self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps']
52 | self.current_level = self.start_level
53 | self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank())
54 |
55 | def forward(self, x):
56 | enc = self.encoding(x)
57 | enc = enc * self.mask
58 | return enc
59 |
60 | def update_step(self, epoch, global_step):
61 | current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level)
62 | if current_level > self.current_level:
63 | rank_zero_info(f'Update grid level to {current_level}')
64 | self.current_level = current_level
65 | self.mask[:self.current_level * self.n_features_per_level] = 1.
66 |
67 |
68 | class CompositeEncoding(nn.Module):
69 | def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.):
70 | super(CompositeEncoding, self).__init__()
71 | self.encoding = encoding
72 | self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset
73 | self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims
74 |
75 | def forward(self, x, *args):
76 | return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1)
77 |
78 | def update_step(self, epoch, global_step):
79 | update_module_step(self.encoding, epoch, global_step)
80 |
81 |
82 | def get_encoding(n_input_dims, config):
83 | # input suppose to be range [0, 1]
84 | if config.otype == 'VanillaFrequency':
85 | encoding = VanillaFrequency(n_input_dims, config_to_primitive(config))
86 | elif config.otype == 'ProgressiveBandHashGrid':
87 | encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config))
88 | else:
89 | with torch.cuda.device(get_rank()):
90 | encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config))
91 | encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.)
92 | return encoding
93 |
94 |
95 | class VanillaMLP(nn.Module):
96 | def __init__(self, dim_in, dim_out, config):
97 | super().__init__()
98 | self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers']
99 | self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False)
100 | self.sphere_init_radius = config.get('sphere_init_radius', 0.5)
101 | self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
102 | for i in range(self.n_hidden_layers - 1):
103 | self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
104 | self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
105 | self.layers = nn.Sequential(*self.layers)
106 | self.output_activation = get_activation(config['output_activation'])
107 |
108 | @torch.cuda.amp.autocast(False)
109 | def forward(self, x):
110 | x = self.layers(x.float())
111 | x = self.output_activation(x)
112 | return x
113 |
114 | def make_linear(self, dim_in, dim_out, is_first, is_last):
115 | layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
116 | if self.sphere_init:
117 | if is_last:
118 | torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
119 | torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
120 | elif is_first:
121 | torch.nn.init.constant_(layer.bias, 0.0)
122 | torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
123 | torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
124 | else:
125 | torch.nn.init.constant_(layer.bias, 0.0)
126 | torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
127 | else:
128 | torch.nn.init.constant_(layer.bias, 0.0)
129 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
130 |
131 | if self.weight_norm:
132 | layer = nn.utils.weight_norm(layer)
133 | return layer
134 |
135 | def make_activation(self):
136 | if self.sphere_init:
137 | return nn.Softplus(beta=100)
138 | else:
139 | return nn.ReLU(inplace=True)
140 |
141 |
142 | def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network):
143 | rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.')
144 | """
145 | from https://github.com/NVlabs/tiny-cuda-nn/issues/96
146 | It's the weight matrices of each layer laid out in row-major order and then concatenated.
147 | Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP).
148 | The padded input dimensions get a constant value of 1.0,
149 | whereas the padded output dimensions are simply ignored,
150 | so the weights pertaining to those can have any value.
151 | """
152 | padto = 16 if config.otype == 'FullyFusedMLP' else 8
153 | n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto
154 | n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto
155 | data = list(network.parameters())[0].data
156 | assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2
157 | new_data = []
158 | # first layer
159 | weight = torch.zeros((config.n_neurons, n_input_dims)).to(data)
160 | torch.nn.init.constant_(weight[:, 3:], 0.0)
161 | torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
162 | new_data.append(weight.flatten())
163 | # hidden layers
164 | for i in range(config.n_hidden_layers - 1):
165 | weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data)
166 | torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
167 | new_data.append(weight.flatten())
168 | # last layer
169 | weight = torch.zeros((n_output_dims, config.n_neurons)).to(data)
170 | torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001)
171 | new_data.append(weight.flatten())
172 | new_data = torch.cat(new_data)
173 | data.copy_(new_data)
174 |
175 |
176 | def get_mlp(n_input_dims, n_output_dims, config):
177 | if config.otype == 'VanillaMLP':
178 | network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config))
179 | else:
180 | with torch.cuda.device(get_rank()):
181 | network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config))
182 | if config.get('sphere_init', False):
183 | sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network)
184 | return network
185 |
186 |
187 | class EncodingWithNetwork(nn.Module):
188 | def __init__(self, encoding, network):
189 | super().__init__()
190 | self.encoding, self.network = encoding, network
191 |
192 | def forward(self, x):
193 | return self.network(self.encoding(x))
194 |
195 | def update_step(self, epoch, global_step):
196 | update_module_step(self.encoding, epoch, global_step)
197 | update_module_step(self.network, epoch, global_step)
198 |
199 |
200 | def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config):
201 | # input suppose to be range [0, 1]
202 | if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \
203 | or network_config.otype in ['VanillaMLP']:
204 | encoding = get_encoding(n_input_dims, encoding_config)
205 | network = get_mlp(encoding.n_output_dims, n_output_dims, network_config)
206 | encoding_with_network = EncodingWithNetwork(encoding, network)
207 | else:
208 | with torch.cuda.device(get_rank()):
209 | encoding_with_network = tcnn.NetworkWithInputEncoding(
210 | n_input_dims=n_input_dims,
211 | n_output_dims=n_output_dims,
212 | encoding_config=config_to_primitive(encoding_config),
213 | network_config=config_to_primitive(network_config)
214 | )
215 | return encoding_with_network
216 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def cast_rays(ori, dir, z_vals):
6 | return ori[..., None, :] + z_vals[..., None] * dir[..., None, :]
7 |
8 |
9 | def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True):
10 | pixel_center = 0.5 if use_pixel_centers else 0
11 | i, j = np.meshgrid(
12 | np.arange(W, dtype=np.float32) + pixel_center,
13 | np.arange(H, dtype=np.float32) + pixel_center,
14 | indexing='xy'
15 | )
16 | i, j = torch.from_numpy(i), torch.from_numpy(j)
17 |
18 | # directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3)
19 | # opencv system
20 | directions = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1) # (H, W, 3)
21 |
22 | return directions
23 |
24 |
25 | def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True):
26 | pixel_center = 0.5 if use_pixel_centers else 0
27 | i, j = np.meshgrid(
28 | np.arange(W, dtype=np.float32) + pixel_center,
29 | np.arange(H, dtype=np.float32) + pixel_center,
30 | indexing='xy'
31 | )
32 | i, j = torch.from_numpy(i), torch.from_numpy(j)
33 |
34 | origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2, torch.zeros_like(i)], dim=-1) # W, H, 3
35 | directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
36 |
37 | return origins, directions
38 |
39 |
40 | def get_rays(directions, c2w, keepdim=False):
41 | # Rotate ray directions from camera coordinate to the world coordinate
42 | # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow?
43 | assert directions.shape[-1] == 3
44 |
45 | if directions.ndim == 2: # (N_rays, 3)
46 | assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4)
47 | rays_d = (directions[:,None,:] * c2w[:,:3,:3]).sum(-1) # (N_rays, 3)
48 | rays_o = c2w[:,:,3].expand(rays_d.shape)
49 | elif directions.ndim == 3: # (H, W, 3)
50 | if c2w.ndim == 2: # (4, 4)
51 | rays_d = (directions[:,:,None,:] * c2w[None,None,:3,:3]).sum(-1) # (H, W, 3)
52 | rays_o = c2w[None,None,:,3].expand(rays_d.shape)
53 | elif c2w.ndim == 3: # (B, 4, 4)
54 | rays_d = (directions[None,:,:,None,:] * c2w[:,None,None,:3,:3]).sum(-1) # (B, H, W, 3)
55 | rays_o = c2w[:,None,None,:,3].expand(rays_d.shape)
56 |
57 | if not keepdim:
58 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
59 |
60 | return rays_o, rays_d
61 |
62 |
63 | # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), rays_v[:, :, :, None].cuda()).squeeze() # W, H, 3
64 |
65 | # rays_o = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), q[:, :, :, None].cuda()).squeeze() # W, H, 3
66 | # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape).cuda() + rays_o # W, H, 3
67 |
68 | def get_ortho_rays(origins, directions, c2w, keepdim=False):
69 | # Rotate ray directions from camera coordinate to the world coordinate
70 | # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow?
71 | assert directions.shape[-1] == 3
72 | assert origins.shape[-1] == 3
73 |
74 | if directions.ndim == 2: # (N_rays, 3)
75 | assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4)
76 | rays_d = torch.matmul(c2w[:, :3, :3], directions[:, :, None]).squeeze() # (N_rays, 3)
77 | rays_o = torch.matmul(c2w[:, :3, :3], origins[:, :, None]).squeeze() # (N_rays, 3)
78 | rays_o = c2w[:,:3,3].expand(rays_d.shape) + rays_o
79 | elif directions.ndim == 3: # (H, W, 3)
80 | if c2w.ndim == 2: # (4, 4)
81 | rays_d = torch.matmul(c2w[None, None, :3, :3], directions[:, :, :, None]).squeeze() # (H, W, 3)
82 | rays_o = torch.matmul(c2w[None, None, :3, :3], origins[:, :, :, None]).squeeze() # (H, W, 3)
83 | rays_o = c2w[None, None,:3,3].expand(rays_d.shape) + rays_o
84 | elif c2w.ndim == 3: # (B, 4, 4)
85 | rays_d = torch.matmul(c2w[:,None, None, :3, :3], directions[None, :, :, :, None]).squeeze() # # (B, H, W, 3)
86 | rays_o = torch.matmul(c2w[:,None, None, :3, :3], origins[None, :, :, :, None]).squeeze() # # (B, H, W, 3)
87 | rays_o = c2w[:,None, None, :3,3].expand(rays_d.shape) + rays_o
88 |
89 | if not keepdim:
90 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
91 |
92 | return rays_o, rays_d
93 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/texture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import models
5 | from models.utils import get_activation
6 | from models.network_utils import get_encoding, get_mlp
7 | from systems.utils import update_module_step
8 |
9 |
10 | @models.register('volume-radiance')
11 | class VolumeRadiance(nn.Module):
12 | def __init__(self, config):
13 | super(VolumeRadiance, self).__init__()
14 | self.config = config
15 | self.with_viewdir = False #self.config.get('wo_viewdir', False)
16 | self.n_dir_dims = self.config.get('n_dir_dims', 3)
17 | self.n_output_dims = 3
18 |
19 | if self.with_viewdir:
20 | encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config)
21 | self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims
22 | # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config)
23 | else:
24 | encoding = None
25 | self.n_input_dims = self.config.input_feature_dim
26 |
27 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config)
28 | self.encoding = encoding
29 | self.network = network
30 |
31 | def forward(self, features, dirs, *args):
32 |
33 | # features = features.detach()
34 | if self.with_viewdir:
35 | dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1)
36 | dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims))
37 | network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
38 | # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
39 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()
40 | # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float()
41 | # color = color + color_base
42 | else:
43 | network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1)
44 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()
45 |
46 | if 'color_activation' in self.config:
47 | color = get_activation(self.config.color_activation)(color)
48 | return color
49 |
50 | def update_step(self, epoch, global_step):
51 | update_module_step(self.encoding, epoch, global_step)
52 |
53 | def regularizations(self, out):
54 | return {}
55 |
56 |
57 | @models.register('volume-color')
58 | class VolumeColor(nn.Module):
59 | def __init__(self, config):
60 | super(VolumeColor, self).__init__()
61 | self.config = config
62 | self.n_output_dims = 3
63 | self.n_input_dims = self.config.input_feature_dim
64 | network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config)
65 | self.network = network
66 |
67 | def forward(self, features, *args):
68 | network_inp = features.view(-1, features.shape[-1])
69 | color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float()
70 | if 'color_activation' in self.config:
71 | color = get_activation(self.config.color_activation)(color)
72 | return color
73 |
74 | def regularizations(self, out):
75 | return {}
76 |
--------------------------------------------------------------------------------
/instant-nsr-pl/models/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | from collections import defaultdict
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 | from torch.cuda.amp import custom_bwd, custom_fwd
9 |
10 | import tinycudann as tcnn
11 |
12 |
13 | def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs):
14 | B = None
15 | for arg in args:
16 | if isinstance(arg, torch.Tensor):
17 | B = arg.shape[0]
18 | break
19 | out = defaultdict(list)
20 | out_type = None
21 | for i in range(0, B, chunk_size):
22 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs)
23 | if out_chunk is None:
24 | continue
25 | out_type = type(out_chunk)
26 | if isinstance(out_chunk, torch.Tensor):
27 | out_chunk = {0: out_chunk}
28 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
29 | chunk_length = len(out_chunk)
30 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
31 | elif isinstance(out_chunk, dict):
32 | pass
33 | else:
34 | print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.')
35 | exit(1)
36 | for k, v in out_chunk.items():
37 | v = v if torch.is_grad_enabled() else v.detach()
38 | v = v.cpu() if move_to_cpu else v
39 | out[k].append(v)
40 |
41 | if out_type is None:
42 | return
43 |
44 | out = {k: torch.cat(v, dim=0) for k, v in out.items()}
45 | if out_type is torch.Tensor:
46 | return out[0]
47 | elif out_type in [tuple, list]:
48 | return out_type([out[i] for i in range(chunk_length)])
49 | elif out_type is dict:
50 | return out
51 |
52 |
53 | class _TruncExp(Function): # pylint: disable=abstract-method
54 | # Implementation from torch-ngp:
55 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
56 | @staticmethod
57 | @custom_fwd(cast_inputs=torch.float32)
58 | def forward(ctx, x): # pylint: disable=arguments-differ
59 | ctx.save_for_backward(x)
60 | return torch.exp(x)
61 |
62 | @staticmethod
63 | @custom_bwd
64 | def backward(ctx, g): # pylint: disable=arguments-differ
65 | x = ctx.saved_tensors[0]
66 | return g * torch.exp(torch.clamp(x, max=15))
67 |
68 | trunc_exp = _TruncExp.apply
69 |
70 |
71 | def get_activation(name):
72 | if name is None:
73 | return lambda x: x
74 | name = name.lower()
75 | if name == 'none':
76 | return lambda x: x
77 | elif name.startswith('scale'):
78 | scale_factor = float(name[5:])
79 | return lambda x: x.clamp(0., scale_factor) / scale_factor
80 | elif name.startswith('clamp'):
81 | clamp_max = float(name[5:])
82 | return lambda x: x.clamp(0., clamp_max)
83 | elif name.startswith('mul'):
84 | mul_factor = float(name[3:])
85 | return lambda x: x * mul_factor
86 | elif name == 'lin2srgb':
87 | return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.)
88 | elif name == 'trunc_exp':
89 | return trunc_exp
90 | elif name.startswith('+') or name.startswith('-'):
91 | return lambda x: x + float(name)
92 | elif name == 'sigmoid':
93 | return lambda x: torch.sigmoid(x)
94 | elif name == 'tanh':
95 | return lambda x: torch.tanh(x)
96 | else:
97 | return getattr(F, name)
98 |
99 |
100 | def dot(x, y):
101 | return torch.sum(x*y, -1, keepdim=True)
102 |
103 |
104 | def reflect(x, n):
105 | return 2 * dot(x, n) * n - x
106 |
107 |
108 | def scale_anything(dat, inp_scale, tgt_scale):
109 | if inp_scale is None:
110 | inp_scale = [dat.min(), dat.max()]
111 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
112 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
113 | return dat
114 |
115 |
116 | def cleanup():
117 | gc.collect()
118 | torch.cuda.empty_cache()
119 | tcnn.free_temporary_memory()
120 |
--------------------------------------------------------------------------------
/instant-nsr-pl/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning<2
2 | omegaconf==2.2.3
3 | nerfacc==0.3.3
4 | matplotlib
5 | opencv-python
6 | imageio
7 | imageio-ffmpeg
8 | scipy
9 | PyMCubes
10 | pyransac3d
11 | torch_efficient_distloss
12 | tensorboard
13 |
--------------------------------------------------------------------------------
/instant-nsr-pl/run.sh:
--------------------------------------------------------------------------------
1 | python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=$1 dataset.scene=$2
--------------------------------------------------------------------------------
/instant-nsr-pl/scripts/imgs2poses.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | This file is adapted from https://github.com/Fyusion/LLFF.
4 | """
5 |
6 | import os
7 | import sys
8 | import argparse
9 | import subprocess
10 |
11 |
12 | def run_colmap(basedir, match_type):
13 | logfile_name = os.path.join(basedir, 'colmap_output.txt')
14 | logfile = open(logfile_name, 'w')
15 |
16 | feature_extractor_args = [
17 | 'colmap', 'feature_extractor',
18 | '--database_path', os.path.join(basedir, 'database.db'),
19 | '--image_path', os.path.join(basedir, 'images'),
20 | '--ImageReader.single_camera', '1'
21 | ]
22 | feat_output = ( subprocess.check_output(feature_extractor_args, universal_newlines=True) )
23 | logfile.write(feat_output)
24 | print('Features extracted')
25 |
26 | exhaustive_matcher_args = [
27 | 'colmap', match_type,
28 | '--database_path', os.path.join(basedir, 'database.db'),
29 | ]
30 |
31 | match_output = ( subprocess.check_output(exhaustive_matcher_args, universal_newlines=True) )
32 | logfile.write(match_output)
33 | print('Features matched')
34 |
35 | p = os.path.join(basedir, 'sparse')
36 | if not os.path.exists(p):
37 | os.makedirs(p)
38 |
39 | mapper_args = [
40 | 'colmap', 'mapper',
41 | '--database_path', os.path.join(basedir, 'database.db'),
42 | '--image_path', os.path.join(basedir, 'images'),
43 | '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6
44 | '--Mapper.num_threads', '16',
45 | '--Mapper.init_min_tri_angle', '4',
46 | '--Mapper.multiple_models', '0',
47 | '--Mapper.extract_colors', '0',
48 | ]
49 |
50 | map_output = ( subprocess.check_output(mapper_args, universal_newlines=True) )
51 | logfile.write(map_output)
52 | logfile.close()
53 | print('Sparse map created')
54 |
55 | print( 'Finished running COLMAP, see {} for logs'.format(logfile_name) )
56 |
57 |
58 | def gen_poses(basedir, match_type):
59 | files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']]
60 | if os.path.exists(os.path.join(basedir, 'sparse/0')):
61 | files_had = os.listdir(os.path.join(basedir, 'sparse/0'))
62 | else:
63 | files_had = []
64 | if not all([f in files_had for f in files_needed]):
65 | print( 'Need to run COLMAP' )
66 | run_colmap(basedir, match_type)
67 | else:
68 | print('Don\'t need to run COLMAP')
69 |
70 | return True
71 |
72 |
73 | if __name__=='__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--match_type', type=str,
76 | default='exhaustive_matcher', help='type of matcher used. Valid options: \
77 | exhaustive_matcher sequential_matcher. Other matchers not supported at this time')
78 | parser.add_argument('scenedir', type=str,
79 | help='input scene directory')
80 | args = parser.parse_args()
81 |
82 | if args.match_type != 'exhaustive_matcher' and args.match_type != 'sequential_matcher':
83 | print('ERROR: matcher type ' + args.match_type + ' is not valid. Aborting')
84 | sys.exit()
85 | gen_poses(args.scenedir, args.match_type)
86 |
--------------------------------------------------------------------------------
/instant-nsr-pl/systems/__init__.py:
--------------------------------------------------------------------------------
1 | systems = {}
2 |
3 |
4 | def register(name):
5 | def decorator(cls):
6 | systems[name] = cls
7 | return cls
8 | return decorator
9 |
10 |
11 | def make(name, config, load_from_checkpoint=None):
12 | if load_from_checkpoint is None:
13 | system = systems[name](config)
14 | else:
15 | system = systems[name].load_from_checkpoint(load_from_checkpoint, strict=False, config=config)
16 | return system
17 |
18 |
19 | from . import neus, neus_ortho, neus_pinhole
20 |
--------------------------------------------------------------------------------
/instant-nsr-pl/systems/base.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 | import models
4 | from systems.utils import parse_optimizer, parse_scheduler, update_module_step
5 | from utils.mixins import SaverMixin
6 | from utils.misc import config_to_primitive, get_rank
7 |
8 |
9 | class BaseSystem(pl.LightningModule, SaverMixin):
10 | """
11 | Two ways to print to console:
12 | 1. self.print: correctly handle progress bar
13 | 2. rank_zero_info: use the logging module
14 | """
15 | def __init__(self, config):
16 | super().__init__()
17 | self.config = config
18 | self.rank = get_rank()
19 | self.prepare()
20 | self.model = models.make(self.config.model.name, self.config.model)
21 |
22 | def prepare(self):
23 | pass
24 |
25 | def forward(self, batch):
26 | raise NotImplementedError
27 |
28 | def C(self, value):
29 | if isinstance(value, int) or isinstance(value, float):
30 | pass
31 | else:
32 | value = config_to_primitive(value)
33 | if not isinstance(value, list):
34 | raise TypeError('Scalar specification only supports list, got', type(value))
35 | if len(value) == 3:
36 | value = [0] + value
37 | assert len(value) == 4
38 | start_step, start_value, end_value, end_step = value
39 | if isinstance(end_step, int):
40 | current_step = self.global_step
41 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0)
42 | elif isinstance(end_step, float):
43 | current_step = self.current_epoch
44 | value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0)
45 | return value
46 |
47 | def preprocess_data(self, batch, stage):
48 | pass
49 |
50 | """
51 | Implementing on_after_batch_transfer of DataModule does the same.
52 | But on_after_batch_transfer does not support DP.
53 | """
54 | def on_train_batch_start(self, batch, batch_idx, unused=0):
55 | self.dataset = self.trainer.datamodule.train_dataloader().dataset
56 | self.preprocess_data(batch, 'train')
57 | update_module_step(self.model, self.current_epoch, self.global_step)
58 |
59 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
60 | self.dataset = self.trainer.datamodule.val_dataloader().dataset
61 | self.preprocess_data(batch, 'validation')
62 | update_module_step(self.model, self.current_epoch, self.global_step)
63 |
64 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
65 | self.dataset = self.trainer.datamodule.test_dataloader().dataset
66 | self.preprocess_data(batch, 'test')
67 | update_module_step(self.model, self.current_epoch, self.global_step)
68 |
69 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx):
70 | self.dataset = self.trainer.datamodule.predict_dataloader().dataset
71 | self.preprocess_data(batch, 'predict')
72 | update_module_step(self.model, self.current_epoch, self.global_step)
73 |
74 | def training_step(self, batch, batch_idx):
75 | raise NotImplementedError
76 |
77 | """
78 | # aggregate outputs from different devices (DP)
79 | def training_step_end(self, out):
80 | pass
81 | """
82 |
83 | """
84 | # aggregate outputs from different iterations
85 | def training_epoch_end(self, out):
86 | pass
87 | """
88 |
89 | def validation_step(self, batch, batch_idx):
90 | raise NotImplementedError
91 |
92 | """
93 | # aggregate outputs from different devices when using DP
94 | def validation_step_end(self, out):
95 | pass
96 | """
97 |
98 | def validation_epoch_end(self, out):
99 | """
100 | Gather metrics from all devices, compute mean.
101 | Purge repeated results using data index.
102 | """
103 | raise NotImplementedError
104 |
105 | def test_step(self, batch, batch_idx):
106 | raise NotImplementedError
107 |
108 | def test_epoch_end(self, out):
109 | """
110 | Gather metrics from all devices, compute mean.
111 | Purge repeated results using data index.
112 | """
113 | raise NotImplementedError
114 |
115 | def export(self):
116 | raise NotImplementedError
117 |
118 | def configure_optimizers(self):
119 | optim = parse_optimizer(self.config.system.optimizer, self.model)
120 | ret = {
121 | 'optimizer': optim,
122 | }
123 | if 'scheduler' in self.config.system:
124 | ret.update({
125 | 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim),
126 | })
127 | return ret
128 |
129 |
--------------------------------------------------------------------------------
/instant-nsr-pl/systems/criterions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class WeightedLoss(nn.Module):
7 | @property
8 | def func(self):
9 | raise NotImplementedError
10 |
11 | def forward(self, inputs, targets, weight=None, reduction='mean'):
12 | assert reduction in ['none', 'sum', 'mean', 'valid_mean']
13 | loss = self.func(inputs, targets, reduction='none')
14 | if weight is not None:
15 | while weight.ndim < inputs.ndim:
16 | weight = weight[..., None]
17 | loss *= weight.float()
18 | if reduction == 'none':
19 | return loss
20 | elif reduction == 'sum':
21 | return loss.sum()
22 | elif reduction == 'mean':
23 | return loss.mean()
24 | elif reduction == 'valid_mean':
25 | return loss.sum() / weight.float().sum()
26 |
27 |
28 | class MSELoss(WeightedLoss):
29 | @property
30 | def func(self):
31 | return F.mse_loss
32 |
33 |
34 | class L1Loss(WeightedLoss):
35 | @property
36 | def func(self):
37 | return F.l1_loss
38 |
39 |
40 | class PSNR(nn.Module):
41 | def __init__(self):
42 | super().__init__()
43 |
44 | def forward(self, inputs, targets, valid_mask=None, reduction='mean'):
45 | assert reduction in ['mean', 'none']
46 | value = (inputs - targets)**2
47 | if valid_mask is not None:
48 | value = value[valid_mask]
49 | if reduction == 'mean':
50 | return -10 * torch.log10(torch.mean(value))
51 | elif reduction == 'none':
52 | return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:])))
53 |
54 |
55 | class SSIM():
56 | def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True):
57 | self.kernel_size = kernel_size
58 | self.sigma = sigma
59 | self.gaussian = gaussian
60 |
61 | if any(x % 2 == 0 or x <= 0 for x in self.kernel_size):
62 | raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.")
63 | if any(y <= 0 for y in self.sigma):
64 | raise ValueError(f"Expected sigma to have positive number. Got {sigma}.")
65 |
66 | data_scale = data_range[1] - data_range[0]
67 | self.c1 = (k1 * data_scale)**2
68 | self.c2 = (k2 * data_scale)**2
69 | self.pad_h = (self.kernel_size[0] - 1) // 2
70 | self.pad_w = (self.kernel_size[1] - 1) // 2
71 | self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
72 |
73 | def _uniform(self, kernel_size):
74 | max, min = 2.5, -2.5
75 | ksize_half = (kernel_size - 1) * 0.5
76 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
77 | for i, j in enumerate(kernel):
78 | if min <= j <= max:
79 | kernel[i] = 1 / (max - min)
80 | else:
81 | kernel[i] = 0
82 |
83 | return kernel.unsqueeze(dim=0) # (1, kernel_size)
84 |
85 | def _gaussian(self, kernel_size, sigma):
86 | ksize_half = (kernel_size - 1) * 0.5
87 | kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
88 | gauss = torch.exp(-0.5 * (kernel / sigma).pow(2))
89 | return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
90 |
91 | def _gaussian_or_uniform_kernel(self, kernel_size, sigma):
92 | if self.gaussian:
93 | kernel_x = self._gaussian(kernel_size[0], sigma[0])
94 | kernel_y = self._gaussian(kernel_size[1], sigma[1])
95 | else:
96 | kernel_x = self._uniform(kernel_size[0])
97 | kernel_y = self._uniform(kernel_size[1])
98 |
99 | return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)
100 |
101 | def __call__(self, output, target, reduction='mean'):
102 | if output.dtype != target.dtype:
103 | raise TypeError(
104 | f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}."
105 | )
106 |
107 | if output.shape != target.shape:
108 | raise ValueError(
109 | f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}."
110 | )
111 |
112 | if len(output.shape) != 4 or len(target.shape) != 4:
113 | raise ValueError(
114 | f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}."
115 | )
116 |
117 | assert reduction in ['mean', 'sum', 'none']
118 |
119 | channel = output.size(1)
120 | if len(self._kernel.shape) < 4:
121 | self._kernel = self._kernel.expand(channel, 1, -1, -1)
122 |
123 | output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
124 | target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
125 |
126 | input_list = torch.cat([output, target, output * output, target * target, output * target])
127 | outputs = F.conv2d(input_list, self._kernel, groups=channel)
128 |
129 | output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))]
130 |
131 | mu_pred_sq = output_list[0].pow(2)
132 | mu_target_sq = output_list[1].pow(2)
133 | mu_pred_target = output_list[0] * output_list[1]
134 |
135 | sigma_pred_sq = output_list[2] - mu_pred_sq
136 | sigma_target_sq = output_list[3] - mu_target_sq
137 | sigma_pred_target = output_list[4] - mu_pred_target
138 |
139 | a1 = 2 * mu_pred_target + self.c1
140 | a2 = 2 * sigma_pred_target + self.c2
141 | b1 = mu_pred_sq + mu_target_sq + self.c1
142 | b2 = sigma_pred_sq + sigma_target_sq + self.c2
143 |
144 | ssim_idx = (a1 * a2) / (b1 * b2)
145 | _ssim = torch.mean(ssim_idx, (1, 2, 3))
146 |
147 | if reduction == 'none':
148 | return _ssim
149 | elif reduction == 'sum':
150 | return _ssim.sum()
151 | elif reduction == 'mean':
152 | return _ssim.mean()
153 |
154 |
155 | def binary_cross_entropy(input, target, reduction='mean'):
156 | """
157 | F.binary_cross_entropy is not numerically stable in mixed-precision training.
158 | """
159 | loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input))
160 |
161 | if reduction == 'mean':
162 | return loss.mean()
163 | elif reduction == 'none':
164 | return loss
165 |
--------------------------------------------------------------------------------
/instant-nsr-pl/systems/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_efficient_distloss import flatten_eff_distloss
5 |
6 | import pytorch_lightning as pl
7 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug
8 |
9 | import models
10 | from models.ray_utils import get_rays
11 | import systems
12 | from systems.base import BaseSystem
13 | from systems.criterions import PSNR
14 |
15 |
16 | @systems.register('nerf-system')
17 | class NeRFSystem(BaseSystem):
18 | """
19 | Two ways to print to console:
20 | 1. self.print: correctly handle progress bar
21 | 2. rank_zero_info: use the logging module
22 | """
23 | def prepare(self):
24 | self.criterions = {
25 | 'psnr': PSNR()
26 | }
27 | self.train_num_samples = self.config.model.train_num_rays * self.config.model.num_samples_per_ray
28 | self.train_num_rays = self.config.model.train_num_rays
29 |
30 | def forward(self, batch):
31 | return self.model(batch['rays'])
32 |
33 | def preprocess_data(self, batch, stage):
34 | if 'index' in batch: # validation / testing
35 | index = batch['index']
36 | else:
37 | if self.config.model.batch_image_sampling:
38 | index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device)
39 | else:
40 | index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device)
41 | if stage in ['train']:
42 | c2w = self.dataset.all_c2w[index]
43 | x = torch.randint(
44 | 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device
45 | )
46 | y = torch.randint(
47 | 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device
48 | )
49 | if self.dataset.directions.ndim == 3: # (H, W, 3)
50 | directions = self.dataset.directions[y, x]
51 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3)
52 | directions = self.dataset.directions[index, y, x]
53 | rays_o, rays_d = get_rays(directions, c2w)
54 | rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank)
55 | fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank)
56 | else:
57 | c2w = self.dataset.all_c2w[index][0]
58 | if self.dataset.directions.ndim == 3: # (H, W, 3)
59 | directions = self.dataset.directions
60 | elif self.dataset.directions.ndim == 4: # (N, H, W, 3)
61 | directions = self.dataset.directions[index][0]
62 | rays_o, rays_d = get_rays(directions, c2w)
63 | rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank)
64 | fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank)
65 |
66 | rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1)
67 |
68 | if stage in ['train']:
69 | if self.config.model.background_color == 'white':
70 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank)
71 | elif self.config.model.background_color == 'random':
72 | self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank)
73 | else:
74 | raise NotImplementedError
75 | else:
76 | self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank)
77 |
78 | if self.dataset.apply_mask:
79 | rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None])
80 |
81 | batch.update({
82 | 'rays': rays,
83 | 'rgb': rgb,
84 | 'fg_mask': fg_mask
85 | })
86 |
87 | def training_step(self, batch, batch_idx):
88 | out = self(batch)
89 |
90 | loss = 0.
91 |
92 | # update train_num_rays
93 | if self.config.model.dynamic_ray_sampling:
94 | train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples'].sum().item()))
95 | self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays)
96 |
97 | loss_rgb = F.smooth_l1_loss(out['comp_rgb'][out['rays_valid'][...,0]], batch['rgb'][out['rays_valid'][...,0]])
98 | self.log('train/loss_rgb', loss_rgb)
99 | loss += loss_rgb * self.C(self.config.system.loss.lambda_rgb)
100 |
101 | # distortion loss proposed in MipNeRF360
102 | # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss, but still slows down training by ~30%
103 | if self.C(self.config.system.loss.lambda_distortion) > 0:
104 | loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices'])
105 | self.log('train/loss_distortion', loss_distortion)
106 | loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion)
107 |
108 | losses_model_reg = self.model.regularizations(out)
109 | for name, value in losses_model_reg.items():
110 | self.log(f'train/loss_{name}', value)
111 | loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"])
112 | loss += loss_
113 |
114 | for name, value in self.config.system.loss.items():
115 | if name.startswith('lambda'):
116 | self.log(f'train_params/{name}', self.C(value))
117 |
118 | self.log('train/num_rays', float(self.train_num_rays), prog_bar=True)
119 |
120 | return {
121 | 'loss': loss
122 | }
123 |
124 | """
125 | # aggregate outputs from different devices (DP)
126 | def training_step_end(self, out):
127 | pass
128 | """
129 |
130 | """
131 | # aggregate outputs from different iterations
132 | def training_epoch_end(self, out):
133 | pass
134 | """
135 |
136 | def validation_step(self, batch, batch_idx):
137 | out = self(batch)
138 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb'])
139 | W, H = self.dataset.img_wh
140 | self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [
141 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
142 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
143 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}},
144 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}}
145 | ])
146 | return {
147 | 'psnr': psnr,
148 | 'index': batch['index']
149 | }
150 |
151 |
152 | """
153 | # aggregate outputs from different devices when using DP
154 | def validation_step_end(self, out):
155 | pass
156 | """
157 |
158 | def validation_epoch_end(self, out):
159 | out = self.all_gather(out)
160 | if self.trainer.is_global_zero:
161 | out_set = {}
162 | for step_out in out:
163 | # DP
164 | if step_out['index'].ndim == 1:
165 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
166 | # DDP
167 | else:
168 | for oi, index in enumerate(step_out['index']):
169 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
170 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
171 | self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True)
172 |
173 | def test_step(self, batch, batch_idx):
174 | out = self(batch)
175 | psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb'])
176 | W, H = self.dataset.img_wh
177 | self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [
178 | {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
179 | {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
180 | {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}},
181 | {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}}
182 | ])
183 | return {
184 | 'psnr': psnr,
185 | 'index': batch['index']
186 | }
187 |
188 | def test_epoch_end(self, out):
189 | out = self.all_gather(out)
190 | if self.trainer.is_global_zero:
191 | out_set = {}
192 | for step_out in out:
193 | # DP
194 | if step_out['index'].ndim == 1:
195 | out_set[step_out['index'].item()] = {'psnr': step_out['psnr']}
196 | # DDP
197 | else:
198 | for oi, index in enumerate(step_out['index']):
199 | out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]}
200 | psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
201 | self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True)
202 |
203 | self.save_img_sequence(
204 | f"it{self.global_step}-test",
205 | f"it{self.global_step}-test",
206 | '(\d+)\.png',
207 | save_format='mp4',
208 | fps=30
209 | )
210 |
211 | self.export()
212 |
213 | def export(self):
214 | mesh = self.model.export(self.config.export)
215 | self.save_mesh(
216 | f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj",
217 | **mesh
218 | )
219 |
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xxlong0/Wonder3D/d894f827aa8c2917761a0dad3ab40df74c7a5b24/instant-nsr-pl/utils/__init__.py
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import shutil
4 | from utils.misc import dump_config, parse_version
5 |
6 |
7 | import pytorch_lightning
8 | if parse_version(pytorch_lightning.__version__) > parse_version('1.8'):
9 | from pytorch_lightning.callbacks import Callback
10 | else:
11 | from pytorch_lightning.callbacks.base import Callback
12 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
13 | from pytorch_lightning.callbacks.progress import TQDMProgressBar
14 |
15 |
16 | class VersionedCallback(Callback):
17 | def __init__(self, save_root, version=None, use_version=True):
18 | self.save_root = save_root
19 | self._version = version
20 | self.use_version = use_version
21 |
22 | @property
23 | def version(self) -> int:
24 | """Get the experiment version.
25 |
26 | Returns:
27 | The experiment version if specified else the next version.
28 | """
29 | if self._version is None:
30 | self._version = self._get_next_version()
31 | return self._version
32 |
33 | def _get_next_version(self):
34 | existing_versions = []
35 | if os.path.isdir(self.save_root):
36 | for f in os.listdir(self.save_root):
37 | bn = os.path.basename(f)
38 | if bn.startswith("version_"):
39 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
40 | existing_versions.append(int(dir_ver))
41 | if len(existing_versions) == 0:
42 | return 0
43 | return max(existing_versions) + 1
44 |
45 | @property
46 | def savedir(self):
47 | if not self.use_version:
48 | return self.save_root
49 | return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}")
50 |
51 |
52 | class CodeSnapshotCallback(VersionedCallback):
53 | def __init__(self, save_root, version=None, use_version=True):
54 | super().__init__(save_root, version, use_version)
55 |
56 | def get_file_list(self):
57 | return [
58 | b.decode() for b in
59 | set(subprocess.check_output('git ls-files', shell=True).splitlines()) |
60 | set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines())
61 | ]
62 |
63 | @rank_zero_only
64 | def save_code_snapshot(self):
65 | os.makedirs(self.savedir, exist_ok=True)
66 | for f in self.get_file_list():
67 | if not os.path.exists(f) or os.path.isdir(f):
68 | continue
69 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
70 | shutil.copyfile(f, os.path.join(self.savedir, f))
71 |
72 | def on_fit_start(self, trainer, pl_module):
73 | try:
74 | self.save_code_snapshot()
75 | except:
76 | rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.")
77 |
78 |
79 | class ConfigSnapshotCallback(VersionedCallback):
80 | def __init__(self, config, save_root, version=None, use_version=True):
81 | super().__init__(save_root, version, use_version)
82 | self.config = config
83 |
84 | @rank_zero_only
85 | def save_config_snapshot(self):
86 | os.makedirs(self.savedir, exist_ok=True)
87 | dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config)
88 | shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml'))
89 |
90 | def on_fit_start(self, trainer, pl_module):
91 | self.save_config_snapshot()
92 |
93 |
94 | class CustomProgressBar(TQDMProgressBar):
95 | def get_metrics(self, *args, **kwargs):
96 | # don't show the version number
97 | items = super().get_metrics(*args, **kwargs)
98 | items.pop("v_num", None)
99 | return items
100 |
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/loggers.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pprint
3 | import logging
4 |
5 | from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
6 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
7 |
8 |
9 | class ConsoleLogger(LightningLoggerBase):
10 | def __init__(self, log_keys=[]):
11 | super().__init__()
12 | self.log_keys = [re.compile(k) for k in log_keys]
13 | self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat
14 |
15 | def match_log_keys(self, s):
16 | return True if not self.log_keys else any(r.search(s) for r in self.log_keys)
17 |
18 | @property
19 | def name(self):
20 | return 'console'
21 |
22 | @property
23 | def version(self):
24 | return '0'
25 |
26 | @property
27 | @rank_zero_experiment
28 | def experiment(self):
29 | return logging.getLogger('pytorch_lightning')
30 |
31 | @rank_zero_only
32 | def log_hyperparams(self, params):
33 | pass
34 |
35 | @rank_zero_only
36 | def log_metrics(self, metrics, step):
37 | metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)}
38 | if not metrics_:
39 | return
40 | self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}")
41 |
42 |
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import OmegaConf
3 | from packaging import version
4 |
5 |
6 | # ============ Register OmegaConf Recolvers ============= #
7 | OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8 | OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11 | OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12 | OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13 | OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14 | # ======================================================= #
15 |
16 |
17 | def prompt(question):
18 | inp = input(f"{question} (y/n)").lower().strip()
19 | if inp and inp == 'y':
20 | return True
21 | if inp and inp == 'n':
22 | return False
23 | return prompt(question)
24 |
25 |
26 | def load_config(*yaml_files, cli_args=[]):
27 | yaml_confs = [OmegaConf.load(f) for f in yaml_files]
28 | cli_conf = OmegaConf.from_cli(cli_args)
29 | conf = OmegaConf.merge(*yaml_confs, cli_conf)
30 | OmegaConf.resolve(conf)
31 | return conf
32 |
33 |
34 | def config_to_primitive(config, resolve=True):
35 | return OmegaConf.to_container(config, resolve=resolve)
36 |
37 |
38 | def dump_config(path, config):
39 | with open(path, 'w') as fp:
40 | OmegaConf.save(config=config, f=fp)
41 |
42 | def get_rank():
43 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
44 | # therefore LOCAL_RANK needs to be checked first
45 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
46 | for key in rank_keys:
47 | rank = os.environ.get(key)
48 | if rank is not None:
49 | return int(rank)
50 | return 0
51 |
52 |
53 | def parse_version(ver):
54 | return version.parse(ver)
55 |
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/mixins.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import shutil
4 | import numpy as np
5 | import cv2
6 | import imageio
7 | from matplotlib import cm
8 | from matplotlib.colors import LinearSegmentedColormap
9 | import json
10 |
11 | import torch
12 |
13 | from utils.obj import write_obj
14 |
15 |
16 | class SaverMixin():
17 | @property
18 | def save_dir(self):
19 | return self.config.save_dir
20 |
21 | def convert_data(self, data):
22 | if isinstance(data, np.ndarray):
23 | return data
24 | elif isinstance(data, torch.Tensor):
25 | return data.cpu().numpy()
26 | elif isinstance(data, list):
27 | return [self.convert_data(d) for d in data]
28 | elif isinstance(data, dict):
29 | return {k: self.convert_data(v) for k, v in data.items()}
30 | else:
31 | raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data))
32 |
33 | def get_save_path(self, filename):
34 | save_path = os.path.join(self.save_dir, filename)
35 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
36 | return save_path
37 |
38 | DEFAULT_RGB_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1)}
39 | DEFAULT_UV_KWARGS = {'data_format': 'CHW', 'data_range': (0, 1), 'cmap': 'checkerboard'}
40 | DEFAULT_GRAYSCALE_KWARGS = {'data_range': None, 'cmap': 'jet'}
41 |
42 | def get_rgb_image_(self, img, data_format, data_range):
43 | img = self.convert_data(img)
44 | assert data_format in ['CHW', 'HWC']
45 | if data_format == 'CHW':
46 | img = img.transpose(1, 2, 0)
47 | img = img.clip(min=data_range[0], max=data_range[1])
48 | img = ((img - data_range[0]) / (data_range[1] - data_range[0]) * 255.).astype(np.uint8)
49 | imgs = [img[...,start:start+3] for start in range(0, img.shape[-1], 3)]
50 | imgs = [img_ if img_.shape[-1] == 3 else np.concatenate([img_, np.zeros((img_.shape[0], img_.shape[1], 3 - img_.shape[2]), dtype=img_.dtype)], axis=-1) for img_ in imgs]
51 | img = np.concatenate(imgs, axis=1)
52 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
53 | return img
54 |
55 | def save_rgb_image(self, filename, img, data_format=DEFAULT_RGB_KWARGS['data_format'], data_range=DEFAULT_RGB_KWARGS['data_range']):
56 | img = self.get_rgb_image_(img, data_format, data_range)
57 | cv2.imwrite(self.get_save_path(filename), img)
58 |
59 | def get_uv_image_(self, img, data_format, data_range, cmap):
60 | img = self.convert_data(img)
61 | assert data_format in ['CHW', 'HWC']
62 | if data_format == 'CHW':
63 | img = img.transpose(1, 2, 0)
64 | img = img.clip(min=data_range[0], max=data_range[1])
65 | img = (img - data_range[0]) / (data_range[1] - data_range[0])
66 | assert cmap in ['checkerboard', 'color']
67 | if cmap == 'checkerboard':
68 | n_grid = 64
69 | mask = (img * n_grid).astype(int)
70 | mask = (mask[...,0] + mask[...,1]) % 2 == 0
71 | img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
72 | img[mask] = np.array([255, 0, 255], dtype=np.uint8)
73 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
74 | elif cmap == 'color':
75 | img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
76 | img_[..., 0] = (img[..., 0] * 255).astype(np.uint8)
77 | img_[..., 1] = (img[..., 1] * 255).astype(np.uint8)
78 | img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR)
79 | img = img_
80 | return img
81 |
82 | def save_uv_image(self, filename, img, data_format=DEFAULT_UV_KWARGS['data_format'], data_range=DEFAULT_UV_KWARGS['data_range'], cmap=DEFAULT_UV_KWARGS['cmap']):
83 | img = self.get_uv_image_(img, data_format, data_range, cmap)
84 | cv2.imwrite(self.get_save_path(filename), img)
85 |
86 | def get_grayscale_image_(self, img, data_range, cmap):
87 | img = self.convert_data(img)
88 | img = np.nan_to_num(img)
89 | if data_range is None:
90 | img = (img - img.min()) / (img.max() - img.min())
91 | else:
92 | img = img.clip(data_range[0], data_range[1])
93 | img = (img - data_range[0]) / (data_range[1] - data_range[0])
94 | assert cmap in [None, 'jet', 'magma']
95 | if cmap == None:
96 | img = (img * 255.).astype(np.uint8)
97 | img = np.repeat(img[...,None], 3, axis=2)
98 | elif cmap == 'jet':
99 | img = (img * 255.).astype(np.uint8)
100 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
101 | elif cmap == 'magma':
102 | img = 1. - img
103 | base = cm.get_cmap('magma')
104 | num_bins = 256
105 | colormap = LinearSegmentedColormap.from_list(
106 | f"{base.name}{num_bins}",
107 | base(np.linspace(0, 1, num_bins)),
108 | num_bins
109 | )(np.linspace(0, 1, num_bins))[:,:3]
110 | a = np.floor(img * 255.)
111 | b = (a + 1).clip(max=255.)
112 | f = img * 255. - a
113 | a = a.astype(np.uint16).clip(0, 255)
114 | b = b.astype(np.uint16).clip(0, 255)
115 | img = colormap[a] + (colormap[b] - colormap[a]) * f[...,None]
116 | img = (img * 255.).astype(np.uint8)
117 | return img
118 |
119 | def save_grayscale_image(self, filename, img, data_range=DEFAULT_GRAYSCALE_KWARGS['data_range'], cmap=DEFAULT_GRAYSCALE_KWARGS['cmap']):
120 | img = self.get_grayscale_image_(img, data_range, cmap)
121 | cv2.imwrite(self.get_save_path(filename), img)
122 |
123 | def get_image_grid_(self, imgs):
124 | if isinstance(imgs[0], list):
125 | return np.concatenate([self.get_image_grid_(row) for row in imgs], axis=0)
126 | cols = []
127 | for col in imgs:
128 | assert col['type'] in ['rgb', 'uv', 'grayscale']
129 | if col['type'] == 'rgb':
130 | rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy()
131 | rgb_kwargs.update(col['kwargs'])
132 | cols.append(self.get_rgb_image_(col['img'], **rgb_kwargs))
133 | elif col['type'] == 'uv':
134 | uv_kwargs = self.DEFAULT_UV_KWARGS.copy()
135 | uv_kwargs.update(col['kwargs'])
136 | cols.append(self.get_uv_image_(col['img'], **uv_kwargs))
137 | elif col['type'] == 'grayscale':
138 | grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy()
139 | grayscale_kwargs.update(col['kwargs'])
140 | cols.append(self.get_grayscale_image_(col['img'], **grayscale_kwargs))
141 | return np.concatenate(cols, axis=1)
142 |
143 | def save_image_grid(self, filename, imgs):
144 | img = self.get_image_grid_(imgs)
145 | cv2.imwrite(self.get_save_path(filename), img)
146 |
147 | def save_image(self, filename, img):
148 | img = self.convert_data(img)
149 | assert img.dtype == np.uint8
150 | if img.shape[-1] == 3:
151 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
152 | elif img.shape[-1] == 4:
153 | img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
154 | cv2.imwrite(self.get_save_path(filename), img)
155 |
156 | def save_cubemap(self, filename, img, data_range=(0, 1)):
157 | img = self.convert_data(img)
158 | assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2]
159 |
160 | imgs_full = []
161 | for start in range(0, img.shape[-1], 3):
162 | img_ = img[...,start:start+3]
163 | img_ = np.stack([self.get_rgb_image_(img_[i], 'HWC', data_range) for i in range(img_.shape[0])], axis=0)
164 | size = img_.shape[1]
165 | placeholder = np.zeros((size, size, 3), dtype=np.float32)
166 | img_full = np.concatenate([
167 | np.concatenate([placeholder, img_[2], placeholder, placeholder], axis=1),
168 | np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1),
169 | np.concatenate([placeholder, img_[3], placeholder, placeholder], axis=1)
170 | ], axis=0)
171 | img_full = cv2.cvtColor(img_full, cv2.COLOR_RGB2BGR)
172 | imgs_full.append(img_full)
173 |
174 | imgs_full = np.concatenate(imgs_full, axis=1)
175 | cv2.imwrite(self.get_save_path(filename), imgs_full)
176 |
177 | def save_data(self, filename, data):
178 | data = self.convert_data(data)
179 | if isinstance(data, dict):
180 | if not filename.endswith('.npz'):
181 | filename += '.npz'
182 | np.savez(self.get_save_path(filename), **data)
183 | else:
184 | if not filename.endswith('.npy'):
185 | filename += '.npy'
186 | np.save(self.get_save_path(filename), data)
187 |
188 | def save_state_dict(self, filename, data):
189 | torch.save(data, self.get_save_path(filename))
190 |
191 | def save_img_sequence(self, filename, img_dir, matcher, save_format='gif', fps=30):
192 | assert save_format in ['gif', 'mp4']
193 | if not filename.endswith(save_format):
194 | filename += f".{save_format}"
195 | matcher = re.compile(matcher)
196 | img_dir = os.path.join(self.save_dir, img_dir)
197 | imgs = []
198 | for f in os.listdir(img_dir):
199 | if matcher.search(f):
200 | imgs.append(f)
201 | imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
202 | imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
203 |
204 | if save_format == 'gif':
205 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
206 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps, palettesize=256)
207 | elif save_format == 'mp4':
208 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
209 | imageio.mimsave(self.get_save_path(filename), imgs, fps=fps)
210 |
211 | def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None, v_rgb=None, ortho_scale=1):
212 | v_pos, t_pos_idx = self.convert_data(v_pos), self.convert_data(t_pos_idx)
213 | if v_rgb is not None:
214 | v_rgb = self.convert_data(v_rgb)
215 |
216 | if ortho_scale is not None:
217 | print("ortho scale is: ", ortho_scale)
218 | v_pos = v_pos * ortho_scale * 0.5
219 |
220 | # change to front-facing
221 | v_pos_copy = np.zeros_like(v_pos)
222 | v_pos_copy[:, 0] = v_pos[:, 0]
223 | v_pos_copy[:, 1] = v_pos[:, 2]
224 | v_pos_copy[:, 2] = v_pos[:, 1]
225 |
226 | import trimesh
227 | mesh = trimesh.Trimesh(
228 | vertices=v_pos_copy,
229 | faces=t_pos_idx,
230 | vertex_colors=v_rgb
231 | )
232 | trimesh.repair.fix_inversion(mesh)
233 | mesh.export(self.get_save_path(filename))
234 | # mesh.export(self.get_save_path(filename.replace(".obj", "-meshlab.obj")))
235 |
236 | # v_pos_copy[:, 0] = v_pos[:, 1] * -1
237 | # v_pos_copy[:, 1] = v_pos[:, 0]
238 | # v_pos_copy[:, 2] = v_pos[:, 2]
239 |
240 | # mesh = trimesh.Trimesh(
241 | # vertices=v_pos_copy,
242 | # faces=t_pos_idx,
243 | # vertex_colors=v_rgb
244 | # )
245 | # mesh.export(self.get_save_path(filename.replace(".obj", "-blender.obj")))
246 |
247 |
248 | # v_pos_copy[:, 0] = v_pos[:, 0]
249 | # v_pos_copy[:, 1] = v_pos[:, 1] * -1
250 | # v_pos_copy[:, 2] = v_pos[:, 2] * -1
251 |
252 | # mesh = trimesh.Trimesh(
253 | # vertices=v_pos_copy,
254 | # faces=t_pos_idx,
255 | # vertex_colors=v_rgb
256 | # )
257 | # mesh.export(self.get_save_path(filename.replace(".obj", "-opengl.obj")))
258 |
259 | def save_file(self, filename, src_path):
260 | shutil.copyfile(src_path, self.get_save_path(filename))
261 |
262 | def save_json(self, filename, payload):
263 | with open(self.get_save_path(filename), 'w') as f:
264 | f.write(json.dumps(payload))
265 |
--------------------------------------------------------------------------------
/instant-nsr-pl/utils/obj.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def load_obj(filename):
5 | # Read entire file
6 | with open(filename, 'r') as f:
7 | lines = f.readlines()
8 |
9 | # load vertices
10 | vertices, texcoords = [], []
11 | for line in lines:
12 | if len(line.split()) == 0:
13 | continue
14 |
15 | prefix = line.split()[0].lower()
16 | if prefix == 'v':
17 | vertices.append([float(v) for v in line.split()[1:]])
18 | elif prefix == 'vt':
19 | val = [float(v) for v in line.split()[1:]]
20 | texcoords.append([val[0], 1.0 - val[1]])
21 |
22 | uv = len(texcoords) > 0
23 | faces, tfaces = [], []
24 | for line in lines:
25 | if len(line.split()) == 0:
26 | continue
27 | prefix = line.split()[0].lower()
28 | if prefix == 'usemtl': # Track used materials
29 | pass
30 | elif prefix == 'f': # Parse face
31 | vs = line.split()[1:]
32 | nv = len(vs)
33 | vv = vs[0].split('/')
34 | v0 = int(vv[0]) - 1
35 | if uv:
36 | t0 = int(vv[1]) - 1 if vv[1] != "" else -1
37 | for i in range(nv - 2): # Triangulate polygons
38 | vv1 = vs[i + 1].split('/')
39 | v1 = int(vv1[0]) - 1
40 | vv2 = vs[i + 2].split('/')
41 | v2 = int(vv2[0]) - 1
42 | faces.append([v0, v1, v2])
43 | if uv:
44 | t1 = int(vv1[1]) - 1 if vv1[1] != "" else -1
45 | t2 = int(vv2[1]) - 1 if vv2[1] != "" else -1
46 | tfaces.append([t0, t1, t2])
47 | vertices = np.array(vertices, dtype=np.float32)
48 | faces = np.array(faces, dtype=np.int64)
49 | if uv:
50 | assert len(tfaces) == len(faces)
51 | texcoords = np.array(texcoords, dtype=np.float32)
52 | tfaces = np.array(tfaces, dtype=np.int64)
53 | else:
54 | texcoords, tfaces = None, None
55 |
56 | return vertices, faces, texcoords, tfaces
57 |
58 |
59 | def write_obj(filename, v_pos, t_pos_idx, v_tex, t_tex_idx):
60 | with open(filename, "w") as f:
61 | for v in v_pos:
62 | f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
63 |
64 | if v_tex is not None:
65 | assert(len(t_pos_idx) == len(t_tex_idx))
66 | for v in v_tex:
67 | f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
68 |
69 | # Write faces
70 | for i in range(len(t_pos_idx)):
71 | f.write("f ")
72 | for j in range(3):
73 | f.write(' %s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1)))
74 | f.write("\n")
75 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt:
--------------------------------------------------------------------------------
1 | -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08
2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08
3 | 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt:
--------------------------------------------------------------------------------
1 | -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07
2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07
3 | 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt:
--------------------------------------------------------------------------------
1 | 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07
2 | -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08
3 | 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt:
--------------------------------------------------------------------------------
1 | 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00
2 | 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08
3 | -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt:
--------------------------------------------------------------------------------
1 | -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07
2 | 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07
3 | -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt:
--------------------------------------------------------------------------------
1 | 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07
2 | 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07
3 | -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt:
--------------------------------------------------------------------------------
1 | -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00
2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08
3 | -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt:
--------------------------------------------------------------------------------
1 | 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08
2 | -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08
3 | 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt:
--------------------------------------------------------------------------------
1 | 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09
2 | -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08
3 | -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00
4 |
--------------------------------------------------------------------------------
/mvdiffusion/data/normal_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def camNormal2worldNormal(rot_c2w, camNormal):
4 | H,W,_ = camNormal.shape
5 | normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
6 |
7 | return normal_img
8 |
9 | def worldNormal2camNormal(rot_w2c, normal_map_world):
10 | H,W,_ = normal_map_world.shape
11 | # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
12 |
13 | # faster version
14 | # Reshape the normal map into a 2D array where each row represents a normal vector
15 | normal_map_flat = normal_map_world.reshape(-1, 3)
16 |
17 | # Transform the normal vectors using the transformation matrix
18 | normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
19 |
20 | # Reshape the transformed normal map back to its original shape
21 | normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
22 |
23 | return normal_map_camera
24 |
25 | def trans_normal(normal, RT_w2c, RT_w2c_target):
26 |
27 | # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
28 | # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
29 |
30 | relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
31 | normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
32 |
33 | return normal_target_cam
34 |
35 | def img2normal(img):
36 | return (img/255.)*2-1
37 |
38 | def normal2img(normal):
39 | return np.uint8((normal*0.5+0.5)*255)
40 |
41 | def norm_normalize(normal, dim=-1):
42 |
43 | normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
44 |
45 | return normal
--------------------------------------------------------------------------------
/mvdiffusion/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import numpy as np
3 | from omegaconf import DictConfig, ListConfig
4 | import torch
5 | from torch.utils.data import Dataset
6 | from pathlib import Path
7 | import json
8 | from PIL import Image
9 | from torchvision import transforms
10 | from einops import rearrange
11 | from typing import Literal, Tuple, Optional, Any
12 | import cv2
13 | import random
14 |
15 | import json
16 | import os, sys
17 | import math
18 |
19 | from glob import glob
20 |
21 | import PIL.Image
22 | from .normal_utils import trans_normal, normal2img, img2normal
23 | import pdb
24 |
25 |
26 | import cv2
27 | import numpy as np
28 |
29 | def add_margin(pil_img, color=0, size=256):
30 | width, height = pil_img.size
31 | result = Image.new(pil_img.mode, (size, size), color)
32 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
33 | return result
34 |
35 | def scale_and_place_object(image, scale_factor):
36 | assert np.shape(image)[-1]==4 # RGBA
37 |
38 | # Extract the alpha channel (transparency) and the object (RGB channels)
39 | alpha_channel = image[:, :, 3]
40 |
41 | # Find the bounding box coordinates of the object
42 | coords = cv2.findNonZero(alpha_channel)
43 | x, y, width, height = cv2.boundingRect(coords)
44 |
45 | # Calculate the scale factor for resizing
46 | original_height, original_width = image.shape[:2]
47 |
48 | if width > height:
49 | size = width
50 | original_size = original_width
51 | else:
52 | size = height
53 | original_size = original_height
54 |
55 | scale_factor = min(scale_factor, size / (original_size+0.0))
56 |
57 | new_size = scale_factor * original_size
58 | scale_factor = new_size / size
59 |
60 | # Calculate the new size based on the scale factor
61 | new_width = int(width * scale_factor)
62 | new_height = int(height * scale_factor)
63 |
64 | center_x = original_width // 2
65 | center_y = original_height // 2
66 |
67 | paste_x = center_x - (new_width // 2)
68 | paste_y = center_y - (new_height // 2)
69 |
70 | # Resize the object (RGB channels) to the new size
71 | rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
72 |
73 | # Create a new RGBA image with the resized image
74 | new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
75 |
76 | new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
77 |
78 | return new_image
79 |
80 | class SingleImageDataset(Dataset):
81 | def __init__(self,
82 | root_dir: str,
83 | num_views: int,
84 | img_wh: Tuple[int, int],
85 | bg_color: str,
86 | crop_size: int = 224,
87 | single_image: Optional[PIL.Image.Image] = None,
88 | num_validation_samples: Optional[int] = None,
89 | filepaths: Optional[list] = None,
90 | cond_type: Optional[str] = None
91 | ) -> None:
92 | """Create a dataset from a folder of images.
93 | If you pass in a root directory it will be searched for images
94 | ending in ext (ext can be a list)
95 | """
96 | self.root_dir = root_dir
97 | self.num_views = num_views
98 | self.img_wh = img_wh
99 | self.crop_size = crop_size
100 | self.bg_color = bg_color
101 | self.cond_type = cond_type
102 |
103 | if self.num_views == 4:
104 | self.view_types = ['front', 'right', 'back', 'left']
105 | elif self.num_views == 5:
106 | self.view_types = ['front', 'front_right', 'right', 'back', 'left']
107 | elif self.num_views == 6:
108 | self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
109 |
110 | self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
111 |
112 | self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
113 |
114 | if single_image is None:
115 | if filepaths is None:
116 | # Get a list of all files in the directory
117 | file_list = os.listdir(self.root_dir)
118 | else:
119 | file_list = filepaths
120 |
121 | # Filter the files that end with .png or .jpg
122 | self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
123 | else:
124 | self.file_list = None
125 |
126 | # load all images
127 | self.all_images = []
128 | self.all_alphas = []
129 | bg_color = self.get_bg_color()
130 |
131 | if single_image is not None:
132 | image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image)
133 | self.all_images.append(image)
134 | self.all_alphas.append(alpha)
135 | else:
136 | for file in self.file_list:
137 | print(os.path.join(self.root_dir, file))
138 | image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
139 | self.all_images.append(image)
140 | self.all_alphas.append(alpha)
141 |
142 | self.all_images = self.all_images[:num_validation_samples]
143 | self.all_alphas = self.all_alphas[:num_validation_samples]
144 |
145 |
146 | def __len__(self):
147 | return len(self.all_images)
148 |
149 | def load_fixed_poses(self):
150 | poses = {}
151 | for face in self.view_types:
152 | RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
153 | poses[face] = RT
154 |
155 | return poses
156 |
157 | def cartesian_to_spherical(self, xyz):
158 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
159 | xy = xyz[:,0]**2 + xyz[:,1]**2
160 | z = np.sqrt(xy + xyz[:,2]**2)
161 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
162 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
163 | azimuth = np.arctan2(xyz[:,1], xyz[:,0])
164 | return np.array([theta, azimuth, z])
165 |
166 | def get_T(self, target_RT, cond_RT):
167 | R, T = target_RT[:3, :3], target_RT[:, -1]
168 | T_target = -R.T @ T # change to cam2world
169 |
170 | R, T = cond_RT[:3, :3], cond_RT[:, -1]
171 | T_cond = -R.T @ T
172 |
173 | theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
174 | theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
175 |
176 | d_theta = theta_target - theta_cond
177 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
178 | d_z = z_target - z_cond
179 |
180 | # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
181 | return d_theta, d_azimuth
182 |
183 | def get_bg_color(self):
184 | if self.bg_color == 'white':
185 | bg_color = np.array([1., 1., 1.], dtype=np.float32)
186 | elif self.bg_color == 'black':
187 | bg_color = np.array([0., 0., 0.], dtype=np.float32)
188 | elif self.bg_color == 'gray':
189 | bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
190 | elif self.bg_color == 'random':
191 | bg_color = np.random.rand(3)
192 | elif isinstance(self.bg_color, float):
193 | bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
194 | else:
195 | raise NotImplementedError
196 | return bg_color
197 |
198 |
199 | def load_image(self, img_path, bg_color, return_type='np', Imagefile=None):
200 | # pil always returns uint8
201 | if Imagefile is None:
202 | image_input = Image.open(img_path)
203 | else:
204 | image_input = Imagefile
205 | image_size = self.img_wh[0]
206 |
207 | if self.crop_size!=-1:
208 | alpha_np = np.asarray(image_input)[:, :, 3]
209 | coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
210 | min_x, min_y = np.min(coords, 0)
211 | max_x, max_y = np.max(coords, 0)
212 | ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
213 | h, w = ref_img_.height, ref_img_.width
214 | scale = self.crop_size / max(h, w)
215 | h_, w_ = int(scale * h), int(scale * w)
216 | ref_img_ = ref_img_.resize((w_, h_))
217 | image_input = add_margin(ref_img_, size=image_size)
218 | else:
219 | image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
220 | image_input = image_input.resize((image_size, image_size))
221 |
222 | # img = scale_and_place_object(img, self.scale_ratio)
223 | img = np.array(image_input)
224 | img = img.astype(np.float32) / 255. # [0, 1]
225 | assert img.shape[-1] == 4 # RGBA
226 |
227 | alpha = img[...,3:4]
228 | img = img[...,:3] * alpha + bg_color * (1 - alpha)
229 |
230 | if return_type == "np":
231 | pass
232 | elif return_type == "pt":
233 | img = torch.from_numpy(img)
234 | alpha = torch.from_numpy(alpha)
235 | else:
236 | raise NotImplementedError
237 |
238 | return img, alpha
239 |
240 |
241 | def __len__(self):
242 | return len(self.all_images)
243 |
244 | def __getitem__(self, index):
245 |
246 | image = self.all_images[index%len(self.all_images)]
247 | alpha = self.all_alphas[index%len(self.all_images)]
248 | if self.file_list is not None:
249 | filename = self.file_list[index%len(self.all_images)].replace(".png", "")
250 | else:
251 | filename = 'null'
252 |
253 | cond_w2c = self.fix_cam_poses['front']
254 |
255 | tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
256 |
257 | elevations = []
258 | azimuths = []
259 |
260 | img_tensors_in = [
261 | image.permute(2, 0, 1)
262 | ] * self.num_views
263 |
264 | alpha_tensors_in = [
265 | alpha.permute(2, 0, 1)
266 | ] * self.num_views
267 |
268 | for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
269 | # evelations, azimuths
270 | elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
271 | elevations.append(elevation)
272 | azimuths.append(azimuth)
273 |
274 | img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
275 | alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
276 |
277 | elevations = torch.as_tensor(elevations).float().squeeze(1)
278 | azimuths = torch.as_tensor(azimuths).float().squeeze(1)
279 | elevations_cond = torch.as_tensor([0] * self.num_views).float()
280 |
281 | normal_class = torch.tensor([1, 0]).float()
282 | normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
283 | color_class = torch.tensor([0, 1]).float()
284 | color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
285 |
286 | camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
287 |
288 | out = {
289 | 'elevations_cond': elevations_cond,
290 | 'elevations_cond_deg': torch.rad2deg(elevations_cond),
291 | 'elevations': elevations,
292 | 'azimuths': azimuths,
293 | 'elevations_deg': torch.rad2deg(elevations),
294 | 'azimuths_deg': torch.rad2deg(azimuths),
295 | 'imgs_in': img_tensors_in,
296 | 'alphas': alpha_tensors_in,
297 | 'camera_embeddings': camera_embeddings,
298 | 'normal_task_embeddings': normal_task_embeddings,
299 | 'color_task_embeddings': color_task_embeddings,
300 | 'filename': filename,
301 | }
302 |
303 | return out
304 |
305 |
306 |
--------------------------------------------------------------------------------
/render_codes/README.md:
--------------------------------------------------------------------------------
1 | # Prepare the rendering data
2 |
3 | ## Environment
4 | The rendering codes are mainly based on [BlenderProc](https://github.com/DLR-RM/BlenderProc). Thanks for the great tool.
5 | BlenderProc uses blender Cycle engine to render the images by default, which may meet long-time hanging problem in some specific GPUs (like A800, tested already)
6 |
7 | `
8 | cd ./render_codes
9 | pip install -r requirements.txt
10 | `
11 |
12 | ## How to use the code
13 | Here we provide two rendering scripts `blenderProc_ortho.py` and `blenderProc_persp.py`, which use **orthogonal** camera and **perspective** camera to render the objects respectively.
14 |
15 | ### Use `blenderProc_ortho.py` to render images of a single object
16 | `
17 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender
18 | blenderProc_ortho.py
19 | --object_path /mnt/pfs/data/objaverse_lvis_glbs/c7/c70e8817b5a945aca8bb37e02ddbc6f9.glb --view 0
20 | --output_folder ./out_renderings/
21 | --object_uid c70e8817b5a945aca8bb37e02ddbc6f9
22 | --ortho_scale 1.35
23 | --resolution 512
24 | --random_pose
25 | `
26 |
27 | Here `--view` denotes a tag for the rendering images, since you may render an object multiple times, `--ortho_scale` decides the scaling of rendered object in the image, `--random_pose` will randomly rotate the object before rendering.
28 |
29 |
30 | ### Use `blenderProc_persp.py` to render images of a single object
31 |
32 | `
33 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender
34 | blenderProc_persp.py
35 | --object_path ${the object path} --view 0
36 | --output_folder ${your save path}
37 | --object_uid ${object_uid} --radius 2.0
38 | --random_pose
39 | `
40 |
41 | Here `--radius` denotes the distance of between the camera and the object origin.
42 |
43 | ### Render objects in distributed mode
44 | see `render_batch_ortho.sh` and `render_batch_persp.sh` for commands.
45 |
46 |
--------------------------------------------------------------------------------
/render_codes/distributed.py:
--------------------------------------------------------------------------------
1 | # multiprocessing render
2 | import json
3 | import multiprocessing
4 | import subprocess
5 | from dataclasses import dataclass
6 | from typing import Optional
7 | import os
8 |
9 | import boto3
10 |
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description='distributed rendering')
14 |
15 | parser.add_argument('--workers_per_gpu', type=int,
16 | help='number of workers per gpu.')
17 | parser.add_argument('--input_models_path', type=str,
18 | help='Path to a json file containing a list of 3D object files.')
19 | parser.add_argument('--upload_to_s3', type=bool, default=False,
20 | help='Whether to upload the rendered images to S3.')
21 | parser.add_argument('--log_to_wandb', type=bool, default=False,
22 | help='Whether to log the progress to wandb.')
23 | parser.add_argument('--num_gpus', type=int, default=-1,
24 | help='number of gpus to use. -1 means all available gpus.')
25 | parser.add_argument('--gpu_list',nargs='+', type=int,
26 | help='the avalaible gpus')
27 |
28 | parser.add_argument('--mode', type=str, default='render',
29 | choices=['render_ortho', 'render_persp'],
30 | help='use orthogonal camera or perspective camera')
31 |
32 | parser.add_argument('--start_i', type=int, default=0,
33 | help='the index of first object to be rendered.')
34 |
35 | parser.add_argument('--end_i', type=int, default=-1,
36 | help='the index of the last object to be rendered.')
37 |
38 | parser.add_argument('--objaverse_root', type=str, default='/ghome/l5/xxlong/.objaverse/hf-objaverse-v1',
39 | help='Path to a json file containing a list of 3D object files.')
40 |
41 | parser.add_argument('--save_folder', type=str, default=None,
42 | help='Path to a json file containing a list of 3D object files.')
43 |
44 | parser.add_argument('--blender_install_path', type=str, default=None,
45 | help='blender path.')
46 |
47 | parser.add_argument('--view_idx', type=int, default=2,
48 | help='the number of render views.')
49 |
50 | parser.add_argument('--ortho_scale', type=float, default=1.25,
51 | help='ortho rendering usage; how large the object is')
52 |
53 | parser.add_argument('--random_pose', action='store_true',
54 | help='whether randomly rotate the poses to be rendered')
55 |
56 | args = parser.parse_args()
57 |
58 |
59 | view_idx = args.view_idx
60 |
61 | VIEWS = ["front", "back", "right", "left", "front_right", "front_left", "back_right", "back_left"]
62 |
63 | def check_task_finish(render_dir, view_index):
64 | files_type = ['rgb', 'normals']
65 | flag = True
66 | view_index = "%03d" % view_index
67 | if os.path.exists(render_dir):
68 | for t in files_type:
69 | for face in VIEWS:
70 | fpath = os.path.join(render_dir, f'{t}_{view_index}_{face}.webp')
71 | # print(fpath)
72 | if not os.path.exists(fpath):
73 | flag = False
74 | else:
75 | flag = False
76 |
77 | return flag
78 |
79 | def worker(
80 | queue: multiprocessing.JoinableQueue,
81 | count: multiprocessing.Value,
82 | gpu: int,
83 | s3: Optional[boto3.client],
84 | ) -> None:
85 | while True:
86 | item = queue.get()
87 | if item is None:
88 | break
89 |
90 | view_path = os.path.join(args.save_folder, item.split('/')[-1][:2], item.split('/')[-1][:-4])
91 | print(view_path)
92 | if 'render' in args.mode:
93 | if check_task_finish(view_path, view_idx):
94 | queue.task_done()
95 | print('========', item, 'rendered', '========')
96 |
97 | continue
98 | else:
99 | os.makedirs(view_path, exist_ok = True)
100 |
101 | # Perform some operation on the item
102 | print(item, gpu)
103 |
104 | if args.mode == 'render_ortho':
105 | command = (
106 | f" CUDA_VISIBLE_DEVICES={gpu} "
107 | f" blenderproc run --blender-install-path {args.blender_install_path} blenderProc_ortho.py"
108 | f" --object_path {item} --view {view_idx}"
109 | f" --output_folder {args.save_folder}"
110 | f" --ortho_scale {args.ortho_scale} "
111 | )
112 | if args.random_pose:
113 | print("random pose to render")
114 | command += f" --random_pose"
115 | elif args.mode == 'render_persp':
116 | command = (
117 | f" CUDA_VISIBLE_DEVICES={gpu} "
118 | f" blenderproc run --blender-install-path {args.blender_install_path} blenderProc_persp.py"
119 | f" --object_path {item} --view {view_idx}"
120 | f" --output_folder {args.save_folder}"
121 | )
122 | if args.random_pose:
123 | print("random pose to render")
124 | command += f" --random_pose"
125 |
126 | print(command)
127 | subprocess.run(command, shell=True)
128 |
129 | with count.get_lock():
130 | count.value += 1
131 |
132 | queue.task_done()
133 |
134 |
135 | if __name__ == "__main__":
136 | # args = tyro.cli(Args)
137 |
138 | s3 = boto3.client("s3") if args.upload_to_s3 else None
139 | queue = multiprocessing.JoinableQueue()
140 | count = multiprocessing.Value("i", 0)
141 |
142 | # Start worker processes on each of the GPUs
143 | for gpu_i in range(args.num_gpus):
144 | for worker_i in range(args.workers_per_gpu):
145 | worker_i = gpu_i * args.workers_per_gpu + worker_i
146 | process = multiprocessing.Process(
147 | target=worker, args=(queue, count, args.gpu_list[gpu_i], s3)
148 | )
149 | process.daemon = True
150 | process.start()
151 |
152 | # Add items to the queue
153 | if args.input_models_path is not None:
154 | with open(args.input_models_path, "r") as f:
155 | model_paths = json.load(f)
156 |
157 | args.end_i = len(model_paths) if args.end_i > len(model_paths) else args.end_i
158 |
159 | for item in model_paths[args.start_i:args.end_i]:
160 |
161 | if os.path.exists(os.path.join(args.objaverse_root, os.path.basename(item))):
162 | obj_path = os.path.join(args.objaverse_root, os.path.basename(item))
163 | elif os.path.exists(os.path.join(args.objaverse_root, item)):
164 | obj_path = os.path.join(args.objaverse_root, item)
165 | else:
166 | obj_path = os.path.join(args.objaverse_root, item[:2], item+".glb")
167 | queue.put(obj_path)
168 |
169 | # Wait for all tasks to be completed
170 | queue.join()
171 |
172 | # Add sentinels to the queue to stop the worker processes
173 | for i in range(args.num_gpus * args.workers_per_gpu):
174 | queue.put(None)
175 |
--------------------------------------------------------------------------------
/render_codes/render_batch_ortho.sh:
--------------------------------------------------------------------------------
1 | python distributed.py \
2 | --num_gpus 8 --gpu_list 0 1 2 3 4 5 6 7 --mode render_ortho \
3 | --workers_per_gpu 10 --view_idx $1 \
4 | --start_i $2 --end_i $3 --ortho_scale 1.35 \
5 | --input_models_path ../data_lists/lvis_uids_filter_by_vertex.json \
6 | --objaverse_root /data/objaverse \
7 | --save_folder data/obj_lvis_13views \
8 | --blender_install_path /workplace/blender \
9 | --random_pose
10 |
--------------------------------------------------------------------------------
/render_codes/render_batch_persp.sh:
--------------------------------------------------------------------------------
1 | python distributed.py \
2 | --num_gpus 8 --gpu_list 0 1 2 3 4 5 6 7 --mode render_persp \
3 | --workers_per_gpu 10 --view_idx $1 \
4 | --start_i $2 --end_i $3 \
5 | --input_models_path ../data_lists/lvis/lvis_uids_filter_by_vertex.json \
6 | --save_folder /data/nineviews-pinhole \
7 | --objaverse_root /objaverse \
8 | --blender_install_path /workplace/blender \
9 | --random_pose
10 |
--------------------------------------------------------------------------------
/render_codes/render_single_ortho.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender \
3 | blenderProc_nineviews_ortho.py \
4 | --object_path /mnt/pfs/data/objaverse_lvis_glbs/c7/c70e8817b5a945aca8bb37e02ddbc6f9.glb --view 0 \
5 | --output_folder ./out_renderings/ \
6 | --object_uid c70e8817b5a945aca8bb37e02ddbc6f9 \
7 | --ortho_scale 1.35 \
8 | --resolution 512 \
9 | # --reset_object_euler
--------------------------------------------------------------------------------
/render_codes/render_single_persp.sh:
--------------------------------------------------------------------------------
1 | blenderproc run --blender-install-path /mnt/pfs/users/longxiaoxiao/workplace/blender
2 | blenderProc_persp.py
3 | --object_path ${the object path} --view 0
4 | --output_folder ${your save path}
5 | --object_uid ${object_uid} --radius 2.0
6 | --random_pose
--------------------------------------------------------------------------------
/render_codes/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 | blenderproc==2.5.0
3 | boto3==1.26.105
4 | docstring-parser==0.14.1
5 | h5py==3.8.0
6 | imageio==2.9.0
7 | imageio-ffmpeg==0.4.2
8 | markdown==3.4.3
9 | matplotlib==3.7.1
10 | multiprocess==0.70.13
11 | objaverse==0.0.7
12 | omegaconf==2.1.1
13 | opencv-python==4.5.5.64
14 | opencv-python-headless==4.7.0.72
15 | pillow==9.4.0
16 | progressbar==2.5
17 | scikit-image==0.20.0
18 | termcolor==2.2.0
19 | tqdm==4.65.0
20 | tyro==0.3.38
21 | vtk==9.2.6
22 | wandb==0.14.0
23 | zipp==3.15.0
24 | natsort
25 | mathutils==3.3.0
26 | argparse
27 |
28 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | torch==1.13.1
3 | torchvision==0.14.1
4 | diffusers[torch]==0.19.3
5 | xformers==0.0.16
6 | transformers>=4.25.1
7 | bitsandbytes==0.35.4
8 | decord==0.6.0
9 | pytorch-lightning<2
10 | omegaconf==2.2.3
11 | nerfacc==0.3.3
12 | trimesh==3.9.8
13 | pyhocon==0.3.57
14 | icecream==2.1.0
15 | PyMCubes==0.1.2
16 | accelerate
17 | modelcards
18 | einops
19 | ftfy
20 | piq
21 | matplotlib
22 | opencv-python
23 | imageio
24 | imageio-ffmpeg
25 | scipy
26 | pyransac3d
27 | torch_efficient_distloss
28 | tensorboard
29 | rembg
30 | segment_anything
31 | gradio==3.50.2
32 |
--------------------------------------------------------------------------------
/run_test.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py --config configs/mvdiffusion-joint-ortho-6views.yaml
--------------------------------------------------------------------------------
/run_train_stage1.sh:
--------------------------------------------------------------------------------
1 |
2 | # stage 1
3 | accelerate launch --config_file 1gpu.yaml train_mvdiffusion_image.py --config configs/train/stage1-mix-6views-lvis.yaml
--------------------------------------------------------------------------------
/run_train_stage2.sh:
--------------------------------------------------------------------------------
1 | # stage 2
2 | accelerate launch --config_file 1gpu.yaml train_mvdiffusion_joint.py --config configs/train/stage2-joint-6views-lvis.yaml
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import OmegaConf
3 | from packaging import version
4 |
5 |
6 | # ============ Register OmegaConf Recolvers ============= #
7 | OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8 | OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9 | OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10 | OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11 | OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12 | OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13 | OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14 | # ======================================================= #
15 |
16 |
17 | def prompt(question):
18 | inp = input(f"{question} (y/n)").lower().strip()
19 | if inp and inp == 'y':
20 | return True
21 | if inp and inp == 'n':
22 | return False
23 | return prompt(question)
24 |
25 |
26 | def load_config(*yaml_files, cli_args=[]):
27 | yaml_confs = [OmegaConf.load(f) for f in yaml_files]
28 | cli_conf = OmegaConf.from_cli(cli_args)
29 | conf = OmegaConf.merge(*yaml_confs, cli_conf)
30 | OmegaConf.resolve(conf)
31 | return conf
32 |
33 |
34 | def config_to_primitive(config, resolve=True):
35 | return OmegaConf.to_container(config, resolve=resolve)
36 |
37 |
38 | def dump_config(path, config):
39 | with open(path, 'w') as fp:
40 | OmegaConf.save(config=config, f=fp)
41 |
42 | def get_rank():
43 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
44 | # therefore LOCAL_RANK needs to be checked first
45 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
46 | for key in rank_keys:
47 | rank = os.environ.get(key)
48 | if rank is not None:
49 | return int(rank)
50 | return 0
51 |
52 |
53 | def parse_version(ver):
54 | return version.parse(ver)
55 |
--------------------------------------------------------------------------------