├── .gitignore ├── 2D_regression.py ├── LICENSE ├── README.md ├── README_FactorField.md ├── configs ├── .DS_Store ├── 360_v2.yaml ├── defaults.yaml ├── image.yaml ├── image_intro.yaml ├── image_set.yaml ├── nerf.yaml ├── nerf_ft.yaml ├── nerf_set.yaml ├── sdf.yaml └── tnt.yaml ├── dataLoader ├── .DS_Store ├── __init__.py ├── blender.py ├── blender_set.py ├── colmap.py ├── colmap2nerf.py ├── dtu_objs.py ├── dtu_objs2.py ├── google_objs.py ├── image.py ├── image_set.py ├── llff.py ├── nsvf.py ├── ray_utils.py ├── sdf.py ├── tankstemple.py └── your_own_data.py ├── media ├── Girl_with_a_Pearl_Earring.jpg └── inpainting.png ├── models ├── .DS_Store ├── FactorFields.py ├── __init__.py └── sh.py ├── renderer.py ├── requirements.txt ├── run_batch.py ├── scripts ├── .DS_Store ├── 2D_regression.ipynb ├── 2D_set_regression.ipynb ├── 2D_set_regression.py ├── __init__.py ├── formula_demostration.ipynb ├── mesh2SDF_data_process.ipynb └── sdf_regression.ipynb ├── train_across_scene.py ├── train_across_scene_ft.py ├── train_per_scene.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | slurm/ 164 | logs/ 165 | 166 | -------------------------------------------------------------------------------- /2D_regression.py: -------------------------------------------------------------------------------- 1 | import torch,imageio,sys,time,os,cmapy,scipy 2 | import numpy as np 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | from omegaconf import OmegaConf 6 | import torch.nn.functional as F 7 | 8 | device = 'cuda' 9 | 10 | sys.path.append('..') 11 | from models.sparseCoding import sparseCoding 12 | 13 | from dataLoader import dataset_dict 14 | from torch.utils.data import DataLoader 15 | 16 | 17 | def PSNR(a, b): 18 | if type(a).__module__ == np.__name__: 19 | mse = np.mean((a - b) ** 2) 20 | else: 21 | mse = torch.mean((a - b) ** 2).item() 22 | psnr = -10.0 * np.log(mse) / np.log(10.0) 23 | return psnr 24 | 25 | 26 | def rgb_ssim(img0, img1, max_val, 27 | filter_size=11, 28 | filter_sigma=1.5, 29 | k1=0.01, 30 | k2=0.03, 31 | return_map=False): 32 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 33 | assert len(img0.shape) == 3 34 | assert img0.shape[-1] == 3 35 | assert img0.shape == img1.shape 36 | 37 | # Construct a 1D Gaussian blur filter. 38 | hw = filter_size // 2 39 | shift = (2 * hw - filter_size + 1) / 2 40 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2 41 | filt = np.exp(-0.5 * f_i) 42 | filt /= np.sum(filt) 43 | 44 | # Blur in x and y (faster than the 2D convolution). 45 | def convolve2d(z, f): 46 | return scipy.signal.convolve2d(z, f, mode='valid') 47 | 48 | filt_fn = lambda z: np.stack([ 49 | convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :]) 50 | for i in range(z.shape[-1])], -1) 51 | mu0 = filt_fn(img0) 52 | mu1 = filt_fn(img1) 53 | mu00 = mu0 * mu0 54 | mu11 = mu1 * mu1 55 | mu01 = mu0 * mu1 56 | sigma00 = filt_fn(img0 ** 2) - mu00 57 | sigma11 = filt_fn(img1 ** 2) - mu11 58 | sigma01 = filt_fn(img0 * img1) - mu01 59 | 60 | # Clip the variances and covariances to valid values. 61 | # Variance must be non-negative: 62 | sigma00 = np.maximum(0., sigma00) 63 | sigma11 = np.maximum(0., sigma11) 64 | sigma01 = np.sign(sigma01) * np.minimum( 65 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 66 | c1 = (k1 * max_val) ** 2 67 | c2 = (k2 * max_val) ** 2 68 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 69 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 70 | ssim_map = numer / denom 71 | ssim = np.mean(ssim_map) 72 | return ssim_map if return_map else ssim 73 | 74 | 75 | @torch.no_grad() 76 | def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240): 77 | y = torch.linspace(0, aabb[0] - 1, reso[0]) 78 | x = torch.linspace(0, aabb[1] - 1, reso[1]) 79 | yy, xx = torch.meshgrid((y, x), indexing='ij') 80 | 81 | idx = 0 82 | res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1]) 83 | coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + torch.tensor( 84 | shiftment) # /(torch.FloatTensor(reso[::-1])-1)*2-1 85 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)): 86 | feats, _ = model.get_coding(coordiante.to(model.device)) 87 | y_recon = model.linear_mat(feats, is_train=False) 88 | # y_recon = torch.sum(feats,dim=-1,keepdim=True) 89 | 90 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu() 91 | idx += y_recon.shape[0] 92 | return res.view(reso[0], reso[1], -1), coordiantes 93 | 94 | 95 | def linear_to_srgb(img): 96 | limit = 0.0031308 97 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 98 | 99 | 100 | def write_image_imageio(img_file, img, colormap=None, quality=100): 101 | if colormap == 'turbo': 102 | shape = img.shape 103 | img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1) 104 | elif colormap is not None: 105 | img = cmapy.colorize((img * 255).astype('uint8'), colormap) 106 | 107 | if img.dtype != 'uint8': 108 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 109 | img = (img * 255.0).astype(np.uint8) 110 | 111 | kwargs = {} 112 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]: 113 | if img.ndim >= 3 and img.shape[2] > 3: 114 | img = img[:, :, :3] 115 | kwargs["quality"] = quality 116 | kwargs["subsampling"] = 0 117 | imageio.imwrite(img_file, img, **kwargs) 118 | 119 | if __name__ == '__main__': 120 | 121 | torch.set_default_dtype(torch.float32) 122 | torch.manual_seed(20211202) 123 | np.random.seed(20211202) 124 | 125 | base_conf = OmegaConf.load('configs/defaults.yaml') 126 | cli_conf = OmegaConf.from_cli() 127 | second_conf = OmegaConf.load('configs/image.yaml') 128 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf) 129 | print(cfg) 130 | 131 | 132 | folder = cfg.defaults.expname 133 | save_root = f'/vlg-nfs/anpei/project/NeuBasis/ours/images/' 134 | 135 | dataset = dataset_dict[cfg.dataset.dataset_name] 136 | 137 | delete_region = [[290,350,48,48],[300,380,48,48],[180, 407, 48, 48], [223, 263, 48, 48], [233, 150, 48, 48], [374, 119, 48, 48], [4, 199, 48, 48], [180, 234, 48, 48], [173, 39, 48, 48], [408, 308, 48, 48], [227, 177, 48, 48], [46, 330, 48, 48], [213, 26, 48, 48], [90, 44, 48, 48], [295, 61, 48, 48]] 138 | continue_sampling = False 139 | 140 | psnrs,ssims = [],[] 141 | for i in range(1,257): 142 | cfg.dataset.datadir = f'/vlg-nfs/anpei/dataset/Images/crop//{i:04d}.png' 143 | name = os.path.basename(cfg.dataset.datadir).split('.')[0] 144 | if os.path.exists(f'{save_root}/{folder}/{int(name):04d}.png'): 145 | continue 146 | 147 | 148 | train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train',tolinear=True, perscent=1.0,HW=1024)#, continue_sampling=continue_sampling,delete_region=delete_region 149 | train_loader = DataLoader(train_dataset, 150 | num_workers=2, 151 | persistent_workers=True, 152 | batch_size=None, 153 | pin_memory=False) 154 | # train_dataset.img = train_dataset.img.to(device) 155 | 156 | cfg.model.out_dim = train_dataset.img.shape[-1] 157 | batch_size = cfg.training.batch_size 158 | n_iter = cfg.training.n_iters 159 | 160 | H,W = train_dataset.HW 161 | train_dataset.scene_bbox = [[0., 0.], [W, H]] 162 | cfg.dataset.aabb = train_dataset.scene_bbox 163 | 164 | model = sparseCoding(cfg, device) 165 | if 1==i: 166 | print(model) 167 | print('total parameters: ',model.n_parameters()) 168 | 169 | # tvreg = TVLoss() 170 | # trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size) 171 | 172 | grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large) 173 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))# 174 | 175 | 176 | loss_scale = 1.0 177 | lr_factor = 0.1 ** (1 / n_iter) 178 | # pbar = tqdm(range(10000)) 179 | start = time.time() 180 | # for iteration in pbar: 181 | for (iteration, sample) in zip(range(10000),train_loader): 182 | loss_scale *= lr_factor 183 | 184 | # if iteration==5000: 185 | # model.coeffs = torch.nn.Parameter(F.interpolate(model.coeffs.data, size=None, scale_factor=2.0, align_corners=True,mode='bilinear')) 186 | # grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large) 187 | # optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))# 188 | # model.set_optimizable(['mlp','basis'], False) 189 | 190 | coordiantes, pixel_rgb = sample['xy'], sample['rgb'] 191 | feats,coeff = model.get_coding(coordiantes.to(device)) 192 | # tv_loss = model.TV_loss(tvreg) 193 | 194 | y_recon = model.linear_mat(feats,is_train=True) 195 | # y_recon = torch.sum(feats,dim=-1,keepdim=True) 196 | loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2) 197 | 198 | 199 | psnr = -10.0 * np.log(loss.item()) / np.log(10.0) 200 | # if iteration%100==0: 201 | # pbar.set_description( 202 | # f'Iteration {iteration:05d}:' 203 | # + f' loss_dist = {loss.item():.8f}' 204 | # # + f' tv_loss = {tv_loss.item():.6f}' 205 | # + f' psnr = {psnr:.3f}' 206 | # ) 207 | 208 | # loss = loss + tv_loss 209 | # loss = loss + torch.mean(coeff.abs())*1e-2 210 | loss = loss * loss_scale 211 | optimizer.zero_grad() 212 | loss.backward() 213 | optimizer.step() 214 | 215 | # if iteration%100==0: 216 | # model.normalize_basis() 217 | iteration_time = time.time()-start 218 | 219 | H,W = train_dataset.HW 220 | img,coordinate = eval_img(train_dataset.HW,[1024,1024]) 221 | if continue_sampling: 222 | import torch.nn.functional as F 223 | coordinate_tmp = (coordinate.view(1,1,-1,2))/torch.tensor([W,H])*2-1.0 224 | img_gt = F.grid_sample(train_dataset.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear', 225 | align_corners=False, padding_mode='border').reshape(-1,H,W).permute(1,2,0) 226 | else: 227 | img_gt = train_dataset.img.view(H,W,-1) 228 | psnrs.append(PSNR(img.clamp(0,1.),img_gt)) 229 | ssims.append(rgb_ssim(img.clamp(0,1.),img_gt,1.0)) 230 | # print(PSNR(img.clamp(0,1.),img_gt),iteration_time) 231 | # plt.figure(figsize=(10, 10)) 232 | # plt.imshow(linear_to_srgb(img.clamp(0,1.))) 233 | 234 | print(i, psnrs[-1], ssims[-1]) 235 | 236 | 237 | os.makedirs(f'{save_root}/{folder}',exist_ok=True) 238 | write_image_imageio(f'{save_root}/{folder}/{int(name):04d}.png',linear_to_srgb(img.clamp(0,1.))) 239 | np.savetxt(f'{save_root}/{folder}/{int(name):04d}.txt',[psnrs[-1],ssims[-1],iteration_time,model.n_parameters()]) 240 | 241 | 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Factor Fields 2 | ## [Project page](https://apchenstu.github.io/FactorFields/) | [Paper](https://arxiv.org/abs/2302.01226) 3 | This repository contains a pytorch implementation for the paper: [Factor Fields: A Unified Framework for Neural Fields and Beyond](https://arxiv.org/abs/2302.01226) and [Dictionary Fields: Learning a Neural Basis Decomposition](https://arxiv.org/abs/2302.01226). Our work present a novel framework for modeling and representing signals, 4 | we have also observed that Dictionary Fields offer benefits such as improved **approximation quality**, **compactness**, **faster training speed**, and the ability to **generalize** to unseen images and 3D scenes.

5 | 6 | 7 | ## Installation 8 | 9 | #### Tested on Ubuntu 20.04 + Pytorch 1.13.0 10 | 11 | Install environment: 12 | ```sh 13 | conda create -n FactorFields python=3.9 14 | conda activate FactorFields 15 | conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit 16 | conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | Optionally install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), only needed if you want to run hash grid based representations. 21 | ```sh 22 | conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit 23 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 24 | ``` 25 | 26 | 27 | # Quick Start 28 | Please ensure that you download the corresponding dataset and extract its contents into the `data` folder. 29 | 30 | ## Image 31 | * [Data - Image Set](https://1drv.ms/u/c/0c624178fab774b7/Ebd0t_p4QWIggAx3BAAAAAABikvhj5m_rVm1-qIpYFyrFg?e=hyTeZf) 32 | 33 | The training script can be found at `scripts/2D_regression.ipynb`, and the configuration file is located at `configs/image.yaml`. 34 | 35 |

36 | Girl with a Pearl Earring 37 |

38 | 39 | ## SDF 40 | * [Data - Mesh set](https://1drv.ms/u/c/0c624178fab774b7/Ebd0t_p4QWIggAx4BAAAAAABbouT0SD3PCChlfTQJL3XzA?e=ImcsAj) 41 | 42 | The training script can be found at `scripts/sdf_regression.ipynb`, and the configuration file is located at `configs/sdf.yaml`. 43 | 44 | GIF 45 | 46 | 47 | 48 | ## NeRF 49 | * [Data - Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 50 | * [Data-Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) 51 | 52 | The training script can be found at `train_per_scene.py`: 53 | 54 | ```python 55 | python train_per_scene.py configs/nerf.yaml defaults.expname=lego dataset.datadir=./data/nerf_synthetic/lego 56 | ``` 57 | 58 | GIF 67 | Inpainting 68 |

