├── .gitignore
├── 2D_regression.py
├── LICENSE
├── README.md
├── README_FactorField.md
├── configs
├── .DS_Store
├── 360_v2.yaml
├── defaults.yaml
├── image.yaml
├── image_intro.yaml
├── image_set.yaml
├── nerf.yaml
├── nerf_ft.yaml
├── nerf_set.yaml
├── sdf.yaml
└── tnt.yaml
├── dataLoader
├── .DS_Store
├── __init__.py
├── blender.py
├── blender_set.py
├── colmap.py
├── colmap2nerf.py
├── dtu_objs.py
├── dtu_objs2.py
├── google_objs.py
├── image.py
├── image_set.py
├── llff.py
├── nsvf.py
├── ray_utils.py
├── sdf.py
├── tankstemple.py
└── your_own_data.py
├── media
├── Girl_with_a_Pearl_Earring.jpg
└── inpainting.png
├── models
├── .DS_Store
├── FactorFields.py
├── __init__.py
└── sh.py
├── renderer.py
├── requirements.txt
├── run_batch.py
├── scripts
├── .DS_Store
├── 2D_regression.ipynb
├── 2D_set_regression.ipynb
├── 2D_set_regression.py
├── __init__.py
├── formula_demostration.ipynb
├── mesh2SDF_data_process.ipynb
└── sdf_regression.ipynb
├── train_across_scene.py
├── train_across_scene_ft.py
├── train_per_scene.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | data/
163 | slurm/
164 | logs/
165 |
166 |
--------------------------------------------------------------------------------
/2D_regression.py:
--------------------------------------------------------------------------------
1 | import torch,imageio,sys,time,os,cmapy,scipy
2 | import numpy as np
3 | from tqdm import tqdm
4 | import matplotlib.pyplot as plt
5 | from omegaconf import OmegaConf
6 | import torch.nn.functional as F
7 |
8 | device = 'cuda'
9 |
10 | sys.path.append('..')
11 | from models.sparseCoding import sparseCoding
12 |
13 | from dataLoader import dataset_dict
14 | from torch.utils.data import DataLoader
15 |
16 |
17 | def PSNR(a, b):
18 | if type(a).__module__ == np.__name__:
19 | mse = np.mean((a - b) ** 2)
20 | else:
21 | mse = torch.mean((a - b) ** 2).item()
22 | psnr = -10.0 * np.log(mse) / np.log(10.0)
23 | return psnr
24 |
25 |
26 | def rgb_ssim(img0, img1, max_val,
27 | filter_size=11,
28 | filter_sigma=1.5,
29 | k1=0.01,
30 | k2=0.03,
31 | return_map=False):
32 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
33 | assert len(img0.shape) == 3
34 | assert img0.shape[-1] == 3
35 | assert img0.shape == img1.shape
36 |
37 | # Construct a 1D Gaussian blur filter.
38 | hw = filter_size // 2
39 | shift = (2 * hw - filter_size + 1) / 2
40 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2
41 | filt = np.exp(-0.5 * f_i)
42 | filt /= np.sum(filt)
43 |
44 | # Blur in x and y (faster than the 2D convolution).
45 | def convolve2d(z, f):
46 | return scipy.signal.convolve2d(z, f, mode='valid')
47 |
48 | filt_fn = lambda z: np.stack([
49 | convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :])
50 | for i in range(z.shape[-1])], -1)
51 | mu0 = filt_fn(img0)
52 | mu1 = filt_fn(img1)
53 | mu00 = mu0 * mu0
54 | mu11 = mu1 * mu1
55 | mu01 = mu0 * mu1
56 | sigma00 = filt_fn(img0 ** 2) - mu00
57 | sigma11 = filt_fn(img1 ** 2) - mu11
58 | sigma01 = filt_fn(img0 * img1) - mu01
59 |
60 | # Clip the variances and covariances to valid values.
61 | # Variance must be non-negative:
62 | sigma00 = np.maximum(0., sigma00)
63 | sigma11 = np.maximum(0., sigma11)
64 | sigma01 = np.sign(sigma01) * np.minimum(
65 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
66 | c1 = (k1 * max_val) ** 2
67 | c2 = (k2 * max_val) ** 2
68 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
69 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
70 | ssim_map = numer / denom
71 | ssim = np.mean(ssim_map)
72 | return ssim_map if return_map else ssim
73 |
74 |
75 | @torch.no_grad()
76 | def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240):
77 | y = torch.linspace(0, aabb[0] - 1, reso[0])
78 | x = torch.linspace(0, aabb[1] - 1, reso[1])
79 | yy, xx = torch.meshgrid((y, x), indexing='ij')
80 |
81 | idx = 0
82 | res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1])
83 | coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + torch.tensor(
84 | shiftment) # /(torch.FloatTensor(reso[::-1])-1)*2-1
85 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)):
86 | feats, _ = model.get_coding(coordiante.to(model.device))
87 | y_recon = model.linear_mat(feats, is_train=False)
88 | # y_recon = torch.sum(feats,dim=-1,keepdim=True)
89 |
90 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu()
91 | idx += y_recon.shape[0]
92 | return res.view(reso[0], reso[1], -1), coordiantes
93 |
94 |
95 | def linear_to_srgb(img):
96 | limit = 0.0031308
97 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
98 |
99 |
100 | def write_image_imageio(img_file, img, colormap=None, quality=100):
101 | if colormap == 'turbo':
102 | shape = img.shape
103 | img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1)
104 | elif colormap is not None:
105 | img = cmapy.colorize((img * 255).astype('uint8'), colormap)
106 |
107 | if img.dtype != 'uint8':
108 | img = (img - np.min(img)) / (np.max(img) - np.min(img))
109 | img = (img * 255.0).astype(np.uint8)
110 |
111 | kwargs = {}
112 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
113 | if img.ndim >= 3 and img.shape[2] > 3:
114 | img = img[:, :, :3]
115 | kwargs["quality"] = quality
116 | kwargs["subsampling"] = 0
117 | imageio.imwrite(img_file, img, **kwargs)
118 |
119 | if __name__ == '__main__':
120 |
121 | torch.set_default_dtype(torch.float32)
122 | torch.manual_seed(20211202)
123 | np.random.seed(20211202)
124 |
125 | base_conf = OmegaConf.load('configs/defaults.yaml')
126 | cli_conf = OmegaConf.from_cli()
127 | second_conf = OmegaConf.load('configs/image.yaml')
128 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
129 | print(cfg)
130 |
131 |
132 | folder = cfg.defaults.expname
133 | save_root = f'/vlg-nfs/anpei/project/NeuBasis/ours/images/'
134 |
135 | dataset = dataset_dict[cfg.dataset.dataset_name]
136 |
137 | delete_region = [[290,350,48,48],[300,380,48,48],[180, 407, 48, 48], [223, 263, 48, 48], [233, 150, 48, 48], [374, 119, 48, 48], [4, 199, 48, 48], [180, 234, 48, 48], [173, 39, 48, 48], [408, 308, 48, 48], [227, 177, 48, 48], [46, 330, 48, 48], [213, 26, 48, 48], [90, 44, 48, 48], [295, 61, 48, 48]]
138 | continue_sampling = False
139 |
140 | psnrs,ssims = [],[]
141 | for i in range(1,257):
142 | cfg.dataset.datadir = f'/vlg-nfs/anpei/dataset/Images/crop//{i:04d}.png'
143 | name = os.path.basename(cfg.dataset.datadir).split('.')[0]
144 | if os.path.exists(f'{save_root}/{folder}/{int(name):04d}.png'):
145 | continue
146 |
147 |
148 | train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train',tolinear=True, perscent=1.0,HW=1024)#, continue_sampling=continue_sampling,delete_region=delete_region
149 | train_loader = DataLoader(train_dataset,
150 | num_workers=2,
151 | persistent_workers=True,
152 | batch_size=None,
153 | pin_memory=False)
154 | # train_dataset.img = train_dataset.img.to(device)
155 |
156 | cfg.model.out_dim = train_dataset.img.shape[-1]
157 | batch_size = cfg.training.batch_size
158 | n_iter = cfg.training.n_iters
159 |
160 | H,W = train_dataset.HW
161 | train_dataset.scene_bbox = [[0., 0.], [W, H]]
162 | cfg.dataset.aabb = train_dataset.scene_bbox
163 |
164 | model = sparseCoding(cfg, device)
165 | if 1==i:
166 | print(model)
167 | print('total parameters: ',model.n_parameters())
168 |
169 | # tvreg = TVLoss()
170 | # trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size)
171 |
172 | grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
173 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#
174 |
175 |
176 | loss_scale = 1.0
177 | lr_factor = 0.1 ** (1 / n_iter)
178 | # pbar = tqdm(range(10000))
179 | start = time.time()
180 | # for iteration in pbar:
181 | for (iteration, sample) in zip(range(10000),train_loader):
182 | loss_scale *= lr_factor
183 |
184 | # if iteration==5000:
185 | # model.coeffs = torch.nn.Parameter(F.interpolate(model.coeffs.data, size=None, scale_factor=2.0, align_corners=True,mode='bilinear'))
186 | # grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
187 | # optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#
188 | # model.set_optimizable(['mlp','basis'], False)
189 |
190 | coordiantes, pixel_rgb = sample['xy'], sample['rgb']
191 | feats,coeff = model.get_coding(coordiantes.to(device))
192 | # tv_loss = model.TV_loss(tvreg)
193 |
194 | y_recon = model.linear_mat(feats,is_train=True)
195 | # y_recon = torch.sum(feats,dim=-1,keepdim=True)
196 | loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2)
197 |
198 |
199 | psnr = -10.0 * np.log(loss.item()) / np.log(10.0)
200 | # if iteration%100==0:
201 | # pbar.set_description(
202 | # f'Iteration {iteration:05d}:'
203 | # + f' loss_dist = {loss.item():.8f}'
204 | # # + f' tv_loss = {tv_loss.item():.6f}'
205 | # + f' psnr = {psnr:.3f}'
206 | # )
207 |
208 | # loss = loss + tv_loss
209 | # loss = loss + torch.mean(coeff.abs())*1e-2
210 | loss = loss * loss_scale
211 | optimizer.zero_grad()
212 | loss.backward()
213 | optimizer.step()
214 |
215 | # if iteration%100==0:
216 | # model.normalize_basis()
217 | iteration_time = time.time()-start
218 |
219 | H,W = train_dataset.HW
220 | img,coordinate = eval_img(train_dataset.HW,[1024,1024])
221 | if continue_sampling:
222 | import torch.nn.functional as F
223 | coordinate_tmp = (coordinate.view(1,1,-1,2))/torch.tensor([W,H])*2-1.0
224 | img_gt = F.grid_sample(train_dataset.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear',
225 | align_corners=False, padding_mode='border').reshape(-1,H,W).permute(1,2,0)
226 | else:
227 | img_gt = train_dataset.img.view(H,W,-1)
228 | psnrs.append(PSNR(img.clamp(0,1.),img_gt))
229 | ssims.append(rgb_ssim(img.clamp(0,1.),img_gt,1.0))
230 | # print(PSNR(img.clamp(0,1.),img_gt),iteration_time)
231 | # plt.figure(figsize=(10, 10))
232 | # plt.imshow(linear_to_srgb(img.clamp(0,1.)))
233 |
234 | print(i, psnrs[-1], ssims[-1])
235 |
236 |
237 | os.makedirs(f'{save_root}/{folder}',exist_ok=True)
238 | write_image_imageio(f'{save_root}/{folder}/{int(name):04d}.png',linear_to_srgb(img.clamp(0,1.)))
239 | np.savetxt(f'{save_root}/{folder}/{int(name):04d}.txt',[psnrs[-1],ssims[-1],iteration_time,model.n_parameters()])
240 |
241 |
242 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Factor Fields
2 | ## [Project page](https://apchenstu.github.io/FactorFields/) | [Paper](https://arxiv.org/abs/2302.01226)
3 | This repository contains a pytorch implementation for the paper: [Factor Fields: A Unified Framework for Neural Fields and Beyond](https://arxiv.org/abs/2302.01226) and [Dictionary Fields: Learning a Neural Basis Decomposition](https://arxiv.org/abs/2302.01226). Our work present a novel framework for modeling and representing signals,
4 | we have also observed that Dictionary Fields offer benefits such as improved **approximation quality**, **compactness**, **faster training speed**, and the ability to **generalize** to unseen images and 3D scenes.
5 |
6 |
7 | ## Installation
8 |
9 | #### Tested on Ubuntu 20.04 + Pytorch 1.13.0
10 |
11 | Install environment:
12 | ```sh
13 | conda create -n FactorFields python=3.9
14 | conda activate FactorFields
15 | conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
16 | conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
17 | pip install -r requirements.txt
18 | ```
19 |
20 | Optionally install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), only needed if you want to run hash grid based representations.
21 | ```sh
22 | conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
23 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
24 | ```
25 |
26 |
27 | # Quick Start
28 | Please ensure that you download the corresponding dataset and extract its contents into the `data` folder.
29 |
30 | ## Image
31 | * [Data - Image Set](https://1drv.ms/u/c/0c624178fab774b7/Ebd0t_p4QWIggAx3BAAAAAABikvhj5m_rVm1-qIpYFyrFg?e=hyTeZf)
32 |
33 | The training script can be found at `scripts/2D_regression.ipynb`, and the configuration file is located at `configs/image.yaml`.
34 |
35 |
36 |
37 |
38 |
39 | ## SDF
40 | * [Data - Mesh set](https://1drv.ms/u/c/0c624178fab774b7/Ebd0t_p4QWIggAx4BAAAAAABbouT0SD3PCChlfTQJL3XzA?e=ImcsAj)
41 |
42 | The training script can be found at `scripts/sdf_regression.ipynb`, and the configuration file is located at `configs/sdf.yaml`.
43 |
44 |
45 |
46 |
47 |
48 | ## NeRF
49 | * [Data - Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
50 | * [Data-Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
51 |
52 | The training script can be found at `train_per_scene.py`:
53 |
54 | ```python
55 | python train_per_scene.py configs/nerf.yaml defaults.expname=lego dataset.datadir=./data/nerf_synthetic/lego
56 | ```
57 |
58 |
67 |
68 |
69 |
70 |
71 |
72 | ## Generalization NeRF
73 | * [Data - Google Scanned Objects](https://drive.google.com/file/d/1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2/view)
74 |
75 | ```python
76 | python train_across_scene.py configs/nerf_set.yaml
77 | ```
78 |
79 |
80 |
81 |
82 | ## More examples
83 |
84 | Command explanation with a nerf example:
85 | * `model.basis_dims=[4, 4, 4, 2, 2, 2]` adjusts the number of levels and channels at each level, with a total of 6 levels and 18 channels.
86 | * `model.basis_resos=[32, 51, 70, 89, 108, 128]` represents the resolution of the feature embeddings.
87 | * `model.freq_bands=[2.0, 3.2, 4.4, 5.6, 6.8, 8.0]` indicates the frequency parameters applied at each level of the coordinate transformation function.
88 | * `model.coeff_type` represents the coefficient field representations and can be one of the following: [none, x, grid, mlp, vec, cp, vm].
89 | * `model.basis_type` represents the basis field representation and can be one of the following: [none, x, grid, mlp, vec, cp, vm, hash].
90 | * `model.basis_mapping` represents the coordinate transformation and can be one of the following: [x, triangle, sawtooth, trigonometric]. Please note that if you want to use orthogonal projection, choose the cp or vm basis type, as they automatically utilize the orthogonal projection functions.
91 | * `model.total_params` controls the total model size. It is important to note that the model's size capability is determined by model.basis_resos and model.basis_dims. The total_params parameter mainly affects the capability of the coefficients.
92 | * `exportation.render_only` you can rendering item after training by setting this label to 1. Please also specify the `defaults.ckpt` label.
93 | * `exportation....` you can specify whether to render the items of `[render_test, render_train, render_path, export_mesh]` after training by enable the corressponding label to 1.
94 |
95 | Some pre-defined configurations (such as occNet, DVGO, nerf, iNGP, EG3D) can be found in `README_FactorField.py`.
96 |
97 |
98 | ## COPY RIGHT
99 | * [Summer Day](https://www.rijksmuseum.nl/en/collection/SK-A-3005) - Credit goes to Johan Hendrik Weissenbruch and rijksmuseum.
100 | * [Mars](https://solarsystem.nasa.gov/resources/933/true-colors-of-pluto/) - Credit goes to NASA.
101 | * [Albert](https://cdn.loc.gov/service/pnp/cph/3b40000/3b46000/3b46000/3b46036v.jpg) - Credit goes to Orren Jack Turner.
102 | * [Girl With a Pearl Earring](http://profoundism.com/free_licenses.html) - Renovation copyright Koorosh Orooj (CC BY-SA 4.0).
103 |
104 |
105 | ## Citation
106 | If you find our code or paper helpful, please consider citing both of these papers:
107 | ```
108 | @article{Chen2023factor,
109 | title={Factor Fields: A Unified Framework for Neural Fields and Beyond},
110 | author={Chen, Anpei and Xu, Zexiang and Wei, Xinyue and Tang, Siyu and Su, Hao and Geiger, Andreas},
111 | journal={arXiv preprint arXiv:2302.01226},
112 | year={2023}
113 | }
114 |
115 | @article{Chen2023SIGGRAPH,
116 | title={{Dictionary Fields: Learning a Neural Basis Decomposition}},
117 | author={Anpei, Chen and Zexiang, Xu and Xinyue, Wei and Siyu, Tang and Hao, Su and Andreas, Geiger},
118 | booktitle={International Conference on Computer Graphics and Interactive Techniques (SIGGRAPH)},
119 | year={2023}}
120 | ```
121 |
--------------------------------------------------------------------------------
/README_FactorField.md:
--------------------------------------------------------------------------------
1 | ## nerf reconstruction with Dictionary field
2 |
3 | ```python
4 | for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
5 | cmd = f'python train_basis.py configs/nerf.yaml defaults.expname={scene} ' \
6 | f'dataset.datadir=./data/nerf_synthetic/{scene} '
7 | ```
8 |
9 | ## different model design choices
10 |
11 | ```python
12 | choice_dict = {
13 | '-grid': '', \
14 | '-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \
15 | '-noC': 'model.coeff_type=none', \
16 | '-SL':'model.basis_dims=[18] model.basis_resos=[70] model.freq_bands=[8.]', \
17 | '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
18 | '-iNGP-like': 'model.basis_type=hash model.coeff_type=none', \
19 | '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
20 | '-sinc': f'model.basis_mapping=sinc', \
21 | '-tria': f'model.basis_mapping=triangle', \
22 | '-vm': f'model.coeff_type=vm model.basis_type=vm', \
23 | '-mlpB': 'model.basis_type=mlp', \
24 | '-mlpC': 'model.coeff_type=mlp', \
25 | '-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \
26 | '-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \
27 | f'model.num_layers=8 model.hidden_dim=256 ' \
28 | f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \
29 | '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
30 | '-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
31 | '-DCT':'model.basis_type=fix-grid', \
32 | }
33 |
34 | for name in choice_dict.keys():
35 | for scene in [ 'ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
36 |
37 | cmd = f"python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {config}"
38 | ```
39 |
40 | ## generalized nerf
41 | Your can choice of the the following design choice for testing.
42 | ```python
43 | choice_dict = {
44 | '-grid': '', \
45 | '-DVGO-like': 'model.basis_type=none model.coeff_reso=48',
46 | '-SL':'model.basis_dims=[72] model.basis_resos=[48] model.freq_bands=[6.]', \
47 | '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
48 | '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
49 | '-sinc': f'model.basis_mapping=sinc', \
50 | '-tria': f'model.basis_mapping=triangle', \
51 | '-vm': f'model.coeff_type=vm model.basis_type=vm', \
52 | '-mlpB': 'model.basis_type=mlp', \
53 | '-mlpC': 'model.coeff_type=mlp', \
54 | '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
55 | '-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
56 | '-DCT':'model.basis_type=fix-grid', \
57 | }
58 |
59 | for name in choice_dict.keys(): #
60 | cmd = f'python train_across_scene.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \
61 | f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/'
62 | ```
63 |
64 | You can also fine tune of the trained model for a new scene:
65 |
66 | ```python
67 | for views in [5]:#3,
68 | for name in choice_dict.keys(): #
69 | for scene in [183]:#183,199,298,467,957,244,963,527,
70 |
71 | cmd = f'python train_across_scene_ft.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
72 | f'{config} training.n_iters=10000 ' \
73 | f'dataset.train_views={views} ' \
74 | f'dataset.train_scene_list=[{scene}] ' \
75 | f'dataset.test_scene_list=[{scene}] ' \
76 | f'dataset.datadir=./data/google_scanned_objects/ ' \
77 | f'defaults.ckpt=./logs/google-obj{name}//google-obj{name}.th'
78 | ```
79 |
80 | # render path after optimization
81 | ```python
82 | for views in [5]:
83 | for name in choice_dict.keys(): #
84 | config = commands[name].replace(",", "','")
85 | for scene in [183]:#183,199,298,467,957,244,963,527,681,948
86 |
87 | cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
88 | f'{config} training.n_iters=10000 ' \
89 | f'dataset.train_views={views} exporation.render_only=True exporation.render_path=True exporation.render_test=False ' \
90 | f'dataset.train_scene_list=[{scene}] ' \
91 | f'dataset.test_scene_list=[{scene}] ' \
92 | f'dataset.datadir=./data/google_scanned_objects/ ' \
93 | f'defaults.ckpt=./logs/google_objs_{name}_{scene}_{views}_views//google_objs_{name}_{scene}_{views}_views.th'
94 | ```
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/360_v2.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_room_real_mask
4 | logdir: ./logs
5 |
6 | ckpt: null # help='specific weights npy file to reload for coarse network'
7 |
8 | model:
9 | basis_dims: [5,5,5,2,2,2]
10 | basis_resos: [ 64, 83, 102, 121, 140, 160]
11 | coeff_reso: 16
12 | coef_init: 0.01
13 | phases: [0.0]
14 |
15 | coef_mode: bilinear
16 | basis_mode: bilinear
17 |
18 | freq_bands: [ 1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
19 |
20 | kernel_mapping_type: 'sawtooth'
21 |
22 | in_dim: 3
23 | out_dim: 32
24 | num_layers: 2
25 | hidden_dim: 128
26 |
27 | dataset:
28 | # loader options
29 | dataset_name: llff # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
30 | datadir: /home/anpei/code/NeuBasis/data/360_v2/room/
31 | ndc_ray: 0
32 | is_unbound: True
33 |
34 | with_depth: 0
35 | downsample_train: 4.0
36 | downsample_test: 4.0
37 |
38 | N_vis: 5
39 | vis_every: 5000
40 |
41 | training:
42 |
43 | n_iters: 30000
44 | batch_size: 4096
45 |
46 | volume_resoInit: 128 # 128**3:
47 | volume_resoFinal: 320 # 300**3
48 |
49 | upsamp_list: [2000,3000,4000,5500]
50 | update_AlphaMask_list: [2500]
51 | shrinking_list: [-1]
52 |
53 | L1_weight_inital: 0.0
54 | L1_weight_rest: 0.0
55 |
56 | TV_weight_density: 0.0
57 | TV_weight_app: 0.00
58 |
59 | exportation:
60 | render_only: 0
61 | render_test: 1
62 | render_train: 0
63 | render_path: 0
64 | export_mesh: 0
65 | export_mesh_only: 0
66 |
67 | renderer:
68 | shadingMode: MLP_Fea
69 | num_layers: 3
70 | hidden_dim: 128
71 |
72 | fea2denseAct: 'relu'
73 | density_shift: -10
74 | distance_scale: 25.0
75 |
76 | view_pe: 6
77 | fea_pe: 2
78 |
79 | lindisp: 0
80 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
81 |
82 | step_ratio: 0.5
83 | max_samples: 1600
84 |
85 | alphaMask_thres: 0.04
86 | rayMarch_weight_thres: 1e-3
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/configs/defaults.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_lego
4 | logedir: ./logs
5 |
6 | mode: 'reconstruction'
7 |
8 | progress_refresh_rate: 10
9 |
10 | add_timestamp: 0
11 |
12 | model:
13 | basis_dims: [4,4,4,2,2,2]
14 | basis_resos: [32,51,70,89,108,128]
15 | coeff_reso: 32
16 | total_params: 10744166
17 | T_basis: 0
18 | T_coeff: 0
19 |
20 | coef_init: 1.0
21 | coef_mode: bilinear
22 | basis_mode: bilinear
23 |
24 | freq_bands: [1.0000, 1.3300, 1.7689, 2.3526, 3.1290, 4.1616]
25 |
26 |
27 | basis_mapping: 'sawtooth'
28 | with_dropout: False
29 |
30 | in_dim: 3
31 | out_dim: 32
32 | num_layers: 2
33 | hidden_dim: 128
34 |
35 | dataset:
36 | # loader options
37 | dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
38 | datadir: ./data/nerf_synthetic/lego
39 |
40 | with_depth: 0
41 | downsample_train: 1.0
42 | downsample_test: 1.0
43 |
44 | is_unbound: False
45 |
46 | training:
47 | # training options
48 | batch_size: 4096
49 | n_iters: 30000
50 |
51 | # learning rate
52 | lr_small: 0.001
53 | lr_large: 0.02
54 |
55 | lr_decay_iters: -1
56 | lr_decay_target_ratio: 0.1 # help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters'
57 | lr_upsample_reset: 1 # help='reset lr to inital after upsampling'
58 |
59 | # loss
60 | L1_weight_inital: 0.0 # help='loss weight'
61 | L1_weight_rest: 0
62 | Ortho_weight: 0.0
63 | TV_weight_density: 0.0
64 | TV_weight_app: 0.0
65 |
66 | # optimiziable
67 | coeff: True
68 | basis: True
69 | linear_mat: True
70 | renderModule: True
71 |
72 |
73 |
--------------------------------------------------------------------------------
/configs/image.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_image
4 | logdir: ./logs
5 |
6 | mode: 'image'
7 |
8 | ckpt: null # help='specific weights npy file to reload for coarse network'
9 |
10 | model:
11 | basis_dims: [32,32,32,16,16,16]
12 | basis_resos: [32,51,70,89,108,128]
13 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
14 |
15 | total_params: 1426063 # albert
16 | # total_params: 61445328 # pluto
17 | # total_params: 71848800 #Girl_with_a_Pearl_Earring
18 | # total_params: 37138096 # Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg_base
19 |
20 | coeff_type: 'grid'
21 | basis_type: 'grid'
22 |
23 | coef_init: 0.001
24 |
25 | coef_mode: nearest
26 | basis_mode: nearest
27 | basis_mapping: 'sawtooth'
28 |
29 |
30 | in_dim: 2
31 | out_dim: 3
32 | num_layers: 2
33 | hidden_dim: 64
34 | with_dropout: False
35 |
36 | dataset:
37 | # loader options
38 | dataset_name: image
39 | datadir: "../data/image/albert.exr"
40 | # datadir: "../data/image//pluto.jpeg"
41 | # datadir: "../data/image//Girl_with_a_Pearl_Earring.jpeg"
42 | # datadir: "../data/image//Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg"
43 |
44 |
45 | training:
46 | n_iters: 10000
47 | batch_size: 102400
48 |
49 | # learning rate
50 | lr_small: 0.002
51 | lr_large: 0.002
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/configs/image_intro.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_image
4 | logdir: ./logs
5 |
6 | mode: 'demo'
7 |
8 | ckpt: null # help='specific weights npy file to reload for coarse network'
9 |
10 | model:
11 | in_dim: 2
12 | out_dim: 1
13 |
14 | basis_dims: [32,32,32,16,16,16]
15 | basis_resos: [32,51,70,89,108,128]
16 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
17 |
18 |
19 |
20 |
21 |
22 | # occNet
23 | coeff_type: 'none'
24 | basis_type: 'x'
25 | basis_mapping: 'x'
26 | num_layers: 8
27 | hidden_dim: 256
28 |
29 |
30 | # coef_init: 0.001
31 |
32 | # coef_mode: nearest
33 | # basis_mode: nearest
34 | # basis_mapping: 'sawtooth'
35 |
36 | with_dropout: False
37 |
38 | dataset:
39 | # loader options
40 | dataset_name: image
41 | datadir: ../data/image/cat_occupancy.png
42 |
43 |
44 | training:
45 | n_iters: 10000
46 | batch_size: 102400
47 |
48 | # learning rate
49 | lr_small: 0.0002
50 | lr_large: 0.0002
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/configs/image_set.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_image
4 | logdir: ./logs
5 |
6 | mode: 'images'
7 |
8 | ckpt: null # help='specific weights npy file to reload for coarse network'
9 |
10 | model:
11 | basis_dims: [32,32,32,16,16,16]
12 | basis_resos: [32,51,70,89,108,128]
13 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
14 |
15 |
16 | coeff_reso: 32
17 | total_params: 1024000
18 |
19 | coef_init: 0.001
20 |
21 | coef_mode: bilinear
22 | basis_mode: bilinear
23 |
24 |
25 | coeff_type: 'grid'
26 | basis_type: 'grid'
27 |
28 | in_dim: 3
29 | out_dim: 3
30 | num_layers: 2
31 | hidden_dim: 64
32 | with_dropout: True
33 |
34 | dataset:
35 | # loader options
36 | dataset_name: images
37 | datadir: data/ffhq/ffhq_512.npy
38 |
39 | training:
40 | n_iters: 300000
41 | batch_size: 40960
42 |
43 | # learning rate
44 | lr_small: 0.002
45 | lr_large: 0.002
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/configs/nerf.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | expname: basis_lego
3 | logdir: ./logs
4 |
5 | mode: 'reconstruction'
6 |
7 | ckpt: null # help='specific weights npy file to reload for coarse network'
8 |
9 | model:
10 | coeff_reso: 32
11 |
12 | basis_dims: [4,4,4,2,2,2]
13 | basis_resos: [32,51,70,89,108,128]
14 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
15 |
16 |
17 | coef_init: 1.0
18 | phases: [0.0]
19 | total_params: 5308416
20 |
21 | coef_mode: bilinear
22 | basis_mode: bilinear
23 |
24 | coeff_type: 'grid'
25 | basis_type: 'grid'
26 | basis_mapping: 'sawtooth'
27 |
28 | in_dim: 3
29 | out_dim: 32
30 | num_layers: 2
31 | hidden_dim: 64
32 |
33 | dataset:
34 | # loader options
35 | dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
36 | datadir: ./data/nerf_synthetic/lego
37 | ndc_ray: 0
38 |
39 | with_depth: 0
40 | downsample_train: 1.0
41 | downsample_test: 1.0
42 |
43 | N_vis: 5
44 | vis_every: 100000
45 | scene_reso: 768
46 |
47 | training:
48 |
49 | n_iters: 30000
50 | batch_size: 4096
51 |
52 | volume_resoInit: 128 # 128**3:
53 | volume_resoFinal: 300 # 300**3
54 |
55 | upsamp_list: [2000,3000,4000,5500,7000]
56 | update_AlphaMask_list: [2500,4000]
57 | shrinking_list: [500]
58 |
59 | L1_weight_inital: 0.0
60 | L1_weight_rest: 0.0
61 |
62 | TV_weight_density: 0.000
63 | TV_weight_app: 0.00
64 |
65 | exportation:
66 | render_only: 0
67 | render_test: 1
68 | render_train: 0
69 | render_path: 0
70 | export_mesh: 0
71 | export_mesh_only: 0
72 |
73 | renderer:
74 | shadingMode: MLP_Fea
75 | num_layers: 3
76 | hidden_dim: 128
77 |
78 | fea2denseAct: 'softplus'
79 | density_shift: -10
80 | distance_scale: 25.0
81 |
82 | view_pe: 6
83 | fea_pe: 2
84 |
85 | lindisp: 0
86 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
87 |
88 | step_ratio: 0.5
89 | max_samples: 1200
90 |
91 | alphaMask_thres: 0.02
92 | rayMarch_weight_thres: 1e-3
93 |
--------------------------------------------------------------------------------
/configs/nerf_ft.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | expname: basis
3 | logdir: ./logs
4 |
5 | mode: 'reconstructions'
6 |
7 | ckpt: null # help='specific weights npy file to reload for coarse network'
8 |
9 | model:
10 | coeff_reso: 16
11 |
12 | basis_dims: [16,16,16,8,8,8]
13 | basis_resos: [32,51,70,89,108,128]
14 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
15 |
16 | with_dropout: True
17 |
18 | coef_init: 1.0
19 | phases: [0.0]
20 | total_params: 5308416
21 |
22 | coef_mode: bilinear
23 | basis_mode: bilinear
24 |
25 | coeff_type: 'grid'
26 | basis_type: 'grid'
27 | basis_mapping: 'sawtooth'
28 |
29 | in_dim: 3
30 | out_dim: 32
31 | num_layers: 2
32 | hidden_dim: 64
33 |
34 | dataset:
35 | # loader options
36 | dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
37 | datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
38 | ndc_ray: 0
39 | train_scene_list: [100]
40 | test_scene_list: [100]
41 | train_views: 5
42 |
43 | with_depth: 0
44 | downsample_train: 1.0
45 | downsample_test: 1.0
46 |
47 | N_vis: 5
48 | vis_every: 100000
49 | scene_reso: 768
50 |
51 | training:
52 |
53 | n_iters: 5000
54 | batch_size: 4096
55 |
56 | volume_resoInit: 128 # 128**3:
57 | volume_resoFinal: 300 # 300**3
58 |
59 | upsamp_list: [2000,3000,4000]
60 | update_AlphaMask_list: [1500]
61 | shrinking_list: [-1]
62 |
63 | L1_weight_inital: 0.0
64 | L1_weight_rest: 0.0
65 |
66 | TV_weight_density: 0.000
67 | TV_weight_app: 0.00
68 |
69 | # optimiziable
70 | coeff: True
71 | basis: False
72 | linear_mat: False
73 | renderModule: False
74 |
75 | exportation:
76 | render_only: 0
77 | render_test: 1
78 | render_train: 0
79 | render_path: 0
80 | export_mesh: 0
81 | export_mesh_only: 0
82 |
83 | renderer:
84 | shadingMode: MLP_Fea
85 | num_layers: 3
86 | hidden_dim: 128
87 |
88 | fea2denseAct: 'softplus'
89 | density_shift: -10
90 | distance_scale: 25.0
91 |
92 | view_pe: 6
93 | fea_pe: 2
94 |
95 | lindisp: 0
96 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
97 |
98 | step_ratio: 0.5
99 | max_samples: 1200
100 |
101 | alphaMask_thres: 0.02
102 | rayMarch_weight_thres: 1e-3
--------------------------------------------------------------------------------
/configs/nerf_set.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | expname: basis_no_relu_lego
3 | logdir: ./logs
4 |
5 | mode: 'reconstructions'
6 |
7 | ckpt: null # help='specific weights npy file to reload for coarse network'
8 |
9 | model:
10 | coeff_reso: 16
11 |
12 | basis_dims: [16,16,16,8,8,8]
13 | # basis_resos: [32,51,70,89,108,128]
14 | basis_resos: [32,51,70,89,108,128]
15 | # freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
16 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
17 |
18 | with_dropout: True
19 |
20 | coef_init: 1.0
21 | phases: [0.0]
22 | total_params: 5308416
23 |
24 | coef_mode: bilinear
25 | basis_mode: bilinear
26 |
27 | coeff_type: 'grid'
28 | basis_type: 'grid'
29 | basis_mapping: 'sawtooth'
30 |
31 | in_dim: 3
32 | out_dim: 32
33 | num_layers: 2
34 | hidden_dim: 64
35 |
36 | dataset:
37 | # loader options
38 | dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
39 | datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
40 | ndc_ray: 0
41 | train_scene_list: [0,100]
42 | test_scene_list: [0]
43 | train_views: 100
44 |
45 | with_depth: 0
46 | downsample_train: 1.0
47 | downsample_test: 1.0
48 |
49 | N_vis: 5
50 | vis_every: 100000
51 | scene_reso: 768
52 |
53 | training:
54 |
55 | n_iters: 50000
56 | batch_size: 4096
57 |
58 | volume_resoInit: 128 # 128**3:
59 | volume_resoFinal: 256 # 300**3
60 |
61 | upsamp_list: [2000,3000,4000,5500,7000]
62 | update_AlphaMask_list: [-1]
63 | shrinking_list: [-1]
64 |
65 | L1_weight_inital: 0.0
66 | L1_weight_rest: 0.0
67 |
68 | TV_weight_density: 0.000
69 | TV_weight_app: 0.00
70 |
71 | exportation:
72 | render_only: 0
73 | render_test: 0
74 | render_train: 0
75 | render_path: 0
76 | export_mesh: 0
77 | export_mesh_only: 0
78 |
79 | renderer:
80 | shadingMode: MLP_Fea
81 | num_layers: 3
82 | hidden_dim: 128
83 |
84 | fea2denseAct: 'softplus'
85 | density_shift: -10
86 | distance_scale: 25.0
87 |
88 | view_pe: 6
89 | fea_pe: 2
90 |
91 | lindisp: 0
92 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
93 |
94 | step_ratio: 0.5
95 | max_samples: 1200
96 |
97 | alphaMask_thres: 0.005
98 | rayMarch_weight_thres: 1e-3
99 |
--------------------------------------------------------------------------------
/configs/sdf.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | expname: basis_sdf
3 | logdir: ./logs
4 |
5 | mode: 'sdf'
6 |
7 | ckpt: null # help='specific weights npy file to reload for coarse network'
8 |
9 | model:
10 | basis_dims: [4,4,4,2,2,2]
11 | basis_resos: [32,51,70,89,108,128]
12 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
13 |
14 | total_params: 5313942
15 |
16 | coeff_reso: 32
17 | coef_init: 0.05
18 |
19 | coef_mode: bilinear
20 | basis_mode: bilinear
21 |
22 |
23 | coeff_type: 'grid'
24 | basis_type: 'grid'
25 | kernel_mapping_type: 'sawtooth'
26 |
27 | in_dim: 3
28 | out_dim: 1
29 | num_layers: 1
30 | hidden_dim: 64
31 |
32 | dataset:
33 | # loader options
34 | dataset_name: sdf
35 | datadir: "../data/mesh/statuette_close.npy"
36 |
37 | scene_reso: 384
38 |
39 |
40 | training:
41 | n_iters: 10000
42 | batch_size: 40960
43 |
44 | # learning rate
45 | lr_small: 0.002
46 | lr_large: 0.02
--------------------------------------------------------------------------------
/configs/tnt.yaml:
--------------------------------------------------------------------------------
1 |
2 | defaults:
3 | expname: basis_truck
4 | logdir: ./logs
5 |
6 | ckpt: null # help='specific weights npy file to reload for coarse network'
7 |
8 | model:
9 | coeff_reso: 32
10 |
11 | # basis_dims: [8,4,2]
12 | # basis_resos: [32,64,128]
13 | # freq_bands: [3.0,4.7,6.8]
14 |
15 | ## 32.88
16 | # basis_dims: [3, 3, 3, 3, 3, 3, 3]
17 | # basis_resos: [64, 64, 64, 64, 64, 64, 64]
18 | ## freq_bands: [1.52727273, 2.58181818, 3.63636364, 4.69090909, 5.21818182, 6.27272727, 6.8]
19 | # freq_bands: [2., 3., 4.,5.,6.,7.,8.]
20 |
21 | # basis_dims: [5,5,5,2,2,2]
22 | # basis_resos: [32,51,70,89,108,128]
23 | # coeff_reso: 26
24 | # freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
25 |
26 | basis_dims: [4,4,4,2,2,2]
27 | basis_resos: [32,51,70,89,108,128]
28 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
29 | # freq_bands: [2. , 2.8, 3.6, 4.4, 5.2, 6.]
30 |
31 | coef_init: 1.0
32 | phases: [0.0]
33 | total_params: 5744166
34 |
35 | coef_mode: bilinear
36 | basis_mode: bilinear
37 |
38 | kernel_mapping_type: 'sawtooth'
39 |
40 | in_dim: 3
41 | out_dim: 32
42 | num_layers: 2
43 | hidden_dim: 64
44 |
45 | dataset:
46 | # loader options
47 | dataset_name: tankstemple # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
48 | datadir: ./data/TanksAndTemple/Truck
49 | ndc_ray: 0
50 |
51 | with_depth: 0
52 | downsample_train: 1.0
53 | downsample_test: 1.0
54 |
55 | N_vis: 5
56 | vis_every: 100000
57 | scene_reso: 768
58 |
59 | training:
60 |
61 | n_iters: 30000
62 | batch_size: 4096
63 |
64 | volume_resoInit: 128 # 128**3:
65 | volume_resoFinal: 320 # 300**3
66 |
67 | # upsamp_list: [2000,4000,7000]
68 | # update_AlphaMask_list: [2000,3000]
69 | upsamp_list: [2000,3000,4000,5500,7000]
70 | update_AlphaMask_list: [2500,4000]
71 | shrinking_list: [500]
72 |
73 | L1_weight_inital: 0.0
74 | L1_weight_rest: 0.0
75 |
76 | TV_weight_density: 0.0
77 | TV_weight_app: 0.00
78 |
79 | exportation:
80 | render_only: 0
81 | render_test: 1
82 | render_train: 0
83 | render_path: 0
84 | export_mesh: 0
85 | export_mesh_only: 0
86 |
87 | renderer:
88 | shadingMode: MLP_Fea
89 | num_layers: 3
90 | hidden_dim: 128
91 |
92 | fea2denseAct: 'softplus'
93 | density_shift: -10
94 | distance_scale: 25.0
95 |
96 | view_pe: 2
97 | fea_pe: 2
98 |
99 | lindisp: 0
100 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
101 |
102 | step_ratio: 0.5
103 | max_samples: 1200
104 |
105 | alphaMask_thres: 0.005
106 | rayMarch_weight_thres: 1e-3
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/dataLoader/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/dataLoader/.DS_Store
--------------------------------------------------------------------------------
/dataLoader/__init__.py:
--------------------------------------------------------------------------------
1 | from .llff import LLFFDataset
2 | from .blender import BlenderDataset
3 | from .nsvf import NSVF
4 | from .tankstemple import TanksTempleDataset
5 | from .your_own_data import YourOwnDataset
6 | from .image import ImageDataset
7 | from .image_set import ImageSetDataset
8 | from .colmap import ColmapDataset
9 | from .sdf import SDFDataset
10 | from .blender_set import BlenderDatasetSet
11 | from .google_objs import GoogleObjsDataset
12 | from .dtu_objs import DTUDataset
13 |
14 |
15 | dataset_dict = {'blender': BlenderDataset,
16 | 'blender_set': BlenderDatasetSet,
17 | 'llff':LLFFDataset,
18 | 'tankstemple':TanksTempleDataset,
19 | 'nsvf':NSVF,
20 | 'own_data':YourOwnDataset,
21 | 'image':ImageDataset,
22 | 'images':ImageSetDataset,
23 | 'sdf':SDFDataset,
24 | 'colmap':ColmapDataset,
25 | 'google_objs':GoogleObjsDataset,
26 | 'dtu':DTUDataset}
--------------------------------------------------------------------------------
/dataLoader/blender.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | class BlenderDataset(Dataset):
13 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
14 |
15 | # self.N_vis = N_vis
16 | self.split = split
17 | self.batch_size = batch_size
18 | self.root_dir = cfg.datadir
19 | self.is_stack = is_stack if is_stack is not None else 'train'!=split
20 | self.downsample = cfg.get(f'downsample_{self.split}')
21 | self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
22 | self.define_transforms()
23 |
24 | self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
25 | [0.73729737, 0.44876192, -0.50498052],
26 | [0.16296514, 0.60727077, 0.77760181]])
27 |
28 | self.scene_bbox = (np.array([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]])).tolist()
29 | # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
30 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
31 | self.read_meta()
32 | self.define_proj_mat()
33 |
34 | self.white_bg = True
35 | self.near_far = [2.0, 6.0]
36 |
37 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
38 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
39 |
40 | def read_depth(self, filename):
41 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
42 | return depth
43 |
44 | def read_meta(self):
45 |
46 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
47 | self.meta = json.load(f)
48 |
49 | w, h = self.img_wh
50 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
51 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
52 |
53 | # ray directions for all pixels, same for all images (same H, W, focal)
54 | self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3)
55 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
56 | self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()
57 |
58 | self.image_paths = []
59 | self.poses = []
60 | self.all_rays = []
61 | self.all_rgbs = []
62 | self.all_masks = []
63 | self.all_depth = []
64 | self.downsample = 1.0
65 |
66 | img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
67 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
68 | # idxs = idxs[:10] if self.split=='train' else idxs
69 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
70 |
71 | frame = self.meta['frames'][i]
72 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
73 | c2w = torch.FloatTensor(pose)
74 | c2w[:3,-1] /= 1.5
75 | self.poses += [c2w]
76 |
77 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
78 | self.image_paths += [image_path]
79 | img = Image.open(image_path)
80 |
81 | if self.downsample != 1.0:
82 | img = img.resize(self.img_wh, Image.LANCZOS)
83 | img = self.transform(img) # (4, h, w)
84 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
85 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
86 | self.all_rgbs += [img]
87 |
88 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
89 | # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
90 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
91 |
92 | self.poses = torch.stack(self.poses)
93 | if not self.is_stack:
94 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
95 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
96 |
97 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
98 | else:
99 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
100 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
101 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
102 |
103 | self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]),self.batch_size)
104 |
105 | def define_transforms(self):
106 | self.transform = T.ToTensor()
107 |
108 | def define_proj_mat(self):
109 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
110 |
111 | # def world2ndc(self,points,lindisp=None):
112 | # device = points.device
113 | # return (points - self.center.to(device)) / self.radius.to(device)
114 |
115 | def __len__(self):
116 | return len(self.all_rgbs) if self.split=='test' else 300000
117 |
118 | def __getitem__(self, idx):
119 | idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
120 | sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
121 | return sample
--------------------------------------------------------------------------------
/dataLoader/blender_set.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | class BlenderDatasetSet(Dataset):
13 | def __init__(self, cfg, split='train'):
14 |
15 | # self.N_vis = N_vis
16 | self.root_dir = cfg.datadir
17 | self.split = split
18 | self.is_stack = False if 'train'==split else True
19 | self.downsample = cfg.get(f'downsample_{self.split}')
20 | self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
21 | self.define_transforms()
22 |
23 | self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
24 | [0.73729737, 0.44876192, -0.50498052],
25 | [0.16296514, 0.60727077, 0.77760181]])
26 |
27 | self.scene_bbox = (np.array([[-1.0, -1.0, -1.0, 0], [1.0, 1.0, 1.0, 2]])).tolist()
28 | # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
29 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
30 | self.read_meta()
31 | self.define_proj_mat()
32 |
33 | self.white_bg = True
34 | self.near_far = [2.0, 6.0]
35 |
36 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
37 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
38 |
39 | def read_depth(self, filename):
40 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
41 | return depth
42 |
43 | def read_meta(self):
44 |
45 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
46 | self.meta = json.load(f)
47 |
48 | w, h = self.img_wh
49 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
50 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
51 |
52 | # ray directions for all pixels, same for all images (same H, W, focal)
53 | self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3)
54 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
55 | self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()
56 |
57 | self.image_paths = []
58 | self.poses = []
59 | self.all_rays = []
60 | self.all_rgbs = []
61 | self.all_masks = []
62 | self.all_depth = []
63 | self.downsample = 1.0
64 |
65 | img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
66 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
67 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
68 |
69 | frame = self.meta['frames'][i]
70 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
71 | c2w = torch.FloatTensor(pose)
72 | c2w[:3,-1] /= 1.5
73 | self.poses += [c2w]
74 |
75 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
76 | self.image_paths += [image_path]
77 | img = Image.open(image_path)
78 |
79 | if self.downsample != 1.0:
80 | img = img.resize(self.img_wh, Image.LANCZOS)
81 | img = self.transform(img) # (4, h, w)
82 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
83 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
84 | self.all_rgbs += [img]
85 |
86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
87 | # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
88 |
89 | scene_id = torch.ones_like(rays_o[...,:1])*0
90 | self.all_rays += [torch.cat([rays_o, rays_d, scene_id], 1)] # (h*w, 6)
91 |
92 | self.poses = torch.stack(self.poses)
93 | if not self.is_stack:
94 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
95 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
96 |
97 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
98 | else:
99 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
100 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
101 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
102 |
103 | def define_transforms(self):
104 | self.transform = T.ToTensor()
105 |
106 | def define_proj_mat(self):
107 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
108 |
109 | # def world2ndc(self,points,lindisp=None):
110 | # device = points.device
111 | # return (points - self.center.to(device)) / self.radius.to(device)
112 |
113 | def __len__(self):
114 | return len(self.all_rgbs)
115 |
116 | def __getitem__(self, idx):
117 | rays = torch.cat((self.all_rays[idx],torch.tensor([0+0.5])),dim=-1)
118 | sample = {'rays': rays, 'rgbs': self.all_rgbs[idx]}
119 | return sample
--------------------------------------------------------------------------------
/dataLoader/colmap.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | from torch.utils.data import Dataset
3 |
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | class ColmapDataset(Dataset):
13 | def __init__(self, cfg, split='train'):
14 |
15 | self.cfg = cfg
16 | self.root_dir = cfg.datadir
17 | self.split = split
18 | self.is_stack = False if 'train'==split else True
19 | self.downsample = cfg.get(f'downsample_{self.split}')
20 | self.define_transforms()
21 | self.img_eval_interval = 8
22 |
23 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])#np.eye(4)#
24 | self.read_meta()
25 |
26 | self.white_bg = cfg.get('white_bg')
27 |
28 | # self.near_far = [0.1,2.0]
29 |
30 |
31 | def read_meta(self):
32 |
33 | if os.path.exists(f'{self.root_dir}/transforms.json'):
34 | self.meta = load_json(f'{self.root_dir}/transforms.json')
35 | i_test = np.arange(0, len(self.meta['frames']), self.img_eval_interval) # [np.argmin(dists)]
36 | idxs = i_test if self.split != 'train' else list(set(np.arange(len(self.meta['frames']))) - set(i_test))
37 | else:
38 | self.meta = load_json(f'{self.root_dir}/transforms_{self.split}.json')
39 | idxs = np.arange(0, len(self.meta['frames'])) # [np.argmin(dists)]
40 | inv_split = 'train' if self.split!='train' else 'test'
41 | self.meta['frames'] += load_json(f'{self.root_dir}/transforms_{inv_split}.json')['frames']
42 | print(len(self.meta['frames']),len(idxs))
43 |
44 |
45 | self.scale = self.meta.get('scale', 0.5)
46 | self.offset = torch.FloatTensor(self.meta.get('offset', [0.0,0.0,0.0]))
47 | # self.scene_bbox = (torch.tensor([[-6.,-7.,-10.0],[6.,7.,10.]])/5).tolist()
48 | # self.scene_bbox = [[-1., -1., -1.0], [1., 1., 1.]]
49 |
50 | # center, radius = torch.tensor([-0.082157, 2.415426,-3.703080]), torch.tensor([7.36916, 11.34958, 20.1616])/2
51 | # self.scene_bbox = torch.stack([center-radius, center+radius]).tolist()
52 |
53 | h, w = int(self.meta.get('h')), int(self.meta.get('w'))
54 | cx, cy = self.meta.get('cx'), self.meta.get('cy')
55 | self.focal = [self.meta.get('fl_x'), self.meta.get('fl_y')]
56 |
57 | # ray directions for all pixels, same for all images (same H, W, focal)
58 | self.directions = get_ray_directions(h, w, self.focal, center=[cx, cy]) # (h, w, 3)
59 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
60 | # self.intrinsics = torch.FloatTensor([[self.focal[0], 0, cx], [0, self.focal[1], cy], [0, 0, 1]])
61 | # self.intrinsics[:2] /= self.downsample
62 |
63 | poses = pose_from_json(self.meta, self.blender2opencv)
64 | poses, self.scene_bbox = orientation(poses, f'{self.root_dir}/colmap_text/points3D.txt')
65 |
66 | self.image_paths = []
67 | self.poses = []
68 | self.all_rays = []
69 | self.all_rgbs = []
70 | self.all_masks = []
71 | self.all_depth = []
72 |
73 | self.img_wh = [w,h]
74 |
75 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
76 |
77 | frame = self.meta['frames'][i]
78 | c2w = torch.FloatTensor(poses[i])
79 | # c2w[:3,3] = (c2w[:3,3]*self.scale + self.offset)*2-1
80 | self.poses += [c2w]
81 |
82 | image_path = os.path.join(self.root_dir, frame['file_path'])
83 | self.image_paths += [image_path]
84 | img = Image.open(image_path)
85 |
86 | if self.downsample != 1.0:
87 | img = img.resize(self.img_wh, Image.LANCZOS)
88 |
89 | img = self.transform(img)
90 | if img.shape[0]==4:
91 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
92 | img = img.view(3, -1).permute(1, 0)
93 | self.all_rgbs += [img]
94 |
95 |
96 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
97 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
98 |
99 | self.poses = torch.stack(self.poses)
100 | if not self.is_stack:
101 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
102 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
103 | else:
104 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
105 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
106 |
107 | def define_transforms(self):
108 | self.transform = T.ToTensor()
109 |
110 |
111 | def __len__(self):
112 | return len(self.all_rgbs)
113 |
114 | def __getitem__(self, idx):
115 |
116 | if self.split == 'train': # use data in the buffers
117 | sample = {'rays': self.all_rays[idx],
118 | 'rgbs': self.all_rgbs[idx]}
119 |
120 | else: # create data for each image separately
121 |
122 | img = self.all_rgbs[idx]
123 | rays = self.all_rays[idx]
124 |
125 | sample = {'rays': rays,
126 | 'rgbs': img}
127 | return sample
--------------------------------------------------------------------------------
/dataLoader/colmap2nerf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | import argparse
12 | import os
13 | from pathlib import Path, PurePosixPath
14 |
15 | import numpy as np
16 | import json
17 | import sys
18 | import math
19 | import cv2
20 | import os
21 | import shutil
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
25 |
26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
27 | parser.add_argument("--video_fps", default=2)
28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
32 | parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL","OPENCV"], help="camera model")
33 | parser.add_argument("--colmap_camera_params", default="", help="intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist")
34 | parser.add_argument("--images", default="images", help="input path to the images")
35 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
36 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
37 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
38 | parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)")
39 | parser.add_argument("--out", default="transforms.json", help="output path")
40 | parser.add_argument("--vocab_path", default="", help="vocabulary tree path")
41 | args = parser.parse_args()
42 | return args
43 |
44 | def do_system(arg):
45 | print(f"==== running: {arg}")
46 | err = os.system(arg)
47 | if err:
48 | print("FATAL: command failed")
49 | sys.exit(err)
50 |
51 | def run_ffmpeg(args):
52 | if not os.path.isabs(args.images):
53 | args.images = os.path.join(os.path.dirname(args.video_in), args.images)
54 | images = args.images
55 | video = args.video_in
56 | fps = float(args.video_fps) or 1.0
57 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
58 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
59 | sys.exit(1)
60 | try:
61 | shutil.rmtree(images)
62 | except:
63 | pass
64 | do_system(f"mkdir {images}")
65 |
66 | time_slice_value = ""
67 | time_slice = args.time_slice
68 | if time_slice:
69 | start, end = time_slice.split(",")
70 | time_slice_value = f",select='between(t\,{start}\,{end})'"
71 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
72 |
73 | def run_colmap(args):
74 | db = args.colmap_db
75 | images = "\"" + args.images + "\""
76 | db_noext=str(Path(db).with_suffix(""))
77 |
78 | if args.text=="text":
79 | args.text=db_noext+"_text"
80 | text=args.text
81 | sparse=db_noext+"_sparse"
82 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
83 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
84 | sys.exit(1)
85 | if os.path.exists(db):
86 | os.remove(db)
87 | do_system(f"colmap feature_extractor --ImageReader.camera_model {args.colmap_camera_model} --ImageReader.camera_params \"{args.colmap_camera_params}\" --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
88 | match_cmd = f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}"
89 | if args.vocab_path:
90 | match_cmd += f" --VocabTreeMatching.vocab_tree_path {args.vocab_path}"
91 | do_system(match_cmd)
92 | try:
93 | shutil.rmtree(sparse)
94 | except:
95 | pass
96 | do_system(f"mkdir {sparse}")
97 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
98 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
99 | try:
100 | shutil.rmtree(text)
101 | except:
102 | pass
103 | do_system(f"mkdir {text}")
104 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
105 |
106 |
107 | def variance_of_laplacian(image):
108 | return cv2.Laplacian(image, cv2.CV_64F).var()
109 |
110 | def sharpness(imagePath):
111 | image = cv2.imread(imagePath)
112 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
113 | fm = variance_of_laplacian(gray)
114 | return fm
115 |
116 | def qvec2rotmat(qvec):
117 | return np.array([
118 | [
119 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
120 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
121 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
122 | ], [
123 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
124 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
125 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
126 | ], [
127 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
128 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
129 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
130 | ]
131 | ])
132 |
133 | def rotmat(a, b):
134 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
135 | v = np.cross(a, b)
136 | c = np.dot(a, b)
137 | s = np.linalg.norm(v)
138 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
139 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
140 |
141 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
142 | da = da / np.linalg.norm(da)
143 | db = db / np.linalg.norm(db)
144 | c = np.cross(da, db)
145 | denom = np.linalg.norm(c)**2
146 | t = ob - oa
147 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
148 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
149 | if ta > 0:
150 | ta = 0
151 | if tb > 0:
152 | tb = 0
153 | return (oa+ta*da+ob+tb*db) * 0.5, denom
154 |
155 | ############ orientation ##############
156 | def normalize(x):
157 | return x / np.linalg.norm(x)
158 |
159 | def rotation_matrix_from_vectors(vec1, vec2):
160 | """ Find the rotation matrix that aligns vec1 to vec2
161 | :param vec1: A 3d "source" vector
162 | :param vec2: A 3d "destination" vector
163 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
164 | """
165 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
166 | v = np.cross(a, b)
167 | if any(v): # if not all zeros then
168 | c = np.dot(a, b)
169 | s = np.linalg.norm(v)
170 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
171 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
172 |
173 | else:
174 | return np.eye(3) # cross of all zeros only occurs on identical directions
175 |
176 | def rotation_up(poses):
177 | up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)),rcond=None)[0])
178 | rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.]))
179 | return rot
180 |
181 | def search_orientation(points):
182 | from scipy.spatial.transform import Rotation as R
183 | bbox_sizes,rot_mats,bboxs = [],[],[]
184 | for y_angle in np.linspace(-45,45,15):
185 | rotvec = np.array([0,y_angle,0])/180*np.pi
186 | rot = R.from_rotvec(rotvec).as_matrix()
187 | point_orientation = rot@points
188 | bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1)
189 | bbox_sizes.append(np.prod(bbox))
190 | rot_mats.append(rot)
191 | bboxs.append(bbox)
192 | rot = rot_mats[np.argmin(bbox_sizes)]
193 | bbox = bboxs[np.argmin(bbox_sizes)]
194 | return rot,bbox
195 |
196 | def load_point_txt(path):
197 | points = []
198 | with open(path, "r") as f:
199 | for line in f:
200 | if line[0] == "#":
201 | continue
202 | els = line.split(" ")
203 | points.append([float(els[1]),float(els[2]),float(els[3])])
204 | return np.stack(points)
205 |
206 | # def load_c2ws(frames):
207 | # c2ws =
208 | # for f in frames:
209 | # f["transform_matrix"] -= totp
210 | #
211 | # def oritation(transform_matrix, point_txt):
212 |
213 |
214 | if __name__ == "__main__":
215 | args = parse_args()
216 | if args.video_in != "":
217 | run_ffmpeg(args)
218 | if args.run_colmap:
219 | run_colmap(args)
220 | AABB_SCALE = int(args.aabb_scale)
221 | SKIP_EARLY = int(args.skip_early)
222 | IMAGE_FOLDER = args.images
223 | TEXT_FOLDER = args.text
224 | OUT_PATH = args.out
225 | print(f"outputting to {OUT_PATH}...")
226 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
227 | angle_x = math.pi / 2
228 | for line in f:
229 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
230 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
231 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
232 | if line[0] == "#":
233 | continue
234 | els = line.split(" ")
235 | w = float(els[2])
236 | h = float(els[3])
237 | fl_x = float(els[4])
238 | fl_y = float(els[4])
239 | k1 = 0
240 | k2 = 0
241 | p1 = 0
242 | p2 = 0
243 | cx = w / 2
244 | cy = h / 2
245 | if els[1] == "SIMPLE_PINHOLE":
246 | cx = float(els[5])
247 | cy = float(els[6])
248 | elif els[1] == "PINHOLE":
249 | fl_y = float(els[5])
250 | cx = float(els[6])
251 | cy = float(els[7])
252 | elif els[1] == "SIMPLE_RADIAL":
253 | cx = float(els[5])
254 | cy = float(els[6])
255 | k1 = float(els[7])
256 | elif els[1] == "RADIAL":
257 | cx = float(els[5])
258 | cy = float(els[6])
259 | k1 = float(els[7])
260 | k2 = float(els[8])
261 | elif els[1] == "OPENCV":
262 | fl_y = float(els[5])
263 | cx = float(els[6])
264 | cy = float(els[7])
265 | k1 = float(els[8])
266 | k2 = float(els[9])
267 | p1 = float(els[10])
268 | p2 = float(els[11])
269 | else:
270 | print("unknown camera model ", els[1])
271 | # fl = 0.5 * w / tan(0.5 * angle_x);
272 | angle_x = math.atan(w / (fl_x * 2)) * 2
273 | angle_y = math.atan(h / (fl_y * 2)) * 2
274 | fovx = angle_x * 180 / math.pi
275 | fovy = angle_y * 180 / math.pi
276 |
277 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
278 |
279 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
280 | i = 0
281 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
282 | out = {
283 | "camera_angle_x": angle_x,
284 | "camera_angle_y": angle_y,
285 | "fl_x": fl_x,
286 | "fl_y": fl_y,
287 | "k1": k1,
288 | "k2": k2,
289 | "p1": p1,
290 | "p2": p2,
291 | "cx": cx,
292 | "cy": cy,
293 | "w": w,
294 | "h": h,
295 | "aabb_scale": AABB_SCALE,
296 | "frames": [],
297 | }
298 |
299 | up = np.zeros(3)
300 | for line in f:
301 | line = line.strip()
302 | if line[0] == "#":
303 | continue
304 | i = i + 1
305 | if i < SKIP_EARLY*2:
306 | continue
307 | if i % 2 == 1:
308 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
309 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
310 | # why is this requireing a relitive path while using ^
311 | image_rel = os.path.relpath(IMAGE_FOLDER)
312 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
313 | b=sharpness(name)
314 | print(name, "sharpness=",b)
315 | image_id = int(elems[0])
316 | qvec = np.array(tuple(map(float, elems[1:5])))
317 | tvec = np.array(tuple(map(float, elems[5:8])))
318 | R = qvec2rotmat(-qvec)
319 | t = tvec.reshape([3,1])
320 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
321 | c2w = np.linalg.inv(m)
322 | # c2w[0:3,2] *= -1 # flip the y and z axis
323 | # c2w[0:3,1] *= -1
324 | # c2w = c2w[[1,0,2,3],:] # swap y and z
325 | # c2w[2,:] *= -1 # flip whole world upside down
326 |
327 | # up += c2w[0:3,1]
328 |
329 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
330 | out["frames"].append(frame)
331 |
332 | # up = up / np.linalg.norm(up)
333 | # print("up vector was", up)
334 | # R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
335 | # R = np.pad(R,[0,1])
336 | # R[-1, -1] = 1
337 | # for f in out["frames"]:
338 | # f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
339 |
340 | nframes = len(out["frames"])
341 |
342 | # find a central point they are all looking at
343 | print("computing center of attention...")
344 | totw = 0.0
345 | totp = np.array([0.0, 0.0, 0.0])
346 | for f in out["frames"]:
347 | mf = f["transform_matrix"][0:3,:]
348 | for g in out["frames"]:
349 | mg = g["transform_matrix"][0:3,:]
350 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
351 | if w > 0.01:
352 | totp += p*w
353 | totw += w
354 | totp /= totw
355 | print(totp) # the cameras are looking at totp
356 | for f in out["frames"]:
357 | f["transform_matrix"][0:3,3] -= totp
358 |
359 | avglen = 0.
360 | for f in out["frames"]:
361 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
362 | avglen /= nframes
363 | print("avg camera distance from origin", avglen)
364 | for f in out["frames"]:
365 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
366 |
367 |
368 |
369 | for f in out["frames"]:
370 | f["transform_matrix"] = f["transform_matrix"].tolist()
371 | print(nframes,"frames")
372 | print(f"writing {OUT_PATH}")
373 | with open(OUT_PATH, "w") as outfile:
374 | json.dump(out, outfile, indent=2)
--------------------------------------------------------------------------------
/dataLoader/dtu_objs.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from glob import glob
7 | from .ray_utils import *
8 | from torch.utils.data import Dataset
9 |
10 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
11 | def load_K_Rt_from_P(filename, P=None):
12 | if P is None:
13 | lines = open(filename).read().splitlines()
14 | if len(lines) == 4:
15 | lines = lines[1:]
16 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
17 | P = np.asarray(lines).astype(np.float32).squeeze()
18 |
19 | out = cv.decomposeProjectionMatrix(P)
20 | K = out[0]
21 | R = out[1]
22 | t = out[2]
23 |
24 | K = K / K[2, 2]
25 | intrinsics = np.eye(4)
26 | intrinsics[:3, :3] = K
27 |
28 | pose = np.eye(4, dtype=np.float32)
29 | pose[:3, :3] = R.transpose()
30 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
31 |
32 | return intrinsics, pose
33 |
34 | def fps_downsample(points, n_points_to_sample):
35 | selected_points = np.zeros((n_points_to_sample, 3))
36 | selected_idxs = []
37 | dist = np.ones(points.shape[0]) * 100
38 | for i in range(n_points_to_sample):
39 | idx = np.argmax(dist)
40 | selected_points[i] = points[idx]
41 | selected_idxs.append(idx)
42 | dist_ = ((points - selected_points[i]) ** 2).sum(-1)
43 | dist = np.minimum(dist, dist_)
44 |
45 | return selected_idxs
46 |
47 | class DTUDataset(Dataset):
48 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
49 | """
50 | img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
51 | """
52 | # self.N_vis = N_vis
53 | self.split = split
54 | self.batch_size = batch_size
55 | self.root_dir = cfg.datadir
56 | self.is_stack = is_stack if is_stack is not None else 'train'!=split
57 | self.downsample = cfg.get(f'downsample_{self.split}')
58 | self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))
59 |
60 | train_scene_idxs = sorted(cfg.train_scene_list)
61 | test_scene_idxs = cfg.test_scene_list
62 | if len(train_scene_idxs)==2:
63 | train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1]))
64 | self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs
65 | print(self.scene_idxs)
66 | self.train_views = cfg.train_views
67 | self.scene_num = len(self.scene_idxs)
68 | self.test_index = test_scene_idxs
69 | # if 'test' == self.split:
70 | # self.test_index = train_scene_idxs.index(test_scene_idxs[0])
71 |
72 | self.scene_path_list = [os.path.join(self.root_dir, f"scan{i}") for i in self.scene_idxs]
73 | # self.scene_path_list = sorted(glob(os.path.join(self.root_dir, "scan*")))
74 |
75 | self.read_meta()
76 | self.white_bg = False
77 |
78 | def read_meta(self):
79 | self.aabbs = []
80 | self.all_rgb_files,self.all_pose_files,self.all_intrinsics_files = {},{},{}
81 | for i, scene_idx in enumerate(self.scene_idxs):
82 |
83 | scene_path = self.scene_path_list[i]
84 | camera_dict = np.load(os.path.join(scene_path, 'cameras.npz'))
85 |
86 | self.all_rgb_files[scene_idx] = [
87 | os.path.join(scene_path, "image", f)
88 | for f in sorted(os.listdir(os.path.join(scene_path, "image")))
89 | ]
90 |
91 | # world_mat is a projection matrix from world to image
92 | n_images = len(self.all_rgb_files[scene_idx])
93 | world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
94 | scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
95 | object_scale_mat = camera_dict['scale_mat_0']
96 | self.aabbs.append(self.get_bbox(scale_mats_np,object_scale_mat))
97 |
98 | # W,H = self.img_wh
99 | intrinsics_scene, poses_scene = [], []
100 | for img_idx, (scale_mat, world_mat) in enumerate(zip(scale_mats_np, world_mats_np)):
101 | P = world_mat @ scale_mat
102 | P = P[:3, :4]
103 | intrinsic, c2w = load_K_Rt_from_P(None, P)
104 |
105 | c2w = torch.from_numpy(c2w).float()
106 | intrinsic = torch.from_numpy(intrinsic).float()
107 | intrinsic[:2] /= self.downsample
108 |
109 | poses_scene.append(c2w)
110 | intrinsics_scene.append(intrinsic)
111 |
112 | self.all_pose_files[scene_idx] = np.stack(poses_scene)
113 | self.all_intrinsics_files[scene_idx] = np.stack(intrinsics_scene)
114 |
115 | self.aabbs[0][0].append(0)
116 | self.aabbs[0][1].append(self.scene_num)
117 | self.scene_bbox = self.aabbs[0]
118 | print(self.scene_bbox)
119 | if self.split=='test' or self.scene_num==1:
120 | self.load_data(self.scene_idxs[0],range(49))
121 |
122 | def load_data(self, scene_idx, img_idx=None):
123 | self.all_rays = []
124 |
125 | n_views = len(self.all_pose_files[scene_idx])
126 | cam_xyzs = self.all_pose_files[scene_idx][:,:3, -1]
127 | idxs = fps_downsample(cam_xyzs, min(self.train_views, n_views)) if img_idx is None else img_idx
128 | # if "test" == self.split:
129 | # idxs = [item for item in list(range(n_views)) if item not in idxs]
130 | # if len(idxs)==0:
131 | # idxs = list(range(n_views))
132 |
133 | images_np = np.stack([cv.resize(cv.imread(self.all_rgb_files[scene_idx][idx]), self.img_wh) for idx in idxs]) / 255.0
134 | self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3]
135 |
136 | for c2w,intrinsic in zip(self.all_pose_files[scene_idx][idxs],self.all_intrinsics_files[scene_idx][idxs]):
137 | rays_o, rays_d = self.gen_rays_at(torch.from_numpy(intrinsic).float(), torch.from_numpy(c2w).float())
138 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
139 |
140 | if not self.is_stack:
141 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
142 | self.all_rgbs = self.all_rgbs.reshape(-1, 3)
143 | else:
144 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
145 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
146 |
147 | # self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
148 |
149 |
150 | # def read_meta(self):
151 | #
152 | # images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
153 | # images_np = np.stack([cv.resize(cv.imread(im_name), self.img_wh) for im_name in images_lis]) / 255.0
154 | # # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
155 | # # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128
156 | #
157 | # self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3]
158 | # # self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3]
159 | # self.img_wh = [self.all_rgbs.shape[2], self.all_rgbs.shape[1]]
160 | #
161 | # # world_mat is a projection matrix from world to image
162 | # n_images = len(images_lis)
163 | # world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
164 | # self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
165 | #
166 | # # W,H = self.img_wh
167 | # self.all_rays = []
168 | # self.intrinsics, self.poses = [], []
169 | # for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
170 | # P = world_mat @ scale_mat
171 | # P = P[:3, :4]
172 | # intrinsic, c2w = load_K_Rt_from_P(None, P)
173 | #
174 | # c2w = torch.from_numpy(c2w).float()
175 | # intrinsic = torch.from_numpy(intrinsic).float()
176 | # intrinsic[:2] /= self.downsample
177 | #
178 | # self.poses.append(c2w)
179 | # self.intrinsics.append(intrinsic)
180 | #
181 | # rays_o, rays_d = self.gen_rays_at(intrinsic, c2w)
182 | # self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
183 | #
184 | # self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)
185 | #
186 | # # self.all_rgbs[~self.all_masks] = 1.0
187 | # if not self.is_stack:
188 | # self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
189 | # self.all_rgbs = self.all_rgbs.reshape(-1, 3)
190 | # else:
191 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
192 | # self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
193 | #
194 | # self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
195 |
196 | def get_bbox(self, scale_mats_np, object_scale_mat):
197 | object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
198 | object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0])
199 | # Object scale mat: region of interest to **extract mesh**
200 | # object_scale_mat = np.load(os.path.join(scene_path, 'cameras.npz'))
201 | object_bbox_min = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
202 | object_bbox_max = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
203 | return [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
204 | # self.near_far = [2.125, 4.525]
205 |
206 | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
207 | """
208 | Generate rays at world space from one camera.
209 | """
210 | l = resolution_level
211 | W,H = self.img_wh
212 | tx = torch.linspace(0, W - 1, W // l)+0.5
213 | ty = torch.linspace(0, H - 1, H // l)+0.5
214 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
215 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
216 | intrinsic_inv = torch.inverse(intrinsic)
217 | p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
218 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
219 | rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
220 | rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3
221 | return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)
222 |
223 |
224 | def __len__(self):
225 | return 1000000 #len(self.all_rays)
226 |
227 | def __getitem__(self, idx):
228 | idx = torch.randint(self.scene_num,(1,)).item()
229 | if self.scene_num >= 1:
230 | scene_name = self.scene_idxs[idx]
231 | img_idx = np.random.choice(len(self.all_rgb_files[scene_name]), size=6)
232 | self.load_data(scene_name, img_idx)
233 |
234 | idxs = np.random.choice(self.all_rays.shape[0], size=self.batch_size)
235 |
236 | return {'rays': self.all_rays[idxs], 'rgbs': self.all_rgbs[idxs], 'idx': idx}
--------------------------------------------------------------------------------
/dataLoader/dtu_objs2.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from glob import glob
7 | from .ray_utils import *
8 | from torch.utils.data import Dataset
9 |
10 |
11 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
12 | def load_K_Rt_from_P(filename, P=None):
13 | if P is None:
14 | lines = open(filename).read().splitlines()
15 | if len(lines) == 4:
16 | lines = lines[1:]
17 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
18 | P = np.asarray(lines).astype(np.float32).squeeze()
19 |
20 | out = cv.decomposeProjectionMatrix(P)
21 | K = out[0]
22 | R = out[1]
23 | t = out[2]
24 |
25 | K = K / K[2, 2]
26 | intrinsics = np.eye(4)
27 | intrinsics[:3, :3] = K
28 |
29 | pose = np.eye(4, dtype=np.float32)
30 | pose[:3, :3] = R.transpose()
31 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
32 |
33 | return intrinsics, pose
34 |
35 | class DTUDataset(Dataset):
36 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
37 | """
38 | img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
39 | """
40 | # self.N_vis = N_vis
41 | self.split = split
42 | self.batch_size = batch_size
43 | self.root_dir = cfg.datadir
44 | self.is_stack = is_stack if is_stack is not None else 'train'!=split
45 | self.downsample = cfg.get(f'downsample_{self.split}')
46 | self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))
47 |
48 | self.white_bg = False
49 | self.camera_dict = np.load(os.path.join(self.root_dir, 'cameras.npz'))
50 |
51 | self.read_meta()
52 | self.get_bbox()
53 |
54 | # def define_transforms(self):
55 | # self.transform = T.ToTensor()
56 |
57 | def get_bbox(self):
58 | object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
59 | object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0])
60 | # Object scale mat: region of interest to **extract mesh**
61 | object_scale_mat = np.load(os.path.join(self.root_dir, 'cameras.npz'))['scale_mat_0']
62 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
63 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
64 | self.scene_bbox = [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
65 | self.scene_bbox[0].append(0)
66 | self.scene_bbox[1].append(1)
67 |
68 | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
69 | """
70 | Generate rays at world space from one camera.
71 | """
72 | l = resolution_level
73 | W,H = self.img_wh
74 | tx = torch.linspace(0, W - 1, W // l)+0.5
75 | ty = torch.linspace(0, H - 1, H // l)+0.5
76 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
77 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
78 | intrinsic_inv = torch.inverse(intrinsic)
79 | p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
80 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
81 | rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
82 | rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3
83 | return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)
84 |
85 | def read_meta(self):
86 |
87 | images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
88 | images_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in images_lis]) / 255.0
89 | # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
90 | # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128
91 |
92 | self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[...,[2,1,0]]) # [n_images, H, W, 3]
93 | # self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3]
94 | self.img_wh = [self.all_rgbs.shape[2],self.all_rgbs.shape[1]]
95 |
96 | # world_mat is a projection matrix from world to image
97 | n_images = len(images_lis)
98 | world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
99 | self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
100 |
101 | # W,H = self.img_wh
102 | self.all_rays = []
103 | self.intrinsics, self.poses = [],[]
104 | for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
105 | P = world_mat @ scale_mat
106 | P = P[:3, :4]
107 | intrinsic, c2w = load_K_Rt_from_P(None, P)
108 |
109 | c2w = torch.from_numpy(c2w).float()
110 | intrinsic = torch.from_numpy(intrinsic).float()
111 | intrinsic[:2] /= self.downsample
112 |
113 | self.poses.append(c2w)
114 | self.intrinsics.append(intrinsic)
115 |
116 | rays_o, rays_d = self.gen_rays_at(intrinsic,c2w)
117 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
118 |
119 | self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)
120 |
121 | # self.all_rgbs[~self.all_masks] = 1.0
122 | if not self.is_stack:
123 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
124 | self.all_rgbs = self.all_rgbs.reshape(-1,3)
125 | else:
126 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
127 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
128 |
129 | self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
130 |
131 | def __len__(self):
132 | return len(self.all_rays)
133 |
134 | def __getitem__(self, idx):
135 | idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
136 | sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
137 | return sample
--------------------------------------------------------------------------------
/dataLoader/image.py:
--------------------------------------------------------------------------------
1 | import torch,imageio,cv2
2 | from PIL import Image
3 | Image.MAX_IMAGE_PIXELS = 1000000000
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from torch.utils.data import Dataset
7 |
8 | _img_suffix = ['png','jpg','jpeg','bmp','tif']
9 |
10 | def load(path):
11 | suffix = path.split('.')[-1]
12 | if suffix in _img_suffix:
13 | img = np.array(Image.open(path))#.convert('L')
14 | scale = 256.**(1+np.log2(np.max(img))//8)-1
15 | return img/scale
16 | elif 'exr' == suffix:
17 | return imageio.imread(path)
18 | elif 'npy' == suffix:
19 | return np.load(path)
20 |
21 |
22 | def srgb_to_linear(img):
23 | limit = 0.04045
24 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
25 |
26 | class ImageDataset(Dataset):
27 | def __init__(self, cfg, batchsize, split='train', continue_sampling=False, tolinear=False, HW=-1, perscent=1.0, delete_region=None,mask=None):
28 |
29 | datadir = cfg.datadir
30 | self.batchsize = batchsize
31 | self.continue_sampling = continue_sampling
32 | img = load(datadir).astype(np.float32)
33 | if HW>0:
34 | img = cv2.resize(img,[HW,HW])
35 |
36 | if tolinear:
37 | img = srgb_to_linear(img)
38 | self.img = torch.from_numpy(img)
39 |
40 | H,W = self.img.shape[:2]
41 |
42 | y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
43 | self.coordiante = torch.stack((x,y),-1).float()+0.5
44 |
45 | n_channel = self.img.shape[-1]
46 | self.image = self.img
47 | self.img, self.coordiante = self.img.reshape(H*W,-1), self.coordiante.reshape(H*W,2)
48 |
49 | # if continue_sampling:
50 | # coordiante_tmp = self.coordiante.view(1,1,-1,2)/torch.tensor([W,H])*2-1.0
51 | # self.img = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordiante_tmp, mode='bilinear', align_corners=True).reshape(self.img.shape[-1],-1).t()
52 |
53 |
54 | if 'train'==split:
55 | self.mask = torch.ones_like(y)>0
56 | if mask is not None:
57 | self.mask = mask>0
58 | print(torch.sum(mask)/1.0/HW/HW)
59 | elif delete_region is not None:
60 |
61 | if isinstance(delete_region[0], list):
62 | for item in delete_region:
63 | t_l_x,t_l_y,width,height = item
64 | self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
65 | else:
66 | t_l_x,t_l_y,width,height = delete_region
67 | self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
68 | else:
69 | index = torch.randperm(len(self.img))[:int(len(self.img)*perscent)]
70 | self.mask[:] = False
71 | self.mask.view(-1)[index] = True
72 | self.mask = self.mask.view(-1)
73 | self.image, self.coordiante = self.img[self.mask], self.coordiante[self.mask]
74 | else:
75 | self.image = self.img
76 |
77 |
78 | self.HW = [H,W]
79 |
80 | self.scene_bbox = [[0., 0.], [W, H]]
81 | cfg.aabb = self.scene_bbox
82 | #
83 |
84 | def __len__(self):
85 | return 10000
86 |
87 | def __getitem__(self, idx):
88 | H,W = self.HW
89 | device = self.image.device
90 | idx = torch.randint(0,len(self.image),(self.batchsize,), device=device)
91 |
92 | if self.continue_sampling:
93 | coordinate = self.coordiante[idx] + torch.rand((self.batchsize,2))-0.5
94 | coordinate_tmp = (coordinate.view(1,1,self.batchsize,2))/torch.tensor([W,H],device=device)*2-1.0
95 | rgb = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear',
96 | align_corners=False, padding_mode='border').reshape(self.img.shape[-1],-1).t()
97 | sample = {'rgb': rgb,
98 | 'xy': coordinate}
99 | else:
100 | sample = {'rgb': self.image[idx],
101 | 'xy': self.coordiante[idx]}
102 |
103 | return sample
--------------------------------------------------------------------------------
/dataLoader/image_set.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from torch.utils.data import Dataset
5 |
6 | def srgb_to_linear(img):
7 | limit = 0.04045
8 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92)
9 |
10 | def load(path, HW=512):
11 | suffix = path.split('.')[-1]
12 |
13 | if 'npy' == suffix:
14 | img = np.load(path)
15 | # img = 0.3*img[...,:1] + 0.59*img[...,1:2] + 0.11*img[...,2:]
16 |
17 | if img.shape[-2]!=HW:
18 | for i in range(img.shape[0]):
19 | img[i] = cv2.resize(img[i],[HW,HW])
20 |
21 | return img
22 |
23 |
24 | class ImageSetDataset(Dataset):
25 | def __init__(self, cfg, batchsize, split='train', continue_sampling=False, HW=512, N=10, tolinear=True):
26 |
27 | datadir = cfg.datadir
28 | self.batchsize = batchsize
29 | self.continue_sampling = continue_sampling
30 | imgs = load(datadir,HW=HW)[:N]
31 |
32 |
33 | self.imgs = torch.from_numpy(imgs).float()/255
34 | if tolinear:
35 | self.imgs = srgb_to_linear(self.imgs)
36 |
37 | D,H,W = self.imgs.shape[:3]
38 |
39 | y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
40 | self.coordinate = torch.stack((x,y),-1).float()+0.5
41 |
42 | self.imgs, self.coordinate = self.imgs.reshape(D,H*W,-1), self.coordinate.reshape(H*W,2)
43 | self.DHW = [D,H,W]
44 |
45 | self.scene_bbox = [[0., 0., 0.], [W, H, D]]
46 | cfg.aabb = self.scene_bbox
47 | # self.down_scale = 512.0/H
48 |
49 |
50 | def __len__(self):
51 | return 1000000
52 |
53 | def __getitem__(self, idx):
54 | D,H,W = self.DHW
55 | pix_idx = torch.randint(0,H*W,(self.batchsize,))
56 | img_idx = torch.randint(0,D,(self.batchsize,))
57 |
58 |
59 | if self.continue_sampling:
60 | coordinate = self.coordinate[pix_idx] + torch.rand((self.batchsize,2)) - 0.5
61 | coordinate = torch.cat((coordinate,img_idx.unsqueeze(-1)+0.5),dim=-1)
62 | coordinate_tmp = (coordinate.view(1,1,1,self.batchsize,3))/torch.tensor([W,H,D])*2-1.0
63 | rgb = F.grid_sample(self.imgs.view(1,D,H,W,-1).permute(0,4,1,2,3),coordinate_tmp, mode='bilinear',
64 | align_corners=False, padding_mode='border').reshape(self.imgs.shape[-1],-1).t()
65 | # coordinate[:,:2] *= self.down_scale
66 | sample = {'rgb': rgb,
67 | 'xy': coordinate}
68 | else:
69 | sample = {'rgb': self.imgs[img_idx,pix_idx],
70 | 'xy': torch.cat((self.coordinate[pix_idx],img_idx.unsqueeze(-1)+0.5),dim=-1)}
71 | # 'xy': torch.cat((self.coordiante[pix_idx],img_idx.expand_as(pix_idx).unsqueeze(-1)),dim=-1)}
72 |
73 |
74 |
75 | return sample
--------------------------------------------------------------------------------
/dataLoader/llff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import glob
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | def normalize(v):
13 | """Normalize a vector."""
14 | return v / np.linalg.norm(v)
15 |
16 |
17 | def average_poses(poses):
18 | """
19 | Calculate the average pose, which is then used to center all poses
20 | using @center_poses. Its computation is as follows:
21 | 1. Compute the center: the average of pose centers.
22 | 2. Compute the z axis: the normalized average z axis.
23 | 3. Compute axis y': the average y axis.
24 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
25 | 5. Compute the y axis: z cross product x.
26 |
27 | Note that at step 3, we cannot directly use y' as y axis since it's
28 | not necessarily orthogonal to z axis. We need to pass from x to y.
29 | Inputs:
30 | poses: (N_images, 3, 4)
31 | Outputs:
32 | pose_avg: (3, 4) the average pose
33 | """
34 | # 1. Compute the center
35 | center = poses[..., 3].mean(0) # (3)
36 |
37 | # 2. Compute the z axis
38 | z = normalize(poses[..., 2].mean(0)) # (3)
39 |
40 | # 3. Compute axis y' (no need to normalize as it's not the final output)
41 | y_ = poses[..., 1].mean(0) # (3)
42 |
43 | # 4. Compute the x axis
44 | x = normalize(np.cross(z, y_)) # (3)
45 |
46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
47 | y = np.cross(x, z) # (3)
48 |
49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
50 |
51 | return pose_avg
52 |
53 |
54 | def center_poses(poses, blender2opencv):
55 | """
56 | Center the poses so that we can use NDC.
57 | See https://github.com/bmild/nerf/issues/34
58 | Inputs:
59 | poses: (N_images, 3, 4)
60 | Outputs:
61 | poses_centered: (N_images, 3, 4) the centered poses
62 | pose_avg: (3, 4) the average pose
63 | """
64 | poses = poses @ blender2opencv
65 | pose_avg = average_poses(poses) # (3, 4)
66 | pose_avg_homo = np.eye(4)
67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
68 | pose_avg_homo = pose_avg_homo
69 | # by simply adding 0, 0, 0, 1 as the last row
70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
71 | poses_homo = \
72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
73 |
74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
75 | # poses_centered = poses_centered @ blender2opencv
76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
77 |
78 | return poses_centered, pose_avg_homo
79 |
80 |
81 | def viewmatrix(z, up, pos):
82 | vec2 = normalize(z)
83 | vec1_avg = up
84 | vec0 = normalize(np.cross(vec1_avg, vec2))
85 | vec1 = normalize(np.cross(vec2, vec0))
86 | m = np.eye(4)
87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
88 | return m
89 |
90 |
91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
92 | render_poses = []
93 | rads = np.array(list(rads) + [1.])
94 |
95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
98 | render_poses.append(viewmatrix(z, up, c))
99 | return render_poses
100 |
101 |
102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
103 | # center pose
104 | c2w = average_poses(c2ws_all)
105 |
106 | # Get average pose
107 | up = normalize(c2ws_all[:, :3, 1].sum(0))
108 |
109 | # Find a reasonable "focus depth" for this dataset
110 | dt = 0.75
111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
113 |
114 | # Get radii for spiral path
115 | zdelta = near_fars.min() * .2
116 | tt = c2ws_all[:, :3, 3]
117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
119 | return np.stack(render_poses)
120 |
121 |
122 | class LLFFDataset(Dataset):
123 | def __init__(self, cfg , split='train', hold_every=8):
124 | """
125 | spheric_poses: whether the images are taken in a spheric inward-facing manner
126 | default: False (forward-facing)
127 | val_num: number of val images (used for multigpu training, validate same image for all gpus)
128 | """
129 |
130 | self.root_dir = cfg.datadir
131 | self.split = split
132 | self.hold_every = hold_every
133 | self.is_stack = False if 'train' == split else True
134 | self.downsample = cfg.get(f'downsample_{self.split}')
135 | self.define_transforms()
136 |
137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
138 | self.read_meta()
139 | self.white_bg = False
140 |
141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
142 | self.near_far = [0.0, 1.0]
143 | self.scene_bbox = [[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]
144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
145 | # self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
146 | # self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
147 |
148 | def read_meta(self):
149 |
150 | print(self.root_dir)
151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
153 | # load full resolution image then resize
154 | if self.split in ['train', 'test']:
155 | assert len(poses_bounds) == len(self.image_paths), \
156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
157 |
158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
160 | hwf = poses[:, :, -1]
161 |
162 | # Step 1: rescale focal length according to training resolution
163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
166 |
167 | # Step 2: correct poses
168 | # Original poses has rotation in form "down right back", change to "right up back"
169 | # See https://github.com/bmild/nerf/issues/34
170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
171 | # (N_images, 3, 4) exclude H, W, focal
172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
173 |
174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
175 | # See https://github.com/bmild/nerf/issues/34
176 | near_original = self.near_fars.min()
177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
178 | # the nearest depth is at 1/0.75=1.33
179 | self.near_fars /= scale_factor
180 | self.poses[..., 3] /= scale_factor
181 |
182 | # build rendering path
183 | N_views, N_rots = 120, 2
184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
185 | up = normalize(self.poses[:, :3, 1].sum(0))
186 | rads = np.percentile(np.abs(tt), 90, 0)
187 |
188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
189 |
190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to
192 | # center image
193 |
194 | # ray directions for all pixels, same for all images (same H, W, focal)
195 | W, H = self.img_wh
196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
197 |
198 | average_pose = average_poses(self.poses)
199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
202 |
203 | # use first N_images-1 to train, the LAST is val
204 | self.all_rays = []
205 | self.all_rgbs = []
206 | for i in img_list:
207 | image_path = self.image_paths[i]
208 | c2w = torch.FloatTensor(self.poses[i])
209 |
210 | img = Image.open(image_path).convert('RGB')
211 | if self.downsample != 1.0:
212 | img = img.resize(self.img_wh, Image.LANCZOS)
213 | img = self.transform(img) # (3, h, w)
214 |
215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
216 | self.all_rgbs += [img]
217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
220 |
221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
222 |
223 | if not self.is_stack:
224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
226 | else:
227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
229 |
230 |
231 | def define_transforms(self):
232 | self.transform = T.ToTensor()
233 |
234 | def __len__(self):
235 | return len(self.all_rgbs)
236 |
237 | def __getitem__(self, idx):
238 |
239 | sample = {'rays': self.all_rays[idx],
240 | 'rgbs': self.all_rgbs[idx]}
241 |
242 | return sample
--------------------------------------------------------------------------------
/dataLoader/nsvf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 | class NSVF(Dataset):
37 | """NSVF Generic Dataset."""
38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
39 | self.root_dir = datadir
40 | self.split = split
41 | self.is_stack = is_stack
42 | self.downsample = downsample
43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
44 | self.define_transforms()
45 |
46 | self.white_bg = True
47 | self.near_far = [0.5,6.0]
48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
50 | self.read_meta()
51 | self.define_proj_mat()
52 |
53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
55 |
56 | def bbox2corners(self):
57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
58 | for i in range(3):
59 | corners[i,[0,1],i] = corners[i,[1,0],i]
60 | return corners.view(-1,3)
61 |
62 |
63 | def read_meta(self):
64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
65 | focal = float(f.readline().split()[0])
66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
68 |
69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
71 |
72 | if self.split == 'train':
73 | pose_files = [x for x in pose_files if x.startswith('0_')]
74 | img_files = [x for x in img_files if x.startswith('0_')]
75 | elif self.split == 'val':
76 | pose_files = [x for x in pose_files if x.startswith('1_')]
77 | img_files = [x for x in img_files if x.startswith('1_')]
78 | elif self.split == 'test':
79 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
80 | test_img_files = [x for x in img_files if x.startswith('2_')]
81 | if len(test_pose_files) == 0:
82 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
83 | test_img_files = [x for x in img_files if x.startswith('1_')]
84 | pose_files = test_pose_files
85 | img_files = test_img_files
86 |
87 | # ray directions for all pixels, same for all images (same H, W, focal)
88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
90 |
91 | frames = 200
92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,frames+1)[:-1]], 0)
93 |
94 | self.poses = []
95 | self.all_rays = []
96 | self.all_rgbs = []
97 |
98 | assert len(img_files) == len(pose_files)
99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
101 | img = Image.open(image_path)
102 | if self.downsample!=1.0:
103 | img = img.resize(self.img_wh, Image.LANCZOS)
104 | img = self.transform(img) # (4, h, w)
105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
106 | if img.shape[-1]==4:
107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
108 | self.all_rgbs += [img]
109 |
110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
111 | c2w = torch.FloatTensor(c2w)
112 | self.poses.append(c2w) # C2W
113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
115 |
116 | # w2c = torch.inverse(c2w)
117 | #
118 |
119 | self.poses = torch.stack(self.poses)
120 | if 'train' == self.split:
121 | if self.is_stack:
122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
124 | else:
125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
127 | else:
128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
130 |
131 |
132 | def define_transforms(self):
133 | self.transform = T.ToTensor()
134 |
135 | def define_proj_mat(self):
136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
137 |
138 | def world2ndc(self, points):
139 | device = points.device
140 | return (points - self.center.to(device)) / self.radius.to(device)
141 |
142 | def __len__(self):
143 | if self.split == 'train':
144 | return len(self.all_rays)
145 | return len(self.all_rgbs)
146 |
147 | def __getitem__(self, idx):
148 |
149 | if self.split == 'train': # use data in the buffers
150 | sample = {'rays': self.all_rays[idx],
151 | 'rgbs': self.all_rgbs[idx]}
152 |
153 | else: # create data for each image separately
154 |
155 | img = self.all_rgbs[idx]
156 | rays = self.all_rays[idx]
157 |
158 | sample = {'rays': rays,
159 | 'rgbs': img}
160 | return sample
--------------------------------------------------------------------------------
/dataLoader/sdf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 | def N_to_reso(avg_reso, bbox):
6 | xyz_min, xyz_max = bbox
7 | dim = len(xyz_min)
8 | n_voxels = avg_reso**dim
9 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
10 | return torch.ceil((xyz_max - xyz_min) / voxel_size).long().tolist()
11 |
12 | def load(path, split, dtype='points'):
13 |
14 | if 'grid' == dtype:
15 | sdf = torch.from_numpy(np.load(path).astype(np.float32))
16 | D, H, W = sdf.shape
17 | z, y, x = torch.meshgrid(torch.arange(0, D), torch.arange(0, H), torch.arange(0, W), indexing='ij')
18 | coordiante = torch.stack((x,y,z),-1).reshape(D*H*W,3)#*2-1 # normalize to [-1,1]
19 | sdf = sdf.reshape(D*H*W,-1)
20 | DHW = [D,H,W]
21 | elif 'points' == dtype:
22 | DHW = [640] * 3
23 | sdf_dict = np.load(path, allow_pickle=True).item()
24 | sdf = torch.from_numpy(sdf_dict[f'sdfs_{split}'].astype(np.float32)).reshape(-1,1)
25 | coordiante = torch.from_numpy(sdf_dict[f'points_{split}'].astype(np.float32))
26 | aabb = [[-1,-1,-1],[1,1,1]]
27 | coordiante = ((coordiante + 1) / 2 * (torch.tensor(DHW[::-1]))).reshape(-1,3)
28 | DHW = DHW[::-1]
29 | return coordiante, sdf, DHW
30 |
31 | class SDFDataset(Dataset):
32 | def __init__(self, cfg, split='train'):
33 |
34 | datadir = cfg.datadir
35 | self.coordiante, self.sdf, self.DHW = load(datadir, split)
36 |
37 | [D, H, W] = self.DHW
38 |
39 | self.scene_bbox = [[0., 0., 0.], [W, H, D]]
40 | cfg.aabb = self.scene_bbox
41 |
42 | def __len__(self):
43 | return len(self.sdf)
44 |
45 | def __getitem__(self, idx):
46 | sample = {'rgb': self.sdf[idx],
47 | 'xy': self.coordiante[idx]}
48 |
49 | return sample
--------------------------------------------------------------------------------
/dataLoader/tankstemple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 |
11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
12 | if axis == 'z':
13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
14 | elif axis == 'y':
15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
16 | else:
17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
18 |
19 |
20 | def cross(x, y, axis=0):
21 | T = torch if isinstance(x, torch.Tensor) else np
22 | return T.cross(x, y, axis)
23 |
24 |
25 | def normalize(x, axis=-1, order=2):
26 | if isinstance(x, torch.Tensor):
27 | l2 = x.norm(p=order, dim=axis, keepdim=True)
28 | return x / (l2 + 1e-8), l2
29 |
30 | else:
31 | l2 = np.linalg.norm(x, order, axis)
32 | l2 = np.expand_dims(l2, axis)
33 | l2[l2 == 0] = 1
34 | return x / l2,
35 |
36 |
37 | def cat(x, axis=1):
38 | if isinstance(x[0], torch.Tensor):
39 | return torch.cat(x, dim=axis)
40 | return np.concatenate(x, axis=axis)
41 |
42 |
43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
44 | """
45 | This function takes a vector 'camera_position' which specifies the location
46 | of the camera in world coordinates and two vectors `at` and `up` which
47 | indicate the position of the object and the up directions of the world
48 | coordinate system respectively. The object is assumed to be centered at
49 | the origin.
50 | The output is a rotation matrix representing the transformation
51 | from world coordinates -> view coordinates.
52 | Input:
53 | camera_position: 3
54 | at: 1 x 3 or N x 3 (0, 0, 0) in default
55 | up: 1 x 3 or N x 3 (0, 1, 0) in default
56 | """
57 |
58 | if at is None:
59 | at = torch.zeros_like(camera_position)
60 | else:
61 | at = torch.tensor(at).type_as(camera_position)
62 | if up is None:
63 | up = torch.zeros_like(camera_position)
64 | up[2] = -1
65 | else:
66 | up = torch.tensor(up).type_as(camera_position)
67 |
68 | z_axis = normalize(at - camera_position)[0]
69 | x_axis = normalize(cross(up, z_axis))[0]
70 | y_axis = normalize(cross(z_axis, x_axis))[0]
71 |
72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
73 | return R
74 |
75 |
76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
77 | c2ws = []
78 | for t in range(frames):
79 | c2w = torch.eye(4)
80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
83 | c2ws.append(c2w)
84 | return torch.stack(c2ws)
85 |
86 |
87 | class TanksTempleDataset(Dataset):
88 | """NSVF Generic Dataset."""
89 |
90 | def __init__(self, cfg, split='train'):
91 | self.root_dir = cfg.datadir
92 | self.split = split
93 | self.is_stack = False if 'train'==split else True
94 | self.downsample = cfg.get(f'downsample_{self.split}')
95 | self.img_wh = (int(1920 / self.downsample), int(1080 / self.downsample))
96 | self.define_transforms()
97 |
98 | self.white_bg = True
99 | self.near_far = [0.01, 6.0]
100 | self.scene_bbox = (torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2, 3) * 1.2).tolist()
101 |
102 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
103 | self.read_meta()
104 | self.define_proj_mat()
105 |
106 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
107 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
108 |
109 | def bbox2corners(self):
110 | corners = self.scene_bbox.unsqueeze(0).repeat(4, 1, 1)
111 | for i in range(3):
112 | corners[i, [0, 1], i] = corners[i, [1, 0], i]
113 | return corners.view(-1, 3)
114 |
115 | def read_meta(self):
116 |
117 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
118 | self.intrinsics[:2] *= (np.array(self.img_wh) / np.array([1920, 1080])).reshape(2, 1)
119 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
120 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
121 |
122 | if self.split == 'train':
123 | pose_files = [x for x in pose_files if x.startswith('0_')]
124 | img_files = [x for x in img_files if x.startswith('0_')]
125 | elif self.split == 'val':
126 | pose_files = [x for x in pose_files if x.startswith('1_')]
127 | img_files = [x for x in img_files if x.startswith('1_')]
128 | elif self.split == 'test':
129 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
130 | test_img_files = [x for x in img_files if x.startswith('2_')]
131 | if len(test_pose_files) == 0:
132 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
133 | test_img_files = [x for x in img_files if x.startswith('1_')]
134 | pose_files = test_pose_files
135 | img_files = test_img_files
136 |
137 | # ray directions for all pixels, same for all images (same H, W, focal)
138 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0],
139 | [self.intrinsics[0, 0], self.intrinsics[1, 1]],
140 | center=self.intrinsics[:2, 2]) # (h, w, 3)
141 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
142 |
143 | self.poses = []
144 |
145 | ray_per_frame = self.img_wh[0]*self.img_wh[1]
146 | self.all_rays = torch.empty(len(pose_files),ray_per_frame,6)
147 | self.all_rgbs = torch.empty(len(pose_files),ray_per_frame,3)
148 | assert len(img_files) == len(pose_files)
149 | for i in tqdm(range(len(pose_files)),desc=f'Loading data {self.split} ({len(img_files)})'):
150 |
151 | img_fname, pose_fname = img_files[i], pose_files[i]
152 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
153 | img = Image.open(image_path)
154 | if self.downsample != 1.0:
155 | img = img.resize(self.img_wh, Image.LANCZOS)
156 | img = self.transform(img) # (4, h, w)
157 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
158 | if img.shape[-1] == 4:
159 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
160 | # self.all_rgbs.append(img)
161 |
162 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) # @ cam_trans
163 | c2w = torch.FloatTensor(c2w)
164 | self.poses.append(c2w) # C2W
165 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
166 |
167 | self.all_rays[i] = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
168 | self.all_rgbs[i] = img
169 |
170 | self.poses = torch.stack(self.poses)
171 |
172 | frames = 200
173 | scene_bbox = torch.tensor(self.scene_bbox).float()
174 | center = torch.mean(scene_bbox, dim=0)
175 | radius = torch.norm(scene_bbox[1] - center) * 1.2
176 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
177 | pos_gen = circle(radius=radius, h=-0.2 * up[1], axis='y')
178 | self.render_path = gen_path(pos_gen, up=up, frames=frames)
179 | self.render_path[:, :3, 3] += center
180 |
181 | if 'train' == self.split:
182 | if not self.is_stack:
183 | # self.all_rays = torch.stack(self.all_rays, 0).reshape(-1, *self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
184 | # self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
185 | # else:
186 | self.all_rays = self.all_rays.reshape(-1,6) # (len(self.meta['frames])*h*w, 3)
187 | self.all_rgbs = self.all_rgbs.reshape(-1,3) # (len(self.meta['frames])*h*w, 3)
188 | else:
189 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
190 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
191 |
192 | def define_transforms(self):
193 | self.transform = T.ToTensor()
194 |
195 | def define_proj_mat(self):
196 | self.proj_mat = torch.from_numpy(self.intrinsics[:3, :3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:, :3]
197 |
198 | def world2ndc(self, points):
199 | device = points.device
200 | return (points - self.center.to(device)) / self.radius.to(device)
201 |
202 | def __len__(self):
203 | if self.split == 'train':
204 | return len(self.all_rays)
205 | return len(self.all_rgbs)
206 |
207 | def __getitem__(self, idx):
208 |
209 | if self.split == 'train': # use data in the buffers
210 | sample = {'rays': self.all_rays[idx],
211 | 'rgbs': self.all_rgbs[idx]}
212 |
213 | else: # create data for each image separately
214 |
215 | img = self.all_rgbs[idx]
216 | rays = self.all_rays[idx]
217 |
218 | sample = {'rays': rays,
219 | 'rgbs': img}
220 | return sample
--------------------------------------------------------------------------------
/dataLoader/your_own_data.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class YourOwnDataset(Dataset):
14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
15 |
16 | self.N_vis = N_vis
17 | self.root_dir = datadir
18 | self.split = split
19 | self.is_stack = is_stack
20 | self.downsample = downsample
21 | self.define_transforms()
22 |
23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
25 | self.read_meta()
26 | self.define_proj_mat()
27 |
28 | self.white_bg = True
29 | self.near_far = [0.1,100.0]
30 |
31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
33 | self.downsample=downsample
34 |
35 | def read_depth(self, filename):
36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
37 | return depth
38 |
39 | def read_meta(self):
40 |
41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
42 | self.meta = json.load(f)
43 |
44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
45 | self.img_wh = [w,h]
46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
48 | self.cx, self.cy = self.meta['cx'],self.meta['cy']
49 |
50 |
51 | # ray directions for all pixels, same for all images (same H, W, focal)
52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
55 |
56 | self.image_paths = []
57 | self.poses = []
58 | self.all_rays = []
59 | self.all_rgbs = []
60 | self.all_masks = []
61 | self.all_depth = []
62 |
63 |
64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
67 |
68 | frame = self.meta['frames'][i]
69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
70 | c2w = torch.FloatTensor(pose)
71 | self.poses += [c2w]
72 |
73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
74 | self.image_paths += [image_path]
75 | img = Image.open(image_path)
76 |
77 | if self.downsample!=1.0:
78 | img = img.resize(self.img_wh, Image.LANCZOS)
79 | img = self.transform(img) # (4, h, w)
80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
81 | if img.shape[-1]==4:
82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
83 | self.all_rgbs += [img]
84 |
85 |
86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
88 |
89 |
90 | self.poses = torch.stack(self.poses)
91 | if not self.is_stack:
92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
94 |
95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
96 | else:
97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
100 |
101 |
102 | def define_transforms(self):
103 | self.transform = T.ToTensor()
104 |
105 | def define_proj_mat(self):
106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
107 |
108 | def world2ndc(self,points,lindisp=None):
109 | device = points.device
110 | return (points - self.center.to(device)) / self.radius.to(device)
111 |
112 | def __len__(self):
113 | return len(self.all_rgbs)
114 |
115 | def __getitem__(self, idx):
116 |
117 | if self.split == 'train': # use data in the buffers
118 | sample = {'rays': self.all_rays[idx],
119 | 'rgbs': self.all_rgbs[idx]}
120 |
121 | else: # create data for each image separately
122 |
123 | img = self.all_rgbs[idx]
124 | rays = self.all_rays[idx]
125 | mask = self.all_masks[idx] # for quantity evaluation
126 |
127 | sample = {'rays': rays,
128 | 'rgbs': img}
129 | return sample
130 |
--------------------------------------------------------------------------------
/media/Girl_with_a_Pearl_Earring.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/media/Girl_with_a_Pearl_Earring.jpg
--------------------------------------------------------------------------------
/media/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/media/inpainting.png
--------------------------------------------------------------------------------
/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/models/.DS_Store
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/models/__init__.py
--------------------------------------------------------------------------------
/models/sh.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ################## sh function ##################
4 | C0 = 0.28209479177387814
5 | C1 = 0.4886025119029199
6 | C2 = [
7 | 1.0925484305920792,
8 | -1.0925484305920792,
9 | 0.31539156525252005,
10 | -1.0925484305920792,
11 | 0.5462742152960396
12 | ]
13 | C3 = [
14 | -0.5900435899266435,
15 | 2.890611442640554,
16 | -0.4570457994644658,
17 | 0.3731763325901154,
18 | -0.4570457994644658,
19 | 1.445305721320277,
20 | -0.5900435899266435
21 | ]
22 | C4 = [
23 | 2.5033429417967046,
24 | -1.7701307697799304,
25 | 0.9461746957575601,
26 | -0.6690465435572892,
27 | 0.10578554691520431,
28 | -0.6690465435572892,
29 | 0.47308734787878004,
30 | -1.7701307697799304,
31 | 0.6258357354491761,
32 | ]
33 |
34 | def eval_sh(deg, sh, dirs):
35 | """
36 | Evaluate spherical harmonics at unit directions
37 | using hardcoded SH polynomials.
38 | Works with torch/np/jnp.
39 | ... Can be 0 or more batch dimensions.
40 | :param deg: int SH max degree. Currently, 0-4 supported
41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
42 | :param dirs: torch.Tensor unit directions (..., 3)
43 | :return: (..., C)
44 | """
45 | assert deg <= 4 and deg >= 0
46 | assert (deg + 1) ** 2 == sh.shape[-1]
47 | C = sh.shape[-2]
48 |
49 | result = C0 * sh[..., 0]
50 | if deg > 0:
51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
52 | result = (result -
53 | C1 * y * sh[..., 1] +
54 | C1 * z * sh[..., 2] -
55 | C1 * x * sh[..., 3])
56 | if deg > 1:
57 | xx, yy, zz = x * x, y * y, z * z
58 | xy, yz, xz = x * y, y * z, x * z
59 | result = (result +
60 | C2[0] * xy * sh[..., 4] +
61 | C2[1] * yz * sh[..., 5] +
62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
63 | C2[3] * xz * sh[..., 7] +
64 | C2[4] * (xx - yy) * sh[..., 8])
65 |
66 | if deg > 2:
67 | result = (result +
68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
69 | C3[1] * xy * z * sh[..., 10] +
70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
73 | C3[5] * z * (xx - yy) * sh[..., 14] +
74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
75 | if deg > 3:
76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
85 | return result
86 |
87 | def eval_sh_bases(deg, dirs):
88 | """
89 | Evaluate spherical harmonics bases at unit directions,
90 | without taking linear combination.
91 | At each point, the final result may the be
92 | obtained through simple multiplication.
93 | :param deg: int SH max degree. Currently, 0-4 supported
94 | :param dirs: torch.Tensor (..., 3) unit directions
95 | :return: torch.Tensor (..., (deg+1) ** 2)
96 | """
97 | assert deg <= 4 and deg >= 0
98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
99 | result[..., 0] = C0
100 | if deg > 0:
101 | x, y, z = dirs.unbind(-1)
102 | result[..., 1] = -C1 * y;
103 | result[..., 2] = C1 * z;
104 | result[..., 3] = -C1 * x;
105 | if deg > 1:
106 | xx, yy, zz = x * x, y * y, z * z
107 | xy, yz, xz = x * y, y * z, x * z
108 | result[..., 4] = C2[0] * xy;
109 | result[..., 5] = C2[1] * yz;
110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
111 | result[..., 7] = C2[3] * xz;
112 | result[..., 8] = C2[4] * (xx - yy);
113 |
114 | if deg > 2:
115 | result[..., 9] = C3[0] * y * (3 * xx - yy);
116 | result[..., 10] = C3[1] * xy * z;
117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
120 | result[..., 14] = C3[5] * z * (xx - yy);
121 | result[..., 15] = C3[6] * x * (xx - 3 * yy);
122 |
123 | if deg > 3:
124 | result[..., 16] = C4[0] * xy * (xx - yy);
125 | result[..., 17] = C4[1] * yz * (3 * xx - yy);
126 | result[..., 18] = C4[2] * xy * (7 * zz - 1);
127 | result[..., 19] = C4[3] * yz * (7 * zz - 3);
128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
129 | result[..., 21] = C4[5] * xz * (7 * zz - 3);
130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy);
132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
133 | return result
134 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | import torch,os,imageio,sys
2 | from tqdm.auto import tqdm
3 | from dataLoader.ray_utils import get_rays
4 | from utils import *
5 | from dataLoader.ray_utils import ndc_rays_blender
6 |
7 |
8 | def render_ray(rays, factor_fields, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'):
9 |
10 | rgbs, alphas, depth_maps, weights, coeffs = [], [], [], [], []
11 | N_rays_all = rays.shape[0]
12 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
13 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
14 |
15 | if is_train:
16 | rgb_map, depth_map, coeff = factor_fields(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
17 | coeffs.append(coeff)
18 | else:
19 | rgb_map, depth_map, _ = factor_fields(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
20 |
21 | rgbs.append(rgb_map)
22 | depth_maps.append(depth_map)
23 |
24 | if is_train:
25 | return torch.cat(rgbs), torch.cat(depth_maps), torch.cat(coeffs)
26 | else:
27 | return torch.cat(rgbs), torch.cat(depth_maps)
28 |
29 | @torch.no_grad()
30 | def evaluation(test_dataset,factor_fields, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
31 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
32 | PSNRs, rgb_maps, depth_maps = [], [], []
33 | ssims,l_alex,l_vgg=[],[],[]
34 | os.makedirs(savePath, exist_ok=True)
35 | os.makedirs(savePath+"/rgbd", exist_ok=True)
36 |
37 | try:
38 | tqdm._instances.clear()
39 | except Exception:
40 | pass
41 |
42 | torch.cuda.empty_cache()
43 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1)
44 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval))
45 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
46 |
47 | W, H = test_dataset.img_wh
48 | rays = samples.view(-1,samples.shape[-1])
49 |
50 | rgb_map, depth_map = renderer(rays, factor_fields, chunk=1024, N_samples=N_samples,
51 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
52 | rgb_map = rgb_map.clamp(0.0, 1.0)
53 |
54 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
55 |
56 | depth_map, _ = visualize_depth_numpy(depth_map.numpy())
57 | if len(test_dataset.all_rgbs):
58 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3)
59 | loss = torch.mean((rgb_map - gt_rgb) ** 2)
60 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
61 |
62 | if compute_extra_metrics:
63 | ssim = rgb_ssim(rgb_map, gt_rgb, 1)
64 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', factor_fields.device)
65 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', factor_fields.device)
66 | ssims.append(ssim)
67 | l_alex.append(l_a)
68 | l_vgg.append(l_v)
69 |
70 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
71 | gt_rgb = (gt_rgb.numpy() * 255).astype('uint8')
72 |
73 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
74 | rgb_maps.append(rgb_map)
75 | depth_maps.append(depth_map)
76 | if savePath is not None:
77 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
78 | rgb_map = np.concatenate((rgb_map, gt_rgb, depth_map), axis=1)
79 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
80 |
81 | torch.cuda.empty_cache()
82 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
83 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)
84 |
85 | n_params = factor_fields.n_parameters()
86 | if PSNRs:
87 | psnr = np.mean(np.asarray(PSNRs))
88 | if compute_extra_metrics:
89 | ssim = np.mean(np.asarray(ssims))
90 | l_a = np.mean(np.asarray(l_alex))
91 | l_v = np.mean(np.asarray(l_vgg))
92 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v, n_params]))
93 | print(f"PSNR={psnr} SSIM={ssim} {l_a} {l_v} ")
94 | else:
95 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr,n_params]))
96 |
97 |
98 | return PSNRs
99 |
100 | @torch.no_grad()
101 | def evaluation_path(test_dataset,factor_fields, c2ws, renderer, savePath=None, prtx='', N_samples=-1,
102 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
103 | PSNRs, rgb_maps, depth_maps = [], [], []
104 | ssims,l_alex,l_vgg=[],[],[]
105 | os.makedirs(savePath, exist_ok=True)
106 | os.makedirs(savePath+"/rgbd", exist_ok=True)
107 |
108 | try:
109 | tqdm._instances.clear()
110 | except Exception:
111 | pass
112 |
113 | near_far = test_dataset.near_far
114 | for idx, c2w in tqdm(enumerate(c2ws)):
115 |
116 | W, H = test_dataset.img_wh
117 |
118 | c2w = torch.FloatTensor(c2w)
119 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
120 | if ndc_ray:
121 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
122 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
123 |
124 | rgb_map, depth_map = renderer(rays, factor_fields, chunk=8192, N_samples=N_samples,
125 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
126 | rgb_map = rgb_map.clamp(0.0, 1.0)
127 |
128 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
129 |
130 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
131 |
132 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
133 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
134 | rgb_maps.append(rgb_map)
135 | depth_maps.append(depth_map)
136 | if savePath is not None:
137 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
138 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
139 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
140 |
141 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
142 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)
143 |
144 | if PSNRs:
145 | psnr = np.mean(np.asarray(PSNRs))
146 | if compute_extra_metrics:
147 | ssim = np.mean(np.asarray(ssims))
148 | l_a = np.mean(np.asarray(l_alex))
149 | l_v = np.mean(np.asarray(l_vgg))
150 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
151 | else:
152 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
153 |
154 |
155 | return PSNRs
156 |
157 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio==2.26.0
2 | kornia==0.6.10
3 | lpips==0.1.4
4 | matplotlib==3.7.1
5 | omegaconf==2.3.0
6 | opencv_python==4.7.0.72
7 | plyfile==0.7.4
8 | scikit-image
9 | tqdm==4.65.0
10 | trimesh==3.20.2
11 | jupyterlab
--------------------------------------------------------------------------------
/run_batch.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import threading, queue
4 | import numpy as np
5 | import time
6 |
7 |
8 | if __name__ == '__main__':
9 |
10 | ################ per scene NeRF ################
11 | commands = {
12 | '-grid': '', \
13 | # '-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \
14 | # '-noC': 'model.coeff_type=none', \
15 | # '-SL':'model.basis_dims=[18] model.basis_resos=[70] model.freq_bands=[8.]', \
16 | # '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
17 | # '-iNGP-like': 'model.basis_type=hash model.coeff_type=none', \
18 | # '-hash': f'model.basis_type=hash model.coef_init=1.0', \
19 | # '-sinc': f'model.basis_mapping=sinc', \
20 | # '-tria': f'model.basis_mapping=triangle', \
21 | # '-vm': f'model.coeff_type=vm model.basis_type=vm', \
22 | # '-mlpB': 'model.basis_type=mlp', \
23 | # '-mlpC': 'model.coeff_type=mlp', \
24 | # '-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \
25 | # '-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \
26 | # f'model.num_layers=8 model.hidden_dim=256 ' \
27 | # f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \
28 | # '-iNGP-like-sl': 'model.basis_type=hash model.coeff_type=none model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
29 | # '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
30 | # '-DCT':'model.basis_type=fix-grid', \
31 | }
32 |
33 | ################ per scene NeRF ################
34 | ###### uncomment the following five lines if you want to train on all scenes #########
35 | cmds = []
36 | for name in commands.keys(): #
37 | # for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:#
38 | # cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {commands[name]}'
39 | # cmds.append(cmd)
40 |
41 | for scene in ['Ignatius','Truck']:#
42 | if scene != 'Ignatius':
43 | cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/TanksAndTemple/{scene} {commands[name]} ' \
44 | f' dataset.dataset_name=tankstemple '
45 | cmds.append(cmd)
46 |
47 | cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/TanksAndTemple/{scene} {commands[name]} ' \
48 | f' dataset.dataset_name=tankstemple exportation.render_only=1 exportation.render_path=1 exportation.render_test=0 ' \
49 | f' defaults.ckpt=/mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/{scene}-grid/{scene}-grid.th '
50 | cmds.append(cmd)
51 |
52 | ################ generalization NeRF ################
53 | commands = {
54 | # '-grid': '', \
55 | # '-DVGO-like': 'model.basis_type=none model.coeff_reso=48',
56 | # '-SL':'model.basis_dims=[72] model.basis_resos=[48] model.freq_bands=[6.]', \
57 | # '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
58 | # '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
59 | # '-sinc': f'model.basis_mapping=sinc', \
60 | # '-tria': f'model.basis_mapping=triangle', \
61 | # '-vm': f'model.coeff_type=vm model.basis_type=vm', \
62 | # '-mlpB': 'model.basis_type=mlp', \
63 | # '-mlpC': 'model.coeff_type=mlp', \
64 | # '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
65 | # '-DCT':'model.basis_type=fix-grid', \
66 | }
67 | # for name in commands.keys(): #
68 | # config = commands[name]
69 | # config = f'python train_across_scene2.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \
70 | # f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/'
71 | # cmds.append(config)
72 |
73 |
74 | # # =========> fine tuning <================
75 | # views = 5
76 | # for name in commands.keys(): #
77 | # for scene in [183,199,298,467,957,244,963,527]:#
78 | # cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
79 | # f'dataset.datadir=/home/anpei/Dataset/google_scanned_objects/ {commands[name]} ' \
80 | # f'dataset.train_views={views} ' \
81 | # f'dataset.train_scene_list=[{scene}] ' \
82 | # f'dataset.test_scene_list=[{scene}] ' \
83 | # f'defaults.ckpt=/home/anpei/Code/NeuBasis/log/google-obj{name}//google-obj{name}.th '
84 | # cmds.append(cmd)
85 |
86 | # for scene in ['Ignatius','Barn','Truck','Family','Caterpillar']:#'Ignatius','Barn','Truck','Family','Caterpillar'
87 | # cmds.append(f'python train_basis.py configs/tnt.yaml defaults.expname=tnt_{scene} ' \
88 | # f'dataset.datadir=./data/TanksAndTemple/{scene}'
89 | # )
90 |
91 | # cmds = []
92 | # for scene in ['room']:#,'hall','kitchen','living_room','room2','sofa','meeting_room','room','salon2'
93 | # cmds.append(f'python train_basis.py configs/colmap_new.yaml defaults.expname=indoor_{scene} ' \
94 | # f'dataset.datadir=./data/indoor/real/{scene}'
95 | # # f'defaults.ckpt=/home/anpei/code/NeuBasis2/log/basis_ship/basis_ship.th exporation.render_only=True'
96 | # )
97 |
98 | # cmds = []
99 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeRF model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \
100 | # f'model.num_layers=8 model.hidden_dim=256 ' \
101 | # f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]')
102 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-grid')
103 | # cmds.append(f'python 2D_regression.py defaults.expname=NeuBasis-mlpB model.basis_type=mlp')
104 | # cmds.append(f'python 2D_regression.py defaults.expname=NeuBasis-mlpC model.coeff_type=mlp')
105 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=DVGO-like model.basis_type=none')
106 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-noC model.coeff_type=none')
107 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-sinc model.basis_mapping=sinc')
108 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-tria model.basis_mapping=triangle')
109 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-SL model.basis_dims=[144] model.basis_resos=[14] model.freq_bands=[73.14]')
110 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-DCT model.basis_type=fix-grid')
111 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-CP model.coeff_type=vec model.basis_type=cp \
112 | # model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[1024,1024,1024,1024,1024,1024] model.basis_dims=[64,64,64,32,32,32]')
113 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=iNGP-like model.basis_type=hash model.coeff_type=none')
114 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-hash model.basis_type=hash model.coef_init=0.1 basis_dims=[16,16,16,16,16,16]')
115 |
116 | #setting available gpus
117 | gpu_idx = [0]
118 | gpus_que = queue.Queue(len(gpu_idx))
119 | for i in gpu_idx:
120 | gpus_que.put(i)
121 |
122 | # os.makedirs(f"log/{expFolder}", exist_ok=True)
123 | def run_program(gpu, cmd):
124 | cmd = f'{cmd} '
125 | print(cmd)
126 | os.system(cmd)
127 | gpus_que.put(gpu)
128 |
129 |
130 | ths = []
131 | for i in range(len(cmds)):
132 |
133 | gpu = gpus_que.get()
134 | t = threading.Thread(target=run_program, args=(gpu, cmds[i]), daemon=True)
135 | t.start()
136 | ths.append(t)
137 |
138 | for th in ths:
139 | th.join()
140 |
141 |
142 | # import os
143 | # import numpy as np
144 | # root = f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/'
145 | # # root = '/cluster/home/anchen/root/Code/NeuBasis/log/'
146 | # scores = []
147 | # # for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
148 | # for scene in ['Caterpillar','Family','Ignatius','Truck']:
149 | # scores.append(np.loadtxt(f'{root}/{scene}-grid/imgs_test_all/mean.txt'))
150 | # # os.system(f'cp {root}/{scene}-grid/imgs_test_all/video.mp4 /mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/video/{scene}.mp4')
151 | # os.system(f'cp {root}/{scene}-grid/{scene}-grid/imgs_path_all/video.mp4 /mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/video/{scene}.mp4')
152 | # # print(np.mean(np.stack(scores),axis=0))
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/scripts/.DS_Store
--------------------------------------------------------------------------------
/scripts/2D_set_regression.py:
--------------------------------------------------------------------------------
1 | import torch,imageio,sys,cmapy,time,os
2 | import numpy as np
3 | from tqdm import tqdm
4 | # from .autonotebook import tqdm as tqdm
5 | import matplotlib.pyplot as plt
6 | from omegaconf import OmegaConf
7 | import torch.nn.functional as F
8 |
9 | sys.path.append('..')
10 | from models.FactorFields import FactorFields
11 |
12 | from utils import SimpleSampler,TVLoss
13 | from dataLoader import dataset_dict
14 | from torch.utils.data import DataLoader
15 |
16 | device = 'cuda'
17 |
18 |
19 | def PSNR(a, b):
20 | if type(a).__module__ == np.__name__:
21 | mse = np.mean((a - b) ** 2)
22 | else:
23 | mse = torch.mean((a - b) ** 2).item()
24 | psnr = -10.0 * np.log(mse) / np.log(10.0)
25 | return psnr
26 |
27 |
28 | @torch.no_grad()
29 | def eval_img(aabb, reso, idx, shiftment=[0.5, 0.5, 0.5], chunk=10240):
30 | y = torch.linspace(0, aabb[0] - 1, reso[0])
31 | x = torch.linspace(0, aabb[1] - 1, reso[1])
32 | yy, xx = torch.meshgrid((y, x), indexing='ij')
33 | zz = torch.ones_like(xx) * idx
34 |
35 | idx = 0
36 | res = torch.empty(reso[0] * reso[1], train_dataset.imgs.shape[-1])
37 | coordiantes = torch.stack((xx, yy, zz), dim=-1).reshape(-1, 3) + torch.tensor(
38 | shiftment) # /(torch.FloatTensor(reso[::-1])-1)*2-1
39 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)):
40 | feats, _ = model.get_coding_imgage_set(coordiante.to(model.device))
41 | y_recon = model.linear_mat(feats, is_train=False)
42 |
43 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu()
44 | idx += y_recon.shape[0]
45 | return res.view(reso[0], reso[1], -1), coordiantes
46 |
47 |
48 | @torch.no_grad()
49 | def eval_img_single(aabb, reso, chunk=10240):
50 | y = torch.linspace(0, aabb[0] - 1, reso[0])
51 | x = torch.linspace(0, aabb[1] - 1, reso[1])
52 | yy, xx = torch.meshgrid((y, x), indexing='ij')
53 |
54 | idx = 0
55 | res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1])
56 | coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + 0.5
57 |
58 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)):
59 | feats, _ = model.get_coding(coordiante.to(model.device))
60 | y_recon = model.linear_mat(feats, is_train=False)
61 |
62 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu()
63 | idx += y_recon.shape[0]
64 | return res.view(reso[0], reso[1], -1), coordiantes
65 |
66 |
67 | def linear_to_srgb(img):
68 | limit = 0.0031308
69 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
70 |
71 |
72 | def srgb_to_linear(img):
73 | limit = 0.04045
74 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92)
75 |
76 |
77 | def write_image_imageio(img_file, img, colormap=None, quality=100):
78 | if colormap == 'turbo':
79 | shape = img.shape
80 | img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1)
81 | elif colormap is not None:
82 | img = cmapy.colorize((img * 255).astype('uint8'), colormap)
83 |
84 | if img.dtype != 'uint8':
85 | img = (img - np.min(img)) / (np.max(img) - np.min(img))
86 | img = (img * 255.0).astype(np.uint8)
87 |
88 | kwargs = {}
89 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
90 | if img.ndim >= 3 and img.shape[2] > 3:
91 | img = img[:, :, :3]
92 | kwargs["quality"] = quality
93 | kwargs["subsampling"] = 0
94 | imageio.imwrite(img_file, img, **kwargs)
95 |
96 |
97 | def interpolate(colormap, x):
98 | a = (x * 255.0).astype('uint8')
99 | b = np.clip(a + 1, 0, 255)
100 | f = x * 255.0 - a
101 |
102 | return np.stack([colormap[a][..., 0] + (colormap[b][..., 0] - colormap[a][..., 0]) * f,
103 | colormap[a][..., 1] + (colormap[b][..., 1] - colormap[a][..., 1]) * f,
104 | colormap[a][..., 2] + (colormap[b][..., 2] - colormap[a][..., 2]) * f], axis=-1)
105 |
106 | base_conf = OmegaConf.load('../configs/defaults.yaml')
107 | second_conf = OmegaConf.load('../configs/image_set.yaml')
108 | cfg = OmegaConf.merge(base_conf, second_conf)
109 |
110 | dataset = dataset_dict[cfg.dataset.dataset_name]
111 | train_dataset = dataset(cfg.dataset,cfg.training.batch_size, split='train',N=600,tolinear=True,HW=512, continue_sampling=True)
112 | train_loader = DataLoader(train_dataset,
113 | num_workers=8,
114 | persistent_workers=True,
115 | batch_size=None,
116 | pin_memory=True)
117 |
118 |
119 | batch_size = cfg.training.batch_size
120 | n_iter = cfg.training.n_iters
121 |
122 | model = FactorFields(cfg, device)
123 | tvreg = TVLoss()
124 | print(model)
125 | print('total parameters: ', model.n_parameters())
126 |
127 | grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small, lr_large=cfg.training.lr_large)
128 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) #
129 |
130 | tv_loss = 0
131 | loss_scale = 1.0
132 | lr_factor = 0.1 ** (1 / n_iter)
133 | pbar = tqdm(range(n_iter))
134 | for (iteration, sample) in zip(pbar, train_loader):
135 | loss_scale *= lr_factor
136 |
137 | coordiantes, pixel_rgb = sample['xy'], sample['rgb']
138 |
139 | basis, coeff = model.get_coding_imgage_set(coordiantes.to(device))
140 |
141 | y_recon = model.linear_mat(basis, is_train=True)
142 | # y_recon = torch.sum(basis,dim=-1,keepdim=True)
143 | l2_loss = torch.mean((y_recon.squeeze() - pixel_rgb.squeeze().to(device)) ** 2) # + 4e-3*coeff.abs().mean()
144 |
145 | # tv_loss = model.TV_loss(tvreg)
146 | loss = l2_loss # + tv_loss*10
147 |
148 | # loss = loss * loss_scale
149 | optimizer.zero_grad()
150 | loss.backward()
151 | optimizer.step()
152 | # if iteration%100==0:
153 | # model.normalize_basis()
154 |
155 | psnr = -10.0 * np.log(l2_loss.item()) / np.log(10.0)
156 | if iteration % 100 == 0:
157 | pbar.set_description(
158 | f'Iteration {iteration:05d}:'
159 | + f' loss_dist = {l2_loss.item():.8f}'
160 | + f' tv_loss = {tv_loss:.6f}'
161 | + f' psnr = {psnr:.3f}'
162 | )
163 |
164 | save_root = '../log/imageSet/ffhq_mlp_coeff_16_64_8_pe_linear_64_node_800/'
165 | os.makedirs(save_root, exist_ok=True)
166 | model.save(f'{save_root}/ckpt.th')
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/mesh2SDF_data_process.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "47120857-b2aa-4a36-8733-e04776ca7a80",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os,sys,trimesh\n",
11 | "import numpy as np\n",
12 | "sys.path.append('/home/anpei/Code/nglod/sdf-net')\n",
13 | "\n",
14 | "os.environ['PYOPENGL_PLATFORM'] = 'egl'\n",
15 | "os.environ['MESA_GL_VERSION_OVERRIDE'] = '3.3'\n",
16 | "os.environ['MESA_GLSL_VERSION_OVERRIDE'] = '330'\n",
17 | "\n",
18 | "import torch\n",
19 | "from lib.torchgp import load_obj, point_sample, sample_surface, compute_sdf, normalize"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "id": "0c8bf6fa-7295-40cb-aaaa-3cfbf95350a6",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "def load_obj(path):\n",
30 | " mesh = trimesh.load(path)\n",
31 | " mesh.vertices = normalize(mesh.vertices)\n",
32 | " return mesh\n",
33 | " \n",
34 | "def normalize(V):\n",
35 | "\n",
36 | " # Normalize mesh\n",
37 | " V_max = np.max(V, axis=0)\n",
38 | " V_min = np.min(V, axis=0)\n",
39 | " V_center = (V_max + V_min) / 2.\n",
40 | " V = V - V_center\n",
41 | "\n",
42 | " # Find the max distance to origin\n",
43 | " max_dist = np.sqrt(np.max(np.sum(V**2, axis=-1)))\n",
44 | " V_scale = 1. / max_dist\n",
45 | " V *= V_scale\n",
46 | " return V\n",
47 | "\n",
48 | "\n",
49 | "def resample(V,F,num_samples, chunk=10000, sample_mode=['rand', 'near', 'near', 'near', 'near']):\n",
50 | " \"\"\"Resample SDF samples.\"\"\"\n",
51 | "\n",
52 | " points, sdfs = [],[]\n",
53 | " for _ in range(num_samples//chunk):\n",
54 | "\n",
55 | " pts = point_sample(V, F, sample_mode, chunk)\n",
56 | " sdf = compute_sdf(V, F, pts.cuda()) \n",
57 | " points.append(pts.cpu())\n",
58 | " sdfs.append(sdf.cpu())\n",
59 | " return torch.cat(points), torch.cat(sdfs)"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 3,
65 | "id": "b74c1144-f228-473e-abe3-059b9ae19d9c",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "for item in ['statuette.ply','dragon.ply','armadillo.obj','lucy.ply']:#'statuette.ply','dragon.ply','armadillo.obj'\n",
70 | " dataset_path = f'/home/anpei/Dataset/mesh/obj/{item}'\n",
71 | " mesh = load_obj(dataset_path)\n",
72 | " \n",
73 | " f = SDF(mesh.vertices, mesh.faces);\n",
74 | "\n",
75 | " V = torch.from_numpy(mesh.vertices).float().cuda()\n",
76 | " F = torch.from_numpy(mesh.faces).cuda()\n",
77 | " \n",
78 | " pts_train, _ = resample(V, F, num_samples=int(8*1024*1024/5))\n",
79 | " sdf_train = f(pts_train.numpy())\n",
80 | " \n",
81 | " pts_test, _ = resample(V, F, num_samples=int(16*1024*1024//5))\n",
82 | " sdf_test = f(pts_test.numpy())\n",
83 | "\n",
84 | " np.save(f'/home/anpei/Dataset/mesh/obj/sdf/{item[:-4]}_8M',{'points_train':pts_train.numpy(),'sdfs_train':sdf_train, \\\n",
85 | " 'points_test':pts_test.numpy(),'sdfs_test':sdf_test})\n"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "dafbf526-e79a-47a9-bf7b-6da70497c3af",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": []
95 | }
96 | ],
97 | "metadata": {
98 | "kernelspec": {
99 | "display_name": "Python 3 (ipykernel)",
100 | "language": "python",
101 | "name": "python3"
102 | },
103 | "language_info": {
104 | "codemirror_mode": {
105 | "name": "ipython",
106 | "version": 3
107 | },
108 | "file_extension": ".py",
109 | "mimetype": "text/x-python",
110 | "name": "python",
111 | "nbconvert_exporter": "python",
112 | "pygments_lexer": "ipython3",
113 | "version": "3.9.7"
114 | }
115 | },
116 | "nbformat": 4,
117 | "nbformat_minor": 5
118 | }
119 |
--------------------------------------------------------------------------------
/train_across_scene.py:
--------------------------------------------------------------------------------
1 | from tqdm.auto import tqdm
2 | from omegaconf import OmegaConf
3 | from models.FactorFields import FactorFields
4 |
5 | import json, random,time
6 | from renderer import *
7 | from utils import *
8 | from torch.utils.tensorboard import SummaryWriter
9 | import datetime
10 | from torch.utils.data import DataLoader
11 |
12 | from dataLoader import dataset_dict
13 | import sys
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | class SimpleSampler:
18 | def __init__(self, total, batch):
19 | self.total = total
20 | self.batch = batch
21 | self.curr = total
22 | self.ids = None
23 |
24 | def nextids(self):
25 | self.curr += self.batch
26 | if self.curr + self.batch > self.total:
27 | self.ids = torch.LongTensor(np.random.permutation(self.total))
28 | self.curr = 0
29 | return self.ids[self.curr:self.curr + self.batch]
30 |
31 |
32 | @torch.no_grad()
33 | def export_mesh(cfg):
34 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
35 | model = FactorFields( ckpt['cfg'], device)
36 | model.load(ckpt)
37 |
38 | alpha, _ = model.getDenseAlpha([512]*3)
39 | convert_sdf_samples_to_ply(alpha.cpu(), f'{cfg.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.2)
40 |
41 | # @torch.no_grad()
42 | # def export_mesh(cfg, downsample=1, n_views=100):
43 | # cfg.dataset.downsample_train = downsample
44 | # dataset = dataset_dict[cfg.dataset.dataset_name]
45 | # train_dataset = dataset(cfg.dataset, split='train',is_stack=True)
46 | #
47 | # ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
48 | # model = FactorFields( ckpt['cfg'], device)
49 | # model.load(ckpt)
50 | #
51 | # output_dir = f'{cfg.defaults.ckpt[:-3]}.ply'
52 | # export_tsdf_mesh(model, train_dataset, render_ray, white_bg=train_dataset.white_bg, output_dir=output_dir, n_views=n_views)
53 |
54 | @torch.no_grad()
55 | def render_test(cfg):
56 | # init dataset
57 | dataset = dataset_dict[cfg.dataset.dataset_name]
58 | test_dataset = dataset(cfg.dataset, split='test')
59 | white_bg = test_dataset.white_bg
60 | ndc_ray = cfg.dataset.ndc_ray
61 |
62 | if not os.path.exists(cfg.defaults.ckpt):
63 | print('the ckpt path does not exists!!')
64 | return
65 |
66 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
67 | model = FactorFields( ckpt['cfg'], device)
68 | model.load(ckpt)
69 |
70 |
71 | logfolder = os.path.dirname(cfg.defaults.ckpt)
72 | if cfg.exportation.render_train:
73 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
74 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True)
75 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/',
76 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
77 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
78 |
79 | if cfg.exportation.render_test:
80 | # model.upsample_volume_grid()
81 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True)
82 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/',
83 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
84 | n_params = model.n_parameters()
85 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
86 |
87 |
88 | if cfg.exportation.render_path:
89 | c2ws = test_dataset.render_path
90 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True)
91 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/',
92 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
93 |
94 | if cfg.exportation.export_mesh:
95 | alpha, _ = model.getDenseAlpha(times=1)
96 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(),level=0.02)
97 |
98 |
99 | def reconstruction(cfg):
100 | # init dataset
101 | dataset = dataset_dict[cfg.dataset.dataset_name]
102 | train_dataset = dataset(cfg.dataset, split='train', batch_size=cfg.training.batch_size)
103 | test_dataset = dataset(cfg.dataset, split='test')
104 | white_bg = train_dataset.white_bg
105 | ndc_ray = cfg.dataset.ndc_ray
106 |
107 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)
108 |
109 | # init resolution
110 | upsamp_list = cfg.training.upsamp_list
111 | update_AlphaMask_list = cfg.training.update_AlphaMask_list
112 |
113 | if cfg.defaults.add_timestamp:
114 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
115 | else:
116 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}'
117 |
118 | # init log file
119 | os.makedirs(logfolder, exist_ok=True)
120 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
121 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
122 | os.makedirs(f'{logfolder}/rgba', exist_ok=True)
123 | summary_writer = SummaryWriter(logfolder)
124 |
125 | cfg.dataset.aabb = train_dataset.scene_bbox
126 |
127 | if cfg.defaults.ckpt is not None:
128 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
129 | # cfg = ckpt['cfg']
130 | model = FactorFields(cfg, device)
131 | model.load(ckpt)
132 | else:
133 | model = FactorFields(cfg, device)
134 | print(model)
135 |
136 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
137 | if cfg.training.lr_decay_iters > 0:
138 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters)
139 | else:
140 | cfg.training.lr_decay_iters = cfg.training.n_iters
141 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters)
142 |
143 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters)
144 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
145 |
146 | # linear in logrithmic space
147 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal,
148 | len(cfg.training.upsamp_list)).ceil().long().tolist()
149 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb)
150 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio))
151 |
152 | torch.cuda.empty_cache()
153 | PSNRs, PSNRs_test = [], [0]
154 |
155 | steps_inner = 16
156 | start = time.time()
157 | pbar = tqdm(range(cfg.training.n_iters//steps_inner), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout)
158 | for iteration in pbar:
159 |
160 | # train_dataset.update_index()
161 | scene_idx = torch.randint(0, len(train_dataset.all_rgb_files), (1,)).item()
162 | model.scene_idx = scene_idx
163 | for j in range(steps_inner):
164 |
165 | if j%steps_inner==0:
166 | model.set_optimizable(['coef'], True)
167 | model.set_optimizable(['proj','basis','renderer'], False)
168 | elif j%steps_inner==steps_inner-3:
169 | model.set_optimizable(['coef'], False)
170 | model.set_optimizable(['mlp', 'basis','renderer'], True)
171 |
172 | data = train_dataset[scene_idx] #next(iterator)
173 | rays_train, rgb_train = data['rays'].view(-1,6), data['rgbs'].view(-1,3).to(device)
174 |
175 |
176 | rgb_map, depth_map, coefffs = render_ray(rays_train, model, chunk=cfg.training.batch_size,
177 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
178 | is_train=True)
179 |
180 | loss = torch.mean((rgb_map - rgb_train) ** 2) #+ torch.mean(coefffs.abs())*1e-4
181 |
182 | # loss
183 | optimizer.zero_grad()
184 | loss.backward()
185 | optimizer.step()
186 |
187 | loss = loss.detach().item()
188 |
189 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
190 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
191 | summary_writer.add_scalar('train/mse', loss, global_step=iteration)
192 |
193 | for param_group in optimizer.param_groups:
194 | param_group['lr'] = param_group['lr'] * lr_factor
195 |
196 | # Print the current values of the losses.
197 | pbar.set_description(
198 | f'Iteration {iteration:05d}:'
199 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
200 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
201 | + f' mse = {loss:.6f}'
202 | )
203 | PSNRs = []
204 |
205 |
206 | time_iter = time.time()-start
207 | print(f'=======> time takes: {time_iter} <=============')
208 | os.makedirs(f'{logfolder}/imgs_test_all/', exist_ok=True)
209 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter])
210 | model.save(f'{logfolder}/{cfg.defaults.expname}.th')
211 |
212 | if cfg.render_train:
213 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
214 | train_dataset = dataset(cfg.defaults.datadir, split='train', downsample=args.downsample_train, is_stack=True)
215 | PSNRs_test = evaluation(train_dataset,model, args, renderer, f'{logfolder}/imgs_train_all/',
216 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
217 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
218 |
219 | if cfg.exportation.render_test:
220 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
221 | if 'reconstructions' in cfg.defaults.mode:
222 | model.scene_idx = test_dataset.test_index
223 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/',
224 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
225 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
226 | n_params = model.n_parameters()
227 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
228 |
229 | if cfg.exportation.export_mesh:
230 | cfg.defaults.ckpt = f'{logfolder}/{cfg.defaults.expname}.th'
231 | export_mesh(cfg)
232 |
233 | if __name__ == '__main__':
234 |
235 | torch.set_default_dtype(torch.float32)
236 | torch.manual_seed(20211202)
237 | np.random.seed(20211202)
238 |
239 | base_conf = OmegaConf.load('configs/defaults.yaml')
240 | print(sys.argv)
241 | path_config = sys.argv[1]
242 | cli_conf = OmegaConf.from_cli()
243 | second_conf = OmegaConf.load(path_config)
244 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
245 | print(cfg)
246 |
247 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path):
248 | render_test(cfg)
249 | elif cfg.exportation.export_mesh_only:
250 | export_mesh(cfg)
251 | else:
252 | reconstruction(cfg)
--------------------------------------------------------------------------------
/train_across_scene_ft.py:
--------------------------------------------------------------------------------
1 | from tqdm.auto import tqdm
2 | from omegaconf import OmegaConf
3 | from models.FactorFields import FactorFields
4 |
5 | import json, random,time
6 | from renderer import *
7 | from utils import *
8 | from torch.utils.tensorboard import SummaryWriter
9 | import datetime
10 | from torch.utils.data import DataLoader
11 |
12 | from dataLoader import dataset_dict
13 | import sys
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 |
17 | class SimpleSampler:
18 | def __init__(self, total, batch):
19 | self.total = total
20 | self.batch = batch
21 | self.curr = total
22 | self.ids = None
23 |
24 | def nextids(self):
25 | self.curr += self.batch
26 | if self.curr + self.batch > self.total:
27 | self.ids = torch.LongTensor(np.random.permutation(self.total))
28 | self.curr = 0
29 | return self.ids[self.curr:self.curr + self.batch]
30 |
31 |
32 | @torch.no_grad()
33 | def export_mesh(cfg):
34 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
35 | model = FactorFields( ckpt['cfg'], device)
36 | model.load(ckpt)
37 |
38 | alpha, _ = model.getDenseAlpha([512]*3)
39 | convert_sdf_samples_to_ply(alpha.cpu(), f'{cfg.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.2)
40 |
41 | @torch.no_grad()
42 | def render_test(cfg):
43 | # init dataset
44 | dataset = dataset_dict[cfg.dataset.dataset_name]
45 | test_dataset = dataset(cfg.dataset, split='test')
46 | white_bg = test_dataset.white_bg
47 | ndc_ray = cfg.dataset.ndc_ray
48 |
49 | if not os.path.exists(cfg.defaults.ckpt):
50 | print('the ckpt path does not exists!!')
51 | return
52 |
53 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
54 | cfg.dataset.aabb = test_dataset.scene_bbox
55 | model = FactorFields(cfg, device)
56 |
57 | model.load(ckpt)
58 |
59 |
60 | logfolder = os.path.dirname(cfg.defaults.ckpt)
61 | if cfg.exportation.render_train:
62 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
63 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True)
64 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/',
65 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
66 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
67 |
68 | if cfg.exportation.render_test:
69 | # model.upsample_volume_grid()
70 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True)
71 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/',
72 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
73 | n_params = model.n_parameters()
74 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
75 |
76 |
77 | if cfg.exportation.render_path:
78 | c2ws = test_dataset.render_path
79 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True)
80 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/',
81 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
82 |
83 | if cfg.exportation.export_mesh:
84 | alpha, _ = model.getDenseAlpha(times=1)
85 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(),level=0.02)
86 |
87 |
88 | def reconstruction(cfg):
89 | # init dataset
90 | dataset = dataset_dict[cfg.dataset.dataset_name]
91 | train_dataset = dataset(cfg.dataset, split='train', batch_size=cfg.training.batch_size)
92 | test_dataset = dataset(cfg.dataset, split='test')
93 | white_bg = train_dataset.white_bg
94 | ndc_ray = cfg.dataset.ndc_ray
95 |
96 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)
97 |
98 | # init resolution
99 | upsamp_list = cfg.training.upsamp_list
100 | update_AlphaMask_list = cfg.training.update_AlphaMask_list
101 |
102 | if cfg.defaults.add_timestamp:
103 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
104 | else:
105 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}'
106 |
107 | # init log file
108 | os.makedirs(logfolder, exist_ok=True)
109 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
110 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
111 | os.makedirs(f'{logfolder}/rgba', exist_ok=True)
112 | summary_writer = SummaryWriter(logfolder)
113 |
114 | cfg.dataset.aabb = train_dataset.scene_bbox
115 |
116 | if cfg.defaults.ckpt is not None:
117 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
118 | model = FactorFields(cfg, device)
119 | model.load(ckpt)
120 | else:
121 | model = FactorFields(cfg, device)
122 | print(model)
123 |
124 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
125 | if cfg.training.lr_decay_iters > 0:
126 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters)
127 | else:
128 | cfg.training.lr_decay_iters = cfg.training.n_iters
129 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters)
130 |
131 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters)
132 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
133 |
134 | # linear in logrithmic space
135 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal,
136 | len(cfg.training.upsamp_list)).ceil().long().tolist()
137 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb)
138 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio))
139 |
140 | torch.cuda.empty_cache()
141 | PSNRs, PSNRs_test = [], [0]
142 |
143 | start = time.time()
144 | iterator = iter(trainLoader)
145 | pbar = tqdm(range(cfg.training.n_iters), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout)
146 |
147 | for iteration in pbar:
148 |
149 | data = next(iterator)
150 | rays_train, rgb_train = data['rays'].view(-1,6), data['rgbs'].view(-1,3).to(device)
151 | if 'idx' in data.keys():
152 | model.scene_idx = data['idx']
153 |
154 | rgb_map, depth_map, coefffs = render_ray(rays_train, model, chunk=cfg.training.batch_size,
155 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
156 | is_train=True)
157 |
158 | loss = torch.mean((rgb_map - rgb_train) ** 2)
159 |
160 | # loss
161 | total_loss = loss
162 |
163 | optimizer.zero_grad()
164 | total_loss.backward()
165 | optimizer.step()
166 |
167 | loss = loss.detach().item()
168 |
169 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
170 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
171 | summary_writer.add_scalar('train/mse', loss, global_step=iteration)
172 |
173 | for param_group in optimizer.param_groups:
174 | param_group['lr'] = param_group['lr'] * lr_factor
175 |
176 | # Print the current values of the losses.
177 | if iteration % cfg.defaults.progress_refresh_rate == 0:
178 | pbar.set_description(
179 | f'Iteration {iteration:05d}:'
180 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
181 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
182 | + f' mse = {loss:.6f}'
183 | )
184 | PSNRs = []
185 |
186 | if iteration % cfg.dataset.vis_every == cfg.dataset.vis_every - 1:
187 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_vis/',
188 | N_vis=cfg.dataset.N_vis,
189 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray,
190 | compute_extra_metrics=False)
191 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
192 |
193 | if iteration in update_AlphaMask_list or iteration in cfg.training.shrinking_list:
194 |
195 | if volume_resoList[0] < 256: # update volume resolution
196 | reso_mask = N_to_reso(volume_resoList[0]**model.in_dim, model.aabb)
197 |
198 | new_aabb = model.updateAlphaMask([cfg.model.coeff_reso]*3,is_update_alphaMask=True)
199 |
200 | if iteration in cfg.training.shrinking_list:
201 | model.shrink(new_aabb)
202 | L1_reg_weight = cfg.training.L1_weight_rest
203 |
204 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
205 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
206 | print("continuing L1_reg_weight", L1_reg_weight)
207 |
208 | if not cfg.dataset.ndc_ray and iteration == update_AlphaMask_list[0] and not cfg.dataset.is_unbound:
209 | # filter rays outside the bbox
210 | train_dataset.all_rays, train_dataset.all_rgbs = model.filtering_rays(train_dataset.all_rays, train_dataset.all_rgbs)
211 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)
212 | iterator = iter(trainLoader)
213 |
214 |
215 | if iteration in upsamp_list:
216 | n_voxels = volume_resoList.pop(0)
217 | reso_cur = N_to_reso(n_voxels**model.in_dim, model.aabb)
218 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio))
219 | model.upsample_volume_grid(reso_cur)
220 |
221 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
222 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
223 | torch.cuda.empty_cache()
224 |
225 | if iteration==3000:
226 | model.cfg.training.renderModule = True
227 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
228 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
229 |
230 | time_iter = time.time()-start
231 | print(f'=======> time takes: {time_iter} <=============')
232 | os.makedirs(f'{logfolder}/imgs_test_all/', exist_ok=True)
233 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter])
234 | model.save(f'{logfolder}/{cfg.defaults.expname}.th')
235 |
236 | if args.render_train:
237 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
238 | train_dataset = dataset(cfg.defaults.datadir, split='train', downsample=args.downsample_train, is_stack=True)
239 | PSNRs_test = evaluation(train_dataset,model, args, renderer, f'{logfolder}/imgs_train_all/',
240 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
241 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
242 |
243 | if cfg.exportation.render_test:
244 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
245 | if 'reconstructions' in cfg.defaults.mode:
246 | model.scene_idx = test_dataset.test_index
247 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/',
248 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
249 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
250 | n_params = model.n_parameters()
251 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
252 |
253 | if cfg.exportation.export_mesh:
254 | cfg.defaults.ckpt = f'{logfolder}/{cfg.defaults.expname}.th'
255 | export_mesh(cfg)
256 |
257 | if __name__ == '__main__':
258 |
259 | torch.set_default_dtype(torch.float32)
260 | torch.manual_seed(20211202)
261 | np.random.seed(20211202)
262 |
263 | base_conf = OmegaConf.load('configs/defaults.yaml')
264 | path_config = sys.argv[1]
265 | cli_conf = OmegaConf.from_cli()
266 | second_conf = OmegaConf.load(path_config)
267 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
268 | print(cfg)
269 |
270 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path):
271 | render_test(cfg)
272 | elif cfg.exportation.export_mesh_only:
273 | export_mesh(cfg)
274 | else:
275 | reconstruction(cfg)
--------------------------------------------------------------------------------
/train_per_scene.py:
--------------------------------------------------------------------------------
1 | from tqdm.auto import tqdm
2 | from omegaconf import OmegaConf
3 | from models.FactorFields import FactorFields
4 |
5 | import json, random,time
6 | from renderer import *
7 | from utils import *
8 | from torch.utils.tensorboard import SummaryWriter
9 | import datetime
10 |
11 | from dataLoader import dataset_dict
12 | import sys
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 | class SimpleSampler:
17 | def __init__(self, total, batch):
18 | self.total = total
19 | self.batch = batch
20 | self.curr = total
21 | self.ids = None
22 |
23 | def nextids(self):
24 | self.curr += self.batch
25 | if self.curr + self.batch > self.total:
26 | self.ids = torch.LongTensor(np.random.permutation(self.total))
27 | self.curr = 0
28 | return self.ids[self.curr:self.curr + self.batch]
29 |
30 |
31 | @torch.no_grad()
32 | def export_mesh(ckpt_path):
33 | ckpt = torch.load(ckpt_path, map_location=device)
34 | cfg = ckpt['cfg']
35 | cfg.defaults.device = device
36 | model = FactorFields(cfg)
37 | model.load(ckpt)
38 |
39 | alpha, _ = model.getDenseAlpha()
40 | convert_sdf_samples_to_ply(alpha.cpu(), f'{ckpt.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.005)
41 |
42 |
43 | @torch.no_grad()
44 | def render_test(cfg):
45 | # init dataset
46 | dataset = dataset_dict[cfg.dataset.dataset_name]
47 | test_dataset = dataset(cfg.dataset, split='test')
48 | white_bg = test_dataset.white_bg
49 | ndc_ray = cfg.dataset.ndc_ray
50 |
51 | if not os.path.exists(cfg.defaults.ckpt):
52 | print('the ckpt path does not exists!!')
53 | return
54 |
55 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
56 | model = FactorFields( ckpt['cfg'], device)
57 | model.load(ckpt)
58 |
59 |
60 | logfolder = os.path.dirname(cfg.defaults.ckpt)
61 | if cfg.exportation.render_train:
62 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
63 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True)
64 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/',
65 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
66 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
67 |
68 | if cfg.exportation.render_test:
69 | # model.upsample_volume_grid()
70 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True)
71 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/',
72 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
73 | n_params = model.n_parameters()
74 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
75 |
76 |
77 | if cfg.exportation.render_path:
78 | c2ws = test_dataset.render_path
79 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True)
80 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/',
81 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
82 |
83 | if cfg.exportation.export_mesh:
84 | alpha, _ = model.getDenseAlpha()
85 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(),
86 | level=0.005)
87 |
88 |
89 | def reconstruction(cfg):
90 | # init dataset
91 | dataset = dataset_dict[cfg.dataset.dataset_name]
92 | train_dataset = dataset(cfg.dataset, split='train')
93 | test_dataset = dataset(cfg.dataset, split='test')
94 | white_bg = train_dataset.white_bg
95 | ndc_ray = cfg.dataset.ndc_ray
96 |
97 | # init resolution
98 | upsamp_list = cfg.training.upsamp_list
99 | update_AlphaMask_list = cfg.training.update_AlphaMask_list
100 |
101 | if cfg.defaults.add_timestamp:
102 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
103 | else:
104 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}'
105 |
106 | # init log file
107 | os.makedirs(logfolder, exist_ok=True)
108 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
109 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
110 | os.makedirs(f'{logfolder}/rgba', exist_ok=True)
111 | summary_writer = SummaryWriter(logfolder)
112 |
113 | cfg.dataset.aabb = aabb = train_dataset.scene_bbox
114 |
115 | if cfg.defaults.ckpt is not None:
116 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device)
117 | cfg = ckpt['cfg']
118 | model = FactorFields(cfg, device)
119 | model.load(ckpt)
120 | else:
121 | model = FactorFields(cfg, device)
122 | print(model)
123 |
124 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
125 | if cfg.training.lr_decay_iters > 0:
126 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters)
127 | else:
128 | cfg.training.lr_decay_iters = cfg.training.n_iters
129 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters)
130 |
131 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters)
132 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
133 |
134 | # linear in logrithmic space
135 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal,
136 | len(cfg.training.upsamp_list)).ceil().long().tolist()
137 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb)
138 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio))
139 |
140 | torch.cuda.empty_cache()
141 | PSNRs, PSNRs_test = [], [0]
142 |
143 | allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
144 | trainingSampler = SimpleSampler(allrays.shape[0], cfg.training.batch_size)
145 |
146 |
147 | start = time.time()
148 | pbar = tqdm(range(cfg.training.n_iters), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout)
149 | for iteration in pbar:
150 |
151 | ray_idx = trainingSampler.nextids()
152 | rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)
153 |
154 | rgb_map, depth_map, _ = render_ray(rays_train, model, chunk=cfg.training.batch_size,
155 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
156 | is_train=True)
157 |
158 | loss = torch.mean((rgb_map - rgb_train) ** 2)
159 |
160 | optimizer.zero_grad()
161 | loss.backward()
162 | optimizer.step()
163 |
164 | loss = loss.detach().item()
165 |
166 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
167 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
168 | summary_writer.add_scalar('train/mse', loss, global_step=iteration)
169 |
170 | for param_group in optimizer.param_groups:
171 | param_group['lr'] = param_group['lr'] * lr_factor
172 |
173 | # Print the current values of the losses.
174 | if iteration % cfg.defaults.progress_refresh_rate == 0:
175 | pbar.set_description(
176 | f'Iteration {iteration:05d}:'
177 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
178 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
179 | + f' mse = {loss:.6f}'
180 | )
181 | PSNRs = []
182 |
183 | if iteration % cfg.dataset.vis_every == cfg.dataset.vis_every - 1:
184 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_vis/',
185 | N_vis=cfg.dataset.N_vis,
186 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray,
187 | compute_extra_metrics=False)
188 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
189 |
190 | if iteration in update_AlphaMask_list or iteration in cfg.training.shrinking_list:
191 |
192 | if volume_resoList[0] < 256: # update volume resolution
193 | reso_mask = N_to_reso(volume_resoList[0]**model.in_dim, model.aabb)
194 |
195 | new_aabb = model.updateAlphaMask(tuple(reso_mask), is_update_alphaMask=iteration >= 1500)
196 |
197 |
198 | if iteration in cfg.training.shrinking_list:
199 | model.shrink(new_aabb)
200 | L1_reg_weight = cfg.training.L1_weight_rest
201 |
202 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
203 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
204 | print("continuing L1_reg_weight", L1_reg_weight)
205 |
206 | if not cfg.dataset.ndc_ray and iteration == update_AlphaMask_list[0] and not cfg.dataset.is_unbound:
207 | # filter rays outside the bbox
208 | allrays, allrgbs = model.filtering_rays(allrays, allrgbs)
209 | trainingSampler = SimpleSampler(allrgbs.shape[0], cfg.training.batch_size)
210 |
211 |
212 | if iteration in upsamp_list:
213 | n_voxels = volume_resoList.pop(0)
214 | reso_cur = N_to_reso(n_voxels**model.in_dim, model.aabb)
215 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio))
216 | model.upsample_volume_grid(reso_cur)
217 |
218 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large)
219 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
220 | torch.cuda.empty_cache()
221 |
222 | time_iter = time.time()-start
223 | print(f'=======> time takes: {time_iter} <=============')
224 | os.makedirs(f'{logfolder}/imgs_test_all',exist_ok=True)
225 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter])
226 | model.save(f'{logfolder}/{cfg.defaults.expname}.th')
227 |
228 | if cfg.exportation.render_test:
229 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
230 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/',
231 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
232 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
233 | n_params = model.n_parameters()
234 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================')
235 |
236 |
237 | if __name__ == '__main__':
238 |
239 | torch.set_default_dtype(torch.float32)
240 | torch.manual_seed(20211202)
241 | np.random.seed(20211202)
242 |
243 | base_conf = OmegaConf.load('configs/defaults.yaml')
244 | path_config = sys.argv[1]
245 | cli_conf = OmegaConf.from_cli()
246 | second_conf = OmegaConf.load(path_config)
247 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
248 |
249 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path):
250 | render_test(cfg)
251 | elif cfg.exportation.export_mesh_only or cfg.exportation.export_mesh:
252 | export_mesh(cfg)
253 | else:
254 | reconstruction(cfg)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2,torch,math
2 | import numpy as np
3 | from PIL import Image
4 | import torchvision.transforms as T
5 | import torch.nn.functional as F
6 | import scipy.signal
7 |
8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
9 |
10 |
11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
12 | """
13 | depth: (H, W)
14 | """
15 |
16 | x = np.nan_to_num(depth) # change nan to 0
17 | if minmax is None:
18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
19 | ma = np.max(x)
20 | else:
21 | mi,ma = minmax
22 |
23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
24 | x = (255*x).astype(np.uint8)
25 | x_ = cv2.applyColorMap(x, cmap)
26 | return x_, [mi,ma]
27 |
28 | def init_log(log, keys):
29 | for key in keys:
30 | log[key] = torch.tensor([0.0], dtype=float)
31 | return log
32 |
33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
34 | """
35 | depth: (H, W)
36 | """
37 | if type(depth) is not np.ndarray:
38 | depth = depth.cpu().numpy()
39 |
40 | x = np.nan_to_num(depth) # change nan to 0
41 | if minmax is None:
42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
43 | ma = np.max(x)
44 | else:
45 | mi,ma = minmax
46 |
47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
48 | x = (255*x).astype(np.uint8)
49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
50 | x_ = T.ToTensor()(x_) # (3, H, W)
51 | return x_, [mi,ma]
52 |
53 | def N_to_reso(n_voxels, bbox):
54 | xyz_min, xyz_max = bbox
55 | dim = len(xyz_min)
56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
57 | return torch.round((xyz_max - xyz_min) / voxel_size).long().tolist()
58 |
59 | def N_to_vm_reso(n_voxels, bbox):
60 | xyz_min, xyz_max = bbox
61 | dim = len(xyz_min)
62 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
63 | reso = (xyz_max - xyz_min) / voxel_size
64 | assert len(reso)==3
65 | n_mat = reso[0]*reso[1] + reso[0]*reso[2] + reso[1]*reso[2]
66 | scale = math.sqrt(n_voxels/n_mat)
67 | return torch.round(reso*scale).long().tolist()
68 |
69 | def cal_n_samples(reso, step_ratio=0.5):
70 | return int(np.linalg.norm(reso)/step_ratio)
71 |
72 | class SimpleSampler:
73 | def __init__(self, total, batch):
74 | self.total = total
75 | self.batch = batch
76 | self.curr = total
77 | self.ids = None
78 |
79 | def nextids(self):
80 | self.curr+=self.batch
81 | if self.curr + self.batch > self.total:
82 | self.ids = torch.LongTensor(np.random.permutation(self.total))
83 | self.curr = 0
84 | return self.ids[self.curr:self.curr+self.batch]
85 |
86 | __LPIPS__ = {}
87 | def init_lpips(net_name, device):
88 | assert net_name in ['alex', 'vgg']
89 | import lpips
90 | print(f'init_lpips: lpips_{net_name}')
91 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
92 |
93 | def rgb_lpips(np_gt, np_im, net_name, device):
94 | if net_name not in __LPIPS__:
95 | __LPIPS__[net_name] = init_lpips(net_name, device)
96 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
97 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
98 | return __LPIPS__[net_name](gt, im, normalize=True).item()
99 |
100 |
101 | def findItem(items, target):
102 | for one in items:
103 | if one[:len(target)]==target:
104 | return one
105 | return None
106 |
107 |
108 | ''' Evaluation metrics (ssim, lpips)
109 | '''
110 | def rgb_ssim(img0, img1, max_val,
111 | filter_size=11,
112 | filter_sigma=1.5,
113 | k1=0.01,
114 | k2=0.03,
115 | return_map=False):
116 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
117 | assert len(img0.shape) == 3
118 | assert img0.shape[-1] == 3
119 | assert img0.shape == img1.shape
120 |
121 | # Construct a 1D Gaussian blur filter.
122 | hw = filter_size // 2
123 | shift = (2 * hw - filter_size + 1) / 2
124 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
125 | filt = np.exp(-0.5 * f_i)
126 | filt /= np.sum(filt)
127 |
128 | # Blur in x and y (faster than the 2D convolution).
129 | def convolve2d(z, f):
130 | return scipy.signal.convolve2d(z, f, mode='valid')
131 |
132 | filt_fn = lambda z: np.stack([
133 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
134 | for i in range(z.shape[-1])], -1)
135 | mu0 = filt_fn(img0)
136 | mu1 = filt_fn(img1)
137 | mu00 = mu0 * mu0
138 | mu11 = mu1 * mu1
139 | mu01 = mu0 * mu1
140 | sigma00 = filt_fn(img0**2) - mu00
141 | sigma11 = filt_fn(img1**2) - mu11
142 | sigma01 = filt_fn(img0 * img1) - mu01
143 |
144 | # Clip the variances and covariances to valid values.
145 | # Variance must be non-negative:
146 | sigma00 = np.maximum(0., sigma00)
147 | sigma11 = np.maximum(0., sigma11)
148 | sigma01 = np.sign(sigma01) * np.minimum(
149 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
150 | c1 = (k1 * max_val)**2
151 | c2 = (k2 * max_val)**2
152 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
153 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
154 | ssim_map = numer / denom
155 | ssim = np.mean(ssim_map)
156 | return ssim_map if return_map else ssim
157 |
158 |
159 | import torch.nn as nn
160 | class TVLoss(nn.Module):
161 | def __init__(self,TVLoss_weight=1):
162 | super(TVLoss,self).__init__()
163 | self.TVLoss_weight = TVLoss_weight
164 |
165 | def forward(self,x):
166 | batch_size = x.size()[0]
167 | h_x = x.size()[2]
168 | w_x = x.size()[3]
169 | count_h = self._tensor_size(x[:,:,1:,:])
170 | count_w = self._tensor_size(x[:,:,:,1:])
171 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
172 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
173 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
174 |
175 | def _tensor_size(self,t):
176 | return t.size()[1]*t.size()[2]*t.size()[3]
177 |
178 | def marchcude_to_world(vertices, reso_WHD):
179 | return vertices/(np.array(reso_WHD)-1)
180 |
181 | import plyfile
182 | import skimage.measure
183 | def convert_sdf_samples_to_ply(
184 | pytorch_3d_sdf_tensor,
185 | ply_filename_out,
186 | bbox,
187 | level=0.5,
188 | offset=None,
189 | scale=None,
190 | ):
191 | """
192 | Convert sdf samples to .ply
193 |
194 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
195 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
196 | :voxel_size: float, the size of the voxels
197 | :ply_filename_out: string, path of the filename to save to
198 |
199 | This function adapted from: https://github.com/RobotLocomotion/spartan
200 | """
201 |
202 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
203 | # voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))
204 |
205 | verts, faces, normals, values = skimage.measure.marching_cubes(
206 | numpy_3d_sdf_tensor, level=level
207 | )
208 | reso_WHD = numpy_3d_sdf_tensor.shape
209 | print(bbox)
210 | verts = marchcude_to_world(verts, reso_WHD)
211 |
212 |
213 | faces = faces[...,::-1] # inverse face orientation
214 |
215 | # transform from voxel coordinates to camera coordinates
216 | # note x and y are flipped in the output of marching_cubes
217 | mesh_points = np.zeros_like(verts)
218 | bbox = bbox.numpy()
219 | mesh_points[:, 0] = bbox[0,2] + verts[:, 0]*(bbox[1,2]-bbox[0,2])
220 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1]*(bbox[1,1]-bbox[0,1])
221 | mesh_points[:, 2] = bbox[0,0] + verts[:, 2]*(bbox[1,0]-bbox[0,0])
222 |
223 | # # apply additional offset and scale
224 | # if scale is not None:
225 | # mesh_points = mesh_points / scale
226 | # if offset is not None:
227 | # mesh_points = mesh_points - offset
228 |
229 | # try writing to the ply file
230 |
231 | num_verts = verts.shape[0]
232 | num_faces = faces.shape[0]
233 |
234 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
235 |
236 | for i in range(0, num_verts):
237 | verts_tuple[i] = tuple(mesh_points[i, :])
238 |
239 | faces_building = []
240 | for i in range(0, num_faces):
241 | faces_building.append(((faces[i, :].tolist(),)))
242 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
243 |
244 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
245 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
246 |
247 | ply_data = plyfile.PlyData([el_verts, el_faces])
248 | print("saving mesh to %s" % (ply_filename_out))
249 | ply_data.write(ply_filename_out)
250 |
--------------------------------------------------------------------------------