├── models ├── __init__.py ├── init_net │ ├── __pycache__ │ │ ├── dvgo.cpython-37.pyc │ │ ├── grid.cpython-37.pyc │ │ ├── run.cpython-37.pyc │ │ └── utils.cpython-37.pyc │ └── utils.py ├── cuda │ ├── ub360_utils.cpp │ ├── total_variation.cpp │ ├── ub360_utils_kernel.cu │ ├── total_variation_kernel.cu │ ├── adam_upd.cpp │ ├── grid_sample_1d.cpp │ ├── search_geo_hier.cpp │ ├── adam_upd_kernel.cu │ ├── search_dbasis.cpp │ ├── search_geo_adapt.cpp │ └── render_utils.cpp ├── masked_adam.py └── sh.py ├── mvs ├── mvsnet │ ├── __init__.py │ ├── mvsnet.py │ └── module.py ├── renderer.py └── networks.py ├── preprocessing ├── gmm_torch │ ├── __init__.py │ ├── example.png │ ├── README.md │ ├── LICENSE.md │ ├── utils.py │ ├── example.py │ └── test.py ├── cluster.py ├── recon_prior.py └── recon_prior_hier.py ├── image ├── main_fig.png ├── rot-toy.jpg ├── USC-Logos.png ├── ucsd_logo.png ├── Adobe-Logos.png └── visualization.png ├── .gitignore ├── dataLoader ├── __init__.py ├── data_utils.py ├── your_own_data.py ├── nsvf.py ├── tankstemple.py └── ray_utils.py ├── configs ├── 360 │ ├── 360_garden.txt │ └── 360_room.txt ├── synthetic-nerf │ ├── default │ │ ├── ship.txt │ │ ├── lego.txt │ │ ├── hotdog.txt │ │ ├── materials.txt │ │ ├── ficus.txt │ │ ├── drums.txt │ │ ├── mic.txt │ │ └── chair.txt │ └── local_vm │ │ ├── chair.txt │ │ ├── ficus.txt │ │ ├── ship.txt │ │ ├── hotdog.txt │ │ ├── lego.txt │ │ ├── mic.txt │ │ ├── drums.txt │ │ └── materials.txt ├── TanksAndTemple │ ├── Ignatius.txt │ ├── Caterpillar.txt │ ├── Family.txt │ └── Barn.txt └── scannet │ ├── scannet_101.txt │ └── scannet_241.txt ├── README.md └── utils.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvs/mvsnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocessing/gmm_torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/main_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/main_fig.png -------------------------------------------------------------------------------- /image/rot-toy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/rot-toy.jpg -------------------------------------------------------------------------------- /image/USC-Logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/USC-Logos.png -------------------------------------------------------------------------------- /image/ucsd_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/ucsd_logo.png -------------------------------------------------------------------------------- /image/Adobe-Logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/Adobe-Logos.png -------------------------------------------------------------------------------- /image/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/image/visualization.png -------------------------------------------------------------------------------- /preprocessing/gmm_torch/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/preprocessing/gmm_torch/example.png -------------------------------------------------------------------------------- /models/init_net/__pycache__/dvgo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/models/init_net/__pycache__/dvgo.cpython-37.pyc -------------------------------------------------------------------------------- /models/init_net/__pycache__/grid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/models/init_net/__pycache__/grid.cpython-37.pyc -------------------------------------------------------------------------------- /models/init_net/__pycache__/run.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/models/init_net/__pycache__/run.cpython-37.pyc -------------------------------------------------------------------------------- /models/init_net/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zerg-Overmind/Strivec/HEAD/models/init_net/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | data/ 3 | d_0.2_vox.txt 4 | cyysync.sh 5 | .idea/ 6 | configs/ 7 | */__pycache__/ 8 | */*/__pycache__/ 9 | *.pyc 10 | git_token 11 | yysync.sh 12 | -------------------------------------------------------------------------------- /models/cuda/ub360_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres); 8 | 9 | // C++ interface 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor cumdist_thres(torch::Tensor dist, float thres) { 16 | CHECK_INPUT(dist); 17 | return cumdist_thres_cuda(dist, thres); 18 | } 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 21 | m.def("cumdist_thres", &cumdist_thres, "Generate mask for cumulative dist."); 22 | } -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .llff import LLFFDataset 2 | from .blender import BlenderDataset 3 | from .blender import BlenderMVSDataset 4 | from .nsvf import NSVF 5 | from .tankstemple import TanksTempleDataset 6 | from .tankstempleBG import TanksTempleDatasetBG 7 | from .your_own_data import YourOwnDataset 8 | from .scannet import ScannetDataset 9 | from .indoor_data import IndoorDataset 10 | 11 | 12 | 13 | dataset_dict = {'blender': BlenderDataset, 14 | 'llff':LLFFDataset, 15 | 'tankstemple':TanksTempleDataset, 16 | 'TanksAndTempleBG':TanksTempleDatasetBG, 17 | 'nsvf':NSVF, 18 | 'scannet':ScannetDataset, 19 | 'own_data':YourOwnDataset, 20 | 'indoor_data': IndoorDataset} 21 | 22 | mvs_dataset_dict = {'blender': BlenderMVSDataset} -------------------------------------------------------------------------------- /preprocessing/gmm_torch/README.md: -------------------------------------------------------------------------------- 1 | This repository contains an implementation of a simple **Gaussian mixture model** (GMM) fitted with Expectation-Maximization in [pytorch](http://www.pytorch.org). The interface closely follows that of [sklearn](http://scikit-learn.org). 2 | 3 | ![Example of a fit via a Gaussian Mixture model.](example.png) 4 | 5 | --- 6 | 7 | A new model is instantiated by calling `gmm.GaussianMixture(..)` and providing as arguments the number of components, as well as the tensor dimension. Note that once instantiated, the model expects tensors in a flattened shape `(n, d)`. 8 | 9 | The first step would usually be to fit the model via `model.fit(data)`, then predict with `model.predict(data)`. To reproduce the above figure, just run the provided `example.py`. 10 | 11 | Some sanity checks can be executed by calling `python test.py`. To fit data on GPUs, ensure that you first call `model.cuda()`. -------------------------------------------------------------------------------- /models/cuda/total_variation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode); 8 | 9 | 10 | // C++ interface 11 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | 16 | void total_variation_add_grad(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) { 17 | CHECK_INPUT(param); 18 | CHECK_INPUT(grad); 19 | total_variation_add_grad_cuda(param, grad, wx, wy, wz, dense_mode); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("total_variation_add_grad", &total_variation_add_grad, "Add total variation grad"); 24 | } 25 | 26 | -------------------------------------------------------------------------------- /preprocessing/gmm_torch/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Lucas Deecke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocessing/gmm_torch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calculate_matmul_n_times(n_components, mat_a, mat_b): 4 | """ 5 | Calculate matrix product of two matrics with mat_a[0] >= mat_b[0]. 6 | Bypasses torch.matmul to reduce memory footprint. 7 | args: 8 | mat_a: torch.Tensor (n, k, 1, d) 9 | mat_b: torch.Tensor (1, k, d, d) 10 | """ 11 | res = torch.zeros(mat_a.shape).to(mat_a.device) 12 | 13 | for i in range(n_components): 14 | mat_a_i = mat_a[:, i, :, :].squeeze(-2) 15 | mat_b_i = mat_b[0, i, :, :].squeeze() 16 | res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1) 17 | 18 | return res 19 | 20 | 21 | def calculate_matmul(mat_a, mat_b): 22 | """ 23 | Calculate matrix product of two matrics with mat_a[0] >= mat_b[0]. 24 | Bypasses torch.matmul to reduce memory footprint. 25 | args: 26 | mat_a: torch.Tensor (n, k, 1, d) 27 | mat_b: torch.Tensor (n, k, d, 1) 28 | """ 29 | assert mat_a.shape[-2] == 1 and mat_b.shape[-1] == 1 30 | return torch.sum(mat_a.squeeze(-2) * mat_b.squeeze(-1), dim=2, keepdim=True) 31 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/ship.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/ship 4 | expname = ship 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] 10 | update_AlphaMask_list = [2000,4000] 11 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 12 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 13 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 14 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 17 | max_tensoRF = [4, 4, 4] 18 | 19 | N_vis = 5 20 | vis_every = 200000 21 | 22 | render_test = 1 23 | featureC = 128 24 | 25 | n_lamb_sigma = [32,24,16] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [48,48,48] 30 | data_dim_color = [27, 27, 27] 31 | 32 | 33 | model_name = StrivecCP_hier 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | view_pe = 2 39 | fea_pe = 2 40 | 41 | L1_weight_inital = 1e-5 42 | L1_weight_rest = 1e-5 43 | rm_weight_mask_thre = 1e-4 44 | ray_type=2 45 | skip_zero_grad=1 46 | gpu_ids="0" 47 | vox_res = 320 48 | pointfile = "./log/ship_point.txt" 49 | # fps_num=[0] 50 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 51 | vox_center=[1,1,1] 52 | 53 | ## dvgo 54 | use_geo = -1 -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/lego.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/lego 4 | expname = lego 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] 10 | update_AlphaMask_list = [2000,4000] 11 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 12 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 13 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 14 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 17 | max_tensoRF = [4, 4, 4] 18 | 19 | N_vis = 5 20 | vis_every = 200000 21 | 22 | render_test = 1 23 | featureC = 128 24 | 25 | n_lamb_sigma = [32,24,16] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [48,48,48] 30 | data_dim_color = [27,27,27] 31 | 32 | model_name = StrivecCP_hier 33 | 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | view_pe = 2 39 | fea_pe = 2 40 | 41 | L1_weight_inital = 1e-5 42 | L1_weight_rest = 1e-5 43 | rm_weight_mask_thre = 1e-4 44 | ray_type=2 45 | skip_zero_grad=1 46 | gpu_ids="0" 47 | vox_res = 320 48 | 49 | pointfile= ./log/lego_points.txt 50 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 51 | vox_center=[1,1,1] 52 | 53 | ## dvgo 54 | use_geo = -1 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/hotdog.txt: -------------------------------------------------------------------------------- 1 | dataset_name = blender 2 | datadir = ./data/nerf_synthetic/hotdog 3 | expname = hotdog 4 | basedir = ./log 5 | n_iters = 30000 6 | batch_size = 4096 7 | 8 | upsamp_list = [2000,3000,4000,5500,7000] 9 | update_AlphaMask_list = [2000,4000] 10 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 11 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 12 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 13 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 14 | 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 17 | max_tensoRF = [4, 4, 4] 18 | 19 | N_vis = 5 20 | vis_every = 200000 21 | 22 | render_test = 1 23 | featureC = 128 24 | 25 | n_lamb_sigma = [32, 24, 16] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [48,48,48] 30 | data_dim_color = [27,27, 27] 31 | 32 | model_name = StrivecCP_hier 33 | 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | view_pe = 2 39 | fea_pe = 2 40 | 41 | L1_weight_inital = 1e-5 42 | L1_weight_rest = 1e-5 43 | rm_weight_mask_thre = 1e-4 44 | ray_type=2 45 | skip_zero_grad=1 46 | gpu_ids="0" 47 | vox_res = 320 48 | pointfile= ./log/hotdog_points.txt 49 | #fps_num=0 50 | vox_range= [0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 51 | vox_center=[1,1,1] 52 | 53 | ## dvgo 54 | use_geo = -1 55 | 56 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/materials.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/materials 4 | expname = materials 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 13 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 14 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 15 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 16 | unit_lvl=0 # which lvl to use deciding units 17 | filterall=1 18 | max_tensoRF = [4, 4, 4] 19 | 20 | 21 | N_vis = 5 22 | vis_every = 200000 23 | 24 | render_test = 1 25 | featureC = 128 26 | 27 | n_lamb_sigma = [32,24,12] 28 | radiance_add = 1 29 | den_lvl_norm = 1 30 | rad_lvl_norm = 0 31 | n_lamb_sh = [96, 96, 96] 32 | data_dim_color = [27,27,27] 33 | 34 | model_name = StrivecCP_hier 35 | 36 | 37 | shadingMode = MLP_Fea 38 | fea2denseAct = softplus 39 | 40 | view_pe = 2 41 | fea_pe = 2 42 | 43 | L1_weight_inital = 1e-5 44 | L1_weight_rest = 1e-5 45 | rm_weight_mask_thre = 1e-4 46 | ray_type=2 47 | skip_zero_grad=1 48 | gpu_ids="0" 49 | vox_res = 320 50 | pointfile= ./log/materials_points.txt 51 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 52 | vox_center=[1,1,1] 53 | 54 | ## dvgo 55 | use_geo = -1 56 | -------------------------------------------------------------------------------- /configs/TanksAndTemple/Ignatius.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = tankstemple 3 | datadir = ./data/TanksAndTemple/Ignatius 4 | expname = Ignatius 5 | basedir = ./log 6 | 7 | n_iters = 60000 8 | 9 | batch_size = 4096 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.6, 0.6, 0.6, 0.3, 0.3, 0.3, 0.18, 0.18, 0.18] 14 | local_dims_init = [59, 59, 59, 29, 29, 29, 15, 15, 15] 15 | local_dims_final = [243, 243, 243, 121, 121, 121, 61, 61, 61] 16 | local_dims_trend = [95, 131, 167, 203, 243, 43, 65, 85, 103, 121, 23, 35, 43, 53, 61] 17 | 18 | unit_lvl=0 # which lvl to use deciding units 19 | filterall=1 20 | max_tensoRF = [4, 4, 4] 21 | 22 | 23 | N_vis = 5 24 | vis_every = 10000 25 | 26 | render_test = 1 27 | featureC = 128 28 | 29 | n_lamb_sigma = [32, 32, 32] 30 | radiance_add = 1 31 | den_lvl_norm = 0 32 | rad_lvl_norm = 0 33 | n_lamb_sh = [96, 96, 96] 34 | data_dim_color = [27, 27, 27] 35 | 36 | model_name = StrivecCP_hier 37 | 38 | shadingMode = MLP_Fea 39 | fea2denseAct = softplus 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | pointfile= ./log/Ignatius_points.txt 52 | vox_range= [0.8, 0.8, 0.8, 0.4, 0.4, 0.4, 0.20, 0.20, 0.20] 53 | vox_center=[1,1,1] 54 | 55 | ## dvgo 56 | use_geo = -1 -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/ficus.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/ficus 4 | expname = ficus 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 #4096 8 | 9 | 10 | upsamp_list = [2000,3000,4000,5500, 7000] 11 | update_AlphaMask_list = [2000,4000] 12 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 13 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 14 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 15 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 16 | unit_lvl=0 # which lvl to use deciding units 17 | filterall=1 18 | max_tensoRF = [4, 4, 4] 19 | 20 | N_vis = 5 21 | vis_every = 200000 22 | 23 | render_test = 1 24 | featureC = 128 25 | 26 | n_lamb_sigma = [32, 24, 12] 27 | radiance_add = 1 28 | den_lvl_norm = 0 29 | rad_lvl_norm = 0 30 | n_lamb_sh = [96, 96, 96] 31 | data_dim_color = [27, 27,27] 32 | 33 | 34 | 35 | model_name = StrivecCP_hier 36 | 37 | 38 | shadingMode = MLP_Fea 39 | fea2denseAct = softplus 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | 48 | ray_type=2 49 | skip_zero_grad=1 50 | gpu_ids="0" 51 | vox_res = 320 52 | pointfile= ./log/ficus_points.txt 53 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 54 | vox_center=[1,1,1] 55 | 56 | ## dvgo 57 | use_geo = -1 58 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/drums.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/drums 4 | expname = drums 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] 10 | update_AlphaMask_list = [2000,4000] 11 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 12 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 13 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 14 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 17 | max_tensoRF = [4, 4, 4] 18 | 19 | N_vis = 5 20 | vis_every = 200000 21 | 22 | render_test = 1 23 | featureC = 128 24 | 25 | n_lamb_sigma = [24, 12, 8] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [48,48,48] 30 | data_dim_color = [27, 27, 27] 31 | 32 | model_name = StrivecCP_hier 33 | 34 | shadingMode = MLP_Fea 35 | fea2denseAct = softplus 36 | 37 | view_pe = 2 38 | fea_pe = 2 39 | 40 | L1_weight_inital = 1e-5 41 | L1_weight_rest = 1e-5 42 | rm_weight_mask_thre = 1e-4 43 | ray_type=2 44 | skip_zero_grad=1 45 | gpu_ids="0" 46 | vox_res = 320 47 | pointfile= ./log/drums_points.txt 48 | #fps_num=0 49 | #vox_range=[0.15, 0.15, 0.15] 50 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 51 | vox_center=[1,1,1] 52 | 53 | ## dvgo 54 | use_geo = -1 55 | -------------------------------------------------------------------------------- /configs/TanksAndTemple/Caterpillar.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = tankstemple 3 | datadir = ./data/TanksAndTemple/Caterpillar 4 | expname = Caterpillar 5 | basedir = ./log 6 | 7 | n_iters = 60000 8 | 9 | batch_size = 4096 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,40000] 12 | 13 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 14 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 15 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 16 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 17 | 18 | unit_lvl=0 # which lvl to use deciding units 19 | filterall=1 20 | max_tensoRF = [4, 4, 4] 21 | 22 | 23 | N_vis = 5 24 | vis_every = 100000 25 | 26 | render_test = 1 27 | featureC = 128 28 | 29 | n_lamb_sigma = [32, 24, 12] 30 | radiance_add = 1 31 | den_lvl_norm = 1 32 | rad_lvl_norm = 0 33 | n_lamb_sh = [48, 48, 48] 34 | data_dim_color = [27, 27, 27] 35 | 36 | model_name = StrivecCP_hier 37 | 38 | shadingMode = MLP_Fea 39 | fea2denseAct = softplus 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | pointfile= ./log/Caterpillar_points.txt 52 | #fps_num=0 53 | #vox_range=[0.15, 0.15, 0.15] 54 | vox_range= [0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 55 | vox_center=[1,1,1] 56 | 57 | ## dvgo 58 | use_geo = -1 59 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/mic.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/mic 4 | expname = mic 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] 10 | update_AlphaMask_list = [2000,4000] 11 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] 12 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] 13 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] 14 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 17 | max_tensoRF = [4, 4, 4] 18 | 19 | 20 | N_vis = 5 21 | vis_every = 2000 22 | 23 | render_test = 1 24 | ##n_lamb_sigma = [96] 25 | ##n_lamb_sh = [288] 26 | #n_lamb_sigma = [32] 27 | #n_lamb_sh = [96] 28 | #data_dim_color = 48 29 | featureC = 128 30 | 31 | n_lamb_sigma = [32, 16, 8] 32 | radiance_add = 1 33 | den_lvl_norm = 0 34 | rad_lvl_norm = 0 35 | n_lamb_sh = [48,48,48] 36 | data_dim_color = [27, 27, 27] 37 | 38 | model_name = StrivecCP_hier 39 | 40 | shadingMode = MLP_Fea 41 | fea2denseAct = softplus 42 | 43 | view_pe = 2 44 | fea_pe = 2 45 | 46 | L1_weight_inital = 1e-5 47 | L1_weight_rest = 1e-5 48 | rm_weight_mask_thre = 1e-4 49 | ray_type=2 50 | skip_zero_grad=1 51 | gpu_ids="0" 52 | vox_res = 320 53 | pointfile= ./log/mic_points.txt 54 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] 55 | vox_center=[1,1,1] 56 | 57 | ## dvgo 58 | use_geo = -1 59 | 60 | -------------------------------------------------------------------------------- /configs/TanksAndTemple/Family.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = tankstemple 3 | datadir = ./data/TanksAndTemple/Family 4 | expname = Family 5 | basedir = ./log 6 | 7 | n_iters = 60000 8 | 9 | batch_size = 4096 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,40000] 12 | 13 | local_range = [0.6, 0.6, 0.6, 0.3, 0.3, 0.3, 0.15, 0.15, 0.15] 14 | local_dims_init = [59, 59, 59, 29, 29, 29, 15, 15, 15] 15 | local_dims_final = [243, 243, 243, 121, 121, 121, 61, 61, 61] 16 | local_dims_trend = [95, 131, 167, 203, 243, 43, 65, 85, 103, 121, 23, 35, 43, 53, 61] 17 | 18 | 19 | unit_lvl=0 # which lvl to use deciding units 20 | filterall=1 21 | max_tensoRF = [4, 4, 4] 22 | 23 | 24 | N_vis = 5 25 | vis_every = 10000 26 | 27 | render_test = 1 28 | ##n_lamb_sigma = [96] 29 | ##n_lamb_sh = [288] 30 | #n_lamb_sigma = [32] 31 | #n_lamb_sh = [96] 32 | #data_dim_color = 48 33 | featureC = 128 34 | 35 | n_lamb_sigma = [32, 24, 12] 36 | radiance_add = 1 37 | den_lvl_norm = 1 38 | rad_lvl_norm = 0 39 | n_lamb_sh = [48, 48, 48] 40 | data_dim_color = [27, 27, 27] 41 | 42 | model_name = StrivecCP_hier 43 | 44 | shadingMode = MLP_Fea 45 | fea2denseAct = softplus 46 | 47 | view_pe = 2 48 | fea_pe = 2 49 | 50 | L1_weight_inital = 1e-5 51 | L1_weight_rest = 1e-5 52 | rm_weight_mask_thre = 1e-4 53 | ray_type=2 54 | skip_zero_grad=1 55 | gpu_ids="0" 56 | vox_res = 320 57 | pointfile= ./log/Family_points.txt 58 | 59 | vox_range= [0.7, 0.7, 0.7, 0.35, 0.35, 0.35, 0.175, 0.175, 0.175] 60 | 61 | ## dvgo 62 | use_geo = -1 63 | -------------------------------------------------------------------------------- /models/cuda/ub360_utils_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | /* 9 | helper function to skip oversampled points, 10 | especially near the foreground scene bbox boundary 11 | */ 12 | template 13 | __global__ void cumdist_thres_cuda_kernel( 14 | scalar_t* __restrict__ dist, 15 | const float thres, 16 | const int n_rays, 17 | const int n_pts, 18 | bool* __restrict__ mask) { 19 | const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; 20 | if(i_ray thres); 28 | cum_dist *= float(!over); 29 | mask[i] = over; 30 | } 31 | } 32 | } 33 | 34 | torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres) { 35 | const int n_rays = dist.size(0); 36 | const int n_pts = dist.size(1); 37 | const int threads = 256; 38 | const int blocks = (n_rays + threads - 1) / threads; 39 | auto mask = torch::zeros({n_rays, n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA)); 40 | AT_DISPATCH_FLOATING_TYPES(dist.type(), "cumdist_thres_cuda", ([&] { 41 | cumdist_thres_cuda_kernel<<>>( 42 | dist.data(), thres, 43 | n_rays, n_pts, 44 | mask.data()); 45 | })); 46 | return mask; 47 | } -------------------------------------------------------------------------------- /configs/TanksAndTemple/Barn.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = tankstemple 3 | datadir = ./data/TanksAndTemple/Barn 4 | expname = Barn 5 | basedir = ./log 6 | 7 | n_iters = 60000 8 | 9 | batch_size = 4096 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,40000] 12 | 13 | # training 14 | local_range = [0.6, 0.6, 0.6, 0.3, 0.3, 0.3, 0.15, 0.15, 0.15] 15 | local_dims_init = [59, 59, 59, 29, 29, 29, 15, 15, 15] 16 | local_dims_final = [243, 243, 243, 121, 121, 121, 61, 61, 61] 17 | local_dims_trend = [95, 131, 167, 203, 243, 43, 65, 85, 103, 121, 23, 35, 43, 53, 61] 18 | 19 | 20 | unit_lvl=0 # which lvl to use deciding units 21 | filterall=1 22 | max_tensoRF = [4, 4, 4] 23 | 24 | N_vis = 5 25 | vis_every = 100000 26 | 27 | render_train = 1 28 | render_test = 1 29 | ##n_lamb_sigma = [96] 30 | ##n_lamb_sh = [288] 31 | #n_lamb_sigma = [32] 32 | #n_lamb_sh = [96] 33 | #data_dim_color = 48 34 | featureC = 128 35 | 36 | n_lamb_sigma = [24, 16, 12] 37 | radiance_add = 1 38 | den_lvl_norm = 0 39 | rad_lvl_norm = 0 40 | n_lamb_sh = [48, 48, 48] 41 | data_dim_color = [27, 27, 27] 42 | 43 | model_name = StrivecCP_hier 44 | 45 | shadingMode = MLP_Fea 46 | fea2denseAct = softplus 47 | 48 | view_pe = 2 49 | fea_pe = 2 50 | 51 | L1_weight_inital = 1e-5 52 | L1_weight_rest = 1e-5 53 | rm_weight_mask_thre = 1e-4 54 | ray_type=2 55 | skip_zero_grad=1 56 | gpu_ids="0" 57 | vox_res = 320 58 | pointfile= ./log/barn_points.txt 59 | #fps_num=0 60 | #vox_range=[0.15, 0.15, 0.15] 61 | vox_range = [0.8, 0.8, 0.8, 0.4, 0.4, 0.4, 0.25, 0.25, 0.25] 62 | vox_center=[1,1,1] 63 | 64 | ## dvgo 65 | use_geo = -1 66 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/chair.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/chair 4 | expname = chair_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile= ./log/chair_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] 54 | 55 | ## dvgo 56 | use_geo = -1 57 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/ficus.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/ficus 4 | expname = ficus_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile= ./log/ficus_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo =-1 -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/ship.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/ship 4 | expname = ship_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile = ./log/ship_poins.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo =-1 57 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/hotdog.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/hotdog 4 | expname = hotdog_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile= ./log/hotdog_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo =-1 -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/lego.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/lego 4 | expname = lego_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 #4096 8 | pointfile= ./log/lego_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | 56 | ## dvgo 57 | use_geo =-1 -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/mic.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/mic 4 | expname = mic_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 #4096 8 | pointfile= ./log/mic_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo =-1 57 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/drums.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/drums 4 | expname = drums_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile= ./log/drums_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = PointTensor_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo = -1 57 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/local_vm/materials.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/materials 4 | expname = materials_vm 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | pointfile= ./log/materials_points.txt 9 | 10 | upsamp_list = [2000,3000,4000,5500,7000] 11 | update_AlphaMask_list = [2000,4000] 12 | 13 | local_range = [0.15, 0.15, 0.15] # scale1- [0.3, 0.3, 0.3] scale2- [0.15, 0.15, 0.15] scale3-[0.075, 0.075, 0.075] 14 | local_dims_init = [15, 15, 15] # scale1- [29, 29, 29] scale2- [15, 15, 15] scale3-[7, 7, 7] 15 | local_dims_final = [61, 61, 61] # scale1- [121, 121, 121] scale2- [61, 61, 61] scale3-[31, 31, 31] 16 | local_dims_trend = [23, 35, 43, 53, 61] # scale1- [43, 65, 85, 103, 121] scale2- [23, 35, 43, 53, 61] scale3-[11, 17, 21, 27, 31] 17 | unit_lvl=0 # which lvl to use deciding units 18 | filterall=1 19 | max_tensoRF = [4] 20 | 21 | N_vis = 5 22 | vis_every = 30000 23 | render_test = 1 24 | 25 | n_lamb_sigma = [12] 26 | radiance_add = 1 27 | den_lvl_norm = 0 28 | rad_lvl_norm = 0 29 | n_lamb_sh = [24] 30 | data_dim_color = [27] 31 | 32 | featureC = 128 33 | model_name = Strivec_DBase 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | vm_agg_mode=1 39 | vm_sep_mat=2 40 | 41 | view_pe = 2 42 | fea_pe = 2 43 | 44 | L1_weight_inital = 1e-5 45 | L1_weight_rest = 1e-5 46 | rm_weight_mask_thre = 1e-4 47 | ray_type=2 48 | skip_zero_grad=1 49 | gpu_ids="0" 50 | vox_res = 320 51 | # fps_num=[0] 52 | vox_range=[0.2, 0.2, 0.2] # 0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05 53 | vox_center=[1] # 1,1,1,1 54 | 55 | ## dvgo 56 | use_geo =-1 57 | -------------------------------------------------------------------------------- /configs/scannet/scannet_101.txt: -------------------------------------------------------------------------------- 1 | dataset_name = scannet 2 | datadir = ./data/scene0101_04 3 | expname = scan_101 4 | basedir = ./log 5 | n_iters = 300000 6 | batch_size = 2048 # 4096 7 | 8 | 9 | upsamp_list = [18000,36000, 54000, 72000, 96000, 120000] 10 | update_AlphaMask_list = [36000, 72000] 11 | local_range = [2.5, 2.5, 2.5, 1.6, 1.6, 1.6, 0.8, 0.8, 0.8, 0.4, 0.4, 0.4] 12 | local_dims_init = [157, 157, 157, 101, 101, 101, 51, 51, 51, 25, 25, 25] 13 | local_dims_final = [525, 525, 525, 401, 401, 401, 201, 201, 201, 101, 101, 101] 14 | local_dims_trend = [201, 251, 301, 401, 525, 321, 441, 561, 681, 801, 161, 221, 281, 341, 401, 81, 111, 141, 171, 201, 41, 55, 71, 85, 101] 15 | vox_range = [3.0, 3.0, 3.0, 1.7, 1.7, 1.7, 0.6, 0.6, 0.6, 0.3, 0.3, 0.3] 16 | vox_center= [1, 1, 1, 1] 17 | n_lamb_sigma = [24, 24, 16, 12] 18 | n_lamb_sh = [48, 48, 48, 48] 19 | unit_lvl=0 # which lvl to use deciding units 20 | filterall=1 21 | max_tensoRF = [4,4,2,2] 22 | 23 | N_vis = 5 24 | vis_every = 200000 25 | 26 | render_test = 1 27 | #render_train = 1 28 | 29 | 30 | radiance_add = 1 31 | den_lvl_norm = 1 32 | rad_lvl_norm = 0 33 | data_dim_color = [27,27,27,27] 34 | 35 | featureC = 128 36 | model_name = StrivecCP_hier 37 | 38 | shadingMode = MLP_Fea 39 | fea2denseAct = softplus 40 | 41 | 42 | # dbasis 43 | #vm_agg_mode=1 44 | #vm_sep_mat=1 45 | 46 | view_pe = 2 47 | fea_pe = 2 48 | 49 | L1_weight_inital = 1e-5 50 | L1_weight_rest = 1e-5 51 | rm_weight_mask_thre = 1e-4 52 | ray_type=2 53 | skip_zero_grad=1 54 | gpu_ids="0" 55 | vox_res = 320 56 | pointfile = ./data/scene0101_04/101_pts_from_mesh.txt 57 | #fps_num=[0] 58 | 59 | test_margin = 10 60 | margin = 10 61 | 62 | ## dvgo 63 | use_geo = 1 64 | 65 | -------------------------------------------------------------------------------- /configs/scannet/scannet_241.txt: -------------------------------------------------------------------------------- 1 | dataset_name = scannet 2 | datadir = ./data/scene0241_01 3 | expname = scan_241 4 | basedir = ./log 5 | n_iters = 300000 6 | batch_size = 2048 #4096 7 | 8 | 9 | upsamp_list = [18000,36000, 54000, 72000, 96000, 120000] 10 | update_AlphaMask_list = [36000, 72000] 11 | local_range = [2.5, 2.5, 2.5, 1.6, 1.6, 1.6, 0.8, 0.8, 0.8, 0.4, 0.4, 0.4] 12 | local_dims_init = [157, 157, 157, 101, 101, 101, 51, 51, 51, 25, 25, 25] 13 | local_dims_final = [525, 525, 525, 401, 401, 401, 201, 201, 201, 101, 101, 101] 14 | local_dims_trend = [201, 251, 301, 401, 525, 321, 441, 561, 681, 801, 161, 221, 281, 341, 401, 81, 111, 141, 171, 201, 41, 55, 71, 85, 101] 15 | vox_range = [3.0, 3.0, 3.0, 1.7, 1.7, 1.7, 0.6, 0.6, 0.6, 0.3, 0.3, 0.3] 16 | vox_center=[1, 1, 1, 1] 17 | n_lamb_sigma = [24, 24, 16, 12] 18 | n_lamb_sh = [48, 48, 48, 48] 19 | unit_lvl=0 # which lvl to use deciding units 20 | filterall=1 21 | max_tensoRF = [4,4,2,2] 22 | 23 | N_vis = 5 24 | vis_every = 1000000 25 | 26 | render_test = 1 27 | #render_train = 1 28 | 29 | 30 | radiance_add = 1 31 | den_lvl_norm = 1 32 | rad_lvl_norm = 0 33 | data_dim_color = [27,27,27,27] 34 | 35 | featureC = 128 36 | model_name = StrivecCP_hier 37 | 38 | shadingMode = MLP_Fea 39 | fea2denseAct = softplus 40 | 41 | 42 | # dbasis 43 | #vm_agg_mode=1 44 | #vm_sep_mat=1 45 | 46 | view_pe = 2 47 | fea_pe = 2 48 | 49 | L1_weight_inital = 1e-5 50 | L1_weight_rest = 1e-5 51 | rm_weight_mask_thre = 1e-4 52 | ray_type=2 53 | skip_zero_grad=1 54 | gpu_ids="0" 55 | vox_res = 320 56 | pointfile = "./data/scene0241_01/241_pts_from_mesh.txt" 57 | #fps_num=[0] 58 | test_margin = 10 59 | margin = 10 60 | 61 | ## dvgo 62 | use_geo = 1 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /configs/360/360_garden.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = llff 3 | datadir = ./data/360/garden 4 | expname = 360_garden 5 | basedir = ./log 6 | n_iters = 120000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] # [2000,3000,5000,7000] [2000,3000,4000,5500,7000] 10 | 11 | local_range = [1.2, 1.2, 1.2, 0.6, 0.6, 0.6] # two scales: 1.2 0.6 12 | local_dims_init = [101, 101, 101, 51, 51, 51] # initial dimension of each scale: 101, 51 13 | local_dims_final = [323, 323, 323, 161, 161, 161] # initial dimension of each scale: 323, 161 14 | local_dims_trend = [145, 189, 233, 277, 323, 71, 91, 111, 131, 161] # 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 # only the points coveraged by all scales are considered 17 | max_tensoRF = [4,4] # the number of local tensors to be considered coveraging the sampled point along a ray (TopK and K=4 as mentioned in the paper) 18 | 19 | 20 | N_vis = 5 # the number of rendered views when testing/inference 21 | vis_every = 20000 # tesing/inference at every 'vis_every' iters 22 | 23 | render_test = 1 24 | 25 | featureC = 128 26 | 27 | n_lamb_sigma = [48,32] 28 | radiance_add = 1 29 | den_lvl_norm = 0 30 | rad_lvl_norm = 0 31 | n_lamb_sh = [48,48] 32 | data_dim_color = [27,27] 33 | 34 | model_name = StrivecCP_hier 35 | 36 | shadingMode = MLP_Fea 37 | fea2denseAct = softplus 38 | 39 | view_pe = 2 40 | fea_pe = 2 41 | 42 | lr_init = 3e-2 43 | #lr_decay_iters = 40000 44 | #lr_decay_target_ratio = 0.5 45 | 46 | L1_weight_inital = 1e-4 47 | L1_weight_rest = 5e-5 48 | rm_weight_mask_thre = 1e-4 49 | ray_type=2 50 | skip_zero_grad=1 51 | gpu_ids="0" 52 | vox_res = 320 53 | pointfile= ./log/garden_points.vox 54 | 55 | vox_range = [1.2, 1.2, 1.2, 0.6, 0.6, 0.6] # distribute local tensors of different scales at every 'vox_range'; 1.2 for first scale and 0.6 for the second 56 | vox_center = [0,0] 57 | 58 | 59 | ## dvgo 60 | use_geo = -1 # -1 for using dvgo to get a initial geometry, 1 for using other pre-loaded file (pointfile) 61 | 62 | ub360 = 1 63 | indoor = 0 64 | downsample_train=4 # 1297x840 65 | #pre_lrate_decay = 80 66 | #pre_num_voxels= 32768000 # 320**3 67 | pre_N_iters = 15000 68 | -------------------------------------------------------------------------------- /configs/synthetic-nerf/default/chair.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/chair 4 | expname = chair 5 | basedir = ./log 6 | n_iters = 30000 7 | batch_size = 4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,7000] 10 | update_AlphaMask_list = [2000,4000] 11 | local_range = [0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.075, 0.075, 0.075] # three scales: 0.3, 0.15, 0.75 12 | local_dims_init = [29, 29, 29, 15, 15, 15, 7, 7, 7] # initial dimension of each scale: 29, 15, 7 13 | local_dims_final = [121, 121, 121, 61, 61, 61, 31, 31, 31] # initial dimension of each scale: 121, 61, 31 14 | local_dims_trend = [43, 65, 85, 103, 121, 23, 35, 43, 53, 61, 11, 17, 21, 27, 31] # upsampling the dimension of each scale: 1st scale begins with 29, and 43, 65, 85, 103 and ends with 121; 2nd scale begins with 15, and 23, 35, 43, 53, and ends with 61; 15 | unit_lvl=0 # which lvl to use deciding units 16 | filterall=1 # only the points coveraged by all scales are considered 17 | max_tensoRF = [4, 4, 4] # the number of local tensors to be considered coveraging the sampled point along a ray (TopK and K=4 as mentioned in the paper) 18 | 19 | N_vis = 5 # the number of rendered views when testing/inference 20 | vis_every = 200000 # tesing/inference at every 'vis_every' iters 21 | 22 | render_test = 1 23 | featureC = 128 24 | 25 | # number of components of each scale 26 | n_lamb_sigma = [32, 24, 16] 27 | radiance_add = 1 28 | den_lvl_norm = 0 29 | rad_lvl_norm = 0 30 | n_lamb_sh = [48, 48, 48] # now it is ours-48, if you want to try like ours-24, please do as [24, 24, 24] 31 | data_dim_color = [27, 27, 27] 32 | 33 | model_name = StrivecCP_hier 34 | 35 | shadingMode = MLP_Fea 36 | fea2denseAct = softplus 37 | 38 | view_pe = 2 39 | fea_pe = 2 40 | 41 | L1_weight_inital = 1e-5 42 | L1_weight_rest = 1e-5 43 | rm_weight_mask_thre = 1e-4 44 | ray_type=2 45 | skip_zero_grad=1 46 | gpu_ids="0" 47 | vox_res = 320 48 | pointfile= ./log/chair_points.txt # your own inital geometry 49 | 50 | vox_range=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1] # distribute local tensors of different scales at every 'vox_range'; 0.4 for first scale, 0.2 for the second, and 0.1 for the third 51 | vox_center=[1,1,1] 52 | 53 | ## dvgo 54 | use_geo = -1 # 1 when you want to use your own initial geometry claimed in 'pointfile' 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /preprocessing/gmm_torch/example.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | sns.set(style="white", font="Arial") 6 | colors = sns.color_palette("Paired", n_colors=12).as_hex() 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from gmm import GaussianMixture 12 | from math import sqrt 13 | 14 | 15 | def main(): 16 | n, d = 300, 2 17 | 18 | # generate some data points .. 19 | data = torch.Tensor(n, d).normal_() 20 | # .. and shift them around to non-standard Gaussians 21 | data[:n//2] -= 1 22 | data[:n//2] *= sqrt(3) 23 | data[n//2:] += 1 24 | data[n//2:] *= sqrt(2) 25 | 26 | # Next, the Gaussian mixture is instantiated and .. 27 | n_components = 2 28 | model = GaussianMixture(n_components, d) 29 | model.fit(data) 30 | # .. used to predict the data points as they where shifted 31 | y = model.predict(data) 32 | 33 | plot(data, y) 34 | 35 | 36 | def plot(data, y): 37 | n = y.shape[0] 38 | 39 | fig, ax = plt.subplots(1, 1, figsize=(1.61803398875*4, 4)) 40 | ax.set_facecolor("#bbbbbb") 41 | ax.set_xlabel("Dimension 1") 42 | ax.set_ylabel("Dimension 2") 43 | 44 | # plot the locations of all data points .. 45 | for i, point in enumerate(data.data): 46 | if i <= n//2: 47 | # .. separating them by ground truth .. 48 | ax.scatter(*point, color="#000000", s=3, alpha=.75, zorder=n+i) 49 | else: 50 | ax.scatter(*point, color="#ffffff", s=3, alpha=.75, zorder=n+i) 51 | 52 | if y[i] == 0: 53 | # .. as well as their predicted class 54 | ax.scatter(*point, zorder=i, color="#dbe9ff", alpha=.6, edgecolors=colors[1]) 55 | else: 56 | ax.scatter(*point, zorder=i, color="#ffdbdb", alpha=.6, edgecolors=colors[5]) 57 | 58 | handles = [plt.Line2D([0], [0], color="w", lw=4, label="Ground Truth 1"), 59 | plt.Line2D([0], [0], color="black", lw=4, label="Ground Truth 2"), 60 | plt.Line2D([0], [0], color=colors[1], lw=4, label="Predicted 1"), 61 | plt.Line2D([0], [0], color=colors[5], lw=4, label="Predicted 2")] 62 | 63 | legend = ax.legend(loc="best", handles=handles) 64 | 65 | plt.tight_layout() 66 | plt.savefig("example.pdf") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /configs/360/360_room.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = llff 3 | datadir = ./data/360/room 4 | expname = 360_room 5 | basedir = ./log 6 | n_iters = 100000 7 | batch_size = 3000 #4096 8 | 9 | upsamp_list = [2000,3000,4000,5500,6500] # [2000,3000,5000,7000] [2000,3000,4000,5500,7000] 10 | #update_AlphaMask_list = [20000,40000] 11 | update_AlphaMask_list = [200000,400000] # [2000,4000,15000,30000] 12 | #local_range = [0.15, 0.15, 0.15] # [0.05, 0.05, 0.05] # [0.1, 0.1, 0.1] 13 | #local_dims_init = [15, 15, 15] # [3, 5, 9, 17, 33] 14 | #local_dims_final = [65, 65, 65] # [40, 40, 40] 15 | #max_tensoRF = 4 16 | 17 | local_range = [1.2, 1.2, 1.2, 0.6, 0.6, 0.6] # [1.2, 1.2, 1.2, 0.6, 0.6, 0.6, 0.3, 0.3, 0.3] 18 | local_dims_init = [101, 101, 101, 51, 51, 51] # [81, 81, 81] [101, 101, 101, 51, 51, 51, 25, 25, 25] 19 | local_dims_final = [323, 323, 323, 161, 161, 161] # [323, 323, 323, 161, 161, 161, 81, 81, 81] [401, 401, 401, 161, 161, 161, 81, 81, 81] 20 | local_dims_trend = [145, 189, 233, 277, 323, 71, 91, 111, 131, 161] # [145, 189, 233, 277, 323, 71, 91, 111, 131, 161, 37, 47, 59, 71, 81] [161, 221, 281, 341, 401, 71, 91, 111, 131, 161] 21 | unit_lvl=0 # which lvl to use deciding units 22 | filterall=1 23 | max_tensoRF = [4,4] 24 | 25 | 26 | N_vis = 5 27 | vis_every = 200000 28 | 29 | render_test = 1 30 | ##n_lamb_sigma = [96] 31 | ##n_lamb_sh = [288] 32 | #n_lamb_sigma = [32] 33 | #n_lamb_sh = [96] 34 | #data_dim_color = 48 35 | featureC = 128 36 | 37 | n_lamb_sigma = [48,32] # [64, 48] 38 | radiance_add = 1 39 | den_lvl_norm = 0 40 | rad_lvl_norm = 1 41 | n_lamb_sh = [96, 96] 42 | data_dim_color = [27, 27] 43 | 44 | model_name = PointTensorCP_hier #PointTensorCP 45 | 46 | shadingMode = MLP_Fea 47 | fea2denseAct = softplus 48 | 49 | view_pe = 2 50 | fea_pe = 2 51 | 52 | lr_init = 3e-2 53 | #lr_decay_iters = 40000 54 | #lr_decay_target_ratio = 0.5 55 | 56 | L1_weight_inital = 1e-4 57 | L1_weight_rest = 5e-5 58 | rm_weight_mask_thre = 1e-4 59 | ray_type=2 60 | skip_zero_grad=1 61 | gpu_ids="0" 62 | vox_res = 320 63 | pointfile= ./log/room_points.vox 64 | #fps_num=0 65 | #vox_range=[0.15, 0.15, 0.15] 66 | vox_range= [1.2, 1.2, 1.2, 0.6, 0.6, 0.6] # [1.2, 1.2, 1.2, 0.6, 0.6, 0.6] 67 | vox_center=[0,0] 68 | 69 | 70 | 71 | ## dvgo 72 | use_geo = -1 73 | 74 | ub360 = 1 75 | indoor = 1 76 | downsample_train=2 # 1297x840 77 | #pre_lrate_decay = 80 78 | #pre_num_voxels= 32768000 # 320**3 79 | pre_N_iters = 15000 80 | -------------------------------------------------------------------------------- /models/cuda/total_variation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __device__ __forceinline__ scalar_t clamp(const scalar_t v, const bound_t lo, const bound_t hi) { 10 | return min(max(v, lo), hi); 11 | } 12 | 13 | template 14 | __global__ void total_variation_add_grad_cuda_kernel( 15 | const scalar_t* __restrict__ param, 16 | scalar_t* __restrict__ grad, 17 | float wx, float wy, float wz, 18 | const size_t sz_i, const size_t sz_j, const size_t sz_k, const size_t N) { 19 | 20 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 21 | if(index<<>>( 52 | param.data(), 53 | grad.data(), 54 | wx, wy, wz, 55 | sz_i, sz_j, sz_k, N); 56 | })); 57 | } 58 | else { 59 | AT_DISPATCH_FLOATING_TYPES(param.type(), "total_variation_add_grad_cuda", ([&] { 60 | total_variation_add_grad_cuda_kernel<<>>( 61 | param.data(), 62 | grad.data(), 63 | wx, wy, wz, 64 | sz_i, sz_j, sz_k, N); 65 | })); 66 | } 67 | } 68 | 69 | -------------------------------------------------------------------------------- /models/cuda/adam_upd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | void adam_upd_cuda( 8 | torch::Tensor param, 9 | torch::Tensor grad, 10 | torch::Tensor exp_avg, 11 | torch::Tensor exp_avg_sq, 12 | int step, float beta1, float beta2, float lr, float eps); 13 | 14 | void masked_adam_upd_cuda( 15 | torch::Tensor param, 16 | torch::Tensor grad, 17 | torch::Tensor exp_avg, 18 | torch::Tensor exp_avg_sq, 19 | int step, float beta1, float beta2, float lr, float eps); 20 | 21 | void adam_upd_with_perlr_cuda( 22 | torch::Tensor param, 23 | torch::Tensor grad, 24 | torch::Tensor exp_avg, 25 | torch::Tensor exp_avg_sq, 26 | torch::Tensor perlr, 27 | int step, float beta1, float beta2, float lr, float eps); 28 | 29 | 30 | // C++ interface 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | void adam_upd( 37 | torch::Tensor param, 38 | torch::Tensor grad, 39 | torch::Tensor exp_avg, 40 | torch::Tensor exp_avg_sq, 41 | int step, float beta1, float beta2, float lr, float eps) { 42 | CHECK_INPUT(param); 43 | CHECK_INPUT(grad); 44 | CHECK_INPUT(exp_avg); 45 | CHECK_INPUT(exp_avg_sq); 46 | adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 47 | step, beta1, beta2, lr, eps); 48 | } 49 | 50 | void masked_adam_upd( 51 | torch::Tensor param, 52 | torch::Tensor grad, 53 | torch::Tensor exp_avg, 54 | torch::Tensor exp_avg_sq, 55 | int step, float beta1, float beta2, float lr, float eps) { 56 | CHECK_INPUT(param); 57 | CHECK_INPUT(grad); 58 | CHECK_INPUT(exp_avg); 59 | CHECK_INPUT(exp_avg_sq); 60 | masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, 61 | step, beta1, beta2, lr, eps); 62 | } 63 | 64 | void adam_upd_with_perlr( 65 | torch::Tensor param, 66 | torch::Tensor grad, 67 | torch::Tensor exp_avg, 68 | torch::Tensor exp_avg_sq, 69 | torch::Tensor perlr, 70 | int step, float beta1, float beta2, float lr, float eps) { 71 | CHECK_INPUT(param); 72 | CHECK_INPUT(grad); 73 | CHECK_INPUT(exp_avg); 74 | CHECK_INPUT(exp_avg_sq); 75 | adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr, 76 | step, beta1, beta2, lr, eps); 77 | } 78 | 79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 80 | m.def("adam_upd", &adam_upd, 81 | "Adam update"); 82 | m.def("masked_adam_upd", &masked_adam_upd, 83 | "Adam update ignoring zero grad"); 84 | m.def("adam_upd_with_perlr", &adam_upd_with_perlr, 85 | "Adam update ignoring zero grad with per-voxel lr"); 86 | } 87 | 88 | -------------------------------------------------------------------------------- /models/masked_adam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.cpp_extension import load 4 | 5 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 6 | sources=['cuda/adam_upd.cpp', 'cuda/adam_upd_kernel.cu'] 7 | adam_upd_cuda = load( 8 | name='adam_upd_cuda', 9 | sources=[os.path.join(parent_dir, path) for path in sources], 10 | verbose=True) 11 | 12 | 13 | ''' Extend Adam optimizer 14 | 1. support per-voxel learning rate 15 | 2. masked update (ignore zero grad) which speeduping training 16 | ''' 17 | class MaskedAdam(torch.optim.Optimizer): 18 | 19 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8): 20 | if not 0.0 <= lr: 21 | raise ValueError("Invalid learning rate: {}".format(lr)) 22 | if not 0.0 <= eps: 23 | raise ValueError("Invalid epsilon value: {}".format(eps)) 24 | if not 0.0 <= betas[0] < 1.0: 25 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 26 | if not 0.0 <= betas[1] < 1.0: 27 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 28 | defaults = dict(lr=lr, betas=betas, eps=eps) 29 | self.per_lr = None 30 | super(MaskedAdam, self).__init__(params, defaults) 31 | 32 | def __setstate__(self, state): 33 | super(MaskedAdam, self).__setstate__(state) 34 | 35 | def set_pervoxel_lr(self, count): 36 | assert self.param_groups[0]['params'][0].shape == count.shape 37 | self.per_lr = count.float() / count.max() 38 | 39 | @torch.no_grad() 40 | def step(self): 41 | for group in self.param_groups: 42 | lr = group['lr'] 43 | beta1, beta2 = group['betas'] 44 | eps = group['eps'] 45 | skip_zero_grad = group['skip_zero_grad'] 46 | 47 | for param in group['params']: 48 | if param.grad is not None: 49 | state = self.state[param] 50 | # Lazy state initialization 51 | if len(state) == 0: 52 | state['step'] = 0 53 | # Exponential moving average of gradient values 54 | state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) 55 | # Exponential moving average of squared gradient values 56 | state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) 57 | 58 | state['step'] += 1 59 | 60 | if self.per_lr is not None and param.shape == self.per_lr.shape: 61 | adam_upd_cuda.adam_upd_with_perlr( 62 | param, param.grad, state['exp_avg'], state['exp_avg_sq'], self.per_lr, 63 | state['step'], beta1, beta2, lr, eps) 64 | elif skip_zero_grad: 65 | adam_upd_cuda.masked_adam_upd( 66 | param, param.grad, state['exp_avg'], state['exp_avg_sq'], 67 | state['step'], beta1, beta2, lr, eps) 68 | else: 69 | adam_upd_cuda.adam_upd( 70 | param, param.grad, state['exp_avg'], state['exp_avg_sq'], 71 | state['step'], beta1, beta2, lr, eps) 72 | 73 | -------------------------------------------------------------------------------- /models/cuda/grid_sample_1d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | 6 | std::vector grid_sample_from_tensoRF_backward_cuda(torch::Tensor local_gindx_s, torch::Tensor local_gindx_l, torch::Tensor local_gweight_s, torch::Tensor local_gweight_l, torch::Tensor final_tensoRF_id, torch::Tensor grad_planeout, torch::Tensor grad_lineout, int planesurf_num, int linesurf_num, int component_num, int res); 7 | 8 | 9 | std::vector grid_sample_from_tensoRF_cuda( 10 | torch::Tensor plane, 11 | torch::Tensor line1, 12 | torch::Tensor line2, 13 | torch::Tensor line3, 14 | torch::Tensor xyz_sampled, 15 | torch::Tensor xyz_min, 16 | torch::Tensor xyz_max, 17 | torch::Tensor units, 18 | torch::Tensor lvl_units, 19 | torch::Tensor local_range, 20 | torch::Tensor local_dims, 21 | torch::Tensor tensoRF_cvrg_inds, 22 | torch::Tensor tensoRF_count, 23 | torch::Tensor tensoRF_topindx, 24 | torch::Tensor geo_xyz, 25 | const int K, 26 | const bool KNN); 27 | 28 | 29 | std::vector cal_w_inds_cuda( 30 | torch::Tensor plane, 31 | torch::Tensor line1, 32 | torch::Tensor line2, 33 | torch::Tensor line3, 34 | torch::Tensor local_gindx_s, 35 | torch::Tensor local_gindx_l, 36 | torch::Tensor local_gweight_s, 37 | torch::Tensor local_gweight_l, 38 | torch::Tensor final_tensoRF_id); 39 | 40 | // C++ interface 41 | 42 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 43 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 44 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 45 | 46 | 47 | 48 | std::vector grid_sample_from_tensoRF(torch::Tensor plane, torch::Tensor line1, torch::Tensor line2, torch::Tensor line3, torch::Tensor xyz_sampled, torch::Tensor xyz_min, torch::Tensor xyz_max, torch::Tensor units, torch::Tensor lvl_units, torch::Tensor local_range, torch::Tensor local_dims, torch::Tensor tensoRF_cvrg_inds, torch::Tensor tensoRF_count, torch::Tensor tensoRF_topindx, torch::Tensor geo_xyz, const int K, const bool KNN) { 49 | CHECK_INPUT(plane); 50 | CHECK_INPUT(line1); 51 | CHECK_INPUT(line2); 52 | CHECK_INPUT(line3); 53 | CHECK_INPUT(geo_xyz); 54 | assert(xyz_sampled.dim()==2); 55 | return grid_sample_from_tensoRF_cuda(plane, line1, line2, line3, xyz_sampled, xyz_min, xyz_max, units, lvl_units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, K, KNN); 56 | } 57 | 58 | 59 | 60 | std::vector cal_w_inds(torch::Tensor plane, torch::Tensor line1, torch::Tensor line2, torch::Tensor line3, torch::Tensor local_gindx_s, torch::Tensor local_gindx_l, torch::Tensor local_gweight_s, torch::Tensor local_gweight_l, torch::Tensor final_tensoRF_id) { 61 | CHECK_INPUT(plane); 62 | CHECK_INPUT(line1); 63 | CHECK_INPUT(line2); 64 | CHECK_INPUT(line3); 65 | CHECK_INPUT(local_gindx_l); 66 | return cal_w_inds_cuda(plane, line1, line2, line3, local_gindx_s, local_gindx_l, local_gweight_s, local_gweight_l, final_tensoRF_id); 67 | } 68 | 69 | std::vector grid_sample_from_tensoRF_backward(torch::Tensor local_gindx_s, torch::Tensor local_gindx_l, torch::Tensor local_gweight_s, torch::Tensor local_gweight_l, torch::Tensor final_tensoRF_id, torch::Tensor grad_planeout, torch::Tensor grad_lineout, int planesurf_num, int linesurf_num, int component_num, int res) { 70 | CHECK_INPUT(grad_planeout); 71 | CHECK_INPUT(grad_lineout); 72 | return grid_sample_from_tensoRF_backward_cuda(local_gindx_s, local_gindx_l, local_gweight_s, local_gweight_l, final_tensoRF_id, grad_planeout, grad_lineout, planesurf_num, linesurf_num, component_num, res); 73 | } 74 | 75 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 76 | m.def("cal_w_inds", &cal_w_inds, "Sampled points to get torsoRF with inds"); 77 | m.def("grid_sample_from_tensoRF", &grid_sample_from_tensoRF, "Sampled points to get torsoRF with gs sample"); 78 | m.def("grid_sample_from_tensoRF_backward", &grid_sample_from_tensoRF_backward, "Backward pass of the tensorf"); 79 | } 80 | 81 | 82 | -------------------------------------------------------------------------------- /models/cuda/search_geo_hier.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | 8 | std::vector build_tensoRF_map_hier_cuda( 9 | torch::Tensor pnt_xyz, 10 | torch::Tensor gridSize, 11 | torch::Tensor xyz_min, 12 | torch::Tensor xyz_max, 13 | torch::Tensor units, 14 | torch::Tensor local_range, 15 | torch::Tensor local_dims, 16 | const int max_tensoRF); 17 | 18 | 19 | std::vector sample_2_tensoRF_cvrg_hier_cuda( 20 | torch::Tensor xyz_sampled, 21 | torch::Tensor xyz_min, 22 | torch::Tensor xyz_max, 23 | torch::Tensor units, 24 | torch::Tensor lvl_units, 25 | torch::Tensor local_range, 26 | torch::Tensor local_dims, 27 | torch::Tensor tensoRF_cvrg_inds, 28 | torch::Tensor tensoRF_count, 29 | torch::Tensor tensoRF_topindx, 30 | torch::Tensor geo_xyz, 31 | const int K, 32 | const bool KNN); 33 | 34 | std::vector sample_2_tensoRF_cvrg_hier_gs_cuda( 35 | torch::Tensor xyz_sampled, 36 | torch::Tensor xyz_min, 37 | torch::Tensor xyz_max, 38 | torch::Tensor units, 39 | torch::Tensor lvl_units, 40 | torch::Tensor local_range, 41 | torch::Tensor local_dims, 42 | torch::Tensor tensoRF_cvrg_inds, 43 | torch::Tensor tensoRF_count, 44 | torch::Tensor tensoRF_topindx, 45 | torch::Tensor geo_xyz, 46 | const int K, 47 | const bool KNN); 48 | 49 | // C++ interface 50 | 51 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 52 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 53 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 54 | 55 | 56 | 57 | std::vector sample_2_tensoRF_cvrg_hier_gs(torch::Tensor xyz_sampled, torch::Tensor xyz_min, torch::Tensor xyz_max, torch::Tensor units, torch::Tensor lvl_units, torch::Tensor local_range, torch::Tensor local_dims, torch::Tensor tensoRF_cvrg_inds, torch::Tensor tensoRF_count, torch::Tensor tensoRF_topindx, torch::Tensor geo_xyz, const int K, const bool KNN) { 58 | CHECK_INPUT(xyz_sampled); 59 | CHECK_INPUT(geo_xyz); 60 | assert(xyz_sampled.dim()==2); 61 | return sample_2_tensoRF_cvrg_hier_gs_cuda(xyz_sampled, xyz_min, xyz_max, units, lvl_units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, K, KNN); 62 | } 63 | 64 | 65 | std::vector sample_2_tensoRF_cvrg_hier(torch::Tensor xyz_sampled, torch::Tensor xyz_min, torch::Tensor xyz_max, torch::Tensor units, torch::Tensor lvl_units, torch::Tensor local_range, torch::Tensor local_dims, torch::Tensor tensoRF_cvrg_inds, torch::Tensor tensoRF_count, torch::Tensor tensoRF_topindx, torch::Tensor geo_xyz, const int K, const bool KNN) { 66 | CHECK_INPUT(xyz_sampled); 67 | CHECK_INPUT(geo_xyz); 68 | assert(xyz_sampled.dim()==2); 69 | return sample_2_tensoRF_cvrg_hier_cuda(xyz_sampled, xyz_min, xyz_max, units, lvl_units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, K, KNN); 70 | } 71 | 72 | std::vector build_tensoRF_map_hier( 73 | torch::Tensor pnt_xyz, 74 | torch::Tensor gridSize, 75 | torch::Tensor xyz_min, 76 | torch::Tensor xyz_max, 77 | torch::Tensor units, 78 | torch::Tensor local_range, 79 | torch::Tensor local_dims, const int max_tensoRF) { 80 | CHECK_INPUT(pnt_xyz); 81 | CHECK_INPUT(gridSize); 82 | CHECK_INPUT(xyz_min); 83 | CHECK_INPUT(xyz_max); 84 | CHECK_INPUT(units); 85 | CHECK_INPUT(local_range); 86 | CHECK_INPUT(local_dims); 87 | assert(pnt_xyz.dim()==2); 88 | return build_tensoRF_map_hier_cuda(pnt_xyz, gridSize, xyz_min, xyz_max, units, local_range, local_dims, max_tensoRF); 89 | } 90 | 91 | 92 | 93 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 94 | m.def("build_tensoRF_map_hier", &build_tensoRF_map_hier, "build tensoRF indices map"); 95 | m.def("sample_2_tensoRF_cvrg_hier", &sample_2_tensoRF_cvrg_hier, "Sampled points to get torsoRF"); 96 | m.def("sample_2_tensoRF_cvrg_hier_gs", &sample_2_tensoRF_cvrg_hier_gs, "Sampled points to get torsoRF with gs sample"); 97 | } 98 | 99 | 100 | -------------------------------------------------------------------------------- /models/init_net/utils.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import numpy as np 3 | import scipy.signal 4 | from typing import List, Optional 5 | 6 | from torch import Tensor 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from ..masked_adam import MaskedAdam 12 | 13 | 14 | ''' Misc 15 | ''' 16 | mse2psnr = lambda x : -10. * torch.log10(x) 17 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 18 | 19 | def create_optimizer_or_freeze_model(model, args, global_step): 20 | decay_steps = args.pre_lrate_decay * 1000 21 | decay_factor = 0.1 ** (global_step/decay_steps) 22 | 23 | param_group = [] 24 | for k in args.__dict__.keys(): 25 | if not k.startswith('pre_lrate_'): 26 | continue 27 | k = k[len('pre_lrate_'):] 28 | 29 | if not hasattr(model, k): 30 | continue 31 | 32 | param = getattr(model, k) 33 | if param is None: 34 | print(f'create_optimizer_or_freeze_model: param {k} not exist') 35 | continue 36 | 37 | lr = getattr(args, f'pre_lrate_{k}') * decay_factor 38 | if lr > 0: 39 | print(f'create_optimizer_or_freeze_model: param {k} lr {lr}') 40 | if isinstance(param, nn.Module): 41 | param = param.parameters() 42 | param_group.append({'params': param, 'lr': lr, 'skip_zero_grad': args.skip_zero_grad}) 43 | else: 44 | print(f'create_optimizer_or_freeze_model: param {k} freeze') 45 | param.requires_grad = False 46 | return MaskedAdam(param_group) 47 | 48 | 49 | ''' Checkpoint utils 50 | ''' 51 | def load_checkpoint(model, optimizer, ckpt_path, no_reload_optimizer): 52 | ckpt = torch.load(ckpt_path) 53 | start = ckpt['global_step'] 54 | model.load_state_dict(ckpt['model_state_dict']) 55 | if not no_reload_optimizer: 56 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 57 | return model, optimizer, start 58 | 59 | 60 | def load_model(model_class, ckpt_path): 61 | ckpt = torch.load(ckpt_path) 62 | model = model_class(**ckpt['model_kwargs']) 63 | model.load_state_dict(ckpt['model_state_dict']) 64 | return model 65 | 66 | 67 | ''' Evaluation metrics (ssim, lpips) 68 | ''' 69 | def rgb_ssim(img0, img1, max_val, 70 | filter_size=11, 71 | filter_sigma=1.5, 72 | k1=0.01, 73 | k2=0.03, 74 | return_map=False): 75 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 76 | assert len(img0.shape) == 3 77 | assert img0.shape[-1] == 3 78 | assert img0.shape == img1.shape 79 | 80 | # Construct a 1D Gaussian blur filter. 81 | hw = filter_size // 2 82 | shift = (2 * hw - filter_size + 1) / 2 83 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 84 | filt = np.exp(-0.5 * f_i) 85 | filt /= np.sum(filt) 86 | 87 | # Blur in x and y (faster than the 2D convolution). 88 | def convolve2d(z, f): 89 | return scipy.signal.convolve2d(z, f, mode='valid') 90 | 91 | filt_fn = lambda z: np.stack([ 92 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 93 | for i in range(z.shape[-1])], -1) 94 | mu0 = filt_fn(img0) 95 | mu1 = filt_fn(img1) 96 | mu00 = mu0 * mu0 97 | mu11 = mu1 * mu1 98 | mu01 = mu0 * mu1 99 | sigma00 = filt_fn(img0**2) - mu00 100 | sigma11 = filt_fn(img1**2) - mu11 101 | sigma01 = filt_fn(img0 * img1) - mu01 102 | 103 | # Clip the variances and covariances to valid values. 104 | # Variance must be non-negative: 105 | sigma00 = np.maximum(0., sigma00) 106 | sigma11 = np.maximum(0., sigma11) 107 | sigma01 = np.sign(sigma01) * np.minimum( 108 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 109 | c1 = (k1 * max_val)**2 110 | c2 = (k2 * max_val)**2 111 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 112 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 113 | ssim_map = numer / denom 114 | ssim = np.mean(ssim_map) 115 | return ssim_map if return_map else ssim 116 | 117 | 118 | __LPIPS__ = {} 119 | def init_lpips(net_name, device): 120 | assert net_name in ['alex', 'vgg'] 121 | import lpips 122 | print(f'init_lpips: lpips_{net_name}') 123 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 124 | 125 | def rgb_lpips(np_gt, np_im, net_name, device): 126 | if net_name not in __LPIPS__: 127 | __LPIPS__[net_name] = init_lpips(net_name, device) 128 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 129 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 130 | return __LPIPS__[net_name](gt, im, normalize=True).item() 131 | 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Strivec: Sparse Tri-Vector Radiance Fields 2 | 3 | 4 | [Video](https://youtu.be/zQ5Uli553CY) 5 | | [Paper](https://arxiv.org/abs/2307.13226) 6 | 7 | ![Teaser image](image/main_fig.png) 8 | 9 | ## Overal Instruction 10 | 1. We build the initial geometry with the 1st stage of DVGO in our implementation by default, which is `use_geo = -1` in config files. 11 | 2. The geometry can be either initialized online (by default) or from other sources in `.txt` form, which can be enabled with `use_geo = 1` and `pointfile = /your/file.txt` in config files. 12 | 3. You may ignore `preprosessing` folder, which is initially for our ealy trying and not used here. 13 | 4. You may refer to the comments in `./configs/synthetic-nerf/default/chair.txt` for the usage of hyperparameters. 14 | 15 | For Synthetic-NeRF dataset, we provide the initial geometry from DVGO (which is the default one in our implementation) and from [MVS](https://drive.google.com/file/d/1m6ftmKU4lhxXQZKhkoeeWnC9F85kyMBu/view?usp=sharing). Feel free to try both (e.g., `use_geo = 1` and `pointfile = /your/mvs_file.txt`) to see the comparison. 16 | 17 | For Scannet dataset, we use the initial geometry provided by the dataset itself. We convert the original `.ply` file into `.txt` and you may download from [here](https://drive.google.com/file/d/1QLeHGUwAqEkrZEQPQSDvSSyzViO1ziGY/view?usp=sharing). 18 | 19 | ## Installation 20 | 21 | ### Requirements 22 | All the codes are tested in the following environment: 23 | * Linux 18.04+ 24 | * Python 3.6+ 25 | * PyTorch 1.10+ 26 | * CUDA 10.2+ 27 | 28 | ## Data Preparation 29 | 30 | * [Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 31 | * [Scannet](https://drive.google.com/drive/folders/1GoxJyf_YYEGvWStD7SpcPBqhePqCGpEJ) 32 | * [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) 33 | * [Mip-NeRF360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) 34 | 35 | And the layout should look like this: 36 | 37 | ``` 38 | Strivec 39 | ├── data 40 | │ ├── nerf_synthetic 41 | │ │ |──default 42 | │ │ │ |──chair 43 | │ │ │ │──drums 44 | │ │ │ |──... 45 | │ │ |──local_vm 46 | │ │ │ |──chair 47 | │ │ │ │──drums 48 | │ │ │ |──... 49 | ├── scene0101_04 (scannet) 50 | │ │ │──exported 51 | │ │ │──scene0101_04_2d-instance-filt.zip 52 | │ │ │──... 53 | ├── scene0241_01 (scannet) 54 | │ │ │──exported 55 | │ │ │──scene0241_01_2d-instance-filt.zip 56 | │ │ │──... 57 | ├── TanksAndTemple 58 | │ │ │──Barn 59 | │ │ │──Caterpillar 60 | │ │ │──... 61 | ├── 360 (Mip-NeRF360) 62 | │ │ │──garden 63 | │ │ │──room 64 | │ │ │──... 65 | ``` 66 | 67 | ## Training & Evaluation 68 | We not only provide the training and evaluation code to reproduce the results in the paper, but also the code of ablation that uses local VM tensors instead of local CP tensors (results 69 | are [here](https://drive.google.com/drive/folders/1-OW0Qdnk4Wz-9BRr81P2mDe1aYDmjd0g?usp=sharing)). 70 | 71 | 72 | ``` 73 | # hierachical Strivec, without rotation (grid aligned) 74 | python train_hier.py --config ./configs/synthetic-nerf/default/chair.txt 75 | 76 | # local VM tensors instead of local CP tensors 77 | train_dbasis.py --config ./configs/synthetic-nerf/local_vm/chair.txt 78 | 79 | ``` 80 | 81 | ## Visualization 82 | We visualize the local tensors of different scales into `./log/your_scene/rot_tensoRF/0_lvl_k.ply`, where k is the kth scale. 83 | 84 | 85 | ![visual image](image/visualization.png) 86 | 87 | 88 | ## Why our local design is superior than original TensoRF against rotation 89 | 90 | 91 | Here is a toy example to illustrate the TensoRF-CP (TensoRF-VM is similar) with global decomposition in (left) axis-aligned and (right) non-axis-aligned situations. The bottom shows the grid values. 92 | In axis-aligned case, only 1 component is needed to represent the scene (vector bases recover grid values by outer product). In non-axis-aligned case, however, 3 components 93 | are needed because the rank of matrix changes from 1 to 3 after scene rotation. While our design with local low-rank tensors can alleviate this issue, i.e., local tensors (2*2) are always rank-1 before and after rotation. 94 | 95 | 96 | ## Citation 97 | If you find our code or paper helps, please consider citing: 98 | ``` 99 | @INPROCEEDINGS{gao2023iCCV, 100 | author = {Quankai Gao and Qiangeng Xu and Hao Su and Ulrich Neumann and Zexiang Xu}, 101 | title = {Strivec: Sparse Tri-Vector Radiance Fields}, 102 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 103 | year = {2023} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /models/cuda/adam_upd_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | template 9 | __global__ void adam_upd_cuda_kernel( 10 | scalar_t* __restrict__ param, 11 | const scalar_t* __restrict__ grad, 12 | scalar_t* __restrict__ exp_avg, 13 | scalar_t* __restrict__ exp_avg_sq, 14 | const size_t N, 15 | const float step_size, const float beta1, const float beta2, const float eps) { 16 | 17 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 18 | if(index 26 | __global__ void masked_adam_upd_cuda_kernel( 27 | scalar_t* __restrict__ param, 28 | const scalar_t* __restrict__ grad, 29 | scalar_t* __restrict__ exp_avg, 30 | scalar_t* __restrict__ exp_avg_sq, 31 | const size_t N, 32 | const float step_size, const float beta1, const float beta2, const float eps) { 33 | 34 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 35 | if(index 43 | __global__ void adam_upd_with_perlr_cuda_kernel( 44 | scalar_t* __restrict__ param, 45 | const scalar_t* __restrict__ grad, 46 | scalar_t* __restrict__ exp_avg, 47 | scalar_t* __restrict__ exp_avg_sq, 48 | scalar_t* __restrict__ perlr, 49 | const size_t N, 50 | const float step_size, const float beta1, const float beta2, const float eps) { 51 | 52 | const size_t index = blockIdx.x * blockDim.x + threadIdx.x; 53 | if(index<<>>( 76 | param.data(), 77 | grad.data(), 78 | exp_avg.data(), 79 | exp_avg_sq.data(), 80 | N, step_size, beta1, beta2, eps); 81 | })); 82 | } 83 | 84 | void masked_adam_upd_cuda( 85 | torch::Tensor param, 86 | torch::Tensor grad, 87 | torch::Tensor exp_avg, 88 | torch::Tensor exp_avg_sq, 89 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 90 | 91 | const size_t N = param.numel(); 92 | 93 | const int threads = 256; 94 | const int blocks = (N + threads - 1) / threads; 95 | 96 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 97 | 98 | AT_DISPATCH_FLOATING_TYPES(param.type(), "masked_adam_upd_cuda", ([&] { 99 | masked_adam_upd_cuda_kernel<<>>( 100 | param.data(), 101 | grad.data(), 102 | exp_avg.data(), 103 | exp_avg_sq.data(), 104 | N, step_size, beta1, beta2, eps); 105 | })); 106 | } 107 | 108 | void adam_upd_with_perlr_cuda( 109 | torch::Tensor param, 110 | torch::Tensor grad, 111 | torch::Tensor exp_avg, 112 | torch::Tensor exp_avg_sq, 113 | torch::Tensor perlr, 114 | const int step, const float beta1, const float beta2, const float lr, const float eps) { 115 | 116 | const size_t N = param.numel(); 117 | 118 | const int threads = 256; 119 | const int blocks = (N + threads - 1) / threads; 120 | 121 | const float step_size = lr * sqrt(1 - pow(beta2, (float)step)) / (1 - pow(beta1, (float)step)); 122 | 123 | AT_DISPATCH_FLOATING_TYPES(param.type(), "adam_upd_with_perlr_cuda", ([&] { 124 | adam_upd_with_perlr_cuda_kernel<<>>( 125 | param.data(), 126 | grad.data(), 127 | exp_avg.data(), 128 | exp_avg_sq.data(), 129 | perlr.data(), 130 | N, step_size, beta1, beta2, eps); 131 | })); 132 | } 133 | 134 | -------------------------------------------------------------------------------- /dataLoader/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | def get_cv_raydir(pixelcoords, height, width, focal, rot): 4 | # pixelcoords: H x W x 2 5 | if isinstance(focal, float): 6 | focal = [focal, focal] 7 | x = (pixelcoords[..., 0] - width / 2.0) / focal[0] 8 | y = (pixelcoords[..., 1] - height / 2.0) / focal[1] 9 | z = np.ones_like(x) 10 | dirs = np.stack([x, y, z], axis=-1) 11 | dirs = np.sum(rot[None,None,:,:] * dirs[...,None], axis=-2) # 1*1*3*3 x h*w*3*1 12 | dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5) 13 | 14 | return dirs 15 | 16 | 17 | def get_camera_rotation(eye, center, up): 18 | nz = center - eye 19 | nz /= np.linalg.norm(nz) 20 | x = np.cross(nz, up) 21 | x /= np.linalg.norm(x) 22 | y = np.cross(x, nz) 23 | return np.array([x, y, -nz]).T 24 | 25 | # 26 | # def get_blender_raydir(pixelcoords, height, width, focal, rot, dir_norm): 27 | # ## pixelcoords: H x W x 2 28 | # x = (pixelcoords[..., 0] - width / 2.0) / focal 29 | # y = (pixelcoords[..., 1] - height / 2.0) / focal 30 | # z = np.ones_like(x) 31 | # dirs = np.stack([x, -y, -z], axis=-1) 32 | # dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # 32, 32, 3 33 | # if dir_norm: 34 | # # print("dirs",dirs-dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5)) 35 | # dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5) 36 | # # print("dirs", dirs.shape) 37 | # 38 | # return dirs 39 | 40 | 41 | def get_blender_raydir(pixelcoords, height, width, focal, rot, dir_norm): 42 | ## pixelcoords: H x W x 2 43 | x = (pixelcoords[..., 0] + 0.5 - width / 2.0) / focal 44 | y = (pixelcoords[..., 1] + 0.5 - height / 2.0) / focal 45 | z = np.ones_like(x) 46 | dirs = np.stack([x, -y, -z], axis=-1) 47 | dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3 48 | if dir_norm: 49 | # print("dirs",dirs-dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5)) 50 | dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5) 51 | # print("dirs", dirs.shape) 52 | 53 | return dirs 54 | 55 | def get_dtu_raydir(pixelcoords, intrinsic, rot, dir_norm): 56 | # rot is c2w 57 | ## pixelcoords: H x W x 2 58 | x = (pixelcoords[..., 0] + 0.5 - intrinsic[0, 2]) / intrinsic[0, 0] 59 | y = (pixelcoords[..., 1] + 0.5 - intrinsic[1, 2]) / intrinsic[1, 1] 60 | z = np.ones_like(x) 61 | dirs = np.stack([x, y, z], axis=-1) 62 | # dirs = np.sum(dirs[...,None,:] * rot[:,:], axis=-1) # h*w*1*3 x 3*3 63 | dirs = dirs @ rot[:,:].T # 64 | if dir_norm: 65 | # print("dirs",dirs-dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5)) 66 | dirs = dirs / (np.linalg.norm(dirs, axis=-1, keepdims=True) + 1e-5) 67 | # print("dirs", dirs.shape) 68 | 69 | return dirs 70 | 71 | 72 | def get_optix_raydir(pixelcoords, height, width, focal, eye, center, up): 73 | c2w = get_camera_rotation(eye, center, up) 74 | return get_blender_raydir(pixelcoords, height, width, focal, c2w) 75 | 76 | 77 | def flip_z(poses): 78 | z_flip_matrix = np.eye(4, dtype=np.float32) 79 | z_flip_matrix[2, 2] = -1.0 80 | return np.matmul(poses, z_flip_matrix[None,...]) 81 | 82 | 83 | def triangluation_bpa(pnts, test_pnts=None, full_comb=False): 84 | pcd = o3d.geometry.PointCloud() 85 | pcd.points = o3d.utility.Vector3dVector(pnts[:, :3]) 86 | pcd.normals = o3d.utility.Vector3dVector(pnts[:, :3] / np.linalg.norm(pnts[:, :3], axis=-1, keepdims=True)) 87 | 88 | # pcd.colors = o3d.utility.Vector3dVector(pnts[:, 3:6] / 255) 89 | # pcd.normals = o3d.utility.Vector3dVector(pnts[:, 6:9]) 90 | # o3d.visualization.draw_geometries([pcd]) 91 | 92 | distances = pcd.compute_nearest_neighbor_distance() 93 | avg_dist = np.mean(distances) 94 | 95 | 96 | radius = 3 * avg_dist 97 | dec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector( 98 | [radius, radius * 2])) 99 | # dec_mesh = dec_mesh.simplify_quadric_decimation(100000) 100 | # dec_mesh.remove_degenerate_triangles() 101 | # dec_mesh.remove_duplicated_triangles() 102 | # dec_mesh.remove_duplicated_vertices() 103 | # dec_mesh.remove_non_manifold_edges() 104 | 105 | # vis_lst = [dec_mesh, pcd] 106 | # vis_lst = [dec_mesh, pcd] 107 | # o3d.visualization.draw_geometries(vis_lst) 108 | # if test_pnts is not None : 109 | # tpcd = o3d.geometry.PointCloud() 110 | # print("test_pnts",test_pnts.shape) 111 | # tpcd.points = o3d.utility.Vector3dVector(test_pnts[:, :3]) 112 | # tpcd.normals = o3d.utility.Vector3dVector(test_pnts[:, :3] / np.linalg.norm(test_pnts[:, :3], axis=-1, keepdims=True)) 113 | # o3d.visualization.draw_geometries([dec_mesh, tpcd] ) 114 | triangles = np.asarray(dec_mesh.triangles, dtype=np.int32) 115 | if full_comb: 116 | q, w, e = triangles[..., 0], triangles[..., 1], triangles[..., 2] 117 | triangles2 = np.stack([w,q,e], axis=-1) 118 | triangles3 = np.stack([e,q,w], axis=-1) 119 | triangles = np.concatenate([triangles, triangles2, triangles3], axis=0) 120 | return triangles 121 | 122 | -------------------------------------------------------------------------------- /models/cuda/search_dbasis.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | 8 | std::vector build_cubic_tensoRF_map_hier_cuda( 9 | torch::Tensor pnt_xyz, 10 | torch::Tensor gridSize, 11 | torch::Tensor xyz_min, 12 | torch::Tensor xyz_max, 13 | torch::Tensor units, 14 | torch::Tensor radius, 15 | torch::Tensor local_range, 16 | torch::Tensor pnt_rmatrix, 17 | torch::Tensor local_dims, 18 | const int max_tensoRF); 19 | 20 | 21 | std::vector sample_2_rot_cubic_tensoRF_cvrg_cuda( 22 | torch::Tensor xyz_sampled, 23 | torch::Tensor xyz_min, 24 | torch::Tensor xyz_max, 25 | torch::Tensor units, 26 | torch::Tensor local_range, 27 | torch::Tensor local_dims, 28 | torch::Tensor tensoRF_cvrg_inds, 29 | torch::Tensor tensoRF_count, 30 | torch::Tensor tensoRF_topindx, 31 | torch::Tensor geo_xyz, 32 | torch::Tensor geo_rot, 33 | torch::Tensor dim_cumsum_counter, 34 | const int K, 35 | const bool KNN); 36 | 37 | 38 | torch::Tensor filter_tensoRF_cuda( 39 | torch::Tensor xyz_sampled, 40 | torch::Tensor xyz_inbbox, 41 | torch::Tensor xyz_min, 42 | torch::Tensor xyz_max, 43 | torch::Tensor units, 44 | torch::Tensor local_range, 45 | torch::Tensor local_dims, 46 | torch::Tensor tensoRF_cvrg_inds, 47 | torch::Tensor tensoRF_count, 48 | torch::Tensor tensoRF_topindx, 49 | torch::Tensor geo_xyz, 50 | torch::Tensor geo_rot, 51 | const int K, 52 | const int ord_thresh); 53 | 54 | // C++ interface 55 | 56 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 57 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 58 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 59 | 60 | 61 | torch::Tensor filter_tensoRF( 62 | torch::Tensor xyz_sampled, 63 | torch::Tensor xyz_inbbox, 64 | torch::Tensor xyz_min, 65 | torch::Tensor xyz_max, 66 | torch::Tensor units, 67 | torch::Tensor local_range, 68 | torch::Tensor local_dims, 69 | torch::Tensor tensoRF_cvrg_inds, 70 | torch::Tensor tensoRF_count, 71 | torch::Tensor tensoRF_topindx, 72 | torch::Tensor geo_rot, 73 | torch::Tensor geo_xyz, 74 | const int K, 75 | const int ord_thresh) { 76 | CHECK_INPUT(xyz_sampled); 77 | CHECK_INPUT(geo_xyz); 78 | CHECK_INPUT(geo_rot); 79 | assert(xyz_sampled.dim()==3); 80 | return filter_tensoRF_cuda(xyz_sampled, xyz_inbbox, xyz_min, xyz_max, units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, geo_rot, K, ord_thresh); 81 | } 82 | 83 | 84 | std::vector sample_2_rot_cubic_tensoRF_cvrg( 85 | torch::Tensor xyz_sampled, 86 | torch::Tensor xyz_min, 87 | torch::Tensor xyz_max, 88 | torch::Tensor units, 89 | torch::Tensor local_range, 90 | torch::Tensor local_dims, 91 | torch::Tensor tensoRF_cvrg_inds, 92 | torch::Tensor tensoRF_count, 93 | torch::Tensor tensoRF_topindx, 94 | torch::Tensor geo_rot, 95 | torch::Tensor geo_xyz, 96 | torch::Tensor dim_cumsum_counter, 97 | const int K, 98 | const bool KNN) { 99 | CHECK_INPUT(xyz_sampled); 100 | CHECK_INPUT(geo_xyz); 101 | CHECK_INPUT(geo_rot); 102 | CHECK_INPUT(dim_cumsum_counter); 103 | assert(xyz_sampled.dim()==2); 104 | return sample_2_rot_cubic_tensoRF_cvrg_cuda(xyz_sampled, xyz_min, xyz_max, units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, geo_rot, dim_cumsum_counter, K, KNN); 105 | } 106 | 107 | //xyz_sampled.contiguous(), self.aabb[0], self.aabb[1], self.units, self.local_range[l], self.local_dims[l], self.tensoRF_cvrg_inds[l], self.tensoRF_count[l], self.tensoRF_topindx[l], pnt_rmatrix[l], self.geo_xyz[l], self.dim_cumsum_counter, self.K_tensoRF[l], self.KNN 108 | 109 | std::vector build_cubic_tensoRF_map_hier( 110 | torch::Tensor pnt_xyz, 111 | torch::Tensor gridSize, 112 | torch::Tensor xyz_min, 113 | torch::Tensor xyz_max, 114 | torch::Tensor units, 115 | torch::Tensor radius, 116 | torch::Tensor local_range, 117 | torch::Tensor pnt_rmatrix, 118 | torch::Tensor local_dims, 119 | const int max_tensoRF){ 120 | CHECK_INPUT(pnt_xyz); 121 | CHECK_INPUT(gridSize); 122 | CHECK_INPUT(xyz_min); 123 | CHECK_INPUT(xyz_max); 124 | CHECK_INPUT(units); 125 | CHECK_INPUT(radius); 126 | CHECK_INPUT(local_range); 127 | CHECK_INPUT(pnt_rmatrix); 128 | CHECK_INPUT(local_dims); 129 | assert(pnt_xyz.dim()==2); 130 | return build_cubic_tensoRF_map_hier_cuda(pnt_xyz, gridSize, xyz_min, xyz_max, units, radius, local_range, pnt_rmatrix, local_dims, max_tensoRF); 131 | } 132 | 133 | 134 | 135 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 136 | m.def("sample_2_rot_cubic_tensoRF_cvrg", &sample_2_rot_cubic_tensoRF_cvrg, "Sampled points to get torsoRF"); 137 | m.def("build_cubic_tensoRF_map_hier", &build_cubic_tensoRF_map_hier, "build cubic tensoRF indices map"); 138 | m.def("filter_tensoRF", &filter_tensoRF, "filter tensoRF by threshold"); 139 | } 140 | 141 | 142 | -------------------------------------------------------------------------------- /models/cuda/search_geo_adapt.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | 8 | std::vector build_cubic_tensoRF_map_hier_cuda( 9 | torch::Tensor pnt_xyz, 10 | torch::Tensor gridSize, 11 | torch::Tensor xyz_min, 12 | torch::Tensor xyz_max, 13 | torch::Tensor units, 14 | torch::Tensor radius, 15 | torch::Tensor local_range, 16 | torch::Tensor pnt_rmatrix, 17 | torch::Tensor local_dims, 18 | const int max_tensoRF); 19 | 20 | 21 | std::vector sample_2_rot_cubic_tensoRF_cvrg_cuda( 22 | torch::Tensor xyz_sampled, 23 | torch::Tensor xyz_min, 24 | torch::Tensor xyz_max, 25 | torch::Tensor units, 26 | torch::Tensor local_range, 27 | torch::Tensor local_dims, 28 | torch::Tensor tensoRF_cvrg_inds, 29 | torch::Tensor tensoRF_count, 30 | torch::Tensor tensoRF_topindx, 31 | torch::Tensor geo_xyz, 32 | torch::Tensor geo_rot, 33 | torch::Tensor dim_cumsum_counter, 34 | const int K, 35 | const bool KNN); 36 | 37 | torch::Tensor filter_tensoRF_cuda( 38 | torch::Tensor xyz_sampled, 39 | torch::Tensor xyz_inbbox, 40 | torch::Tensor xyz_min, 41 | torch::Tensor xyz_max, 42 | torch::Tensor units, 43 | torch::Tensor local_range, 44 | torch::Tensor local_dims, 45 | torch::Tensor tensoRF_cvrg_inds, 46 | torch::Tensor tensoRF_count, 47 | torch::Tensor tensoRF_topindx, 48 | torch::Tensor geo_xyz, 49 | torch::Tensor geo_rot, 50 | const int K, 51 | const int ord_thresh); 52 | 53 | // C++ interface 54 | 55 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 56 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 57 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 58 | 59 | 60 | torch::Tensor filter_tensoRF( 61 | torch::Tensor xyz_sampled, 62 | torch::Tensor xyz_inbbox, 63 | torch::Tensor xyz_min, 64 | torch::Tensor xyz_max, 65 | torch::Tensor units, 66 | torch::Tensor local_range, 67 | torch::Tensor local_dims, 68 | torch::Tensor tensoRF_cvrg_inds, 69 | torch::Tensor tensoRF_count, 70 | torch::Tensor tensoRF_topindx, 71 | torch::Tensor geo_rot, 72 | torch::Tensor geo_xyz, 73 | const int K, 74 | const int ord_thresh) { 75 | CHECK_INPUT(xyz_sampled); 76 | CHECK_INPUT(geo_xyz); 77 | CHECK_INPUT(geo_rot); 78 | assert(xyz_sampled.dim()==3); 79 | return filter_tensoRF_cuda(xyz_sampled, xyz_inbbox, xyz_min, xyz_max, units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, geo_rot, K, ord_thresh); 80 | } 81 | 82 | 83 | std::vector sample_2_rot_cubic_tensoRF_cvrg( 84 | torch::Tensor xyz_sampled, 85 | torch::Tensor xyz_min, 86 | torch::Tensor xyz_max, 87 | torch::Tensor units, 88 | torch::Tensor local_range, 89 | torch::Tensor local_dims, 90 | torch::Tensor tensoRF_cvrg_inds, 91 | torch::Tensor tensoRF_count, 92 | torch::Tensor tensoRF_topindx, 93 | torch::Tensor geo_rot, 94 | torch::Tensor geo_xyz, 95 | torch::Tensor dim_cumsum_counter, 96 | const int K, 97 | const bool KNN) { 98 | CHECK_INPUT(xyz_sampled); 99 | CHECK_INPUT(geo_xyz); 100 | CHECK_INPUT(geo_rot); 101 | CHECK_INPUT(dim_cumsum_counter); 102 | assert(xyz_sampled.dim()==2); 103 | return sample_2_rot_cubic_tensoRF_cvrg_cuda(xyz_sampled, xyz_min, xyz_max, units, local_range, local_dims, tensoRF_cvrg_inds, tensoRF_count, tensoRF_topindx, geo_xyz, geo_rot, dim_cumsum_counter, K, KNN); 104 | } 105 | 106 | 107 | //xyz_sampled.contiguous(), self.aabb[0], self.aabb[1], self.units, self.local_range[l], self.local_dims[l], self.tensoRF_cvrg_inds[l], self.tensoRF_count[l], self.tensoRF_topindx[l], pnt_rmatrix[l], self.geo_xyz[l], self.dim_cumsum_counter, self.K_tensoRF[l], self.KNN 108 | 109 | std::vector build_cubic_tensoRF_map_hier( 110 | torch::Tensor pnt_xyz, 111 | torch::Tensor gridSize, 112 | torch::Tensor xyz_min, 113 | torch::Tensor xyz_max, 114 | torch::Tensor units, 115 | torch::Tensor radius, 116 | torch::Tensor local_range, 117 | torch::Tensor pnt_rmatrix, 118 | torch::Tensor local_dims, 119 | const int max_tensoRF){ 120 | CHECK_INPUT(pnt_xyz); 121 | CHECK_INPUT(gridSize); 122 | CHECK_INPUT(xyz_min); 123 | CHECK_INPUT(xyz_max); 124 | CHECK_INPUT(units); 125 | CHECK_INPUT(radius); 126 | CHECK_INPUT(local_range); 127 | CHECK_INPUT(pnt_rmatrix); 128 | CHECK_INPUT(local_dims); 129 | assert(pnt_xyz.dim()==2); 130 | return build_cubic_tensoRF_map_hier_cuda(pnt_xyz, gridSize, xyz_min, xyz_max, units, radius, local_range, pnt_rmatrix, local_dims, max_tensoRF); 131 | } 132 | 133 | 134 | 135 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 136 | m.def("sample_2_rot_cubic_tensoRF_cvrg", &sample_2_rot_cubic_tensoRF_cvrg, "Sampled points to get torsoRF"); 137 | m.def("build_cubic_tensoRF_map_hier", &build_cubic_tensoRF_map_hier, "build cubic tensoRF indices map"); 138 | m.def("filter_tensoRF", &filter_tensoRF, "filter tensoRF by threshold"); 139 | } 140 | 141 | 142 | -------------------------------------------------------------------------------- /preprocessing/cluster.py: -------------------------------------------------------------------------------- 1 | # dbscan clustering 2 | import sys 3 | import os 4 | import pathlib 5 | sys.path.append(os.path.join(pathlib.Path(__file__).parent.absolute(), '..')) 6 | import torch 7 | import numpy as np 8 | # torch.manual_seed(0) 9 | # np.random.seed(0) 10 | import time 11 | 12 | from opt_adapt import config_parser 13 | args = config_parser() 14 | print(args) 15 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu_ids 16 | from models.apparatus import * 17 | from utils import * 18 | import numpy as np 19 | from numpy import unique 20 | from numpy import where 21 | from sklearn.datasets import make_classification 22 | from sklearn.cluster import DBSCAN, KMeans, MeanShift, SpectralClustering 23 | from sklearn.mixture import GaussianMixture 24 | # from preprocessing.gmm_BatyaGG.GMM_GMR import GMM_GMR 25 | # from preprocessing.fast_gmm.python.pygmm import GMM as fast_gmm 26 | from preprocessing.gmm_torch.gmm import GaussianMixture as torch_gmm 27 | from matplotlib import pyplot 28 | from mvs import mvs_utils, filter_utils 29 | 30 | path = "/home/xharlie/dev/cloud_tensoRF/log/ship_adapt_0.4_0.2/clusters/" 31 | 32 | import numpy as np 33 | 34 | def scatter_mean(indices, updates, shape): 35 | target = np.zeros(shape, dtype=updates.dtype) 36 | divd = np.zeros(shape, dtype=updates.dtype) 37 | indices = tuple(indices.reshape(1, -1)) 38 | np.add.at(target, indices, updates) 39 | np.add.at(divd, indices, 1) 40 | 41 | return target / divd 42 | 43 | def load_pnts(): 44 | vox_res=100 45 | pointfile="/home/xharlie/dev/cloud_tensoRF/log/ship_points.txt" 46 | xyz_world_all = torch.as_tensor(np.loadtxt(pointfile, delimiter=";"), dtype=torch.float32) 47 | xyz_world_all, _, sampled_pnt_idx = mvs_utils.construct_vox_points_closest( 48 | xyz_world_all.cuda() if len(xyz_world_all) < 99999999 else xyz_world_all[::(len(xyz_world_all) // 99999999 + 1), ...].cuda(), vox_res) 49 | return xyz_world_all[:,:3].cpu().numpy() 50 | 51 | def cluster(X, method="gmm", num=100, vis=False, tol=0.0005): 52 | # define the model 53 | # ‘full’: each component has its own general covariance matrix. 54 | # ‘tied’: all components share the same general covariance matrix. 55 | # ‘diag’: each component has its own diagonal covariance matrix. 56 | # ‘spherical’: each component has its own single variance. 57 | 58 | if method == "gmm": 59 | # 'kmeans', 'k-means++', 'random', 'random_from_data' 60 | np.random.seed(int(time.time())) 61 | model = GaussianMixture(n_components=num, covariance_type="full", init_params="k-means++", max_iter=1000, tol=tol) 62 | # elif method == "gmgg": 63 | # model = GMM_GMR(num) 64 | # elif method == "fastgmm": 65 | # model = fast_gmm(nr_mixture = num, min_covar = tol, nr_iteration = 1000, concurrency = 18) 66 | elif method == "torchgmm": 67 | model = torch_gmm(num, X.shape[-1], covariance_type="full", init_params="kmeans", max_iter=1000, tol=tol) 68 | elif method == "km": 69 | model = KMeans(n_clusters=num, init="k-means++") 70 | elif method == "ms": 71 | model = MeanShift(n_jobs=18) 72 | elif method == "sc": 73 | model = SpectralClustering(n_clusters=num, n_jobs=18) 74 | elif method == "db": 75 | model = DBSCAN(eps=0.17, min_samples=1, n_jobs=18) 76 | 77 | # fit model and predict clusters 78 | print("start clustering using ", method) 79 | cluster_inds = np.asarray(model.fit_predict(X)) 80 | print("cluster_inds", X.shape, cluster_inds.shape) 81 | # score_samples = np.asarray(model.score_samples(X)) 82 | # # scores = np.asarray(model.score(X)) 83 | # cluster_scores = np.asarray(model.predict_proba(X)) 84 | # print("cluster_scores", cluster_scores.shape) 85 | # max_cluster_scores = np.max(cluster_scores, axis=1) 86 | # max_cluster_scores_inds = np.argmax(cluster_scores, axis=1) 87 | # sum_cluster_scores = np.sum(cluster_scores, axis=1) 88 | # min_inds = np.argsort(max_cluster_scores) 89 | # # mincluster_inds = np.argmin(cluster_scores, axis=0) 90 | # print("cluster_scores", max_cluster_scores) 91 | # print("score_min", max_cluster_scores[min_inds[100]], max_cluster_scores[min_inds[1000]], max_cluster_scores[min_inds[3000]], max_cluster_scores[min_inds[10000]]) 92 | # 93 | # 94 | # # print("max_cluster_scores_inds", max_cluster_scores_inds.shape, max_cluster_scores.shape) 95 | # # max_cluster_mean = scatter_mean(max_cluster_scores_inds, max_cluster_scores, [num]) 96 | # # print("max_cluster_mean", max_cluster_mean) 97 | # # max_cluster_mean_inds = np.argsort(max_cluster_mean) 98 | # # print("smallest_cluster_inds", max_cluster_mean_inds) 99 | # retrieve unique clusters 100 | clusters = np.unique(cluster_inds) 101 | # print("finished with ", len(clusters), " clusters", clusters) 102 | cluster_xyz=np.zeros([len(clusters),3], dtype=np.float32) 103 | # create scatter plot for samples from each cluster 104 | counter = 0 105 | for i in range(len(clusters)): 106 | row_mask = cluster_inds == clusters[i] 107 | # print("row_ix", row_mask.shape) 108 | os.makedirs(path, exist_ok=True) 109 | if vis: 110 | np.savetxt(os.path.join(path, "cluster_{:04d}.txt".format(counter)), X[row_mask], delimiter=";") 111 | cluster_xyz[i] = np.mean(X[row_mask][...,:3], axis=0) 112 | counter+=1 113 | # print(np.asarray(model.means_)[...,:3], cluster_xyz) 114 | return cluster_xyz, cluster_inds, model 115 | 116 | if __name__ == '__main__': 117 | X = load_pnts() 118 | np.savetxt("/home/xharlie/dev/cloud_tensoRF/log/ship_adapt_0.4_0.2_try/clusters/ship.txt", X, delimiter=";") 119 | print("X.shape", X.shape) 120 | cluster(X, vis=True) 121 | -------------------------------------------------------------------------------- /dataLoader/your_own_data.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class YourOwnDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.downsample = downsample 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [0.1,100.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample) 45 | self.img_wh = [w,h] 46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 48 | self.cx, self.cy = self.meta['cx'],self.meta['cy'] 49 | 50 | 51 | # ray directions for all pixels, same for all images (same H, W, focal) 52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float() 55 | 56 | self.image_paths = [] 57 | self.poses = [] 58 | self.all_rays = [] 59 | self.all_rgbs = [] 60 | self.all_masks = [] 61 | self.all_depth = [] 62 | 63 | 64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 67 | 68 | frame = self.meta['frames'][i] 69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 70 | c2w = torch.FloatTensor(pose) 71 | self.poses += [c2w] 72 | 73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 74 | self.image_paths += [image_path] 75 | img = Image.open(image_path) 76 | 77 | if self.downsample!=1.0: 78 | img = img.resize(self.img_wh, Image.LANCZOS) 79 | img = self.transform(img) # (4, h, w) 80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA 81 | if img.shape[-1]==4: 82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 83 | self.all_rgbs += [img] 84 | 85 | 86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 88 | 89 | 90 | self.poses = torch.stack(self.poses) 91 | if not self.is_stack: 92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 94 | 95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 96 | else: 97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 100 | 101 | 102 | def define_transforms(self): 103 | self.transform = T.ToTensor() 104 | 105 | def define_proj_mat(self): 106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 107 | 108 | def world2ndc(self,points,lindisp=None): 109 | device = points.device 110 | return (points - self.center.to(device)) / self.radius.to(device) 111 | 112 | def __len__(self): 113 | return len(self.all_rgbs) 114 | 115 | def __getitem__(self, idx): 116 | 117 | if self.split == 'train': # use data in the buffers 118 | sample = {'rays': self.all_rays[idx], 119 | 'rgbs': self.all_rgbs[idx]} 120 | 121 | else: # create data for each image separately 122 | 123 | img = self.all_rgbs[idx] 124 | rays = self.all_rays[idx] 125 | mask = self.all_masks[idx] # for quantity evaluation 126 | 127 | sample = {'rays': rays, 128 | 'rgbs': img} 129 | return sample 130 | -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /mvs/mvsnet/mvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .module import * 5 | 6 | 7 | class FeatureNet(nn.Module): 8 | def __init__(self): 9 | super(FeatureNet, self).__init__() 10 | self.inplanes = 32 11 | 12 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1) 13 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1) 14 | 15 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2) 16 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1) 17 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1) 18 | 19 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2) 20 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1) 21 | self.feature = nn.Conv2d(32, 32, 3, 1, 1) 22 | 23 | def forward(self, x): 24 | x = self.conv1(self.conv0(x)) 25 | x = self.conv4(self.conv3(self.conv2(x))) 26 | x = self.feature(self.conv6(self.conv5(x))) 27 | return x 28 | 29 | 30 | class CostRegNet(nn.Module): 31 | def __init__(self): 32 | super(CostRegNet, self).__init__() 33 | self.conv0 = ConvBnReLU3D(32, 8) 34 | 35 | self.conv1 = ConvBnReLU3D(8, 16, stride=2) 36 | self.conv2 = ConvBnReLU3D(16, 16) 37 | 38 | self.conv3 = ConvBnReLU3D(16, 32, stride=2) 39 | self.conv4 = ConvBnReLU3D(32, 32) 40 | 41 | self.conv5 = ConvBnReLU3D(32, 64, stride=2) 42 | self.conv6 = ConvBnReLU3D(64, 64) 43 | 44 | self.conv7 = nn.Sequential( 45 | nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 46 | nn.BatchNorm3d(32), 47 | nn.ReLU(inplace=True)) 48 | 49 | self.conv9 = nn.Sequential( 50 | nn.ConvTranspose3d(32, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 51 | nn.BatchNorm3d(16), 52 | nn.ReLU(inplace=True)) 53 | 54 | self.conv11 = nn.Sequential( 55 | nn.ConvTranspose3d(16, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 56 | nn.BatchNorm3d(8), 57 | nn.ReLU(inplace=True)) 58 | 59 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 60 | 61 | def forward(self, x): 62 | conv0 = self.conv0(x) 63 | conv2 = self.conv2(self.conv1(conv0)) 64 | conv4 = self.conv4(self.conv3(conv2)) 65 | x = self.conv6(self.conv5(conv4)) 66 | x = conv4 + self.conv7(x) 67 | x = conv2 + self.conv9(x) 68 | x = conv0 + self.conv11(x) 69 | x = self.prob(x) 70 | return x 71 | 72 | 73 | class RefineNet(nn.Module): 74 | def __init__(self): 75 | super(RefineNet, self).__init__() 76 | self.conv1 = ConvBnReLU(4, 32) 77 | self.conv2 = ConvBnReLU(32, 32) 78 | self.conv3 = ConvBnReLU(32, 32) 79 | self.res = ConvBnReLU(32, 1) 80 | 81 | def forward(self, img, depth_init): 82 | concat = F.cat((img, depth_init), dim=1) 83 | depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat)))) 84 | depth_refined = depth_init + depth_residual 85 | return depth_refined 86 | 87 | 88 | class MVSNet(nn.Module): 89 | def __init__(self, refine=False): 90 | super(MVSNet, self).__init__() 91 | self.refine = refine 92 | 93 | self.feature = FeatureNet() 94 | self.cost_regularization = CostRegNet() 95 | if self.refine: 96 | self.refine_network = RefineNet() 97 | 98 | def forward(self, imgs, proj_matrices, depth_values, features=None, prob_only=False): 99 | 100 | imgs = torch.unbind(imgs, 1) 101 | num_depth = depth_values.shape[1] 102 | num_views = len(imgs) 103 | 104 | # step 1. feature extraction 105 | # in: images; out: 32-channel feature maps 106 | if features is None: 107 | features = [self.feature(img) for img in imgs] 108 | 109 | # step 2. differentiable homograph, build cost volume 110 | volume_sum = 0 111 | volume_sq_sum = 0 112 | for vid in range(num_views): 113 | # warpped features 114 | warped_volume = homo_warping(features[vid], proj_matrices[:, vid], depth_values) 115 | if self.training: 116 | volume_sum = volume_sum + warped_volume 117 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 118 | else: 119 | volume_sum += warped_volume 120 | volume_sq_sum += warped_volume.pow_(2) # the memory of warped_volume has been modified 121 | del warped_volume 122 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2)) 123 | 124 | # step 3. cost volume regularization 125 | cost_reg = self.cost_regularization(volume_variance) 126 | cost_reg = cost_reg.squeeze(1) 127 | prob_volume = F.softmax(cost_reg, dim=1) 128 | if prob_only: 129 | return features, prob_volume, cost_reg 130 | depth = depth_regression(prob_volume, depth_values=depth_values) 131 | 132 | with torch.no_grad(): 133 | # photometric confidence 134 | prob_volume_sum4 = 4 * F.avg_pool3d(F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0).squeeze(1) 135 | depth_index = depth_regression(prob_volume, depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float)).long() 136 | photometric_confidence = torch.gather(prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1) 137 | 138 | # step 4. depth map refinement 139 | if not self.refine: 140 | return depth, photometric_confidence, features, prob_volume # {"depth": depth, "photometric_confidence": photometric_confidence} 141 | else: 142 | refined_depth = self.refine_network(torch.cat((imgs[0], depth), 1)) 143 | return {"depth": depth, "refined_depth": refined_depth, "photometric_confidence": photometric_confidence} 144 | 145 | 146 | def mvsnet_loss(depth_est, depth_gt, mask): 147 | mask = mask > 0.5 148 | return F.smooth_l1_loss(depth_est[mask], depth_gt[mask], size_average=True) 149 | -------------------------------------------------------------------------------- /mvs/mvsnet/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvBnReLU(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 8 | super(ConvBnReLU, self).__init__() 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 10 | self.bn = nn.BatchNorm2d(out_channels) 11 | 12 | def forward(self, x): 13 | return F.relu(self.bn(self.conv(x)), inplace=True) 14 | 15 | 16 | class ConvBn(nn.Module): 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 18 | super(ConvBn, self).__init__() 19 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 20 | self.bn = nn.BatchNorm2d(out_channels) 21 | 22 | def forward(self, x): 23 | return self.bn(self.conv(x)) 24 | 25 | 26 | class ConvBnReLU3D(nn.Module): 27 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 28 | super(ConvBnReLU3D, self).__init__() 29 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 30 | self.bn = nn.BatchNorm3d(out_channels) 31 | 32 | def forward(self, x): 33 | return F.relu(self.bn(self.conv(x)), inplace=True) 34 | 35 | 36 | def homo_warping(src_fea, proj, depth_values): 37 | # src_fea: [B, C, H, W] 38 | # src_proj: [B, 4, 4] 39 | # ref_proj: [B, 4, 4] 40 | # depth_values: [B, Ndepth] 41 | # out: [B, C, Ndepth, H, W] 42 | batch, channels = src_fea.shape[0], src_fea.shape[1] 43 | num_depth = depth_values.shape[1] 44 | height, width = src_fea.shape[2], src_fea.shape[3] 45 | 46 | with torch.no_grad(): 47 | rot = proj[:, :3, :3] # [B,3,3] 48 | trans = proj[:, :3, 3:4] # [B,3,1] 49 | 50 | y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 51 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device)]) 52 | y, x = y.contiguous(), x.contiguous() 53 | y, x = y.view(height * width), x.view(height * width) 54 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 55 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 56 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 57 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, 58 | 1) # [B, 3, Ndepth, H*W] 59 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 60 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 61 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 62 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 63 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 64 | grid = proj_xy 65 | 66 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', 67 | padding_mode='zeros') 68 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) 69 | 70 | return warped_src_fea 71 | 72 | 73 | def depth_regression(p, depth_values): 74 | # p: probability volume [B, D, H, W] 75 | # depth_values: discrete depth values [B, D] 76 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 77 | depth = torch.sum(p * depth_values, 1) 78 | return depth 79 | 80 | 81 | if __name__ == "__main__": 82 | # some testing code, just IGNORE it 83 | from datasets import find_dataset_def 84 | from torch.utils.data import DataLoader 85 | import numpy as np 86 | import cv2 87 | 88 | MVSDataset = find_dataset_def("dtu_yao") 89 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_training/dtu/", '../lists/dtu/train.txt', 'train', 90 | 3, 256) 91 | dataloader = DataLoader(dataset, batch_size=2) 92 | item = next(iter(dataloader)) 93 | 94 | imgs = item["imgs"][:, :, :, ::4, ::4].cuda() 95 | proj_matrices = item["proj_matrices"].cuda() 96 | mask = item["mask"].cuda() 97 | depth = item["depth"].cuda() 98 | depth_values = item["depth_values"].cuda() 99 | 100 | imgs = torch.unbind(imgs, 1) 101 | proj_matrices = torch.unbind(proj_matrices, 1) 102 | ref_img, src_imgs = imgs[0], imgs[1:] 103 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 104 | 105 | warped_imgs = homo_warping(src_imgs[0], src_projs[0], ref_proj, depth_values) 106 | 107 | cv2.imwrite('../tmp/ref.png', ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) 108 | cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) 109 | 110 | for i in range(warped_imgs.shape[2]): 111 | warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous() 112 | img_np = warped_img[0].detach().cpu().numpy() 113 | cv2.imwrite('../tmp/tmp{}.png'.format(i), img_np[:, :, ::-1] * 255) 114 | 115 | 116 | # generate gt 117 | def tocpu(x): 118 | return x.detach().cpu().numpy().copy() 119 | 120 | 121 | ref_img = tocpu(ref_img)[0].transpose([1, 2, 0]) 122 | src_imgs = [tocpu(x)[0].transpose([1, 2, 0]) for x in src_imgs] 123 | ref_proj_mat = tocpu(ref_proj)[0] 124 | src_proj_mats = [tocpu(x)[0] for x in src_projs] 125 | mask = tocpu(mask)[0] 126 | depth = tocpu(depth)[0] 127 | depth_values = tocpu(depth_values)[0] 128 | 129 | for i, D in enumerate(depth_values): 130 | height = ref_img.shape[0] 131 | width = ref_img.shape[1] 132 | xx, yy = np.meshgrid(np.arange(0, width), np.arange(0, height)) 133 | print("yy", yy.max(), yy.min()) 134 | yy = yy.reshape([-1]) 135 | xx = xx.reshape([-1]) 136 | X = np.vstack((xx, yy, np.ones_like(xx))) 137 | # D = depth.reshape([-1]) 138 | # print("X", "D", X.shape, D.shape) 139 | 140 | X = np.vstack((X * D, np.ones_like(xx))) 141 | X = np.matmul(np.linalg.inv(ref_proj_mat), X) 142 | X = np.matmul(src_proj_mats[0], X) 143 | X /= X[2] 144 | X = X[:2] 145 | 146 | yy = X[0].reshape([height, width]).astype(np.float32) 147 | xx = X[1].reshape([height, width]).astype(np.float32) 148 | 149 | warped = cv2.remap(src_imgs[0], yy, xx, interpolation=cv2.INTER_LINEAR) 150 | # warped[mask[:, :] < 0.5] = 0 151 | 152 | cv2.imwrite('../tmp/tmp{}_gt.png'.format(i), warped[:, :, ::-1] * 255) 153 | -------------------------------------------------------------------------------- /dataLoader/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class NSVF(Dataset): 37 | """NSVF Generic Dataset.""" 38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): 39 | self.root_dir = datadir 40 | self.split = split 41 | self.is_stack = is_stack 42 | self.downsample = downsample 43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 44 | self.define_transforms() 45 | 46 | self.white_bg = True 47 | self.near_far = [0.5,6.0] 48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) 49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 50 | self.read_meta() 51 | self.define_proj_mat() 52 | 53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 55 | 56 | def bbox2corners(self): 57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 58 | for i in range(3): 59 | corners[i,[0,1],i] = corners[i,[1,0],i] 60 | return corners.view(-1,3) 61 | 62 | 63 | def read_meta(self): 64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: 65 | focal = float(f.readline().split()[0]) 66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) 67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) 68 | 69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 71 | 72 | if self.split == 'train': 73 | pose_files = [x for x in pose_files if x.startswith('0_')] 74 | img_files = [x for x in img_files if x.startswith('0_')] 75 | elif self.split == 'val': 76 | pose_files = [x for x in pose_files if x.startswith('1_')] 77 | img_files = [x for x in img_files if x.startswith('1_')] 78 | elif self.split == 'test': 79 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 80 | test_img_files = [x for x in img_files if x.startswith('2_')] 81 | if len(test_pose_files) == 0: 82 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 83 | test_img_files = [x for x in img_files if x.startswith('1_')] 84 | pose_files = test_pose_files 85 | img_files = test_img_files 86 | 87 | # ray directions for all pixels, same for all images (same H, W, focal) 88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 90 | 91 | 92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 93 | 94 | self.poses = [] 95 | self.all_rays = [] 96 | self.all_rgbs = [] 97 | 98 | assert len(img_files) == len(pose_files) 99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 101 | img = Image.open(image_path) 102 | if self.downsample!=1.0: 103 | img = img.resize(self.img_wh, Image.LANCZOS) 104 | img = self.transform(img) # (4, h, w) 105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 106 | if img.shape[-1]==4: 107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 108 | self.all_rgbs += [img] 109 | 110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv 111 | c2w = torch.FloatTensor(c2w) 112 | self.poses.append(c2w) # C2W 113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 115 | 116 | # w2c = torch.inverse(c2w) 117 | # 118 | 119 | self.poses = torch.stack(self.poses) 120 | if 'train' == self.split: 121 | if self.is_stack: 122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 124 | else: 125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 127 | else: 128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 130 | 131 | 132 | def define_transforms(self): 133 | self.transform = T.ToTensor() 134 | 135 | def define_proj_mat(self): 136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 137 | 138 | def world2ndc(self, points): 139 | device = points.device 140 | return (points - self.center.to(device)) / self.radius.to(device) 141 | 142 | def __len__(self): 143 | if self.split == 'train': 144 | return len(self.all_rays) 145 | return len(self.all_rgbs) 146 | 147 | def __getitem__(self, idx): 148 | 149 | if self.split == 'train': # use data in the buffers 150 | sample = {'rays': self.all_rays[idx], 151 | 'rgbs': self.all_rgbs[idx]} 152 | 153 | else: # create data for each image separately 154 | 155 | img = self.all_rgbs[idx] 156 | rays = self.all_rays[idx] 157 | 158 | sample = {'rays': rays, 159 | 'rgbs': img} 160 | return sample -------------------------------------------------------------------------------- /preprocessing/gmm_torch/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.mixture 3 | import torch 4 | 5 | from gmm import GaussianMixture 6 | 7 | import unittest 8 | 9 | 10 | class CpuCheck(unittest.TestCase): 11 | """ 12 | Basic tests for CPU. 13 | """ 14 | def testPredictClasses(self): 15 | """ 16 | Assert that torch.FloatTensor is handled correctly. 17 | """ 18 | x = torch.randn(400, 2) 19 | n_components = np.random.randint(1, 100) 20 | 21 | model = GaussianMixture(n_components, x.size(1)) 22 | model.fit(x) 23 | y = model.predict(x) 24 | 25 | # check that dimensionality of class memberships is (n) 26 | self.assertEqual(torch.Tensor(x.size(0)).size(), y.size()) 27 | 28 | 29 | def testPredictProbabilities(self): 30 | """ 31 | Assert that torch.FloatTensor is handled correctly when returning class probabilities. 32 | """ 33 | x = torch.randn(400, 2) 34 | n_components = np.random.randint(1, 100) 35 | 36 | model = GaussianMixture(n_components, x.size(1)) 37 | model.fit(x) 38 | 39 | # check that y_p has dimensions (n, k) 40 | y_p = model.predict(x, probs=True) 41 | self.assertEqual(torch.Tensor(x.size(0), n_components).size(), y_p.size()) 42 | 43 | 44 | def testEmMatchesDiagSkLearn(self): 45 | """ 46 | Assert that log-probabilities (E-step) and parameter updates (M-step) approximately match those of sklearn. 47 | """ 48 | d = 20 49 | n_components = np.random.randint(1, 100) 50 | 51 | # (n, k, d) 52 | x = torch.randn(400, 1, d) 53 | # (n, d) 54 | x_np = np.squeeze(x.data.numpy()) 55 | 56 | var_init = torch.ones(1, n_components, d) - .4 57 | 58 | model = GaussianMixture(n_components, d, var_init=var_init, covariance_type="diag") 59 | model_sk = sklearn.mixture.GaussianMixture(n_components, 60 | covariance_type="diag", 61 | init_params="random", 62 | means_init=np.squeeze(model.mu.data.numpy()), 63 | precisions_init=np.squeeze(1. / np.sqrt(var_init.data.numpy()))) 64 | 65 | model_sk._initialize_parameters(x_np, np.random.RandomState()) 66 | log_prob_sk = model_sk._estimate_log_prob(x_np) 67 | log_prob = model._estimate_log_prob(x) 68 | 69 | # Test whether log-probabilities are approximately equal 70 | np.testing.assert_almost_equal(np.squeeze(log_prob.data.numpy()), 71 | log_prob_sk, 72 | decimal=2, 73 | verbose=True) 74 | 75 | _, log_resp_sk = model_sk._e_step(x_np) 76 | _, log_resp = model._e_step(x) 77 | 78 | # Test whether E-steps are approximately equal 79 | np.testing.assert_almost_equal(np.squeeze(log_resp.data.numpy()), 80 | log_resp_sk, 81 | decimal=0, 82 | verbose=True) 83 | 84 | model_sk._m_step(x_np, log_prob_sk) 85 | pi_sk = model_sk.weights_ 86 | mu_sk = model_sk.means_ 87 | var_sk = model_sk.means_ 88 | 89 | pi, mu, var = model._m_step(x, log_prob) 90 | 91 | # Test whether pi .. 92 | np.testing.assert_almost_equal(np.squeeze(pi.data.numpy()), 93 | pi_sk, 94 | decimal=1, 95 | verbose=True) 96 | 97 | # .. mu .. 98 | np.testing.assert_almost_equal(np.squeeze(mu.data.numpy()), 99 | mu_sk, 100 | decimal=1, 101 | verbose=True) 102 | 103 | # .. and var are approximately equal 104 | np.testing.assert_almost_equal(np.squeeze(var.data.numpy()), 105 | var_sk, 106 | decimal=1, 107 | verbose=True) 108 | 109 | def testEmMatchesFullSkLearn(self): 110 | """ 111 | Assert that log-probabilities (E-step) and parameter updates (M-step) approximately match those of sklearn. 112 | """ 113 | d = 20 114 | n_components = np.random.randint(1, 100) 115 | 116 | # (n, k, d) 117 | x = torch.randn(400, 1, d) 118 | # (n, d) 119 | x_np = np.squeeze(x.data.numpy()) 120 | 121 | var_init = torch.eye(d,dtype=torch.float64).reshape(1, 1, d, d).repeat(1,n_components,1, 1) 122 | 123 | model = GaussianMixture(n_components, d, init_params="random", var_init=var_init, covariance_type="full") 124 | model_sk = sklearn.mixture.GaussianMixture(n_components, 125 | covariance_type="full", 126 | init_params="random", 127 | means_init=np.squeeze(model.mu.data.numpy()), 128 | precisions_init=np.squeeze(np.linalg.inv(var_init))) 129 | 130 | model_sk._initialize_parameters(x_np, np.random.RandomState()) 131 | log_prob_sk = model_sk._estimate_log_prob(x_np) 132 | log_prob = model._estimate_log_prob(x) 133 | 134 | # Test whether log-probabilities are approximately equal 135 | np.testing.assert_almost_equal(np.squeeze(log_prob.data.numpy()), 136 | log_prob_sk, 137 | decimal=2, 138 | verbose=True) 139 | 140 | _, log_resp_sk = model_sk._e_step(x_np) 141 | _, log_resp = model._e_step(x) 142 | 143 | # Test whether E-steps are approximately equal 144 | np.testing.assert_almost_equal(np.squeeze(log_resp.data.numpy()), 145 | log_resp_sk, 146 | decimal=0, 147 | verbose=True) 148 | 149 | model_sk._m_step(x_np, log_resp_sk) 150 | pi_sk = model_sk.weights_ 151 | mu_sk = model_sk.means_ 152 | var_sk = model_sk.covariances_ 153 | 154 | pi, mu, var = model._m_step(x, log_resp) 155 | 156 | # Test whether pi .. 157 | np.testing.assert_almost_equal(np.squeeze(pi.data.numpy()), 158 | pi_sk, 159 | decimal=1, 160 | verbose=True) 161 | 162 | # .. mu .. 163 | np.testing.assert_almost_equal(np.squeeze(mu.data.numpy()), 164 | mu_sk, 165 | decimal=1, 166 | verbose=True) 167 | 168 | # .. and var are approximately equal 169 | np.testing.assert_almost_equal(np.squeeze(var.data.numpy()), 170 | var_sk, 171 | decimal=1, 172 | verbose=True) 173 | 174 | 175 | class GpuCheck(unittest.TestCase): 176 | """ 177 | Basic tests for GPU. 178 | """ 179 | def testPredictClasses(self): 180 | """ 181 | Assert that torch.cuda.FloatTensor is handled correctly. 182 | """ 183 | x = torch.randn(400, 2).cuda() 184 | n_components = np.random.randint(1, 100) 185 | 186 | model = GaussianMixture(n_components, x.size(1), covariance_type="diag").cuda() 187 | model.fit(x) 188 | y = model.predict(x) 189 | 190 | # check that dimensionality of class memberships is (n) 191 | self.assertEqual(torch.Tensor(x.size(0)).size(), y.size()) 192 | 193 | 194 | def testPredictProbabilities(self): 195 | """ 196 | Assert that torch.cuda.FloatTensor is handled correctly when returning class probabilities. 197 | """ 198 | x = torch.randn(400, 2).cuda() 199 | n_components = np.random.randint(1, 100) 200 | 201 | model = GaussianMixture(n_components, x.size(1), covariance_type="diag").cuda() 202 | model.fit(x) 203 | 204 | # check that y_p has dimensions (n, k) 205 | y_p = model.predict(x, probs=True) 206 | self.assertEqual(torch.Tensor(x.size(0), n_components).size(), y_p.size()) 207 | 208 | 209 | if __name__ == "__main__": 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /preprocessing/recon_prior.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | sys.path.append(os.path.join(pathlib.Path(__file__).parent.absolute(), '..')) 5 | import torch 6 | import torch_cluster 7 | import numpy as np 8 | from mvs import mvs_utils, filter_utils 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | from tqdm import tqdm 12 | from utils import masking, create_mvs_model 13 | from dataLoader import mvs_dataset_dict 14 | 15 | def load(pointfile): 16 | if os.path.exists(pointfile): 17 | return torch.as_tensor(np.loadtxt(pointfile, delimiter=";"), dtype=torch.float32) 18 | else: 19 | return None 20 | 21 | def gen_geo(args): 22 | geo = load(args.pointfile) 23 | if geo is None: 24 | print("Do MVS to create pointfile at ", args.pointfile) 25 | dataset = mvs_dataset_dict[args.dataset_name] 26 | mvs_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) 27 | model = create_mvs_model(args) 28 | xyz_world_all, confidence_filtered_all = gen_points_filter(mvs_dataset, args, model) 29 | geo = torch.cat([xyz_world_all, confidence_filtered_all], dim=-1) 30 | os.makedirs(os.path.dirname(args.pointfile), exist_ok=True) 31 | np.savetxt(args.pointfile, geo.cpu().numpy(), delimiter=";") 32 | else: 33 | print("successfully loaded args.pointfile at : ", args.pointfile, geo.shape) 34 | 35 | if args.vox_range is not None and not args.pointfile[:-4].endswith("vox"): 36 | geo_xyz, confidence = geo[..., :3], geo[..., -1:] 37 | geo = mvs_utils.construct_voxrange_points_mean(geo_xyz, torch.as_tensor(args.vox_range, dtype=torch.float32, device=geo.device), vox_center=args.vox_center>0)#, space_min=torch.as_tensor([-1.5,-1.5,-1.5], device=geo.device, dtype=geo.dtype), space_max=torch.as_tensor([1.5,1.5,1.5], device=geo.device, dtype=geo.dtype)) 38 | print("after vox geo shape", geo.shape) 39 | np.savetxt(args.pointfile[:-4] + "_{}_vox".format(args.vox_range[0]) + ".txt", geo.cpu().numpy(), delimiter=";") 40 | 41 | if args.fps_num > 0 and len(geo) > args.fps_num: 42 | fps_inds = torch_cluster.fps(geo[...,:3], ratio=args.fps_num/len(geo), random_start=True) 43 | geo = geo[fps_inds, ...] 44 | print("fps_inds", fps_inds.shape, geo.shape) 45 | np.savetxt(args.pointfile[:-4]+"_{}".format(args.fps_num)+".txt", geo.cpu().numpy(), delimiter=";") 46 | return geo.cuda() 47 | 48 | def gen_points_filter(dataset, args, model): 49 | cam_xyz_all = [] 50 | intrinsics_all = [] 51 | extrinsics_all = [] 52 | confidence_all = [] 53 | points_mask_all = [] 54 | near_fars_all = [] 55 | gpu_filter = True 56 | cpu2gpu= len(dataset.view_id_list) > 300 57 | imgs_lst, HDWD_lst, c2ws_lst, w2cs_lst, intrinsics_lst = [],[],[],[],[] 58 | 59 | with torch.no_grad(): 60 | for i in tqdm(range(0, len(dataset.view_id_list))): 61 | data = dataset.get_init_item(i) 62 | # intrinsics 1, 3, 3, 3 63 | points_xyz_lst, photometric_confidence_lst, point_mask_lst, HDWD, data_mvs, intrinsics_lst, extrinsics_lst = model.gen_points(data) 64 | 65 | c2ws, w2cs, intrinsics, near_fars = data_mvs['c2ws'], data_mvs['w2cs'], data["intrinsics"], data["near_fars"] 66 | 67 | B, N, C, H, W, _ = points_xyz_lst[0].shape 68 | # print("points_xyz_lst",points_xyz_lst[0].shape) 69 | cam_xyz_all.append((points_xyz_lst[0].cpu() if cpu2gpu else points_xyz_lst[0]) if gpu_filter else points_xyz_lst[0].cpu().numpy()) 70 | # intrinsics_lst[0] 1, 3, 3 71 | intrinsics_all.append(intrinsics_lst[0] if gpu_filter else intrinsics_lst[0]) 72 | extrinsics_all.append(extrinsics_lst[0] if gpu_filter else extrinsics_lst[0].cpu().numpy()) 73 | confidence_all.append((photometric_confidence_lst[0].cpu() if cpu2gpu else photometric_confidence_lst[0]) if gpu_filter else photometric_confidence_lst[0].cpu().numpy()) 74 | points_mask_all.append((point_mask_lst[0].cpu() if cpu2gpu else point_mask_lst[0]) if gpu_filter else point_mask_lst[0].cpu().numpy()) 75 | imgs_lst.append(data["images"].cpu()) 76 | HDWD_lst.append(HDWD) 77 | c2ws_lst.append(c2ws) 78 | w2cs_lst.append(w2cs) 79 | near_fars_all.append(near_fars[0,0]) 80 | # visualizer.save_neural_points(i, points_xyz_lst[0], None, data, save_ref=args.load_points == 0) 81 | # #################### start query embedding ################## 82 | torch.cuda.empty_cache() 83 | if gpu_filter: 84 | _, xyz_world_all, confidence_filtered_all = filter_utils.filter_by_masks_gpu(cam_xyz_all, intrinsics_all, extrinsics_all, confidence_all, points_mask_all, args, vis=True, return_w=True, cpu2gpu=cpu2gpu, near_fars_all=near_fars_all) 85 | else: 86 | _, xyz_world_all, confidence_filtered_all = filter_utils.filter_by_masks(cam_xyz_all, [intr.cpu().numpy() for intr in intrinsics_all], extrinsics_all, confidence_all, points_mask_all, args) 87 | 88 | points_vid = torch.cat([torch.ones_like(xyz_world_all[i][...,0:1]) * i for i in range(len(xyz_world_all))], dim=0) 89 | xyz_world_all = torch.cat(xyz_world_all, dim=0) if gpu_filter else torch.as_tensor( 90 | np.concatenate(xyz_world_all, axis=0), device="cuda", dtype=torch.float32) 91 | confidence_filtered_all = torch.cat(confidence_filtered_all, dim=0) if gpu_filter else torch.as_tensor(np.concatenate(confidence_filtered_all, axis=0), device="cuda", dtype=torch.float32) 92 | print("xyz_world_all", xyz_world_all.shape, points_vid.shape, confidence_filtered_all.shape) 93 | torch.cuda.empty_cache() 94 | 95 | 96 | print("%%%%%%%%%%%%% getattr(dataset, spacemin, None)", getattr(dataset, "spacemin", None)) 97 | if getattr(dataset, "spacemin", None) is not None: 98 | mask = (xyz_world_all - dataset.spacemin[None, ...].to(xyz_world_all.device)) >= 0 99 | mask *= (dataset.spacemax[None, ...].to(xyz_world_all.device) - xyz_world_all) >= 0 100 | mask = torch.prod(mask, dim=-1) > 0 101 | first_lst, second_lst = masking(mask, [xyz_world_all, points_vid, confidence_filtered_all], []) 102 | xyz_world_all, points_vid, confidence_filtered_all = first_lst 103 | # visualizer.save_neural_points(50, xyz_world_all, None, None, save_ref=False) 104 | # print("vis 50") 105 | if getattr(dataset, "alphas", None) is not None: 106 | vishull_mask = mvs_utils.alpha_masking(xyz_world_all, dataset.alphas, dataset.intrinsics, dataset.cam2worlds, dataset.world2cams, dataset.near_far if args.ranges[0] < -90.0 and getattr(dataset,"spacemin",None) is None else None, args=args) 107 | first_lst, second_lst = masking(vishull_mask, [xyz_world_all, points_vid, confidence_filtered_all], []) 108 | xyz_world_all, points_vid, confidence_filtered_all = first_lst 109 | print("alpha masking xyz_world_all", xyz_world_all.shape, points_vid.shape) 110 | # visualizer.save_neural_points(100, xyz_world_all, None, data, save_ref=args.load_points == 0) 111 | # print("vis 100") 112 | 113 | if args.vox_res > 0: 114 | xyz_world_all, _, sampled_pnt_idx = mvs_utils.construct_vox_points_closest(xyz_world_all.cuda() if len(xyz_world_all) < 99999999 else xyz_world_all[::(len(xyz_world_all)//99999999+1),...].cuda(), args.vox_res) 115 | points_vid = points_vid[sampled_pnt_idx,:] 116 | confidence_filtered_all = confidence_filtered_all[sampled_pnt_idx] 117 | print("after voxelize:", xyz_world_all.shape, points_vid.shape) 118 | xyz_world_all = xyz_world_all.cuda() 119 | 120 | return xyz_world_all, confidence_filtered_all[..., None] -------------------------------------------------------------------------------- /mvs/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .mvs_utils import normal_vect, index_point_feature, build_color_volume 4 | 5 | def depth2dist(z_vals, cos_angle): 6 | # z_vals: [N_ray N_sample] 7 | device = z_vals.device 8 | dists = z_vals[..., 1:] - z_vals[..., :-1] 9 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 10 | dists = dists * cos_angle.unsqueeze(-1) 11 | return dists 12 | 13 | def ndc2dist(ndc_pts, cos_angle): 14 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 15 | dists = torch.cat([dists, 1e10*cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 16 | return dists 17 | 18 | def raw2alpha(sigma, dist, net_type): 19 | 20 | alpha_softmax = F.softmax(sigma, 1) 21 | 22 | alpha = 1. - torch.exp(-sigma) 23 | 24 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)[:, :-1] 25 | weights = alpha * T # [N_rays, N_samples] 26 | return alpha, weights, alpha_softmax 27 | 28 | def batchify(fn, chunk): 29 | """Constructs a version of 'fn' that applies to smaller batches. 30 | """ 31 | if chunk is None: 32 | return fn 33 | 34 | def ret(inputs, alpha_only): 35 | if alpha_only: 36 | return torch.cat([fn.forward_alpha(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 37 | else: 38 | return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 39 | 40 | return ret 41 | 42 | def run_network_mvs(pts, viewdirs, alpha_feat, fn, embed_fn, embeddirs_fn, netchunk=1024): 43 | """ 44 | Prepares inputs and applies network 'fn'. 45 | """ 46 | 47 | if embed_fn is not None: 48 | pts = embed_fn(pts) 49 | 50 | if alpha_feat is not None: 51 | pts = torch.cat((pts,alpha_feat), dim=-1) 52 | 53 | if viewdirs is not None: 54 | if viewdirs.dim()!=3: 55 | viewdirs = viewdirs[:, None].expand(-1,pts.shape[1],-1) 56 | 57 | if embeddirs_fn is not None: 58 | viewdirs = embeddirs_fn(viewdirs) 59 | pts = torch.cat([pts, viewdirs], -1) 60 | 61 | alpha_only = viewdirs is None 62 | outputs_flat = batchify(fn, netchunk)(pts, alpha_only) 63 | outputs = torch.reshape(outputs_flat, list(pts.shape[:-1]) + [outputs_flat.shape[-1]]) 64 | return outputs 65 | 66 | def raw2outputs(raw, z_vals, dists, white_bkgd=False, net_type='v2'): 67 | """Transforms model's predictions to semantically meaningful values. 68 | Args: 69 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 70 | z_vals: [num_rays, num_samples along ray]. Integration time. 71 | rays_d: [num_rays, 3]. Direction of each ray. 72 | Returns: 73 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 74 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 75 | acc_map: [num_rays]. Sum of weights along each ray. 76 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 77 | depth_map: [num_rays]. Estimated distance to object. 78 | """ 79 | 80 | device = z_vals.device 81 | 82 | rgb = raw[..., :3] # [N_rays, N_samples, 3] 83 | 84 | alpha, weights, alpha_softmax = raw2alpha(raw[..., 3], dists, net_type) # [N_rays, N_samples] 85 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 86 | depth_map = torch.sum(weights * z_vals, -1) 87 | 88 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map, device=device), depth_map / torch.sum(weights, -1)) 89 | acc_map = torch.sum(weights, -1) 90 | 91 | if white_bkgd: 92 | rgb_map = rgb_map + (1. - acc_map[..., None]) 93 | return rgb_map, disp_map, acc_map, weights, depth_map, alpha 94 | 95 | 96 | 97 | def gen_angle_feature(c2ws, rays_pts, rays_dir): 98 | """ 99 | Inputs: 100 | c2ws: [1,v,4,4] 101 | rays_pts: [N_rays, N_samples, 3] 102 | rays_dir: [N_rays, 3] 103 | 104 | Returns: 105 | 106 | """ 107 | N_rays, N_samples = rays_pts.shape[:2] 108 | dirs = normal_vect(rays_pts.unsqueeze(2) - c2ws[:3, :3, 3][None, None]) # [N_rays, N_samples, v, 3] 109 | angle = torch.sum(dirs[:, :, :3] * rays_dir.reshape(N_rays,1,1,3), dim=-1, keepdim=True).reshape(N_rays, N_samples, -1) 110 | return angle 111 | 112 | def gen_dir_feature(w2c_ref, rays_dir): 113 | """ 114 | Inputs: 115 | c2ws: [1,v,4,4] 116 | rays_pts: [N_rays, N_samples, 3] 117 | rays_dir: [N_rays, 3] 118 | 119 | Returns: 120 | 121 | """ 122 | dirs = rays_dir @ w2c_ref[:3,:3].t() # [N_rays, 3] 123 | return dirs 124 | 125 | def gen_pts_feats(imgs, volume_feature, rays_pts, pose_ref, rays_ndc, feat_dim, img_feat=None, img_downscale=1.0, use_color_volume=False, net_type='v0'): 126 | N_rays, N_samples = rays_pts.shape[:2] 127 | if img_feat is not None: 128 | feat_dim += img_feat.shape[1]*img_feat.shape[2] 129 | 130 | if not use_color_volume: 131 | input_feat = torch.empty((N_rays, N_samples, feat_dim), device=imgs.device, dtype=torch.float) 132 | ray_feats = index_point_feature(volume_feature, rays_ndc) if torch.is_tensor(volume_feature) else volume_feature(rays_ndc) 133 | input_feat[..., :8] = ray_feats 134 | input_feat[..., 8:] = build_color_volume(rays_pts, pose_ref, imgs, img_feat, with_mask=True, downscale=img_downscale) 135 | else: 136 | input_feat = index_point_feature(volume_feature, rays_ndc) if torch.is_tensor(volume_feature) else volume_feature(rays_ndc) 137 | return input_feat 138 | 139 | def rendering(args, pose_ref, rays_pts, rays_ndc, depth_candidates, rays_o, rays_dir, 140 | volume_feature=None, imgs=None, network_fn=None, img_feat=None, network_query_fn=None, white_bkgd=False, **kwargs): 141 | 142 | # rays angle 143 | cos_angle = torch.norm(rays_dir, dim=-1) 144 | 145 | 146 | # using direction 147 | if pose_ref is not None: 148 | angle = gen_dir_feature(pose_ref['w2cs'][0], rays_dir/cos_angle.unsqueeze(-1)) # view dir feature 149 | else: 150 | angle = rays_dir/cos_angle.unsqueeze(-1) 151 | 152 | # rays_pts 153 | input_feat = gen_pts_feats(imgs, volume_feature, rays_pts, pose_ref, rays_ndc, args.feat_dim, \ 154 | img_feat, args.img_downscale, args.use_color_volume, args.net_type) 155 | 156 | # rays_ndc = rays_ndc * 2 - 1.0 157 | # network_query_fn = lambda pts, viewdirs, rays_feats, network_fn: run_network_mvs(pts, viewdirs, rays_feats, 158 | # network_fn, 159 | # embed_fn=embed_fn, 160 | # embeddirs_fn=embeddirs_fn, 161 | # netchunk=args.netchunk) 162 | # run_network_mvs 163 | raw = network_query_fn(rays_ndc, angle, input_feat, network_fn) 164 | if raw.shape[-1]>4: 165 | input_feat = torch.cat((input_feat[...,:8],raw[...,4:]), dim=-1) 166 | 167 | dists = depth2dist(depth_candidates, cos_angle) 168 | # dists = ndc2dist(rays_ndc) 169 | rgb_map, disp_map, acc_map, weights, depth_map, alpha = raw2outputs(raw, depth_candidates, dists, white_bkgd,args.net_type) 170 | ret = {} 171 | 172 | return rgb_map, input_feat, weights, depth_map, alpha, ret 173 | 174 | def render_density(network_fn, rays_pts, density_feature, network_query_fn, chunk=1024 * 5): 175 | densities = [] 176 | device = density_feature.device 177 | for i in range(0, rays_pts.shape[0], chunk): 178 | 179 | input_feat = rays_pts[i:i + chunk].to(device) 180 | 181 | density = network_query_fn(input_feat, None, density_feature[i:i + chunk], network_fn) 182 | densities.append(density) 183 | 184 | return torch.cat(densities) -------------------------------------------------------------------------------- /mvs/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def get_nonlinearity_layer(activation_type='PReLU'): 11 | if activation_type == 'ReLU': 12 | nonlinearity_layer = nn.ReLU(True) 13 | elif activation_type == 'SELU': 14 | nonlinearity_layer = nn.SELU(True) 15 | elif activation_type == 'LeakyReLU': 16 | nonlinearity_layer = nn.LeakyReLU(0.1, True) 17 | elif activation_type == 'PReLU': 18 | nonlinearity_layer = nn.PReLU() 19 | else: 20 | raise NotImplementedError('activation layer [{}] is not found'.format(activation_type)) 21 | return nonlinearity_layer 22 | 23 | 24 | def get_norm_layer(norm_type='instance'): 25 | if norm_type == 'batch': 26 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 27 | elif norm_type == 'instance': 28 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 29 | # norm_layer = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=True) 30 | elif norm_type == 'group': 31 | norm_layer = functools.partial(nn.GroupNorm, num_groups=16, affine=True) 32 | elif norm_type == 'layer': 33 | norm_layer = nn.LayerNorm 34 | elif norm_type == 'none': 35 | norm_layer = None 36 | else: 37 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 38 | return norm_layer 39 | 40 | 41 | def get_scheduler(optimizer, opt): 42 | if opt.lr_policy == 'lambda': 43 | def lambda_rule(it): 44 | lr_l = 1.0 - max(0, it - opt.niter) / float(opt.niter_decay + 1) 45 | return lr_l 46 | 47 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 48 | elif opt.lr_policy == 'step': 49 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 50 | elif opt.lr_policy == 'plateau': 51 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 52 | mode='min', 53 | factor=0.2, 54 | threshold=0.01, 55 | patience=5) 56 | elif opt.lr_policy == 'iter_exponential_decay': 57 | def lambda_rule(it): 58 | lr_l = pow(opt.lr_decay_exp, it / opt.lr_decay_iters) 59 | return lr_l 60 | 61 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 62 | 63 | elif opt.lr_policy == 'cosine_annealing': 64 | scheduler = CosineAnnealingLR(optimizer, T_max=self.args.num_epochs, eta_min=1e-7) 65 | 66 | else: 67 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 68 | return scheduler 69 | 70 | 71 | def get_xavier_multiplier(m, gain): 72 | if isinstance(m, nn.Conv1d): 73 | ksize = m.kernel_size[0] 74 | n1 = m.in_channels 75 | n2 = m.out_channels 76 | 77 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 78 | elif isinstance(m, nn.ConvTranspose1d): 79 | ksize = m.kernel_size[0] // m.stride[0] 80 | n1 = m.in_channels 81 | n2 = m.out_channels 82 | 83 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 84 | elif isinstance(m, nn.Conv2d): 85 | ksize = m.kernel_size[0] * m.kernel_size[1] 86 | n1 = m.in_channels 87 | n2 = m.out_channels 88 | 89 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 90 | elif isinstance(m, nn.ConvTranspose2d): 91 | ksize = m.kernel_size[0] * m.kernel_size[1] // m.stride[0] // m.stride[1] 92 | n1 = m.in_channels 93 | n2 = m.out_channels 94 | 95 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 96 | elif isinstance(m, nn.Conv3d): 97 | ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] 98 | n1 = m.in_channels 99 | n2 = m.out_channels 100 | 101 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 102 | elif isinstance(m, nn.ConvTranspose3d): 103 | ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // m.stride[0] // m.stride[ 104 | 1] // m.stride[2] 105 | n1 = m.in_channels 106 | n2 = m.out_channels 107 | 108 | std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) 109 | elif isinstance(m, nn.Linear): 110 | n1 = m.in_features 111 | n2 = m.out_features 112 | 113 | std = gain * np.sqrt(2.0 / (n1 + n2)) 114 | else: 115 | return None 116 | 117 | return std 118 | 119 | 120 | def xavier_uniform_(m, gain): 121 | std = get_xavier_multiplier(m, gain) 122 | m.weight.data.uniform_(-std * np.sqrt(3.0), std * np.sqrt(3.0)) 123 | 124 | 125 | def init_weights(net, init_type='xavier_uniform', gain=1): 126 | def init_func(m): 127 | classname = m.__class__.__name__ 128 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 129 | if init_type == 'xavier_uniform': 130 | xavier_uniform_(m, gain) 131 | elif init_type == 'normal': 132 | init.normal_(m.weight.data, 0.0, gain) 133 | elif init_type == 'xavier': 134 | init.xavier_normal_(m.weight.data, gain=gain) 135 | elif init_type == 'kaiming': 136 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 137 | elif init_type == 'orthogonal': 138 | init.orthogonal_(m.weight.data, gain=gain) 139 | else: 140 | raise NotImplementedError('initialization method [{}] is not implemented'.format(init_type)) 141 | if hasattr(m, 'bias') and m.bias is not None: 142 | init.constant_(m.bias.data, 0.0) 143 | elif classname.find('BatchNorm2d') != -1: 144 | init.normal_(m.weight.data, 1.0, gain) 145 | init.constant_(m.bias.data, 0.0) 146 | 147 | # if classname == 'ConvTranspose2d': 148 | # m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] 149 | # m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] 150 | # m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] 151 | # elif classname == 'ConvTranspose3d': 152 | # m.weight.data[:, :, 0::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 153 | # m.weight.data[:, :, 0::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 154 | # m.weight.data[:, :, 0::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 155 | # m.weight.data[:, :, 1::2, 0::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 156 | # m.weight.data[:, :, 1::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 157 | # m.weight.data[:, :, 1::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 158 | # m.weight.data[:, :, 1::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] 159 | 160 | net.apply(init_func) 161 | 162 | 163 | def init_seq(s, init_type='xavier_uniform'): 164 | '''initialize sequential model''' 165 | for a, b in zip(s[:-1], s[1:]): 166 | if isinstance(b, nn.ReLU): 167 | init_weights(a, init_type, nn.init.calculate_gain('relu')) 168 | elif isinstance(b, nn.LeakyReLU): 169 | init_weights(a, init_type, nn.init.calculate_gain('leaky_relu', b.negative_slope)) 170 | else: 171 | init_weights(a, init_type) 172 | init_weights(s[-1]) 173 | 174 | 175 | def positional_encoding(positions, freqs, ori=False): 176 | '''encode positions with positional encoding 177 | positions: :math:`(...,D)` 178 | freqs: int 179 | Return: 180 | pts: :math:`(..., 2DF)` 181 | ''' 182 | freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) 183 | ori_c = positions.shape[-1] 184 | pts = (positions[..., None] * freq_bands).reshape(positions.shape[:-1] + 185 | (freqs * positions.shape[-1], )) # (..., DF) 186 | if ori: 187 | pts = torch.cat([positions, torch.sin(pts), torch.cos(pts)], dim=-1).reshape(pts.shape[:-1]+(pts.shape[-1]*2+ori_c,)) 188 | else: 189 | pts = torch.stack([torch.sin(pts), torch.cos(pts)], dim=-1).reshape(pts.shape[:-1]+(pts.shape[-1]*2,)) 190 | return pts -------------------------------------------------------------------------------- /preprocessing/recon_prior_hier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pathlib 4 | sys.path.append(os.path.join(pathlib.Path(__file__).parent.absolute(), '..')) 5 | import torch 6 | import torch_cluster 7 | import numpy as np 8 | from mvs import mvs_utils, filter_utils 9 | torch.manual_seed(0) 10 | np.random.seed(0) 11 | from tqdm import tqdm 12 | from utils import masking, create_mvs_model 13 | from dataLoader import mvs_dataset_dict 14 | #from preprocessing.boxing import find_tensorf_box, filter_cluster_n_pnts 15 | #from preprocessing.cluster import cluster 16 | 17 | def load(pointfile): 18 | if os.path.exists(pointfile): 19 | return torch.as_tensor(np.loadtxt(pointfile, delimiter=";"), dtype=torch.float32) 20 | else: 21 | return None 22 | 23 | def gen_geo(args, geo): 24 | if not os.path.isdir(args.pointfile[:-4]): 25 | os.makedirs(args.pointfile[:-4], exist_ok=True) 26 | if args.pointfile.endswith('ply'): 27 | # points_path = os.path.join(self.root_dir, "exported/pcd.ply") 28 | geo = mvs_utils.load_ply_points(args) 29 | elif args.pointfile == 'depth': 30 | colordir = os.path.join(args.datadir, "exported/color") 31 | image_paths = [f for f in os.listdir(colordir) if os.path.isfile(os.path.join(colordir, f))] 32 | image_paths = [os.path.join(args.datadir, "exported/color/{}.jpg".format(i)) for i in 33 | range(len(image_paths))] 34 | all_id_list = mvs_utils.filter_valid_id(args, list(range(len(image_paths)))) 35 | depth_intrinsic = np.loadtxt( 36 | os.path.join(args.datadir, "exported/intrinsic/intrinsic_depth.txt")).astype(np.float32)[:3, :3] 37 | geo = mvs_utils.load_init_depth_points(args, all_id_list, depth_intrinsic, device="cuda") 38 | # np.savetxt(os.path.join(args.basedir, args.expname, "depth.txt"), geo.cpu().numpy(), delimiter=";") 39 | else: 40 | if geo is None: # no points by dvgo 41 | geo = load(args.pointfile) 42 | if geo is None: 43 | print("Do MVS to create pointfile at ", args.pointfile) 44 | dataset = mvs_dataset_dict[args.dataset_name] 45 | mvs_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) 46 | model = create_mvs_model(args) 47 | xyz_world_all, confidence_filtered_all = gen_points_filter(mvs_dataset, args, model) 48 | geo = torch.cat([xyz_world_all, confidence_filtered_all], dim=-1) 49 | os.makedirs(os.path.dirname(args.pointfile), exist_ok=True) 50 | np.savetxt(args.pointfile, geo.cpu().numpy(), delimiter=";") 51 | else: 52 | print("successfully loaded args.pointfile at : ", args.pointfile, geo.shape) 53 | geo_lst = [] 54 | if args.vox_range is not None and not args.pointfile[:-4].endswith("vox"): 55 | geo_xyz, confidence = geo[..., :3], geo[..., -1:] 56 | for i in range(len(args.vox_range)): 57 | geo_lvl, _, _, _ = mvs_utils.construct_voxrange_points_mean(geo_xyz, torch.as_tensor(args.vox_range[i], dtype=torch.float32, device=geo.device), vox_center=args.vox_center[i]>0) 58 | print("after vox geo shape", geo_lvl.shape) 59 | np.savetxt(args.pointfile[:-4] + "_{}_vox".format(args.vox_range[i][0]) + ".txt", geo_lvl.cpu().numpy(), delimiter=";") 60 | geo_lst.append(geo_lvl.cuda()) 61 | 62 | if args.fps_num is not None: 63 | for i in range(len(args.fps_num)): 64 | if len(geo_lst[i]) > args.fps_num[i]: 65 | fps_inds = torch_cluster.fps(geo_lst[i][...,:3], ratio=args.fps_num[i]/len(geo_lst[i]), random_start=True) 66 | geo_lvl = geo_lst[i][fps_inds, ...] 67 | print("fps_inds", fps_inds.shape, geo_lvl.shape) 68 | np.savetxt(args.pointfile[:-4]+"_{}".format(args.fps_num)+".txt", geo_lvl.cpu().numpy(), delimiter=";") 69 | geo_lst[i] = geo_lvl.cuda() 70 | return None, geo_lst 71 | 72 | 73 | 74 | def gen_points_filter(dataset, args, model): 75 | cam_xyz_all = [] 76 | intrinsics_all = [] 77 | extrinsics_all = [] 78 | confidence_all = [] 79 | points_mask_all = [] 80 | near_fars_all = [] 81 | gpu_filter = True 82 | cpu2gpu= len(dataset.view_id_list) > 300 83 | imgs_lst, HDWD_lst, c2ws_lst, w2cs_lst, intrinsics_lst = [],[],[],[],[] 84 | 85 | with torch.no_grad(): 86 | for i in tqdm(range(0, len(dataset.view_id_list))): 87 | data = dataset.get_init_item(i) 88 | # intrinsics 1, 3, 3, 3 89 | points_xyz_lst, photometric_confidence_lst, point_mask_lst, HDWD, data_mvs, intrinsics_lst, extrinsics_lst = model.gen_points(data) 90 | 91 | c2ws, w2cs, intrinsics, near_fars = data_mvs['c2ws'], data_mvs['w2cs'], data["intrinsics"], data["near_fars"] 92 | 93 | B, N, C, H, W, _ = points_xyz_lst[0].shape 94 | # print("points_xyz_lst",points_xyz_lst[0].shape) 95 | cam_xyz_all.append((points_xyz_lst[0].cpu() if cpu2gpu else points_xyz_lst[0]) if gpu_filter else points_xyz_lst[0].cpu().numpy()) 96 | # intrinsics_lst[0] 1, 3, 3 97 | intrinsics_all.append(intrinsics_lst[0] if gpu_filter else intrinsics_lst[0]) 98 | extrinsics_all.append(extrinsics_lst[0] if gpu_filter else extrinsics_lst[0].cpu().numpy()) 99 | confidence_all.append((photometric_confidence_lst[0].cpu() if cpu2gpu else photometric_confidence_lst[0]) if gpu_filter else photometric_confidence_lst[0].cpu().numpy()) 100 | points_mask_all.append((point_mask_lst[0].cpu() if cpu2gpu else point_mask_lst[0]) if gpu_filter else point_mask_lst[0].cpu().numpy()) 101 | imgs_lst.append(data["images"].cpu()) 102 | HDWD_lst.append(HDWD) 103 | c2ws_lst.append(c2ws) 104 | w2cs_lst.append(w2cs) 105 | near_fars_all.append(near_fars[0,0]) 106 | # visualizer.save_neural_points(i, points_xyz_lst[0], None, data, save_ref=args.load_points == 0) 107 | # #################### start query embedding ################## 108 | torch.cuda.empty_cache() 109 | if gpu_filter: 110 | _, xyz_world_all, confidence_filtered_all = filter_utils.filter_by_masks_gpu(cam_xyz_all, intrinsics_all, extrinsics_all, confidence_all, points_mask_all, args, vis=True, return_w=True, cpu2gpu=cpu2gpu, near_fars_all=near_fars_all) 111 | else: 112 | _, xyz_world_all, confidence_filtered_all = filter_utils.filter_by_masks(cam_xyz_all, [intr.cpu().numpy() for intr in intrinsics_all], extrinsics_all, confidence_all, points_mask_all, args) 113 | 114 | points_vid = torch.cat([torch.ones_like(xyz_world_all[i][...,0:1]) * i for i in range(len(xyz_world_all))], dim=0) 115 | xyz_world_all = torch.cat(xyz_world_all, dim=0) if gpu_filter else torch.as_tensor( 116 | np.concatenate(xyz_world_all, axis=0), device="cuda", dtype=torch.float32) 117 | confidence_filtered_all = torch.cat(confidence_filtered_all, dim=0) if gpu_filter else torch.as_tensor(np.concatenate(confidence_filtered_all, axis=0), device="cuda", dtype=torch.float32) 118 | print("xyz_world_all", xyz_world_all.shape, points_vid.shape, confidence_filtered_all.shape) 119 | torch.cuda.empty_cache() 120 | 121 | 122 | print("%%%%%%%%%%%%% getattr(dataset, spacemin, None)", getattr(dataset, "spacemin", None)) 123 | if getattr(dataset, "spacemin", None) is not None: 124 | mask = (xyz_world_all - dataset.spacemin[None, ...].to(xyz_world_all.device)) >= 0 125 | mask *= (dataset.spacemax[None, ...].to(xyz_world_all.device) - xyz_world_all) >= 0 126 | mask = torch.prod(mask, dim=-1) > 0 127 | first_lst, second_lst = masking(mask, [xyz_world_all, points_vid, confidence_filtered_all], []) 128 | xyz_world_all, points_vid, confidence_filtered_all = first_lst 129 | # visualizer.save_neural_points(50, xyz_world_all, None, None, save_ref=False) 130 | # print("vis 50") 131 | if getattr(dataset, "alphas", None) is not None: 132 | vishull_mask = mvs_utils.alpha_masking(xyz_world_all, dataset.alphas, dataset.intrinsics, dataset.cam2worlds, dataset.world2cams, dataset.near_far if args.ranges[0] < -90.0 and getattr(dataset,"spacemin",None) is None else None, args=args) 133 | first_lst, second_lst = masking(vishull_mask, [xyz_world_all, points_vid, confidence_filtered_all], []) 134 | xyz_world_all, points_vid, confidence_filtered_all = first_lst 135 | print("alpha masking xyz_world_all", xyz_world_all.shape, points_vid.shape) 136 | # visualizer.save_neural_points(100, xyz_world_all, None, data, save_ref=args.load_points == 0) 137 | # print("vis 100") 138 | 139 | if args.vox_res > 0: 140 | xyz_world_all, _, sampled_pnt_idx = mvs_utils.construct_vox_points_closest(xyz_world_all.cuda() if len(xyz_world_all) < 99999999 else xyz_world_all[::(len(xyz_world_all)//99999999+1),...].cuda(), args.vox_res) 141 | points_vid = points_vid[sampled_pnt_idx,:] 142 | confidence_filtered_all = confidence_filtered_all[sampled_pnt_idx] 143 | print("after voxelize:", xyz_world_all.shape, points_vid.shape) 144 | xyz_world_all = xyz_world_all.cuda() 145 | 146 | return xyz_world_all, confidence_filtered_all[..., None] -------------------------------------------------------------------------------- /dataLoader/tankstemple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | 11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 12 | if axis == 'z': 13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] 14 | elif axis == 'y': 15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] 16 | else: 17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] 18 | 19 | 20 | def cross(x, y, axis=0): 21 | T = torch if isinstance(x, torch.Tensor) else np 22 | return T.cross(x, y, axis) 23 | 24 | 25 | def normalize(x, axis=-1, order=2): 26 | if isinstance(x, torch.Tensor): 27 | l2 = x.norm(p=order, dim=axis, keepdim=True) 28 | return x / (l2 + 1e-8), l2 29 | 30 | else: 31 | l2 = np.linalg.norm(x, order, axis) 32 | l2 = np.expand_dims(l2, axis) 33 | l2[l2 == 0] = 1 34 | return x / l2, 35 | 36 | 37 | def cat(x, axis=1): 38 | if isinstance(x[0], torch.Tensor): 39 | return torch.cat(x, dim=axis) 40 | return np.concatenate(x, axis=axis) 41 | 42 | 43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 44 | """ 45 | This function takes a vector 'camera_position' which specifies the location 46 | of the camera in world coordinates and two vectors `at` and `up` which 47 | indicate the position of the object and the up directions of the world 48 | coordinate system respectively. The object is assumed to be centered at 49 | the origin. 50 | The output is a rotation matrix representing the transformation 51 | from world coordinates -> view coordinates. 52 | Input: 53 | camera_position: 3 54 | at: 1 x 3 or N x 3 (0, 0, 0) in default 55 | up: 1 x 3 or N x 3 (0, 1, 0) in default 56 | """ 57 | 58 | if at is None: 59 | at = torch.zeros_like(camera_position) 60 | else: 61 | at = torch.tensor(at).type_as(camera_position) 62 | if up is None: 63 | up = torch.zeros_like(camera_position) 64 | up[2] = -1 65 | else: 66 | up = torch.tensor(up).type_as(camera_position) 67 | 68 | z_axis = normalize(at - camera_position)[0] 69 | x_axis = normalize(cross(up, z_axis))[0] 70 | y_axis = normalize(cross(z_axis, x_axis))[0] 71 | 72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 73 | return R 74 | 75 | 76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): 77 | c2ws = [] 78 | for t in range(frames): 79 | c2w = torch.eye(4) 80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) 81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) 82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot 83 | c2ws.append(c2w) 84 | return torch.stack(c2ws) 85 | 86 | class TanksTempleDataset(Dataset): 87 | """NSVF Generic Dataset.""" 88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False): 89 | self.root_dir = datadir 90 | self.split = split 91 | self.is_stack = is_stack 92 | self.downsample = downsample 93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 94 | self.define_transforms() 95 | 96 | self.white_bg = True 97 | self.near_far = [0.01,6.0] 98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2 99 | 100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 101 | self.read_meta() 102 | self.define_proj_mat() 103 | 104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 106 | 107 | def bbox2corners(self): 108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 109 | for i in range(3): 110 | corners[i,[0,1],i] = corners[i,[1,0],i] 111 | return corners.view(-1,3) 112 | 113 | 114 | def read_meta(self): 115 | 116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) 117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1) 118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 120 | 121 | if self.split == 'train': 122 | pose_files = [x for x in pose_files if x.startswith('0_')] 123 | img_files = [x for x in img_files if x.startswith('0_')] 124 | elif self.split == 'val': 125 | pose_files = [x for x in pose_files if x.startswith('1_')] 126 | img_files = [x for x in img_files if x.startswith('1_')] 127 | elif self.split == 'test': 128 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 129 | test_img_files = [x for x in img_files if x.startswith('2_')] 130 | if len(test_pose_files) == 0: 131 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 132 | test_img_files = [x for x in img_files if x.startswith('1_')] 133 | pose_files = test_pose_files 134 | img_files = test_img_files 135 | 136 | # ray directions for all pixels, same for all images (same H, W, focal) 137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 139 | 140 | 141 | 142 | self.poses = [] 143 | self.all_rays = [] 144 | self.all_rgbs = [] 145 | 146 | assert len(img_files) == len(pose_files) 147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 149 | img = Image.open(image_path) 150 | if self.downsample!=1.0: 151 | img = img.resize(self.img_wh, Image.LANCZOS) 152 | img = self.transform(img) # (4, h, w) 153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 154 | if img.shape[-1]==4: 155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 156 | self.all_rgbs.append(img) 157 | 158 | 159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans 160 | c2w = torch.FloatTensor(c2w) 161 | self.poses.append(c2w) # C2W 162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 164 | 165 | self.poses = torch.stack(self.poses) 166 | 167 | center = torch.mean(self.scene_bbox, dim=0) 168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2 169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() 170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y') 171 | self.render_path = gen_path(pos_gen, up=up,frames=200) 172 | self.render_path[:, :3, 3] += center 173 | 174 | 175 | 176 | if 'train' == self.split: 177 | if self.is_stack: 178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 180 | else: 181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 183 | else: 184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 186 | 187 | 188 | def define_transforms(self): 189 | self.transform = T.ToTensor() 190 | 191 | def define_proj_mat(self): 192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 193 | 194 | def world2ndc(self, points): 195 | device = points.device 196 | return (points - self.center.to(device)) / self.radius.to(device) 197 | 198 | def __len__(self): 199 | if self.split == 'train': 200 | return len(self.all_rays) 201 | return len(self.all_rgbs) 202 | 203 | def __getitem__(self, idx): 204 | 205 | if self.split == 'train': # use data in the buffers 206 | sample = {'rays': self.all_rays[idx], 207 | 'rgbs': self.all_rgbs[idx]} 208 | 209 | else: # create data for each image separately 210 | 211 | img = self.all_rgbs[idx] 212 | rays = self.all_rays[idx] 213 | 214 | sample = {'rays': rays, 215 | 'rgbs': img} 216 | return sample -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2,torch 2 | import numpy as np 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import torch.nn.functional as F 6 | import scipy.signal 7 | import importlib 8 | 9 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 10 | 11 | 12 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 13 | """ 14 | depth: (H, W) 15 | """ 16 | 17 | x = np.nan_to_num(depth) # change nan to 0 18 | if minmax is None: 19 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 20 | ma = np.max(x) 21 | else: 22 | mi,ma = minmax 23 | 24 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 25 | x = (255*x).astype(np.uint8) 26 | x_ = cv2.applyColorMap(x, cmap) 27 | return x_, [mi,ma] 28 | 29 | def init_log(log, keys): 30 | for key in keys: 31 | log[key] = torch.tensor([0.0], dtype=float) 32 | return log 33 | 34 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 35 | """ 36 | depth: (H, W) 37 | """ 38 | if type(depth) is not np.ndarray: 39 | depth = depth.cpu().numpy() 40 | 41 | x = np.nan_to_num(depth) # change nan to 0 42 | if minmax is None: 43 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 44 | ma = np.max(x) 45 | else: 46 | mi,ma = minmax 47 | 48 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 49 | x = (255*x).astype(np.uint8) 50 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 51 | x_ = T.ToTensor()(x_) # (3, H, W) 52 | return x_, [mi,ma] 53 | 54 | def N_to_reso(n_voxels, bbox): 55 | xyz_min, xyz_max = bbox 56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / 3) 57 | return ((xyz_max - xyz_min) / voxel_size).long().tolist() 58 | 59 | def cal_n_samples(reso, step_ratio=0.5): 60 | return int(np.linalg.norm(reso)/step_ratio) 61 | 62 | 63 | 64 | 65 | __LPIPS__ = {} 66 | def init_lpips(net_name, device): 67 | assert net_name in ['alex', 'vgg'] 68 | import lpips 69 | print(f'init_lpips: lpips_{net_name}') 70 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 71 | 72 | def rgb_lpips(np_gt, np_im, net_name, device): 73 | if net_name not in __LPIPS__: 74 | __LPIPS__[net_name] = init_lpips(net_name, device) 75 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 76 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 77 | return __LPIPS__[net_name](gt, im, normalize=True).item() 78 | 79 | 80 | def findItem(items, target): 81 | for one in items: 82 | if one[:len(target)]==target: 83 | return one 84 | return None 85 | 86 | 87 | ''' Evaluation metrics (ssim, lpips) 88 | ''' 89 | def rgb_ssim(img0, img1, max_val, 90 | filter_size=11, 91 | filter_sigma=1.5, 92 | k1=0.01, 93 | k2=0.03, 94 | return_map=False): 95 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 96 | assert len(img0.shape) == 3 97 | assert img0.shape[-1] == 3 98 | assert img0.shape == img1.shape 99 | 100 | # Construct a 1D Gaussian blur filter. 101 | hw = filter_size // 2 102 | shift = (2 * hw - filter_size + 1) / 2 103 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 104 | filt = np.exp(-0.5 * f_i) 105 | filt /= np.sum(filt) 106 | 107 | # Blur in x and y (faster than the 2D convolution). 108 | def convolve2d(z, f): 109 | return scipy.signal.convolve2d(z, f, mode='valid') 110 | 111 | filt_fn = lambda z: np.stack([ 112 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 113 | for i in range(z.shape[-1])], -1) 114 | mu0 = filt_fn(img0) 115 | mu1 = filt_fn(img1) 116 | mu00 = mu0 * mu0 117 | mu11 = mu1 * mu1 118 | mu01 = mu0 * mu1 119 | sigma00 = filt_fn(img0**2) - mu00 120 | sigma11 = filt_fn(img1**2) - mu11 121 | sigma01 = filt_fn(img0 * img1) - mu01 122 | 123 | # Clip the variances and covariances to valid values. 124 | # Variance must be non-negative: 125 | sigma00 = np.maximum(0., sigma00) 126 | sigma11 = np.maximum(0., sigma11) 127 | sigma01 = np.sign(sigma01) * np.minimum( 128 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 129 | c1 = (k1 * max_val)**2 130 | c2 = (k2 * max_val)**2 131 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 132 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 133 | ssim_map = numer / denom 134 | ssim = np.mean(ssim_map) 135 | return ssim_map if return_map else ssim 136 | 137 | 138 | import torch.nn as nn 139 | class TVLoss(nn.Module): 140 | def __init__(self,TVLoss_weight=1): 141 | super(TVLoss,self).__init__() 142 | self.TVLoss_weight = TVLoss_weight 143 | 144 | def forward(self,x): 145 | batch_size = x.size()[0] 146 | h_x = x.size()[2] 147 | w_x = x.size()[3] 148 | count_h = self._tensor_size(x[:,:,1:,:]) 149 | count_w = self._tensor_size(x[:,:,:,1:]) 150 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 151 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 152 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 153 | 154 | def _tensor_size(self,t): 155 | return t.size()[1]*t.size()[2]*t.size()[3] 156 | 157 | import plyfile 158 | import skimage.measure 159 | def convert_sdf_samples_to_ply( 160 | pytorch_3d_sdf_tensor, 161 | ply_filename_out, 162 | bbox, 163 | level=0.5, 164 | offset=None, 165 | scale=None, 166 | ): 167 | """ 168 | Convert sdf samples to .ply 169 | 170 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 171 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 172 | :voxel_size: float, the size of the voxels 173 | :ply_filename_out: string, path of the filename to save to 174 | 175 | This function adapted from: https://github.com/RobotLocomotion/spartan 176 | """ 177 | 178 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 179 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) 180 | 181 | verts, faces, normals, values = skimage.measure.marching_cubes( 182 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size 183 | ) 184 | faces = faces[...,::-1] # inverse face orientation 185 | 186 | # transform from voxel coordinates to camera coordinates 187 | # note x and y are flipped in the output of marching_cubes 188 | mesh_points = np.zeros_like(verts) 189 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0] 190 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1] 191 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2] 192 | 193 | # apply additional offset and scale 194 | if scale is not None: 195 | mesh_points = mesh_points / scale 196 | if offset is not None: 197 | mesh_points = mesh_points - offset 198 | 199 | # try writing to the ply file 200 | 201 | num_verts = verts.shape[0] 202 | num_faces = faces.shape[0] 203 | 204 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 205 | 206 | for i in range(0, num_verts): 207 | verts_tuple[i] = tuple(mesh_points[i, :]) 208 | 209 | faces_building = [] 210 | for i in range(0, num_faces): 211 | faces_building.append(((faces[i, :].tolist(),))) 212 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 213 | 214 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 215 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 216 | 217 | ply_data = plyfile.PlyData([el_verts, el_faces]) 218 | print("saving mesh to %s" % (ply_filename_out)) 219 | ply_data.write(ply_filename_out) 220 | 221 | 222 | def masking(mask, firstdim_lst, seconddim_lst): 223 | first_lst = [item[mask, ...] if item is not None else None for item in firstdim_lst] 224 | second_lst = [item[:, mask, ...] if item is not None else None for item in seconddim_lst] 225 | return first_lst, second_lst 226 | 227 | 228 | def find_mvs_model_class_by_name(model_name): 229 | # Given the option --model [modelname], 230 | # the file "models/modelname_model.py" 231 | # will be imported. 232 | model_filename = "mvs." + model_name + "_model" 233 | modellib = importlib.import_module(model_filename) 234 | 235 | # In the file, the class called ModelNameModel() will 236 | # be instantiated. It has to be a subclass of BaseModel, 237 | # and it is case-insensitive. 238 | model = None 239 | target_model_name = model_name.replace('_', '') + 'model' 240 | for name, cls in modellib.__dict__.items(): 241 | if name.lower() == target_model_name.lower(): 242 | model = cls 243 | 244 | if model is None: 245 | print( 246 | "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." 247 | % (model_filename, target_model_name)) 248 | exit(0) 249 | 250 | return model 251 | 252 | 253 | def get_option_setter(model_name): 254 | model_class = find_mvs_model_class_by_name(model_name) 255 | return model_class.modify_commandline_options 256 | 257 | 258 | def create_mvs_model(args): 259 | model = find_mvs_model_class_by_name(args.mvs_model) 260 | instance = model(args) 261 | print("model [{}] was created".format("mvs." + args.mvs_model + "_model")) 262 | return instance -------------------------------------------------------------------------------- /models/cuda/render_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector infer_t_minmax_cuda( 8 | torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max, 9 | const float near, const float far); 10 | 11 | torch::Tensor infer_n_samples_cuda(torch::Tensor t_min, torch::Tensor t_max, const float stepdist); 12 | 13 | std::vector infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min); 14 | 15 | std::vector sample_pts_on_rays_cuda( 16 | torch::Tensor rays_o, torch::Tensor rays_d, 17 | torch::Tensor xyz_min, torch::Tensor xyz_max, 18 | const float near, const float far, const float stepdist); 19 | 20 | std::vector sample_pts_on_rays_geo_cuda( 21 | torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor pnt_xyz, torch::Tensor tensoRF_per_ray, 22 | torch::Tensor half_range_sqr, torch::Tensor xyz_min, torch::Tensor xyz_max, 23 | const float near, const float far, const float stepdist, torch::Tensor local_range, torch::Tensor local_dims, torch::Tensor local_stepsize); 24 | 25 | std::vector sample_pts_on_rays_dist_cuda( 26 | torch::Tensor rays_o, torch::Tensor rays_d, 27 | torch::Tensor xyz_min, torch::Tensor xyz_max, 28 | const float near, const float far, const float stepdist, torch::Tensor shift); 29 | 30 | std::vector sample_ndc_pts_on_rays_cuda( 31 | torch::Tensor rays_o, torch::Tensor rays_d, 32 | torch::Tensor xyz_min, torch::Tensor xyz_max, 33 | const int N_samples); 34 | 35 | torch::Tensor maskcache_lookup_cuda(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift); 36 | 37 | std::vector raw2alpha_cuda(torch::Tensor density, const float shift, const float interval); 38 | 39 | torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, const float interval); 40 | 41 | std::vector raw2alpha_randstep_cuda(torch::Tensor density, const float shift, torch::Tensor interval); 42 | 43 | torch::Tensor raw2alpha_randstep_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval); 44 | 45 | std::vector alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays); 46 | torch::Tensor filter_ray_by_points_cuda(torch::Tensor xyz, torch::Tensor pnts, torch::Tensor half_range); 47 | torch::Tensor filter_ray_by_projection_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor pnts, torch::Tensor half_range_sqr); 48 | torch::Tensor alpha2weight_backward_cuda( 49 | torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last, 50 | torch::Tensor i_start, torch::Tensor i_end, const int n_rays, 51 | torch::Tensor grad_weights, torch::Tensor grad_last); 52 | 53 | // C++ interface 54 | 55 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 56 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 57 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 58 | 59 | std::vector infer_t_minmax( 60 | torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max, 61 | const float near, const float far) { 62 | CHECK_INPUT(rays_o); 63 | CHECK_INPUT(rays_d); 64 | CHECK_INPUT(xyz_min); 65 | CHECK_INPUT(xyz_max); 66 | return infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far); 67 | } 68 | 69 | torch::Tensor infer_n_samples(torch::Tensor t_min, torch::Tensor t_max, const float stepdist) { 70 | CHECK_INPUT(t_min); 71 | CHECK_INPUT(t_max); 72 | return infer_n_samples_cuda(t_min, t_max, stepdist); 73 | } 74 | 75 | std::vector infer_ray_start_dir(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) { 76 | CHECK_INPUT(rays_o); 77 | CHECK_INPUT(rays_d); 78 | CHECK_INPUT(t_min); 79 | return infer_ray_start_dir_cuda(rays_o, rays_d, t_min); 80 | } 81 | 82 | std::vector sample_pts_on_rays( 83 | torch::Tensor rays_o, torch::Tensor rays_d, 84 | torch::Tensor xyz_min, torch::Tensor xyz_max, 85 | const float near, const float far, const float stepdist) { 86 | CHECK_INPUT(rays_o); 87 | CHECK_INPUT(rays_d); 88 | CHECK_INPUT(xyz_min); 89 | CHECK_INPUT(xyz_max); 90 | assert(rays_o.dim()==2); 91 | assert(rays_o.size(1)==3); 92 | return sample_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist); 93 | } 94 | 95 | std::vector sample_pts_on_rays_geo( 96 | torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor pnt_xyz, torch::Tensor tensoRF_per_ray, 97 | torch::Tensor half_range_sqr, 98 | torch::Tensor xyz_min, torch::Tensor xyz_max, 99 | const float near, const float far, const float stepdist, 100 | torch::Tensor local_range, torch::Tensor local_dims, torch::Tensor local_stepsize) { 101 | CHECK_INPUT(rays_o); 102 | CHECK_INPUT(rays_d); 103 | CHECK_INPUT(pnt_xyz); 104 | CHECK_INPUT(tensoRF_per_ray); 105 | CHECK_INPUT(xyz_min); 106 | CHECK_INPUT(xyz_max); 107 | assert(rays_o.dim()==2); 108 | assert(rays_o.size(1)==3); 109 | return sample_pts_on_rays_geo_cuda(rays_o, rays_d, pnt_xyz, tensoRF_per_ray, half_range_sqr, xyz_min, xyz_max, near, far, stepdist, local_range, local_dims, local_stepsize); 110 | } 111 | 112 | 113 | std::vector sample_pts_on_rays_dist( 114 | torch::Tensor rays_o, torch::Tensor rays_d, 115 | torch::Tensor xyz_min, torch::Tensor xyz_max, 116 | const float near, const float far, const float stepdist, torch::Tensor shift) { 117 | CHECK_INPUT(rays_o); 118 | CHECK_INPUT(rays_d); 119 | CHECK_INPUT(xyz_min); 120 | CHECK_INPUT(xyz_max); 121 | assert(rays_o.dim()==2); 122 | assert(rays_o.size(1)==3); 123 | return sample_pts_on_rays_dist_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist, shift); 124 | } 125 | 126 | std::vector sample_ndc_pts_on_rays( 127 | torch::Tensor rays_o, torch::Tensor rays_d, 128 | torch::Tensor xyz_min, torch::Tensor xyz_max, 129 | const int N_samples) { 130 | CHECK_INPUT(rays_o); 131 | CHECK_INPUT(rays_d); 132 | CHECK_INPUT(xyz_min); 133 | CHECK_INPUT(xyz_max); 134 | assert(rays_o.dim()==2); 135 | assert(rays_o.size(1)==3); 136 | return sample_ndc_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, N_samples); 137 | } 138 | 139 | torch::Tensor maskcache_lookup(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift) { 140 | CHECK_INPUT(world); 141 | CHECK_INPUT(xyz); 142 | CHECK_INPUT(xyz2ijk_scale); 143 | CHECK_INPUT(xyz2ijk_shift); 144 | assert(world.dim()==3); 145 | assert(xyz.dim()==2); 146 | assert(xyz.size(1)==3); 147 | return maskcache_lookup_cuda(world, xyz, xyz2ijk_scale, xyz2ijk_shift); 148 | } 149 | 150 | std::vector raw2alpha(torch::Tensor density, const float shift, const float interval) { 151 | CHECK_INPUT(density); 152 | assert(density.dim()==1); 153 | return raw2alpha_cuda(density, shift, interval); 154 | } 155 | 156 | torch::Tensor raw2alpha_backward(torch::Tensor exp, torch::Tensor grad_back, const float interval) { 157 | CHECK_INPUT(exp); 158 | CHECK_INPUT(grad_back); 159 | return raw2alpha_backward_cuda(exp, grad_back, interval); 160 | } 161 | 162 | std::vector raw2alpha_randstep(torch::Tensor density, const float shift, torch::Tensor interval) { 163 | CHECK_INPUT(density); 164 | CHECK_INPUT(interval); 165 | assert(density.dim()==1); 166 | return raw2alpha_randstep_cuda(density, shift, interval); 167 | } 168 | 169 | torch::Tensor raw2alpha_randstep_backward(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval) { 170 | CHECK_INPUT(exp); 171 | CHECK_INPUT(interval); 172 | CHECK_INPUT(grad_back); 173 | return raw2alpha_randstep_backward_cuda(exp, grad_back, interval); 174 | } 175 | 176 | 177 | std::vector alpha2weight(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) { 178 | CHECK_INPUT(alpha); 179 | CHECK_INPUT(ray_id); 180 | assert(alpha.dim()==1); 181 | assert(ray_id.dim()==1); 182 | assert(alpha.sizes()==ray_id.sizes()); 183 | return alpha2weight_cuda(alpha, ray_id, n_rays); 184 | } 185 | 186 | torch::Tensor alpha2weight_backward( 187 | torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last, 188 | torch::Tensor i_start, torch::Tensor i_end, const int n_rays, 189 | torch::Tensor grad_weights, torch::Tensor grad_last) { 190 | CHECK_INPUT(alpha); 191 | CHECK_INPUT(weight); 192 | CHECK_INPUT(T); 193 | CHECK_INPUT(alphainv_last); 194 | CHECK_INPUT(i_start); 195 | CHECK_INPUT(i_end); 196 | CHECK_INPUT(grad_weights); 197 | CHECK_INPUT(grad_last); 198 | return alpha2weight_backward_cuda( 199 | alpha, weight, T, alphainv_last, 200 | i_start, i_end, n_rays, 201 | grad_weights, grad_last); 202 | } 203 | 204 | torch::Tensor filter_ray_by_points(torch::Tensor xyz, torch::Tensor pnts, torch::Tensor half_range) { 205 | CHECK_INPUT(xyz); 206 | CHECK_INPUT(pnts); 207 | assert(xyz.dim()==3); 208 | assert(pnts.dim()==2); 209 | return filter_ray_by_points_cuda(xyz, pnts, half_range); 210 | } 211 | 212 | 213 | torch::Tensor filter_ray_by_projection(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor pnts, torch::Tensor half_range_sqr) { 214 | CHECK_INPUT(rays_o); 215 | CHECK_INPUT(rays_d); 216 | CHECK_INPUT(pnts); 217 | assert(rays_o.dim()==2); 218 | assert(rays_d.dim()==2); 219 | assert(pnts.dim()==2); 220 | return filter_ray_by_projection_cuda(rays_o, rays_d, pnts, half_range_sqr); 221 | } 222 | 223 | 224 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 225 | m.def("infer_t_minmax", &infer_t_minmax, "Inference t_min and t_max of ray-bbox intersection"); 226 | m.def("infer_n_samples", &infer_n_samples, "Inference the number of points to sample on each ray"); 227 | m.def("infer_ray_start_dir", &infer_ray_start_dir, "Inference the starting point and shooting direction of each ray"); 228 | m.def("sample_pts_on_rays", &sample_pts_on_rays, "Sample points on rays"); 229 | m.def("sample_pts_on_rays_dist", &sample_pts_on_rays_dist, "Sample points on rays with shading dist"); 230 | m.def("sample_pts_on_rays_geo", &sample_pts_on_rays_geo, "Sample points on rays with point tensoRF"); 231 | m.def("sample_ndc_pts_on_rays", &sample_ndc_pts_on_rays, "Sample points on rays"); 232 | m.def("maskcache_lookup", &maskcache_lookup, "Lookup to skip know freespace."); 233 | m.def("raw2alpha", &raw2alpha, "Raw values [-inf, inf] to alpha [0, 1]."); 234 | m.def("raw2alpha_backward", &raw2alpha_backward, "Backward pass of the raw to alpha"); m.def("raw2alpha_randstep", &raw2alpha_randstep, "Raw values [-inf, inf] to alpha [0, 1]. if the step is not the same"); 235 | m.def("raw2alpha_randstep_backward", &raw2alpha_randstep_backward, "Backward pass of the raw to alpha if the step is not the same"); 236 | m.def("alpha2weight", &alpha2weight, "Per-point alpha to accumulated blending weight"); 237 | m.def("alpha2weight_backward", &alpha2weight_backward, "Backward pass of alpha2weight"); 238 | m.def("filter_ray_by_points", &filter_ray_by_points, "filter ray if nearby points"); 239 | m.def("filter_ray_by_projection", &filter_ray_by_projection, "filter ray if within projection"); 240 | } 241 | 242 | -------------------------------------------------------------------------------- /dataLoader/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch, re 2 | import numpy as np 3 | from torch import searchsorted 4 | from kornia import create_meshgrid 5 | 6 | 7 | # from utils import index_point_feature 8 | class SimpleSampler: 9 | def __init__(self, total, batch): 10 | self.total = total 11 | self.batch = batch 12 | self.curr = total 13 | self.ids = None 14 | 15 | def nextids(self): 16 | self.curr+=self.batch 17 | if self.curr + self.batch > self.total: 18 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 19 | self.curr = 0 20 | return self.ids[self.curr:self.curr+self.batch] 21 | 22 | def depth2dist(z_vals, cos_angle): 23 | # z_vals: [N_ray N_sample] 24 | device = z_vals.device 25 | dists = z_vals[..., 1:] - z_vals[..., :-1] 26 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 27 | dists = dists * cos_angle.unsqueeze(-1) 28 | return dists 29 | 30 | 31 | def ndc2dist(ndc_pts, cos_angle): 32 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 33 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 34 | return dists 35 | 36 | 37 | def get_ray_directions(H, W, focal, center=None): 38 | """ 39 | Get ray directions for all pixels in camera coordinate. 40 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 41 | ray-tracing-generating-camera-rays/standard-coordinate-systems 42 | Inputs: 43 | H, W, focal: image height, width and focal length 44 | Outputs: 45 | directions: (H, W, 3), the direction of the rays in camera coordinate 46 | """ 47 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 48 | 49 | i, j = grid.unbind(-1) 50 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 51 | # see https://github.com/bmild/nerf/issues/24 52 | cent = center if center is not None else [W / 2, H / 2] 53 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 54 | 55 | return directions, torch.stack([i,j], dim=-1) 56 | 57 | 58 | def get_ray_directions_blender(H, W, focal, center=None): 59 | """ 60 | Get ray directions for all pixels in camera coordinate. 61 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 62 | ray-tracing-generating-camera-rays/standard-coordinate-systems 63 | Inputs: 64 | H, W, focal: image height, width and focal length 65 | Outputs: 66 | directions: (H, W, 3), the direction of the rays in camera coordinate 67 | """ 68 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 69 | i, j = grid.unbind(-1) 70 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 71 | # see https://github.com/bmild/nerf/issues/24 72 | cent = center if center is not None else [W / 2, H / 2] 73 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], 74 | -1) # (H, W, 3) 75 | 76 | return directions 77 | 78 | 79 | def get_rays(directions, c2w): 80 | """ 81 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 82 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 83 | ray-tracing-generating-camera-rays/standard-coordinate-systems 84 | Inputs: 85 | directions: (H, W, 3) precomputed ray directions in camera coordinate 86 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 87 | Outputs: 88 | rays_o: (H*W, 3), the origin of the rays in world coordinate 89 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 90 | """ 91 | # Rotate ray directions from camera coordinate to the world coordinate 92 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 93 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 94 | # The origin of all rays is the camera origin in world coordinate 95 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 96 | 97 | rays_d = rays_d.view(-1, 3) 98 | rays_o = rays_o.view(-1, 3) 99 | 100 | return rays_o, rays_d 101 | 102 | 103 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): 104 | # Shift ray origins to near plane 105 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 106 | rays_o = rays_o + t[..., None] * rays_d 107 | 108 | # Projection 109 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 110 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 111 | o2 = 1. + 2. * near / rays_o[..., 2] 112 | 113 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 114 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 115 | d2 = -2. * near / rays_o[..., 2] 116 | 117 | rays_o = torch.stack([o0, o1, o2], -1) 118 | rays_d = torch.stack([d0, d1, d2], -1) 119 | 120 | return rays_o, rays_d 121 | 122 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 123 | # Shift ray origins to near plane 124 | t = (near - rays_o[..., 2]) / rays_d[..., 2] 125 | rays_o = rays_o + t[..., None] * rays_d 126 | 127 | # Projection 128 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 129 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 130 | o2 = 1. - 2. * near / rays_o[..., 2] 131 | 132 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 133 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 134 | d2 = 2. * near / rays_o[..., 2] 135 | 136 | rays_o = torch.stack([o0, o1, o2], -1) 137 | rays_d = torch.stack([d0, d1, d2], -1) 138 | 139 | return rays_o, rays_d 140 | 141 | # Hierarchical sampling (section 5.2) 142 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 143 | device = weights.device 144 | # Get pdf 145 | weights = weights + 1e-5 # prevent nans 146 | pdf = weights / torch.sum(weights, -1, keepdim=True) 147 | cdf = torch.cumsum(pdf, -1) 148 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 149 | 150 | # Take uniform samples 151 | if det: 152 | u = torch.linspace(0., 1., steps=N_samples, device=device) 153 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 154 | else: 155 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) 156 | 157 | # Pytest, overwrite u with numpy's fixed random numbers 158 | if pytest: 159 | np.random.seed(0) 160 | new_shape = list(cdf.shape[:-1]) + [N_samples] 161 | if det: 162 | u = np.linspace(0., 1., N_samples) 163 | u = np.broadcast_to(u, new_shape) 164 | else: 165 | u = np.random.rand(*new_shape) 166 | u = torch.Tensor(u) 167 | 168 | # Invert CDF 169 | u = u.contiguous() 170 | inds = searchsorted(cdf.detach(), u, right=True) 171 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 172 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 173 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 174 | 175 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 176 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 177 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 178 | 179 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 180 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 181 | t = (u - cdf_g[..., 0]) / denom 182 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 183 | 184 | return samples 185 | 186 | 187 | def dda(rays_o, rays_d, bbox_3D): 188 | inv_ray_d = 1.0 / (rays_d + 1e-6) 189 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 190 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d 191 | t = torch.stack((t_min, t_max)) # 2 N_rays 3 192 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] 193 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] 194 | return t_min, t_max 195 | 196 | 197 | def ray_marcher(rays, 198 | N_samples=64, 199 | lindisp=False, 200 | perturb=0, 201 | bbox_3D=None): 202 | """ 203 | sample points along the rays 204 | Inputs: 205 | rays: () 206 | 207 | Returns: 208 | 209 | """ 210 | 211 | # Decompose the inputs 212 | N_rays = rays.shape[0] 213 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 214 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 215 | 216 | if bbox_3D is not None: 217 | # cal aabb boundles 218 | near, far = dda(rays_o, rays_d, bbox_3D) 219 | 220 | # Sample depth points 221 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 222 | if not lindisp: # use linear sampling in depth space 223 | z_vals = near * (1 - z_steps) + far * z_steps 224 | else: # use linear sampling in disparity space 225 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 226 | 227 | z_vals = z_vals.expand(N_rays, N_samples) 228 | 229 | if perturb > 0: # perturb sampling depths (z_vals) 230 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points 231 | # get intervals between samples 232 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 233 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 234 | 235 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 236 | z_vals = lower + (upper - lower) * perturb_rand 237 | 238 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 239 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 240 | 241 | return xyz_coarse_sampled, rays_o, rays_d, z_vals 242 | 243 | 244 | def read_pfm(filename): 245 | file = open(filename, 'rb') 246 | color = None 247 | width = None 248 | height = None 249 | scale = None 250 | endian = None 251 | 252 | header = file.readline().decode('utf-8').rstrip() 253 | if header == 'PF': 254 | color = True 255 | elif header == 'Pf': 256 | color = False 257 | else: 258 | raise Exception('Not a PFM file.') 259 | 260 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 261 | if dim_match: 262 | width, height = map(int, dim_match.groups()) 263 | else: 264 | raise Exception('Malformed PFM header.') 265 | 266 | scale = float(file.readline().rstrip()) 267 | if scale < 0: # little-endian 268 | endian = '<' 269 | scale = -scale 270 | else: 271 | endian = '>' # big-endian 272 | 273 | data = np.fromfile(file, endian + 'f') 274 | shape = (height, width, 3) if color else (height, width) 275 | 276 | data = np.reshape(data, shape) 277 | data = np.flipud(data) 278 | file.close() 279 | return data, scale 280 | 281 | 282 | def ndc_bbox(all_rays): 283 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] 284 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] 285 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] 286 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] 287 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') 288 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) --------------------------------------------------------------------------------