69 | 70 | 71 | 72 | ## Generalization NeRF 73 | * [Data - Google Scanned Objects](https://drive.google.com/file/d/1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2/view) 74 | 75 | ```python 76 | python train_across_scene.py configs/nerf_set.yaml 77 | ``` 78 | 79 | GIF 80 | 81 | 82 | ## More examples 83 | 84 | Command explanation with a nerf example: 85 | * `model.basis_dims=[4, 4, 4, 2, 2, 2]` adjusts the number of levels and channels at each level, with a total of 6 levels and 18 channels. 86 | * `model.basis_resos=[32, 51, 70, 89, 108, 128]` represents the resolution of the feature embeddings. 87 | * `model.freq_bands=[2.0, 3.2, 4.4, 5.6, 6.8, 8.0]` indicates the frequency parameters applied at each level of the coordinate transformation function. 88 | * `model.coeff_type` represents the coefficient field representations and can be one of the following: [none, x, grid, mlp, vec, cp, vm]. 89 | * `model.basis_type` represents the basis field representation and can be one of the following: [none, x, grid, mlp, vec, cp, vm, hash]. 90 | * `model.basis_mapping` represents the coordinate transformation and can be one of the following: [x, triangle, sawtooth, trigonometric]. Please note that if you want to use orthogonal projection, choose the cp or vm basis type, as they automatically utilize the orthogonal projection functions. 91 | * `model.total_params` controls the total model size. It is important to note that the model's size capability is determined by model.basis_resos and model.basis_dims. The total_params parameter mainly affects the capability of the coefficients. 92 | * `exportation.render_only` you can rendering item after training by setting this label to 1. Please also specify the `defaults.ckpt` label. 93 | * `exportation....` you can specify whether to render the items of `[render_test, render_train, render_path, export_mesh]` after training by enable the corressponding label to 1. 94 | 95 | Some pre-defined configurations (such as occNet, DVGO, nerf, iNGP, EG3D) can be found in `README_FactorField.py`. 96 | 97 | 98 | ## COPY RIGHT 99 | * [Summer Day](https://www.rijksmuseum.nl/en/collection/SK-A-3005) - Credit goes to Johan Hendrik Weissenbruch and rijksmuseum. 100 | * [Mars](https://solarsystem.nasa.gov/resources/933/true-colors-of-pluto/) - Credit goes to NASA. 101 | * [Albert](https://cdn.loc.gov/service/pnp/cph/3b40000/3b46000/3b46000/3b46036v.jpg) - Credit goes to Orren Jack Turner. 102 | * [Girl With a Pearl Earring](http://profoundism.com/free_licenses.html) - Renovation copyright Koorosh Orooj (CC BY-SA 4.0). 103 | 104 | 105 | ## Citation 106 | If you find our code or paper helpful, please consider citing both of these papers: 107 | ``` 108 | @article{Chen2023factor, 109 | title={Factor Fields: A Unified Framework for Neural Fields and Beyond}, 110 | author={Chen, Anpei and Xu, Zexiang and Wei, Xinyue and Tang, Siyu and Su, Hao and Geiger, Andreas}, 111 | journal={arXiv preprint arXiv:2302.01226}, 112 | year={2023} 113 | } 114 | 115 | @article{Chen2023SIGGRAPH, 116 | title={{Dictionary Fields: Learning a Neural Basis Decomposition}}, 117 | author={Anpei, Chen and Zexiang, Xu and Xinyue, Wei and Siyu, Tang and Hao, Su and Andreas, Geiger}, 118 | booktitle={International Conference on Computer Graphics and Interactive Techniques (SIGGRAPH)}, 119 | year={2023}} 120 | ``` 121 | -------------------------------------------------------------------------------- /README_FactorField.md: -------------------------------------------------------------------------------- 1 | ## nerf reconstruction with Dictionary field 2 | 3 | ```python 4 | for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']: 5 | cmd = f'python train_basis.py configs/nerf.yaml defaults.expname={scene} ' \ 6 | f'dataset.datadir=./data/nerf_synthetic/{scene} ' 7 | ``` 8 | 9 | ## different model design choices 10 | 11 | ```python 12 | choice_dict = { 13 | '-grid': '', \ 14 | '-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \ 15 | '-noC': 'model.coeff_type=none', \ 16 | '-SL':'model.basis_dims=[18] model.basis_resos=[70] model.freq_bands=[8.]', \ 17 | '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \ 18 | '-iNGP-like': 'model.basis_type=hash model.coeff_type=none', \ 19 | '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \ 20 | '-sinc': f'model.basis_mapping=sinc', \ 21 | '-tria': f'model.basis_mapping=triangle', \ 22 | '-vm': f'model.coeff_type=vm model.basis_type=vm', \ 23 | '-mlpB': 'model.basis_type=mlp', \ 24 | '-mlpC': 'model.coeff_type=mlp', \ 25 | '-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \ 26 | '-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \ 27 | f'model.num_layers=8 model.hidden_dim=256 ' \ 28 | f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \ 29 | '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \ 30 | '-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \ 31 | '-DCT':'model.basis_type=fix-grid', \ 32 | } 33 | 34 | for name in choice_dict.keys(): 35 | for scene in [ 'ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']: 36 | 37 | cmd = f"python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {config}" 38 | ``` 39 | 40 | ## generalized nerf 41 | Your can choice of the the following design choice for testing. 42 | ```python 43 | choice_dict = { 44 | '-grid': '', \ 45 | '-DVGO-like': 'model.basis_type=none model.coeff_reso=48', 46 | '-SL':'model.basis_dims=[72] model.basis_resos=[48] model.freq_bands=[6.]', \ 47 | '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \ 48 | '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \ 49 | '-sinc': f'model.basis_mapping=sinc', \ 50 | '-tria': f'model.basis_mapping=triangle', \ 51 | '-vm': f'model.coeff_type=vm model.basis_type=vm', \ 52 | '-mlpB': 'model.basis_type=mlp', \ 53 | '-mlpC': 'model.coeff_type=mlp', \ 54 | '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \ 55 | '-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \ 56 | '-DCT':'model.basis_type=fix-grid', \ 57 | } 58 | 59 | for name in choice_dict.keys(): # 60 | cmd = f'python train_across_scene.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \ 61 | f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/' 62 | ``` 63 | 64 | You can also fine tune of the trained model for a new scene: 65 | 66 | ```python 67 | for views in [5]:#3, 68 | for name in choice_dict.keys(): # 69 | for scene in [183]:#183,199,298,467,957,244,963,527, 70 | 71 | cmd = f'python train_across_scene_ft.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \ 72 | f'{config} training.n_iters=10000 ' \ 73 | f'dataset.train_views={views} ' \ 74 | f'dataset.train_scene_list=[{scene}] ' \ 75 | f'dataset.test_scene_list=[{scene}] ' \ 76 | f'dataset.datadir=./data/google_scanned_objects/ ' \ 77 | f'defaults.ckpt=./logs/google-obj{name}//google-obj{name}.th' 78 | ``` 79 | 80 | # render path after optimization 81 | ```python 82 | for views in [5]: 83 | for name in choice_dict.keys(): # 84 | config = commands[name].replace(",", "','") 85 | for scene in [183]:#183,199,298,467,957,244,963,527,681,948 86 | 87 | cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \ 88 | f'{config} training.n_iters=10000 ' \ 89 | f'dataset.train_views={views} exporation.render_only=True exporation.render_path=True exporation.render_test=False ' \ 90 | f'dataset.train_scene_list=[{scene}] ' \ 91 | f'dataset.test_scene_list=[{scene}] ' \ 92 | f'dataset.datadir=./data/google_scanned_objects/ ' \ 93 | f'defaults.ckpt=./logs/google_objs_{name}_{scene}_{views}_views//google_objs_{name}_{scene}_{views}_views.th' 94 | ``` -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/configs/.DS_Store -------------------------------------------------------------------------------- /configs/360_v2.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_room_real_mask 4 | logdir: ./logs 5 | 6 | ckpt: null # help='specific weights npy file to reload for coarse network' 7 | 8 | model: 9 | basis_dims: [5,5,5,2,2,2] 10 | basis_resos: [ 64, 83, 102, 121, 140, 160] 11 | coeff_reso: 16 12 | coef_init: 0.01 13 | phases: [0.0] 14 | 15 | coef_mode: bilinear 16 | basis_mode: bilinear 17 | 18 | freq_bands: [ 1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.] 19 | 20 | kernel_mapping_type: 'sawtooth' 21 | 22 | in_dim: 3 23 | out_dim: 32 24 | num_layers: 2 25 | hidden_dim: 128 26 | 27 | dataset: 28 | # loader options 29 | dataset_name: llff # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 30 | datadir: /home/anpei/code/NeuBasis/data/360_v2/room/ 31 | ndc_ray: 0 32 | is_unbound: True 33 | 34 | with_depth: 0 35 | downsample_train: 4.0 36 | downsample_test: 4.0 37 | 38 | N_vis: 5 39 | vis_every: 5000 40 | 41 | training: 42 | 43 | n_iters: 30000 44 | batch_size: 4096 45 | 46 | volume_resoInit: 128 # 128**3: 47 | volume_resoFinal: 320 # 300**3 48 | 49 | upsamp_list: [2000,3000,4000,5500] 50 | update_AlphaMask_list: [2500] 51 | shrinking_list: [-1] 52 | 53 | L1_weight_inital: 0.0 54 | L1_weight_rest: 0.0 55 | 56 | TV_weight_density: 0.0 57 | TV_weight_app: 0.00 58 | 59 | exportation: 60 | render_only: 0 61 | render_test: 1 62 | render_train: 0 63 | render_path: 0 64 | export_mesh: 0 65 | export_mesh_only: 0 66 | 67 | renderer: 68 | shadingMode: MLP_Fea 69 | num_layers: 3 70 | hidden_dim: 128 71 | 72 | fea2denseAct: 'relu' 73 | density_shift: -10 74 | distance_scale: 25.0 75 | 76 | view_pe: 6 77 | fea_pe: 2 78 | 79 | lindisp: 0 80 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter' 81 | 82 | step_ratio: 0.5 83 | max_samples: 1600 84 | 85 | alphaMask_thres: 0.04 86 | rayMarch_weight_thres: 1e-3 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_lego 4 | logedir: ./logs 5 | 6 | mode: 'reconstruction' 7 | 8 | progress_refresh_rate: 10 9 | 10 | add_timestamp: 0 11 | 12 | model: 13 | basis_dims: [4,4,4,2,2,2] 14 | basis_resos: [32,51,70,89,108,128] 15 | coeff_reso: 32 16 | total_params: 10744166 17 | T_basis: 0 18 | T_coeff: 0 19 | 20 | coef_init: 1.0 21 | coef_mode: bilinear 22 | basis_mode: bilinear 23 | 24 | freq_bands: [1.0000, 1.3300, 1.7689, 2.3526, 3.1290, 4.1616] 25 | 26 | 27 | basis_mapping: 'sawtooth' 28 | with_dropout: False 29 | 30 | in_dim: 3 31 | out_dim: 32 32 | num_layers: 2 33 | hidden_dim: 128 34 | 35 | dataset: 36 | # loader options 37 | dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 38 | datadir: ./data/nerf_synthetic/lego 39 | 40 | with_depth: 0 41 | downsample_train: 1.0 42 | downsample_test: 1.0 43 | 44 | is_unbound: False 45 | 46 | training: 47 | # training options 48 | batch_size: 4096 49 | n_iters: 30000 50 | 51 | # learning rate 52 | lr_small: 0.001 53 | lr_large: 0.02 54 | 55 | lr_decay_iters: -1 56 | lr_decay_target_ratio: 0.1 # help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters' 57 | lr_upsample_reset: 1 # help='reset lr to inital after upsampling' 58 | 59 | # loss 60 | L1_weight_inital: 0.0 # help='loss weight' 61 | L1_weight_rest: 0 62 | Ortho_weight: 0.0 63 | TV_weight_density: 0.0 64 | TV_weight_app: 0.0 65 | 66 | # optimiziable 67 | coeff: True 68 | basis: True 69 | linear_mat: True 70 | renderModule: True 71 | 72 | 73 | -------------------------------------------------------------------------------- /configs/image.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_image 4 | logdir: ./logs 5 | 6 | mode: 'image' 7 | 8 | ckpt: null # help='specific weights npy file to reload for coarse network' 9 | 10 | model: 11 | basis_dims: [32,32,32,16,16,16] 12 | basis_resos: [32,51,70,89,108,128] 13 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 14 | 15 | total_params: 1426063 # albert 16 | # total_params: 61445328 # pluto 17 | # total_params: 71848800 #Girl_with_a_Pearl_Earring 18 | # total_params: 37138096 # Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg_base 19 | 20 | coeff_type: 'grid' 21 | basis_type: 'grid' 22 | 23 | coef_init: 0.001 24 | 25 | coef_mode: nearest 26 | basis_mode: nearest 27 | basis_mapping: 'sawtooth' 28 | 29 | 30 | in_dim: 2 31 | out_dim: 3 32 | num_layers: 2 33 | hidden_dim: 64 34 | with_dropout: False 35 | 36 | dataset: 37 | # loader options 38 | dataset_name: image 39 | datadir: "../data/image/albert.exr" 40 | # datadir: "../data/image//pluto.jpeg" 41 | # datadir: "../data/image//Girl_with_a_Pearl_Earring.jpeg" 42 | # datadir: "../data/image//Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg" 43 | 44 | 45 | training: 46 | n_iters: 10000 47 | batch_size: 102400 48 | 49 | # learning rate 50 | lr_small: 0.002 51 | lr_large: 0.002 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /configs/image_intro.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_image 4 | logdir: ./logs 5 | 6 | mode: 'demo' 7 | 8 | ckpt: null # help='specific weights npy file to reload for coarse network' 9 | 10 | model: 11 | in_dim: 2 12 | out_dim: 1 13 | 14 | basis_dims: [32,32,32,16,16,16] 15 | basis_resos: [32,51,70,89,108,128] 16 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 17 | 18 | 19 | 20 | 21 | 22 | # occNet 23 | coeff_type: 'none' 24 | basis_type: 'x' 25 | basis_mapping: 'x' 26 | num_layers: 8 27 | hidden_dim: 256 28 | 29 | 30 | # coef_init: 0.001 31 | 32 | # coef_mode: nearest 33 | # basis_mode: nearest 34 | # basis_mapping: 'sawtooth' 35 | 36 | with_dropout: False 37 | 38 | dataset: 39 | # loader options 40 | dataset_name: image 41 | datadir: ../data/image/cat_occupancy.png 42 | 43 | 44 | training: 45 | n_iters: 10000 46 | batch_size: 102400 47 | 48 | # learning rate 49 | lr_small: 0.0002 50 | lr_large: 0.0002 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /configs/image_set.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_image 4 | logdir: ./logs 5 | 6 | mode: 'images' 7 | 8 | ckpt: null # help='specific weights npy file to reload for coarse network' 9 | 10 | model: 11 | basis_dims: [32,32,32,16,16,16] 12 | basis_resos: [32,51,70,89,108,128] 13 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 14 | 15 | 16 | coeff_reso: 32 17 | total_params: 1024000 18 | 19 | coef_init: 0.001 20 | 21 | coef_mode: bilinear 22 | basis_mode: bilinear 23 | 24 | 25 | coeff_type: 'grid' 26 | basis_type: 'grid' 27 | 28 | in_dim: 3 29 | out_dim: 3 30 | num_layers: 2 31 | hidden_dim: 64 32 | with_dropout: True 33 | 34 | dataset: 35 | # loader options 36 | dataset_name: images 37 | datadir: data/ffhq/ffhq_512.npy 38 | 39 | training: 40 | n_iters: 300000 41 | batch_size: 40960 42 | 43 | # learning rate 44 | lr_small: 0.002 45 | lr_large: 0.002 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/nerf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | expname: basis_lego 3 | logdir: ./logs 4 | 5 | mode: 'reconstruction' 6 | 7 | ckpt: null # help='specific weights npy file to reload for coarse network' 8 | 9 | model: 10 | coeff_reso: 32 11 | 12 | basis_dims: [4,4,4,2,2,2] 13 | basis_resos: [32,51,70,89,108,128] 14 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 15 | 16 | 17 | coef_init: 1.0 18 | phases: [0.0] 19 | total_params: 5308416 20 | 21 | coef_mode: bilinear 22 | basis_mode: bilinear 23 | 24 | coeff_type: 'grid' 25 | basis_type: 'grid' 26 | basis_mapping: 'sawtooth' 27 | 28 | in_dim: 3 29 | out_dim: 32 30 | num_layers: 2 31 | hidden_dim: 64 32 | 33 | dataset: 34 | # loader options 35 | dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 36 | datadir: ./data/nerf_synthetic/lego 37 | ndc_ray: 0 38 | 39 | with_depth: 0 40 | downsample_train: 1.0 41 | downsample_test: 1.0 42 | 43 | N_vis: 5 44 | vis_every: 100000 45 | scene_reso: 768 46 | 47 | training: 48 | 49 | n_iters: 30000 50 | batch_size: 4096 51 | 52 | volume_resoInit: 128 # 128**3: 53 | volume_resoFinal: 300 # 300**3 54 | 55 | upsamp_list: [2000,3000,4000,5500,7000] 56 | update_AlphaMask_list: [2500,4000] 57 | shrinking_list: [500] 58 | 59 | L1_weight_inital: 0.0 60 | L1_weight_rest: 0.0 61 | 62 | TV_weight_density: 0.000 63 | TV_weight_app: 0.00 64 | 65 | exportation: 66 | render_only: 0 67 | render_test: 1 68 | render_train: 0 69 | render_path: 0 70 | export_mesh: 0 71 | export_mesh_only: 0 72 | 73 | renderer: 74 | shadingMode: MLP_Fea 75 | num_layers: 3 76 | hidden_dim: 128 77 | 78 | fea2denseAct: 'softplus' 79 | density_shift: -10 80 | distance_scale: 25.0 81 | 82 | view_pe: 6 83 | fea_pe: 2 84 | 85 | lindisp: 0 86 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter' 87 | 88 | step_ratio: 0.5 89 | max_samples: 1200 90 | 91 | alphaMask_thres: 0.02 92 | rayMarch_weight_thres: 1e-3 93 | -------------------------------------------------------------------------------- /configs/nerf_ft.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | expname: basis 3 | logdir: ./logs 4 | 5 | mode: 'reconstructions' 6 | 7 | ckpt: null # help='specific weights npy file to reload for coarse network' 8 | 9 | model: 10 | coeff_reso: 16 11 | 12 | basis_dims: [16,16,16,8,8,8] 13 | basis_resos: [32,51,70,89,108,128] 14 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 15 | 16 | with_dropout: True 17 | 18 | coef_init: 1.0 19 | phases: [0.0] 20 | total_params: 5308416 21 | 22 | coef_mode: bilinear 23 | basis_mode: bilinear 24 | 25 | coeff_type: 'grid' 26 | basis_type: 'grid' 27 | basis_mapping: 'sawtooth' 28 | 29 | in_dim: 3 30 | out_dim: 32 31 | num_layers: 2 32 | hidden_dim: 64 33 | 34 | dataset: 35 | # loader options 36 | dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 37 | datadir: /vlg-nfs/anpei/dataset/google_scanned_objects 38 | ndc_ray: 0 39 | train_scene_list: [100] 40 | test_scene_list: [100] 41 | train_views: 5 42 | 43 | with_depth: 0 44 | downsample_train: 1.0 45 | downsample_test: 1.0 46 | 47 | N_vis: 5 48 | vis_every: 100000 49 | scene_reso: 768 50 | 51 | training: 52 | 53 | n_iters: 5000 54 | batch_size: 4096 55 | 56 | volume_resoInit: 128 # 128**3: 57 | volume_resoFinal: 300 # 300**3 58 | 59 | upsamp_list: [2000,3000,4000] 60 | update_AlphaMask_list: [1500] 61 | shrinking_list: [-1] 62 | 63 | L1_weight_inital: 0.0 64 | L1_weight_rest: 0.0 65 | 66 | TV_weight_density: 0.000 67 | TV_weight_app: 0.00 68 | 69 | # optimiziable 70 | coeff: True 71 | basis: False 72 | linear_mat: False 73 | renderModule: False 74 | 75 | exportation: 76 | render_only: 0 77 | render_test: 1 78 | render_train: 0 79 | render_path: 0 80 | export_mesh: 0 81 | export_mesh_only: 0 82 | 83 | renderer: 84 | shadingMode: MLP_Fea 85 | num_layers: 3 86 | hidden_dim: 128 87 | 88 | fea2denseAct: 'softplus' 89 | density_shift: -10 90 | distance_scale: 25.0 91 | 92 | view_pe: 6 93 | fea_pe: 2 94 | 95 | lindisp: 0 96 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter' 97 | 98 | step_ratio: 0.5 99 | max_samples: 1200 100 | 101 | alphaMask_thres: 0.02 102 | rayMarch_weight_thres: 1e-3 -------------------------------------------------------------------------------- /configs/nerf_set.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | expname: basis_no_relu_lego 3 | logdir: ./logs 4 | 5 | mode: 'reconstructions' 6 | 7 | ckpt: null # help='specific weights npy file to reload for coarse network' 8 | 9 | model: 10 | coeff_reso: 16 11 | 12 | basis_dims: [16,16,16,8,8,8] 13 | # basis_resos: [32,51,70,89,108,128] 14 | basis_resos: [32,51,70,89,108,128] 15 | # freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.] 16 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 17 | 18 | with_dropout: True 19 | 20 | coef_init: 1.0 21 | phases: [0.0] 22 | total_params: 5308416 23 | 24 | coef_mode: bilinear 25 | basis_mode: bilinear 26 | 27 | coeff_type: 'grid' 28 | basis_type: 'grid' 29 | basis_mapping: 'sawtooth' 30 | 31 | in_dim: 3 32 | out_dim: 32 33 | num_layers: 2 34 | hidden_dim: 64 35 | 36 | dataset: 37 | # loader options 38 | dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 39 | datadir: /vlg-nfs/anpei/dataset/google_scanned_objects 40 | ndc_ray: 0 41 | train_scene_list: [0,100] 42 | test_scene_list: [0] 43 | train_views: 100 44 | 45 | with_depth: 0 46 | downsample_train: 1.0 47 | downsample_test: 1.0 48 | 49 | N_vis: 5 50 | vis_every: 100000 51 | scene_reso: 768 52 | 53 | training: 54 | 55 | n_iters: 50000 56 | batch_size: 4096 57 | 58 | volume_resoInit: 128 # 128**3: 59 | volume_resoFinal: 256 # 300**3 60 | 61 | upsamp_list: [2000,3000,4000,5500,7000] 62 | update_AlphaMask_list: [-1] 63 | shrinking_list: [-1] 64 | 65 | L1_weight_inital: 0.0 66 | L1_weight_rest: 0.0 67 | 68 | TV_weight_density: 0.000 69 | TV_weight_app: 0.00 70 | 71 | exportation: 72 | render_only: 0 73 | render_test: 0 74 | render_train: 0 75 | render_path: 0 76 | export_mesh: 0 77 | export_mesh_only: 0 78 | 79 | renderer: 80 | shadingMode: MLP_Fea 81 | num_layers: 3 82 | hidden_dim: 128 83 | 84 | fea2denseAct: 'softplus' 85 | density_shift: -10 86 | distance_scale: 25.0 87 | 88 | view_pe: 6 89 | fea_pe: 2 90 | 91 | lindisp: 0 92 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter' 93 | 94 | step_ratio: 0.5 95 | max_samples: 1200 96 | 97 | alphaMask_thres: 0.005 98 | rayMarch_weight_thres: 1e-3 99 | -------------------------------------------------------------------------------- /configs/sdf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | expname: basis_sdf 3 | logdir: ./logs 4 | 5 | mode: 'sdf' 6 | 7 | ckpt: null # help='specific weights npy file to reload for coarse network' 8 | 9 | model: 10 | basis_dims: [4,4,4,2,2,2] 11 | basis_resos: [32,51,70,89,108,128] 12 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 13 | 14 | total_params: 5313942 15 | 16 | coeff_reso: 32 17 | coef_init: 0.05 18 | 19 | coef_mode: bilinear 20 | basis_mode: bilinear 21 | 22 | 23 | coeff_type: 'grid' 24 | basis_type: 'grid' 25 | kernel_mapping_type: 'sawtooth' 26 | 27 | in_dim: 3 28 | out_dim: 1 29 | num_layers: 1 30 | hidden_dim: 64 31 | 32 | dataset: 33 | # loader options 34 | dataset_name: sdf 35 | datadir: "../data/mesh/statuette_close.npy" 36 | 37 | scene_reso: 384 38 | 39 | 40 | training: 41 | n_iters: 10000 42 | batch_size: 40960 43 | 44 | # learning rate 45 | lr_small: 0.002 46 | lr_large: 0.02 -------------------------------------------------------------------------------- /configs/tnt.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | expname: basis_truck 4 | logdir: ./logs 5 | 6 | ckpt: null # help='specific weights npy file to reload for coarse network' 7 | 8 | model: 9 | coeff_reso: 32 10 | 11 | # basis_dims: [8,4,2] 12 | # basis_resos: [32,64,128] 13 | # freq_bands: [3.0,4.7,6.8] 14 | 15 | ## 32.88 16 | # basis_dims: [3, 3, 3, 3, 3, 3, 3] 17 | # basis_resos: [64, 64, 64, 64, 64, 64, 64] 18 | ## freq_bands: [1.52727273, 2.58181818, 3.63636364, 4.69090909, 5.21818182, 6.27272727, 6.8] 19 | # freq_bands: [2., 3., 4.,5.,6.,7.,8.] 20 | 21 | # basis_dims: [5,5,5,2,2,2] 22 | # basis_resos: [32,51,70,89,108,128] 23 | # coeff_reso: 26 24 | # freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.] 25 | 26 | basis_dims: [4,4,4,2,2,2] 27 | basis_resos: [32,51,70,89,108,128] 28 | freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.] 29 | # freq_bands: [2. , 2.8, 3.6, 4.4, 5.2, 6.] 30 | 31 | coef_init: 1.0 32 | phases: [0.0] 33 | total_params: 5744166 34 | 35 | coef_mode: bilinear 36 | basis_mode: bilinear 37 | 38 | kernel_mapping_type: 'sawtooth' 39 | 40 | in_dim: 3 41 | out_dim: 32 42 | num_layers: 2 43 | hidden_dim: 64 44 | 45 | dataset: 46 | # loader options 47 | dataset_name: tankstemple # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'] 48 | datadir: ./data/TanksAndTemple/Truck 49 | ndc_ray: 0 50 | 51 | with_depth: 0 52 | downsample_train: 1.0 53 | downsample_test: 1.0 54 | 55 | N_vis: 5 56 | vis_every: 100000 57 | scene_reso: 768 58 | 59 | training: 60 | 61 | n_iters: 30000 62 | batch_size: 4096 63 | 64 | volume_resoInit: 128 # 128**3: 65 | volume_resoFinal: 320 # 300**3 66 | 67 | # upsamp_list: [2000,4000,7000] 68 | # update_AlphaMask_list: [2000,3000] 69 | upsamp_list: [2000,3000,4000,5500,7000] 70 | update_AlphaMask_list: [2500,4000] 71 | shrinking_list: [500] 72 | 73 | L1_weight_inital: 0.0 74 | L1_weight_rest: 0.0 75 | 76 | TV_weight_density: 0.0 77 | TV_weight_app: 0.00 78 | 79 | exportation: 80 | render_only: 0 81 | render_test: 1 82 | render_train: 0 83 | render_path: 0 84 | export_mesh: 0 85 | export_mesh_only: 0 86 | 87 | renderer: 88 | shadingMode: MLP_Fea 89 | num_layers: 3 90 | hidden_dim: 128 91 | 92 | fea2denseAct: 'softplus' 93 | density_shift: -10 94 | distance_scale: 25.0 95 | 96 | view_pe: 2 97 | fea_pe: 2 98 | 99 | lindisp: 0 100 | perturb: 1 # help='set to 0. for no jitter, 1. for jitter' 101 | 102 | step_ratio: 0.5 103 | max_samples: 1200 104 | 105 | alphaMask_thres: 0.005 106 | rayMarch_weight_thres: 1e-3 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /dataLoader/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/dataLoader/.DS_Store -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .llff import LLFFDataset 2 | from .blender import BlenderDataset 3 | from .nsvf import NSVF 4 | from .tankstemple import TanksTempleDataset 5 | from .your_own_data import YourOwnDataset 6 | from .image import ImageDataset 7 | from .image_set import ImageSetDataset 8 | from .colmap import ColmapDataset 9 | from .sdf import SDFDataset 10 | from .blender_set import BlenderDatasetSet 11 | from .google_objs import GoogleObjsDataset 12 | from .dtu_objs import DTUDataset 13 | 14 | 15 | dataset_dict = {'blender': BlenderDataset, 16 | 'blender_set': BlenderDatasetSet, 17 | 'llff':LLFFDataset, 18 | 'tankstemple':TanksTempleDataset, 19 | 'nsvf':NSVF, 20 | 'own_data':YourOwnDataset, 21 | 'image':ImageDataset, 22 | 'images':ImageSetDataset, 23 | 'sdf':SDFDataset, 24 | 'colmap':ColmapDataset, 25 | 'google_objs':GoogleObjsDataset, 26 | 'dtu':DTUDataset} -------------------------------------------------------------------------------- /dataLoader/blender.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | class BlenderDataset(Dataset): 13 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None): 14 | 15 | # self.N_vis = N_vis 16 | self.split = split 17 | self.batch_size = batch_size 18 | self.root_dir = cfg.datadir 19 | self.is_stack = is_stack if is_stack is not None else 'train'!=split 20 | self.downsample = cfg.get(f'downsample_{self.split}') 21 | self.img_wh = (int(800 / self.downsample), int(800 / self.downsample)) 22 | self.define_transforms() 23 | 24 | self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659], 25 | [0.73729737, 0.44876192, -0.50498052], 26 | [0.16296514, 0.60727077, 0.77760181]]) 27 | 28 | self.scene_bbox = (np.array([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]])).tolist() 29 | # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]] 30 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 31 | self.read_meta() 32 | self.define_proj_mat() 33 | 34 | self.white_bg = True 35 | self.near_far = [2.0, 6.0] 36 | 37 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 38 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 39 | 40 | def read_depth(self, filename): 41 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 42 | return depth 43 | 44 | def read_meta(self): 45 | 46 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 47 | self.meta = json.load(f) 48 | 49 | w, h = self.img_wh 50 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 51 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh 52 | 53 | # ray directions for all pixels, same for all images (same H, W, focal) 54 | self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3) 55 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 56 | self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float() 57 | 58 | self.image_paths = [] 59 | self.poses = [] 60 | self.all_rays = [] 61 | self.all_rgbs = [] 62 | self.all_masks = [] 63 | self.all_depth = [] 64 | self.downsample = 1.0 65 | 66 | img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 67 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 68 | # idxs = idxs[:10] if self.split=='train' else idxs 69 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:# 70 | 71 | frame = self.meta['frames'][i] 72 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 73 | c2w = torch.FloatTensor(pose) 74 | c2w[:3,-1] /= 1.5 75 | self.poses += [c2w] 76 | 77 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 78 | self.image_paths += [image_path] 79 | img = Image.open(image_path) 80 | 81 | if self.downsample != 1.0: 82 | img = img.resize(self.img_wh, Image.LANCZOS) 83 | img = self.transform(img) # (4, h, w) 84 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 85 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 86 | self.all_rgbs += [img] 87 | 88 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 89 | # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot 90 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 91 | 92 | self.poses = torch.stack(self.poses) 93 | if not self.is_stack: 94 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 95 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 96 | 97 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 98 | else: 99 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 100 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3) 101 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 102 | 103 | self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]),self.batch_size) 104 | 105 | def define_transforms(self): 106 | self.transform = T.ToTensor() 107 | 108 | def define_proj_mat(self): 109 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3] 110 | 111 | # def world2ndc(self,points,lindisp=None): 112 | # device = points.device 113 | # return (points - self.center.to(device)) / self.radius.to(device) 114 | 115 | def __len__(self): 116 | return len(self.all_rgbs) if self.split=='test' else 300000 117 | 118 | def __getitem__(self, idx): 119 | idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,)) 120 | sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]} 121 | return sample -------------------------------------------------------------------------------- /dataLoader/blender_set.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | class BlenderDatasetSet(Dataset): 13 | def __init__(self, cfg, split='train'): 14 | 15 | # self.N_vis = N_vis 16 | self.root_dir = cfg.datadir 17 | self.split = split 18 | self.is_stack = False if 'train'==split else True 19 | self.downsample = cfg.get(f'downsample_{self.split}') 20 | self.img_wh = (int(800 / self.downsample), int(800 / self.downsample)) 21 | self.define_transforms() 22 | 23 | self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659], 24 | [0.73729737, 0.44876192, -0.50498052], 25 | [0.16296514, 0.60727077, 0.77760181]]) 26 | 27 | self.scene_bbox = (np.array([[-1.0, -1.0, -1.0, 0], [1.0, 1.0, 1.0, 2]])).tolist() 28 | # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]] 29 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 30 | self.read_meta() 31 | self.define_proj_mat() 32 | 33 | self.white_bg = True 34 | self.near_far = [2.0, 6.0] 35 | 36 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 37 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 38 | 39 | def read_depth(self, filename): 40 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 41 | return depth 42 | 43 | def read_meta(self): 44 | 45 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 46 | self.meta = json.load(f) 47 | 48 | w, h = self.img_wh 49 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 50 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh 51 | 52 | # ray directions for all pixels, same for all images (same H, W, focal) 53 | self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3) 54 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 55 | self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float() 56 | 57 | self.image_paths = [] 58 | self.poses = [] 59 | self.all_rays = [] 60 | self.all_rgbs = [] 61 | self.all_masks = [] 62 | self.all_depth = [] 63 | self.downsample = 1.0 64 | 65 | img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 66 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 67 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:# 68 | 69 | frame = self.meta['frames'][i] 70 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 71 | c2w = torch.FloatTensor(pose) 72 | c2w[:3,-1] /= 1.5 73 | self.poses += [c2w] 74 | 75 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 76 | self.image_paths += [image_path] 77 | img = Image.open(image_path) 78 | 79 | if self.downsample != 1.0: 80 | img = img.resize(self.img_wh, Image.LANCZOS) 81 | img = self.transform(img) # (4, h, w) 82 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 83 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 84 | self.all_rgbs += [img] 85 | 86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 87 | # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot 88 | 89 | scene_id = torch.ones_like(rays_o[...,:1])*0 90 | self.all_rays += [torch.cat([rays_o, rays_d, scene_id], 1)] # (h*w, 6) 91 | 92 | self.poses = torch.stack(self.poses) 93 | if not self.is_stack: 94 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 95 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 96 | 97 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 98 | else: 99 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 100 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3) 101 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 102 | 103 | def define_transforms(self): 104 | self.transform = T.ToTensor() 105 | 106 | def define_proj_mat(self): 107 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3] 108 | 109 | # def world2ndc(self,points,lindisp=None): 110 | # device = points.device 111 | # return (points - self.center.to(device)) / self.radius.to(device) 112 | 113 | def __len__(self): 114 | return len(self.all_rgbs) 115 | 116 | def __getitem__(self, idx): 117 | rays = torch.cat((self.all_rays[idx],torch.tensor([0+0.5])),dim=-1) 118 | sample = {'rays': rays, 'rgbs': self.all_rgbs[idx]} 119 | return sample -------------------------------------------------------------------------------- /dataLoader/colmap.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | from torch.utils.data import Dataset 3 | 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | class ColmapDataset(Dataset): 13 | def __init__(self, cfg, split='train'): 14 | 15 | self.cfg = cfg 16 | self.root_dir = cfg.datadir 17 | self.split = split 18 | self.is_stack = False if 'train'==split else True 19 | self.downsample = cfg.get(f'downsample_{self.split}') 20 | self.define_transforms() 21 | self.img_eval_interval = 8 22 | 23 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])#np.eye(4)# 24 | self.read_meta() 25 | 26 | self.white_bg = cfg.get('white_bg') 27 | 28 | # self.near_far = [0.1,2.0] 29 | 30 | 31 | def read_meta(self): 32 | 33 | if os.path.exists(f'{self.root_dir}/transforms.json'): 34 | self.meta = load_json(f'{self.root_dir}/transforms.json') 35 | i_test = np.arange(0, len(self.meta['frames']), self.img_eval_interval) # [np.argmin(dists)] 36 | idxs = i_test if self.split != 'train' else list(set(np.arange(len(self.meta['frames']))) - set(i_test)) 37 | else: 38 | self.meta = load_json(f'{self.root_dir}/transforms_{self.split}.json') 39 | idxs = np.arange(0, len(self.meta['frames'])) # [np.argmin(dists)] 40 | inv_split = 'train' if self.split!='train' else 'test' 41 | self.meta['frames'] += load_json(f'{self.root_dir}/transforms_{inv_split}.json')['frames'] 42 | print(len(self.meta['frames']),len(idxs)) 43 | 44 | 45 | self.scale = self.meta.get('scale', 0.5) 46 | self.offset = torch.FloatTensor(self.meta.get('offset', [0.0,0.0,0.0])) 47 | # self.scene_bbox = (torch.tensor([[-6.,-7.,-10.0],[6.,7.,10.]])/5).tolist() 48 | # self.scene_bbox = [[-1., -1., -1.0], [1., 1., 1.]] 49 | 50 | # center, radius = torch.tensor([-0.082157, 2.415426,-3.703080]), torch.tensor([7.36916, 11.34958, 20.1616])/2 51 | # self.scene_bbox = torch.stack([center-radius, center+radius]).tolist() 52 | 53 | h, w = int(self.meta.get('h')), int(self.meta.get('w')) 54 | cx, cy = self.meta.get('cx'), self.meta.get('cy') 55 | self.focal = [self.meta.get('fl_x'), self.meta.get('fl_y')] 56 | 57 | # ray directions for all pixels, same for all images (same H, W, focal) 58 | self.directions = get_ray_directions(h, w, self.focal, center=[cx, cy]) # (h, w, 3) 59 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 60 | # self.intrinsics = torch.FloatTensor([[self.focal[0], 0, cx], [0, self.focal[1], cy], [0, 0, 1]]) 61 | # self.intrinsics[:2] /= self.downsample 62 | 63 | poses = pose_from_json(self.meta, self.blender2opencv) 64 | poses, self.scene_bbox = orientation(poses, f'{self.root_dir}/colmap_text/points3D.txt') 65 | 66 | self.image_paths = [] 67 | self.poses = [] 68 | self.all_rays = [] 69 | self.all_rgbs = [] 70 | self.all_masks = [] 71 | self.all_depth = [] 72 | 73 | self.img_wh = [w,h] 74 | 75 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:# 76 | 77 | frame = self.meta['frames'][i] 78 | c2w = torch.FloatTensor(poses[i]) 79 | # c2w[:3,3] = (c2w[:3,3]*self.scale + self.offset)*2-1 80 | self.poses += [c2w] 81 | 82 | image_path = os.path.join(self.root_dir, frame['file_path']) 83 | self.image_paths += [image_path] 84 | img = Image.open(image_path) 85 | 86 | if self.downsample != 1.0: 87 | img = img.resize(self.img_wh, Image.LANCZOS) 88 | 89 | img = self.transform(img) 90 | if img.shape[0]==4: 91 | img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB 92 | img = img.view(3, -1).permute(1, 0) 93 | self.all_rgbs += [img] 94 | 95 | 96 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 97 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 98 | 99 | self.poses = torch.stack(self.poses) 100 | if not self.is_stack: 101 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 102 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 103 | else: 104 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 105 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3) 106 | 107 | def define_transforms(self): 108 | self.transform = T.ToTensor() 109 | 110 | 111 | def __len__(self): 112 | return len(self.all_rgbs) 113 | 114 | def __getitem__(self, idx): 115 | 116 | if self.split == 'train': # use data in the buffers 117 | sample = {'rays': self.all_rays[idx], 118 | 'rgbs': self.all_rgbs[idx]} 119 | 120 | else: # create data for each image separately 121 | 122 | img = self.all_rgbs[idx] 123 | rays = self.all_rays[idx] 124 | 125 | sample = {'rays': rays, 126 | 'rgbs': img} 127 | return sample -------------------------------------------------------------------------------- /dataLoader/colmap2nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from pathlib import Path, PurePosixPath 14 | 15 | import numpy as np 16 | import json 17 | import sys 18 | import math 19 | import cv2 20 | import os 21 | import shutil 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") 25 | 26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") 27 | parser.add_argument("--video_fps", default=2) 28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") 29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") 31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 32 | parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL","OPENCV"], help="camera model") 33 | parser.add_argument("--colmap_camera_params", default="", help="intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist") 34 | parser.add_argument("--images", default="images", help="input path to the images") 35 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") 36 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") 37 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start") 38 | parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)") 39 | parser.add_argument("--out", default="transforms.json", help="output path") 40 | parser.add_argument("--vocab_path", default="", help="vocabulary tree path") 41 | args = parser.parse_args() 42 | return args 43 | 44 | def do_system(arg): 45 | print(f"==== running: {arg}") 46 | err = os.system(arg) 47 | if err: 48 | print("FATAL: command failed") 49 | sys.exit(err) 50 | 51 | def run_ffmpeg(args): 52 | if not os.path.isabs(args.images): 53 | args.images = os.path.join(os.path.dirname(args.video_in), args.images) 54 | images = args.images 55 | video = args.video_in 56 | fps = float(args.video_fps) or 1.0 57 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") 58 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 59 | sys.exit(1) 60 | try: 61 | shutil.rmtree(images) 62 | except: 63 | pass 64 | do_system(f"mkdir {images}") 65 | 66 | time_slice_value = "" 67 | time_slice = args.time_slice 68 | if time_slice: 69 | start, end = time_slice.split(",") 70 | time_slice_value = f",select='between(t\,{start}\,{end})'" 71 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") 72 | 73 | def run_colmap(args): 74 | db = args.colmap_db 75 | images = "\"" + args.images + "\"" 76 | db_noext=str(Path(db).with_suffix("")) 77 | 78 | if args.text=="text": 79 | args.text=db_noext+"_text" 80 | text=args.text 81 | sparse=db_noext+"_sparse" 82 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") 83 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 84 | sys.exit(1) 85 | if os.path.exists(db): 86 | os.remove(db) 87 | do_system(f"colmap feature_extractor --ImageReader.camera_model {args.colmap_camera_model} --ImageReader.camera_params \"{args.colmap_camera_params}\" --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") 88 | match_cmd = f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}" 89 | if args.vocab_path: 90 | match_cmd += f" --VocabTreeMatching.vocab_tree_path {args.vocab_path}" 91 | do_system(match_cmd) 92 | try: 93 | shutil.rmtree(sparse) 94 | except: 95 | pass 96 | do_system(f"mkdir {sparse}") 97 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") 98 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") 99 | try: 100 | shutil.rmtree(text) 101 | except: 102 | pass 103 | do_system(f"mkdir {text}") 104 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") 105 | 106 | 107 | def variance_of_laplacian(image): 108 | return cv2.Laplacian(image, cv2.CV_64F).var() 109 | 110 | def sharpness(imagePath): 111 | image = cv2.imread(imagePath) 112 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 113 | fm = variance_of_laplacian(gray) 114 | return fm 115 | 116 | def qvec2rotmat(qvec): 117 | return np.array([ 118 | [ 119 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 120 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 121 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 122 | ], [ 123 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 124 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 125 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 126 | ], [ 127 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 128 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 129 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 130 | ] 131 | ]) 132 | 133 | def rotmat(a, b): 134 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 135 | v = np.cross(a, b) 136 | c = np.dot(a, b) 137 | s = np.linalg.norm(v) 138 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 139 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 140 | 141 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 142 | da = da / np.linalg.norm(da) 143 | db = db / np.linalg.norm(db) 144 | c = np.cross(da, db) 145 | denom = np.linalg.norm(c)**2 146 | t = ob - oa 147 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 148 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 149 | if ta > 0: 150 | ta = 0 151 | if tb > 0: 152 | tb = 0 153 | return (oa+ta*da+ob+tb*db) * 0.5, denom 154 | 155 | ############ orientation ############## 156 | def normalize(x): 157 | return x / np.linalg.norm(x) 158 | 159 | def rotation_matrix_from_vectors(vec1, vec2): 160 | """ Find the rotation matrix that aligns vec1 to vec2 161 | :param vec1: A 3d "source" vector 162 | :param vec2: A 3d "destination" vector 163 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2. 164 | """ 165 | a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3) 166 | v = np.cross(a, b) 167 | if any(v): # if not all zeros then 168 | c = np.dot(a, b) 169 | s = np.linalg.norm(v) 170 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 171 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2)) 172 | 173 | else: 174 | return np.eye(3) # cross of all zeros only occurs on identical directions 175 | 176 | def rotation_up(poses): 177 | up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)),rcond=None)[0]) 178 | rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.])) 179 | return rot 180 | 181 | def search_orientation(points): 182 | from scipy.spatial.transform import Rotation as R 183 | bbox_sizes,rot_mats,bboxs = [],[],[] 184 | for y_angle in np.linspace(-45,45,15): 185 | rotvec = np.array([0,y_angle,0])/180*np.pi 186 | rot = R.from_rotvec(rotvec).as_matrix() 187 | point_orientation = rot@points 188 | bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1) 189 | bbox_sizes.append(np.prod(bbox)) 190 | rot_mats.append(rot) 191 | bboxs.append(bbox) 192 | rot = rot_mats[np.argmin(bbox_sizes)] 193 | bbox = bboxs[np.argmin(bbox_sizes)] 194 | return rot,bbox 195 | 196 | def load_point_txt(path): 197 | points = [] 198 | with open(path, "r") as f: 199 | for line in f: 200 | if line[0] == "#": 201 | continue 202 | els = line.split(" ") 203 | points.append([float(els[1]),float(els[2]),float(els[3])]) 204 | return np.stack(points) 205 | 206 | # def load_c2ws(frames): 207 | # c2ws = 208 | # for f in frames: 209 | # f["transform_matrix"] -= totp 210 | # 211 | # def oritation(transform_matrix, point_txt): 212 | 213 | 214 | if __name__ == "__main__": 215 | args = parse_args() 216 | if args.video_in != "": 217 | run_ffmpeg(args) 218 | if args.run_colmap: 219 | run_colmap(args) 220 | AABB_SCALE = int(args.aabb_scale) 221 | SKIP_EARLY = int(args.skip_early) 222 | IMAGE_FOLDER = args.images 223 | TEXT_FOLDER = args.text 224 | OUT_PATH = args.out 225 | print(f"outputting to {OUT_PATH}...") 226 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: 227 | angle_x = math.pi / 2 228 | for line in f: 229 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 230 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 231 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 232 | if line[0] == "#": 233 | continue 234 | els = line.split(" ") 235 | w = float(els[2]) 236 | h = float(els[3]) 237 | fl_x = float(els[4]) 238 | fl_y = float(els[4]) 239 | k1 = 0 240 | k2 = 0 241 | p1 = 0 242 | p2 = 0 243 | cx = w / 2 244 | cy = h / 2 245 | if els[1] == "SIMPLE_PINHOLE": 246 | cx = float(els[5]) 247 | cy = float(els[6]) 248 | elif els[1] == "PINHOLE": 249 | fl_y = float(els[5]) 250 | cx = float(els[6]) 251 | cy = float(els[7]) 252 | elif els[1] == "SIMPLE_RADIAL": 253 | cx = float(els[5]) 254 | cy = float(els[6]) 255 | k1 = float(els[7]) 256 | elif els[1] == "RADIAL": 257 | cx = float(els[5]) 258 | cy = float(els[6]) 259 | k1 = float(els[7]) 260 | k2 = float(els[8]) 261 | elif els[1] == "OPENCV": 262 | fl_y = float(els[5]) 263 | cx = float(els[6]) 264 | cy = float(els[7]) 265 | k1 = float(els[8]) 266 | k2 = float(els[9]) 267 | p1 = float(els[10]) 268 | p2 = float(els[11]) 269 | else: 270 | print("unknown camera model ", els[1]) 271 | # fl = 0.5 * w / tan(0.5 * angle_x); 272 | angle_x = math.atan(w / (fl_x * 2)) * 2 273 | angle_y = math.atan(h / (fl_y * 2)) * 2 274 | fovx = angle_x * 180 / math.pi 275 | fovy = angle_y * 180 / math.pi 276 | 277 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") 278 | 279 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: 280 | i = 0 281 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 282 | out = { 283 | "camera_angle_x": angle_x, 284 | "camera_angle_y": angle_y, 285 | "fl_x": fl_x, 286 | "fl_y": fl_y, 287 | "k1": k1, 288 | "k2": k2, 289 | "p1": p1, 290 | "p2": p2, 291 | "cx": cx, 292 | "cy": cy, 293 | "w": w, 294 | "h": h, 295 | "aabb_scale": AABB_SCALE, 296 | "frames": [], 297 | } 298 | 299 | up = np.zeros(3) 300 | for line in f: 301 | line = line.strip() 302 | if line[0] == "#": 303 | continue 304 | i = i + 1 305 | if i < SKIP_EARLY*2: 306 | continue 307 | if i % 2 == 1: 308 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 309 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) 310 | # why is this requireing a relitive path while using ^ 311 | image_rel = os.path.relpath(IMAGE_FOLDER) 312 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}") 313 | b=sharpness(name) 314 | print(name, "sharpness=",b) 315 | image_id = int(elems[0]) 316 | qvec = np.array(tuple(map(float, elems[1:5]))) 317 | tvec = np.array(tuple(map(float, elems[5:8]))) 318 | R = qvec2rotmat(-qvec) 319 | t = tvec.reshape([3,1]) 320 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 321 | c2w = np.linalg.inv(m) 322 | # c2w[0:3,2] *= -1 # flip the y and z axis 323 | # c2w[0:3,1] *= -1 324 | # c2w = c2w[[1,0,2,3],:] # swap y and z 325 | # c2w[2,:] *= -1 # flip whole world upside down 326 | 327 | # up += c2w[0:3,1] 328 | 329 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} 330 | out["frames"].append(frame) 331 | 332 | # up = up / np.linalg.norm(up) 333 | # print("up vector was", up) 334 | # R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] 335 | # R = np.pad(R,[0,1]) 336 | # R[-1, -1] = 1 337 | # for f in out["frames"]: 338 | # f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 339 | 340 | nframes = len(out["frames"]) 341 | 342 | # find a central point they are all looking at 343 | print("computing center of attention...") 344 | totw = 0.0 345 | totp = np.array([0.0, 0.0, 0.0]) 346 | for f in out["frames"]: 347 | mf = f["transform_matrix"][0:3,:] 348 | for g in out["frames"]: 349 | mg = g["transform_matrix"][0:3,:] 350 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 351 | if w > 0.01: 352 | totp += p*w 353 | totw += w 354 | totp /= totw 355 | print(totp) # the cameras are looking at totp 356 | for f in out["frames"]: 357 | f["transform_matrix"][0:3,3] -= totp 358 | 359 | avglen = 0. 360 | for f in out["frames"]: 361 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 362 | avglen /= nframes 363 | print("avg camera distance from origin", avglen) 364 | for f in out["frames"]: 365 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 366 | 367 | 368 | 369 | for f in out["frames"]: 370 | f["transform_matrix"] = f["transform_matrix"].tolist() 371 | print(nframes,"frames") 372 | print(f"writing {OUT_PATH}") 373 | with open(OUT_PATH, "w") as outfile: 374 | json.dump(out, outfile, indent=2) -------------------------------------------------------------------------------- /dataLoader/dtu_objs.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from .ray_utils import * 8 | from torch.utils.data import Dataset 9 | 10 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 11 | def load_K_Rt_from_P(filename, P=None): 12 | if P is None: 13 | lines = open(filename).read().splitlines() 14 | if len(lines) == 4: 15 | lines = lines[1:] 16 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 17 | P = np.asarray(lines).astype(np.float32).squeeze() 18 | 19 | out = cv.decomposeProjectionMatrix(P) 20 | K = out[0] 21 | R = out[1] 22 | t = out[2] 23 | 24 | K = K / K[2, 2] 25 | intrinsics = np.eye(4) 26 | intrinsics[:3, :3] = K 27 | 28 | pose = np.eye(4, dtype=np.float32) 29 | pose[:3, :3] = R.transpose() 30 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 31 | 32 | return intrinsics, pose 33 | 34 | def fps_downsample(points, n_points_to_sample): 35 | selected_points = np.zeros((n_points_to_sample, 3)) 36 | selected_idxs = [] 37 | dist = np.ones(points.shape[0]) * 100 38 | for i in range(n_points_to_sample): 39 | idx = np.argmax(dist) 40 | selected_points[i] = points[idx] 41 | selected_idxs.append(idx) 42 | dist_ = ((points - selected_points[i]) ** 2).sum(-1) 43 | dist = np.minimum(dist, dist_) 44 | 45 | return selected_idxs 46 | 47 | class DTUDataset(Dataset): 48 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None): 49 | """ 50 | img_wh should be set to a tuple ex: (1152, 864) to enable test mode! 51 | """ 52 | # self.N_vis = N_vis 53 | self.split = split 54 | self.batch_size = batch_size 55 | self.root_dir = cfg.datadir 56 | self.is_stack = is_stack if is_stack is not None else 'train'!=split 57 | self.downsample = cfg.get(f'downsample_{self.split}') 58 | self.img_wh = (int(400 / self.downsample), int(300 / self.downsample)) 59 | 60 | train_scene_idxs = sorted(cfg.train_scene_list) 61 | test_scene_idxs = cfg.test_scene_list 62 | if len(train_scene_idxs)==2: 63 | train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1])) 64 | self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs 65 | print(self.scene_idxs) 66 | self.train_views = cfg.train_views 67 | self.scene_num = len(self.scene_idxs) 68 | self.test_index = test_scene_idxs 69 | # if 'test' == self.split: 70 | # self.test_index = train_scene_idxs.index(test_scene_idxs[0]) 71 | 72 | self.scene_path_list = [os.path.join(self.root_dir, f"scan{i}") for i in self.scene_idxs] 73 | # self.scene_path_list = sorted(glob(os.path.join(self.root_dir, "scan*"))) 74 | 75 | self.read_meta() 76 | self.white_bg = False 77 | 78 | def read_meta(self): 79 | self.aabbs = [] 80 | self.all_rgb_files,self.all_pose_files,self.all_intrinsics_files = {},{},{} 81 | for i, scene_idx in enumerate(self.scene_idxs): 82 | 83 | scene_path = self.scene_path_list[i] 84 | camera_dict = np.load(os.path.join(scene_path, 'cameras.npz')) 85 | 86 | self.all_rgb_files[scene_idx] = [ 87 | os.path.join(scene_path, "image", f) 88 | for f in sorted(os.listdir(os.path.join(scene_path, "image"))) 89 | ] 90 | 91 | # world_mat is a projection matrix from world to image 92 | n_images = len(self.all_rgb_files[scene_idx]) 93 | world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 94 | scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 95 | object_scale_mat = camera_dict['scale_mat_0'] 96 | self.aabbs.append(self.get_bbox(scale_mats_np,object_scale_mat)) 97 | 98 | # W,H = self.img_wh 99 | intrinsics_scene, poses_scene = [], [] 100 | for img_idx, (scale_mat, world_mat) in enumerate(zip(scale_mats_np, world_mats_np)): 101 | P = world_mat @ scale_mat 102 | P = P[:3, :4] 103 | intrinsic, c2w = load_K_Rt_from_P(None, P) 104 | 105 | c2w = torch.from_numpy(c2w).float() 106 | intrinsic = torch.from_numpy(intrinsic).float() 107 | intrinsic[:2] /= self.downsample 108 | 109 | poses_scene.append(c2w) 110 | intrinsics_scene.append(intrinsic) 111 | 112 | self.all_pose_files[scene_idx] = np.stack(poses_scene) 113 | self.all_intrinsics_files[scene_idx] = np.stack(intrinsics_scene) 114 | 115 | self.aabbs[0][0].append(0) 116 | self.aabbs[0][1].append(self.scene_num) 117 | self.scene_bbox = self.aabbs[0] 118 | print(self.scene_bbox) 119 | if self.split=='test' or self.scene_num==1: 120 | self.load_data(self.scene_idxs[0],range(49)) 121 | 122 | def load_data(self, scene_idx, img_idx=None): 123 | self.all_rays = [] 124 | 125 | n_views = len(self.all_pose_files[scene_idx]) 126 | cam_xyzs = self.all_pose_files[scene_idx][:,:3, -1] 127 | idxs = fps_downsample(cam_xyzs, min(self.train_views, n_views)) if img_idx is None else img_idx 128 | # if "test" == self.split: 129 | # idxs = [item for item in list(range(n_views)) if item not in idxs] 130 | # if len(idxs)==0: 131 | # idxs = list(range(n_views)) 132 | 133 | images_np = np.stack([cv.resize(cv.imread(self.all_rgb_files[scene_idx][idx]), self.img_wh) for idx in idxs]) / 255.0 134 | self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3] 135 | 136 | for c2w,intrinsic in zip(self.all_pose_files[scene_idx][idxs],self.all_intrinsics_files[scene_idx][idxs]): 137 | rays_o, rays_d = self.gen_rays_at(torch.from_numpy(intrinsic).float(), torch.from_numpy(c2w).float()) 138 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 139 | 140 | if not self.is_stack: 141 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 142 | self.all_rgbs = self.all_rgbs.reshape(-1, 3) 143 | else: 144 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 145 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 146 | 147 | # self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size) 148 | 149 | 150 | # def read_meta(self): 151 | # 152 | # images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png'))) 153 | # images_np = np.stack([cv.resize(cv.imread(im_name), self.img_wh) for im_name in images_lis]) / 255.0 154 | # # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png'))) 155 | # # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128 156 | # 157 | # self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3] 158 | # # self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3] 159 | # self.img_wh = [self.all_rgbs.shape[2], self.all_rgbs.shape[1]] 160 | # 161 | # # world_mat is a projection matrix from world to image 162 | # n_images = len(images_lis) 163 | # world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 164 | # self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 165 | # 166 | # # W,H = self.img_wh 167 | # self.all_rays = [] 168 | # self.intrinsics, self.poses = [], [] 169 | # for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)): 170 | # P = world_mat @ scale_mat 171 | # P = P[:3, :4] 172 | # intrinsic, c2w = load_K_Rt_from_P(None, P) 173 | # 174 | # c2w = torch.from_numpy(c2w).float() 175 | # intrinsic = torch.from_numpy(intrinsic).float() 176 | # intrinsic[:2] /= self.downsample 177 | # 178 | # self.poses.append(c2w) 179 | # self.intrinsics.append(intrinsic) 180 | # 181 | # rays_o, rays_d = self.gen_rays_at(intrinsic, c2w) 182 | # self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 183 | # 184 | # self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses) 185 | # 186 | # # self.all_rgbs[~self.all_masks] = 1.0 187 | # if not self.is_stack: 188 | # self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 189 | # self.all_rgbs = self.all_rgbs.reshape(-1, 3) 190 | # else: 191 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 192 | # self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 193 | # 194 | # self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size) 195 | 196 | def get_bbox(self, scale_mats_np, object_scale_mat): 197 | object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0]) 198 | object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0]) 199 | # Object scale mat: region of interest to **extract mesh** 200 | # object_scale_mat = np.load(os.path.join(scene_path, 'cameras.npz')) 201 | object_bbox_min = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None] 202 | object_bbox_max = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None] 203 | return [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()] 204 | # self.near_far = [2.125, 4.525] 205 | 206 | def gen_rays_at(self, intrinsic, c2w, resolution_level=1): 207 | """ 208 | Generate rays at world space from one camera. 209 | """ 210 | l = resolution_level 211 | W,H = self.img_wh 212 | tx = torch.linspace(0, W - 1, W // l)+0.5 213 | ty = torch.linspace(0, H - 1, H // l)+0.5 214 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 215 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 216 | intrinsic_inv = torch.inverse(intrinsic) 217 | p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 218 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 219 | rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 220 | rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3 221 | return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3) 222 | 223 | 224 | def __len__(self): 225 | return 1000000 #len(self.all_rays) 226 | 227 | def __getitem__(self, idx): 228 | idx = torch.randint(self.scene_num,(1,)).item() 229 | if self.scene_num >= 1: 230 | scene_name = self.scene_idxs[idx] 231 | img_idx = np.random.choice(len(self.all_rgb_files[scene_name]), size=6) 232 | self.load_data(scene_name, img_idx) 233 | 234 | idxs = np.random.choice(self.all_rays.shape[0], size=self.batch_size) 235 | 236 | return {'rays': self.all_rays[idxs], 'rgbs': self.all_rgbs[idxs], 'idx': idx} -------------------------------------------------------------------------------- /dataLoader/dtu_objs2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from .ray_utils import * 8 | from torch.utils.data import Dataset 9 | 10 | 11 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 12 | def load_K_Rt_from_P(filename, P=None): 13 | if P is None: 14 | lines = open(filename).read().splitlines() 15 | if len(lines) == 4: 16 | lines = lines[1:] 17 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 18 | P = np.asarray(lines).astype(np.float32).squeeze() 19 | 20 | out = cv.decomposeProjectionMatrix(P) 21 | K = out[0] 22 | R = out[1] 23 | t = out[2] 24 | 25 | K = K / K[2, 2] 26 | intrinsics = np.eye(4) 27 | intrinsics[:3, :3] = K 28 | 29 | pose = np.eye(4, dtype=np.float32) 30 | pose[:3, :3] = R.transpose() 31 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 32 | 33 | return intrinsics, pose 34 | 35 | class DTUDataset(Dataset): 36 | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None): 37 | """ 38 | img_wh should be set to a tuple ex: (1152, 864) to enable test mode! 39 | """ 40 | # self.N_vis = N_vis 41 | self.split = split 42 | self.batch_size = batch_size 43 | self.root_dir = cfg.datadir 44 | self.is_stack = is_stack if is_stack is not None else 'train'!=split 45 | self.downsample = cfg.get(f'downsample_{self.split}') 46 | self.img_wh = (int(400 / self.downsample), int(300 / self.downsample)) 47 | 48 | self.white_bg = False 49 | self.camera_dict = np.load(os.path.join(self.root_dir, 'cameras.npz')) 50 | 51 | self.read_meta() 52 | self.get_bbox() 53 | 54 | # def define_transforms(self): 55 | # self.transform = T.ToTensor() 56 | 57 | def get_bbox(self): 58 | object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0]) 59 | object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0]) 60 | # Object scale mat: region of interest to **extract mesh** 61 | object_scale_mat = np.load(os.path.join(self.root_dir, 'cameras.npz'))['scale_mat_0'] 62 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None] 63 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None] 64 | self.scene_bbox = [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()] 65 | self.scene_bbox[0].append(0) 66 | self.scene_bbox[1].append(1) 67 | 68 | def gen_rays_at(self, intrinsic, c2w, resolution_level=1): 69 | """ 70 | Generate rays at world space from one camera. 71 | """ 72 | l = resolution_level 73 | W,H = self.img_wh 74 | tx = torch.linspace(0, W - 1, W // l)+0.5 75 | ty = torch.linspace(0, H - 1, H // l)+0.5 76 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 77 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 78 | intrinsic_inv = torch.inverse(intrinsic) 79 | p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 80 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 81 | rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 82 | rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3 83 | return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3) 84 | 85 | def read_meta(self): 86 | 87 | images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png'))) 88 | images_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in images_lis]) / 255.0 89 | # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png'))) 90 | # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128 91 | 92 | self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[...,[2,1,0]]) # [n_images, H, W, 3] 93 | # self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3] 94 | self.img_wh = [self.all_rgbs.shape[2],self.all_rgbs.shape[1]] 95 | 96 | # world_mat is a projection matrix from world to image 97 | n_images = len(images_lis) 98 | world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 99 | self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] 100 | 101 | # W,H = self.img_wh 102 | self.all_rays = [] 103 | self.intrinsics, self.poses = [],[] 104 | for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)): 105 | P = world_mat @ scale_mat 106 | P = P[:3, :4] 107 | intrinsic, c2w = load_K_Rt_from_P(None, P) 108 | 109 | c2w = torch.from_numpy(c2w).float() 110 | intrinsic = torch.from_numpy(intrinsic).float() 111 | intrinsic[:2] /= self.downsample 112 | 113 | self.poses.append(c2w) 114 | self.intrinsics.append(intrinsic) 115 | 116 | rays_o, rays_d = self.gen_rays_at(intrinsic,c2w) 117 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 118 | 119 | self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses) 120 | 121 | # self.all_rgbs[~self.all_masks] = 1.0 122 | if not self.is_stack: 123 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 124 | self.all_rgbs = self.all_rgbs.reshape(-1,3) 125 | else: 126 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 127 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3) 128 | 129 | self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size) 130 | 131 | def __len__(self): 132 | return len(self.all_rays) 133 | 134 | def __getitem__(self, idx): 135 | idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,)) 136 | sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]} 137 | return sample -------------------------------------------------------------------------------- /dataLoader/image.py: -------------------------------------------------------------------------------- 1 | import torch,imageio,cv2 2 | from PIL import Image 3 | Image.MAX_IMAGE_PIXELS = 1000000000 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | 8 | _img_suffix = ['png','jpg','jpeg','bmp','tif'] 9 | 10 | def load(path): 11 | suffix = path.split('.')[-1] 12 | if suffix in _img_suffix: 13 | img = np.array(Image.open(path))#.convert('L') 14 | scale = 256.**(1+np.log2(np.max(img))//8)-1 15 | return img/scale 16 | elif 'exr' == suffix: 17 | return imageio.imread(path) 18 | elif 'npy' == suffix: 19 | return np.load(path) 20 | 21 | 22 | def srgb_to_linear(img): 23 | limit = 0.04045 24 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) 25 | 26 | class ImageDataset(Dataset): 27 | def __init__(self, cfg, batchsize, split='train', continue_sampling=False, tolinear=False, HW=-1, perscent=1.0, delete_region=None,mask=None): 28 | 29 | datadir = cfg.datadir 30 | self.batchsize = batchsize 31 | self.continue_sampling = continue_sampling 32 | img = load(datadir).astype(np.float32) 33 | if HW>0: 34 | img = cv2.resize(img,[HW,HW]) 35 | 36 | if tolinear: 37 | img = srgb_to_linear(img) 38 | self.img = torch.from_numpy(img) 39 | 40 | H,W = self.img.shape[:2] 41 | 42 | y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij') 43 | self.coordiante = torch.stack((x,y),-1).float()+0.5 44 | 45 | n_channel = self.img.shape[-1] 46 | self.image = self.img 47 | self.img, self.coordiante = self.img.reshape(H*W,-1), self.coordiante.reshape(H*W,2) 48 | 49 | # if continue_sampling: 50 | # coordiante_tmp = self.coordiante.view(1,1,-1,2)/torch.tensor([W,H])*2-1.0 51 | # self.img = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordiante_tmp, mode='bilinear', align_corners=True).reshape(self.img.shape[-1],-1).t() 52 | 53 | 54 | if 'train'==split: 55 | self.mask = torch.ones_like(y)>0 56 | if mask is not None: 57 | self.mask = mask>0 58 | print(torch.sum(mask)/1.0/HW/HW) 59 | elif delete_region is not None: 60 | 61 | if isinstance(delete_region[0], list): 62 | for item in delete_region: 63 | t_l_x,t_l_y,width,height = item 64 | self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False 65 | else: 66 | t_l_x,t_l_y,width,height = delete_region 67 | self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False 68 | else: 69 | index = torch.randperm(len(self.img))[:int(len(self.img)*perscent)] 70 | self.mask[:] = False 71 | self.mask.view(-1)[index] = True 72 | self.mask = self.mask.view(-1) 73 | self.image, self.coordiante = self.img[self.mask], self.coordiante[self.mask] 74 | else: 75 | self.image = self.img 76 | 77 | 78 | self.HW = [H,W] 79 | 80 | self.scene_bbox = [[0., 0.], [W, H]] 81 | cfg.aabb = self.scene_bbox 82 | # 83 | 84 | def __len__(self): 85 | return 10000 86 | 87 | def __getitem__(self, idx): 88 | H,W = self.HW 89 | device = self.image.device 90 | idx = torch.randint(0,len(self.image),(self.batchsize,), device=device) 91 | 92 | if self.continue_sampling: 93 | coordinate = self.coordiante[idx] + torch.rand((self.batchsize,2))-0.5 94 | coordinate_tmp = (coordinate.view(1,1,self.batchsize,2))/torch.tensor([W,H],device=device)*2-1.0 95 | rgb = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear', 96 | align_corners=False, padding_mode='border').reshape(self.img.shape[-1],-1).t() 97 | sample = {'rgb': rgb, 98 | 'xy': coordinate} 99 | else: 100 | sample = {'rgb': self.image[idx], 101 | 'xy': self.coordiante[idx]} 102 | 103 | return sample -------------------------------------------------------------------------------- /dataLoader/image_set.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset 5 | 6 | def srgb_to_linear(img): 7 | limit = 0.04045 8 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92) 9 | 10 | def load(path, HW=512): 11 | suffix = path.split('.')[-1] 12 | 13 | if 'npy' == suffix: 14 | img = np.load(path) 15 | # img = 0.3*img[...,:1] + 0.59*img[...,1:2] + 0.11*img[...,2:] 16 | 17 | if img.shape[-2]!=HW: 18 | for i in range(img.shape[0]): 19 | img[i] = cv2.resize(img[i],[HW,HW]) 20 | 21 | return img 22 | 23 | 24 | class ImageSetDataset(Dataset): 25 | def __init__(self, cfg, batchsize, split='train', continue_sampling=False, HW=512, N=10, tolinear=True): 26 | 27 | datadir = cfg.datadir 28 | self.batchsize = batchsize 29 | self.continue_sampling = continue_sampling 30 | imgs = load(datadir,HW=HW)[:N] 31 | 32 | 33 | self.imgs = torch.from_numpy(imgs).float()/255 34 | if tolinear: 35 | self.imgs = srgb_to_linear(self.imgs) 36 | 37 | D,H,W = self.imgs.shape[:3] 38 | 39 | y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij') 40 | self.coordinate = torch.stack((x,y),-1).float()+0.5 41 | 42 | self.imgs, self.coordinate = self.imgs.reshape(D,H*W,-1), self.coordinate.reshape(H*W,2) 43 | self.DHW = [D,H,W] 44 | 45 | self.scene_bbox = [[0., 0., 0.], [W, H, D]] 46 | cfg.aabb = self.scene_bbox 47 | # self.down_scale = 512.0/H 48 | 49 | 50 | def __len__(self): 51 | return 1000000 52 | 53 | def __getitem__(self, idx): 54 | D,H,W = self.DHW 55 | pix_idx = torch.randint(0,H*W,(self.batchsize,)) 56 | img_idx = torch.randint(0,D,(self.batchsize,)) 57 | 58 | 59 | if self.continue_sampling: 60 | coordinate = self.coordinate[pix_idx] + torch.rand((self.batchsize,2)) - 0.5 61 | coordinate = torch.cat((coordinate,img_idx.unsqueeze(-1)+0.5),dim=-1) 62 | coordinate_tmp = (coordinate.view(1,1,1,self.batchsize,3))/torch.tensor([W,H,D])*2-1.0 63 | rgb = F.grid_sample(self.imgs.view(1,D,H,W,-1).permute(0,4,1,2,3),coordinate_tmp, mode='bilinear', 64 | align_corners=False, padding_mode='border').reshape(self.imgs.shape[-1],-1).t() 65 | # coordinate[:,:2] *= self.down_scale 66 | sample = {'rgb': rgb, 67 | 'xy': coordinate} 68 | else: 69 | sample = {'rgb': self.imgs[img_idx,pix_idx], 70 | 'xy': torch.cat((self.coordinate[pix_idx],img_idx.unsqueeze(-1)+0.5),dim=-1)} 71 | # 'xy': torch.cat((self.coordiante[pix_idx],img_idx.expand_as(pix_idx).unsqueeze(-1)),dim=-1)} 72 | 73 | 74 | 75 | return sample -------------------------------------------------------------------------------- /dataLoader/llff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import glob 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | def normalize(v): 13 | """Normalize a vector.""" 14 | return v / np.linalg.norm(v) 15 | 16 | 17 | def average_poses(poses): 18 | """ 19 | Calculate the average pose, which is then used to center all poses 20 | using @center_poses. Its computation is as follows: 21 | 1. Compute the center: the average of pose centers. 22 | 2. Compute the z axis: the normalized average z axis. 23 | 3. Compute axis y': the average y axis. 24 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 25 | 5. Compute the y axis: z cross product x. 26 | 27 | Note that at step 3, we cannot directly use y' as y axis since it's 28 | not necessarily orthogonal to z axis. We need to pass from x to y. 29 | Inputs: 30 | poses: (N_images, 3, 4) 31 | Outputs: 32 | pose_avg: (3, 4) the average pose 33 | """ 34 | # 1. Compute the center 35 | center = poses[..., 3].mean(0) # (3) 36 | 37 | # 2. Compute the z axis 38 | z = normalize(poses[..., 2].mean(0)) # (3) 39 | 40 | # 3. Compute axis y' (no need to normalize as it's not the final output) 41 | y_ = poses[..., 1].mean(0) # (3) 42 | 43 | # 4. Compute the x axis 44 | x = normalize(np.cross(z, y_)) # (3) 45 | 46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 47 | y = np.cross(x, z) # (3) 48 | 49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 50 | 51 | return pose_avg 52 | 53 | 54 | def center_poses(poses, blender2opencv): 55 | """ 56 | Center the poses so that we can use NDC. 57 | See https://github.com/bmild/nerf/issues/34 58 | Inputs: 59 | poses: (N_images, 3, 4) 60 | Outputs: 61 | poses_centered: (N_images, 3, 4) the centered poses 62 | pose_avg: (3, 4) the average pose 63 | """ 64 | poses = poses @ blender2opencv 65 | pose_avg = average_poses(poses) # (3, 4) 66 | pose_avg_homo = np.eye(4) 67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 68 | pose_avg_homo = pose_avg_homo 69 | # by simply adding 0, 0, 0, 1 as the last row 70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 71 | poses_homo = \ 72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 73 | 74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 75 | # poses_centered = poses_centered @ blender2opencv 76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 77 | 78 | return poses_centered, pose_avg_homo 79 | 80 | 81 | def viewmatrix(z, up, pos): 82 | vec2 = normalize(z) 83 | vec1_avg = up 84 | vec0 = normalize(np.cross(vec1_avg, vec2)) 85 | vec1 = normalize(np.cross(vec2, vec0)) 86 | m = np.eye(4) 87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) 88 | return m 89 | 90 | 91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 92 | render_poses = [] 93 | rads = np.array(list(rads) + [1.]) 94 | 95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: 96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 98 | render_poses.append(viewmatrix(z, up, c)) 99 | return render_poses 100 | 101 | 102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): 103 | # center pose 104 | c2w = average_poses(c2ws_all) 105 | 106 | # Get average pose 107 | up = normalize(c2ws_all[:, :3, 1].sum(0)) 108 | 109 | # Find a reasonable "focus depth" for this dataset 110 | dt = 0.75 111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 113 | 114 | # Get radii for spiral path 115 | zdelta = near_fars.min() * .2 116 | tt = c2ws_all[:, :3, 3] 117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale 118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 119 | return np.stack(render_poses) 120 | 121 | 122 | class LLFFDataset(Dataset): 123 | def __init__(self, cfg , split='train', hold_every=8): 124 | """ 125 | spheric_poses: whether the images are taken in a spheric inward-facing manner 126 | default: False (forward-facing) 127 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 128 | """ 129 | 130 | self.root_dir = cfg.datadir 131 | self.split = split 132 | self.hold_every = hold_every 133 | self.is_stack = False if 'train' == split else True 134 | self.downsample = cfg.get(f'downsample_{self.split}') 135 | self.define_transforms() 136 | 137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 138 | self.read_meta() 139 | self.white_bg = False 140 | 141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] 142 | self.near_far = [0.0, 1.0] 143 | self.scene_bbox = [[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]] 144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) 145 | # self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) 146 | # self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 147 | 148 | def read_meta(self): 149 | 150 | print(self.root_dir) 151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) 152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*'))) 153 | # load full resolution image then resize 154 | if self.split in ['train', 'test']: 155 | assert len(poses_bounds) == len(self.image_paths), \ 156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!' 157 | 158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2) 160 | hwf = poses[:, :, -1] 161 | 162 | # Step 1: rescale focal length according to training resolution 163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images 164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) 165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] 166 | 167 | # Step 2: correct poses 168 | # Original poses has rotation in form "down right back", change to "right up back" 169 | # See https://github.com/bmild/nerf/issues/34 170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 171 | # (N_images, 3, 4) exclude H, W, focal 172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) 173 | 174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 175 | # See https://github.com/bmild/nerf/issues/34 176 | near_original = self.near_fars.min() 177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 178 | # the nearest depth is at 1/0.75=1.33 179 | self.near_fars /= scale_factor 180 | self.poses[..., 3] /= scale_factor 181 | 182 | # build rendering path 183 | N_views, N_rots = 120, 2 184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 185 | up = normalize(self.poses[:, :3, 1].sum(0)) 186 | rads = np.percentile(np.abs(tt), 90, 0) 187 | 188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views) 189 | 190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to 192 | # center image 193 | 194 | # ray directions for all pixels, same for all images (same H, W, focal) 195 | W, H = self.img_wh 196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) 197 | 198 | average_pose = average_poses(self.poses) 199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1) 200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)] 201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test)) 202 | 203 | # use first N_images-1 to train, the LAST is val 204 | self.all_rays = [] 205 | self.all_rgbs = [] 206 | for i in img_list: 207 | image_path = self.image_paths[i] 208 | c2w = torch.FloatTensor(self.poses[i]) 209 | 210 | img = Image.open(image_path).convert('RGB') 211 | if self.downsample != 1.0: 212 | img = img.resize(self.img_wh, Image.LANCZOS) 213 | img = self.transform(img) # (3, h, w) 214 | 215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 216 | self.all_rgbs += [img] 217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 220 | 221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 222 | 223 | if not self.is_stack: 224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) 226 | else: 227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3) 228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 229 | 230 | 231 | def define_transforms(self): 232 | self.transform = T.ToTensor() 233 | 234 | def __len__(self): 235 | return len(self.all_rgbs) 236 | 237 | def __getitem__(self, idx): 238 | 239 | sample = {'rays': self.all_rays[idx], 240 | 'rgbs': self.all_rgbs[idx]} 241 | 242 | return sample -------------------------------------------------------------------------------- /dataLoader/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class NSVF(Dataset): 37 | """NSVF Generic Dataset.""" 38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): 39 | self.root_dir = datadir 40 | self.split = split 41 | self.is_stack = is_stack 42 | self.downsample = downsample 43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 44 | self.define_transforms() 45 | 46 | self.white_bg = True 47 | self.near_far = [0.5,6.0] 48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) 49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 50 | self.read_meta() 51 | self.define_proj_mat() 52 | 53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 55 | 56 | def bbox2corners(self): 57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 58 | for i in range(3): 59 | corners[i,[0,1],i] = corners[i,[1,0],i] 60 | return corners.view(-1,3) 61 | 62 | 63 | def read_meta(self): 64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: 65 | focal = float(f.readline().split()[0]) 66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) 67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) 68 | 69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 71 | 72 | if self.split == 'train': 73 | pose_files = [x for x in pose_files if x.startswith('0_')] 74 | img_files = [x for x in img_files if x.startswith('0_')] 75 | elif self.split == 'val': 76 | pose_files = [x for x in pose_files if x.startswith('1_')] 77 | img_files = [x for x in img_files if x.startswith('1_')] 78 | elif self.split == 'test': 79 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 80 | test_img_files = [x for x in img_files if x.startswith('2_')] 81 | if len(test_pose_files) == 0: 82 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 83 | test_img_files = [x for x in img_files if x.startswith('1_')] 84 | pose_files = test_pose_files 85 | img_files = test_img_files 86 | 87 | # ray directions for all pixels, same for all images (same H, W, focal) 88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 90 | 91 | frames = 200 92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,frames+1)[:-1]], 0) 93 | 94 | self.poses = [] 95 | self.all_rays = [] 96 | self.all_rgbs = [] 97 | 98 | assert len(img_files) == len(pose_files) 99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 101 | img = Image.open(image_path) 102 | if self.downsample!=1.0: 103 | img = img.resize(self.img_wh, Image.LANCZOS) 104 | img = self.transform(img) # (4, h, w) 105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 106 | if img.shape[-1]==4: 107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 108 | self.all_rgbs += [img] 109 | 110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv 111 | c2w = torch.FloatTensor(c2w) 112 | self.poses.append(c2w) # C2W 113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 115 | 116 | # w2c = torch.inverse(c2w) 117 | # 118 | 119 | self.poses = torch.stack(self.poses) 120 | if 'train' == self.split: 121 | if self.is_stack: 122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 124 | else: 125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 127 | else: 128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 130 | 131 | 132 | def define_transforms(self): 133 | self.transform = T.ToTensor() 134 | 135 | def define_proj_mat(self): 136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 137 | 138 | def world2ndc(self, points): 139 | device = points.device 140 | return (points - self.center.to(device)) / self.radius.to(device) 141 | 142 | def __len__(self): 143 | if self.split == 'train': 144 | return len(self.all_rays) 145 | return len(self.all_rgbs) 146 | 147 | def __getitem__(self, idx): 148 | 149 | if self.split == 'train': # use data in the buffers 150 | sample = {'rays': self.all_rays[idx], 151 | 'rgbs': self.all_rgbs[idx]} 152 | 153 | else: # create data for each image separately 154 | 155 | img = self.all_rgbs[idx] 156 | rays = self.all_rays[idx] 157 | 158 | sample = {'rays': rays, 159 | 'rgbs': img} 160 | return sample -------------------------------------------------------------------------------- /dataLoader/sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | def N_to_reso(avg_reso, bbox): 6 | xyz_min, xyz_max = bbox 7 | dim = len(xyz_min) 8 | n_voxels = avg_reso**dim 9 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 10 | return torch.ceil((xyz_max - xyz_min) / voxel_size).long().tolist() 11 | 12 | def load(path, split, dtype='points'): 13 | 14 | if 'grid' == dtype: 15 | sdf = torch.from_numpy(np.load(path).astype(np.float32)) 16 | D, H, W = sdf.shape 17 | z, y, x = torch.meshgrid(torch.arange(0, D), torch.arange(0, H), torch.arange(0, W), indexing='ij') 18 | coordiante = torch.stack((x,y,z),-1).reshape(D*H*W,3)#*2-1 # normalize to [-1,1] 19 | sdf = sdf.reshape(D*H*W,-1) 20 | DHW = [D,H,W] 21 | elif 'points' == dtype: 22 | DHW = [640] * 3 23 | sdf_dict = np.load(path, allow_pickle=True).item() 24 | sdf = torch.from_numpy(sdf_dict[f'sdfs_{split}'].astype(np.float32)).reshape(-1,1) 25 | coordiante = torch.from_numpy(sdf_dict[f'points_{split}'].astype(np.float32)) 26 | aabb = [[-1,-1,-1],[1,1,1]] 27 | coordiante = ((coordiante + 1) / 2 * (torch.tensor(DHW[::-1]))).reshape(-1,3) 28 | DHW = DHW[::-1] 29 | return coordiante, sdf, DHW 30 | 31 | class SDFDataset(Dataset): 32 | def __init__(self, cfg, split='train'): 33 | 34 | datadir = cfg.datadir 35 | self.coordiante, self.sdf, self.DHW = load(datadir, split) 36 | 37 | [D, H, W] = self.DHW 38 | 39 | self.scene_bbox = [[0., 0., 0.], [W, H, D]] 40 | cfg.aabb = self.scene_bbox 41 | 42 | def __len__(self): 43 | return len(self.sdf) 44 | 45 | def __getitem__(self, idx): 46 | sample = {'rgb': self.sdf[idx], 47 | 'xy': self.coordiante[idx]} 48 | 49 | return sample -------------------------------------------------------------------------------- /dataLoader/tankstemple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | 11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 12 | if axis == 'z': 13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] 14 | elif axis == 'y': 15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] 16 | else: 17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] 18 | 19 | 20 | def cross(x, y, axis=0): 21 | T = torch if isinstance(x, torch.Tensor) else np 22 | return T.cross(x, y, axis) 23 | 24 | 25 | def normalize(x, axis=-1, order=2): 26 | if isinstance(x, torch.Tensor): 27 | l2 = x.norm(p=order, dim=axis, keepdim=True) 28 | return x / (l2 + 1e-8), l2 29 | 30 | else: 31 | l2 = np.linalg.norm(x, order, axis) 32 | l2 = np.expand_dims(l2, axis) 33 | l2[l2 == 0] = 1 34 | return x / l2, 35 | 36 | 37 | def cat(x, axis=1): 38 | if isinstance(x[0], torch.Tensor): 39 | return torch.cat(x, dim=axis) 40 | return np.concatenate(x, axis=axis) 41 | 42 | 43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 44 | """ 45 | This function takes a vector 'camera_position' which specifies the location 46 | of the camera in world coordinates and two vectors `at` and `up` which 47 | indicate the position of the object and the up directions of the world 48 | coordinate system respectively. The object is assumed to be centered at 49 | the origin. 50 | The output is a rotation matrix representing the transformation 51 | from world coordinates -> view coordinates. 52 | Input: 53 | camera_position: 3 54 | at: 1 x 3 or N x 3 (0, 0, 0) in default 55 | up: 1 x 3 or N x 3 (0, 1, 0) in default 56 | """ 57 | 58 | if at is None: 59 | at = torch.zeros_like(camera_position) 60 | else: 61 | at = torch.tensor(at).type_as(camera_position) 62 | if up is None: 63 | up = torch.zeros_like(camera_position) 64 | up[2] = -1 65 | else: 66 | up = torch.tensor(up).type_as(camera_position) 67 | 68 | z_axis = normalize(at - camera_position)[0] 69 | x_axis = normalize(cross(up, z_axis))[0] 70 | y_axis = normalize(cross(z_axis, x_axis))[0] 71 | 72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 73 | return R 74 | 75 | 76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): 77 | c2ws = [] 78 | for t in range(frames): 79 | c2w = torch.eye(4) 80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) 81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) 82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot 83 | c2ws.append(c2w) 84 | return torch.stack(c2ws) 85 | 86 | 87 | class TanksTempleDataset(Dataset): 88 | """NSVF Generic Dataset.""" 89 | 90 | def __init__(self, cfg, split='train'): 91 | self.root_dir = cfg.datadir 92 | self.split = split 93 | self.is_stack = False if 'train'==split else True 94 | self.downsample = cfg.get(f'downsample_{self.split}') 95 | self.img_wh = (int(1920 / self.downsample), int(1080 / self.downsample)) 96 | self.define_transforms() 97 | 98 | self.white_bg = True 99 | self.near_far = [0.01, 6.0] 100 | self.scene_bbox = (torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2, 3) * 1.2).tolist() 101 | 102 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 103 | self.read_meta() 104 | self.define_proj_mat() 105 | 106 | # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 107 | # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 108 | 109 | def bbox2corners(self): 110 | corners = self.scene_bbox.unsqueeze(0).repeat(4, 1, 1) 111 | for i in range(3): 112 | corners[i, [0, 1], i] = corners[i, [1, 0], i] 113 | return corners.view(-1, 3) 114 | 115 | def read_meta(self): 116 | 117 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) 118 | self.intrinsics[:2] *= (np.array(self.img_wh) / np.array([1920, 1080])).reshape(2, 1) 119 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 120 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 121 | 122 | if self.split == 'train': 123 | pose_files = [x for x in pose_files if x.startswith('0_')] 124 | img_files = [x for x in img_files if x.startswith('0_')] 125 | elif self.split == 'val': 126 | pose_files = [x for x in pose_files if x.startswith('1_')] 127 | img_files = [x for x in img_files if x.startswith('1_')] 128 | elif self.split == 'test': 129 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 130 | test_img_files = [x for x in img_files if x.startswith('2_')] 131 | if len(test_pose_files) == 0: 132 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 133 | test_img_files = [x for x in img_files if x.startswith('1_')] 134 | pose_files = test_pose_files 135 | img_files = test_img_files 136 | 137 | # ray directions for all pixels, same for all images (same H, W, focal) 138 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], 139 | [self.intrinsics[0, 0], self.intrinsics[1, 1]], 140 | center=self.intrinsics[:2, 2]) # (h, w, 3) 141 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 142 | 143 | self.poses = [] 144 | 145 | ray_per_frame = self.img_wh[0]*self.img_wh[1] 146 | self.all_rays = torch.empty(len(pose_files),ray_per_frame,6) 147 | self.all_rgbs = torch.empty(len(pose_files),ray_per_frame,3) 148 | assert len(img_files) == len(pose_files) 149 | for i in tqdm(range(len(pose_files)),desc=f'Loading data {self.split} ({len(img_files)})'): 150 | 151 | img_fname, pose_fname = img_files[i], pose_files[i] 152 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 153 | img = Image.open(image_path) 154 | if self.downsample != 1.0: 155 | img = img.resize(self.img_wh, Image.LANCZOS) 156 | img = self.transform(img) # (4, h, w) 157 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 158 | if img.shape[-1] == 4: 159 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 160 | # self.all_rgbs.append(img) 161 | 162 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) # @ cam_trans 163 | c2w = torch.FloatTensor(c2w) 164 | self.poses.append(c2w) # C2W 165 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 166 | 167 | self.all_rays[i] = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 168 | self.all_rgbs[i] = img 169 | 170 | self.poses = torch.stack(self.poses) 171 | 172 | frames = 200 173 | scene_bbox = torch.tensor(self.scene_bbox).float() 174 | center = torch.mean(scene_bbox, dim=0) 175 | radius = torch.norm(scene_bbox[1] - center) * 1.2 176 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() 177 | pos_gen = circle(radius=radius, h=-0.2 * up[1], axis='y') 178 | self.render_path = gen_path(pos_gen, up=up, frames=frames) 179 | self.render_path[:, :3, 3] += center 180 | 181 | if 'train' == self.split: 182 | if not self.is_stack: 183 | # self.all_rays = torch.stack(self.all_rays, 0).reshape(-1, *self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 184 | # self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 185 | # else: 186 | self.all_rays = self.all_rays.reshape(-1,6) # (len(self.meta['frames])*h*w, 3) 187 | self.all_rgbs = self.all_rgbs.reshape(-1,3) # (len(self.meta['frames])*h*w, 3) 188 | else: 189 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 190 | self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 191 | 192 | def define_transforms(self): 193 | self.transform = T.ToTensor() 194 | 195 | def define_proj_mat(self): 196 | self.proj_mat = torch.from_numpy(self.intrinsics[:3, :3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:, :3] 197 | 198 | def world2ndc(self, points): 199 | device = points.device 200 | return (points - self.center.to(device)) / self.radius.to(device) 201 | 202 | def __len__(self): 203 | if self.split == 'train': 204 | return len(self.all_rays) 205 | return len(self.all_rgbs) 206 | 207 | def __getitem__(self, idx): 208 | 209 | if self.split == 'train': # use data in the buffers 210 | sample = {'rays': self.all_rays[idx], 211 | 'rgbs': self.all_rgbs[idx]} 212 | 213 | else: # create data for each image separately 214 | 215 | img = self.all_rgbs[idx] 216 | rays = self.all_rays[idx] 217 | 218 | sample = {'rays': rays, 219 | 'rgbs': img} 220 | return sample -------------------------------------------------------------------------------- /dataLoader/your_own_data.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class YourOwnDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.downsample = downsample 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [0.1,100.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample) 45 | self.img_wh = [w,h] 46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 48 | self.cx, self.cy = self.meta['cx'],self.meta['cy'] 49 | 50 | 51 | # ray directions for all pixels, same for all images (same H, W, focal) 52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float() 55 | 56 | self.image_paths = [] 57 | self.poses = [] 58 | self.all_rays = [] 59 | self.all_rgbs = [] 60 | self.all_masks = [] 61 | self.all_depth = [] 62 | 63 | 64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 67 | 68 | frame = self.meta['frames'][i] 69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 70 | c2w = torch.FloatTensor(pose) 71 | self.poses += [c2w] 72 | 73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 74 | self.image_paths += [image_path] 75 | img = Image.open(image_path) 76 | 77 | if self.downsample!=1.0: 78 | img = img.resize(self.img_wh, Image.LANCZOS) 79 | img = self.transform(img) # (4, h, w) 80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA 81 | if img.shape[-1]==4: 82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 83 | self.all_rgbs += [img] 84 | 85 | 86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 88 | 89 | 90 | self.poses = torch.stack(self.poses) 91 | if not self.is_stack: 92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 94 | 95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 96 | else: 97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 100 | 101 | 102 | def define_transforms(self): 103 | self.transform = T.ToTensor() 104 | 105 | def define_proj_mat(self): 106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 107 | 108 | def world2ndc(self,points,lindisp=None): 109 | device = points.device 110 | return (points - self.center.to(device)) / self.radius.to(device) 111 | 112 | def __len__(self): 113 | return len(self.all_rgbs) 114 | 115 | def __getitem__(self, idx): 116 | 117 | if self.split == 'train': # use data in the buffers 118 | sample = {'rays': self.all_rays[idx], 119 | 'rgbs': self.all_rgbs[idx]} 120 | 121 | else: # create data for each image separately 122 | 123 | img = self.all_rgbs[idx] 124 | rays = self.all_rays[idx] 125 | mask = self.all_masks[idx] # for quantity evaluation 126 | 127 | sample = {'rays': rays, 128 | 'rgbs': img} 129 | return sample 130 | -------------------------------------------------------------------------------- /media/Girl_with_a_Pearl_Earring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/media/Girl_with_a_Pearl_Earring.jpg -------------------------------------------------------------------------------- /media/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/media/inpainting.png -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/models/.DS_Store -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/models/__init__.py -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | import torch,os,imageio,sys 2 | from tqdm.auto import tqdm 3 | from dataLoader.ray_utils import get_rays 4 | from utils import * 5 | from dataLoader.ray_utils import ndc_rays_blender 6 | 7 | 8 | def render_ray(rays, factor_fields, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'): 9 | 10 | rgbs, alphas, depth_maps, weights, coeffs = [], [], [], [], [] 11 | N_rays_all = rays.shape[0] 12 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): 13 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) 14 | 15 | if is_train: 16 | rgb_map, depth_map, coeff = factor_fields(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples) 17 | coeffs.append(coeff) 18 | else: 19 | rgb_map, depth_map, _ = factor_fields(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples) 20 | 21 | rgbs.append(rgb_map) 22 | depth_maps.append(depth_map) 23 | 24 | if is_train: 25 | return torch.cat(rgbs), torch.cat(depth_maps), torch.cat(coeffs) 26 | else: 27 | return torch.cat(rgbs), torch.cat(depth_maps) 28 | 29 | @torch.no_grad() 30 | def evaluation(test_dataset,factor_fields, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 31 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 32 | PSNRs, rgb_maps, depth_maps = [], [], [] 33 | ssims,l_alex,l_vgg=[],[],[] 34 | os.makedirs(savePath, exist_ok=True) 35 | os.makedirs(savePath+"/rgbd", exist_ok=True) 36 | 37 | try: 38 | tqdm._instances.clear() 39 | except Exception: 40 | pass 41 | 42 | torch.cuda.empty_cache() 43 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1) 44 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval)) 45 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout): 46 | 47 | W, H = test_dataset.img_wh 48 | rays = samples.view(-1,samples.shape[-1]) 49 | 50 | rgb_map, depth_map = renderer(rays, factor_fields, chunk=1024, N_samples=N_samples, 51 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 52 | rgb_map = rgb_map.clamp(0.0, 1.0) 53 | 54 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 55 | 56 | depth_map, _ = visualize_depth_numpy(depth_map.numpy()) 57 | if len(test_dataset.all_rgbs): 58 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3) 59 | loss = torch.mean((rgb_map - gt_rgb) ** 2) 60 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) 61 | 62 | if compute_extra_metrics: 63 | ssim = rgb_ssim(rgb_map, gt_rgb, 1) 64 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', factor_fields.device) 65 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', factor_fields.device) 66 | ssims.append(ssim) 67 | l_alex.append(l_a) 68 | l_vgg.append(l_v) 69 | 70 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 71 | gt_rgb = (gt_rgb.numpy() * 255).astype('uint8') 72 | 73 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 74 | rgb_maps.append(rgb_map) 75 | depth_maps.append(depth_map) 76 | if savePath is not None: 77 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 78 | rgb_map = np.concatenate((rgb_map, gt_rgb, depth_map), axis=1) 79 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 80 | 81 | torch.cuda.empty_cache() 82 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10) 83 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10) 84 | 85 | n_params = factor_fields.n_parameters() 86 | if PSNRs: 87 | psnr = np.mean(np.asarray(PSNRs)) 88 | if compute_extra_metrics: 89 | ssim = np.mean(np.asarray(ssims)) 90 | l_a = np.mean(np.asarray(l_alex)) 91 | l_v = np.mean(np.asarray(l_vgg)) 92 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v, n_params])) 93 | print(f"PSNR={psnr} SSIM={ssim} {l_a} {l_v} ") 94 | else: 95 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr,n_params])) 96 | 97 | 98 | return PSNRs 99 | 100 | @torch.no_grad() 101 | def evaluation_path(test_dataset,factor_fields, c2ws, renderer, savePath=None, prtx='', N_samples=-1, 102 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 103 | PSNRs, rgb_maps, depth_maps = [], [], [] 104 | ssims,l_alex,l_vgg=[],[],[] 105 | os.makedirs(savePath, exist_ok=True) 106 | os.makedirs(savePath+"/rgbd", exist_ok=True) 107 | 108 | try: 109 | tqdm._instances.clear() 110 | except Exception: 111 | pass 112 | 113 | near_far = test_dataset.near_far 114 | for idx, c2w in tqdm(enumerate(c2ws)): 115 | 116 | W, H = test_dataset.img_wh 117 | 118 | c2w = torch.FloatTensor(c2w) 119 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) 120 | if ndc_ray: 121 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) 122 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 123 | 124 | rgb_map, depth_map = renderer(rays, factor_fields, chunk=8192, N_samples=N_samples, 125 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 126 | rgb_map = rgb_map.clamp(0.0, 1.0) 127 | 128 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 129 | 130 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 131 | 132 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 133 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 134 | rgb_maps.append(rgb_map) 135 | depth_maps.append(depth_map) 136 | if savePath is not None: 137 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 138 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 139 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 140 | 141 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) 142 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) 143 | 144 | if PSNRs: 145 | psnr = np.mean(np.asarray(PSNRs)) 146 | if compute_extra_metrics: 147 | ssim = np.mean(np.asarray(ssims)) 148 | l_a = np.mean(np.asarray(l_alex)) 149 | l_v = np.mean(np.asarray(l_vgg)) 150 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 151 | else: 152 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 153 | 154 | 155 | return PSNRs 156 | 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.26.0 2 | kornia==0.6.10 3 | lpips==0.1.4 4 | matplotlib==3.7.1 5 | omegaconf==2.3.0 6 | opencv_python==4.7.0.72 7 | plyfile==0.7.4 8 | scikit-image 9 | tqdm==4.65.0 10 | trimesh==3.20.2 11 | jupyterlab -------------------------------------------------------------------------------- /run_batch.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import threading, queue 4 | import numpy as np 5 | import time 6 | 7 | 8 | if __name__ == '__main__': 9 | 10 | ################ per scene NeRF ################ 11 | commands = { 12 | '-grid': '', \ 13 | # '-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \ 14 | # '-noC': 'model.coeff_type=none', \ 15 | # '-SL':'model.basis_dims=[18] model.basis_resos=[70] model.freq_bands=[8.]', \ 16 | # '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \ 17 | # '-iNGP-like': 'model.basis_type=hash model.coeff_type=none', \ 18 | # '-hash': f'model.basis_type=hash model.coef_init=1.0', \ 19 | # '-sinc': f'model.basis_mapping=sinc', \ 20 | # '-tria': f'model.basis_mapping=triangle', \ 21 | # '-vm': f'model.coeff_type=vm model.basis_type=vm', \ 22 | # '-mlpB': 'model.basis_type=mlp', \ 23 | # '-mlpC': 'model.coeff_type=mlp', \ 24 | # '-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \ 25 | # '-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \ 26 | # f'model.num_layers=8 model.hidden_dim=256 ' \ 27 | # f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \ 28 | # '-iNGP-like-sl': 'model.basis_type=hash model.coeff_type=none model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \ 29 | # '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \ 30 | # '-DCT':'model.basis_type=fix-grid', \ 31 | } 32 | 33 | ################ per scene NeRF ################ 34 | ###### uncomment the following five lines if you want to train on all scenes ######### 35 | cmds = [] 36 | for name in commands.keys(): # 37 | # for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:# 38 | # cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {commands[name]}' 39 | # cmds.append(cmd) 40 | 41 | for scene in ['Ignatius','Truck']:# 42 | if scene != 'Ignatius': 43 | cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/TanksAndTemple/{scene} {commands[name]} ' \ 44 | f' dataset.dataset_name=tankstemple ' 45 | cmds.append(cmd) 46 | 47 | cmd = f'python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/TanksAndTemple/{scene} {commands[name]} ' \ 48 | f' dataset.dataset_name=tankstemple exportation.render_only=1 exportation.render_path=1 exportation.render_test=0 ' \ 49 | f' defaults.ckpt=/mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/{scene}-grid/{scene}-grid.th ' 50 | cmds.append(cmd) 51 | 52 | ################ generalization NeRF ################ 53 | commands = { 54 | # '-grid': '', \ 55 | # '-DVGO-like': 'model.basis_type=none model.coeff_reso=48', 56 | # '-SL':'model.basis_dims=[72] model.basis_resos=[48] model.freq_bands=[6.]', \ 57 | # '-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \ 58 | # '-hash': f'model.basis_type=hash model.coef_init=1.0 ', \ 59 | # '-sinc': f'model.basis_mapping=sinc', \ 60 | # '-tria': f'model.basis_mapping=triangle', \ 61 | # '-vm': f'model.coeff_type=vm model.basis_type=vm', \ 62 | # '-mlpB': 'model.basis_type=mlp', \ 63 | # '-mlpC': 'model.coeff_type=mlp', \ 64 | # '-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \ 65 | # '-DCT':'model.basis_type=fix-grid', \ 66 | } 67 | # for name in commands.keys(): # 68 | # config = commands[name] 69 | # config = f'python train_across_scene2.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \ 70 | # f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/' 71 | # cmds.append(config) 72 | 73 | 74 | # # =========> fine tuning <================ 75 | # views = 5 76 | # for name in commands.keys(): # 77 | # for scene in [183,199,298,467,957,244,963,527]:# 78 | # cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \ 79 | # f'dataset.datadir=/home/anpei/Dataset/google_scanned_objects/ {commands[name]} ' \ 80 | # f'dataset.train_views={views} ' \ 81 | # f'dataset.train_scene_list=[{scene}] ' \ 82 | # f'dataset.test_scene_list=[{scene}] ' \ 83 | # f'defaults.ckpt=/home/anpei/Code/NeuBasis/log/google-obj{name}//google-obj{name}.th ' 84 | # cmds.append(cmd) 85 | 86 | # for scene in ['Ignatius','Barn','Truck','Family','Caterpillar']:#'Ignatius','Barn','Truck','Family','Caterpillar' 87 | # cmds.append(f'python train_basis.py configs/tnt.yaml defaults.expname=tnt_{scene} ' \ 88 | # f'dataset.datadir=./data/TanksAndTemple/{scene}' 89 | # ) 90 | 91 | # cmds = [] 92 | # for scene in ['room']:#,'hall','kitchen','living_room','room2','sofa','meeting_room','room','salon2' 93 | # cmds.append(f'python train_basis.py configs/colmap_new.yaml defaults.expname=indoor_{scene} ' \ 94 | # f'dataset.datadir=./data/indoor/real/{scene}' 95 | # # f'defaults.ckpt=/home/anpei/code/NeuBasis2/log/basis_ship/basis_ship.th exporation.render_only=True' 96 | # ) 97 | 98 | # cmds = [] 99 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeRF model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \ 100 | # f'model.num_layers=8 model.hidden_dim=256 ' \ 101 | # f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]') 102 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-grid') 103 | # cmds.append(f'python 2D_regression.py defaults.expname=NeuBasis-mlpB model.basis_type=mlp') 104 | # cmds.append(f'python 2D_regression.py defaults.expname=NeuBasis-mlpC model.coeff_type=mlp') 105 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=DVGO-like model.basis_type=none') 106 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-noC model.coeff_type=none') 107 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-sinc model.basis_mapping=sinc') 108 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-tria model.basis_mapping=triangle') 109 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-SL model.basis_dims=[144] model.basis_resos=[14] model.freq_bands=[73.14]') 110 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-DCT model.basis_type=fix-grid') 111 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-CP model.coeff_type=vec model.basis_type=cp \ 112 | # model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[1024,1024,1024,1024,1024,1024] model.basis_dims=[64,64,64,32,32,32]') 113 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=iNGP-like model.basis_type=hash model.coeff_type=none') 114 | # cmds.append(f'python 2D_regression.py configs/image.yaml defaults.expname=NeuBasis-hash model.basis_type=hash model.coef_init=0.1 basis_dims=[16,16,16,16,16,16]') 115 | 116 | #setting available gpus 117 | gpu_idx = [0] 118 | gpus_que = queue.Queue(len(gpu_idx)) 119 | for i in gpu_idx: 120 | gpus_que.put(i) 121 | 122 | # os.makedirs(f"log/{expFolder}", exist_ok=True) 123 | def run_program(gpu, cmd): 124 | cmd = f'{cmd} ' 125 | print(cmd) 126 | os.system(cmd) 127 | gpus_que.put(gpu) 128 | 129 | 130 | ths = [] 131 | for i in range(len(cmds)): 132 | 133 | gpu = gpus_que.get() 134 | t = threading.Thread(target=run_program, args=(gpu, cmds[i]), daemon=True) 135 | t.start() 136 | ths.append(t) 137 | 138 | for th in ths: 139 | th.join() 140 | 141 | 142 | # import os 143 | # import numpy as np 144 | # root = f'/mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/' 145 | # # root = '/cluster/home/anchen/root/Code/NeuBasis/log/' 146 | # scores = [] 147 | # # for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']: 148 | # for scene in ['Caterpillar','Family','Ignatius','Truck']: 149 | # scores.append(np.loadtxt(f'{root}/{scene}-grid/imgs_test_all/mean.txt')) 150 | # # os.system(f'cp {root}/{scene}-grid/imgs_test_all/video.mp4 /mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/video/{scene}.mp4') 151 | # os.system(f'cp {root}/{scene}-grid/{scene}-grid/imgs_path_all/video.mp4 /mnt/qb/home/geiger/zyu30/Projects/Anpei/Code/factor-fields/logs/video/{scene}.mp4') 152 | # # print(np.mean(np.stack(scores),axis=0)) -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/2D_set_regression.py: -------------------------------------------------------------------------------- 1 | import torch,imageio,sys,cmapy,time,os 2 | import numpy as np 3 | from tqdm import tqdm 4 | # from .autonotebook import tqdm as tqdm 5 | import matplotlib.pyplot as plt 6 | from omegaconf import OmegaConf 7 | import torch.nn.functional as F 8 | 9 | sys.path.append('..') 10 | from models.FactorFields import FactorFields 11 | 12 | from utils import SimpleSampler,TVLoss 13 | from dataLoader import dataset_dict 14 | from torch.utils.data import DataLoader 15 | 16 | device = 'cuda' 17 | 18 | 19 | def PSNR(a, b): 20 | if type(a).__module__ == np.__name__: 21 | mse = np.mean((a - b) ** 2) 22 | else: 23 | mse = torch.mean((a - b) ** 2).item() 24 | psnr = -10.0 * np.log(mse) / np.log(10.0) 25 | return psnr 26 | 27 | 28 | @torch.no_grad() 29 | def eval_img(aabb, reso, idx, shiftment=[0.5, 0.5, 0.5], chunk=10240): 30 | y = torch.linspace(0, aabb[0] - 1, reso[0]) 31 | x = torch.linspace(0, aabb[1] - 1, reso[1]) 32 | yy, xx = torch.meshgrid((y, x), indexing='ij') 33 | zz = torch.ones_like(xx) * idx 34 | 35 | idx = 0 36 | res = torch.empty(reso[0] * reso[1], train_dataset.imgs.shape[-1]) 37 | coordiantes = torch.stack((xx, yy, zz), dim=-1).reshape(-1, 3) + torch.tensor( 38 | shiftment) # /(torch.FloatTensor(reso[::-1])-1)*2-1 39 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)): 40 | feats, _ = model.get_coding_imgage_set(coordiante.to(model.device)) 41 | y_recon = model.linear_mat(feats, is_train=False) 42 | 43 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu() 44 | idx += y_recon.shape[0] 45 | return res.view(reso[0], reso[1], -1), coordiantes 46 | 47 | 48 | @torch.no_grad() 49 | def eval_img_single(aabb, reso, chunk=10240): 50 | y = torch.linspace(0, aabb[0] - 1, reso[0]) 51 | x = torch.linspace(0, aabb[1] - 1, reso[1]) 52 | yy, xx = torch.meshgrid((y, x), indexing='ij') 53 | 54 | idx = 0 55 | res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1]) 56 | coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + 0.5 57 | 58 | for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)): 59 | feats, _ = model.get_coding(coordiante.to(model.device)) 60 | y_recon = model.linear_mat(feats, is_train=False) 61 | 62 | res[idx:idx + y_recon.shape[0]] = y_recon.cpu() 63 | idx += y_recon.shape[0] 64 | return res.view(reso[0], reso[1], -1), coordiantes 65 | 66 | 67 | def linear_to_srgb(img): 68 | limit = 0.0031308 69 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 70 | 71 | 72 | def srgb_to_linear(img): 73 | limit = 0.04045 74 | return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92) 75 | 76 | 77 | def write_image_imageio(img_file, img, colormap=None, quality=100): 78 | if colormap == 'turbo': 79 | shape = img.shape 80 | img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1) 81 | elif colormap is not None: 82 | img = cmapy.colorize((img * 255).astype('uint8'), colormap) 83 | 84 | if img.dtype != 'uint8': 85 | img = (img - np.min(img)) / (np.max(img) - np.min(img)) 86 | img = (img * 255.0).astype(np.uint8) 87 | 88 | kwargs = {} 89 | if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]: 90 | if img.ndim >= 3 and img.shape[2] > 3: 91 | img = img[:, :, :3] 92 | kwargs["quality"] = quality 93 | kwargs["subsampling"] = 0 94 | imageio.imwrite(img_file, img, **kwargs) 95 | 96 | 97 | def interpolate(colormap, x): 98 | a = (x * 255.0).astype('uint8') 99 | b = np.clip(a + 1, 0, 255) 100 | f = x * 255.0 - a 101 | 102 | return np.stack([colormap[a][..., 0] + (colormap[b][..., 0] - colormap[a][..., 0]) * f, 103 | colormap[a][..., 1] + (colormap[b][..., 1] - colormap[a][..., 1]) * f, 104 | colormap[a][..., 2] + (colormap[b][..., 2] - colormap[a][..., 2]) * f], axis=-1) 105 | 106 | base_conf = OmegaConf.load('../configs/defaults.yaml') 107 | second_conf = OmegaConf.load('../configs/image_set.yaml') 108 | cfg = OmegaConf.merge(base_conf, second_conf) 109 | 110 | dataset = dataset_dict[cfg.dataset.dataset_name] 111 | train_dataset = dataset(cfg.dataset,cfg.training.batch_size, split='train',N=600,tolinear=True,HW=512, continue_sampling=True) 112 | train_loader = DataLoader(train_dataset, 113 | num_workers=8, 114 | persistent_workers=True, 115 | batch_size=None, 116 | pin_memory=True) 117 | 118 | 119 | batch_size = cfg.training.batch_size 120 | n_iter = cfg.training.n_iters 121 | 122 | model = FactorFields(cfg, device) 123 | tvreg = TVLoss() 124 | print(model) 125 | print('total parameters: ', model.n_parameters()) 126 | 127 | grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small, lr_large=cfg.training.lr_large) 128 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) # 129 | 130 | tv_loss = 0 131 | loss_scale = 1.0 132 | lr_factor = 0.1 ** (1 / n_iter) 133 | pbar = tqdm(range(n_iter)) 134 | for (iteration, sample) in zip(pbar, train_loader): 135 | loss_scale *= lr_factor 136 | 137 | coordiantes, pixel_rgb = sample['xy'], sample['rgb'] 138 | 139 | basis, coeff = model.get_coding_imgage_set(coordiantes.to(device)) 140 | 141 | y_recon = model.linear_mat(basis, is_train=True) 142 | # y_recon = torch.sum(basis,dim=-1,keepdim=True) 143 | l2_loss = torch.mean((y_recon.squeeze() - pixel_rgb.squeeze().to(device)) ** 2) # + 4e-3*coeff.abs().mean() 144 | 145 | # tv_loss = model.TV_loss(tvreg) 146 | loss = l2_loss # + tv_loss*10 147 | 148 | # loss = loss * loss_scale 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | # if iteration%100==0: 153 | # model.normalize_basis() 154 | 155 | psnr = -10.0 * np.log(l2_loss.item()) / np.log(10.0) 156 | if iteration % 100 == 0: 157 | pbar.set_description( 158 | f'Iteration {iteration:05d}:' 159 | + f' loss_dist = {l2_loss.item():.8f}' 160 | + f' tv_loss = {tv_loss:.6f}' 161 | + f' psnr = {psnr:.3f}' 162 | ) 163 | 164 | save_root = '../log/imageSet/ffhq_mlp_coeff_16_64_8_pe_linear_64_node_800/' 165 | os.makedirs(save_root, exist_ok=True) 166 | model.save(f'{save_root}/ckpt.th') -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/factor-fields/21ea155d70efce5f96399830cb424c444c977948/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/mesh2SDF_data_process.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "47120857-b2aa-4a36-8733-e04776ca7a80", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os,sys,trimesh\n", 11 | "import numpy as np\n", 12 | "sys.path.append('/home/anpei/Code/nglod/sdf-net')\n", 13 | "\n", 14 | "os.environ['PYOPENGL_PLATFORM'] = 'egl'\n", 15 | "os.environ['MESA_GL_VERSION_OVERRIDE'] = '3.3'\n", 16 | "os.environ['MESA_GLSL_VERSION_OVERRIDE'] = '330'\n", 17 | "\n", 18 | "import torch\n", 19 | "from lib.torchgp import load_obj, point_sample, sample_surface, compute_sdf, normalize" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "0c8bf6fa-7295-40cb-aaaa-3cfbf95350a6", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def load_obj(path):\n", 30 | " mesh = trimesh.load(path)\n", 31 | " mesh.vertices = normalize(mesh.vertices)\n", 32 | " return mesh\n", 33 | " \n", 34 | "def normalize(V):\n", 35 | "\n", 36 | " # Normalize mesh\n", 37 | " V_max = np.max(V, axis=0)\n", 38 | " V_min = np.min(V, axis=0)\n", 39 | " V_center = (V_max + V_min) / 2.\n", 40 | " V = V - V_center\n", 41 | "\n", 42 | " # Find the max distance to origin\n", 43 | " max_dist = np.sqrt(np.max(np.sum(V**2, axis=-1)))\n", 44 | " V_scale = 1. / max_dist\n", 45 | " V *= V_scale\n", 46 | " return V\n", 47 | "\n", 48 | "\n", 49 | "def resample(V,F,num_samples, chunk=10000, sample_mode=['rand', 'near', 'near', 'near', 'near']):\n", 50 | " \"\"\"Resample SDF samples.\"\"\"\n", 51 | "\n", 52 | " points, sdfs = [],[]\n", 53 | " for _ in range(num_samples//chunk):\n", 54 | "\n", 55 | " pts = point_sample(V, F, sample_mode, chunk)\n", 56 | " sdf = compute_sdf(V, F, pts.cuda()) \n", 57 | " points.append(pts.cpu())\n", 58 | " sdfs.append(sdf.cpu())\n", 59 | " return torch.cat(points), torch.cat(sdfs)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "id": "b74c1144-f228-473e-abe3-059b9ae19d9c", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "for item in ['statuette.ply','dragon.ply','armadillo.obj','lucy.ply']:#'statuette.ply','dragon.ply','armadillo.obj'\n", 70 | " dataset_path = f'/home/anpei/Dataset/mesh/obj/{item}'\n", 71 | " mesh = load_obj(dataset_path)\n", 72 | " \n", 73 | " f = SDF(mesh.vertices, mesh.faces);\n", 74 | "\n", 75 | " V = torch.from_numpy(mesh.vertices).float().cuda()\n", 76 | " F = torch.from_numpy(mesh.faces).cuda()\n", 77 | " \n", 78 | " pts_train, _ = resample(V, F, num_samples=int(8*1024*1024/5))\n", 79 | " sdf_train = f(pts_train.numpy())\n", 80 | " \n", 81 | " pts_test, _ = resample(V, F, num_samples=int(16*1024*1024//5))\n", 82 | " sdf_test = f(pts_test.numpy())\n", 83 | "\n", 84 | " np.save(f'/home/anpei/Dataset/mesh/obj/sdf/{item[:-4]}_8M',{'points_train':pts_train.numpy(),'sdfs_train':sdf_train, \\\n", 85 | " 'points_test':pts_test.numpy(),'sdfs_test':sdf_test})\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "dafbf526-e79a-47a9-bf7b-6da70497c3af", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3 (ipykernel)", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.9.7" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 5 118 | } 119 | -------------------------------------------------------------------------------- /train_across_scene.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from omegaconf import OmegaConf 3 | from models.FactorFields import FactorFields 4 | 5 | import json, random,time 6 | from renderer import * 7 | from utils import * 8 | from torch.utils.tensorboard import SummaryWriter 9 | import datetime 10 | from torch.utils.data import DataLoader 11 | 12 | from dataLoader import dataset_dict 13 | import sys 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | class SimpleSampler: 18 | def __init__(self, total, batch): 19 | self.total = total 20 | self.batch = batch 21 | self.curr = total 22 | self.ids = None 23 | 24 | def nextids(self): 25 | self.curr += self.batch 26 | if self.curr + self.batch > self.total: 27 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 28 | self.curr = 0 29 | return self.ids[self.curr:self.curr + self.batch] 30 | 31 | 32 | @torch.no_grad() 33 | def export_mesh(cfg): 34 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 35 | model = FactorFields( ckpt['cfg'], device) 36 | model.load(ckpt) 37 | 38 | alpha, _ = model.getDenseAlpha([512]*3) 39 | convert_sdf_samples_to_ply(alpha.cpu(), f'{cfg.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.2) 40 | 41 | # @torch.no_grad() 42 | # def export_mesh(cfg, downsample=1, n_views=100): 43 | # cfg.dataset.downsample_train = downsample 44 | # dataset = dataset_dict[cfg.dataset.dataset_name] 45 | # train_dataset = dataset(cfg.dataset, split='train',is_stack=True) 46 | # 47 | # ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 48 | # model = FactorFields( ckpt['cfg'], device) 49 | # model.load(ckpt) 50 | # 51 | # output_dir = f'{cfg.defaults.ckpt[:-3]}.ply' 52 | # export_tsdf_mesh(model, train_dataset, render_ray, white_bg=train_dataset.white_bg, output_dir=output_dir, n_views=n_views) 53 | 54 | @torch.no_grad() 55 | def render_test(cfg): 56 | # init dataset 57 | dataset = dataset_dict[cfg.dataset.dataset_name] 58 | test_dataset = dataset(cfg.dataset, split='test') 59 | white_bg = test_dataset.white_bg 60 | ndc_ray = cfg.dataset.ndc_ray 61 | 62 | if not os.path.exists(cfg.defaults.ckpt): 63 | print('the ckpt path does not exists!!') 64 | return 65 | 66 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 67 | model = FactorFields( ckpt['cfg'], device) 68 | model.load(ckpt) 69 | 70 | 71 | logfolder = os.path.dirname(cfg.defaults.ckpt) 72 | if cfg.exportation.render_train: 73 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 74 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True) 75 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/', 76 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 77 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================') 78 | 79 | if cfg.exportation.render_test: 80 | # model.upsample_volume_grid() 81 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True) 82 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/', 83 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 84 | n_params = model.n_parameters() 85 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 86 | 87 | 88 | if cfg.exportation.render_path: 89 | c2ws = test_dataset.render_path 90 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True) 91 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/', 92 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 93 | 94 | if cfg.exportation.export_mesh: 95 | alpha, _ = model.getDenseAlpha(times=1) 96 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(),level=0.02) 97 | 98 | 99 | def reconstruction(cfg): 100 | # init dataset 101 | dataset = dataset_dict[cfg.dataset.dataset_name] 102 | train_dataset = dataset(cfg.dataset, split='train', batch_size=cfg.training.batch_size) 103 | test_dataset = dataset(cfg.dataset, split='test') 104 | white_bg = train_dataset.white_bg 105 | ndc_ray = cfg.dataset.ndc_ray 106 | 107 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) 108 | 109 | # init resolution 110 | upsamp_list = cfg.training.upsamp_list 111 | update_AlphaMask_list = cfg.training.update_AlphaMask_list 112 | 113 | if cfg.defaults.add_timestamp: 114 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' 115 | else: 116 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}' 117 | 118 | # init log file 119 | os.makedirs(logfolder, exist_ok=True) 120 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) 121 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) 122 | os.makedirs(f'{logfolder}/rgba', exist_ok=True) 123 | summary_writer = SummaryWriter(logfolder) 124 | 125 | cfg.dataset.aabb = train_dataset.scene_bbox 126 | 127 | if cfg.defaults.ckpt is not None: 128 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 129 | # cfg = ckpt['cfg'] 130 | model = FactorFields(cfg, device) 131 | model.load(ckpt) 132 | else: 133 | model = FactorFields(cfg, device) 134 | print(model) 135 | 136 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 137 | if cfg.training.lr_decay_iters > 0: 138 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters) 139 | else: 140 | cfg.training.lr_decay_iters = cfg.training.n_iters 141 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters) 142 | 143 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters) 144 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 145 | 146 | # linear in logrithmic space 147 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal, 148 | len(cfg.training.upsamp_list)).ceil().long().tolist() 149 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb) 150 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio)) 151 | 152 | torch.cuda.empty_cache() 153 | PSNRs, PSNRs_test = [], [0] 154 | 155 | steps_inner = 16 156 | start = time.time() 157 | pbar = tqdm(range(cfg.training.n_iters//steps_inner), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout) 158 | for iteration in pbar: 159 | 160 | # train_dataset.update_index() 161 | scene_idx = torch.randint(0, len(train_dataset.all_rgb_files), (1,)).item() 162 | model.scene_idx = scene_idx 163 | for j in range(steps_inner): 164 | 165 | if j%steps_inner==0: 166 | model.set_optimizable(['coef'], True) 167 | model.set_optimizable(['proj','basis','renderer'], False) 168 | elif j%steps_inner==steps_inner-3: 169 | model.set_optimizable(['coef'], False) 170 | model.set_optimizable(['mlp', 'basis','renderer'], True) 171 | 172 | data = train_dataset[scene_idx] #next(iterator) 173 | rays_train, rgb_train = data['rays'].view(-1,6), data['rgbs'].view(-1,3).to(device) 174 | 175 | 176 | rgb_map, depth_map, coefffs = render_ray(rays_train, model, chunk=cfg.training.batch_size, 177 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device, 178 | is_train=True) 179 | 180 | loss = torch.mean((rgb_map - rgb_train) ** 2) #+ torch.mean(coefffs.abs())*1e-4 181 | 182 | # loss 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | 187 | loss = loss.detach().item() 188 | 189 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0)) 190 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration) 191 | summary_writer.add_scalar('train/mse', loss, global_step=iteration) 192 | 193 | for param_group in optimizer.param_groups: 194 | param_group['lr'] = param_group['lr'] * lr_factor 195 | 196 | # Print the current values of the losses. 197 | pbar.set_description( 198 | f'Iteration {iteration:05d}:' 199 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}' 200 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' 201 | + f' mse = {loss:.6f}' 202 | ) 203 | PSNRs = [] 204 | 205 | 206 | time_iter = time.time()-start 207 | print(f'=======> time takes: {time_iter} <=============') 208 | os.makedirs(f'{logfolder}/imgs_test_all/', exist_ok=True) 209 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter]) 210 | model.save(f'{logfolder}/{cfg.defaults.expname}.th') 211 | 212 | if cfg.render_train: 213 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 214 | train_dataset = dataset(cfg.defaults.datadir, split='train', downsample=args.downsample_train, is_stack=True) 215 | PSNRs_test = evaluation(train_dataset,model, args, renderer, f'{logfolder}/imgs_train_all/', 216 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 217 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 218 | 219 | if cfg.exportation.render_test: 220 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 221 | if 'reconstructions' in cfg.defaults.mode: 222 | model.scene_idx = test_dataset.test_index 223 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/', 224 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 225 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) 226 | n_params = model.n_parameters() 227 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 228 | 229 | if cfg.exportation.export_mesh: 230 | cfg.defaults.ckpt = f'{logfolder}/{cfg.defaults.expname}.th' 231 | export_mesh(cfg) 232 | 233 | if __name__ == '__main__': 234 | 235 | torch.set_default_dtype(torch.float32) 236 | torch.manual_seed(20211202) 237 | np.random.seed(20211202) 238 | 239 | base_conf = OmegaConf.load('configs/defaults.yaml') 240 | print(sys.argv) 241 | path_config = sys.argv[1] 242 | cli_conf = OmegaConf.from_cli() 243 | second_conf = OmegaConf.load(path_config) 244 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf) 245 | print(cfg) 246 | 247 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path): 248 | render_test(cfg) 249 | elif cfg.exportation.export_mesh_only: 250 | export_mesh(cfg) 251 | else: 252 | reconstruction(cfg) -------------------------------------------------------------------------------- /train_across_scene_ft.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from omegaconf import OmegaConf 3 | from models.FactorFields import FactorFields 4 | 5 | import json, random,time 6 | from renderer import * 7 | from utils import * 8 | from torch.utils.tensorboard import SummaryWriter 9 | import datetime 10 | from torch.utils.data import DataLoader 11 | 12 | from dataLoader import dataset_dict 13 | import sys 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | class SimpleSampler: 18 | def __init__(self, total, batch): 19 | self.total = total 20 | self.batch = batch 21 | self.curr = total 22 | self.ids = None 23 | 24 | def nextids(self): 25 | self.curr += self.batch 26 | if self.curr + self.batch > self.total: 27 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 28 | self.curr = 0 29 | return self.ids[self.curr:self.curr + self.batch] 30 | 31 | 32 | @torch.no_grad() 33 | def export_mesh(cfg): 34 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 35 | model = FactorFields( ckpt['cfg'], device) 36 | model.load(ckpt) 37 | 38 | alpha, _ = model.getDenseAlpha([512]*3) 39 | convert_sdf_samples_to_ply(alpha.cpu(), f'{cfg.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.2) 40 | 41 | @torch.no_grad() 42 | def render_test(cfg): 43 | # init dataset 44 | dataset = dataset_dict[cfg.dataset.dataset_name] 45 | test_dataset = dataset(cfg.dataset, split='test') 46 | white_bg = test_dataset.white_bg 47 | ndc_ray = cfg.dataset.ndc_ray 48 | 49 | if not os.path.exists(cfg.defaults.ckpt): 50 | print('the ckpt path does not exists!!') 51 | return 52 | 53 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 54 | cfg.dataset.aabb = test_dataset.scene_bbox 55 | model = FactorFields(cfg, device) 56 | 57 | model.load(ckpt) 58 | 59 | 60 | logfolder = os.path.dirname(cfg.defaults.ckpt) 61 | if cfg.exportation.render_train: 62 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 63 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True) 64 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/', 65 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 66 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================') 67 | 68 | if cfg.exportation.render_test: 69 | # model.upsample_volume_grid() 70 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True) 71 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/', 72 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 73 | n_params = model.n_parameters() 74 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 75 | 76 | 77 | if cfg.exportation.render_path: 78 | c2ws = test_dataset.render_path 79 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True) 80 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/', 81 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 82 | 83 | if cfg.exportation.export_mesh: 84 | alpha, _ = model.getDenseAlpha(times=1) 85 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(),level=0.02) 86 | 87 | 88 | def reconstruction(cfg): 89 | # init dataset 90 | dataset = dataset_dict[cfg.dataset.dataset_name] 91 | train_dataset = dataset(cfg.dataset, split='train', batch_size=cfg.training.batch_size) 92 | test_dataset = dataset(cfg.dataset, split='test') 93 | white_bg = train_dataset.white_bg 94 | ndc_ray = cfg.dataset.ndc_ray 95 | 96 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) 97 | 98 | # init resolution 99 | upsamp_list = cfg.training.upsamp_list 100 | update_AlphaMask_list = cfg.training.update_AlphaMask_list 101 | 102 | if cfg.defaults.add_timestamp: 103 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' 104 | else: 105 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}' 106 | 107 | # init log file 108 | os.makedirs(logfolder, exist_ok=True) 109 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) 110 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) 111 | os.makedirs(f'{logfolder}/rgba', exist_ok=True) 112 | summary_writer = SummaryWriter(logfolder) 113 | 114 | cfg.dataset.aabb = train_dataset.scene_bbox 115 | 116 | if cfg.defaults.ckpt is not None: 117 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 118 | model = FactorFields(cfg, device) 119 | model.load(ckpt) 120 | else: 121 | model = FactorFields(cfg, device) 122 | print(model) 123 | 124 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 125 | if cfg.training.lr_decay_iters > 0: 126 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters) 127 | else: 128 | cfg.training.lr_decay_iters = cfg.training.n_iters 129 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters) 130 | 131 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters) 132 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 133 | 134 | # linear in logrithmic space 135 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal, 136 | len(cfg.training.upsamp_list)).ceil().long().tolist() 137 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb) 138 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio)) 139 | 140 | torch.cuda.empty_cache() 141 | PSNRs, PSNRs_test = [], [0] 142 | 143 | start = time.time() 144 | iterator = iter(trainLoader) 145 | pbar = tqdm(range(cfg.training.n_iters), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout) 146 | 147 | for iteration in pbar: 148 | 149 | data = next(iterator) 150 | rays_train, rgb_train = data['rays'].view(-1,6), data['rgbs'].view(-1,3).to(device) 151 | if 'idx' in data.keys(): 152 | model.scene_idx = data['idx'] 153 | 154 | rgb_map, depth_map, coefffs = render_ray(rays_train, model, chunk=cfg.training.batch_size, 155 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device, 156 | is_train=True) 157 | 158 | loss = torch.mean((rgb_map - rgb_train) ** 2) 159 | 160 | # loss 161 | total_loss = loss 162 | 163 | optimizer.zero_grad() 164 | total_loss.backward() 165 | optimizer.step() 166 | 167 | loss = loss.detach().item() 168 | 169 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0)) 170 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration) 171 | summary_writer.add_scalar('train/mse', loss, global_step=iteration) 172 | 173 | for param_group in optimizer.param_groups: 174 | param_group['lr'] = param_group['lr'] * lr_factor 175 | 176 | # Print the current values of the losses. 177 | if iteration % cfg.defaults.progress_refresh_rate == 0: 178 | pbar.set_description( 179 | f'Iteration {iteration:05d}:' 180 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}' 181 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' 182 | + f' mse = {loss:.6f}' 183 | ) 184 | PSNRs = [] 185 | 186 | if iteration % cfg.dataset.vis_every == cfg.dataset.vis_every - 1: 187 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_vis/', 188 | N_vis=cfg.dataset.N_vis, 189 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, 190 | compute_extra_metrics=False) 191 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration) 192 | 193 | if iteration in update_AlphaMask_list or iteration in cfg.training.shrinking_list: 194 | 195 | if volume_resoList[0] < 256: # update volume resolution 196 | reso_mask = N_to_reso(volume_resoList[0]**model.in_dim, model.aabb) 197 | 198 | new_aabb = model.updateAlphaMask([cfg.model.coeff_reso]*3,is_update_alphaMask=True) 199 | 200 | if iteration in cfg.training.shrinking_list: 201 | model.shrink(new_aabb) 202 | L1_reg_weight = cfg.training.L1_weight_rest 203 | 204 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 205 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 206 | print("continuing L1_reg_weight", L1_reg_weight) 207 | 208 | if not cfg.dataset.ndc_ray and iteration == update_AlphaMask_list[0] and not cfg.dataset.is_unbound: 209 | # filter rays outside the bbox 210 | train_dataset.all_rays, train_dataset.all_rgbs = model.filtering_rays(train_dataset.all_rays, train_dataset.all_rgbs) 211 | trainLoader = DataLoader(train_dataset, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) 212 | iterator = iter(trainLoader) 213 | 214 | 215 | if iteration in upsamp_list: 216 | n_voxels = volume_resoList.pop(0) 217 | reso_cur = N_to_reso(n_voxels**model.in_dim, model.aabb) 218 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio)) 219 | model.upsample_volume_grid(reso_cur) 220 | 221 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 222 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 223 | torch.cuda.empty_cache() 224 | 225 | if iteration==3000: 226 | model.cfg.training.renderModule = True 227 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 228 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 229 | 230 | time_iter = time.time()-start 231 | print(f'=======> time takes: {time_iter} <=============') 232 | os.makedirs(f'{logfolder}/imgs_test_all/', exist_ok=True) 233 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter]) 234 | model.save(f'{logfolder}/{cfg.defaults.expname}.th') 235 | 236 | if args.render_train: 237 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 238 | train_dataset = dataset(cfg.defaults.datadir, split='train', downsample=args.downsample_train, is_stack=True) 239 | PSNRs_test = evaluation(train_dataset,model, args, renderer, f'{logfolder}/imgs_train_all/', 240 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 241 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 242 | 243 | if cfg.exportation.render_test: 244 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 245 | if 'reconstructions' in cfg.defaults.mode: 246 | model.scene_idx = test_dataset.test_index 247 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/', 248 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 249 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) 250 | n_params = model.n_parameters() 251 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 252 | 253 | if cfg.exportation.export_mesh: 254 | cfg.defaults.ckpt = f'{logfolder}/{cfg.defaults.expname}.th' 255 | export_mesh(cfg) 256 | 257 | if __name__ == '__main__': 258 | 259 | torch.set_default_dtype(torch.float32) 260 | torch.manual_seed(20211202) 261 | np.random.seed(20211202) 262 | 263 | base_conf = OmegaConf.load('configs/defaults.yaml') 264 | path_config = sys.argv[1] 265 | cli_conf = OmegaConf.from_cli() 266 | second_conf = OmegaConf.load(path_config) 267 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf) 268 | print(cfg) 269 | 270 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path): 271 | render_test(cfg) 272 | elif cfg.exportation.export_mesh_only: 273 | export_mesh(cfg) 274 | else: 275 | reconstruction(cfg) -------------------------------------------------------------------------------- /train_per_scene.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | from omegaconf import OmegaConf 3 | from models.FactorFields import FactorFields 4 | 5 | import json, random,time 6 | from renderer import * 7 | from utils import * 8 | from torch.utils.tensorboard import SummaryWriter 9 | import datetime 10 | 11 | from dataLoader import dataset_dict 12 | import sys 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | class SimpleSampler: 17 | def __init__(self, total, batch): 18 | self.total = total 19 | self.batch = batch 20 | self.curr = total 21 | self.ids = None 22 | 23 | def nextids(self): 24 | self.curr += self.batch 25 | if self.curr + self.batch > self.total: 26 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 27 | self.curr = 0 28 | return self.ids[self.curr:self.curr + self.batch] 29 | 30 | 31 | @torch.no_grad() 32 | def export_mesh(ckpt_path): 33 | ckpt = torch.load(ckpt_path, map_location=device) 34 | cfg = ckpt['cfg'] 35 | cfg.defaults.device = device 36 | model = FactorFields(cfg) 37 | model.load(ckpt) 38 | 39 | alpha, _ = model.getDenseAlpha() 40 | convert_sdf_samples_to_ply(alpha.cpu(), f'{ckpt.defaults.ckpt[:-3]}.ply', bbox=model.aabb.cpu(), level=0.005) 41 | 42 | 43 | @torch.no_grad() 44 | def render_test(cfg): 45 | # init dataset 46 | dataset = dataset_dict[cfg.dataset.dataset_name] 47 | test_dataset = dataset(cfg.dataset, split='test') 48 | white_bg = test_dataset.white_bg 49 | ndc_ray = cfg.dataset.ndc_ray 50 | 51 | if not os.path.exists(cfg.defaults.ckpt): 52 | print('the ckpt path does not exists!!') 53 | return 54 | 55 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 56 | model = FactorFields( ckpt['cfg'], device) 57 | model.load(ckpt) 58 | 59 | 60 | logfolder = os.path.dirname(cfg.defaults.ckpt) 61 | if cfg.exportation.render_train: 62 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 63 | train_dataset = dataset(cfg.dataset.datadir, split='train', is_stack=True) 64 | PSNRs_test = evaluation(train_dataset, model, render_ray, f'{logfolder}/imgs_train_all/', 65 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 66 | print(f'======> {cfg.defaults.expname} train all psnr: {np.mean(PSNRs_test)} <========================') 67 | 68 | if cfg.exportation.render_test: 69 | # model.upsample_volume_grid() 70 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_test_all', exist_ok=True) 71 | evaluation(test_dataset, model, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_test_all/', 72 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 73 | n_params = model.n_parameters() 74 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 75 | 76 | 77 | if cfg.exportation.render_path: 78 | c2ws = test_dataset.render_path 79 | os.makedirs(f'{logfolder}/{cfg.defaults.expname}/imgs_path_all', exist_ok=True) 80 | evaluation_path(test_dataset, model, c2ws, render_ray, f'{logfolder}/{cfg.defaults.expname}/imgs_path_all/', 81 | N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 82 | 83 | if cfg.exportation.export_mesh: 84 | alpha, _ = model.getDenseAlpha() 85 | convert_sdf_samples_to_ply(alpha.cpu(), f'{logfolder}/{cfg.defaults.expname}.ply', bbox=model.aabb.cpu(), 86 | level=0.005) 87 | 88 | 89 | def reconstruction(cfg): 90 | # init dataset 91 | dataset = dataset_dict[cfg.dataset.dataset_name] 92 | train_dataset = dataset(cfg.dataset, split='train') 93 | test_dataset = dataset(cfg.dataset, split='test') 94 | white_bg = train_dataset.white_bg 95 | ndc_ray = cfg.dataset.ndc_ray 96 | 97 | # init resolution 98 | upsamp_list = cfg.training.upsamp_list 99 | update_AlphaMask_list = cfg.training.update_AlphaMask_list 100 | 101 | if cfg.defaults.add_timestamp: 102 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' 103 | else: 104 | logfolder = f'{cfg.defaults.logdir}/{cfg.defaults.expname}' 105 | 106 | # init log file 107 | os.makedirs(logfolder, exist_ok=True) 108 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) 109 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) 110 | os.makedirs(f'{logfolder}/rgba', exist_ok=True) 111 | summary_writer = SummaryWriter(logfolder) 112 | 113 | cfg.dataset.aabb = aabb = train_dataset.scene_bbox 114 | 115 | if cfg.defaults.ckpt is not None: 116 | ckpt = torch.load(cfg.defaults.ckpt, map_location=device) 117 | cfg = ckpt['cfg'] 118 | model = FactorFields(cfg, device) 119 | model.load(ckpt) 120 | else: 121 | model = FactorFields(cfg, device) 122 | print(model) 123 | 124 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 125 | if cfg.training.lr_decay_iters > 0: 126 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.lr_decay_iters) 127 | else: 128 | cfg.training.lr_decay_iters = cfg.training.n_iters 129 | lr_factor = cfg.training.lr_decay_target_ratio ** (1 / cfg.training.n_iters) 130 | 131 | print("lr decay", cfg.training.lr_decay_target_ratio, cfg.training.lr_decay_iters) 132 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 133 | 134 | # linear in logrithmic space 135 | volume_resoList = torch.linspace(cfg.training.volume_resoInit, cfg.training.volume_resoFinal, 136 | len(cfg.training.upsamp_list)).ceil().long().tolist() 137 | reso_cur = N_to_reso(cfg.training.volume_resoInit**model.in_dim, model.aabb) 138 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio)) 139 | 140 | torch.cuda.empty_cache() 141 | PSNRs, PSNRs_test = [], [0] 142 | 143 | allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs 144 | trainingSampler = SimpleSampler(allrays.shape[0], cfg.training.batch_size) 145 | 146 | 147 | start = time.time() 148 | pbar = tqdm(range(cfg.training.n_iters), miniters=cfg.defaults.progress_refresh_rate, file=sys.stdout) 149 | for iteration in pbar: 150 | 151 | ray_idx = trainingSampler.nextids() 152 | rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device) 153 | 154 | rgb_map, depth_map, _ = render_ray(rays_train, model, chunk=cfg.training.batch_size, 155 | N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, device=device, 156 | is_train=True) 157 | 158 | loss = torch.mean((rgb_map - rgb_train) ** 2) 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | loss = loss.detach().item() 165 | 166 | PSNRs.append(-10.0 * np.log(loss) / np.log(10.0)) 167 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration) 168 | summary_writer.add_scalar('train/mse', loss, global_step=iteration) 169 | 170 | for param_group in optimizer.param_groups: 171 | param_group['lr'] = param_group['lr'] * lr_factor 172 | 173 | # Print the current values of the losses. 174 | if iteration % cfg.defaults.progress_refresh_rate == 0: 175 | pbar.set_description( 176 | f'Iteration {iteration:05d}:' 177 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}' 178 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' 179 | + f' mse = {loss:.6f}' 180 | ) 181 | PSNRs = [] 182 | 183 | if iteration % cfg.dataset.vis_every == cfg.dataset.vis_every - 1: 184 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_vis/', 185 | N_vis=cfg.dataset.N_vis, 186 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, 187 | compute_extra_metrics=False) 188 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration) 189 | 190 | if iteration in update_AlphaMask_list or iteration in cfg.training.shrinking_list: 191 | 192 | if volume_resoList[0] < 256: # update volume resolution 193 | reso_mask = N_to_reso(volume_resoList[0]**model.in_dim, model.aabb) 194 | 195 | new_aabb = model.updateAlphaMask(tuple(reso_mask), is_update_alphaMask=iteration >= 1500) 196 | 197 | 198 | if iteration in cfg.training.shrinking_list: 199 | model.shrink(new_aabb) 200 | L1_reg_weight = cfg.training.L1_weight_rest 201 | 202 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 203 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 204 | print("continuing L1_reg_weight", L1_reg_weight) 205 | 206 | if not cfg.dataset.ndc_ray and iteration == update_AlphaMask_list[0] and not cfg.dataset.is_unbound: 207 | # filter rays outside the bbox 208 | allrays, allrgbs = model.filtering_rays(allrays, allrgbs) 209 | trainingSampler = SimpleSampler(allrgbs.shape[0], cfg.training.batch_size) 210 | 211 | 212 | if iteration in upsamp_list: 213 | n_voxels = volume_resoList.pop(0) 214 | reso_cur = N_to_reso(n_voxels**model.in_dim, model.aabb) 215 | nSamples = min(cfg.renderer.max_samples, cal_n_samples(reso_cur, cfg.renderer.step_ratio)) 216 | model.upsample_volume_grid(reso_cur) 217 | 218 | grad_vars = model.get_optparam_groups(cfg.training.lr_small, cfg.training.lr_large) 219 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 220 | torch.cuda.empty_cache() 221 | 222 | time_iter = time.time()-start 223 | print(f'=======> time takes: {time_iter} <=============') 224 | os.makedirs(f'{logfolder}/imgs_test_all',exist_ok=True) 225 | np.savetxt(f'{logfolder}/imgs_test_all/time.txt',[time_iter]) 226 | model.save(f'{logfolder}/{cfg.defaults.expname}.th') 227 | 228 | if cfg.exportation.render_test: 229 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 230 | PSNRs_test = evaluation(test_dataset, model, render_ray, f'{logfolder}/imgs_test_all/', 231 | N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device) 232 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) 233 | n_params = model.n_parameters() 234 | print(f'======> {cfg.defaults.expname} test all psnr: {np.mean(PSNRs_test)} n_params: {n_params} <========================') 235 | 236 | 237 | if __name__ == '__main__': 238 | 239 | torch.set_default_dtype(torch.float32) 240 | torch.manual_seed(20211202) 241 | np.random.seed(20211202) 242 | 243 | base_conf = OmegaConf.load('configs/defaults.yaml') 244 | path_config = sys.argv[1] 245 | cli_conf = OmegaConf.from_cli() 246 | second_conf = OmegaConf.load(path_config) 247 | cfg = OmegaConf.merge(base_conf, second_conf, cli_conf) 248 | 249 | if cfg.exportation.render_only and (cfg.exportation.render_test or cfg.exportation.render_path): 250 | render_test(cfg) 251 | elif cfg.exportation.export_mesh_only or cfg.exportation.export_mesh: 252 | export_mesh(cfg) 253 | else: 254 | reconstruction(cfg) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2,torch,math 2 | import numpy as np 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import torch.nn.functional as F 6 | import scipy.signal 7 | 8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 9 | 10 | 11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 12 | """ 13 | depth: (H, W) 14 | """ 15 | 16 | x = np.nan_to_num(depth) # change nan to 0 17 | if minmax is None: 18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 19 | ma = np.max(x) 20 | else: 21 | mi,ma = minmax 22 | 23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 24 | x = (255*x).astype(np.uint8) 25 | x_ = cv2.applyColorMap(x, cmap) 26 | return x_, [mi,ma] 27 | 28 | def init_log(log, keys): 29 | for key in keys: 30 | log[key] = torch.tensor([0.0], dtype=float) 31 | return log 32 | 33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 34 | """ 35 | depth: (H, W) 36 | """ 37 | if type(depth) is not np.ndarray: 38 | depth = depth.cpu().numpy() 39 | 40 | x = np.nan_to_num(depth) # change nan to 0 41 | if minmax is None: 42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 43 | ma = np.max(x) 44 | else: 45 | mi,ma = minmax 46 | 47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 48 | x = (255*x).astype(np.uint8) 49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 50 | x_ = T.ToTensor()(x_) # (3, H, W) 51 | return x_, [mi,ma] 52 | 53 | def N_to_reso(n_voxels, bbox): 54 | xyz_min, xyz_max = bbox 55 | dim = len(xyz_min) 56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 57 | return torch.round((xyz_max - xyz_min) / voxel_size).long().tolist() 58 | 59 | def N_to_vm_reso(n_voxels, bbox): 60 | xyz_min, xyz_max = bbox 61 | dim = len(xyz_min) 62 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 63 | reso = (xyz_max - xyz_min) / voxel_size 64 | assert len(reso)==3 65 | n_mat = reso[0]*reso[1] + reso[0]*reso[2] + reso[1]*reso[2] 66 | scale = math.sqrt(n_voxels/n_mat) 67 | return torch.round(reso*scale).long().tolist() 68 | 69 | def cal_n_samples(reso, step_ratio=0.5): 70 | return int(np.linalg.norm(reso)/step_ratio) 71 | 72 | class SimpleSampler: 73 | def __init__(self, total, batch): 74 | self.total = total 75 | self.batch = batch 76 | self.curr = total 77 | self.ids = None 78 | 79 | def nextids(self): 80 | self.curr+=self.batch 81 | if self.curr + self.batch > self.total: 82 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 83 | self.curr = 0 84 | return self.ids[self.curr:self.curr+self.batch] 85 | 86 | __LPIPS__ = {} 87 | def init_lpips(net_name, device): 88 | assert net_name in ['alex', 'vgg'] 89 | import lpips 90 | print(f'init_lpips: lpips_{net_name}') 91 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 92 | 93 | def rgb_lpips(np_gt, np_im, net_name, device): 94 | if net_name not in __LPIPS__: 95 | __LPIPS__[net_name] = init_lpips(net_name, device) 96 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 97 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 98 | return __LPIPS__[net_name](gt, im, normalize=True).item() 99 | 100 | 101 | def findItem(items, target): 102 | for one in items: 103 | if one[:len(target)]==target: 104 | return one 105 | return None 106 | 107 | 108 | ''' Evaluation metrics (ssim, lpips) 109 | ''' 110 | def rgb_ssim(img0, img1, max_val, 111 | filter_size=11, 112 | filter_sigma=1.5, 113 | k1=0.01, 114 | k2=0.03, 115 | return_map=False): 116 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 117 | assert len(img0.shape) == 3 118 | assert img0.shape[-1] == 3 119 | assert img0.shape == img1.shape 120 | 121 | # Construct a 1D Gaussian blur filter. 122 | hw = filter_size // 2 123 | shift = (2 * hw - filter_size + 1) / 2 124 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 125 | filt = np.exp(-0.5 * f_i) 126 | filt /= np.sum(filt) 127 | 128 | # Blur in x and y (faster than the 2D convolution). 129 | def convolve2d(z, f): 130 | return scipy.signal.convolve2d(z, f, mode='valid') 131 | 132 | filt_fn = lambda z: np.stack([ 133 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 134 | for i in range(z.shape[-1])], -1) 135 | mu0 = filt_fn(img0) 136 | mu1 = filt_fn(img1) 137 | mu00 = mu0 * mu0 138 | mu11 = mu1 * mu1 139 | mu01 = mu0 * mu1 140 | sigma00 = filt_fn(img0**2) - mu00 141 | sigma11 = filt_fn(img1**2) - mu11 142 | sigma01 = filt_fn(img0 * img1) - mu01 143 | 144 | # Clip the variances and covariances to valid values. 145 | # Variance must be non-negative: 146 | sigma00 = np.maximum(0., sigma00) 147 | sigma11 = np.maximum(0., sigma11) 148 | sigma01 = np.sign(sigma01) * np.minimum( 149 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 150 | c1 = (k1 * max_val)**2 151 | c2 = (k2 * max_val)**2 152 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 153 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 154 | ssim_map = numer / denom 155 | ssim = np.mean(ssim_map) 156 | return ssim_map if return_map else ssim 157 | 158 | 159 | import torch.nn as nn 160 | class TVLoss(nn.Module): 161 | def __init__(self,TVLoss_weight=1): 162 | super(TVLoss,self).__init__() 163 | self.TVLoss_weight = TVLoss_weight 164 | 165 | def forward(self,x): 166 | batch_size = x.size()[0] 167 | h_x = x.size()[2] 168 | w_x = x.size()[3] 169 | count_h = self._tensor_size(x[:,:,1:,:]) 170 | count_w = self._tensor_size(x[:,:,:,1:]) 171 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 172 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 173 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 174 | 175 | def _tensor_size(self,t): 176 | return t.size()[1]*t.size()[2]*t.size()[3] 177 | 178 | def marchcude_to_world(vertices, reso_WHD): 179 | return vertices/(np.array(reso_WHD)-1) 180 | 181 | import plyfile 182 | import skimage.measure 183 | def convert_sdf_samples_to_ply( 184 | pytorch_3d_sdf_tensor, 185 | ply_filename_out, 186 | bbox, 187 | level=0.5, 188 | offset=None, 189 | scale=None, 190 | ): 191 | """ 192 | Convert sdf samples to .ply 193 | 194 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 195 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 196 | :voxel_size: float, the size of the voxels 197 | :ply_filename_out: string, path of the filename to save to 198 | 199 | This function adapted from: https://github.com/RobotLocomotion/spartan 200 | """ 201 | 202 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 203 | # voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) 204 | 205 | verts, faces, normals, values = skimage.measure.marching_cubes( 206 | numpy_3d_sdf_tensor, level=level 207 | ) 208 | reso_WHD = numpy_3d_sdf_tensor.shape 209 | print(bbox) 210 | verts = marchcude_to_world(verts, reso_WHD) 211 | 212 | 213 | faces = faces[...,::-1] # inverse face orientation 214 | 215 | # transform from voxel coordinates to camera coordinates 216 | # note x and y are flipped in the output of marching_cubes 217 | mesh_points = np.zeros_like(verts) 218 | bbox = bbox.numpy() 219 | mesh_points[:, 0] = bbox[0,2] + verts[:, 0]*(bbox[1,2]-bbox[0,2]) 220 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1]*(bbox[1,1]-bbox[0,1]) 221 | mesh_points[:, 2] = bbox[0,0] + verts[:, 2]*(bbox[1,0]-bbox[0,0]) 222 | 223 | # # apply additional offset and scale 224 | # if scale is not None: 225 | # mesh_points = mesh_points / scale 226 | # if offset is not None: 227 | # mesh_points = mesh_points - offset 228 | 229 | # try writing to the ply file 230 | 231 | num_verts = verts.shape[0] 232 | num_faces = faces.shape[0] 233 | 234 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 235 | 236 | for i in range(0, num_verts): 237 | verts_tuple[i] = tuple(mesh_points[i, :]) 238 | 239 | faces_building = [] 240 | for i in range(0, num_faces): 241 | faces_building.append(((faces[i, :].tolist(),))) 242 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 243 | 244 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 245 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 246 | 247 | ply_data = plyfile.PlyData([el_verts, el_faces]) 248 | print("saving mesh to %s" % (ply_filename_out)) 249 | ply_data.write(ply_filename_out) 250 | --------------------------------------------------------------------------------