├── models ├── __init__.py └── sh.py ├── colmapUtils ├── __init__.py ├── read_write_fused_vis.py ├── read_write_dense.py └── read_write_model.py ├── start.bat ├── .DS_Store ├── find_filter ├── cppNcuda │ ├── test.cpp │ ├── nvccMy.bat │ ├── test.py │ ├── setup.py │ ├── cal_score.cpp │ ├── cal_every_cuda.cu │ └── cal_every.cpp ├── coordinate_ssf_distance.mat ├── random_mask.py ├── find_best_filters.py ├── GA.py └── GA2.py ├── sampleInput ├── random_(15,14)_60.mat ├── random_(9,19)_120.mat ├── random_(15,14)_206.mat └── random_mask.py ├── filesort_int.py ├── .vscode ├── sftp.json ├── launch.json └── c_cpp_properties.json ├── .gitignore ├── dataLoader ├── __init__.py ├── IDsampler.py ├── spec_synthetic.py ├── blender.py ├── nsvf.py ├── tankstemple.py ├── ray_utils.py ├── colmap2nerf.py └── llff.py ├── configs ├── sofa_eigen.txt └── xjhdesk.txt ├── README.md ├── split_allimg2filterfixed_classify.py ├── extra ├── compute_metrics.py └── auto_run_paramsets.py ├── opt.py ├── utils.py ├── renderer.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /colmapUtils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /start.bat: -------------------------------------------------------------------------------- 1 | python train.py --config configs\sofa_eigen.txt 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CPREgroup/SpecNeRF-v2/HEAD/.DS_Store -------------------------------------------------------------------------------- /find_filter/cppNcuda/test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void main(){ 5 | 6 | } 7 | -------------------------------------------------------------------------------- /sampleInput/random_(15,14)_60.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CPREgroup/SpecNeRF-v2/HEAD/sampleInput/random_(15,14)_60.mat -------------------------------------------------------------------------------- /sampleInput/random_(9,19)_120.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CPREgroup/SpecNeRF-v2/HEAD/sampleInput/random_(9,19)_120.mat -------------------------------------------------------------------------------- /sampleInput/random_(15,14)_206.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CPREgroup/SpecNeRF-v2/HEAD/sampleInput/random_(15,14)_206.mat -------------------------------------------------------------------------------- /find_filter/coordinate_ssf_distance.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CPREgroup/SpecNeRF-v2/HEAD/find_filter/coordinate_ssf_distance.mat -------------------------------------------------------------------------------- /filesort_int.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | def sort_file_int(path, ext): 6 | imgpaths = sorted(glob.glob(os.path.join(path)), 7 | key=lambda x: int(x.split('_')[-1][:-1 - len(ext)])) 8 | return imgpaths 9 | -------------------------------------------------------------------------------- /.vscode/sftp.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "3090Server", 3 | "host": "10.22.149.17", 4 | "protocol": "sftp", 5 | "port": 22, 6 | "username": "ljb", 7 | "remotePath": "/", 8 | "uploadOnSave": false, 9 | "useTempFile": false, 10 | "openSsh": false 11 | } 12 | -------------------------------------------------------------------------------- /find_filter/cppNcuda/nvccMy.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | nvcc -I "H:\virtualenv\torch111_cuda113\Scripts","C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\include","C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\lib\x64","C:\ProgramData\NVIDIA Corporation\CUDA Samples\v11.3\common\inc" -o cal_every_cuda.exe cal_every_cuda.cu -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | *.pyc 3 | log/ 4 | logexp/ 5 | logexp_ssfRepresent/ 6 | data/ 7 | myspecdata/ 8 | multi-view-MSI/ 9 | find_filter/cppNcuda/dist 10 | find_filter/cppNcuda/build 11 | find_filter/cppNcuda/include 12 | find_filter/cppNcuda/gascore.egg-info 13 | find_filter/find_filter_res/ 14 | find_filter/find_filter_res 15 | find_filter_res/ 16 | log_filter/ 17 | start 18 | myspecdata 19 | -------------------------------------------------------------------------------- /sampleInput/random_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | 4 | 5 | num = 120 6 | shape = (9, 19) 7 | 8 | 9 | total = shape[1] * shape[0] 10 | 11 | 12 | maskv = np.array([0] * (total - num) + [1] * num) 13 | np.random.shuffle(maskv) 14 | mask = maskv.reshape(shape) 15 | 16 | print(mask) 17 | print(np.sum(mask)) 18 | sio.savemat(f'./sampleInput/random_{str(shape).replace(" ", "")}_{num}.mat', {'mask': mask}) 19 | -------------------------------------------------------------------------------- /find_filter/random_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | 4 | 5 | num = 15*14 - 4 6 | shape = (15, 14) 7 | 8 | 9 | total = shape[1] * shape[0] 10 | 11 | 12 | maskv = np.array([0] * (total - num) + [1] * num) 13 | np.random.shuffle(maskv) 14 | mask = maskv.reshape(shape) 15 | 16 | print(mask) 17 | print(np.sum(mask)) 18 | sio.savemat(f'./find_filter/find_filter_res/random_{str(shape).replace(" ", "")}_{num}.mat', {'mask': mask}) 19 | -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .llff import LLFFDataset 2 | from .blender import BlenderDataset 3 | from .nsvf import NSVF 4 | from .tankstemple import TanksTempleDataset 5 | from .spec_synthetic import FAKEDataset 6 | # from .spec_llff import SPECLLFFDataset 7 | 8 | 9 | dataset_dict = {'blender': BlenderDataset, 10 | 'llff':LLFFDataset, 11 | 'synthetic': FAKEDataset, 12 | 'tankstemple':TanksTempleDataset, 13 | 'nsvf':NSVF} -------------------------------------------------------------------------------- /find_filter/cppNcuda/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import numpy_test as nt 4 | import gascore 5 | 6 | # bd = 20000 7 | 8 | # a = np.ones([bd, bd], dtype=np.float32) * 2 9 | # b = np.ones_like(a) * 2 10 | 11 | # s1 = time.time() 12 | # c = nt.test(a, b) 13 | # print(time.time() - s1) 14 | # print(c) 15 | 16 | # s2 = time.time() 17 | # d = nt.test_cuda(a, b) 18 | # print(time.time() - s2) 19 | # print(d) 20 | 21 | 22 | a = [np.random.random([10, 2]) for _ in range(20)] 23 | 24 | gascore.cal_score(a) 25 | 26 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "args": [ 15 | "--config", 16 | "configs/xjhdesk.txt" 17 | ] 18 | } 19 | ] 20 | } -------------------------------------------------------------------------------- /dataLoader/IDsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | class IDDataset(Dataset): 6 | def __init__(self, total) -> None: 7 | super().__init__() 8 | 9 | self.total = total 10 | 11 | def __getitem__(self, index: Any) -> Any: 12 | return index 13 | 14 | def __len__(self): 15 | return self.total 16 | 17 | def get_simple_sampler(total, batch): 18 | id_dataset = IDDataset(total) 19 | sampler = DataLoader(id_dataset, batch, shuffle=True) 20 | 21 | return sampler 22 | 23 | if __name__ == '__main__': 24 | sampler = get_simple_sampler(100, 10) 25 | 26 | outlist = [] 27 | for i in sampler: 28 | outlist.append(i) 29 | pass 30 | 31 | outlist = [] 32 | for i in range(200): 33 | outp = next(iter(sampler)) 34 | print(i) 35 | if i < 10: 36 | outlist.append(outp) 37 | pass 38 | -------------------------------------------------------------------------------- /find_filter/cppNcuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension 3 | 4 | # setup( 5 | # name='numpy_test', 6 | # version='2.8', 7 | # author='lee', 8 | # description='numpy_test', 9 | # long_description='numpy_test', 10 | # ext_modules=[ 11 | # CUDAExtension( 12 | # name='numpy_test', 13 | # include_dirs=['./include', "C:\\ProgramData\\NVIDIA Corporation\\CUDA Samples\\v11.3\\common\\inc"], 14 | # sources=['cal_every.cpp', 'cal_every_cuda.cu'] 15 | # ) 16 | # ], 17 | # cmdclass={ 18 | # 'build_ext': BuildExtension 19 | # } 20 | # ) 21 | 22 | setup( 23 | name='gascore', 24 | version='3.1', 25 | author='lee', 26 | description='gascore', 27 | long_description='gascore', 28 | ext_modules=[ 29 | CppExtension( 30 | name='gascore', 31 | include_dirs=['E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/pybind11/include', 32 | 'E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/**', 33 | 'E:/Program Files (x86)/Python37/include'], 34 | sources=['cal_score.cpp'] 35 | ) 36 | ], 37 | cmdclass={ 38 | 'build_ext': BuildExtension 39 | } 40 | ) 41 | 42 | 43 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Win32", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/pybind11/include", 8 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/torch/csrc/api/include/torch/**", 9 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/torch/csrc/api/include/torch/detail", 10 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/torch/csrc/api/include", 11 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include", 12 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/**", 13 | "E:/virtualenv_py37/torch1_11_cuda_11_3/Lib/site-packages/torch/include/torch/csrc/utils", 14 | "E:/Program Files (x86)/Python37/include" 15 | ], 16 | "defines": [ 17 | "_DEBUG", 18 | "UNICODE", 19 | "_UNICODE" 20 | ], 21 | "compilerPath": "E:/Program Files (x86)/mingw/mingw64/bin/gcc.exe", 22 | "cStandard": "c11", 23 | "cppStandard": "c++17", 24 | "intelliSenseMode": "windows-gcc-x64" 25 | } 26 | ], 27 | "version": 4 28 | } -------------------------------------------------------------------------------- /configs/sofa_eigen.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = synthetic 3 | datadir = ./multi-view-MSI/filterGate/sofa_eigen 4 | basedir = ./logexp 5 | filters_folder = filters_eigen 6 | expname = rand60_sofa_nmf5-test 7 | angles = 15 8 | filters = 14 9 | img_ext = tiff 10 | img_dir_name = pose_blackbg_??.mat 11 | sample_matrix_dir = ./sampleInput/random_(15,14)_60.mat # random_(15,14)_206.mat 12 | 13 | # white_bkgd 14 | ndc_ray = 1 15 | ssf_model = dcp_nmf # rbf # neuRBF # gt # fcn 16 | rgb4shape_endIter = 0 17 | reset_para 18 | # depth_supervise 19 | # depth_batchsize_endIter = [512, 300] 20 | # distortion_loss 21 | # lsc 22 | spec_channel = 15 23 | observation_channel = 3 24 | band_start_idx = 0 25 | colIdx4RGBTrain = 0 26 | downsample_train = 1 27 | crop_hw = [400, 400] 28 | 29 | n_iters = 500 30 | batch_size = 8192 31 | chunk_size = 8192 32 | 33 | N_voxel_init = 2097156 # 128**3 34 | N_voxel_final = 134217728 # 512**3 35 | upsamp_list = [2000, 4000, 8000, 12000] # [2000,3000,4000,5500,7000] 36 | update_AlphaMask_list = [3000, 7000] # [2000,4000] 37 | 38 | N_vis = -1 39 | vis_every = 500 40 | lr_upsample_reset = 1 41 | 42 | export_mesh = 0 43 | render_only = 0 44 | # ckpt = logexp/rand206_sofaDummyRBFGateFilter_nmf5/rand206_sofaDummyRBFGateFilter_nmf5.th 45 | render_test = 1 46 | render_train = 0 47 | render_path = 1 48 | 49 | n_lamb_sigma = [10,7,7] 50 | n_lamb_sh = [36,18,18] 51 | 52 | shadingMode = MLP_Fea 53 | fea2denseAct = relu 54 | featureC = 80 55 | 56 | view_pe = 2 57 | fea_pe = 2 58 | 59 | L1_weight_inital = 8e-5 60 | L1_weight_rest = 4e-5 61 | rm_weight_mask_thre = 1e-4 62 | TV_weight_density = 0.8 63 | TV_weight_app = 0.3 64 | TV_weight_spec = 0.04 65 | -------------------------------------------------------------------------------- /configs/xjhdesk.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = llff 3 | datadir = ./multi-view-MSI/filter19/xjhdesk 4 | basedir = ./logexp 5 | filters_folder = filters 6 | expname = rand120_xjhdesk_RBF4 7 | angles = 9 8 | filters = 19 9 | img_ext = tiff 10 | img_dir_name = pose??img 11 | sample_matrix_dir = ./sampleInput/random_(9,19)_120.mat 12 | 13 | ndc_ray = 1 14 | ssf_model = rbf # fcn 15 | ssf_model_components = 4 16 | rgb4shape_endIter = 2000 17 | reset_para 18 | # depth_supervise 19 | # depth_batchsize_endIter = [512, 300] 20 | # distortion_loss 21 | # lsc 22 | spec_channel = 31 23 | observation_channel = 1 24 | band_start_idx = 5 25 | colIdx4RGBTrain = 0 26 | downsample_train = 2 27 | crop_hw = [1000, 1500] 28 | 29 | 30 | n_iters = 25000 31 | batch_size = 4096 32 | chunk_size = 4096 33 | 34 | N_voxel_init = 2097156 # 128**3 35 | N_voxel_final = 134217728 # 512**3 36 | upsamp_list = [4000, 6000, 9000, 12000, 15000] # [2000,3000,4000,5500,7000] 37 | update_AlphaMask_list = [5000, 8000] # [2000,4000] 38 | 39 | N_vis = -1 40 | vis_every = 500 41 | 42 | export_mesh = 0 43 | render_only = 0 44 | # ckpt = logexp/all_xjhdesk_RBF4_fea128/24999_all_xjhdesk_RBF4_fea128.pth 45 | render_test = 1 46 | render_train = 0 47 | render_path = 1 48 | # render_test_exhibition = 1 49 | # exhibition_filters_path = myspecdata/filters19_optimized/xjhdesk/exhibition/filtersets.mat 50 | # exhibition_ssfs_path = myspecdata/filters19_optimized/xjhdesk/exhibition/ssfs.mat 51 | # exhibition_lights_path = myspecdata/filters19_optimized/xjhdesk/exhibition/lightsSPD.mat 52 | # exhibition_lightorigin_path = myspecdata/filters19_optimized/xjhdesk/lightspec.mat 53 | 54 | n_lamb_sigma = [10,7,7] 55 | n_lamb_sh = [36,18,18] 56 | 57 | shadingMode = MLP_Fea 58 | fea2denseAct = relu 59 | featureC = 128 60 | 61 | view_pe = 2 62 | fea_pe = 2 63 | 64 | L1_weight_inital = 8e-5 65 | L1_weight_rest = 4e-5 66 | rm_weight_mask_thre = 1e-4 67 | TV_weight_density = 0.8 68 | TV_weight_app = 0.5 69 | TV_weight_spec = 0.015 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spec-NeRF: Multi-Spectral Neural Radiance Fields 2 | [Jiabao Li](https://github.com/TIMESTICKING), [Yuqi Li*](https://github.com/kylin-leo), Ciliang Sun, Chong Wang, and Jinhui Xiang 3 | 4 | [Arxiv](https://arxiv.org/abs/2310.12987). 5 | 6 | ## Intro 7 | 8 | Spec-NeRF jointly optimizes the degradation parameters and achieves high-quality multi-spectral image reconstruction results at novel views, which only requires a low-cost camera (like a phone camera but in RAW mode) and several off-the-shelf color filters. We also provide real scenarios and synthetic datasets for related studies. 9 | 10 | ## Video Demo 11 | With recovered spectral information of the scenario, we can achieve several art effects, like 12 | 13 | ### Change the filters 14 | 15 | 16 | https://github.com/CPREgroup/SpecNeRF-v2/assets/56912131/eae98987-3521-4f76-a756-9f6fe0115039 17 | 18 | 19 | ### Change the camera's SSF 20 | 21 | 22 | https://github.com/CPREgroup/SpecNeRF-v2/assets/56912131/3d270945-bbe5-4281-9a81-eccc0f09bd00 23 | 24 | 25 | 26 | ### Change the ambient light source spectrum 27 | 28 | 29 | https://github.com/CPREgroup/SpecNeRF-v2/assets/56912131/b230b331-8f53-46ac-8d8e-73232ed6f3a1 30 | 31 | 32 | 33 | ## Preliminaries 34 | 35 | We conduct our experiments based on [TensoRF](https://apchenstu.github.io/TensoRF/), **please use the branch named `public` in our repository** and feel free to report issues, we'd really appreciate it! 36 | 37 | 38 | 39 | #### Tested on Ubuntu 18 / Windows 11 + Pytorch 1.11 + cuda 11.3 40 | 41 | 42 | 43 | ## Dataset 44 | Download the two types of datasets (real senario and synthetic one) from [google drive](https://drive.google.com/file/d/1wBux0JdsimjoDJfBvOuo8lJWi9ekk3cD/view?usp=drivesdk). 45 | 46 | For the real dataset, please first separate the images into pose-based groups since all raw images are in the folder `RAW` by running `python split_allimg2filterfixed_classify.py --scene_dir /multi-view-MSI/filter19/xjhdesk --filter_num 20 --legacy 0 --angle_num 9 --img_ext tiff` as administrator or root. 47 | 48 | This will generate nine folders named *pose?img*, each contains the images filtered by all color filters, which are the symbolic links point to the ones in `RAW` folder, so remember **do not** delete the `RAW` folder. 49 | 50 | 51 | ## Quick Start 52 | 53 | Check what's the config file in start/start.bat file and execute `./start` | `start.bat` 54 | 55 | or try 56 | 57 | `python train.py --config ./configs/.txt` 58 | 59 | 60 | -------------------------------------------------------------------------------- /find_filter/cppNcuda/cal_score.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace py=pybind11; 6 | typedef py::array_t arrf64; 7 | typedef py::array_t arrint; 8 | typedef py::detail::unchecked_reference unc; 9 | typedef py::detail::unchecked_reference unc1; 10 | 11 | double cal_idv(arrint &idv, const unc &dis_mtx_r, 12 | const unc &ssf_coor_r, const unc &cosSimi_r, 13 | const unc1 &ssf_trans_r, const double weight){ 14 | auto idv_r = idv.unchecked<2>(); 15 | double score = 0; 16 | double transsum = 0; 17 | int number = idv.shape(0); 18 | 19 | 20 | for (py::ssize_t i = 1; i < number; i++) 21 | { 22 | int r1 = idv_r(i, 0); 23 | int c1 = idv_r(i, 1); 24 | for (int j = 0; j < i; j++) 25 | { 26 | int r2 = idv_r(j, 0); 27 | int c2 = idv_r(j, 1); 28 | // std::cout << r1 << "," << c1<<","<(); 52 | auto ssf_coor_r = ssf_coor.unchecked<2>(); 53 | auto cosSimi_r = cosSimi.unchecked<2>(); 54 | auto ssf_trans_r = ssf_trans.unchecked<1>(); 55 | 56 | try 57 | { 58 | auto res = arrf64(py::len(all)); 59 | auto res_r = res.mutable_unchecked<1>(); 60 | for (auto idv_t : all){ 61 | arrint idv = py::reinterpret_borrow(idv_t); 62 | // py::print(idv); 63 | res_r(idx++) = cal_idv(idv, dis_mtx_r, ssf_coor_r, cosSimi_r, ssf_trans_r, weight); 64 | } 65 | return res; 66 | } 67 | catch(const std::exception& e) 68 | { 69 | std::cerr << e.what() << '\n'; 70 | } 71 | 72 | } 73 | 74 | 75 | PYBIND11_MODULE(gascore, m) { 76 | m.def("cal_score", &cal_score, "cal_score"); 77 | } 78 | 79 | -------------------------------------------------------------------------------- /find_filter/cppNcuda/cal_every_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 7 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 8 | { 9 | if (code != cudaSuccess) 10 | { 11 | // std::cout << "error" << std::endl; 12 | std::cout << stderr << "GPUassert: " << cudaGetErrorString(code)<>>(x, y, z, bd); 56 | // sync 57 | // cudaDeviceSynchronize(); 58 | // check 59 | // std::cout << "item " << z[2] << std::endl; 60 | // py::print(z[2]); 61 | // copy 62 | gpuErrchk(cudaMemcpy(arr3_buf, z, nBytes, cudaMemcpyDeviceToHost)); 63 | 64 | cudaFree(x); 65 | cudaFree(y); 66 | cudaFree(z); 67 | } 68 | 69 | void initarr(float *arr, int bd, float val){ 70 | for (int i = 0; i < bd * bd; i++) 71 | { 72 | arr[i] = val; 73 | } 74 | 75 | } 76 | int main(){ 77 | int bd = 10; 78 | float a[100] = {2.0}; initarr(a, bd, 2.0); 79 | float b[100] = {1.0}; initarr(b, bd, 1.0); 80 | float c[100] = {0.0}; initarr(c, bd, 0.0); 81 | 82 | std::cout << b[3] << std::endl; 83 | 84 | test_cuda_cu(a, b, c, bd); 85 | 86 | for(float i : c){ 87 | std::cout << i; 88 | } 89 | 90 | 91 | return 0; 92 | } 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /split_allimg2filterfixed_classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import shutil 5 | 6 | from filesort_int import sort_file_int 7 | 8 | jpgmode = False 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--scene_dir', type=str, default='data\spec_data\\test', help='scene directory') 12 | parser.add_argument('--filter_dir', type=str, default='../filters/', help='filters directory') 13 | parser.add_argument('--filter_num', type=int, default=21, help='filter number') 14 | parser.add_argument('--legacy', type=int, help='splitting based on filters (all view points in one filter folder)') 15 | parser.add_argument('--angle_num', type=int, default=16, help='shooting angle number') 16 | parser.add_argument('--img_ext', type=str, default='jpg' if jpgmode else 'dng', help='ext. of img files in scene') 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | print('='*5, 'img files should be the format of xxx_234.ext') 22 | 23 | 24 | allimg_dir = os.path.join(args.scene_dir, 'jpegs/' if jpgmode else 'RAW/') 25 | filter_num = args.filter_num 26 | angle_num = args.angle_num 27 | legacy = args.legacy 28 | 29 | imgpaths = sort_file_int(f'{allimg_dir}/*.{args.img_ext}', args.img_ext) 30 | assert filter_num * angle_num == len(imgpaths), 'numbers dont match!' 31 | 32 | def movefile(imgsname, scene_dir): 33 | for im in imgsname: 34 | # shutil.move(im, pose_dir) 35 | _, imname = os.path.split(im) 36 | print(f'{scene_dir}/{imname}') 37 | os.symlink(im, f'{scene_dir}/{imname}') 38 | 39 | 40 | def filter_base(): 41 | for i in range(0, filter_num): 42 | imgs = imgpaths[i::filter_num] 43 | print('='*5, i) 44 | print(imgs) 45 | 46 | filter_dir = os.path.join(args.scene_dir, f'filter{i}img_jpegs/images/' if jpgmode else f'filter{i}img/images/') 47 | if not os.path.exists(filter_dir): 48 | os.makedirs(filter_dir) 49 | 50 | movefile(imgs, filter_dir) 51 | 52 | shutil.copy(os.path.join(args.scene_dir, args.filter_dir, f'./f_{i}.mat'), os.path.join(filter_dir, '../')) 53 | print('done!') 54 | 55 | if jpgmode: 56 | break 57 | 58 | def pose_base(): 59 | for i in range(0, angle_num): 60 | imgs = imgpaths[i * filter_num: (i+1) * filter_num] 61 | print('=' * 5, i) 62 | print(imgs) 63 | 64 | pose_dir = os.path.join(args.scene_dir, f'pose{i}img/images/') 65 | if not os.path.exists(pose_dir): 66 | os.makedirs(pose_dir) 67 | 68 | movefile(imgs, pose_dir) 69 | 70 | print('done!') 71 | 72 | 73 | if __name__ == '__main__': 74 | if legacy: 75 | filter_base() 76 | else: 77 | pose_base() 78 | 79 | -------------------------------------------------------------------------------- /find_filter/find_best_filters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from filesort_int import * 7 | import torch 8 | from torch import nn 9 | import scipy.io as sio 10 | 11 | device = 'cpu' # torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | def read_filters(args): 15 | files = sort_file_int(f'{args.filter_dir}/*.mat', 'mat') 16 | filters = [] 17 | for f in files: 18 | filters.append( 19 | np.diagonal(sio.loadmat(f)['filter']) 20 | ) 21 | 22 | filters = torch.FloatTensor(filters).unsqueeze(0).repeat([args.angles, 1, 1]) 23 | 24 | return filters.permute([2, 0, 1]) 25 | 26 | 27 | 28 | class FindNet(torch.nn.Module): 29 | def __init__(self, args): 30 | super(FindNet, self).__init__() 31 | 32 | self.args = args 33 | self.cols = args.angles 34 | self.S = read_filters(args).to(device) 35 | self.filt_num = self.S.shape[-1] 36 | self.M = nn.Parameter(torch.randn(self.S.shape[1:])) 37 | pass 38 | 39 | 40 | def mySigmoid(self, x): 41 | return 1 / (1+ torch.exp(-20*(x-(0.5 + self.args.eps)))) 42 | 43 | 44 | def forward(self,x): 45 | M = self.mySigmoid(self.M) 46 | Sm = self.S * M # 31, 12, 20 47 | 48 | # temp_mul = torch.ones([31, self.filt_num]).to(device) 49 | # for i in range(self.cols): 50 | # temp_mul = temp_mul * Sm[:, i, :] 51 | 52 | Sm_pre, Sm_last = Sm[:, :-1, :], Sm[:, -1, :].unsqueeze(1) 53 | Sm_shift = torch.cat([Sm_last, Sm_pre], dim=1) 54 | temp_mul = Sm * Sm_shift 55 | 56 | 57 | L = torch.sum(temp_mul) 58 | return L, M 59 | 60 | 61 | 62 | def maind(args): 63 | mynet = FindNet(args).to(device) 64 | optimizer = torch.optim.Adam(mynet.parameters(), lr=0.001) 65 | 66 | allloss = [] 67 | 68 | for ep in range(args.epoch): 69 | L, M = mynet(0) 70 | 71 | U, S, Vh = torch.linalg.svd(M, full_matrices=True) 72 | highrank_loss = 1 / (1e-4 * torch.sum(S)) * 0.02 73 | 74 | sparse_loss = torch.linalg.norm(M,ord=1,dim=(0,1)) 75 | 76 | loss = L + highrank_loss + sparse_loss 77 | allloss.append(loss.item()) 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | 83 | if ep % 20 == 0: 84 | print(f'=={ep}: ', 'lossall', loss.item(), 'L', L.item(), 85 | 'sparse_loss', sparse_loss.item(), 'highrank_loss', highrank_loss.item()) 86 | 87 | M = M.cpu().detach().numpy() 88 | M = np.where(M>0.5, 1, 0) 89 | print(np.sum(M)) 90 | sio.savemat(f'find_filter_res/res_2by2_3.mat', {'mask': M}) 91 | plt.imshow(M) 92 | plt.show() 93 | plt.plot(allloss) 94 | plt.show() 95 | 96 | 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--filter_dir', type=str, default='myspecdata/filter20_no1/filters') 102 | parser.add_argument('--angles', type=int, default=12) 103 | parser.add_argument('--epoch', type=int, default=20000) 104 | parser.add_argument('--eps', type=float, default=0.5, help='the larger it is, the less ones') 105 | args = parser.parse_args() 106 | 107 | maind(args) 108 | 109 | -------------------------------------------------------------------------------- /find_filter/cppNcuda/cal_every.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "utils.h" 3 | 4 | namespace py=pybind11; 5 | 6 | 7 | py::array_t test(py::array_t arr1, py::array_t arr2){ 8 | py::buffer_info arr1buf = arr1.request(); 9 | py::buffer_info arr2buf = arr2.request(); 10 | // auto dir_arr1 = arr1.unchecked<2>(); 11 | // auto dir_arr2 = arr2.unchecked<2>(); 12 | 13 | auto result = py::array_t(arr1buf.size); 14 | result.resize(arr1buf.shape); 15 | // auto dir_res = result.mutable_unchecked<2>(); 16 | py::buffer_info res1 = result.request(); 17 | 18 | // for (py::size_t i = 0; i < result.shape(0); i++) 19 | // { 20 | // for (py::size_t j = 0; j < result.shape(1); j++) 21 | // { 22 | // dir_res(i, j) = dir_arr1(i, j) + dir_arr2(i, j); 23 | // } 24 | 25 | // } 26 | 27 | int ti = 1; 28 | for(int i : arr1buf.shape){ 29 | ti *= i; 30 | } 31 | 32 | float *ptr1 = static_cast(arr1buf.ptr); 33 | float *ptr2 = static_cast(arr2buf.ptr); 34 | float *ptr3 = static_cast(res1.ptr); 35 | 36 | for (size_t i = 0; i < ti; i++) 37 | { 38 | ptr3[i] = ptr1[i] + ptr2[i]; 39 | } 40 | 41 | 42 | return result; 43 | 44 | } 45 | 46 | 47 | py::array_t test_cuda(py::array_t arr1, py::array_t arr2){ 48 | py::buffer_info arr1_buf = arr1.request(), arr2_buf = arr2.request(); 49 | 50 | auto result = py::array_t(arr1.size()); 51 | result.resize(arr1_buf.shape); 52 | py::buffer_info res_buf = result.request(); 53 | 54 | // std::cout << "in cpp" << std::endl; 55 | 56 | test_cuda_cu((float*)arr1_buf.ptr,(float*)arr2_buf.ptr,(float*)res_buf.ptr, arr1_buf.shape[0]); 57 | 58 | return result; 59 | } 60 | 61 | // void initarr(py::buffer_info *arr, int bd){ 62 | // float *ptr = static_cast(arr->ptr); 63 | // for (py::size_t i = 0; i < bd * bd; i++) 64 | // { 65 | // ptr[i] = 2.0; 66 | // } 67 | 68 | // } 69 | 70 | // int main(){ 71 | // std::cout << "begin!" << std::endl; 72 | 73 | // int bd = 100; 74 | // py::array_t arr1 = py::array_t(bd *bd); 75 | // py::array_t arr2 = py::array_t(bd *bd); 76 | // std::cout << "begin1!" << std::endl; 77 | // py::buffer_info arr1_buf = arr1.request(), arr2_buf = arr2.request(); 78 | // arr1.resize({bd, bd}); 79 | // arr2.resize({bd, bd}); 80 | // std::cout << "begin2!" << std::endl; 81 | // initarr(&arr1_buf, bd); 82 | // initarr(&arr2_buf, bd); 83 | // std::cout << "checkarr1!" << std::endl; 84 | // std::cout << arr1.index_at(2, 2) << std::endl; 85 | 86 | // // auto result = py::array_t(arr1.size()); 87 | // // result.resize(arr1_buf.shape); 88 | // // py::buffer_info res_buf = result.request(); 89 | 90 | // // test_cuda_cu(&arr1_buf,&arr2_buf,&res_buf); 91 | // // auto res = test(arr1, arr2); 92 | // // std::cout << res.index_at(2, 2) << std::endl; 93 | 94 | // return 0; 95 | // } 96 | 97 | 98 | 99 | 100 | PYBIND11_MODULE(numpy_test, m) { 101 | m.def("test", &test, "test", py::return_value_policy::reference); 102 | m.def("test_cuda", &test_cuda, "test_cuda", py::return_value_policy::reference); 103 | } 104 | 105 | 106 | -------------------------------------------------------------------------------- /dataLoader/spec_synthetic.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | from dataLoader.llff import LLFFDataset, normalize, center_poses, get_spiral, get_ray_directions_blender, average_poses,ndc_rays_blender, get_rays 7 | from opt import args 8 | from PIL import Image 9 | from torchvision import transforms as T 10 | import scipy.io as sio 11 | 12 | 13 | class FAKEDataset(LLFFDataset): 14 | 15 | def __init__(self, datadir, split='train', downsample=1, is_stack=False, hold_every=8): 16 | super().__init__(datadir, split, downsample, is_stack, hold_every) 17 | 18 | 19 | def load_img(self): 20 | poses_img = [Path(args.datadir) / args.img_dir_name.replace('??', str(i)) 21 | for i in range(args.angles)] 22 | sample_matrix = self._fix_sample_matrix() 23 | 24 | W, H = self.img_wh 25 | # use first N_images-1 to train, the LAST is val 26 | all_rays = [] 27 | all_rgbs = [] 28 | all_poses = [] 29 | all_filtersIdx = [] 30 | ids4shapeTrain = [] 31 | tensor_cropper = T.CenterCrop(args.crop_hw) 32 | tensor_resizer = T.Resize([H, W], antialias=True) 33 | for r, row in enumerate(sample_matrix): 34 | images_degraded = sio.loadmat(poses_img[r])['all_degraded'] 35 | for c, aimEle in enumerate(row): 36 | if aimEle != 1: 37 | continue 38 | elif c == args.colIdx4RGBTrain: 39 | # it's for geometry training 40 | ids4shapeTrain.append(len(all_rays)) 41 | 42 | if images_degraded.ndim == 4: 43 | img = images_degraded[:, :, :, c] # for rgb image 44 | else: 45 | img = images_degraded[:, :, c:c+1] # for monochroma image 46 | img = torch.FloatTensor(img).permute(2, 0, 1) 47 | c2w = torch.FloatTensor(self.poses[r]) 48 | 49 | img = tensor_cropper(img) # c croph cropw [0-1] 50 | if self.downsample != 1.0: 51 | img = tensor_resizer(img) 52 | 53 | img = img.reshape(args.observation_channel, -1).permute(1, 0) # (h*w, 3) RGB 54 | all_rgbs.append(img) 55 | 56 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 57 | if args.ndc_ray == 1: 58 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 59 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 60 | all_rays.append(torch.cat([rays_o, rays_d], 1)) # (h*w, 6) 61 | all_poses.append(torch.LongTensor([[r]]).expand([rays_o.shape[0], -1])) 62 | all_filtersIdx.append(torch.LongTensor([[c]]).expand([rays_o.shape[0], -1])) 63 | 64 | print(f'{len(all_rgbs)} of images are loaded!') 65 | 66 | self.ids4shapeTrain = ids4shapeTrain 67 | self.raysnum_oneimage = all_rays[0].shape[0] 68 | if not self.is_stack: 69 | self.all_rays = torch.cat(all_rays, 0) # (len(self.meta['frames])*h*w, 3) 70 | self.all_rgbs = torch.cat(all_rgbs, 0) # (len(self.meta['frames])*h*w,3) 71 | self.all_poses = torch.cat(all_poses, 0) 72 | self.all_filtersIdx = torch.cat(all_filtersIdx, 0) 73 | else: 74 | self.all_rays = torch.stack(all_rays, 0) # (len(self.meta['frames]),h*w, 3) 75 | self.all_rgbs = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], args.observation_channel) # (len(self.meta['frames]),h,w,3) 76 | self.all_poses = torch.stack(all_poses, 0) 77 | self.all_filtersIdx = torch.stack(all_filtersIdx, 0) 78 | 79 | -------------------------------------------------------------------------------- /dataLoader/blender.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class BlenderDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.img_wh = (int(800/downsample),int(800/downsample)) 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [2.0,6.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = self.img_wh 45 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 46 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh 47 | 48 | 49 | # ray directions for all pixels, same for all images (same H, W, focal) 50 | self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) 51 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 52 | self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float() 53 | 54 | self.image_paths = [] 55 | self.poses = [] 56 | self.all_rays = [] 57 | self.all_rgbs = [] 58 | self.all_masks = [] 59 | self.all_depth = [] 60 | self.downsample=1.0 61 | 62 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 63 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 64 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 65 | 66 | frame = self.meta['frames'][i] 67 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 68 | c2w = torch.FloatTensor(pose) 69 | self.poses += [c2w] 70 | 71 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 72 | self.image_paths += [image_path] 73 | img = Image.open(image_path) 74 | 75 | if self.downsample!=1.0: 76 | img = img.resize(self.img_wh, Image.LANCZOS) 77 | img = self.transform(img) # (4, h, w) 78 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 79 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 80 | self.all_rgbs += [img] 81 | 82 | 83 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 84 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 85 | 86 | 87 | self.poses = torch.stack(self.poses) 88 | if not self.is_stack: 89 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 90 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 91 | 92 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 93 | else: 94 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 95 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 96 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 97 | 98 | 99 | def define_transforms(self): 100 | self.transform = T.ToTensor() 101 | 102 | def define_proj_mat(self): 103 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 104 | 105 | def world2ndc(self,points,lindisp=None): 106 | device = points.device 107 | return (points - self.center.to(device)) / self.radius.to(device) 108 | 109 | def __len__(self): 110 | return len(self.all_rgbs) 111 | 112 | def __getitem__(self, idx): 113 | 114 | if self.split == 'train': # use data in the buffers 115 | sample = {'rays': self.all_rays[idx], 116 | 'rgbs': self.all_rgbs[idx]} 117 | 118 | else: # create data for each image separately 119 | 120 | img = self.all_rgbs[idx] 121 | rays = self.all_rays[idx] 122 | mask = self.all_masks[idx] # for quantity evaluation 123 | 124 | sample = {'rays': rays, 125 | 'rgbs': img, 126 | 'mask': mask} 127 | return sample 128 | -------------------------------------------------------------------------------- /colmapUtils/read_write_fused_vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright 13 | # notice, this list of conditions and the following disclaimer in the 14 | # documentation and/or other materials provided with the distribution. 15 | # 16 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 17 | # its contributors may be used to endorse or promote products derived 18 | # from this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 24 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 25 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 26 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 27 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 28 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | # POSSIBILITY OF SUCH DAMAGE. 31 | # 32 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 33 | 34 | import os 35 | import collections 36 | import numpy as np 37 | import pandas as pd 38 | from pyntcloud import PyntCloud 39 | 40 | from read_write_model import read_next_bytes, write_next_bytes 41 | 42 | 43 | MeshPoint = collections.namedtuple( 44 | "MeshingPoint", ["position", "color", "normal", "num_visible_images", "visible_image_idxs"]) 45 | 46 | 47 | def read_fused(path_to_fused_ply, path_to_fused_ply_vis): 48 | """ 49 | see: src/mvs/meshing.cc 50 | void ReadDenseReconstruction(const std::string& path 51 | """ 52 | assert os.path.isfile(path_to_fused_ply) 53 | assert os.path.isfile(path_to_fused_ply_vis) 54 | 55 | point_cloud = PyntCloud.from_file(path_to_fused_ply) 56 | xyz_arr = point_cloud.points.loc[:, ["x", "y", "z"]].to_numpy() 57 | normal_arr = point_cloud.points.loc[:, ["nx", "ny", "nz"]].to_numpy() 58 | color_arr = point_cloud.points.loc[:, ["red", "green", "blue"]].to_numpy() 59 | 60 | with open(path_to_fused_ply_vis, "rb") as fid: 61 | num_points = read_next_bytes(fid, 8, "Q")[0] 62 | mesh_points = [0] * num_points 63 | for i in range(num_points): 64 | num_visible_images = read_next_bytes(fid, 4, "I")[0] 65 | visible_image_idxs = read_next_bytes( 66 | fid, num_bytes=4*num_visible_images, 67 | format_char_sequence="I"*num_visible_images) 68 | visible_image_idxs = np.array(tuple(map(int, visible_image_idxs))) 69 | mesh_point = MeshPoint( 70 | position=xyz_arr[i], 71 | color=color_arr[i], 72 | normal=normal_arr[i], 73 | num_visible_images=num_visible_images, 74 | visible_image_idxs=visible_image_idxs) 75 | mesh_points[i] = mesh_point 76 | return mesh_points 77 | 78 | 79 | def write_fused_ply(mesh_points, path_to_fused_ply): 80 | columns = ["x", "y", "z", "nx", "ny", "nz", "red", "green", "blue"] 81 | points_data_frame = pd.DataFrame( 82 | np.zeros((len(mesh_points), len(columns))), 83 | columns=columns) 84 | 85 | positions = np.asarray([point.position for point in mesh_points]) 86 | normals = np.asarray([point.normal for point in mesh_points]) 87 | colors = np.asarray([point.color for point in mesh_points]) 88 | 89 | points_data_frame.loc[:, ["x", "y", "z"]] = positions 90 | points_data_frame.loc[:, ["nx", "ny", "nz"]] = normals 91 | points_data_frame.loc[:, ["red", "green", "blue"]] = colors 92 | 93 | points_data_frame = points_data_frame.astype({ 94 | "x": positions.dtype, "y": positions.dtype, "z": positions.dtype, 95 | "red": colors.dtype, "green": colors.dtype, "blue": colors.dtype, 96 | "nx": normals.dtype, "ny": normals.dtype, "nz": normals.dtype}) 97 | 98 | point_cloud = PyntCloud(points_data_frame) 99 | point_cloud.to_file(path_to_fused_ply) 100 | 101 | 102 | def write_fused_ply_vis(mesh_points, path_to_fused_ply_vis): 103 | """ 104 | see: src/mvs/fusion.cc 105 | void WritePointsVisibility(const std::string& path, const std::vector>& points_visibility) 106 | """ 107 | with open(path_to_fused_ply_vis, "wb") as fid: 108 | write_next_bytes(fid, len(mesh_points), "Q") 109 | for point in mesh_points: 110 | write_next_bytes(fid, point.num_visible_images, "I") 111 | format_char_sequence = "I"*point.num_visible_images 112 | write_next_bytes(fid, [*point.visible_image_idxs], format_char_sequence) 113 | 114 | 115 | def write_fused(points, path_to_fused_ply, path_to_fused_ply_vis): 116 | write_fused_ply(points, path_to_fused_ply) 117 | write_fused_ply_vis(points, path_to_fused_ply_vis) 118 | -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /colmapUtils/read_write_dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright 13 | # notice, this list of conditions and the following disclaimer in the 14 | # documentation and/or other materials provided with the distribution. 15 | # 16 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 17 | # its contributors may be used to endorse or promote products derived 18 | # from this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 24 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 25 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 26 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 27 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 28 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | # POSSIBILITY OF SUCH DAMAGE. 31 | # 32 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 33 | 34 | import argparse 35 | import numpy as np 36 | import os 37 | import struct 38 | 39 | 40 | def read_array(path): 41 | with open(path, "rb") as fid: 42 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 43 | usecols=(0, 1, 2), dtype=int) 44 | fid.seek(0) 45 | num_delimiter = 0 46 | byte = fid.read(1) 47 | while True: 48 | if byte == b"&": 49 | num_delimiter += 1 50 | if num_delimiter >= 3: 51 | break 52 | byte = fid.read(1) 53 | array = np.fromfile(fid, np.float32) 54 | array = array.reshape((width, height, channels), order="F") 55 | return np.transpose(array, (1, 0, 2)).squeeze() 56 | 57 | 58 | def write_array(array, path): 59 | """ 60 | see: src/mvs/mat.h 61 | void Mat::Write(const std::string& path) 62 | """ 63 | assert array.dtype == np.float32 64 | if len(array.shape) == 2: 65 | height, width = array.shape 66 | channels = 1 67 | elif len(array.shape) == 3: 68 | height, width, channels = array.shape 69 | else: 70 | assert False 71 | 72 | with open(path, "w") as fid: 73 | fid.write(str(width) + "&" + str(height) + "&" + str(channels) + "&") 74 | 75 | with open(path, "ab") as fid: 76 | if len(array.shape) == 2: 77 | array_trans = np.transpose(array, (1, 0)) 78 | elif len(array.shape) == 3: 79 | array_trans = np.transpose(array, (1, 0, 2)) 80 | else: 81 | assert False 82 | data_1d = array_trans.reshape(-1, order="F") 83 | data_list = data_1d.tolist() 84 | endian_character = "<" 85 | format_char_sequence = "".join(["f"] * len(data_list)) 86 | byte_data = struct.pack(endian_character + format_char_sequence, *data_list) 87 | fid.write(byte_data) 88 | 89 | 90 | def parse_args(): 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("-d", "--depth_map", 93 | help="path to depth map", type=str, required=True) 94 | parser.add_argument("-n", "--normal_map", 95 | help="path to normal map", type=str, required=True) 96 | parser.add_argument("--min_depth_percentile", 97 | help="minimum visualization depth percentile", 98 | type=float, default=5) 99 | parser.add_argument("--max_depth_percentile", 100 | help="maximum visualization depth percentile", 101 | type=float, default=95) 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | def main(): 107 | args = parse_args() 108 | 109 | if args.min_depth_percentile > args.max_depth_percentile: 110 | raise ValueError("min_depth_percentile should be less than or equal " 111 | "to the max_depth_perceintile.") 112 | 113 | # Read depth and normal maps corresponding to the same image. 114 | if not os.path.exists(args.depth_map): 115 | raise FileNotFoundError("File not found: {}".format(args.depth_map)) 116 | 117 | if not os.path.exists(args.normal_map): 118 | raise FileNotFoundError("File not found: {}".format(args.normal_map)) 119 | 120 | depth_map = read_array(args.depth_map) 121 | normal_map = read_array(args.normal_map) 122 | 123 | min_depth, max_depth = np.percentile( 124 | depth_map, [args.min_depth_percentile, args.max_depth_percentile]) 125 | depth_map[depth_map < min_depth] = min_depth 126 | depth_map[depth_map > max_depth] = max_depth 127 | 128 | import pylab as plt 129 | 130 | # Visualize the depth map. 131 | plt.figure() 132 | plt.imshow(depth_map) 133 | plt.title("depth map") 134 | 135 | # Visualize the normal map. 136 | plt.figure() 137 | plt.imshow(normal_map) 138 | plt.title("normal map") 139 | 140 | plt.show() 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /find_filter/GA.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import scipy.io as sio 4 | import filesort_int 5 | 6 | 7 | def log(i, a, b): 8 | print("epoch --> ", 9 | str(i + 1).rjust(5, " "), " max:", 10 | str(round(a, 4)).rjust(8, " "), "mean:", 11 | str(round(b, 4)).rjust(8, " "), "alpha:", 12 | str(round(a / b, 4)).rjust(8, " ")) 13 | 14 | 15 | class GeneSolve: 16 | ## 初始定义,后续引用GeneSolve类方法为(初始种群数,最大迭代数,交叉概率,变异概率,最大适应度/平均适应度(扰动率趋于平稳则越接近1越好)) 17 | def __init__(self, pose_path, ssf_path, pop_size, epoch, cross_prob, mutate_prob, alpha, poses, aim_poses, print_batch=10): 18 | self.aim_poses = aim_poses 19 | self.ssf_path = ssf_path 20 | self.pose_path = pose_path 21 | self.aim_number = aim_poses[0] * aim_poses[1] 22 | self.pop_size = pop_size 23 | self.epoch = epoch 24 | self.cross_prob = cross_prob 25 | self.mutate_prob = mutate_prob 26 | self.print_batch = print_batch 27 | self.alpha = alpha 28 | self.poses = poses 29 | self.width = poses[0] * poses[1] 30 | self.best = None 31 | self.coor = self.prepare_poses() # np N x 3 32 | self.fai = self.prepare_ssfs() # 12 x 20 x 31 33 | self.dis_mtx = self.cal_distance() # np N x N 34 | 35 | # 产生初始种群 36 | genes = [] 37 | for _ in range(self.pop_size): 38 | tts = [] 39 | for _ in range(self.poses[0]): 40 | tt = ['0']*(self.poses[1]-self.aim_poses[1]) + ['1']*self.aim_poses[1] 41 | np.random.shuffle(tt) 42 | tts.append(tt) 43 | genes.append(np.array(tts)) 44 | self.genes = np.array(genes) 45 | 46 | def prepare_ssfs(self): 47 | fs = filesort_int.sort_file_int(self.ssf_path, 'mat') 48 | filters_list = np.array([ 49 | np.asarray(np.diagonal(sio.loadmat(x)['filter']), dtype=np.float32) for x in fs 50 | ]) 51 | fai = filters_list[np.newaxis, ...].repeat(self.poses[0], axis=0) 52 | 53 | return fai 54 | 55 | def prepare_poses(self): 56 | coor = [] 57 | poses_bounds = np.load(self.pose_path) 58 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 59 | for i in range(poses.shape[0]): 60 | coor.append(poses[i, :, 3]) 61 | 62 | return coor 63 | 64 | def cal_distance(self): 65 | n = len(self.coor) 66 | dis_mtx = np.zeros((n, n)) 67 | cal_dis = lambda a, b: np.sqrt(np.sum((self.coor[a]-self.coor[b]) ** 2)) 68 | for r in range(n): 69 | for c in range(r+1, n): 70 | dis_mtx[r, c] = cal_dis(r, c) 71 | 72 | dis_mtx += dis_mtx.T 73 | return dis_mtx 74 | 75 | 76 | def inter_cross(self): 77 | """对染色体进行交叉操作""" 78 | ready_index = list(range(self.pop_size)) 79 | while len(ready_index) >= 2: 80 | d1 = random.choice(ready_index) 81 | ready_index.remove(d1) 82 | d2 = random.choice(ready_index) 83 | ready_index.remove(d2) 84 | if np.random.uniform(0, 1) <= self.cross_prob: 85 | loc = random.choice(range(1, self.poses[1] - 1)) 86 | temp = self.genes[d1][:, loc:] 87 | self.genes[d1][:, loc:] = self.genes[d2][:, loc:] 88 | self.genes[d2][:, loc:] = temp 89 | 90 | 91 | def mutate(self): 92 | """基因突变""" 93 | ready_index = list(range(self.pop_size)) 94 | for i in ready_index: 95 | t0 = self.genes[i] 96 | if np.random.uniform(0, 1) <= self.mutate_prob: 97 | loc = random.choice(range(0, self.width)) 98 | r, c = loc // self.poses[1], loc % self.poses[1] 99 | t0[r, c] = str(1 - int(t0[r, c])) 100 | # fix the total number to ** 101 | for rr in range(self.poses[0]): 102 | t0[rr, :] = self.fix_total_number(list(t0[rr, :])) 103 | self.genes[i] = t0 104 | 105 | 106 | def fix_total_number(self, genes): 107 | ones_num = genes.count('1') 108 | genes = np.array(genes) 109 | if ones_num != self.aim_poses[1]: 110 | to_what = '1' if ones_num < self.aim_poses[1] else '0' 111 | 112 | tt = list(np.argwhere(genes == '1').squeeze()) 113 | ones_idx = random.sample(tt, abs(ones_num - self.aim_poses[1])) 114 | genes[ones_idx] = to_what 115 | 116 | return genes 117 | 118 | 119 | def get_adjust(self): 120 | """计算适应度(只有在计算适应度的时候要反函数,其余过程全都是随机的二进制编码)""" 121 | x = self.get_decode() 122 | return x * np.sin(x) + 12 123 | 124 | def get_decode(self): 125 | """编码,从表现型到基因型的映射""" 126 | aimssfs = [ 127 | self.fai[np.where(idx == '0', False, True)] for idx in self.genes 128 | ] 129 | 130 | return aimssfs 131 | 132 | 133 | def cycle_select(self): 134 | """通过轮盘赌来进行选择""" 135 | adjusts = self.get_adjust() 136 | if self.best is None or np.max(adjusts) > self.best[1]: 137 | self.best = self.genes[np.argmax(adjusts)], np.max(adjusts) 138 | p = adjusts / np.sum(adjusts) 139 | cu_p = [] 140 | for i in range(self.pop_size): 141 | cu_p.append(np.sum(p[0:i])) 142 | cu_p = np.array(cu_p) 143 | r0 = np.random.uniform(0, 1, self.pop_size) 144 | sel = [max(list(np.where(r > cu_p)[0]) + [0]) for r in r0] 145 | # 保留最优的个体 146 | if np.max(adjusts[sel]) < self.best[1]: 147 | self.genes[sel[np.argmin(adjusts[sel])]] = self.best[0] 148 | self.genes = self.genes[sel] 149 | 150 | def evolve(self): 151 | """逐代演化""" 152 | for i in range(self.epoch): 153 | self.cycle_select() #种群选取 154 | self.inter_cross() #染色体交叉 155 | self.mutate() #计算适应度 156 | a, b = np.max(self.get_adjust()), np.mean(self.get_adjust()) 157 | if i % self.print_batch == self.print_batch - 1 or i == 0: 158 | log(i, a, b) 159 | if a / b < self.alpha: 160 | log(i, a, b) 161 | print("进化终止,算法已收敛!共进化 ", i + 1, " 代!") 162 | break 163 | 164 | 165 | if __name__ == '__main__': 166 | gs = GeneSolve('../myspecdata/filter20_no1/newDoraemon/poses_bounds.npy', 167 | '../myspecdata/filter20_no1/filters/*.mat', 168 | 100, 500, 0.65, 0.1, 1.2, (12, 20), (10, 3)) 169 | gs.evolve() 170 | 171 | -------------------------------------------------------------------------------- /dataLoader/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class NSVF(Dataset): 37 | """NSVF Generic Dataset.""" 38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): 39 | self.root_dir = datadir 40 | self.split = split 41 | self.is_stack = is_stack 42 | self.downsample = downsample 43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 44 | self.define_transforms() 45 | 46 | self.white_bg = True 47 | self.near_far = [0.5,6.0] 48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) 49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 50 | self.read_meta() 51 | self.define_proj_mat() 52 | 53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 55 | 56 | def bbox2corners(self): 57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 58 | for i in range(3): 59 | corners[i,[0,1],i] = corners[i,[1,0],i] 60 | return corners.view(-1,3) 61 | 62 | 63 | def read_meta(self): 64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: 65 | focal = float(f.readline().split()[0]) 66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) 67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) 68 | 69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 71 | 72 | if self.split == 'train': 73 | pose_files = [x for x in pose_files if x.startswith('0_')] 74 | img_files = [x for x in img_files if x.startswith('0_')] 75 | elif self.split == 'val': 76 | pose_files = [x for x in pose_files if x.startswith('1_')] 77 | img_files = [x for x in img_files if x.startswith('1_')] 78 | elif self.split == 'test': 79 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 80 | test_img_files = [x for x in img_files if x.startswith('2_')] 81 | if len(test_pose_files) == 0: 82 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 83 | test_img_files = [x for x in img_files if x.startswith('1_')] 84 | pose_files = test_pose_files 85 | img_files = test_img_files 86 | 87 | # ray directions for all pixels, same for all images (same H, W, focal) 88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 90 | 91 | 92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 93 | 94 | self.poses = [] 95 | self.all_rays = [] 96 | self.all_rgbs = [] 97 | 98 | assert len(img_files) == len(pose_files) 99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 101 | img = Image.open(image_path) 102 | if self.downsample!=1.0: 103 | img = img.resize(self.img_wh, Image.LANCZOS) 104 | img = self.transform(img) # (4, h, w) 105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 106 | if img.shape[-1]==4: 107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 108 | self.all_rgbs += [img] 109 | 110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv 111 | c2w = torch.FloatTensor(c2w) 112 | self.poses.append(c2w) # C2W 113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 115 | 116 | # w2c = torch.inverse(c2w) 117 | # 118 | 119 | self.poses = torch.stack(self.poses) 120 | if 'train' == self.split: 121 | if self.is_stack: 122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 124 | else: 125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 127 | else: 128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 130 | 131 | 132 | def define_transforms(self): 133 | self.transform = T.ToTensor() 134 | 135 | def define_proj_mat(self): 136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 137 | 138 | def world2ndc(self, points): 139 | device = points.device 140 | return (points - self.center.to(device)) / self.radius.to(device) 141 | 142 | def __len__(self): 143 | if self.split == 'train': 144 | return len(self.all_rays) 145 | return len(self.all_rgbs) 146 | 147 | def __getitem__(self, idx): 148 | 149 | if self.split == 'train': # use data in the buffers 150 | sample = {'rays': self.all_rays[idx], 151 | 'rgbs': self.all_rgbs[idx]} 152 | 153 | else: # create data for each image separately 154 | 155 | img = self.all_rgbs[idx] 156 | rays = self.all_rays[idx] 157 | 158 | sample = {'rays': rays, 159 | 'rgbs': img} 160 | return sample -------------------------------------------------------------------------------- /extra/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import numpy as np 3 | import scipy.signal 4 | from typing import List, Optional 5 | from PIL import Image 6 | import os 7 | import torch 8 | import configargparse 9 | 10 | __LPIPS__ = {} 11 | def init_lpips(net_name, device): 12 | assert net_name in ['alex', 'vgg'] 13 | import lpips 14 | print(f'init_lpips: lpips_{net_name}') 15 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 16 | 17 | def rgb_lpips(np_gt, np_im, net_name, device): 18 | if net_name not in __LPIPS__: 19 | __LPIPS__[net_name] = init_lpips(net_name, device) 20 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 21 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 22 | return __LPIPS__[net_name](gt, im, normalize=True).item() 23 | 24 | 25 | def findItem(items, target): 26 | for one in items: 27 | if one[:len(target)]==target: 28 | return one 29 | return None 30 | 31 | 32 | ''' Evaluation metrics (ssim, lpips) 33 | ''' 34 | def rgb_ssim(img0, img1, max_val, 35 | filter_size=11, 36 | filter_sigma=1.5, 37 | k1=0.01, 38 | k2=0.03, 39 | return_map=False): 40 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 41 | assert len(img0.shape) == 3 42 | # assert img0.shape[-1] == 3 43 | assert img0.shape == img1.shape 44 | 45 | # Construct a 1D Gaussian blur filter. 46 | hw = filter_size // 2 47 | shift = (2 * hw - filter_size + 1) / 2 48 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 49 | filt = np.exp(-0.5 * f_i) 50 | filt /= np.sum(filt) 51 | 52 | # Blur in x and y (faster than the 2D convolution). 53 | def convolve2d(z, f): 54 | return scipy.signal.convolve2d(z, f, mode='valid') 55 | 56 | filt_fn = lambda z: np.stack([ 57 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 58 | for i in range(z.shape[-1])], -1) 59 | mu0 = filt_fn(img0) 60 | mu1 = filt_fn(img1) 61 | mu00 = mu0 * mu0 62 | mu11 = mu1 * mu1 63 | mu01 = mu0 * mu1 64 | sigma00 = filt_fn(img0**2) - mu00 65 | sigma11 = filt_fn(img1**2) - mu11 66 | sigma01 = filt_fn(img0 * img1) - mu01 67 | 68 | # Clip the variances and covariances to valid values. 69 | # Variance must be non-negative: 70 | sigma00 = np.maximum(0., sigma00) 71 | sigma11 = np.maximum(0., sigma11) 72 | sigma01 = np.sign(sigma01) * np.minimum( 73 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 74 | c1 = (k1 * max_val)**2 75 | c2 = (k2 * max_val)**2 76 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 77 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 78 | ssim_map = numer / denom 79 | ssim = np.mean(ssim_map) 80 | return ssim_map if return_map else ssim 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | parser = configargparse.ArgumentParser() 86 | parser.add_argument("--exp", type=str, help="folder of exps") 87 | parser.add_argument("--paramStr", type=str, help="str of params") 88 | args = parser.parse_args() 89 | 90 | 91 | # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']# 92 | # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic" 93 | # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp 94 | 95 | # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']# 96 | # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/" 97 | # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp 98 | paramStr = args.paramStr 99 | fileNum = 200 100 | 101 | 102 | expitems = os.listdir(expFolder) 103 | finalFolder = f'{expFolder}/finals/{paramStr}' 104 | outFile = f'{finalFolder}/{paramStr}_metrics.txt' 105 | os.makedirs(finalFolder, exist_ok=True) 106 | 107 | expitems.sort(reverse=True) 108 | 109 | 110 | with open(outFile, 'w') as f: 111 | all_psnr = [] 112 | all_ssim = [] 113 | all_alex = [] 114 | all_vgg = [] 115 | for dataname in datanames: 116 | 117 | 118 | gtstr = gtFolder+"/"+dataname+"/test/r_%d.png" 119 | expname = findItem(expitems, f'{paramStr}-{dataname}') 120 | print("expname: ", expname) 121 | if expname is None: 122 | print("no ",dataname, "exists") 123 | continue 124 | resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png" 125 | metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt' 126 | video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4' 127 | 128 | exist_metric=False 129 | if os.path.isfile(metric_file): 130 | metrics = np.loadtxt(metric_file) 131 | print(metrics, metrics.tolist()) 132 | if metrics.size == 4: 133 | psnr, ssim, l_a, l_v = metrics.tolist() 134 | exist_metric = True 135 | os.system(f"cp {video_file} {finalFolder}/") 136 | 137 | if not exist_metric: 138 | psnrs = [] 139 | ssims = [] 140 | l_alex = [] 141 | l_vgg = [] 142 | for i in range(fileNum): 143 | gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0 144 | gtmask = gt[...,[3]] 145 | gt = gt[...,:3] 146 | gt = gt*gtmask + (1-gtmask) 147 | img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0 148 | # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max()) 149 | 150 | 151 | psnr = -10. * np.log10(np.mean(np.square(img - gt))) 152 | ssim = rgb_ssim(img, gt, 1) 153 | lpips_alex = rgb_lpips(gt, img, 'alex','cuda') 154 | lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda') 155 | 156 | print(i, psnr, ssim, lpips_alex, lpips_vgg) 157 | psnrs.append(psnr) 158 | ssims.append(ssim) 159 | l_alex.append(lpips_alex) 160 | l_vgg.append(lpips_vgg) 161 | psnr = np.mean(np.array(psnrs)) 162 | ssim = np.mean(np.array(ssims)) 163 | l_a = np.mean(np.array(l_alex)) 164 | l_v = np.mean(np.array(l_vgg)) 165 | 166 | rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 167 | print(rS) 168 | f.write(rS+"\n") 169 | 170 | all_psnr.append(psnr) 171 | all_ssim.append(ssim) 172 | all_alex.append(l_a) 173 | all_vgg.append(l_v) 174 | 175 | psnr = np.mean(np.array(all_psnr)) 176 | ssim = np.mean(np.array(all_ssim)) 177 | l_a = np.mean(np.array(all_alex)) 178 | l_v = np.mean(np.array(all_vgg)) 179 | 180 | rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 181 | print(rS) 182 | f.write(rS+"\n") -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def config_parser(cmd=None): 4 | parser = configargparse.ArgumentParser() 5 | parser.add_argument('--config', is_config_file=True, 6 | help='config file path') 7 | parser.add_argument("--expname", type=str, 8 | help='experiment name') 9 | parser.add_argument("--basedir", type=str, default='./log', 10 | help='where to store ckpts and logs') 11 | parser.add_argument("--add_timestamp", type=int, default=0, 12 | help='add timestamp to dir') 13 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 14 | help='input data directory') 15 | parser.add_argument("--progress_refresh_rate", type=int, default=10, 16 | help='how many iterations to show psnrs or iters') 17 | 18 | parser.add_argument('--with_depth', action='store_true') 19 | parser.add_argument('--downsample_train', type=float, default=1.0) 20 | parser.add_argument('--downsample_test', type=float, default=1.0) 21 | 22 | parser.add_argument('--model_name', type=str, default='TensorVMSplit', 23 | choices=['TensorVMSplit', 'TensorCP']) 24 | 25 | # loader options 26 | parser.add_argument("--batch_size", type=int, default=4096) 27 | parser.add_argument("--chunk_size", type=int, default=2048, help='the size of putting into the gpu') 28 | parser.add_argument("--n_iters", type=int, default=30000) 29 | 30 | parser.add_argument('--dataset_name', type=str, default='blender', 31 | choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data', 'synthetic']) 32 | 33 | 34 | # training options 35 | # learning rate 36 | parser.add_argument("--lr_init", type=float, default=0.02, 37 | help='learning rate') 38 | parser.add_argument("--lr_basis", type=float, default=1e-3, 39 | help='learning rate') 40 | parser.add_argument("--lr_decay_iters", type=int, default=-1, 41 | help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters') 42 | parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1, 43 | help='the target decay ratio; after decay_iters inital lr decays to lr*ratio') 44 | parser.add_argument("--lr_upsample_reset", type=int, default=1, 45 | help='reset lr to inital after upsampling') 46 | 47 | # loss 48 | parser.add_argument("--L1_weight_inital", type=float, default=0.0, 49 | help='loss weight') 50 | parser.add_argument("--L1_weight_rest", type=float, default=0, 51 | help='loss weight') 52 | parser.add_argument("--Ortho_weight", type=float, default=0.0, 53 | help='loss weight') 54 | parser.add_argument("--TV_weight_density", type=float, default=0.0, 55 | help='loss weight') 56 | parser.add_argument("--TV_weight_app", type=float, default=0.0, 57 | help='loss weight') 58 | 59 | # model 60 | # volume options 61 | parser.add_argument("--n_lamb_sigma", type=int, action="append") 62 | parser.add_argument("--n_lamb_sh", type=int, action="append") 63 | parser.add_argument("--data_dim_color", type=int, default=27) 64 | 65 | parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001, 66 | help='mask points in ray marching') 67 | parser.add_argument("--alpha_mask_thre", type=float, default=0.0001, 68 | help='threshold for creating alpha mask volume') 69 | parser.add_argument("--distance_scale", type=float, default=25, 70 | help='scaling sampling distance for computation') 71 | parser.add_argument("--density_shift", type=float, default=-10, 72 | help='shift density in softplus; making density = 0 when feature == 0') 73 | 74 | # network decoder 75 | parser.add_argument("--shadingMode", type=str, default="MLP_PE", 76 | help='which shading mode to use') 77 | parser.add_argument("--pos_pe", type=int, default=6, 78 | help='number of pe for pos') 79 | parser.add_argument("--view_pe", type=int, default=6, 80 | help='number of pe for view') 81 | parser.add_argument("--fea_pe", type=int, default=6, 82 | help='number of pe for features') 83 | parser.add_argument("--featureC", type=int, default=128, 84 | help='hidden feature channel in MLP') 85 | 86 | 87 | 88 | parser.add_argument("--ckpt", type=str, default=None, 89 | help='specific weights npy file to reload for coarse network') 90 | parser.add_argument("--render_only", type=int, default=0) 91 | parser.add_argument("--render_test", type=int, default=0) 92 | parser.add_argument("--render_train", type=int, default=0) 93 | parser.add_argument("--render_path", type=int, default=0) 94 | parser.add_argument("--export_mesh", type=int, default=0) 95 | 96 | # rendering options 97 | parser.add_argument('--lindisp', default=False, action="store_true", 98 | help='use disparity depth sampling') 99 | parser.add_argument("--perturb", type=float, default=1., 100 | help='set to 0. for no jitter, 1. for jitter') 101 | parser.add_argument("--accumulate_decay", type=float, default=0.998) 102 | parser.add_argument("--fea2denseAct", type=str, default='softplus') 103 | parser.add_argument('--ndc_ray', type=int, default=0) 104 | parser.add_argument('--nSamples', type=int, default=1e6, 105 | help='sample point each ray, pass 1e6 if automatic adjust') 106 | parser.add_argument('--step_ratio',type=float,default=0.5) 107 | 108 | 109 | ## blender flags 110 | parser.add_argument("--white_bkgd", action='store_true', 111 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 112 | 113 | 114 | 115 | parser.add_argument('--N_voxel_init', 116 | type=int, 117 | default=100**3) 118 | parser.add_argument('--N_voxel_final', 119 | type=int, 120 | default=300**3) 121 | parser.add_argument("--upsamp_list", type=int, action="append") 122 | parser.add_argument("--update_AlphaMask_list", type=int, action="append") 123 | 124 | parser.add_argument('--idx_view', 125 | type=int, 126 | default=0) 127 | # logging/saving options 128 | parser.add_argument("--N_vis", type=int, default=5, 129 | help='N images to vis') 130 | parser.add_argument("--vis_every", type=int, default=10000, 131 | help='frequency of visualize the image') 132 | 133 | parser.add_argument('--angles',type=int,default=10) 134 | parser.add_argument('--filters',type=int,default=20) 135 | parser.add_argument('--img_dir_name',type=str,default='pose??img') 136 | parser.add_argument('--spec_channel', type=int, default=30) 137 | parser.add_argument('--observation_channel', type=int, default=3) 138 | parser.add_argument('--depth_supervise', action='store_true') 139 | parser.add_argument('--reset_para', action='store_true') 140 | parser.add_argument('--rgb4shape_endIter', type=int,default=0) 141 | parser.add_argument('--depth_batchsize_endIter', type=int, action="append", default=[0, 0]) 142 | parser.add_argument('--sample_matrix_dir',type=str,default='null') 143 | parser.add_argument('--img_ext',type=str,default='dng') 144 | parser.add_argument('--ssf_model',type=str,default='fcn', help='fcn or rbf') 145 | parser.add_argument('--ssf_model_components',type=int,default=3) 146 | parser.add_argument('--distortion_loss', action='store_true') 147 | parser.add_argument('--band_start_idx',type=int,default=3) 148 | parser.add_argument('--colIdx4RGBTrain', help='col index in fixed sample matrix',type=int,default=0) 149 | parser.add_argument('--lsc', action='store_true') 150 | parser.add_argument("--crop_hw", type=int, action="append") 151 | parser.add_argument("--TV_weight_spec", type=float, default=0.1) 152 | parser.add_argument("--filters_folder", type=str, default='filters') 153 | parser.add_argument('--render_test_exhibition',type=int,default=0) 154 | parser.add_argument("--exhibition_filters_path", type=str, default='none.mat') 155 | parser.add_argument("--exhibition_ssfs_path", type=str, default='none.mat') 156 | parser.add_argument("--exhibition_lights_path", type=str, default='none.mat') 157 | parser.add_argument("--exhibition_lightorigin_path", type=str, default='none.mat') 158 | 159 | 160 | if cmd is not None: 161 | return parser.parse_args(cmd) 162 | else: 163 | return parser.parse_args() 164 | 165 | args = config_parser() 166 | 167 | -------------------------------------------------------------------------------- /extra/auto_run_paramsets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading, queue 3 | import numpy as np 4 | import time 5 | 6 | 7 | def getFolderLocker(logFolder): 8 | while True: 9 | try: 10 | os.makedirs(logFolder+"/lockFolder") 11 | break 12 | except: 13 | time.sleep(0.01) 14 | 15 | def releaseFolderLocker(logFolder): 16 | os.removedirs(logFolder+"/lockFolder") 17 | 18 | def getStopFolder(logFolder): 19 | return os.path.isdir(logFolder+"/stopFolder") 20 | 21 | 22 | def get_param_str(key, val): 23 | if key == 'data_name': 24 | return f'--datadir {datafolder}/{val} ' 25 | else: 26 | return f'--{key} {val} ' 27 | 28 | def get_param_list(param_dict): 29 | param_keys = list(param_dict.keys()) 30 | param_modes = len(param_keys) 31 | param_nums = [len(param_dict[key]) for key in param_keys] 32 | 33 | param_ids = np.zeros(param_nums+[param_modes], dtype=int) 34 | for i in range(param_modes): 35 | broad_tuple = np.ones(param_modes, dtype=int).tolist() 36 | broad_tuple[i] = param_nums[i] 37 | broad_tuple = tuple(broad_tuple) 38 | print(broad_tuple) 39 | param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple) 40 | param_ids = param_ids.reshape(-1, param_modes) 41 | # print(param_ids) 42 | print(len(param_ids)) 43 | 44 | params = [] 45 | expnames = [] 46 | for i in range(param_ids.shape[0]): 47 | one = "" 48 | name = "" 49 | param_id = param_ids[i] 50 | for j in range(param_modes): 51 | key = param_keys[j] 52 | val = param_dict[key][param_id[j]] 53 | if type(key) is tuple: 54 | assert len(key) == len(val) 55 | for k in range(len(key)): 56 | one += get_param_str(key[k], val[k]) 57 | name += f'{val[k]},' 58 | name=name[:-1]+'-' 59 | else: 60 | one += get_param_str(key, val) 61 | name += f'{val}-' 62 | params.append(one) 63 | name=name.replace(' ','') 64 | print(name) 65 | expnames.append(name[:-1]) 66 | # print(params) 67 | return params, expnames 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | 78 | 79 | # nerf 80 | expFolder = "nerf/" 81 | # parameters to iterate, use tuple to couple multiple parameters 82 | datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/' 83 | param_dict = { 84 | 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'], 85 | 'data_dim_color': [13, 27, 54] 86 | } 87 | 88 | # n_iters = 30000 89 | # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 90 | # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \ 91 | # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\ 92 | # f'--expname {data_name} --batch_size {batch_size} ' \ 93 | # f'--n_iters {n_iters} ' \ 94 | # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\ 95 | # f'--N_vis {5} ' \ 96 | # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \ 97 | # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \ 98 | # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \ 99 | # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \ 100 | # f'--render_test 1 ' 101 | # print(cmd) 102 | # os.system(cmd) 103 | 104 | # nsvf 105 | # expFolder = "nsvf_0227/" 106 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/' 107 | # param_dict = { 108 | # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 109 | # 'shadingMode': ['SH'], 110 | # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")], 111 | # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)], 112 | # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)], 113 | # ('n_iters','N_voxel_final'): [(30000,300**3)], 114 | # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)], 115 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")] 116 | # 117 | # } 118 | 119 | # tankstemple 120 | # expFolder = "tankstemple_0304/" 121 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/' 122 | # param_dict = { 123 | # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'], 124 | # 'shadingMode': ['MLP_Fea'], 125 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")], 126 | # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)], 127 | # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)], 128 | # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)], 129 | # ('n_iters','N_voxel_final'): [(15000,300**3)], 130 | # ('dataset_name','N_vis') : [("tankstemple",5)], 131 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")] 132 | # } 133 | 134 | # llff 135 | # expFolder = "real_iconic/" 136 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/' 137 | # List = os.listdir(datafolder) 138 | # param_dict = { 139 | # 'data_name': List, 140 | # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)], 141 | # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")], 142 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 143 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 144 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)], 145 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 146 | # } 147 | 148 | # expFolder = "llff/" 149 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data' 150 | # param_dict = { 151 | # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids' 152 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")], 153 | # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)], 154 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 155 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 156 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)], 157 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 158 | # } 159 | 160 | #setting available gpus 161 | gpus_que = queue.Queue(3) 162 | for i in [1,2,3]: 163 | gpus_que.put(i) 164 | 165 | os.makedirs(f"log/{expFolder}", exist_ok=True) 166 | 167 | def run_program(gpu, expname, param): 168 | cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \ 169 | f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \ 170 | f'{param}' \ 171 | f'> "log/{expFolder}{expname}/{expname}.txt"' 172 | print(cmd) 173 | os.system(cmd) 174 | gpus_que.put(gpu) 175 | 176 | params, expnames = get_param_list(param_dict) 177 | 178 | 179 | logFolder=f"log/{expFolder}" 180 | os.makedirs(logFolder, exist_ok=True) 181 | 182 | ths = [] 183 | for i in range(len(params)): 184 | 185 | if getStopFolder(logFolder): 186 | break 187 | 188 | 189 | targetFolder = f"log/{expFolder}{expnames[i]}" 190 | gpu = gpus_que.get() 191 | getFolderLocker(logFolder) 192 | if os.path.isdir(targetFolder): 193 | releaseFolderLocker(logFolder) 194 | gpus_que.put(gpu) 195 | continue 196 | else: 197 | os.makedirs(targetFolder, exist_ok=True) 198 | print("making",targetFolder, "running",expnames[i], params[i]) 199 | releaseFolderLocker(logFolder) 200 | 201 | 202 | t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True) 203 | t.start() 204 | ths.append(t) 205 | 206 | for th in ths: 207 | th.join() -------------------------------------------------------------------------------- /dataLoader/tankstemple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | 11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 12 | if axis == 'z': 13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] 14 | elif axis == 'y': 15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] 16 | else: 17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] 18 | 19 | 20 | def cross(x, y, axis=0): 21 | T = torch if isinstance(x, torch.Tensor) else np 22 | return T.cross(x, y, axis) 23 | 24 | 25 | def normalize(x, axis=-1, order=2): 26 | if isinstance(x, torch.Tensor): 27 | l2 = x.norm(p=order, dim=axis, keepdim=True) 28 | return x / (l2 + 1e-8), l2 29 | 30 | else: 31 | l2 = np.linalg.norm(x, order, axis) 32 | l2 = np.expand_dims(l2, axis) 33 | l2[l2 == 0] = 1 34 | return x / l2, 35 | 36 | 37 | def cat(x, axis=1): 38 | if isinstance(x[0], torch.Tensor): 39 | return torch.cat(x, dim=axis) 40 | return np.concatenate(x, axis=axis) 41 | 42 | 43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 44 | """ 45 | This function takes a vector 'camera_position' which specifies the location 46 | of the camera in world coordinates and two vectors `at` and `up` which 47 | indicate the position of the object and the up directions of the world 48 | coordinate system respectively. The object is assumed to be centered at 49 | the origin. 50 | The output is a rotation matrix representing the transformation 51 | from world coordinates -> view coordinates. 52 | Input: 53 | camera_position: 3 54 | at: 1 x 3 or N x 3 (0, 0, 0) in default 55 | up: 1 x 3 or N x 3 (0, 1, 0) in default 56 | """ 57 | 58 | if at is None: 59 | at = torch.zeros_like(camera_position) 60 | else: 61 | at = torch.tensor(at).type_as(camera_position) 62 | if up is None: 63 | up = torch.zeros_like(camera_position) 64 | up[2] = -1 65 | else: 66 | up = torch.tensor(up).type_as(camera_position) 67 | 68 | z_axis = normalize(at - camera_position)[0] 69 | x_axis = normalize(cross(up, z_axis))[0] 70 | y_axis = normalize(cross(z_axis, x_axis))[0] 71 | 72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 73 | return R 74 | 75 | 76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): 77 | c2ws = [] 78 | for t in range(frames): 79 | c2w = torch.eye(4) 80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) 81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) 82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot 83 | c2ws.append(c2w) 84 | return torch.stack(c2ws) 85 | 86 | class TanksTempleDataset(Dataset): 87 | """NSVF Generic Dataset.""" 88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False): 89 | self.root_dir = datadir 90 | self.split = split 91 | self.is_stack = is_stack 92 | self.downsample = downsample 93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 94 | self.define_transforms() 95 | 96 | self.white_bg = True 97 | self.near_far = [0.01,6.0] 98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2 99 | 100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 101 | self.read_meta() 102 | self.define_proj_mat() 103 | 104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 106 | 107 | def bbox2corners(self): 108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 109 | for i in range(3): 110 | corners[i,[0,1],i] = corners[i,[1,0],i] 111 | return corners.view(-1,3) 112 | 113 | 114 | def read_meta(self): 115 | 116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) 117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1) 118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 120 | 121 | if self.split == 'train': 122 | pose_files = [x for x in pose_files if x.startswith('0_')] 123 | img_files = [x for x in img_files if x.startswith('0_')] 124 | elif self.split == 'val': 125 | pose_files = [x for x in pose_files if x.startswith('1_')] 126 | img_files = [x for x in img_files if x.startswith('1_')] 127 | elif self.split == 'test': 128 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 129 | test_img_files = [x for x in img_files if x.startswith('2_')] 130 | if len(test_pose_files) == 0: 131 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 132 | test_img_files = [x for x in img_files if x.startswith('1_')] 133 | pose_files = test_pose_files 134 | img_files = test_img_files 135 | 136 | # ray directions for all pixels, same for all images (same H, W, focal) 137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 139 | 140 | 141 | 142 | self.poses = [] 143 | self.all_rays = [] 144 | self.all_rgbs = [] 145 | 146 | assert len(img_files) == len(pose_files) 147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 149 | img = Image.open(image_path) 150 | if self.downsample!=1.0: 151 | img = img.resize(self.img_wh, Image.LANCZOS) 152 | img = self.transform(img) # (4, h, w) 153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 154 | if img.shape[-1]==4: 155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 156 | self.all_rgbs.append(img) 157 | 158 | 159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans 160 | c2w = torch.FloatTensor(c2w) 161 | self.poses.append(c2w) # C2W 162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 164 | 165 | self.poses = torch.stack(self.poses) 166 | 167 | center = torch.mean(self.scene_bbox, dim=0) 168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2 169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() 170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y') 171 | self.render_path = gen_path(pos_gen, up=up,frames=200) 172 | self.render_path[:, :3, 3] += center 173 | 174 | 175 | 176 | if 'train' == self.split: 177 | if self.is_stack: 178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 180 | else: 181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 183 | else: 184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 186 | 187 | 188 | def define_transforms(self): 189 | self.transform = T.ToTensor() 190 | 191 | def define_proj_mat(self): 192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 193 | 194 | def world2ndc(self, points): 195 | device = points.device 196 | return (points - self.center.to(device)) / self.radius.to(device) 197 | 198 | def __len__(self): 199 | if self.split == 'train': 200 | return len(self.all_rays) 201 | return len(self.all_rgbs) 202 | 203 | def __getitem__(self, idx): 204 | 205 | if self.split == 'train': # use data in the buffers 206 | sample = {'rays': self.all_rays[idx], 207 | 'rgbs': self.all_rgbs[idx]} 208 | 209 | else: # create data for each image separately 210 | 211 | img = self.all_rgbs[idx] 212 | rays = self.all_rays[idx] 213 | 214 | sample = {'rays': rays, 215 | 'rgbs': img} 216 | return sample -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import cv2,torch 4 | import numpy as np 5 | from PIL import Image 6 | import torchvision.transforms as T 7 | import torch.nn.functional as F 8 | import scipy.signal 9 | from opt import args 10 | 11 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 12 | 13 | def get_filter_path(id_file): 14 | with open(id_file) as f: 15 | res = f.readlines() 16 | res = list(map(lambda x: x.strip(), res)) 17 | filterset, ids = res[0], res[1:] 18 | filterpaths = list(map( 19 | lambda x: str(Path(filterset) / f'{x}.mat'), 20 | ids 21 | )) 22 | return filterpaths 23 | 24 | 25 | if __name__ == '__main__': 26 | get_filter_path(r'E:\pythonProject\python3 v2\SpecNeRF-v2\myspecdata\filters19_optimized\filters\ids.txt') 27 | 28 | def positionencoding1D(W, L): 29 | 30 | x_linspace = (np.linspace(0, W - 1, W) / W) * 2 - 1 31 | 32 | x_el = [] 33 | 34 | x_el_hf = [] 35 | 36 | pe_1d = np.zeros((W, 2*L+1)) 37 | # cache the values so you don't have to do function calls at every pixel 38 | for el in range(0, L): 39 | val = 2 ** el 40 | 41 | x = np.sin(val * np.pi * x_linspace) 42 | x_el.append(x) 43 | 44 | x = np.cos(val * np.pi * x_linspace) 45 | x_el_hf.append(x) 46 | 47 | 48 | for x_i in range(0, W): 49 | 50 | p_enc = [] 51 | 52 | for li in range(0, L): 53 | p_enc.append(x_el[li][x_i]) 54 | p_enc.append(x_el_hf[li][x_i]) 55 | 56 | p_enc.append(x_linspace[x_i]) 57 | 58 | pe_1d[x_i] = np.array(p_enc) 59 | 60 | return pe_1d.astype('float32') 61 | 62 | 63 | def norm0to1(x): 64 | xmin, _ = x.min(dim=1, keepdim=True) 65 | xmax, _ = x.max(dim=1, keepdim=True) 66 | x_norm = (x - xmin) / (xmax - xmin) 67 | 68 | return x_norm 69 | 70 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 71 | """ 72 | depth: (H, W) 73 | """ 74 | 75 | x = np.nan_to_num(depth) # change nan to 0 76 | if minmax is None: 77 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 78 | ma = np.max(x) 79 | else: 80 | mi,ma = minmax 81 | 82 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 83 | x = (255*x).astype(np.uint8) 84 | x_ = cv2.applyColorMap(x, cmap) 85 | return x_, [mi,ma] 86 | 87 | 88 | def visualize_depth_numpy_mono(depth, minmax=None): 89 | """ 90 | depth: (H, W) 91 | """ 92 | 93 | x = np.nan_to_num(depth) # change nan to 0 94 | if minmax is None: 95 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 96 | ma = np.max(x) 97 | else: 98 | mi,ma = minmax 99 | 100 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 101 | x = (255*x).astype(np.uint8) 102 | 103 | return x[..., np.newaxis] 104 | 105 | 106 | def init_log(log, keys): 107 | for key in keys: 108 | log[key] = torch.tensor([0.0], dtype=float) 109 | return log 110 | 111 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 112 | """ 113 | depth: (H, W) 114 | """ 115 | if type(depth) is not np.ndarray: 116 | depth = depth.cpu().numpy() 117 | 118 | x = np.nan_to_num(depth) # change nan to 0 119 | if minmax is None: 120 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 121 | ma = np.max(x) 122 | else: 123 | mi,ma = minmax 124 | 125 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 126 | x = (255*x).astype(np.uint8) 127 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 128 | x_ = T.ToTensor()(x_) # (3, H, W) 129 | return x_, [mi,ma] 130 | 131 | def N_to_reso(n_voxels, bbox): 132 | xyz_min, xyz_max = bbox 133 | dim = len(xyz_min) 134 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 135 | return ((xyz_max - xyz_min) / voxel_size).long().tolist() 136 | 137 | def cal_n_samples(reso, step_ratio=0.5): 138 | return int(np.linalg.norm(reso)/step_ratio) 139 | 140 | 141 | 142 | 143 | __LPIPS__ = {} 144 | def init_lpips(net_name, device): 145 | assert net_name in ['alex', 'vgg'] 146 | import lpips 147 | print(f'init_lpips: lpips_{net_name}') 148 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 149 | 150 | def rgb_lpips(np_gt, np_im, net_name, device): 151 | if net_name not in __LPIPS__: 152 | __LPIPS__[net_name] = init_lpips(net_name, device) 153 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 154 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 155 | return __LPIPS__[net_name](gt, im, normalize=True).item() 156 | 157 | 158 | def findItem(items, target): 159 | for one in items: 160 | if one[:len(target)]==target: 161 | return one 162 | return None 163 | 164 | 165 | ''' Evaluation metrics (ssim, lpips) 166 | ''' 167 | def rgb_ssim(img0, img1, max_val, 168 | filter_size=11, 169 | filter_sigma=1.5, 170 | k1=0.01, 171 | k2=0.03, 172 | return_map=False): 173 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 174 | assert len(img0.shape) == 3 175 | # assert img0.shape[-1] == 3 176 | assert img0.shape == img1.shape 177 | 178 | # Construct a 1D Gaussian blur filter. 179 | hw = filter_size // 2 180 | shift = (2 * hw - filter_size + 1) / 2 181 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 182 | filt = np.exp(-0.5 * f_i) 183 | filt /= np.sum(filt) 184 | 185 | # Blur in x and y (faster than the 2D convolution). 186 | def convolve2d(z, f): 187 | return scipy.signal.convolve2d(z, f, mode='valid') 188 | 189 | filt_fn = lambda z: np.stack([ 190 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 191 | for i in range(z.shape[-1])], -1) 192 | mu0 = filt_fn(img0) 193 | mu1 = filt_fn(img1) 194 | mu00 = mu0 * mu0 195 | mu11 = mu1 * mu1 196 | mu01 = mu0 * mu1 197 | sigma00 = filt_fn(img0**2) - mu00 198 | sigma11 = filt_fn(img1**2) - mu11 199 | sigma01 = filt_fn(img0 * img1) - mu01 200 | 201 | # Clip the variances and covariances to valid values. 202 | # Variance must be non-negative: 203 | sigma00 = np.maximum(0., sigma00) 204 | sigma11 = np.maximum(0., sigma11) 205 | sigma01 = np.sign(sigma01) * np.minimum( 206 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 207 | c1 = (k1 * max_val)**2 208 | c2 = (k2 * max_val)**2 209 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 210 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 211 | ssim_map = numer / denom 212 | ssim = np.mean(ssim_map) 213 | return ssim_map if return_map else ssim 214 | 215 | 216 | def calculate_sum_diff(tensor_arr): 217 | diff_matrix = tensor_arr.view(-1, 1) - tensor_arr 218 | upper_triangle = torch.triu(diff_matrix, diagonal=1).abs() 219 | result = torch.sum(upper_triangle) 220 | return result 221 | 222 | 223 | import torch.nn as nn 224 | class TVLoss(nn.Module): 225 | def __init__(self,TVLoss_weight=1): 226 | super(TVLoss,self).__init__() 227 | self.TVLoss_weight = TVLoss_weight 228 | 229 | def forward(self,x): 230 | batch_size = x.size()[0] 231 | h_x = x.size()[2] 232 | w_x = x.size()[3] 233 | count_h = self._tensor_size(x[:,:,1:,:]) 234 | count_w = self._tensor_size(x[:,:,:,1:]) 235 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 236 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 237 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 238 | 239 | def _tensor_size(self,t): 240 | return t.size()[1]*t.size()[2]*t.size()[3] 241 | 242 | 243 | def TVloss_Spectral(specmap): 244 | # specmap n x 31 245 | # shifted = torch.roll(specmap, shifts=1, dims=1) 246 | # return (((shifted - specmap) / (specmap.detach() + 1e-7)) ** 2).mean() 247 | return (((specmap[:, 1:] - specmap[:, :-1]) / (specmap[:, :-1].detach() + 1e-4)) ** 2).mean() 248 | 249 | 250 | def TVloss_SSF(ssf): 251 | return (((ssf[1:] - ssf[:-1]) / (ssf[:-1].detach() + 1e-4)) ** 2).mean() 252 | 253 | 254 | 255 | import plyfile 256 | import skimage.measure 257 | def convert_sdf_samples_to_ply( 258 | pytorch_3d_sdf_tensor, 259 | ply_filename_out, 260 | bbox, 261 | level=0.5, 262 | offset=None, 263 | scale=None, 264 | ): 265 | """ 266 | Convert sdf samples to .ply 267 | 268 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 269 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 270 | :voxel_size: float, the size of the voxels 271 | :ply_filename_out: string, path of the filename to save to 272 | 273 | This function adapted from: https://github.com/RobotLocomotion/spartan 274 | """ 275 | 276 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 277 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) 278 | 279 | verts, faces, normals, values = skimage.measure.marching_cubes( 280 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size 281 | ) 282 | faces = faces[...,::-1] # inverse face orientation 283 | 284 | # transform from voxel coordinates to camera coordinates 285 | # note x and y are flipped in the output of marching_cubes 286 | mesh_points = np.zeros_like(verts) 287 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0] 288 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1] 289 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2] 290 | 291 | # apply additional offset and scale 292 | if scale is not None: 293 | mesh_points = mesh_points / scale 294 | if offset is not None: 295 | mesh_points = mesh_points - offset 296 | 297 | # try writing to the ply file 298 | 299 | num_verts = verts.shape[0] 300 | num_faces = faces.shape[0] 301 | 302 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 303 | 304 | for i in range(0, num_verts): 305 | verts_tuple[i] = tuple(mesh_points[i, :]) 306 | 307 | faces_building = [] 308 | for i in range(0, num_faces): 309 | faces_building.append(((faces[i, :].tolist(),))) 310 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 311 | 312 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 313 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 314 | 315 | ply_data = plyfile.PlyData([el_verts, el_faces]) 316 | print("saving mesh to %s" % (ply_filename_out)) 317 | ply_data.write(ply_filename_out) 318 | -------------------------------------------------------------------------------- /dataLoader/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch, re 2 | import numpy as np 3 | from torch import searchsorted 4 | from kornia import create_meshgrid 5 | 6 | 7 | # from utils import index_point_feature 8 | 9 | def depth2dist(z_vals, cos_angle): 10 | # z_vals: [N_ray N_sample] 11 | device = z_vals.device 12 | dists = z_vals[..., 1:] - z_vals[..., :-1] 13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 14 | dists = dists * cos_angle.unsqueeze(-1) 15 | return dists 16 | 17 | 18 | def ndc2dist(ndc_pts, cos_angle): 19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 21 | return dists 22 | 23 | 24 | def get_ray_directions(H, W, focal, center=None): 25 | """ 26 | Get ray directions for all pixels in camera coordinate. 27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 28 | ray-tracing-generating-camera-rays/standard-coordinate-systems 29 | Inputs: 30 | H, W, focal: image height, width and focal length 31 | Outputs: 32 | directions: (H, W, 3), the direction of the rays in camera coordinate 33 | """ 34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 35 | 36 | i, j = grid.unbind(-1) 37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 38 | # see https://github.com/bmild/nerf/issues/24 39 | cent = center if center is not None else [W / 2, H / 2] 40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 41 | 42 | return directions 43 | 44 | 45 | def get_ray_directions_blender(H, W, focal, center=None): 46 | """ 47 | Get ray directions for all pixels in camera coordinate. 48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 49 | ray-tracing-generating-camera-rays/standard-coordinate-systems 50 | Inputs: 51 | H, W, focal: image height, width and focal length 52 | Outputs: 53 | directions: (H, W, 3), the direction of the rays in camera coordinate 54 | """ 55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 56 | i, j = grid.unbind(-1) 57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 58 | # see https://github.com/bmild/nerf/issues/24 59 | cent = center if center is not None else [W / 2, H / 2] 60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], 61 | -1) # (H, W, 3) 62 | 63 | return directions 64 | 65 | 66 | def get_rays_by_coord_np(H, W, focal, c2w, coords): 67 | i, j = (coords[:,0] -W*0.5)/focal[0], -(coords[:,1] -H*0.5)/focal[1] 68 | dirs = np.stack([i,j,-np.ones_like(i)],-1) 69 | rays_d = dirs @ c2w[:3, :3].T # (H, W, 3) 70 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 71 | return torch.from_numpy(rays_o), torch.from_numpy(rays_d) 72 | 73 | 74 | def get_rays(directions, c2w): 75 | """ 76 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 77 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 78 | ray-tracing-generating-camera-rays/standard-coordinate-systems 79 | Inputs: 80 | directions: (H, W, 3) precomputed ray directions in camera coordinate 81 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 82 | Outputs: 83 | rays_o: (H*W, 3), the origin of the rays in world coordinate 84 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 85 | """ 86 | # Rotate ray directions from camera coordinate to the world coordinate 87 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 88 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 89 | # The origin of all rays is the camera origin in world coordinate 90 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 91 | 92 | rays_d = rays_d.view(-1, 3) 93 | rays_o = rays_o.view(-1, 3) 94 | 95 | return rays_o, rays_d 96 | 97 | 98 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): 99 | # Shift ray origins to near plane 100 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 101 | rays_o = rays_o + t[..., None] * rays_d 102 | 103 | # Projection 104 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 105 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 106 | o2 = 1. + 2. * near / rays_o[..., 2] 107 | 108 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 109 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 110 | d2 = -2. * near / rays_o[..., 2] 111 | 112 | rays_o = torch.stack([o0, o1, o2], -1) 113 | rays_d = torch.stack([d0, d1, d2], -1) 114 | 115 | return rays_o, rays_d 116 | 117 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 118 | # Shift ray origins to near plane 119 | t = (near - rays_o[..., 2]) / rays_d[..., 2] 120 | rays_o = rays_o + t[..., None] * rays_d 121 | 122 | # Projection 123 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 124 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 125 | o2 = 1. - 2. * near / rays_o[..., 2] 126 | 127 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 128 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 129 | d2 = 2. * near / rays_o[..., 2] 130 | 131 | rays_o = torch.stack([o0, o1, o2], -1) 132 | rays_d = torch.stack([d0, d1, d2], -1) 133 | 134 | return rays_o, rays_d 135 | 136 | # Hierarchical sampling (section 5.2) 137 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 138 | device = weights.device 139 | # Get pdf 140 | weights = weights + 1e-5 # prevent nans 141 | pdf = weights / torch.sum(weights, -1, keepdim=True) 142 | cdf = torch.cumsum(pdf, -1) 143 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 144 | 145 | # Take uniform samples 146 | if det: 147 | u = torch.linspace(0., 1., steps=N_samples, device=device) 148 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 149 | else: 150 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) 151 | 152 | # Pytest, overwrite u with numpy's fixed random numbers 153 | if pytest: 154 | np.random.seed(0) 155 | new_shape = list(cdf.shape[:-1]) + [N_samples] 156 | if det: 157 | u = np.linspace(0., 1., N_samples) 158 | u = np.broadcast_to(u, new_shape) 159 | else: 160 | u = np.random.rand(*new_shape) 161 | u = torch.Tensor(u) 162 | 163 | # Invert CDF 164 | u = u.contiguous() 165 | inds = searchsorted(cdf.detach(), u, right=True) 166 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 167 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 168 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 169 | 170 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 171 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 172 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 173 | 174 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 175 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 176 | t = (u - cdf_g[..., 0]) / denom 177 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 178 | 179 | return samples 180 | 181 | 182 | def dda(rays_o, rays_d, bbox_3D): 183 | inv_ray_d = 1.0 / (rays_d + 1e-6) 184 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 185 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d 186 | t = torch.stack((t_min, t_max)) # 2 N_rays 3 187 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] 188 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] 189 | return t_min, t_max 190 | 191 | 192 | def ray_marcher(rays, 193 | N_samples=64, 194 | lindisp=False, 195 | perturb=0, 196 | bbox_3D=None): 197 | """ 198 | sample points along the rays 199 | Inputs: 200 | rays: () 201 | 202 | Returns: 203 | 204 | """ 205 | 206 | # Decompose the inputs 207 | N_rays = rays.shape[0] 208 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 209 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 210 | 211 | if bbox_3D is not None: 212 | # cal aabb boundles 213 | near, far = dda(rays_o, rays_d, bbox_3D) 214 | 215 | # Sample depth points 216 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 217 | if not lindisp: # use linear sampling in depth space 218 | z_vals = near * (1 - z_steps) + far * z_steps 219 | else: # use linear sampling in disparity space 220 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 221 | 222 | z_vals = z_vals.expand(N_rays, N_samples) 223 | 224 | if perturb > 0: # perturb sampling depths (z_vals) 225 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points 226 | # get intervals between samples 227 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 228 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 229 | 230 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 231 | z_vals = lower + (upper - lower) * perturb_rand 232 | 233 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 234 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 235 | 236 | return xyz_coarse_sampled, rays_o, rays_d, z_vals 237 | 238 | 239 | def read_pfm(filename): 240 | file = open(filename, 'rb') 241 | color = None 242 | width = None 243 | height = None 244 | scale = None 245 | endian = None 246 | 247 | header = file.readline().decode('utf-8').rstrip() 248 | if header == 'PF': 249 | color = True 250 | elif header == 'Pf': 251 | color = False 252 | else: 253 | raise Exception('Not a PFM file.') 254 | 255 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 256 | if dim_match: 257 | width, height = map(int, dim_match.groups()) 258 | else: 259 | raise Exception('Malformed PFM header.') 260 | 261 | scale = float(file.readline().rstrip()) 262 | if scale < 0: # little-endian 263 | endian = '<' 264 | scale = -scale 265 | else: 266 | endian = '>' # big-endian 267 | 268 | data = np.fromfile(file, endian + 'f') 269 | shape = (height, width, 3) if color else (height, width) 270 | 271 | data = np.reshape(data, shape) 272 | data = np.flipud(data) 273 | file.close() 274 | return data, scale 275 | 276 | 277 | def ndc_bbox(all_rays): 278 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] 279 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] 280 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] 281 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] 282 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') 283 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) -------------------------------------------------------------------------------- /dataLoader/colmap2nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from pathlib import Path, PurePosixPath 14 | 15 | import numpy as np 16 | import json 17 | import sys 18 | import math 19 | import cv2 20 | import os 21 | import shutil 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") 25 | 26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") 27 | parser.add_argument("--video_fps", default=2) 28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") 29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") 31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 32 | parser.add_argument("--images", default="images", help="input path to the images") 33 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") 34 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") 35 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start") 36 | parser.add_argument("--out", default="transforms.json", help="output path") 37 | args = parser.parse_args() 38 | return args 39 | 40 | def do_system(arg): 41 | print(f"==== running: {arg}") 42 | err = os.system(arg) 43 | if err: 44 | print("FATAL: command failed") 45 | sys.exit(err) 46 | 47 | def run_ffmpeg(args): 48 | if not os.path.isabs(args.images): 49 | args.images = os.path.join(os.path.dirname(args.video_in), args.images) 50 | images = args.images 51 | video = args.video_in 52 | fps = float(args.video_fps) or 1.0 53 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") 54 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 55 | sys.exit(1) 56 | try: 57 | shutil.rmtree(images) 58 | except: 59 | pass 60 | do_system(f"mkdir {images}") 61 | 62 | time_slice_value = "" 63 | time_slice = args.time_slice 64 | if time_slice: 65 | start, end = time_slice.split(",") 66 | time_slice_value = f",select='between(t\,{start}\,{end})'" 67 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") 68 | 69 | def run_colmap(args): 70 | db=args.colmap_db 71 | images=args.images 72 | db_noext=str(Path(db).with_suffix("")) 73 | 74 | if args.text=="text": 75 | args.text=db_noext+"_text" 76 | text=args.text 77 | sparse=db_noext+"_sparse" 78 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") 79 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 80 | sys.exit(1) 81 | if os.path.exists(db): 82 | os.remove(db) 83 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") 84 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}") 85 | try: 86 | shutil.rmtree(sparse) 87 | except: 88 | pass 89 | do_system(f"mkdir {sparse}") 90 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") 91 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") 92 | try: 93 | shutil.rmtree(text) 94 | except: 95 | pass 96 | do_system(f"mkdir {text}") 97 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") 98 | 99 | def variance_of_laplacian(image): 100 | return cv2.Laplacian(image, cv2.CV_64F).var() 101 | 102 | def sharpness(imagePath): 103 | image = cv2.imread(imagePath) 104 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 105 | fm = variance_of_laplacian(gray) 106 | return fm 107 | 108 | def qvec2rotmat(qvec): 109 | return np.array([ 110 | [ 111 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 114 | ], [ 115 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 116 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 117 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 118 | ], [ 119 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 120 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 121 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 122 | ] 123 | ]) 124 | 125 | def rotmat(a, b): 126 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 127 | v = np.cross(a, b) 128 | c = np.dot(a, b) 129 | s = np.linalg.norm(v) 130 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 131 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 132 | 133 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 134 | da = da / np.linalg.norm(da) 135 | db = db / np.linalg.norm(db) 136 | c = np.cross(da, db) 137 | denom = np.linalg.norm(c)**2 138 | t = ob - oa 139 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 140 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 141 | if ta > 0: 142 | ta = 0 143 | if tb > 0: 144 | tb = 0 145 | return (oa+ta*da+ob+tb*db) * 0.5, denom 146 | 147 | if __name__ == "__main__": 148 | args = parse_args() 149 | if args.video_in != "": 150 | run_ffmpeg(args) 151 | if args.run_colmap: 152 | run_colmap(args) 153 | AABB_SCALE = int(args.aabb_scale) 154 | SKIP_EARLY = int(args.skip_early) 155 | IMAGE_FOLDER = args.images 156 | TEXT_FOLDER = args.text 157 | OUT_PATH = args.out 158 | print(f"outputting to {OUT_PATH}...") 159 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: 160 | angle_x = math.pi / 2 161 | for line in f: 162 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 163 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 164 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 165 | if line[0] == "#": 166 | continue 167 | els = line.split(" ") 168 | w = float(els[2]) 169 | h = float(els[3]) 170 | fl_x = float(els[4]) 171 | fl_y = float(els[4]) 172 | k1 = 0 173 | k2 = 0 174 | p1 = 0 175 | p2 = 0 176 | cx = w / 2 177 | cy = h / 2 178 | if els[1] == "SIMPLE_PINHOLE": 179 | cx = float(els[5]) 180 | cy = float(els[6]) 181 | elif els[1] == "PINHOLE": 182 | fl_y = float(els[5]) 183 | cx = float(els[6]) 184 | cy = float(els[7]) 185 | elif els[1] == "SIMPLE_RADIAL": 186 | cx = float(els[5]) 187 | cy = float(els[6]) 188 | k1 = float(els[7]) 189 | elif els[1] == "RADIAL": 190 | cx = float(els[5]) 191 | cy = float(els[6]) 192 | k1 = float(els[7]) 193 | k2 = float(els[8]) 194 | elif els[1] == "OPENCV": 195 | fl_y = float(els[5]) 196 | cx = float(els[6]) 197 | cy = float(els[7]) 198 | k1 = float(els[8]) 199 | k2 = float(els[9]) 200 | p1 = float(els[10]) 201 | p2 = float(els[11]) 202 | else: 203 | print("unknown camera model ", els[1]) 204 | # fl = 0.5 * w / tan(0.5 * angle_x); 205 | angle_x = math.atan(w / (fl_x * 2)) * 2 206 | angle_y = math.atan(h / (fl_y * 2)) * 2 207 | fovx = angle_x * 180 / math.pi 208 | fovy = angle_y * 180 / math.pi 209 | 210 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") 211 | 212 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: 213 | i = 0 214 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 215 | out = { 216 | "camera_angle_x": angle_x, 217 | "camera_angle_y": angle_y, 218 | "fl_x": fl_x, 219 | "fl_y": fl_y, 220 | "k1": k1, 221 | "k2": k2, 222 | "p1": p1, 223 | "p2": p2, 224 | "cx": cx, 225 | "cy": cy, 226 | "w": w, 227 | "h": h, 228 | "aabb_scale": AABB_SCALE, 229 | "frames": [], 230 | } 231 | 232 | up = np.zeros(3) 233 | for line in f: 234 | line = line.strip() 235 | if line[0] == "#": 236 | continue 237 | i = i + 1 238 | if i < SKIP_EARLY*2: 239 | continue 240 | if i % 2 == 1: 241 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 242 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) 243 | # why is this requireing a relitive path while using ^ 244 | image_rel = os.path.relpath(IMAGE_FOLDER) 245 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}") 246 | b=sharpness(name) 247 | print(name, "sharpness=",b) 248 | image_id = int(elems[0]) 249 | qvec = np.array(tuple(map(float, elems[1:5]))) 250 | tvec = np.array(tuple(map(float, elems[5:8]))) 251 | R = qvec2rotmat(-qvec) 252 | t = tvec.reshape([3,1]) 253 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 254 | c2w = np.linalg.inv(m) 255 | c2w[0:3,2] *= -1 # flip the y and z axis 256 | c2w[0:3,1] *= -1 257 | c2w = c2w[[1,0,2,3],:] # swap y and z 258 | c2w[2,:] *= -1 # flip whole world upside down 259 | 260 | up += c2w[0:3,1] 261 | 262 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} 263 | out["frames"].append(frame) 264 | nframes = len(out["frames"]) 265 | up = up / np.linalg.norm(up) 266 | print("up vector was", up) 267 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] 268 | R = np.pad(R,[0,1]) 269 | R[-1, -1] = 1 270 | 271 | 272 | for f in out["frames"]: 273 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 274 | 275 | # find a central point they are all looking at 276 | print("computing center of attention...") 277 | totw = 0.0 278 | totp = np.array([0.0, 0.0, 0.0]) 279 | for f in out["frames"]: 280 | mf = f["transform_matrix"][0:3,:] 281 | for g in out["frames"]: 282 | mg = g["transform_matrix"][0:3,:] 283 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 284 | if w > 0.01: 285 | totp += p*w 286 | totw += w 287 | totp /= totw 288 | print(totp) # the cameras are looking at totp 289 | for f in out["frames"]: 290 | f["transform_matrix"][0:3,3] -= totp 291 | 292 | avglen = 0. 293 | for f in out["frames"]: 294 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 295 | avglen /= nframes 296 | print("avg camera distance from origin", avglen) 297 | for f in out["frames"]: 298 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 299 | 300 | for f in out["frames"]: 301 | f["transform_matrix"] = f["transform_matrix"].tolist() 302 | print(nframes,"frames") 303 | print(f"writing {OUT_PATH}") 304 | with open(OUT_PATH, "w") as outfile: 305 | json.dump(out, outfile, indent=2) -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch,os,imageio,sys 3 | from tqdm.auto import tqdm 4 | from dataLoader.ray_utils import get_rays 5 | from models.tensoRF import TensorVM, TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask 6 | from utils import * 7 | from dataLoader.ray_utils import ndc_rays_blender 8 | from dataLoader.llff import LLFFDataset 9 | from opt import args 10 | import scipy.io as sio 11 | 12 | def OctreeRender_trilinear_fast(rays, tensorf, chunk=args.chunk_size, N_samples=-1, ndc_ray=False, white_bg=True, \ 13 | is_train=False, device='cuda', **kargs): 14 | poseids, filterids = kargs['poseids'], kargs['filterids'] 15 | filters = LLFFDataset.filters_back 16 | 17 | rgbs, alphas, depth_maps, weights, uncertainties, dist_losses, spec_maps = [], [], [], [], [], [], [] 18 | N_rays_all = rays.shape[0] 19 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): 20 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) 21 | poseids_chunk = poseids[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) 22 | filters_chunk = filters[filterids[chunk_idx * chunk:(chunk_idx + 1) * chunk].reshape(-1)].to(device) 23 | 24 | rgb_map, depth_map, dist_loss, spec_map, phi = tensorf(rays_chunk, poseids_chunk, filters_chunk, is_train=is_train, white_bg=white_bg, \ 25 | ndc_ray=ndc_ray, N_samples=N_samples) 26 | 27 | rgbs.append(rgb_map) 28 | depth_maps.append(depth_map) 29 | dist_losses.append(dist_loss) 30 | spec_maps.append(spec_map) 31 | 32 | return None if args.render_test_exhibition else torch.cat(rgbs), None, torch.cat(depth_maps), None, None, \ 33 | torch.cat(dist_losses).mean(), torch.cat(spec_maps), phi 34 | 35 | @torch.no_grad() 36 | def evaluation(test_dataset:LLFFDataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 37 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 38 | PSNRs, rgb_maps, depth_maps = [], [], [] 39 | ssims,l_alex,l_vgg=[],[],[] 40 | os.makedirs(savePath, exist_ok=True) 41 | os.makedirs(savePath+"/rgbd", exist_ok=True) 42 | os.makedirs(savePath+"/spec", exist_ok=True) 43 | # delete old spec.mat 44 | del_list = os.listdir(f'{savePath}/spec') 45 | for f in del_list: 46 | file_path = os.path.join(f'{savePath}/spec', f) 47 | if os.path.isfile(file_path): 48 | os.remove(file_path) 49 | 50 | 51 | try: 52 | tqdm._instances.clear() 53 | except Exception: 54 | pass 55 | 56 | near_far = test_dataset.near_far 57 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1) 58 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval)) 59 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout): 60 | 61 | W, H = test_dataset.img_wh 62 | rays = samples.view(-1,samples.shape[-1]) 63 | 64 | rgb_map, _, depth_map, _, _, _, spec_map, _ = \ 65 | renderer(rays, tensorf, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, device=device, \ 66 | poseids=test_dataset.all_poses[idxs[idx]], filterids=test_dataset.all_filtersIdx[idxs[idx]]) 67 | rgb_map = rgb_map.clamp(0.0, 1.0) 68 | 69 | rgb_map, depth_map, spec_map = rgb_map.reshape(H, W, args.observation_channel).cpu(), depth_map.reshape(H, W).cpu(), spec_map.reshape(H, W, args.spec_channel).cpu().numpy() 70 | 71 | depth_map_raw = depth_map.numpy().copy() 72 | if rgb_map.shape[-1] != 1: 73 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 74 | else: 75 | depth_map = visualize_depth_numpy_mono(depth_map.numpy(),near_far) 76 | 77 | if len(test_dataset.all_rgbs): 78 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, args.observation_channel) 79 | loss = torch.mean((rgb_map - gt_rgb) ** 2) 80 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) 81 | 82 | if compute_extra_metrics: 83 | ssim = rgb_ssim(rgb_map, gt_rgb, 1) 84 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device) 85 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device) 86 | ssims.append(ssim) 87 | l_alex.append(l_a) 88 | l_vgg.append(l_v) 89 | 90 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 91 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 92 | rgb_maps.append(rgb_map) 93 | depth_maps.append(depth_map) 94 | if savePath is not None: 95 | # imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 96 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 97 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 98 | # save spec.mat 99 | sio.savemat(f'{savePath}/spec/{prtx}{idx:03d}.mat', {'spec': spec_map}) 100 | sio.savemat(f'{savePath}/rgbd/{prtx}{idx:03d}_depth.mat', {'depth': depth_map_raw}) 101 | 102 | # imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10) 103 | # imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10) 104 | 105 | if PSNRs: 106 | psnr = np.mean(np.asarray(PSNRs)) 107 | if compute_extra_metrics: 108 | ssim = np.mean(np.asarray(ssims)) 109 | l_a = np.mean(np.asarray(l_alex)) 110 | l_v = np.mean(np.asarray(l_vgg)) 111 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 112 | else: 113 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 114 | 115 | 116 | return PSNRs 117 | 118 | @torch.no_grad() 119 | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 120 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 121 | PSNRs, rgb_maps, depth_maps = [], [], [] 122 | ssims,l_alex,l_vgg=[],[],[] 123 | os.makedirs(savePath, exist_ok=True) 124 | os.makedirs(savePath+"/rgbd", exist_ok=True) 125 | os.makedirs(savePath+"/spec", exist_ok=True) 126 | 127 | try: 128 | tqdm._instances.clear() 129 | except Exception: 130 | pass 131 | 132 | near_far = test_dataset.near_far 133 | ones_filtersIdx = torch.LongTensor([[0]]) 134 | for idx, c2w in tqdm(enumerate(c2ws)): 135 | 136 | W, H = test_dataset.img_wh 137 | 138 | c2w = torch.FloatTensor(c2w) 139 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) 140 | if ndc_ray: 141 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) 142 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 143 | 144 | rgb_map, _, depth_map, _, _, _, spec_map, _ = \ 145 | renderer(rays, tensorf, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, device=device, \ 146 | poseids=ones_filtersIdx.expand((rays.shape[0], -1)), filterids=ones_filtersIdx.expand((rays.shape[0], -1))) 147 | rgb_map = rgb_map.clamp(0.0, 1.0) 148 | 149 | rgb_map, depth_map, spec_map = rgb_map.reshape(H, W, args.observation_channel).cpu(), depth_map.reshape(H, W).cpu(), spec_map.reshape(H, W, args.spec_channel).cpu().numpy() 150 | 151 | if rgb_map.shape[-1] != 1: 152 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 153 | else: 154 | depth_map = visualize_depth_numpy_mono(depth_map.numpy(),near_far) 155 | 156 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 157 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 158 | rgb_maps.append(rgb_map) 159 | depth_maps.append(depth_map) 160 | if savePath is not None: 161 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 162 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 163 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 164 | sio.savemat(f'{savePath}/spec/{prtx}{idx:03d}.mat', {'spec': spec_map}) 165 | 166 | 167 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) 168 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) 169 | 170 | if PSNRs: 171 | psnr = np.mean(np.asarray(PSNRs)) 172 | if compute_extra_metrics: 173 | ssim = np.mean(np.asarray(ssims)) 174 | l_a = np.mean(np.asarray(l_alex)) 175 | l_v = np.mean(np.asarray(l_vgg)) 176 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 177 | else: 178 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 179 | 180 | 181 | return PSNRs 182 | 183 | 184 | @torch.no_grad() 185 | def exhibition(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 186 | white_bg=False, ndc_ray=False, device='cuda', scale=False, **kargs): 187 | rgb_maps, depth_maps = [], [] 188 | os.makedirs(savePath, exist_ok=True) 189 | os.makedirs(savePath+"/rgbd", exist_ok=True) 190 | os.makedirs(savePath+"/spec", exist_ok=True) 191 | 192 | try: 193 | tqdm._instances.clear() 194 | except Exception: 195 | pass 196 | 197 | near_far = test_dataset.near_far 198 | ones_filtersIdx = torch.LongTensor([[0]]) 199 | filtersets = kargs['filtersets'] 200 | ssfs = kargs['ssfs'] 201 | lights = kargs['lights'] 202 | try: 203 | lightOrigin = torch.from_numpy(sio.loadmat(args.exhibition_lightorigin_path)['light_spec'][:, args.band_start_idx:]).cuda() 204 | lightOrigin = (1 / lightOrigin.mean()) * lightOrigin # normalize to one mean 205 | except Exception as e: 206 | print('Just warning you, seems you did not provide lightorgin file.') 207 | for idx, c2w in tqdm(enumerate(c2ws)): 208 | times1 = math.ceil(len(c2ws) / len(filtersets)) 209 | if idx % times1 == 0: 210 | filter = torch.from_numpy(filtersets[idx // times1]).cuda() 211 | 212 | times2 = math.ceil(len(c2ws) / len(ssfs)) 213 | if idx % times2 == 0: 214 | ssf = torch.from_numpy(ssfs[idx // times2]).cuda() 215 | 216 | if len(lights) != 0: 217 | if idx % math.ceil(len(c2ws) / len(lights)) == 0: 218 | light = torch.from_numpy(lights[idx // math.ceil(len(c2ws) / len(lights))]).cuda() 219 | else: 220 | light = None 221 | 222 | W, H = test_dataset.img_wh 223 | 224 | c2w = torch.FloatTensor(c2w) 225 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) 226 | if ndc_ray: 227 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) 228 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 229 | 230 | _, _, depth_map, _, _, _, spec_map, _ = \ 231 | renderer(rays, tensorf, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, device=device, \ 232 | poseids=ones_filtersIdx.expand((rays.shape[0], -1)), filterids=ones_filtersIdx.expand((rays.shape[0], -1))) 233 | 234 | if light is not None: 235 | spec_map = spec_map * light / lightOrigin 236 | rgb_map = ((spec_map * filter) @ ssf) 237 | if scale: 238 | rgb_map = (0.6 / torch.quantile(rgb_map.reshape(-1), 0.95)) * rgb_map 239 | rgb_map = rgb_map.clamp(0,1) 240 | 241 | rgb_map, depth_map, spec_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu(), spec_map.reshape(H, W, args.spec_channel).cpu().numpy() 242 | 243 | if rgb_map.shape[-1] != 1: 244 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 245 | else: 246 | depth_map = visualize_depth_numpy_mono(depth_map.numpy(),near_far) 247 | 248 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 249 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 250 | rgb_maps.append(rgb_map) 251 | depth_maps.append(depth_map) 252 | if savePath is not None: 253 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 254 | 255 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) 256 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) 257 | -------------------------------------------------------------------------------- /find_filter/GA2.py: -------------------------------------------------------------------------------- 1 | from calendar import c 2 | import itertools 3 | import os 4 | import random 5 | import sys 6 | sys.path.append(r'E:\pythonProject\python3\myutils_v2') 7 | sys.path.append(r'..\myutils_v2') 8 | sys.path.append(r'../myutils') 9 | sys.path.append(os.getcwd()) 10 | import time 11 | import traceback 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import numba as nb 16 | import scipy.io as sio 17 | from tqdm.auto import tqdm 18 | 19 | import filesort_int 20 | # from Online_breakpoint_debug import Online_breakpoint_debug 21 | from multiprocessing.dummy import Pool 22 | from scipy.spatial.transform import Rotation 23 | from Images import Plt_subplot_in_loop_or_ion 24 | 25 | # myd = Online_breakpoint_debug() 26 | # myd.start() 27 | 28 | def execute_time(func): 29 | # 定义嵌套函数,用来打印出装饰的函数的执行时间 30 | def wrapper(*args, **kwargs): 31 | # 定义开始时间 32 | start = time.time() 33 | # 执行函数 34 | func_return = func(*args, **kwargs) 35 | # 定义结束时间 36 | end = time.time() 37 | # 打印方法名称和其执行时间 38 | print('{}() execute time: {}s'.format(func.__name__, end-start)) 39 | # 返回func的返回值 40 | return func_return 41 | 42 | # 返回嵌套的函数 43 | return wrapper 44 | 45 | def normalization(a): 46 | ran = a.max() - a.min() 47 | a = (a - a.min()) / ran 48 | return a, ran 49 | 50 | 51 | def log(pbar, i, a, b): 52 | d = "epoch --> " + \ 53 | str(i + 1).rjust(5, " "), " max:" + \ 54 | str(round(a, 4)).rjust(8, " "), "mean:" + \ 55 | str(round(b, 4)).rjust(8, " "), "alpha:" + \ 56 | str(round(a / b, 4)).rjust(8, " ") 57 | # pbar.set_description(str(d)) 58 | print(d) 59 | 60 | 61 | class GeneSolve: 62 | ## 初始定义,后续引用GeneSolve类方法为(初始种群数,最大迭代数,交叉概率,变异概率,最大适应度/平均适应度(扰动率趋于平稳则越接近1越好)) 63 | def __init__(self, pose_path, ssf_path, pop_size, epoch, cross_prob, mutate_prob, alpha, 64 | poses, aim_number, sigma, print_batch=2): 65 | 66 | assert pop_size % 2 == 0, 'pop size must be even number' 67 | 68 | self.sigma = sigma 69 | self.poses = poses 70 | self.ssf_path = ssf_path 71 | self.pose_path = pose_path 72 | self.aim_number = aim_number 73 | self.pop_size = pop_size 74 | self.epoch = epoch 75 | self.cross_prob = cross_prob 76 | self.mutate_prob = mutate_prob 77 | self.print_batch = print_batch 78 | self.alpha = alpha 79 | self.width = poses[0] * poses[1] 80 | self.best = None 81 | self.coor, self.eulers, self.view_dirs = self.prepare_poses() # list N x 3, 12 x 3, 12 x 3 82 | self.fai = self.prepare_ssfs() # np 12 x 20 x 31 83 | self.ssf_coor = self.cal_ssfCoorelation() 84 | self.dis_mtx = self.cal_distance() # np N x N 85 | self.cosSimi = self.cal_cos_simi() 86 | 87 | self._idxs_4_inter_cross = np.arange(0, stop=pop_size).astype(dtype=np.int) 88 | 89 | # sio.savemat('coordinate_ssf_distance.mat', { 90 | # 'coor': np.array(self.coor), 91 | # 'ssfs': self.fai, 92 | # 'distance': self.dis_mtx 93 | # }) 94 | def cross2idx(pair, loc, genes): 95 | d1, d2 = pair 96 | d1_a, d1_b = genes[d1, 0:loc], genes[d1, loc:] 97 | d2_a, d2_b = genes[d2, 0:loc], genes[d2, loc:] 98 | genes[d1] = np.append(d1_a, d2_b) 99 | genes[d2] = np.append(d2_a, d1_b) 100 | 101 | self.crossfunc = np.vectorize(cross2idx, excluded=['genes'], 102 | signature='(j),()->()', cache=True) 103 | 104 | @nb.jit(nopython=True, parallel=True) 105 | def gen(pop_size, width): 106 | # 产生初始种群 107 | genes = [] 108 | for _ in range(pop_size): 109 | tt = np.array([0] * (width - aim_number) + [1] * aim_number) 110 | np.random.shuffle(tt) 111 | genes.append(tt) 112 | return genes 113 | 114 | self.genes = np.vstack(gen(self.pop_size, self.width)) 115 | 116 | def get_meshgrid_comb(): 117 | n = number 118 | init_id = list(range(1, n)) 119 | res_id = init_id[:] 120 | for i in range(1, n): 121 | res_id += list(map(lambda x: x + n*i, init_id[i:])) 122 | return res_id 123 | self.comb_meshgrid_id = get_meshgrid_comb() 124 | 125 | pass 126 | 127 | def cal_ssfCoorelation(self): 128 | n = self.poses[1] 129 | ssf_norm = np.linalg.norm(np.array(self.filters_list), axis=-1) 130 | 131 | @nb.njit(parallel=True) 132 | def cal(n, ssfs, ssf_norm): 133 | ssfCor_norm = np.zeros((n, n)) 134 | for i in range(n): 135 | for j in range(i + 1): 136 | # tt = np.dot(ssfs[i], ssfs[j]) / (ssf_norm[i] * ssf_norm[j]) 137 | ''' use var normed ''' 138 | tt = np.var(ssfs[i] / ssf_norm[i] + ssfs[j] / ssf_norm[j]) 139 | ssfCor_norm[i, j] = ssfCor_norm[j, i] = tt 140 | return ssfCor_norm 141 | 142 | return cal(n, self.filters_list, ssf_norm) 143 | 144 | def cal_cos_simi(self): 145 | num = self.view_dirs @ self.view_dirs.T 146 | normed = np.linalg.norm(self.view_dirs, axis=-1).reshape(-1, 1) 147 | all_normed = normed * normed.T 148 | 149 | cosSimi = np.abs(num / all_normed) / 2 + 0.5 150 | cosSimi = cosSimi ** cosSim_gamma 151 | 152 | return cosSimi 153 | 154 | def prepare_ssfs(self): 155 | fs = filesort_int.sort_file_int(self.ssf_path, 'mat')[1:] 156 | filters_list = np.array([ 157 | np.asarray(np.diagonal(sio.loadmat(x)['filter']), dtype=np.float32) for x in fs 158 | ]) 159 | fai = filters_list[np.newaxis, ...].repeat(self.poses[0], axis=0) 160 | 161 | self.filters_list = filters_list # 19 x 25 162 | self.filters_trans_mean = filters_list.mean(axis=1) 163 | 164 | return fai 165 | 166 | def prepare_poses(self): 167 | coor = [] 168 | view_dirs = [] 169 | 170 | foo_v = np.array([0.3, 0.3, 0.7]).reshape([-1, 1]) 171 | hypothetic_vect = foo_v / np.linalg.norm(foo_v) 172 | 173 | poses_bounds = np.load(self.pose_path) 174 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 175 | # Original poses has rotation in form "down right back", change to "right up back" 176 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 177 | 178 | for i in range(poses.shape[0]): 179 | coor.append(poses[i, :, 3]) 180 | view_dirs.append(poses[i, :, :3] @ hypothetic_vect) 181 | 182 | # matrix to angles, 183 | r = Rotation.from_matrix(poses[..., :3]) 184 | euler_angles = r.as_euler('zxy', True) 185 | 186 | return coor, euler_angles, np.array(view_dirs).squeeze() 187 | 188 | def cal_distance(self): 189 | n = len(self.coor) 190 | map_dis = lambda d: np.exp((-d ** 2) / self.sigma ** 2) 191 | 192 | @nb.njit(parallel=True) 193 | def cal(coor): 194 | cal_dis = lambda a, b: np.sqrt(np.sum((coor[a] - coor[b]) ** 2)) 195 | dis_mtx = np.zeros((n, n)) 196 | for r in range(n): 197 | for c in range(r + 1, n): 198 | tt = cal_dis(r, c) 199 | dis_mtx[r, c] = dis_mtx[c, r] = tt 200 | return dis_mtx 201 | 202 | dis_mtx = cal(self.coor) 203 | 204 | dis_mtx, _ = normalization(dis_mtx) 205 | dis_mtx = map_dis(6 * dis_mtx) 206 | 207 | return dis_mtx 208 | 209 | 210 | def inter_cross(self): 211 | np.random.shuffle(self._idxs_4_inter_cross) 212 | idxs = self._idxs_4_inter_cross.reshape([-1, 2]) 213 | idxs2change = idxs[np.random.random(idxs.shape[0]) <= self.cross_prob] 214 | locs = np.random.randint(1, self.width-2, idxs2change.shape[0]) 215 | 216 | self.crossfunc(pair=idxs2change, loc=locs, genes=self.genes) 217 | 218 | 219 | def mutate(self): 220 | """基因突变""" 221 | ready_index = list(range(self.pop_size)) 222 | for i in ready_index: 223 | t0 = self.genes[i] 224 | if np.random.uniform(0, 1) <= self.mutate_prob: 225 | loc = random.choice(range(0, self.width)) 226 | t0[loc] = 1 - t0[loc] 227 | # fix the total number to ** 228 | self.genes[i] = self.fix_total_number(t0) 229 | 230 | def fix_total_number(self, genes): 231 | ones_num = genes.sum() 232 | if ones_num != self.aim_number: 233 | to_what = 1 if ones_num < self.aim_number else 0 234 | 235 | tt = list(np.argwhere(genes == (1 - to_what)).squeeze()) 236 | try: 237 | ones_idx = random.sample(tt, abs(ones_num - self.aim_number)) 238 | except Exception as e: 239 | print(traceback.format_exc()) 240 | # myd.goin(locals()) 241 | genes[ones_idx] = to_what 242 | 243 | return genes 244 | 245 | 246 | def _get_combination(self, arr): 247 | comb = np.array(np.meshgrid(arr, arr)).T.reshape(-1, 2) 248 | comb2 = comb[self.comb_meshgrid_id, :] 249 | 250 | return comb2 251 | 252 | def _pad_filterid(self, arr): 253 | arr = np.unique(arr) 254 | arr_pad = np.pad(arr, (0, self.poses[1] - arr.shape[0]), "constant", constant_values=self.poses[1]) 255 | return arr_pad 256 | 257 | def get_adjust(self): 258 | """编码,从表现型到基因型的映射""" 259 | indexes = np.argwhere(self.genes == 1)[:, 1] 260 | r_idx, c_idx = np.unravel_index(indexes, self.fai.shape[:2]) 261 | r_idx, c_idx = r_idx.reshape([-1, number]), c_idx.reshape([-1, number]) # e.g. 40000 x 40 262 | 263 | """计算适应度(只有在计算适应度的时候要反函数,其余过程全都是随机的二进制编码)""" 264 | r_comb = np.apply_along_axis(self._get_combination, axis=1, arr=r_idx).reshape(-1, 2) # e.g. (40000 x 780) x 2 265 | c_comb = np.apply_along_axis(self._get_combination, axis=1, arr=c_idx).reshape(-1, 2) 266 | 267 | distance = self.dis_mtx[r_comb[:, 0], r_comb[:, 1]] 268 | ssfCoorela = self.ssf_coor[c_comb[:, 0], c_comb[:, 1]] 269 | viewDir_rela = self.cosSimi[r_comb[:, 0], r_comb[:, 1]] 270 | score = (distance * ssfCoorela * viewDir_rela).reshape(self.pop_size, -1).sum(1) 271 | 272 | # cal ssf var 273 | ssflist_with0 = np.vstack((self.filters_list, [0] * self.filters_list.shape[1])) 274 | c_idx_unique = np.apply_along_axis(self._pad_filterid, axis=1, arr=c_idx) 275 | ssfs = ssflist_with0[c_idx_unique.flatten(), :].reshape([self.pop_size, self.poses[1], -1]) 276 | ssfs_var = np.var(np.sum(ssfs, axis=1), axis=1) * ssfs_var_weight 277 | 278 | score = 1 / score + 1 / ssfs_var 279 | 280 | return score 281 | 282 | 283 | def cycle_select(self): 284 | """通过轮盘赌来进行选择""" 285 | adjusts = self.get_adjust() 286 | if self.best is None or np.max(adjusts) >= self.best[1]: 287 | self.best = self.genes[np.argmax(adjusts)], np.max(adjusts) 288 | p = adjusts / np.sum(adjusts) 289 | 290 | cu_p = np.cumsum(p) 291 | 292 | r0 = np.random.uniform(0, 1, self.pop_size) 293 | wheel_choose_func = np.vectorize(np.searchsorted, [np.int], excluded=['a'], cache=True) 294 | sel = wheel_choose_func(a=cu_p, v=r0) #[wheel_choose(r, cu_p) for r in r0] 295 | 296 | # sel = [np.append(np.where(r > cu_p)[0], 0).max() for r in r0] 297 | # sel = list(map(lambda r:np.append(np.where(r > cu_p)[0], 0).max(), r0)) 298 | # 保留最优的个体 299 | if np.max(adjusts[sel]) < self.best[1]: 300 | self.genes[sel[np.argmin(adjusts[sel])]] = self.best[0] 301 | self.genes = self.genes[sel] 302 | 303 | def evolve(self): 304 | def show_best(best): 305 | score, gene = best 306 | g = gene.reshape(self.poses) 307 | plt.imshow(g) 308 | plt.show() 309 | print(g.sum()) 310 | 311 | sio.savemat(f'find_filter_res/GA_{self.aim_number}_{score}_{random.randint(0, 50)}.mat', {'mask': g}) 312 | 313 | """逐代演化""" 314 | best = (0, None) 315 | pbar = tqdm(range(self.epoch), miniters=self.print_batch, file=sys.stdout) 316 | with Plt_subplot_in_loop_or_ion(1, 1) as myplt: 317 | for i in pbar: 318 | self.cycle_select() # 种群选取 319 | self.inter_cross() # 染色体交叉 320 | self.mutate() 321 | 322 | adjust_val = self.get_adjust() # 计算适应度 323 | best_gene, a, b = self.genes[np.argmax(adjust_val)], np.max(adjust_val), np.mean(adjust_val) 324 | 325 | if i % self.print_batch == 0: 326 | g = best_gene.reshape(self.poses) 327 | myplt.repaint() 328 | myplt.ax_cur().imshow(g) 329 | myplt.pause() 330 | # save a temp 331 | sio.savemat(f'{outfile}/GA_{self.aim_number}_ep{i}_{a}.mat', 332 | {'mask': g}) 333 | 334 | if a >= best[0]: 335 | best = (a, best_gene) 336 | 337 | if i % self.print_batch == self.print_batch - 1 or i == 0: 338 | log(pbar, i, a, b) 339 | if a / b < self.alpha: 340 | log(pbar, i, a, b) 341 | print("进化终止,算法已收敛!共进化 ", i + 1, " 代!") 342 | break 343 | 344 | show_best(best) 345 | 346 | 347 | if __name__ == '__main__': 348 | outfile = f'find_filter/find_filter_res/sofaEigenGate_sigma0d5_dir1d5_num70' # lab_trans_sigma0d3_wei1d0_num40 349 | if not os.path.exists(outfile): 350 | os.makedirs(outfile) 351 | 352 | sigma = 0.5 # dis 353 | number = 70 354 | cosSim_gamma = 1.5 # view dir exp 355 | ssfs_var_weight = 1 356 | print(f'running sigma = {sigma}') 357 | print(f'running number = {number}') 358 | t1 = time.time() 359 | gs = GeneSolve(r'myspecdata/filterset_gate/sofa_eigen/poses_bounds.npy', 360 | r'myspecdata/filterset_gate/filters_eigen/*.mat', 361 | 20000, 1000, 0.65, 0.1, 1.05, (15, 14), number, np.sqrt(sigma)) 362 | gs.evolve() 363 | t2 = time.time() 364 | 365 | print(f'cost {(t2 - t1) / 60.0} mins') 366 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm.auto import tqdm 4 | from opt import config_parser, args 5 | from dataLoader.IDsampler import get_simple_sampler 6 | from shutil import copy 7 | import scipy.io as sio 8 | import json, random 9 | from renderer import * 10 | from utils import * 11 | from tensorboardX import SummaryWriter 12 | import datetime 13 | from dataLoader.llff import LLFFDataset, get_dataset4RGBtraining 14 | from dataLoader import dataset_dict 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | renderer = OctreeRender_trilinear_fast 20 | 21 | lastfilename = 'null' 22 | def saveModel(model, filepath): 23 | global lastfilename 24 | try: 25 | model.save(filepath) 26 | except Exception as e: 27 | print(traceback.format_exc()) 28 | lastfilename = filepath 29 | return 30 | 31 | # code runs to here, only when the saving was successful 32 | # clear last 10 round model 33 | if os.path.exists(lastfilename): 34 | os.remove(lastfilename) 35 | lastfilename = filepath 36 | 37 | class SimpleSampler: 38 | def __init__(self, total, batch): 39 | self.total = total 40 | self.batch = batch 41 | self.curr = total 42 | self.ids = None 43 | 44 | def nextids(self): 45 | self.curr+=self.batch 46 | if self.curr + self.batch > self.total: 47 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 48 | self.curr = 0 49 | return self.ids[self.curr:self.curr+self.batch] 50 | 51 | def set_batch(self, batch): 52 | self.batch = batch 53 | 54 | @torch.no_grad() 55 | def export_mesh(args): 56 | 57 | ckpt = torch.load(args.ckpt, map_location=device) 58 | kwargs = ckpt['kwargs'] 59 | kwargs.update({'device': device}) 60 | tensorf = eval(args.model_name)(**kwargs) 61 | tensorf.load(ckpt) 62 | 63 | alpha,_ = tensorf.getDenseAlpha() 64 | convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005) 65 | 66 | 67 | @torch.no_grad() 68 | def render_test(args): 69 | # init dataset 70 | dataset = dataset_dict[args.dataset_name] 71 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) 72 | white_bg = test_dataset.white_bg 73 | ndc_ray = args.ndc_ray 74 | 75 | if not os.path.exists(args.ckpt): 76 | print('the ckpt path does not exists!!') 77 | return 78 | 79 | ckpt = torch.load(args.ckpt, map_location=device) 80 | kwargs = ckpt['kwargs'] 81 | kwargs.update({'device': device}) 82 | tensorf = eval(args.model_name)(**kwargs) 83 | tensorf.load(ckpt) 84 | 85 | logfolder = os.path.dirname(args.ckpt) 86 | if args.render_train: 87 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 88 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) 89 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', 90 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 91 | print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================') 92 | 93 | if args.render_test: 94 | os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) 95 | evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', 96 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 97 | 98 | if args.render_path: 99 | c2ws = test_dataset.render_path 100 | os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True) 101 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/', 102 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 103 | 104 | if args.render_test_exhibition: 105 | filtersets = [np.array([1/31] * 31)] # sio.loadmat(args.exhibition_filters_path)['filtersets'][0].tolist() 106 | ssfs = sio.loadmat(args.exhibition_ssfs_path)['ssfs'][0].tolist()[:1] 107 | lights = sio.loadmat(args.exhibition_lights_path)['spds'][0].tolist() 108 | c2ws = test_dataset.render_path 109 | os.makedirs(f'{logfolder}/{args.expname}/imgs_ambient_exhibition', exist_ok=True) 110 | exhibition(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_ambient_exhibition/',scale=True, 111 | N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray, filtersets=filtersets, ssfs=ssfs, lights=lights) 112 | 113 | 114 | def reconstruction(args): 115 | 116 | # init dataset 117 | dataset = dataset_dict[args.dataset_name] 118 | train_dataset: LLFFDataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) 119 | test_dataset: LLFFDataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) 120 | white_bg = train_dataset.white_bg 121 | near_far = train_dataset.near_far 122 | ndc_ray = args.ndc_ray 123 | 124 | # init resolution 125 | upsamp_list = args.upsamp_list 126 | update_AlphaMask_list = args.update_AlphaMask_list 127 | n_lamb_sigma = args.n_lamb_sigma 128 | n_lamb_sh = args.n_lamb_sh 129 | 130 | 131 | if args.add_timestamp: 132 | logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' 133 | else: 134 | logfolder = f'{args.basedir}/{args.expname}' 135 | 136 | 137 | # init log file 138 | os.makedirs(logfolder, exist_ok=True) 139 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) 140 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) 141 | os.makedirs(f'{logfolder}/rgba', exist_ok=True) 142 | os.makedirs(f'{logfolder}/SSFs', exist_ok=True) 143 | os.makedirs(f'{logfolder}/SSFs/RBFparams', exist_ok=True) 144 | # copy config file 145 | copy(args.config, logfolder) 146 | 147 | summary_writer = SummaryWriter(logfolder) 148 | 149 | 150 | 151 | # init parameters 152 | # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0]) 153 | aabb = train_dataset.scene_bbox.to(device) 154 | reso_cur = N_to_reso(args.N_voxel_init, aabb) 155 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) 156 | 157 | 158 | if args.ckpt is not None: 159 | ckpt = torch.load(args.ckpt, map_location=device) 160 | kwargs = ckpt['kwargs'] 161 | kwargs.update({'device':device}) 162 | tensorf = eval(args.model_name)(**kwargs) 163 | tensorf.load(ckpt) 164 | else: 165 | tensorf = eval(args.model_name)(aabb, reso_cur, device, 166 | density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far, 167 | shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale, 168 | pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct) 169 | 170 | tensorf: TensorVMSplit = tensorf.cuda() 171 | 172 | grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis) 173 | if args.lr_decay_iters > 0: 174 | lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters) 175 | else: 176 | args.lr_decay_iters = args.n_iters 177 | lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) 178 | 179 | print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) 180 | 181 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) 182 | criterian = torch.nn.MSELoss() 183 | 184 | #linear in logrithmic space 185 | N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:] 186 | 187 | 188 | torch.cuda.empty_cache() 189 | PSNRs,PSNRs_test = [],[0] 190 | 191 | if args.rgb4shape_endIter: 192 | allrays_4shape, allrgbs_4shape, allposesID_4shape, all_filterID_4shape = get_dataset4RGBtraining(train_dataset) 193 | rgb4shapeSampler = SimpleSampler(allrays_4shape.shape[0], args.batch_size - args.depth_supervise * args.depth_batchsize_endIter[0]) 194 | 195 | allrays, allrgbs, allposesID, all_filterID = train_dataset.all_rays, train_dataset.all_rgbs, train_dataset.all_poses, train_dataset.all_filtersIdx 196 | if not args.ndc_ray: 197 | allrays, allrgbs, mask_filtered = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True) 198 | allposesID, all_filterID = allposesID[mask_filtered], all_filterID[mask_filtered] 199 | trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size - args.depth_supervise * args.depth_batchsize_endIter[0]) 200 | 201 | if args.depth_supervise: 202 | depthrays, depthweights, depthvalue = \ 203 | train_dataset.depth_rays, train_dataset.depth_weight, train_dataset.depth_value 204 | depthSampler = SimpleSampler(depthrays.shape[0], args.depth_batchsize_endIter[0]) 205 | 206 | Ortho_reg_weight = args.Ortho_weight 207 | print("initial Ortho_reg_weight", Ortho_reg_weight) 208 | 209 | L1_reg_weight = args.L1_weight_inital 210 | print("initial L1_reg_weight", L1_reg_weight) 211 | TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app 212 | tvreg = TVLoss() 213 | print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}") 214 | 215 | 216 | pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout) 217 | for iteration in pbar: 218 | if iteration < args.rgb4shape_endIter: 219 | ray_idx = rgb4shapeSampler.nextids() 220 | rays_train, rgb_train, poseID_train, filterID_train = allrays_4shape[ray_idx], allrgbs_4shape[ray_idx].to(device), allposesID_4shape[ray_idx], all_filterID_4shape[ray_idx] 221 | else: 222 | ray_idx = trainingSampler.nextids() 223 | rays_train, rgb_train, poseID_train, filterID_train = allrays[ray_idx], allrgbs[ray_idx].to(device), allposesID[ray_idx], all_filterID[ray_idx] 224 | 225 | if args.reset_para and args.rgb4shape_endIter == iteration: 226 | #reset lr 227 | for param_group in optimizer.param_groups: 228 | if hasattr(param_group, 'myname'): 229 | if param_group['myname'] in ['appLine', 'appPlane']: 230 | param_group['lr'] = args.lr_init * 1.2 231 | else: 232 | param_group['lr'] = args.lr_basis * 1.2 233 | 234 | if args.depth_supervise: 235 | depth_rays_idx = depthSampler.nextids() 236 | depth_rays_train, depth_wei_train, depth_val_train = \ 237 | depthrays[depth_rays_idx], depthweights[depth_rays_idx].to(device), depthvalue[depth_rays_idx].to(device) 238 | 239 | fake_poseid = torch.LongTensor([[0]]).expand((depth_rays_idx.shape[0], -1)) 240 | fake_filterID = fake_poseid 241 | rays_train = torch.cat([rays_train, depth_rays_train]) 242 | poseID_train = torch.cat([poseID_train, fake_poseid]) 243 | filterID_train = torch.cat([filterID_train, fake_filterID]) 244 | 245 | #rgb_map, alphas_map, depth_map, weights, uncertainty 246 | rgb_map, alphas_map, depth_map, weights, uncertainty, dist_loss, spec_map, ssf = \ 247 | renderer(rays_train, tensorf, N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, \ 248 | is_train=True, poseids=poseID_train, filterids=filterID_train) 249 | 250 | if args.depth_supervise: 251 | rgb_batch = args.batch_size - args.depth_batchsize_endIter[0] 252 | # disentangle 253 | rgb_map = rgb_map[:rgb_batch] 254 | depth_map, depth_supervise = depth_map[:rgb_batch], depth_map[rgb_batch:] 255 | 256 | # depth map and loss 257 | depth_est_mapped = tensorf.depth_linear(depth_supervise) 258 | depth_loss = torch.mean((torch.abs(depth_val_train - depth_est_mapped)) * depth_wei_train) 259 | depth_loss_print = depth_loss.detach().item() 260 | summary_writer.add_scalar('train/depth_loss', depth_loss_print, global_step=iteration) 261 | # depth supervise only used for several rounds 262 | if iteration + 1 == args.depth_batchsize_endIter[1]: 263 | args.depth_supervise = False 264 | trainingSampler.set_batch(args.batch_size) 265 | else: 266 | depth_loss = depth_loss_print = 0 267 | 268 | if args.distortion_loss: 269 | dist_loss = 0.1 * dist_loss 270 | summary_writer.add_scalar('train/dist_loss', dist_loss, global_step=iteration) 271 | 272 | 273 | rgbloss = (((rgb_map - rgb_train) / (rgb_map.detach() + 0.001)) ** 2).mean() 274 | psnrloss = criterian(rgb_map, rgb_train).detach().item() # temp 275 | 276 | # loss 277 | total_loss = rgbloss + depth_loss + dist_loss 278 | # if 'rbf' in args.ssf_model.lower(): 279 | # loss_ssfTV = TVloss_SSF(ssf) 280 | # total_loss += loss_ssfTV * 0.1 281 | # RBFmeans = tensorf.ssfnet.params[:, 0] 282 | # spreadloss = 1 / calculate_sum_diff(RBFmeans) 283 | # total_loss += spreadloss * 0.01 284 | 285 | if args.TV_weight_spec > 0: 286 | loss_specTV = TVloss_Spectral(spec_map) 287 | total_loss += loss_specTV * args.TV_weight_spec 288 | summary_writer.add_scalar('train/specTV', loss_specTV.detach().item(), global_step=iteration) 289 | else: 290 | loss_specTV = 0 291 | 292 | if Ortho_reg_weight > 0: 293 | loss_reg = tensorf.vector_comp_diffs() 294 | total_loss += Ortho_reg_weight*loss_reg 295 | summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration) 296 | if L1_reg_weight > 0: 297 | loss_reg_L1 = tensorf.density_L1() 298 | total_loss += L1_reg_weight*loss_reg_L1 299 | summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration) 300 | 301 | if TV_weight_density>0: 302 | TV_weight_density *= lr_factor 303 | loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density 304 | total_loss = total_loss + loss_tv 305 | summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration) 306 | if TV_weight_app>0: 307 | TV_weight_app *= lr_factor 308 | loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app 309 | total_loss = total_loss + loss_tv 310 | summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration) 311 | 312 | optimizer.zero_grad() 313 | total_loss.backward() 314 | optimizer.step() 315 | 316 | PSNRs.append(-10.0 * np.log(psnrloss) / np.log(10.0)) 317 | summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration) 318 | summary_writer.add_scalar('train/mse', rgbloss, global_step=iteration) 319 | 320 | 321 | for param_group in optimizer.param_groups: 322 | param_group['lr'] = param_group['lr'] * lr_factor 323 | 324 | # Print the current values of the losses. 325 | if iteration % args.progress_refresh_rate == 0: 326 | # print('\ndepth linear para (a, b) is', tensorf.depth_linear.a, tensorf.depth_linear.b, '\n') 327 | 328 | pbar.set_description( 329 | f'Iteration {iteration:05d}:' 330 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}' 331 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' 332 | + f' mse = {rgbloss:.6f}' 333 | + f' depth_loss = {depth_loss_print:.6f}' 334 | + f' dist_loss = {dist_loss:.6f}' 335 | + f' loss_specTV = {loss_specTV:.6f}' 336 | ) 337 | PSNRs = [] 338 | 339 | 340 | if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0: 341 | saveModel(tensorf, f'{logfolder}/{iteration}_{args.expname}.pth') 342 | sio.savemat(f'{logfolder}/SSFs/ssf_{iteration}.mat', 343 | {'ssf': ssf.cpu().detach().numpy()}) 344 | if 'rbf' in args.ssf_model.lower(): 345 | sio.savemat(f'{logfolder}/SSFs/RBFparams/param_{iteration}.mat', 346 | {'rbf': tensorf.ssfnet.params.data.cpu().detach().numpy()}) 347 | 348 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis, 349 | prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False) 350 | summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration) 351 | 352 | 353 | 354 | if iteration in update_AlphaMask_list: 355 | 356 | if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution 357 | reso_mask = reso_cur 358 | new_aabb = tensorf.updateAlphaMask(tuple(reso_mask)) 359 | if iteration == update_AlphaMask_list[0]: 360 | tensorf.shrink(new_aabb) 361 | # tensorVM.alphaMask = None 362 | L1_reg_weight = args.L1_weight_rest 363 | print("continuing L1_reg_weight", L1_reg_weight) 364 | 365 | 366 | if not args.ndc_ray and iteration == update_AlphaMask_list[1]: 367 | # filter rays outside the bbox 368 | allrays, allrgbs, mask_filtered = tensorf.filtering_rays(allrays,allrgbs) 369 | allposesID, all_filterID = allposesID[mask_filtered], all_filterID[mask_filtered] 370 | trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size - args.depth_supervise * args.depth_batchsize_endIter[0]) 371 | 372 | 373 | if iteration in upsamp_list: 374 | n_voxels = N_voxel_list.pop(0) 375 | reso_cur = N_to_reso(n_voxels, tensorf.aabb) 376 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) 377 | tensorf.upsample_volume_grid(reso_cur) 378 | 379 | if args.lr_upsample_reset: 380 | print("reset lr to initial") 381 | lr_scale = 1 #0.1 ** (iteration / args.n_iters) 382 | else: 383 | lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters) 384 | grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale) 385 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 386 | 387 | 388 | tensorf.save(f'{logfolder}/{args.expname}.th') 389 | 390 | 391 | if args.render_train: 392 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 393 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) 394 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', 395 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 396 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 397 | 398 | if args.render_test: 399 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 400 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/', 401 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 402 | summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) 403 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 404 | 405 | if args.render_path: 406 | c2ws = test_dataset.render_path 407 | # c2ws = test_dataset.poses 408 | print('========>',c2ws.shape) 409 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True) 410 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/', 411 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 412 | 413 | 414 | if __name__ == '__main__': 415 | 416 | torch.set_default_dtype(torch.float32) 417 | torch.manual_seed(20211202) 418 | np.random.seed(20211202) 419 | 420 | print(args) 421 | 422 | if args.export_mesh: 423 | export_mesh(args) 424 | exit(0) 425 | 426 | if args.render_only and (args.render_test or args.render_path or args.render_test_exhibition): 427 | render_test(args) 428 | else: 429 | reconstruction(args) 430 | 431 | -------------------------------------------------------------------------------- /dataLoader/llff.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import traceback 3 | import cv2 4 | import torch 5 | from torch.utils.data import Dataset 6 | import glob 7 | import numpy as np 8 | import os 9 | from PIL import Image 10 | from torchvision import transforms as T 11 | import rawpy 12 | from rawpy._rawpy import ColorSpace 13 | from .ray_utils import * 14 | from opt import args 15 | from colmapUtils.read_write_model import read_images_binary, read_points3d_binary 16 | # from colmapUtils.read_write_dense import read_points3d_binary 17 | import scipy.io as sio 18 | 19 | 20 | def normalize(v): 21 | """Normalize a vector.""" 22 | return v / np.linalg.norm(v) 23 | 24 | 25 | def average_poses(poses): 26 | """ 27 | Calculate the average pose, which is then used to center all poses 28 | using @center_poses. Its computation is as follows: 29 | 1. Compute the center: the average of pose centers. 30 | 2. Compute the z axis: the normalized average z axis. 31 | 3. Compute axis y': the average y axis. 32 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 33 | 5. Compute the y axis: z cross product x. 34 | 35 | Note that at step 3, we cannot directly use y' as y axis since it's 36 | not necessarily orthogonal to z axis. We need to pass from x to y. 37 | Inputs: 38 | poses: (N_images, 3, 4) 39 | Outputs: 40 | pose_avg: (3, 4) the average pose 41 | """ 42 | # 1. Compute the center 43 | center = poses[..., 3].mean(0) # (3) 44 | 45 | # 2. Compute the z axis 46 | z = normalize(poses[..., 2].mean(0)) # (3) 47 | 48 | # 3. Compute axis y' (no need to normalize as it's not the final output) 49 | y_ = poses[..., 1].mean(0) # (3) 50 | 51 | # 4. Compute the x axis 52 | x = normalize(np.cross(z, y_)) # (3) 53 | 54 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 55 | y = np.cross(x, z) # (3) 56 | 57 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 58 | 59 | return pose_avg 60 | 61 | 62 | def center_poses(poses, blender2opencv): 63 | """ 64 | Center the poses so that we can use NDC. 65 | See https://github.com/bmild/nerf/issues/34 66 | Inputs: 67 | poses: (N_images, 3, 4) 68 | Outputs: 69 | poses_centered: (N_images, 3, 4) the centered poses 70 | pose_avg: (3, 4) the average pose 71 | """ 72 | poses = poses @ blender2opencv 73 | pose_avg = average_poses(poses) # (3, 4) 74 | pose_avg_homo = np.eye(4) 75 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 76 | pose_avg_homo = pose_avg_homo 77 | # by simply adding 0, 0, 0, 1 as the last row 78 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 79 | poses_homo = \ 80 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 81 | 82 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 83 | # poses_centered = poses_centered @ blender2opencv 84 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 85 | 86 | return poses_centered, pose_avg_homo 87 | 88 | 89 | def viewmatrix(z, up, pos): 90 | vec2 = normalize(z) 91 | vec1_avg = up 92 | vec0 = normalize(np.cross(vec1_avg, vec2)) 93 | vec1 = normalize(np.cross(vec2, vec0)) 94 | m = np.eye(4) 95 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) 96 | return m 97 | 98 | 99 | def render_path_spiral(c2w, up, rads, focal, zrate, N_rots=2, N=120): 100 | render_poses = [] 101 | rads = np.array(list(rads) + [1.]) 102 | 103 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: 104 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 105 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 106 | render_poses.append(viewmatrix(z, up, c)) 107 | return render_poses 108 | 109 | 110 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120, n_rot=2, focal=None): 111 | # center pose 112 | c2w = average_poses(c2ws_all) 113 | 114 | # Get average pose 115 | up = normalize(c2ws_all[:, :3, 1].sum(0)) 116 | 117 | # Find a reasonable "focus depth" for this dataset 118 | dt = 0.75 119 | if near_fars is not None: 120 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 121 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 122 | 123 | # Get radii for spiral path 124 | tt = c2ws_all[:, :3, 3] 125 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale 126 | render_poses = render_path_spiral(c2w, up, rads, focal, zrate=.5, N=N_views, N_rots=n_rot) 127 | return np.stack(render_poses) 128 | 129 | 130 | def _find_test_sample(mtx): 131 | mtxcp = mtx[...] 132 | mtxcp[:, 0] = 1 133 | zero_inds = np.argwhere(mtxcp == 0) 134 | choose_zero_inds = np.random.choice(np.array(range(zero_inds.shape[0])), 3).tolist() 135 | # choose_zero_inds = [[0, 0], [5, 0], [11, 0]] 136 | 137 | sample_mtx = np.zeros_like(mtx) 138 | for idx in choose_zero_inds: 139 | sample_mtx[tuple(zero_inds[idx])] = 1 140 | 141 | return sample_mtx 142 | 143 | class LLFFDataset: 144 | # black = T.ToTensor()(sio.loadmat('./myspecdata/decorner/meanblack.mat')['data']) 145 | filters_back = [] 146 | depth_mean = 1 147 | 148 | def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8): 149 | """ 150 | spheric_poses: whether the images are taken in a spheric inward-facing manner 151 | default: False (forward-facing) 152 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 153 | """ 154 | 155 | self.root_dir = datadir 156 | self.split = split 157 | self.hold_every = hold_every 158 | self.is_stack = is_stack 159 | self.downsample = downsample 160 | 161 | if args.lsc: 162 | # fix the vignetting effect 163 | LLFFDataset.white = T.ToTensor()(sio.loadmat('./meanwhite_max.mat')['data']) 164 | 165 | 166 | self.parameter_setting() 167 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 168 | self.prepare_filters(LLFFDataset) 169 | self.read_meta() # read poses, also for depth data and rays 170 | self.load_img() # read images and form rays 171 | self.white_bg = args.white_bkgd 172 | 173 | 174 | def parameter_setting(self): 175 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] 176 | self.near_far = [0.0, 1.0] if args.ndc_ray == 1 else [0.01, 6.0] 177 | self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]) if args.ndc_ray == 1 else \ 178 | torch.tensor([[-7.0, -7.0, -5], [7.0, 7.0, 5]]) 179 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) 180 | self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) 181 | self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 182 | 183 | 184 | @staticmethod 185 | def prepare_filters(cls): 186 | bandstart = args.band_start_idx 187 | if not isinstance(cls.filters_back, list): 188 | return 189 | for i in range(0, args.filters + 1): 190 | fi = torch.FloatTensor( 191 | np.diagonal(sio.loadmat(os.path.join(args.datadir, f'../{args.filters_folder}/f_{i}.mat'))['filter']) 192 | )[bandstart: bandstart + args.spec_channel] 193 | 194 | cls.filters_back.append(fi) 195 | cls.filters_back = torch.stack(cls.filters_back) 196 | 197 | 198 | def read_meta(self): 199 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) 200 | 201 | # load full resolution image then resize 202 | if self.split in ['train', 'test']: 203 | assert len(poses_bounds) == args.angles, \ 204 | 'Mismatch between number of args.angles and number of poses! Please rerun COLMAP!' 205 | 206 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 207 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2) 208 | 209 | if args.depth_supervise: 210 | self.depth_list = self.load_colmap_depth(poses, self.near_fars, 211 | Path(self.root_dir), self.downsample) 212 | 213 | croph, cropw = args.crop_hw 214 | poses[:,0, -1] = croph 215 | poses[:,1, -1] = cropw 216 | 217 | # Step 1: rescale focal length according to training resolution 218 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images 219 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) 220 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] 221 | 222 | # Step 2: correct poses 223 | # Original poses has rotation in form "down right back", change to "right up back" 224 | # See https://github.com/bmild/nerf/issues/34 225 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 226 | # (N_images, 3, 4) exclude H, W, focal 227 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) 228 | 229 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 230 | # See https://github.com/bmild/nerf/issues/34 231 | near_original = self.near_fars.min() 232 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 233 | # the nearest depth is at 1/0.75=1.33 234 | self.near_fars /= scale_factor 235 | self.poses[..., 3] /= scale_factor 236 | 237 | # build rendering path 238 | N_views, N_rots = 60, 1 # 120, 2 239 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views, n_rot=N_rots, rads_scale=0.3) 240 | 241 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 242 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to 243 | # center image 244 | 245 | # ray directions for all pixels, same for all images (same H, W, focal) 246 | W, H = self.img_wh 247 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) 248 | 249 | if args.depth_supervise: 250 | self.depth_rays = self.get_depth_rays() 251 | self.combine_depthimages() 252 | 253 | 254 | def _fix_sample_matrix(self): 255 | self.training_matrix = np.hstack((np.array([0] * args.angles)[:, np.newaxis], 256 | sio.loadmat(args.sample_matrix_dir)['mask'])) # first column is pure rgb image 257 | self.training_matrix[:, args.colIdx4RGBTrain] = 1 # the column image is used for geometry training 258 | sample_matrix = self.training_matrix if self.split == 'train' else _find_test_sample(self.training_matrix) 259 | print(f'{sample_matrix.sum()} of images are loading...') 260 | 261 | return sample_matrix 262 | 263 | 264 | def load_img(self): 265 | rays_savePath = Path(args.datadir) / f"rays_idgeo{args.colIdx4RGBTrain}_ndc{args.ndc_ray}_{self.split}_ds{self.downsample}_mtx{os.path.split(args.sample_matrix_dir)[1][:-4]}.pth" 266 | folders = [Path(args.datadir) / args.img_dir_name.replace('??', str(i)) 267 | for i in range(args.angles)] 268 | sample_matrix = self._fix_sample_matrix() 269 | 270 | W, H = self.img_wh 271 | # use first N_images-1 to train, the LAST is val 272 | if os.path.exists(rays_savePath): 273 | data = torch.load(rays_savePath) 274 | all_rays, all_rgbs, all_poses, all_filtersIdx, ids4shapeTrain = \ 275 | data['rays'], data['rgbs'], data['poses'], data['filterids'], data['id4geo'] 276 | else: 277 | all_rays = [] 278 | all_rgbs = [] 279 | all_poses = [] 280 | all_filtersIdx = [] 281 | ids4shapeTrain = [] 282 | tensor_cropper = T.CenterCrop(args.crop_hw) 283 | tensor_resizer = T.Resize([H, W], antialias=True) 284 | for r, row in enumerate(sample_matrix): 285 | image_paths = sorted(glob.glob(str(folders[r] / f"images/*{args.img_ext}"))) 286 | for c, aimEle in enumerate(row): 287 | if aimEle != 1: 288 | continue 289 | elif c == args.colIdx4RGBTrain: 290 | # it's for geometry training 291 | ids4shapeTrain.append(len(all_rays)) 292 | 293 | image_path = image_paths[c] 294 | c2w = torch.FloatTensor(self.poses[r]) 295 | 296 | img = self.read_non_raw(image_path) # c h w [0-1] 297 | if args.lsc: 298 | img = LLFFDataset.img_correction(img) # lens shade correction & black level correction 299 | img = tensor_cropper(img) # c croph cropw [0-1] 300 | if self.downsample != 1.0: 301 | img = tensor_resizer(img) 302 | 303 | img = img.view(args.observation_channel, -1).permute(1, 0) # (h*w, 3) RGB 304 | all_rgbs.append(img) 305 | 306 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 307 | if args.ndc_ray == 1: 308 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 309 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 310 | all_rays.append(torch.cat([rays_o, rays_d], 1)) # (h*w, 6) 311 | all_poses.append(torch.LongTensor([[r]]).expand([rays_o.shape[0], -1])) 312 | all_filtersIdx.append(torch.LongTensor([[c]]).expand([rays_o.shape[0], -1])) 313 | 314 | torch.save({ 315 | 'rgbs': all_rgbs, 316 | 'rays': all_rays, 317 | 'poses': all_poses, 318 | 'filterids': all_filtersIdx, 319 | 'id4geo': ids4shapeTrain 320 | }, rays_savePath) 321 | print(f'{len(all_rgbs)} of images are loaded!') 322 | 323 | self.ids4shapeTrain = ids4shapeTrain 324 | self.raysnum_oneimage = all_rays[0].shape[0] 325 | if not self.is_stack: 326 | self.all_rays = torch.cat(all_rays, 0) # (len(self.meta['frames])*h*w, 3) 327 | self.all_rgbs = torch.cat(all_rgbs, 0) # (len(self.meta['frames])*h*w,3) 328 | self.all_poses = torch.cat(all_poses, 0) 329 | self.all_filtersIdx = torch.cat(all_filtersIdx, 0) 330 | else: 331 | self.all_rays = torch.stack(all_rays, 0) # (len(self.meta['frames]),h*w, 3) 332 | self.all_rgbs = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], args.observation_channel) # (len(self.meta['frames]),h,w,3) 333 | self.all_poses = torch.stack(all_poses, 0) 334 | self.all_filtersIdx = torch.stack(all_filtersIdx, 0) 335 | 336 | 337 | def load_colmap_depth(self, poses, bds_raw, basedir, factor=8, bd_factor=.75): 338 | croph, cropw = args.crop_hw 339 | images = read_images_binary(Path(basedir) / 'sparse' / '0' / 'images.bin') 340 | points = read_points3d_binary(Path(basedir) / 'sparse' / '0' / 'points3D.bin') 341 | 342 | Errs = np.array([point3D.error for point3D in points.values()]) 343 | Err_mean = np.mean(Errs) 344 | print("Mean Projection Error:", Err_mean) 345 | 346 | # print(bds_raw.shape) 347 | # Rescale if bd_factor is provided 348 | sc = 1. if bd_factor is None else 1. / (bds_raw.min() * bd_factor) 349 | 350 | H, W, focal = poses[0, :, -1] 351 | near = np.ndarray.min(bds_raw) * .9 * sc 352 | far = np.ndarray.max(bds_raw) * 1. * sc 353 | print('near/far:', near, far) 354 | 355 | data_list = [] 356 | for id_im in range(1, len(images) + 1): 357 | # if id_im - 1 not in image_pick4depth: 358 | # continue 359 | depth_list = [] 360 | coord_list = [] 361 | weight_list = [] 362 | for i in range(len(images[id_im].xys)): 363 | point2D = images[id_im].xys[i] # w h 364 | if np.abs(point2D[0] - 0.5*W) > cropw * 0.5 or np.abs(point2D[1] - 0.5*H) > croph * 0.5: 365 | # outside of the crop border 366 | continue 367 | id_3D = images[id_im].point3D_ids[i] 368 | if id_3D == -1: 369 | continue 370 | point3D = points[id_3D].xyz 371 | depth = ((-poses[id_im - 1, :3, 2]).T @ (point3D - poses[id_im - 1, :3, 3])) * sc 372 | if depth < bds_raw[id_im - 1, 0] * sc or depth > bds_raw[id_im - 1, 1] * sc: 373 | continue 374 | err = points[id_3D].error 375 | weight = 2 * np.exp(-(err / Err_mean) ** 2) 376 | depth_list.append(depth) 377 | # re-position the coords relative to the crop size 378 | coord_list.append((point2D - [0.5*(W-cropw), 0.5*(H-croph)]) / factor) 379 | weight_list.append(weight) 380 | if len(depth_list) > 0: 381 | print(id_im, len(depth_list), np.min(depth_list), np.max(depth_list), np.mean(depth_list)) 382 | data_list.append( 383 | {"depth": np.array(depth_list, dtype=np.float32), 384 | "coord": np.array(coord_list, dtype=np.float32), 385 | "weight": np.array(weight_list, dtype=np.float32)}) 386 | else: 387 | print(id_im, len(depth_list)) 388 | # json.dump(data_list, open(data_file, "w")) 389 | return data_list 390 | 391 | 392 | def get_depth_rays(self): 393 | W, H = self.img_wh 394 | 395 | data_list = [] 396 | for ind, img_d in enumerate(self.depth_list): 397 | coord = torch.from_numpy(img_d['coord']) 398 | rays_o, rays_d = get_rays_by_coord_np(H, W, self.focal, self.poses[ind], coord) 399 | if args.ndc_ray == 1: 400 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 401 | rayso_d = torch.cat([rays_o.float(), rays_d.float()], 1) # (h*w, 6) 402 | 403 | data_list.append(rayso_d) 404 | return data_list 405 | 406 | 407 | def combine_depthimages(self): 408 | num = len(self.depth_list) 409 | depth_rays = [] 410 | depth_weight = [] 411 | depth_value = [] 412 | 413 | for i in range(num): 414 | d_list = self.depth_list[i] 415 | d_rays = self.depth_rays[i] 416 | 417 | depth_rays.append(d_rays) 418 | depth_weight.append(torch.from_numpy(d_list['weight'][:, np.newaxis])) 419 | depth_value.append(torch.from_numpy(d_list['depth'][:, np.newaxis])) 420 | 421 | self.depth_rays = torch.cat(depth_rays, 0) 422 | self.depth_weight = torch.cat(depth_weight, 0) 423 | self.depth_value = torch.cat(depth_value, 0) 424 | 425 | LLFFDataset.depth_mean = self.depth_value.mean() 426 | 427 | def read_non_raw(self, image_path): 428 | is_dng = image_path.endswith('.dng') 429 | if is_dng: 430 | try: 431 | with rawpy.imread(image_path) as raw: 432 | rgb = raw.postprocess(user_wb=(1, 1, 1, 1), output_color=ColorSpace.raw, 433 | no_auto_bright=True, output_bps=16, gamma=(1, 1)) 434 | except: 435 | print(traceback.format_exc()) 436 | img = self.transform(rgb) # c h w [0~2^16] 437 | elif image_path.endswith('.tiff'): 438 | rgb = Image.open(image_path) 439 | im_arr = np.array(rgb, dtype=np.uint16)[..., np.newaxis] 440 | img = self.transform(im_arr) # c h w [0-1] 441 | else: 442 | img = Image.open(image_path).convert('RGB') 443 | img = self.transform(img, 255.) # c h w [0-1] 444 | 445 | return img 446 | 447 | @classmethod 448 | def img_correction(cls, img): 449 | return torch.clamp_max(torch.clamp_min(img - 0.014, 0.002) / cls.white, 1) 450 | # return (img - 0.014) / cls.white 451 | 452 | def transform(self, img, maxbits_num=65535.): 453 | return torch.FloatTensor(img / maxbits_num).permute(2, 0, 1) 454 | 455 | 456 | def get_dataset4RGBtraining(dataset: LLFFDataset): 457 | allrays, allrgbs, allposesID, all_filterID = [], [], [], [] 458 | id_in_group = dataset.ids4shapeTrain 459 | group_length = dataset.raysnum_oneimage 460 | 461 | for i in id_in_group: 462 | slc = slice(i * group_length, (i + 1) * group_length) 463 | 464 | allrays.append(dataset.all_rays[slc]) 465 | allrgbs.append(dataset.all_rgbs[slc]) 466 | allposesID.append(dataset.all_poses[slc]) 467 | all_filterID.append(dataset.all_filtersIdx[slc]) 468 | 469 | return torch.cat(allrays), torch.cat(allrgbs), torch.cat(allposesID), torch.cat(all_filterID) 470 | -------------------------------------------------------------------------------- /colmapUtils/read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | import argparse 38 | 39 | 40 | CameraModel = collections.namedtuple( 41 | "CameraModel", ["model_id", "model_name", "num_params"]) 42 | Camera = collections.namedtuple( 43 | "Camera", ["id", "model", "width", "height", "params"]) 44 | BaseImage = collections.namedtuple( 45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 46 | Point3D = collections.namedtuple( 47 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 48 | 49 | 50 | class Image(BaseImage): 51 | def qvec2rotmat(self): 52 | return qvec2rotmat(self.qvec) 53 | 54 | 55 | CAMERA_MODELS = { 56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 63 | CameraModel(model_id=7, model_name="FOV", num_params=5), 64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 67 | } 68 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 69 | for camera_model in CAMERA_MODELS]) 70 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 71 | for camera_model in CAMERA_MODELS]) 72 | 73 | 74 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 75 | """Read and unpack the next bytes from a binary file. 76 | :param fid: 77 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 78 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 79 | :param endian_character: Any of {@, =, <, >, !} 80 | :return: Tuple of read and unpacked values. 81 | """ 82 | data = fid.read(num_bytes) 83 | return struct.unpack(endian_character + format_char_sequence, data) 84 | 85 | 86 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 87 | """pack and write to a binary file. 88 | :param fid: 89 | :param data: data to send, if multiple elements are sent at the same time, 90 | they should be encapsuled either in a list or a tuple 91 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 92 | should be the same length as the data list or tuple 93 | :param endian_character: Any of {@, =, <, >, !} 94 | """ 95 | if isinstance(data, (list, tuple)): 96 | bytes = struct.pack(endian_character + format_char_sequence, *data) 97 | else: 98 | bytes = struct.pack(endian_character + format_char_sequence, data) 99 | fid.write(bytes) 100 | 101 | 102 | def read_cameras_text(path): 103 | """ 104 | see: src/base/reconstruction.cc 105 | void Reconstruction::WriteCamerasText(const std::string& path) 106 | void Reconstruction::ReadCamerasText(const std::string& path) 107 | """ 108 | cameras = {} 109 | with open(path, "r") as fid: 110 | while True: 111 | line = fid.readline() 112 | if not line: 113 | break 114 | line = line.strip() 115 | if len(line) > 0 and line[0] != "#": 116 | elems = line.split() 117 | camera_id = int(elems[0]) 118 | model = elems[1] 119 | width = int(elems[2]) 120 | height = int(elems[3]) 121 | params = np.array(tuple(map(float, elems[4:]))) 122 | cameras[camera_id] = Camera(id=camera_id, model=model, 123 | width=width, height=height, 124 | params=params) 125 | return cameras 126 | 127 | 128 | def read_cameras_binary(path_to_model_file): 129 | """ 130 | see: src/base/reconstruction.cc 131 | void Reconstruction::WriteCamerasBinary(const std::string& path) 132 | void Reconstruction::ReadCamerasBinary(const std::string& path) 133 | """ 134 | cameras = {} 135 | with open(path_to_model_file, "rb") as fid: 136 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 137 | for _ in range(num_cameras): 138 | camera_properties = read_next_bytes( 139 | fid, num_bytes=24, format_char_sequence="iiQQ") 140 | camera_id = camera_properties[0] 141 | model_id = camera_properties[1] 142 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 143 | width = camera_properties[2] 144 | height = camera_properties[3] 145 | num_params = CAMERA_MODEL_IDS[model_id].num_params 146 | params = read_next_bytes(fid, num_bytes=8*num_params, 147 | format_char_sequence="d"*num_params) 148 | cameras[camera_id] = Camera(id=camera_id, 149 | model=model_name, 150 | width=width, 151 | height=height, 152 | params=np.array(params)) 153 | assert len(cameras) == num_cameras 154 | return cameras 155 | 156 | 157 | def write_cameras_text(cameras, path): 158 | """ 159 | see: src/base/reconstruction.cc 160 | void Reconstruction::WriteCamerasText(const std::string& path) 161 | void Reconstruction::ReadCamerasText(const std::string& path) 162 | """ 163 | HEADER = "# Camera list with one line of data per camera:\n" 164 | "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" 165 | "# Number of cameras: {}\n".format(len(cameras)) 166 | with open(path, "w") as fid: 167 | fid.write(HEADER) 168 | for _, cam in cameras.items(): 169 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 170 | line = " ".join([str(elem) for elem in to_write]) 171 | fid.write(line + "\n") 172 | 173 | 174 | def write_cameras_binary(cameras, path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::WriteCamerasBinary(const std::string& path) 178 | void Reconstruction::ReadCamerasBinary(const std::string& path) 179 | """ 180 | with open(path_to_model_file, "wb") as fid: 181 | write_next_bytes(fid, len(cameras), "Q") 182 | for _, cam in cameras.items(): 183 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 184 | camera_properties = [cam.id, 185 | model_id, 186 | cam.width, 187 | cam.height] 188 | write_next_bytes(fid, camera_properties, "iiQQ") 189 | for p in cam.params: 190 | write_next_bytes(fid, float(p), "d") 191 | return cameras 192 | 193 | 194 | def read_images_text(path): 195 | """ 196 | see: src/base/reconstruction.cc 197 | void Reconstruction::ReadImagesText(const std::string& path) 198 | void Reconstruction::WriteImagesText(const std::string& path) 199 | """ 200 | images = {} 201 | with open(path, "r") as fid: 202 | while True: 203 | line = fid.readline() 204 | if not line: 205 | break 206 | line = line.strip() 207 | if len(line) > 0 and line[0] != "#": 208 | elems = line.split() 209 | image_id = int(elems[0]) 210 | qvec = np.array(tuple(map(float, elems[1:5]))) 211 | tvec = np.array(tuple(map(float, elems[5:8]))) 212 | camera_id = int(elems[8]) 213 | image_name = elems[9] 214 | elems = fid.readline().split() 215 | xys = np.column_stack([tuple(map(float, elems[0::3])), 216 | tuple(map(float, elems[1::3]))]) 217 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 218 | images[image_id] = Image( 219 | id=image_id, qvec=qvec, tvec=tvec, 220 | camera_id=camera_id, name=image_name, 221 | xys=xys, point3D_ids=point3D_ids) 222 | return images 223 | 224 | 225 | def read_images_binary(path_to_model_file): 226 | """ 227 | see: src/base/reconstruction.cc 228 | void Reconstruction::ReadImagesBinary(const std::string& path) 229 | void Reconstruction::WriteImagesBinary(const std::string& path) 230 | """ 231 | images = {} 232 | with open(path_to_model_file, "rb") as fid: 233 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 234 | for _ in range(num_reg_images): 235 | binary_image_properties = read_next_bytes( 236 | fid, num_bytes=64, format_char_sequence="idddddddi") 237 | image_id = binary_image_properties[0] 238 | qvec = np.array(binary_image_properties[1:5]) 239 | tvec = np.array(binary_image_properties[5:8]) 240 | camera_id = binary_image_properties[8] 241 | image_name = "" 242 | current_char = read_next_bytes(fid, 1, "c")[0] 243 | while current_char != b"\x00": # look for the ASCII 0 entry 244 | image_name += current_char.decode("utf-8") 245 | current_char = read_next_bytes(fid, 1, "c")[0] 246 | num_points2D = read_next_bytes(fid, num_bytes=8, 247 | format_char_sequence="Q")[0] 248 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 249 | format_char_sequence="ddq"*num_points2D) 250 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 251 | tuple(map(float, x_y_id_s[1::3]))]) 252 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 253 | images[image_id] = Image( 254 | id=image_id, qvec=qvec, tvec=tvec, 255 | camera_id=camera_id, name=image_name, 256 | xys=xys, point3D_ids=point3D_ids) 257 | return images 258 | 259 | 260 | def write_images_text(images, path): 261 | """ 262 | see: src/base/reconstruction.cc 263 | void Reconstruction::ReadImagesText(const std::string& path) 264 | void Reconstruction::WriteImagesText(const std::string& path) 265 | """ 266 | if len(images) == 0: 267 | mean_observations = 0 268 | else: 269 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 270 | HEADER = "# Image list with two lines of data per image:\n" 271 | "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" 272 | "# POINTS2D[] as (X, Y, POINT3D_ID)\n" 273 | "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) 274 | 275 | with open(path, "w") as fid: 276 | fid.write(HEADER) 277 | for _, img in images.items(): 278 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 279 | first_line = " ".join(map(str, image_header)) 280 | fid.write(first_line + "\n") 281 | 282 | points_strings = [] 283 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 284 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 285 | fid.write(" ".join(points_strings) + "\n") 286 | 287 | 288 | def write_images_binary(images, path_to_model_file): 289 | """ 290 | see: src/base/reconstruction.cc 291 | void Reconstruction::ReadImagesBinary(const std::string& path) 292 | void Reconstruction::WriteImagesBinary(const std::string& path) 293 | """ 294 | with open(path_to_model_file, "wb") as fid: 295 | write_next_bytes(fid, len(images), "Q") 296 | for _, img in images.items(): 297 | write_next_bytes(fid, img.id, "i") 298 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 299 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 300 | write_next_bytes(fid, img.camera_id, "i") 301 | for char in img.name: 302 | write_next_bytes(fid, char.encode("utf-8"), "c") 303 | write_next_bytes(fid, b"\x00", "c") 304 | write_next_bytes(fid, len(img.point3D_ids), "Q") 305 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 306 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 307 | 308 | 309 | def read_points3D_text(path): 310 | """ 311 | see: src/base/reconstruction.cc 312 | void Reconstruction::ReadPoints3DText(const std::string& path) 313 | void Reconstruction::WritePoints3DText(const std::string& path) 314 | """ 315 | points3D = {} 316 | with open(path, "r") as fid: 317 | while True: 318 | line = fid.readline() 319 | if not line: 320 | break 321 | line = line.strip() 322 | if len(line) > 0 and line[0] != "#": 323 | elems = line.split() 324 | point3D_id = int(elems[0]) 325 | xyz = np.array(tuple(map(float, elems[1:4]))) 326 | rgb = np.array(tuple(map(int, elems[4:7]))) 327 | error = float(elems[7]) 328 | image_ids = np.array(tuple(map(int, elems[8::2]))) 329 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 330 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 331 | error=error, image_ids=image_ids, 332 | point2D_idxs=point2D_idxs) 333 | return points3D 334 | 335 | 336 | def read_points3d_binary(path_to_model_file): 337 | """ 338 | see: src/base/reconstruction.cc 339 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 340 | void Reconstruction::WritePoints3DBinary(const std::string& path) 341 | """ 342 | points3D = {} 343 | with open(path_to_model_file, "rb") as fid: 344 | num_points = read_next_bytes(fid, 8, "Q")[0] 345 | for _ in range(num_points): 346 | binary_point_line_properties = read_next_bytes( 347 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 348 | point3D_id = binary_point_line_properties[0] 349 | xyz = np.array(binary_point_line_properties[1:4]) 350 | rgb = np.array(binary_point_line_properties[4:7]) 351 | error = np.array(binary_point_line_properties[7]) 352 | track_length = read_next_bytes( 353 | fid, num_bytes=8, format_char_sequence="Q")[0] 354 | track_elems = read_next_bytes( 355 | fid, num_bytes=8*track_length, 356 | format_char_sequence="ii"*track_length) 357 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 358 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 359 | points3D[point3D_id] = Point3D( 360 | id=point3D_id, xyz=xyz, rgb=rgb, 361 | error=error, image_ids=image_ids, 362 | point2D_idxs=point2D_idxs) 363 | return points3D 364 | 365 | 366 | def write_points3D_text(points3D, path): 367 | """ 368 | see: src/base/reconstruction.cc 369 | void Reconstruction::ReadPoints3DText(const std::string& path) 370 | void Reconstruction::WritePoints3DText(const std::string& path) 371 | """ 372 | if len(points3D) == 0: 373 | mean_track_length = 0 374 | else: 375 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) 376 | HEADER = "# 3D point list with one line of data per point:\n" 377 | "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" 378 | "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) 379 | 380 | with open(path, "w") as fid: 381 | fid.write(HEADER) 382 | for _, pt in points3D.items(): 383 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 384 | fid.write(" ".join(map(str, point_header)) + " ") 385 | track_strings = [] 386 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 387 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 388 | fid.write(" ".join(track_strings) + "\n") 389 | 390 | 391 | def write_points3d_binary(points3D, path_to_model_file): 392 | """ 393 | see: src/base/reconstruction.cc 394 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 395 | void Reconstruction::WritePoints3DBinary(const std::string& path) 396 | """ 397 | with open(path_to_model_file, "wb") as fid: 398 | write_next_bytes(fid, len(points3D), "Q") 399 | for _, pt in points3D.items(): 400 | write_next_bytes(fid, pt.id, "Q") 401 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 402 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 403 | write_next_bytes(fid, pt.error, "d") 404 | track_length = pt.image_ids.shape[0] 405 | write_next_bytes(fid, track_length, "Q") 406 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 407 | write_next_bytes(fid, [image_id, point2D_id], "ii") 408 | 409 | 410 | def detect_model_format(path, ext): 411 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ 412 | os.path.isfile(os.path.join(path, "images" + ext)) and \ 413 | os.path.isfile(os.path.join(path, "points3D" + ext)): 414 | print("Detected model format: '" + ext + "'") 415 | return True 416 | 417 | return False 418 | 419 | 420 | def read_model(path, ext=""): 421 | # try to detect the extension automatically 422 | if ext == "": 423 | if detect_model_format(path, ".bin"): 424 | ext = ".bin" 425 | elif detect_model_format(path, ".txt"): 426 | ext = ".txt" 427 | else: 428 | print("Provide model format: '.bin' or '.txt'") 429 | return 430 | 431 | if ext == ".txt": 432 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 433 | images = read_images_text(os.path.join(path, "images" + ext)) 434 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 435 | else: 436 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 437 | images = read_images_binary(os.path.join(path, "images" + ext)) 438 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 439 | return cameras, images, points3D 440 | 441 | 442 | def write_model(cameras, images, points3D, path, ext=".bin"): 443 | if ext == ".txt": 444 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 445 | write_images_text(images, os.path.join(path, "images" + ext)) 446 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 447 | else: 448 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 449 | write_images_binary(images, os.path.join(path, "images" + ext)) 450 | write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) 451 | return cameras, images, points3D 452 | 453 | 454 | def qvec2rotmat(qvec): 455 | return np.array([ 456 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 457 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 458 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 459 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 460 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 461 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 462 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 463 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 464 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 465 | 466 | 467 | def rotmat2qvec(R): 468 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 469 | K = np.array([ 470 | [Rxx - Ryy - Rzz, 0, 0, 0], 471 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 472 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 473 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 474 | eigvals, eigvecs = np.linalg.eigh(K) 475 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 476 | if qvec[0] < 0: 477 | qvec *= -1 478 | return qvec 479 | 480 | 481 | def main(): 482 | parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") 483 | parser.add_argument("--input_model", help="path to input model folder") 484 | parser.add_argument("--input_format", choices=[".bin", ".txt"], 485 | help="input model format", default="") 486 | parser.add_argument("--output_model", 487 | help="path to output model folder") 488 | parser.add_argument("--output_format", choices=[".bin", ".txt"], 489 | help="outut model format", default=".txt") 490 | args = parser.parse_args() 491 | 492 | cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) 493 | 494 | print("num_cameras:", len(cameras)) 495 | print("num_images:", len(images)) 496 | print("num_points3D:", len(points3D)) 497 | 498 | if args.output_model is not None: 499 | write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) 500 | 501 | 502 | if __name__ == "__main__": 503 | main() 504 | --------------------------------------------------------------------------------