├── models ├── __init__.py ├── __pycache__ │ ├── sh.cpython-39.pyc │ ├── tensoRF.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── tensoRF_VQ.cpython-39.pyc │ ├── tensorBase.cpython-39.pyc │ └── weighted_vq.cpython-39.pyc ├── sh.py ├── weighted_vq.py ├── tensorBase.py └── my_vq.py ├── figures ├── teaser.png └── pipeline.png ├── dataLoader ├── __pycache__ │ ├── llff.cpython-39.pyc │ ├── nsvf.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── blender.cpython-39.pyc │ ├── ray_utils.cpython-39.pyc │ ├── tankstemple.cpython-39.pyc │ └── your_own_data.cpython-39.pyc ├── __init__.py ├── blender.py ├── your_own_data.py ├── nsvf.py ├── tankstemple.py ├── llff.py ├── ray_utils.py └── colmap2nerf.py ├── configs ├── vq │ ├── llff.txt │ ├── tt.txt │ ├── syn.txt │ └── nsvf.txt ├── flower.txt ├── truck.txt ├── wineholder.txt ├── lego.txt └── your_own_data.txt ├── eval_vq_only.py ├── README.md ├── renderer.py ├── extra ├── compute_metrics.py └── auto_run_paramsets.py ├── result.md ├── opt.py ├── autotask_vq.py ├── utils.py └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/figures/pipeline.png -------------------------------------------------------------------------------- /models/__pycache__/sh.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/sh.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/tensoRF.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/tensoRF.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/llff.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/llff.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/nsvf.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/nsvf.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/tensoRF_VQ.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/tensoRF_VQ.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/tensorBase.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/tensorBase.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/blender.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/blender.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/weighted_vq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/models/__pycache__/weighted_vq.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/ray_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/ray_utils.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/tankstemple.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/tankstemple.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/your_own_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spark001/VQ-TensoRF/HEAD/dataLoader/__pycache__/your_own_data.cpython-39.pyc -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .llff import LLFFDataset 2 | from .blender import BlenderDataset 3 | from .nsvf import NSVF 4 | from .tankstemple import TanksTempleDataset 5 | from .your_own_data import YourOwnDataset 6 | 7 | 8 | 9 | dataset_dict = {'blender': BlenderDataset, 10 | 'llff':LLFFDataset, 11 | 'tankstemple':TanksTempleDataset, 12 | 'nsvf':NSVF, 13 | 'own_data':YourOwnDataset} -------------------------------------------------------------------------------- /configs/vq/llff.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = llff 3 | 4 | downsample_train = 4.0 5 | ndc_ray = 10 6 | vq_iters = 10000 7 | vq_up_interval = 10 8 | pct_high = 0.7 9 | pct_mid = 1.0 10 | 11 | n_iters = 25000 12 | batch_size = 4096 13 | 14 | N_voxel_init = 2097156 # 128**3 15 | N_voxel_final = 262144000 # 640**3 16 | upsamp_list = [2000,3000,4000,5500] 17 | update_AlphaMask_list = [2500] 18 | 19 | N_vis = -1 # vis all testing images 20 | vis_every = 10000 21 | 22 | render_test = 1 23 | render_path = 1 24 | 25 | n_lamb_sigma = [16,4,4] 26 | n_lamb_sh = [48,12,12] 27 | 28 | shadingMode = MLP_Fea 29 | fea2denseAct = relu 30 | 31 | view_pe = 0 32 | fea_pe = 0 33 | 34 | TV_weight_density = 1.0 35 | TV_weight_app = 1.0 36 | 37 | -------------------------------------------------------------------------------- /configs/flower.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = llff 3 | datadir = ./data/nerf_llff_data/flower 4 | expname = tensorf_flower_VM 5 | basedir = ./log 6 | 7 | downsample_train = 4.0 8 | ndc_ray = 1 9 | 10 | n_iters = 25000 11 | batch_size = 4096 12 | 13 | N_voxel_init = 2097156 # 128**3 14 | N_voxel_final = 262144000 # 640**3 15 | upsamp_list = [2000,3000,4000,5500] 16 | update_AlphaMask_list = [2500] 17 | 18 | N_vis = -1 # vis all testing images 19 | vis_every = 10000 20 | 21 | render_test = 1 22 | render_path = 1 23 | 24 | n_lamb_sigma = [16,4,4] 25 | n_lamb_sh = [48,12,12] 26 | 27 | shadingMode = MLP_Fea 28 | fea2denseAct = relu 29 | 30 | view_pe = 0 31 | fea_pe = 0 32 | 33 | TV_weight_density = 1.0 34 | TV_weight_app = 1.0 35 | 36 | -------------------------------------------------------------------------------- /configs/vq/tt.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | dataset_name = tankstemple 4 | 5 | vq_iters = 1000 6 | vq_up_interval = 1 7 | pct_high = 0.7 8 | pct_mid = 0.999 9 | 10 | n_iters = 30000 11 | batch_size = 4096 12 | 13 | N_voxel_init = 2097156 # 128**3 14 | N_voxel_final = 27000000 # 300**3 15 | upsamp_list = [2000,3000,4000,5500,7000] 16 | update_AlphaMask_list = [2000,4000] 17 | 18 | N_vis = 5 19 | vis_every = 10000 20 | 21 | render_test = 1 22 | 23 | n_lamb_sigma = [16,16,16] 24 | n_lamb_sh = [48,48,48] 25 | 26 | shadingMode = MLP_Fea 27 | fea2denseAct = softplus 28 | 29 | view_pe = 2 30 | fea_pe = 2 31 | 32 | TV_weight_density = 0.1 33 | TV_weight_app = 0.01 34 | 35 | ## please uncomment following configuration if hope to training on cp model 36 | #model_name = TensorCP 37 | #n_lamb_sigma = [96] 38 | #n_lamb_sh = [288] 39 | #N_voxel_final = 125000000 # 500**3 40 | #L1_weight_inital = 1e-5 41 | #L1_weight_rest = 1e-5 42 | 43 | -------------------------------------------------------------------------------- /configs/truck.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | dataset_name = tankstemple 4 | datadir = ./data/TanksAndTemple/Truck 5 | expname = tensorf_truck_VM 6 | basedir = ./log 7 | 8 | n_iters = 30000 9 | batch_size = 4096 10 | 11 | N_voxel_init = 2097156 # 128**3 12 | N_voxel_final = 27000000 # 300**3 13 | upsamp_list = [2000,3000,4000,5500,7000] 14 | update_AlphaMask_list = [2000,4000] 15 | 16 | N_vis = 5 17 | vis_every = 10000 18 | 19 | render_test = 1 20 | 21 | n_lamb_sigma = [16,16,16] 22 | n_lamb_sh = [48,48,48] 23 | 24 | shadingMode = MLP_Fea 25 | fea2denseAct = softplus 26 | 27 | view_pe = 2 28 | fea_pe = 2 29 | 30 | TV_weight_density = 0.1 31 | TV_weight_app = 0.01 32 | 33 | ## please uncomment following configuration if hope to training on cp model 34 | #model_name = TensorCP 35 | #n_lamb_sigma = [96] 36 | #n_lamb_sh = [288] 37 | #N_voxel_final = 125000000 # 500**3 38 | #L1_weight_inital = 1e-5 39 | #L1_weight_rest = 1e-5 40 | 41 | -------------------------------------------------------------------------------- /configs/vq/syn.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | 4 | vq_iters = 10000 5 | pct_high = 0.7 6 | pct_mid = 0.999 7 | 8 | n_iters = 30000 9 | batch_size = 4096 10 | 11 | N_voxel_init = 2097156 # 128**3 12 | N_voxel_final = 27000000 # 300**3 13 | upsamp_list = [2000,3000,4000,5500,7000] 14 | update_AlphaMask_list = [2000,4000] 15 | 16 | N_vis = 5 17 | vis_every = 10000 18 | 19 | render_test = 1 20 | 21 | n_lamb_sigma = [16,16,16] 22 | n_lamb_sh = [48,48,48] 23 | model_name = TensorVMSplit 24 | 25 | 26 | shadingMode = MLP_Fea 27 | fea2denseAct = softplus 28 | 29 | view_pe = 2 30 | fea_pe = 2 31 | 32 | L1_weight_inital = 8e-5 33 | L1_weight_rest = 4e-5 34 | rm_weight_mask_thre = 1e-4 35 | 36 | ## please uncomment following configuration if hope to training on cp model 37 | #model_name = TensorCP 38 | #n_lamb_sigma = [96] 39 | #n_lamb_sh = [288] 40 | #N_voxel_final = 125000000 # 500**3 41 | #L1_weight_inital = 1e-5 42 | #L1_weight_rest = 1e-5 43 | -------------------------------------------------------------------------------- /configs/wineholder.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = nsvf 3 | datadir = ./data/Synthetic_NSVF/Wineholder 4 | expname = tensorf_Wineholder_VM 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | 20 | n_lamb_sigma = [16,16,16] 21 | n_lamb_sh = [48,48,48] 22 | 23 | shadingMode = MLP_Fea 24 | fea2denseAct = softplus 25 | 26 | view_pe = 2 27 | fea_pe = 2 28 | 29 | L1_weight_inital = 8e-5 30 | L1_weight_rest = 4e-5 31 | rm_weight_mask_thre = 1e-4 32 | 33 | ## please uncomment following configuration if hope to training on cp model 34 | #model_name = TensorCP 35 | #n_lamb_sigma = [96] 36 | #n_lamb_sh = [288] 37 | #N_voxel_final = 125000000 # 500**3 38 | #L1_weight_inital = 1e-5 39 | #L1_weight_rest = 1e-5 40 | -------------------------------------------------------------------------------- /configs/vq/nsvf.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = nsvf 3 | 4 | 5 | vq_iters = 10000 6 | vq_up_interval = 10 7 | pct_high = 0.7 8 | pct_mid = 1.0 9 | #split_or_union = 1 10 | 11 | n_iters = 30000 12 | batch_size = 4096 13 | 14 | N_voxel_init = 2097156 # 128**3 15 | N_voxel_final = 27000000 # 300**3 16 | upsamp_list = [2000,3000,4000,5500,7000] 17 | update_AlphaMask_list = [2000,4000] 18 | 19 | N_vis = 5 20 | vis_every = 10000 21 | 22 | render_test = 1 23 | 24 | n_lamb_sigma = [16,16,16] 25 | n_lamb_sh = [48,48,48] 26 | 27 | shadingMode = MLP_Fea 28 | fea2denseAct = softplus 29 | 30 | view_pe = 2 31 | fea_pe = 2 32 | 33 | L1_weight_inital = 8e-5 34 | L1_weight_rest = 4e-5 35 | rm_weight_mask_thre = 1e-4 36 | 37 | ## please uncomment following configuration if hope to training on cp model 38 | #model_name = TensorCP 39 | #n_lamb_sigma = [96] 40 | #n_lamb_sh = [288] 41 | #N_voxel_final = 125000000 # 500**3 42 | #L1_weight_inital = 1e-5 43 | #L1_weight_rest = 1e-5 44 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/nerf_synthetic/lego 4 | expname = tensorf_lego_VM 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | 20 | n_lamb_sigma = [16,16,16] 21 | n_lamb_sh = [48,48,48] 22 | model_name = TensorVMSplit 23 | 24 | 25 | shadingMode = MLP_Fea 26 | fea2denseAct = softplus 27 | 28 | view_pe = 2 29 | fea_pe = 2 30 | 31 | L1_weight_inital = 8e-5 32 | L1_weight_rest = 4e-5 33 | rm_weight_mask_thre = 1e-4 34 | 35 | ## please uncomment following configuration if hope to training on cp model 36 | #model_name = TensorCP 37 | #n_lamb_sigma = [96] 38 | #n_lamb_sh = [288] 39 | #N_voxel_final = 125000000 # 500**3 40 | #L1_weight_inital = 1e-5 41 | #L1_weight_rest = 1e-5 42 | -------------------------------------------------------------------------------- /configs/your_own_data.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = own_data 3 | datadir = ./data/xxx 4 | expname = tensorf_xxx_VM 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | 20 | n_lamb_sigma = [16,16,16] 21 | n_lamb_sh = [48,48,48] 22 | model_name = TensorVMSplit 23 | 24 | 25 | shadingMode = MLP_Fea 26 | fea2denseAct = softplus 27 | 28 | view_pe = 2 29 | fea_pe = 2 30 | 31 | view_pe = 2 32 | fea_pe = 2 33 | 34 | TV_weight_density = 0.1 35 | TV_weight_app = 0.01 36 | 37 | rm_weight_mask_thre = 1e-4 38 | 39 | ## please uncomment following configuration if hope to training on cp model 40 | #model_name = TensorCP 41 | #n_lamb_sigma = [96] 42 | #n_lamb_sh = [288] 43 | #N_voxel_final = 125000000 # 500**3 44 | #L1_weight_inital = 1e-5 45 | #L1_weight_rest = 1e-5 46 | -------------------------------------------------------------------------------- /eval_vq_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.auto import tqdm 3 | from opt import config_parser 4 | 5 | 6 | 7 | import json, random 8 | from renderer import * 9 | from utils import * 10 | import datetime 11 | 12 | from dataLoader import dataset_dict 13 | import sys 14 | 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | renderer = OctreeRender_trilinear_fast 20 | 21 | 22 | def eval_vq(args): 23 | dataset = dataset_dict[args.dataset_name] 24 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) 25 | white_bg = test_dataset.white_bg 26 | ndc_ray = args.ndc_ray 27 | 28 | assert os.path.isfile(args.ckpt), f'No checkpoint found at {args.ckpt}' 29 | 30 | ckpt = torch.load(args.ckpt, map_location=device) 31 | kwargs = ckpt['kwargs'] 32 | kwargs.update({'device':device}) 33 | tensorf = TensorVMSplitVQ(**kwargs) 34 | tensorf.extreme_load(ckpt) 35 | 36 | from functools import partial 37 | evaluation_test = partial(evaluation,test_dataset=test_dataset, args=args, 38 | renderer=renderer,white_bg = white_bg, ndc_ray=ndc_ray,device=device, 39 | compute_extra_metrics=args.autotask, im_save=args.autotask) 40 | logfolder = os.path.dirname(args.ckpt) 41 | os.makedirs(f'{logfolder}/extreme_load', exist_ok=True) 42 | TestLoad = evaluation_test(tensorf=tensorf, N_vis=-1, savePath=f"{logfolder}/extreme_load" if not args.debug else None) 43 | print(f'======>Test ExtremeLoad: AfterVQFinetune WithFinetune(Quant) {args.expname} test all psnr: {np.mean(TestLoad)} <========================') 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | torch.set_default_dtype(torch.float32) 49 | torch.manual_seed(20211202) 50 | np.random.seed(20211202) 51 | 52 | args = config_parser() 53 | print(args) 54 | 55 | eval_vq(args) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compressing Volumetric Radiance Fields to 1 MB 2 | 3 | ## [Paper](https://openaccess.thecvf.com/content/CVPR2023/html/Li_Compressing_Volumetric_Radiance_Fields_to_1_MB_CVPR_2023_paper.html) 4 | 5 | 6 | **Update**: 🤗 We update our compressed models in [ModelScope](https://modelscope.cn/models/DAMOXR/cv_nerf_3d-reconstruction_vector-quantize-compression/summary), so you can test the models and render videos easily. 7 | 8 | 我们在[魔搭](https://modelscope.cn/models/DAMOXR/cv_nerf_3d-reconstruction_vector-quantize-compression/summary)上更新了压缩后的模型,更方便的支持在线测试和渲染视频。 9 | 10 | ![compression](figures/teaser.png) 11 | 12 | ![Pipeline](figures/pipeline.png) 13 | 14 | **Note**: This repository only contain VQ-TensoRF. 15 | 16 | VQ-DVGO please refer to [VQRF](https://github.com/AlgoHunt/VQRF). 17 | 18 | 19 | ## Setup 20 | 21 | - Download datasets: 22 | [NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1), 23 | [NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip), [T&T (masked)](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip), [LLFF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 24 | 25 | 26 | - Install required libraries, Please refer to [TensoRF](https://github.com/apchenstu/TensoRF) 27 | 28 | 29 | Please install the correct version of [Pytorch](https://pytorch.org/) and [torch_scatter](https://github.com/rusty1s/pytorch_scatter) for your machine. 30 | 31 | ## Directory structure for the datasets 32 | 33 | 35 | ``` 36 | data 37 | ├── nerf_synthetic # Link: https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 38 | │ └── [chair|drums|ficus|hotdog|lego|materials|mic|ship] 39 | │ ├── [train|val|test] 40 | │ │ └── r_*.png 41 | │ └── transforms_[train|val|test].json 42 | │ 43 | ├── Synthetic_NSVF # Link: https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip 44 | │ └── [Bike|Lifestyle|Palace|Robot|Spaceship|Steamtrain|Toad|Wineholder] 45 | │ ├── intrinsics.txt 46 | │ ├── rgb 47 | │ │ └── [0_train|1_val|2_test]_*.png 48 | │ └── pose 49 | │ └── [0_train|1_val|2_test]_*.txt 50 | │ 51 | ├── nerf_llff_data # Link: https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1 52 | │ └── [fern|flower|fortress|horns|leaves|orchids|room|trex] 53 | │    ├── poses_bounds.npy 54 | │    └── images 55 | │ 56 | └── TanksAndTemple # Link: https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip 57 | └── [Barn|Caterpillar|Family|Ignatius|Truck] 58 | ├── intrinsics.txt 59 | ├── rgb 60 | │ └── [0|1|2]_*.png 61 | └── pose 62 | └── [0|1|2]_*.txt 63 | 64 | ``` 65 | 66 | 67 | 68 | 69 | ## Training & VectQuantize & Testing 70 | 71 | The training script is in `vectquant.py`. 72 | 73 | For example, to train a VectQuantized model on the synthetic dataset: 74 | 75 | ``` 76 | python vectquant.py --config configs/vq/syn.txt --datadir {syn_dataset_dir}/hotdog --expname hotdog --basedir ./log_reimp/syn --render_path 0 --render_only 0 --ckpt ./log_reimp/syn/hotdog/hotdog.th 77 | ``` 78 | 79 | The process of the training script is divided into three steps: 80 | * **Step 1**: Train a baseline model and save the model checkpoint. (Follow the vinilla TensoRF training pipeline) 81 | * **Step 2**: Train a VectQuantized model with the baseline model checkpoint from Step 1 and save the VectQuantized model checkpoint. 82 | * **Step 3**: Test the VectQuantized model checkpoint from Step 2. 83 | 84 | 85 | More options refer to the `opt.py`. 86 | 87 | ## Autotask for a dataset 88 | 89 | `python autotask_vq.py -g "0 1 2 3" --dataset {dataset_name} --suffix v0` 90 | 91 | Modify your `data` directory in `DatasetSetting`. 92 | 93 | Set `dataset_name`, choices = ['syn', 'nsvf', 'tt', 'llff']. 94 | 95 | Set `-g` option according to the availible gpus on your machine. 96 | 97 | > Notice: When you run the autotask script with multiple gpus, maybe the bottleneck is the disk IO for data loading. 98 | 99 | ## Testing the VectQuantized model 100 | 101 | ``` 102 | python eval_vq_only.py --autotask --config configs/vq/syn.txt --datadir {syn_dataset_dir} --ckpt {VQ_model_checkpoint} 103 | 104 | ``` 105 | 106 | 107 | ## Acknowledgements 108 | In this repository, we have used codes from the following repositories. 109 | * [TensoRF](https://github.com/apchenstu/TensoRF) 110 | * [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch) 111 | 112 | ## Citation 113 | If you find our work useful in your research, please consider citing: 114 | 115 | ``` 116 | @inproceedings{li2023compressing, 117 | title={Compressing volumetric radiance fields to 1 mb}, 118 | author={Li, Lingzhi and Shen, Zhen and Wang, Zhongshu and Shen, Li and Bo, Liefeng}, 119 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 120 | pages={4222--4231}, 121 | year={2023} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /dataLoader/blender.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class BlenderDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.img_wh = (int(800/downsample),int(800/downsample)) 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [2.0,6.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = self.img_wh 45 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 46 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh 47 | 48 | 49 | # ray directions for all pixels, same for all images (same H, W, focal) 50 | self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) 51 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 52 | self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float() 53 | 54 | self.image_paths = [] 55 | self.poses = [] 56 | self.all_rays = [] 57 | self.all_rgbs = [] 58 | self.all_masks = [] 59 | self.all_depth = [] 60 | self.downsample=1.0 61 | 62 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 63 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 64 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 65 | 66 | frame = self.meta['frames'][i] 67 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 68 | c2w = torch.FloatTensor(pose) 69 | self.poses += [c2w] 70 | 71 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 72 | self.image_paths += [image_path] 73 | img = Image.open(image_path) 74 | 75 | if self.downsample!=1.0: 76 | img = img.resize(self.img_wh, Image.LANCZOS) 77 | img = self.transform(img) # (4, h, w) 78 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 79 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 80 | self.all_rgbs += [img] 81 | 82 | 83 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 84 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 85 | 86 | 87 | self.poses = torch.stack(self.poses) 88 | if not self.is_stack: 89 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 90 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 91 | 92 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 93 | else: 94 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 95 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 96 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 97 | 98 | 99 | def define_transforms(self): 100 | self.transform = T.ToTensor() 101 | 102 | def define_proj_mat(self): 103 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 104 | 105 | def world2ndc(self,points,lindisp=None): 106 | device = points.device 107 | return (points - self.center.to(device)) / self.radius.to(device) 108 | 109 | def __len__(self): 110 | return len(self.all_rgbs) 111 | 112 | def __getitem__(self, idx): 113 | 114 | if self.split == 'train': # use data in the buffers 115 | sample = {'rays': self.all_rays[idx], 116 | 'rgbs': self.all_rgbs[idx]} 117 | 118 | else: # create data for each image separately 119 | 120 | img = self.all_rgbs[idx] 121 | rays = self.all_rays[idx] 122 | mask = self.all_masks[idx] # for quantity evaluation 123 | 124 | sample = {'rays': rays, 125 | 'rgbs': img, 126 | 'mask': mask} 127 | return sample 128 | -------------------------------------------------------------------------------- /dataLoader/your_own_data.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class YourOwnDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.downsample = downsample 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [0.1,100.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample) 45 | self.img_wh = [w,h] 46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 48 | self.cx, self.cy = self.meta['cx'],self.meta['cy'] 49 | 50 | 51 | # ray directions for all pixels, same for all images (same H, W, focal) 52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float() 55 | 56 | self.image_paths = [] 57 | self.poses = [] 58 | self.all_rays = [] 59 | self.all_rgbs = [] 60 | self.all_masks = [] 61 | self.all_depth = [] 62 | 63 | 64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 67 | 68 | frame = self.meta['frames'][i] 69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 70 | c2w = torch.FloatTensor(pose) 71 | self.poses += [c2w] 72 | 73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 74 | self.image_paths += [image_path] 75 | img = Image.open(image_path) 76 | 77 | if self.downsample!=1.0: 78 | img = img.resize(self.img_wh, Image.LANCZOS) 79 | img = self.transform(img) # (4, h, w) 80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA 81 | if img.shape[-1]==4: 82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 83 | self.all_rgbs += [img] 84 | 85 | 86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 88 | 89 | 90 | self.poses = torch.stack(self.poses) 91 | if not self.is_stack: 92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 94 | 95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 96 | else: 97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 100 | 101 | 102 | def define_transforms(self): 103 | self.transform = T.ToTensor() 104 | 105 | def define_proj_mat(self): 106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 107 | 108 | def world2ndc(self,points,lindisp=None): 109 | device = points.device 110 | return (points - self.center.to(device)) / self.radius.to(device) 111 | 112 | def __len__(self): 113 | return len(self.all_rgbs) 114 | 115 | def __getitem__(self, idx): 116 | 117 | if self.split == 'train': # use data in the buffers 118 | sample = {'rays': self.all_rays[idx], 119 | 'rgbs': self.all_rgbs[idx]} 120 | 121 | else: # create data for each image separately 122 | 123 | img = self.all_rgbs[idx] 124 | rays = self.all_rays[idx] 125 | mask = self.all_masks[idx] # for quantity evaluation 126 | 127 | sample = {'rays': rays, 128 | 'rgbs': img} 129 | return sample 130 | -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /dataLoader/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class NSVF(Dataset): 37 | """NSVF Generic Dataset.""" 38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): 39 | self.root_dir = datadir 40 | self.split = split 41 | self.is_stack = is_stack 42 | self.downsample = downsample 43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 44 | self.define_transforms() 45 | 46 | self.white_bg = True 47 | self.near_far = [0.5,6.0] 48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) 49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 50 | self.read_meta() 51 | self.define_proj_mat() 52 | 53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 55 | 56 | def bbox2corners(self): 57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 58 | for i in range(3): 59 | corners[i,[0,1],i] = corners[i,[1,0],i] 60 | return corners.view(-1,3) 61 | 62 | 63 | def read_meta(self): 64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: 65 | focal = float(f.readline().split()[0]) 66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) 67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) 68 | 69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 71 | 72 | if self.split == 'train': 73 | pose_files = [x for x in pose_files if x.startswith('0_')] 74 | img_files = [x for x in img_files if x.startswith('0_')] 75 | elif self.split == 'val': 76 | pose_files = [x for x in pose_files if x.startswith('1_')] 77 | img_files = [x for x in img_files if x.startswith('1_')] 78 | elif self.split == 'test': 79 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 80 | test_img_files = [x for x in img_files if x.startswith('2_')] 81 | if len(test_pose_files) == 0: 82 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 83 | test_img_files = [x for x in img_files if x.startswith('1_')] 84 | pose_files = test_pose_files 85 | img_files = test_img_files 86 | 87 | # ray directions for all pixels, same for all images (same H, W, focal) 88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 90 | 91 | 92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 93 | 94 | self.poses = [] 95 | self.all_rays = [] 96 | self.all_rgbs = [] 97 | 98 | assert len(img_files) == len(pose_files) 99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 101 | img = Image.open(image_path) 102 | if self.downsample!=1.0: 103 | img = img.resize(self.img_wh, Image.LANCZOS) 104 | img = self.transform(img) # (4, h, w) 105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 106 | if img.shape[-1]==4: 107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 108 | self.all_rgbs += [img] 109 | 110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv 111 | c2w = torch.FloatTensor(c2w) 112 | self.poses.append(c2w) # C2W 113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 115 | 116 | # w2c = torch.inverse(c2w) 117 | # 118 | 119 | self.poses = torch.stack(self.poses) 120 | if 'train' == self.split: 121 | if self.is_stack: 122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 124 | else: 125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 127 | else: 128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 130 | 131 | 132 | def define_transforms(self): 133 | self.transform = T.ToTensor() 134 | 135 | def define_proj_mat(self): 136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 137 | 138 | def world2ndc(self, points): 139 | device = points.device 140 | return (points - self.center.to(device)) / self.radius.to(device) 141 | 142 | def __len__(self): 143 | if self.split == 'train': 144 | return len(self.all_rays) 145 | return len(self.all_rgbs) 146 | 147 | def __getitem__(self, idx): 148 | 149 | if self.split == 'train': # use data in the buffers 150 | sample = {'rays': self.all_rays[idx], 151 | 'rgbs': self.all_rgbs[idx]} 152 | 153 | else: # create data for each image separately 154 | 155 | img = self.all_rgbs[idx] 156 | rays = self.all_rays[idx] 157 | 158 | sample = {'rays': rays, 159 | 'rgbs': img} 160 | return sample -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | import torch,os,imageio,sys 2 | from tqdm.auto import tqdm 3 | from dataLoader.ray_utils import get_rays 4 | from models.tensoRF import TensorVM, TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask 5 | from models.tensoRF_VQ import TensorVMSplitVQ 6 | from utils import * 7 | from dataLoader.ray_utils import ndc_rays_blender 8 | 9 | 10 | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda', **kwargs): 11 | 12 | rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], [] 13 | N_rays_all = rays.shape[0] 14 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): 15 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) 16 | 17 | rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples, **kwargs) 18 | 19 | rgbs.append(rgb_map) 20 | depth_maps.append(depth_map) 21 | 22 | return torch.cat(rgbs), None, torch.cat(depth_maps), None, None 23 | 24 | @torch.no_grad() 25 | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 26 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda', im_save=False): 27 | if prtx is not None: 28 | prtx = prtx + '_' 29 | result_path = f'{savePath}/{prtx}res.txt' 30 | if os.path.exists(result_path) and not args.render_path: 31 | psnr = np.loadtxt(result_path)[0] 32 | return psnr 33 | PSNRs, rgb_maps, depth_maps = [], [], [] 34 | ssims,l_alex,l_vgg=[],[],[] 35 | if savePath is not None: 36 | os.makedirs(savePath, exist_ok=True) 37 | os.makedirs(savePath+"/rgbd", exist_ok=True) 38 | 39 | try: 40 | tqdm._instances.clear() 41 | except Exception: 42 | pass 43 | 44 | near_far = test_dataset.near_far 45 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1) 46 | # img_eval_interval = max(img_eval_interval, test_dataset.all_rays.shape[0]//49) 47 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval)) 48 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout): 49 | 50 | W, H = test_dataset.img_wh 51 | rays = samples.view(-1,samples.shape[-1]) 52 | 53 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples, 54 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 55 | rgb_map = rgb_map.clamp(0.0, 1.0) 56 | 57 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 58 | 59 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 60 | if len(test_dataset.all_rgbs): 61 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3) 62 | loss = torch.mean((rgb_map - gt_rgb) ** 2) 63 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) 64 | 65 | if compute_extra_metrics: 66 | ssim = rgb_ssim(rgb_map, gt_rgb, 1) 67 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device) 68 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device) 69 | ssims.append(ssim) 70 | l_alex.append(l_a) 71 | l_vgg.append(l_v) 72 | 73 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 74 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 75 | rgb_maps.append(rgb_map) 76 | depth_maps.append(depth_map) 77 | if savePath is not None and im_save: 78 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 79 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 80 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 81 | if savePath is not None: 82 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10) 83 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10) 84 | 85 | if PSNRs: 86 | psnr = np.mean(np.asarray(PSNRs)) 87 | if compute_extra_metrics: 88 | ssim = np.mean(np.asarray(ssims)) 89 | l_a = np.mean(np.asarray(l_alex)) 90 | l_v = np.mean(np.asarray(l_vgg)) 91 | if savePath is not None: 92 | np.savetxt(result_path, np.asarray([psnr, ssim, l_a, l_v])) 93 | else: 94 | if savePath is not None: 95 | np.savetxt(result_path, np.asarray([psnr])) 96 | 97 | 98 | return PSNRs 99 | 100 | @torch.no_grad() 101 | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 102 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 103 | PSNRs, rgb_maps, depth_maps = [], [], [] 104 | ssims,l_alex,l_vgg=[],[],[] 105 | os.makedirs(savePath, exist_ok=True) 106 | os.makedirs(savePath+"/rgbd", exist_ok=True) 107 | 108 | try: 109 | tqdm._instances.clear() 110 | except Exception: 111 | pass 112 | 113 | near_far = test_dataset.near_far 114 | for idx, c2w in tqdm(enumerate(c2ws)): 115 | 116 | W, H = test_dataset.img_wh 117 | 118 | c2w = torch.FloatTensor(c2w) 119 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) 120 | if ndc_ray: 121 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) 122 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 123 | 124 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples, 125 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 126 | rgb_map = rgb_map.clamp(0.0, 1.0) 127 | 128 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 129 | 130 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 131 | 132 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 133 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 134 | rgb_maps.append(rgb_map) 135 | depth_maps.append(depth_map) 136 | if savePath is not None: 137 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 138 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 139 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 140 | 141 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) 142 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) 143 | 144 | if PSNRs: 145 | psnr = np.mean(np.asarray(PSNRs)) 146 | if compute_extra_metrics: 147 | ssim = np.mean(np.asarray(ssims)) 148 | l_a = np.mean(np.asarray(l_alex)) 149 | l_v = np.mean(np.asarray(l_vgg)) 150 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 151 | else: 152 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 153 | 154 | 155 | return PSNRs 156 | 157 | -------------------------------------------------------------------------------- /extra/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import numpy as np 3 | import scipy.signal 4 | from typing import List, Optional 5 | from PIL import Image 6 | import os 7 | import torch 8 | import configargparse 9 | 10 | __LPIPS__ = {} 11 | def init_lpips(net_name, device): 12 | assert net_name in ['alex', 'vgg'] 13 | import lpips 14 | print(f'init_lpips: lpips_{net_name}') 15 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 16 | 17 | def rgb_lpips(np_gt, np_im, net_name, device): 18 | if net_name not in __LPIPS__: 19 | __LPIPS__[net_name] = init_lpips(net_name, device) 20 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 21 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 22 | return __LPIPS__[net_name](gt, im, normalize=True).item() 23 | 24 | 25 | def findItem(items, target): 26 | for one in items: 27 | if one[:len(target)]==target: 28 | return one 29 | return None 30 | 31 | 32 | ''' Evaluation metrics (ssim, lpips) 33 | ''' 34 | def rgb_ssim(img0, img1, max_val, 35 | filter_size=11, 36 | filter_sigma=1.5, 37 | k1=0.01, 38 | k2=0.03, 39 | return_map=False): 40 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 41 | assert len(img0.shape) == 3 42 | assert img0.shape[-1] == 3 43 | assert img0.shape == img1.shape 44 | 45 | # Construct a 1D Gaussian blur filter. 46 | hw = filter_size // 2 47 | shift = (2 * hw - filter_size + 1) / 2 48 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 49 | filt = np.exp(-0.5 * f_i) 50 | filt /= np.sum(filt) 51 | 52 | # Blur in x and y (faster than the 2D convolution). 53 | def convolve2d(z, f): 54 | return scipy.signal.convolve2d(z, f, mode='valid') 55 | 56 | filt_fn = lambda z: np.stack([ 57 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 58 | for i in range(z.shape[-1])], -1) 59 | mu0 = filt_fn(img0) 60 | mu1 = filt_fn(img1) 61 | mu00 = mu0 * mu0 62 | mu11 = mu1 * mu1 63 | mu01 = mu0 * mu1 64 | sigma00 = filt_fn(img0**2) - mu00 65 | sigma11 = filt_fn(img1**2) - mu11 66 | sigma01 = filt_fn(img0 * img1) - mu01 67 | 68 | # Clip the variances and covariances to valid values. 69 | # Variance must be non-negative: 70 | sigma00 = np.maximum(0., sigma00) 71 | sigma11 = np.maximum(0., sigma11) 72 | sigma01 = np.sign(sigma01) * np.minimum( 73 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 74 | c1 = (k1 * max_val)**2 75 | c2 = (k2 * max_val)**2 76 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 77 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 78 | ssim_map = numer / denom 79 | ssim = np.mean(ssim_map) 80 | return ssim_map if return_map else ssim 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | parser = configargparse.ArgumentParser() 86 | parser.add_argument("--exp", type=str, help="folder of exps") 87 | parser.add_argument("--paramStr", type=str, help="str of params") 88 | args = parser.parse_args() 89 | 90 | 91 | # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']# 92 | # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic" 93 | # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp 94 | 95 | # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']# 96 | # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/" 97 | # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp 98 | paramStr = args.paramStr 99 | fileNum = 200 100 | 101 | 102 | expitems = os.listdir(expFolder) 103 | finalFolder = f'{expFolder}/finals/{paramStr}' 104 | outFile = f'{finalFolder}/{paramStr}_metrics.txt' 105 | os.makedirs(finalFolder, exist_ok=True) 106 | 107 | expitems.sort(reverse=True) 108 | 109 | 110 | with open(outFile, 'w') as f: 111 | all_psnr = [] 112 | all_ssim = [] 113 | all_alex = [] 114 | all_vgg = [] 115 | for dataname in datanames: 116 | 117 | 118 | gtstr = gtFolder+"/"+dataname+"/test/r_%d.png" 119 | expname = findItem(expitems, f'{paramStr}-{dataname}') 120 | print("expname: ", expname) 121 | if expname is None: 122 | print("no ",dataname, "exists") 123 | continue 124 | resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png" 125 | metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt' 126 | video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4' 127 | 128 | exist_metric=False 129 | if os.path.isfile(metric_file): 130 | metrics = np.loadtxt(metric_file) 131 | print(metrics, metrics.tolist()) 132 | if metrics.size == 4: 133 | psnr, ssim, l_a, l_v = metrics.tolist() 134 | exist_metric = True 135 | os.system(f"cp {video_file} {finalFolder}/") 136 | 137 | if not exist_metric: 138 | psnrs = [] 139 | ssims = [] 140 | l_alex = [] 141 | l_vgg = [] 142 | for i in range(fileNum): 143 | gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0 144 | gtmask = gt[...,[3]] 145 | gt = gt[...,:3] 146 | gt = gt*gtmask + (1-gtmask) 147 | img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0 148 | # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max()) 149 | 150 | 151 | psnr = -10. * np.log10(np.mean(np.square(img - gt))) 152 | ssim = rgb_ssim(img, gt, 1) 153 | lpips_alex = rgb_lpips(gt, img, 'alex','cuda') 154 | lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda') 155 | 156 | print(i, psnr, ssim, lpips_alex, lpips_vgg) 157 | psnrs.append(psnr) 158 | ssims.append(ssim) 159 | l_alex.append(lpips_alex) 160 | l_vgg.append(lpips_vgg) 161 | psnr = np.mean(np.array(psnrs)) 162 | ssim = np.mean(np.array(ssims)) 163 | l_a = np.mean(np.array(l_alex)) 164 | l_v = np.mean(np.array(l_vgg)) 165 | 166 | rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 167 | print(rS) 168 | f.write(rS+"\n") 169 | 170 | all_psnr.append(psnr) 171 | all_ssim.append(ssim) 172 | all_alex.append(l_a) 173 | all_vgg.append(l_v) 174 | 175 | psnr = np.mean(np.array(all_psnr)) 176 | ssim = np.mean(np.array(all_ssim)) 177 | l_a = np.mean(np.array(all_alex)) 178 | l_v = np.mean(np.array(all_vgg)) 179 | 180 | rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 181 | print(rS) 182 | f.write(rS+"\n") -------------------------------------------------------------------------------- /result.md: -------------------------------------------------------------------------------- 1 | 2 | # Command 3 | 4 | `python autotask_vq.py -g "0 1 2 3" --dataset syn` 5 | 6 | `python autotask_vq.py -g "0 1 2 3" --dataset llff` 7 | 8 | `python autotask_vq.py -g "0 1 2 3" --dataset nsvf` 9 | 10 | `python autotask_vq.py -g "0 1 2 3" --dataset tt` 11 | 12 | # Results 13 | ## Syn 14 | 15 | ``` 16 | Raw: 17 | +-----------+---------+--------+------------+-----------+---------+ 18 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 19 | +-----------+---------+--------+------------+-----------+---------+ 20 | | chair | 35.7760 | 0.9846 | 0.0095 | 0.0215 | 67.8907 | 21 | | drums | 26.0150 | 0.9370 | 0.0494 | 0.0718 | 67.7655 | 22 | | ficus | 34.1024 | 0.9826 | 0.0122 | 0.0222 | 70.7913 | 23 | | hotdog | 37.5654 | 0.9828 | 0.0129 | 0.0301 | 84.0902 | 24 | | lego | 36.5248 | 0.9834 | 0.0072 | 0.0178 | 68.8282 | 25 | | materials | 30.1193 | 0.9523 | 0.0259 | 0.0584 | 84.6609 | 26 | | mic | 34.9041 | 0.9885 | 0.0078 | 0.0144 | 66.6982 | 27 | | ship | 30.7121 | 0.8941 | 0.0831 | 0.1376 | 70.9092 | 28 | | Mean | 33.2149 | 0.9632 | 0.0260 | 0.0467 | 72.7043 | 29 | +-----------+---------+--------+------------+-----------+---------+ 30 | 31 | VQRF: 32 | +-----------+---------+--------+------------+-----------+--------+ 33 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 34 | +-----------+---------+--------+------------+-----------+--------+ 35 | | chair | 35.1320 | 0.9806 | 0.0185 | 0.0355 | 5.6044 | 36 | | drums | 25.9492 | 0.9316 | 0.0624 | 0.0983 | 4.9731 | 37 | | ficus | 33.8504 | 0.9813 | 0.0136 | 0.0288 | 6.0117 | 38 | | hotdog | 37.2531 | 0.9807 | 0.0154 | 0.0367 | 9.6663 | 39 | | lego | 35.9736 | 0.9808 | 0.0087 | 0.0242 | 6.1286 | 40 | | materials | 30.0846 | 0.9511 | 0.0279 | 0.0617 | 8.2645 | 41 | | mic | 34.3908 | 0.9861 | 0.0128 | 0.0241 | 3.9262 | 42 | | ship | 30.5233 | 0.8908 | 0.0852 | 0.1430 | 7.0612 | 43 | | Mean | 32.8946 | 0.9604 | 0.0306 | 0.0565 | 6.4545 | 44 | +-----------+---------+--------+------------+-----------+--------+ 45 | ``` 46 | 47 | 48 | ## NSVF 49 | 50 | 51 | ``` 52 | Raw: 53 | +------------+---------+--------+------------+-----------+---------+ 54 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 55 | +------------+---------+--------+------------+-----------+---------+ 56 | | Bike | 39.4477 | 0.9928 | 0.0027 | 0.0099 | 73.5326 | 57 | | Lifestyle | 34.6238 | 0.9691 | 0.0201 | 0.0464 | 67.5628 | 58 | | Palace | 37.8154 | 0.9805 | 0.0101 | 0.0206 | 67.3745 | 59 | | Robot | 38.5671 | 0.9945 | 0.0029 | 0.0096 | 70.7333 | 60 | | Spaceship | 38.7290 | 0.9885 | 0.0087 | 0.0197 | 70.6000 | 61 | | Steamtrain | 37.9821 | 0.9910 | 0.0061 | 0.0169 | 83.7099 | 62 | | Toad | 35.1108 | 0.9795 | 0.0145 | 0.0285 | 71.4396 | 63 | | Wineholder | 31.4842 | 0.9620 | 0.0225 | 0.0482 | 67.9301 | 64 | | Mean | 36.7200 | 0.9822 | 0.0109 | 0.0250 | 71.6103 | 65 | +------------+---------+--------+------------+-----------+---------+ 66 | VQRF: 67 | +------------+---------+--------+------------+-----------+--------+ 68 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 69 | +------------+---------+--------+------------+-----------+--------+ 70 | | Bike | 38.8761 | 0.9919 | 0.0032 | 0.0123 | 6.2926 | 71 | | Lifestyle | 34.4752 | 0.9667 | 0.0219 | 0.0512 | 6.7598 | 72 | | Palace | 37.4467 | 0.9779 | 0.0113 | 0.0238 | 6.1246 | 73 | | Robot | 37.9221 | 0.9935 | 0.0034 | 0.0112 | 5.9037 | 74 | | Spaceship | 38.3696 | 0.9876 | 0.0103 | 0.0217 | 6.7477 | 75 | | Steamtrain | 37.6387 | 0.9894 | 0.0073 | 0.0223 | 7.7254 | 76 | | Toad | 33.8910 | 0.9724 | 0.0195 | 0.0398 | 8.4120 | 77 | | Wineholder | 31.3890 | 0.9602 | 0.0240 | 0.0529 | 6.0048 | 78 | | Mean | 36.2511 | 0.9800 | 0.0126 | 0.0294 | 6.7463 | 79 | +------------+---------+--------+------------+-----------+--------+ 80 | 81 | ``` 82 | 83 | 84 | ## LLFF 85 | 86 | 87 | ``` 88 | Raw: 89 | +----------+---------+--------+------------+-----------+----------+ 90 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 91 | +----------+---------+--------+------------+-----------+----------+ 92 | | fern | 24.9985 | 0.7981 | 0.1595 | 0.2513 | 179.9157 | 93 | | flower | 28.1521 | 0.8574 | 0.1043 | 0.1775 | 179.8139 | 94 | | room | 32.1206 | 0.9514 | 0.0770 | 0.1622 | 179.8682 | 95 | | leaves | 21.1095 | 0.7439 | 0.1425 | 0.2215 | 179.6977 | 96 | | horns | 28.3430 | 0.8835 | 0.1027 | 0.1801 | 179.8110 | 97 | | trex | 27.6879 | 0.9107 | 0.0793 | 0.2019 | 179.8421 | 98 | | fortress | 31.4531 | 0.8982 | 0.0666 | 0.1425 | 179.8664 | 99 | | orchids | 19.8818 | 0.6468 | 0.1914 | 0.2786 | 179.9021 | 100 | | Mean | 26.7183 | 0.8362 | 0.1154 | 0.2019 | 179.8396 | 101 | +----------+---------+--------+------------+-----------+----------+ 102 | VQRF: 103 | +----------+---------+--------+------------+-----------+---------+ 104 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 105 | +----------+---------+--------+------------+-----------+---------+ 106 | | fern | 24.7626 | 0.7861 | 0.1725 | 0.2687 | 16.4782 | 107 | | flower | 27.8500 | 0.8421 | 0.1183 | 0.2032 | 16.6604 | 108 | | room | 31.7010 | 0.9446 | 0.0900 | 0.1842 | 16.6127 | 109 | | leaves | 20.9996 | 0.7262 | 0.1580 | 0.2559 | 16.4771 | 110 | | horns | 27.8376 | 0.8613 | 0.1289 | 0.2195 | 16.3832 | 111 | | trex | 27.1907 | 0.8979 | 0.0930 | 0.2256 | 16.4172 | 112 | | fortress | 31.0268 | 0.8764 | 0.1027 | 0.1895 | 17.7717 | 113 | | orchids | 19.7541 | 0.6346 | 0.2073 | 0.2984 | 16.3075 | 114 | | Mean | 26.3903 | 0.8212 | 0.1338 | 0.2306 | 16.6385 | 115 | +----------+---------+--------+------------+-----------+---------+ 116 | 117 | 118 | ``` 119 | 120 | ## T&T 121 | 122 | ``` 123 | Raw: 124 | +-------------+---------+--------+------------+-----------+---------+ 125 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 126 | +-------------+---------+--------+------------+-----------+---------+ 127 | | Barn | 27.4822 | 0.8663 | 0.2082 | 0.2481 | 80.8193 | 128 | | Caterpillar | 25.9890 | 0.9105 | 0.1347 | 0.1593 | 72.0683 | 129 | | Family | 34.0665 | 0.9663 | 0.0533 | 0.0601 | 67.1116 | 130 | | Ignatius | 28.3670 | 0.9485 | 0.0758 | 0.0764 | 67.2292 | 131 | | Truck | 26.9499 | 0.9131 | 0.1267 | 0.1476 | 75.9510 | 132 | | Mean | 28.5709 | 0.9209 | 0.1198 | 0.1383 | 72.6359 | 133 | +-------------+---------+--------+------------+-----------+---------+ 134 | 135 | VQRF: 136 | +-------------+---------+--------+------------+-----------+--------+ 137 | | Scene | PSNR | SSIM | LPIPS_ALEX | LPIPS_VGG | SIZE | 138 | +-------------+---------+--------+------------+-----------+--------+ 139 | | Barn | 27.1056 | 0.8569 | 0.2321 | 0.2753 | 4.9942 | 140 | | Caterpillar | 25.6002 | 0.9019 | 0.1636 | 0.1902 | 5.2829 | 141 | | Family | 33.2410 | 0.9589 | 0.0644 | 0.0760 | 4.6899 | 142 | | Ignatius | 28.0493 | 0.9428 | 0.0855 | 0.0878 | 5.5910 | 143 | | Truck | 26.6837 | 0.9033 | 0.1537 | 0.1794 | 5.3103 | 144 | | Mean | 28.1360 | 0.9127 | 0.1398 | 0.1617 | 5.1737 | 145 | +-------------+---------+--------+------------+-----------+--------+ 146 | 147 | ``` -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def config_parser(cmd=None): 4 | parser = configargparse.ArgumentParser() 5 | parser.add_argument('--config', is_config_file=True, 6 | help='config file path') 7 | parser.add_argument("--expname", type=str, 8 | help='experiment name') 9 | parser.add_argument("--basedir", type=str, default='./log', 10 | help='where to store ckpts and logs') 11 | parser.add_argument("--add_timestamp", type=int, default=0, 12 | help='add timestamp to dir') 13 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 14 | help='input data directory') 15 | parser.add_argument("--progress_refresh_rate", type=int, default=10, 16 | help='how many iterations to show psnrs or iters') 17 | 18 | parser.add_argument('--with_depth', action='store_true') 19 | parser.add_argument('--downsample_train', type=float, default=1.0) 20 | parser.add_argument('--downsample_test', type=float, default=1.0) 21 | 22 | # VQ related 23 | parser.add_argument('--model_name', type=str, default='TensorVMSplit', 24 | choices=['TensorVMSplit', 'TensorVMSplitVQ','TensorCP']) 25 | parser.add_argument("--use_cosine_sim", type=int, default=0) 26 | parser.add_argument("--codebook_size", type=int, default=4096) 27 | parser.add_argument("--codebook_dim", type=int, default=0) 28 | parser.add_argument("--autotask", action='store_true', default=False) 29 | parser.add_argument("--debug", action='store_true', default=False) 30 | parser.add_argument("--pct_mid", type=float, default=0.999) 31 | parser.add_argument("--pct_high", type=float, default=0.7) 32 | parser.add_argument("--split_or_union", type=int, default=0, help='0: split, 1: union') 33 | parser.add_argument("--lr_scale", type=float, default=0.25) 34 | parser.add_argument("--vq_iters", type=int, default=10000) 35 | parser.add_argument("--vq_up_interval", type=int, default=10) 36 | parser.add_argument("--suffix", type=str, default='v0') 37 | 38 | 39 | # loader options 40 | parser.add_argument("--batch_size", type=int, default=4096) 41 | parser.add_argument("--n_iters", type=int, default=30000) 42 | 43 | parser.add_argument('--dataset_name', type=str, default='blender', 44 | choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']) 45 | 46 | 47 | # training options 48 | # learning rate 49 | parser.add_argument("--lr_init", type=float, default=0.02, 50 | help='learning rate') 51 | parser.add_argument("--lr_basis", type=float, default=1e-3, 52 | help='learning rate') 53 | parser.add_argument("--lr_decay_iters", type=int, default=-1, 54 | help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters') 55 | parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1, 56 | help='the target decay ratio; after decay_iters inital lr decays to lr*ratio') 57 | parser.add_argument("--lr_upsample_reset", type=int, default=1, 58 | help='reset lr to inital after upsampling') 59 | 60 | # loss 61 | parser.add_argument("--L1_weight_inital", type=float, default=0.0, 62 | help='loss weight') 63 | parser.add_argument("--L1_weight_rest", type=float, default=0, 64 | help='loss weight') 65 | parser.add_argument("--Ortho_weight", type=float, default=0.0, 66 | help='loss weight') 67 | parser.add_argument("--TV_weight_density", type=float, default=0.0, 68 | help='loss weight') 69 | parser.add_argument("--TV_weight_app", type=float, default=0.0, 70 | help='loss weight') 71 | 72 | # model 73 | # volume options 74 | parser.add_argument("--n_lamb_sigma", type=int, action="append") 75 | parser.add_argument("--n_lamb_sh", type=int, action="append") 76 | parser.add_argument("--data_dim_color", type=int, default=27) 77 | 78 | parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001, 79 | help='mask points in ray marching') 80 | parser.add_argument("--alpha_mask_thre", type=float, default=0.0001, 81 | help='threshold for creating alpha mask volume') 82 | parser.add_argument("--distance_scale", type=float, default=25, 83 | help='scaling sampling distance for computation') 84 | parser.add_argument("--density_shift", type=float, default=-10, 85 | help='shift density in softplus; making density = 0 when feature == 0') 86 | 87 | # network decoder 88 | parser.add_argument("--shadingMode", type=str, default="MLP_PE", 89 | help='which shading mode to use') 90 | parser.add_argument("--pos_pe", type=int, default=6, 91 | help='number of pe for pos') 92 | parser.add_argument("--view_pe", type=int, default=6, 93 | help='number of pe for view') 94 | parser.add_argument("--fea_pe", type=int, default=6, 95 | help='number of pe for features') 96 | parser.add_argument("--featureC", type=int, default=128, 97 | help='hidden feature channel in MLP') 98 | 99 | 100 | 101 | parser.add_argument("--ckpt", type=str, default=None, 102 | help='specific weights npy file to reload for coarse network') 103 | parser.add_argument("--render_only", type=int, default=0) 104 | parser.add_argument("--render_test", type=int, default=0) 105 | parser.add_argument("--render_train", type=int, default=0) 106 | parser.add_argument("--render_path", type=int, default=0) 107 | parser.add_argument("--export_mesh", type=int, default=0) 108 | 109 | # rendering options 110 | parser.add_argument('--lindisp', default=False, action="store_true", 111 | help='use disparity depth sampling') 112 | parser.add_argument("--perturb", type=float, default=1., 113 | help='set to 0. for no jitter, 1. for jitter') 114 | parser.add_argument("--accumulate_decay", type=float, default=0.998) 115 | parser.add_argument("--fea2denseAct", type=str, default='softplus') 116 | parser.add_argument('--ndc_ray', type=int, default=0) 117 | parser.add_argument('--nSamples', type=int, default=1e6, 118 | help='sample point each ray, pass 1e6 if automatic adjust') 119 | parser.add_argument('--step_ratio',type=float,default=0.5) 120 | 121 | 122 | ## blender flags 123 | parser.add_argument("--white_bkgd", action='store_true', 124 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 125 | 126 | 127 | 128 | parser.add_argument('--N_voxel_init', 129 | type=int, 130 | default=100**3) 131 | parser.add_argument('--N_voxel_final', 132 | type=int, 133 | default=300**3) 134 | parser.add_argument("--upsamp_list", type=int, action="append") 135 | parser.add_argument("--update_AlphaMask_list", type=int, action="append") 136 | 137 | parser.add_argument('--idx_view', 138 | type=int, 139 | default=0) 140 | # logging/saving options 141 | parser.add_argument("--N_vis", type=int, default=5, 142 | help='N images to vis') 143 | parser.add_argument("--vis_every", type=int, default=10000, 144 | help='frequency of visualize the image') 145 | if cmd is not None: 146 | return parser.parse_args(cmd) 147 | else: 148 | return parser.parse_args() -------------------------------------------------------------------------------- /autotask_vq.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | from multiprocessing import Process, Queue 5 | from typing import List, Dict 6 | import subprocess 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--gpus", "-g", type=str, required=True, 11 | help="space delimited GPU id list (global id in nvidia-smi, " 12 | "not considering CUDA_VISIBLE_DEVICES)") 13 | parser.add_argument('--eval', action='store_true', default=False, 14 | help='evaluation mode (run the render_imgs script)') 15 | parser.add_argument('--dataset', type=str, default='syn', choices=['syn', 'llff','tt','nsvf']) 16 | parser.add_argument("--render_path", type=int, default=0) 17 | parser.add_argument("--render_only", type=int, default=0) 18 | parser.add_argument("--suffix", type=str, default='v0') 19 | parser.add_argument("-f","--force", action="store_true") 20 | 21 | args = parser.parse_args() 22 | 23 | PSNR_FILE_NAME = 'test_psnr.txt' 24 | 25 | def run_exp(env, config, datadir, expname, basedir, ckpt=None, suffix='v0'): 26 | base_cmd = ['python', 'vectquant.py', '--autotask', '--config', config, '--datadir', datadir, '--expname', expname, '--basedir', basedir] 27 | base_cmd = base_cmd + ['--render_path', str(args.render_path)] 28 | base_cmd = base_cmd + ['--render_only', str(args.render_only)] 29 | if ckpt is not None: 30 | base_cmd = base_cmd + ['--ckpt', ckpt] 31 | base_cmd = base_cmd + ['--suffix', suffix] 32 | # psnr_file_path = os.path.join(basedir, expname,'imgs_test_all','test_res.txt') 33 | psnr_file_path = os.path.join(basedir, expname+'_'+suffix+'_vq','extreme_load','extreme_load_res.txt') 34 | if os.path.isfile(psnr_file_path) and not args.force: 35 | print('! SKIP', psnr_file_path, "on ", env["CUDA_VISIBLE_DEVICES"]) 36 | return 37 | print('********************************************') 38 | opt_cmd = ' '.join(base_cmd) 39 | print(opt_cmd, "on ", env["CUDA_VISIBLE_DEVICES"]) 40 | opt_ret = subprocess.check_output(opt_cmd, shell=True, env=env).decode( 41 | sys.stdout.encoding) 42 | 43 | 44 | def process_main(device, queue): 45 | # Set CUDA_VISIBLE_DEVICES programmatically 46 | env = os.environ.copy() 47 | env["CUDA_VISIBLE_DEVICES"] = str(device) 48 | while True: 49 | task = queue.get() 50 | if len(task) == 0: 51 | break 52 | 53 | run_exp(env, **task) 54 | 55 | 56 | DatasetSetting={ 57 | "syn": { 58 | "data": "/bfs/HoloResearch/NeRFData/nerf_synthetic", 59 | "cfg": f"configs/vq/syn.txt", 60 | "scene_list":[ 'chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship'], 61 | "basedir":f"./log_reimp/syn" 62 | }, 63 | "llff":{ 64 | "data": "/bfs/HoloResearch/NeRFData/nerf_llff_data", 65 | "cfg": "configs/vq/llff.txt", 66 | "scene_list": ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'], 67 | "basedir": "./log_reimp/llff" 68 | }, 69 | "tt":{ 70 | "data": "/bfs/HoloResearch/NeRFData/TanksAndTemple", 71 | "cfg": "configs/vq/tt.txt", 72 | "scene_list": ['Barn','Caterpillar','Family','Ignatius', 'Truck'], 73 | "basedir": "./log_reimp/tt" 74 | }, 75 | "nsvf":{ 76 | "data": "/bfs/HoloResearch/NeRFData/Synthetic_NSVF", 77 | "cfg": "configs/vq/nsvf.txt", 78 | "scene_list": ['Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'], 79 | "basedir": "./log_reimp/nsvf" 80 | } 81 | } 82 | 83 | 84 | datasetting = DatasetSetting[args.dataset] 85 | all_tasks = [] 86 | for scene in datasetting["scene_list"]: 87 | task: Dict = {} 88 | task['datadir'] = f'{datasetting["data"]}/{scene}' 89 | task['expname'] = f'{scene}' 90 | task["config"] = datasetting['cfg'] 91 | task["basedir"] = datasetting["basedir"] 92 | task["ckpt"] = f"{datasetting['basedir']}/{scene}/{scene}.th" 93 | task["suffix"] = args.suffix 94 | assert os.path.exists(task['datadir']), task['datadir'] + ' does not exist' 95 | assert os.path.isfile(task['config']), task['config'] + ' does not exist' 96 | all_tasks.append(task) 97 | 98 | pqueue = Queue() 99 | for task in all_tasks: 100 | pqueue.put(task) 101 | 102 | args.gpus = list(map(int, args.gpus.split())) 103 | print('GPUS:', args.gpus) 104 | 105 | for _ in args.gpus: 106 | pqueue.put({}) 107 | 108 | all_procs = [] 109 | for i, gpu in enumerate(args.gpus): 110 | process = Process(target=process_main, args=(gpu, pqueue)) 111 | process.daemon = True 112 | process.start() 113 | all_procs.append(process) 114 | 115 | for i, gpu in enumerate(args.gpus): 116 | all_procs[i].join() 117 | 118 | 119 | 120 | class AverageMeter(object): 121 | def __init__(self, name=''): 122 | self.name=name 123 | self.reset() 124 | def reset(self): 125 | self.val=0 126 | self.sum=0 127 | self.avg=0 128 | self.count=0 129 | def update(self,val,n=1): 130 | self.val=val 131 | self.sum += val*n 132 | self.count += n 133 | self.avg=self.sum/self.count 134 | def __repr__(self) -> str: 135 | return f'{self.name}: average {self.count}: {self.avg}\n' 136 | 137 | from prettytable import PrettyTable 138 | table = PrettyTable(["Scene", "PSNR", "SSIM", "LPIPS_ALEX","LPIPS_VGG","SIZE"]) 139 | table.float_format = '.4' 140 | 141 | PSNR=AverageMeter('PSNR') 142 | SSIM=AverageMeter('SSIM') 143 | LPIPS_A=AverageMeter('LPIPS_A') 144 | LPIPS_V=AverageMeter('LPIPS_V') 145 | SIZE=AverageMeter('SIZE') 146 | for scene in datasetting['scene_list']: 147 | path = f'./{datasetting["basedir"]}/{scene}/imgs_test_all/test_res.txt' 148 | if not os.path.exists(path): 149 | path = f'./{datasetting["basedir"]}/{scene}/imgs_test_all/testmean.txt' 150 | 151 | with open(path, 'r') as f: 152 | lines = f.readlines() 153 | psnr = float(lines[0].strip()) 154 | ssim = float(lines[1].strip()) 155 | lpips_a = float(lines[2].strip()) 156 | lpips_v = float(lines[3].strip()) 157 | PSNR.update(psnr) 158 | SSIM.update(ssim) 159 | LPIPS_A.update(lpips_a) 160 | LPIPS_V.update(lpips_v) 161 | uncompressed_file = f'./{datasetting["basedir"]}/{scene}/{scene}.th' 162 | if os.path.exists(uncompressed_file): 163 | size = os.path.getsize(uncompressed_file)/(1024*1024) 164 | else: 165 | size = 0 166 | SIZE.update(size) 167 | table.add_row([scene, psnr, ssim, lpips_a, lpips_v,size]) 168 | table.add_row(['Mean', PSNR.avg, SSIM.avg, LPIPS_A.avg,LPIPS_V.avg, SIZE.avg]) 169 | 170 | writedir = os.path.join(datasetting["basedir"], 'merge.txt') 171 | with open(writedir, 'w') as f: 172 | f.writelines(table.get_string()) 173 | print(table) 174 | 175 | 176 | table = PrettyTable(["Scene", "PSNR", "SSIM", "LPIPS_ALEX","LPIPS_VGG","SIZE"]) 177 | table.float_format = '.4' 178 | 179 | PSNR=AverageMeter('PSNR') 180 | SSIM=AverageMeter('SSIM') 181 | LPIPS_A=AverageMeter('LPIPS_A') 182 | LPIPS_V=AverageMeter('LPIPS_V') 183 | SIZE=AverageMeter('SIZE') 184 | for scene in datasetting['scene_list']: 185 | path = f'./{datasetting["basedir"]}/{scene}_{args.suffix}_vq/extreme_load/extreme_load_res.txt' 186 | if not os.path.exists(path): 187 | path = f'./{datasetting["basedir"]}/{scene}_{args.suffix}_vq/test5/vq_quant_0.7_res.txt' 188 | 189 | with open(path, 'r') as f: 190 | lines = f.readlines() 191 | psnr = float(lines[0].strip()) 192 | ssim = float(lines[1].strip()) 193 | lpips_a = float(lines[2].strip()) 194 | lpips_v = float(lines[3].strip()) 195 | PSNR.update(psnr) 196 | SSIM.update(ssim) 197 | LPIPS_A.update(lpips_a) 198 | LPIPS_V.update(lpips_v) 199 | uncompressed_file = f'./{datasetting["basedir"]}/{scene}_{args.suffix}_vq/extreme_ckpt.pt' 200 | compressed_file = f'./{datasetting["basedir"]}/{scene}_{args.suffix}_vq/extreme_ckpt.zip' 201 | if os.path.exists(compressed_file): 202 | size = os.path.getsize(compressed_file)/(1024*1024) 203 | else: 204 | size = 0 205 | SIZE.update(size) 206 | table.add_row([scene, psnr, ssim, lpips_a, lpips_v,size]) 207 | table.add_row(['Mean', PSNR.avg, SSIM.avg, LPIPS_A.avg,LPIPS_V.avg, SIZE.avg]) 208 | 209 | writedir = os.path.join(datasetting["basedir"], f'merge_{args.suffix}_vq.txt') 210 | with open(writedir, 'w') as f: 211 | f.writelines(table.get_string()) 212 | print(table) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2,torch 2 | import numpy as np 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import torch.nn.functional as F 6 | import scipy.signal 7 | 8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 9 | 10 | 11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 12 | """ 13 | depth: (H, W) 14 | """ 15 | 16 | x = np.nan_to_num(depth) # change nan to 0 17 | if minmax is None: 18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 19 | ma = np.max(x) 20 | else: 21 | mi,ma = minmax 22 | 23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 24 | x = (255*x).astype(np.uint8) 25 | x_ = cv2.applyColorMap(x, cmap) 26 | return x_, [mi,ma] 27 | 28 | def init_log(log, keys): 29 | for key in keys: 30 | log[key] = torch.tensor([0.0], dtype=float) 31 | return log 32 | 33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 34 | """ 35 | depth: (H, W) 36 | """ 37 | if type(depth) is not np.ndarray: 38 | depth = depth.cpu().numpy() 39 | 40 | x = np.nan_to_num(depth) # change nan to 0 41 | if minmax is None: 42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 43 | ma = np.max(x) 44 | else: 45 | mi,ma = minmax 46 | 47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 48 | x = (255*x).astype(np.uint8) 49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 50 | x_ = T.ToTensor()(x_) # (3, H, W) 51 | return x_, [mi,ma] 52 | 53 | def N_to_reso(n_voxels, bbox): 54 | xyz_min, xyz_max = bbox 55 | dim = len(xyz_min) 56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 57 | return ((xyz_max - xyz_min) / voxel_size).long().tolist() 58 | 59 | def cal_n_samples(reso, step_ratio=0.5): 60 | return int(np.linalg.norm(reso)/step_ratio) 61 | 62 | 63 | 64 | 65 | __LPIPS__ = {} 66 | def init_lpips(net_name, device): 67 | assert net_name in ['alex', 'vgg'] 68 | import lpips 69 | print(f'init_lpips: lpips_{net_name}') 70 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 71 | 72 | def rgb_lpips(np_gt, np_im, net_name, device): 73 | if net_name not in __LPIPS__: 74 | __LPIPS__[net_name] = init_lpips(net_name, device) 75 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 76 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 77 | return __LPIPS__[net_name](gt, im, normalize=True).item() 78 | 79 | 80 | def findItem(items, target): 81 | for one in items: 82 | if one[:len(target)]==target: 83 | return one 84 | return None 85 | 86 | 87 | ''' Evaluation metrics (ssim, lpips) 88 | ''' 89 | def rgb_ssim(img0, img1, max_val, 90 | filter_size=11, 91 | filter_sigma=1.5, 92 | k1=0.01, 93 | k2=0.03, 94 | return_map=False): 95 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 96 | assert len(img0.shape) == 3 97 | assert img0.shape[-1] == 3 98 | assert img0.shape == img1.shape 99 | 100 | # Construct a 1D Gaussian blur filter. 101 | hw = filter_size // 2 102 | shift = (2 * hw - filter_size + 1) / 2 103 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 104 | filt = np.exp(-0.5 * f_i) 105 | filt /= np.sum(filt) 106 | 107 | # Blur in x and y (faster than the 2D convolution). 108 | def convolve2d(z, f): 109 | return scipy.signal.convolve2d(z, f, mode='valid') 110 | 111 | filt_fn = lambda z: np.stack([ 112 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 113 | for i in range(z.shape[-1])], -1) 114 | mu0 = filt_fn(img0) 115 | mu1 = filt_fn(img1) 116 | mu00 = mu0 * mu0 117 | mu11 = mu1 * mu1 118 | mu01 = mu0 * mu1 119 | sigma00 = filt_fn(img0**2) - mu00 120 | sigma11 = filt_fn(img1**2) - mu11 121 | sigma01 = filt_fn(img0 * img1) - mu01 122 | 123 | # Clip the variances and covariances to valid values. 124 | # Variance must be non-negative: 125 | sigma00 = np.maximum(0., sigma00) 126 | sigma11 = np.maximum(0., sigma11) 127 | sigma01 = np.sign(sigma01) * np.minimum( 128 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 129 | c1 = (k1 * max_val)**2 130 | c2 = (k2 * max_val)**2 131 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 132 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 133 | ssim_map = numer / denom 134 | ssim = np.mean(ssim_map) 135 | return ssim_map if return_map else ssim 136 | 137 | 138 | import torch.nn as nn 139 | class TVLoss(nn.Module): 140 | def __init__(self,TVLoss_weight=1): 141 | super(TVLoss,self).__init__() 142 | self.TVLoss_weight = TVLoss_weight 143 | 144 | def forward(self,x): 145 | batch_size = x.size()[0] 146 | h_x = x.size()[2] 147 | w_x = x.size()[3] 148 | count_h = self._tensor_size(x[:,:,1:,:]) 149 | count_w = self._tensor_size(x[:,:,:,1:]) 150 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 151 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 152 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 153 | 154 | def _tensor_size(self,t): 155 | return t.size()[1]*t.size()[2]*t.size()[3] 156 | 157 | 158 | 159 | import plyfile 160 | import skimage.measure 161 | def convert_sdf_samples_to_ply( 162 | pytorch_3d_sdf_tensor, 163 | ply_filename_out, 164 | bbox, 165 | level=0.5, 166 | offset=None, 167 | scale=None, 168 | ): 169 | """ 170 | Convert sdf samples to .ply 171 | 172 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 173 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 174 | :voxel_size: float, the size of the voxels 175 | :ply_filename_out: string, path of the filename to save to 176 | 177 | This function adapted from: https://github.com/RobotLocomotion/spartan 178 | """ 179 | 180 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 181 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) 182 | 183 | verts, faces, normals, values = skimage.measure.marching_cubes( 184 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size 185 | ) 186 | faces = faces[...,::-1] # inverse face orientation 187 | 188 | # transform from voxel coordinates to camera coordinates 189 | # note x and y are flipped in the output of marching_cubes 190 | mesh_points = np.zeros_like(verts) 191 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0] 192 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1] 193 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2] 194 | 195 | # apply additional offset and scale 196 | if scale is not None: 197 | mesh_points = mesh_points / scale 198 | if offset is not None: 199 | mesh_points = mesh_points - offset 200 | 201 | # try writing to the ply file 202 | 203 | num_verts = verts.shape[0] 204 | num_faces = faces.shape[0] 205 | 206 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 207 | 208 | for i in range(0, num_verts): 209 | verts_tuple[i] = tuple(mesh_points[i, :]) 210 | 211 | faces_building = [] 212 | for i in range(0, num_faces): 213 | faces_building.append(((faces[i, :].tolist(),))) 214 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 215 | 216 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 217 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 218 | 219 | ply_data = plyfile.PlyData([el_verts, el_faces]) 220 | print("saving mesh to %s" % (ply_filename_out)) 221 | ply_data.write(ply_filename_out) 222 | 223 | 224 | class Timing: 225 | """ 226 | Timing environment 227 | usage: 228 | with Timing("message"): 229 | your commands here 230 | will print CUDA runtime in ms 231 | """ 232 | 233 | def __init__(self, name, debug=False): 234 | self.name = name 235 | self.debug = debug 236 | 237 | def __enter__(self): 238 | if not self.debug: 239 | return 240 | 241 | self.start = torch.cuda.Event(enable_timing=True) 242 | self.end = torch.cuda.Event(enable_timing=True) 243 | self.start.record() 244 | 245 | def __exit__(self, type, value, traceback): 246 | if not self.debug: 247 | return 248 | 249 | self.end.record() 250 | torch.cuda.synchronize() 251 | print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") -------------------------------------------------------------------------------- /extra/auto_run_paramsets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading, queue 3 | import numpy as np 4 | import time 5 | 6 | 7 | def getFolderLocker(logFolder): 8 | while True: 9 | try: 10 | os.makedirs(logFolder+"/lockFolder") 11 | break 12 | except: 13 | time.sleep(0.01) 14 | 15 | def releaseFolderLocker(logFolder): 16 | os.removedirs(logFolder+"/lockFolder") 17 | 18 | def getStopFolder(logFolder): 19 | return os.path.isdir(logFolder+"/stopFolder") 20 | 21 | 22 | def get_param_str(key, val): 23 | if key == 'data_name': 24 | return f'--datadir {datafolder}/{val} ' 25 | else: 26 | return f'--{key} {val} ' 27 | 28 | def get_param_list(param_dict): 29 | param_keys = list(param_dict.keys()) 30 | param_modes = len(param_keys) 31 | param_nums = [len(param_dict[key]) for key in param_keys] 32 | 33 | param_ids = np.zeros(param_nums+[param_modes], dtype=int) 34 | for i in range(param_modes): 35 | broad_tuple = np.ones(param_modes, dtype=int).tolist() 36 | broad_tuple[i] = param_nums[i] 37 | broad_tuple = tuple(broad_tuple) 38 | print(broad_tuple) 39 | param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple) 40 | param_ids = param_ids.reshape(-1, param_modes) 41 | # print(param_ids) 42 | print(len(param_ids)) 43 | 44 | params = [] 45 | expnames = [] 46 | for i in range(param_ids.shape[0]): 47 | one = "" 48 | name = "" 49 | param_id = param_ids[i] 50 | for j in range(param_modes): 51 | key = param_keys[j] 52 | val = param_dict[key][param_id[j]] 53 | if type(key) is tuple: 54 | assert len(key) == len(val) 55 | for k in range(len(key)): 56 | one += get_param_str(key[k], val[k]) 57 | name += f'{val[k]},' 58 | name=name[:-1]+'-' 59 | else: 60 | one += get_param_str(key, val) 61 | name += f'{val}-' 62 | params.append(one) 63 | name=name.replace(' ','') 64 | print(name) 65 | expnames.append(name[:-1]) 66 | # print(params) 67 | return params, expnames 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | 78 | 79 | # nerf 80 | expFolder = "nerf/" 81 | # parameters to iterate, use tuple to couple multiple parameters 82 | datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/' 83 | param_dict = { 84 | 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'], 85 | 'data_dim_color': [13, 27, 54] 86 | } 87 | 88 | # n_iters = 30000 89 | # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 90 | # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \ 91 | # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\ 92 | # f'--expname {data_name} --batch_size {batch_size} ' \ 93 | # f'--n_iters {n_iters} ' \ 94 | # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\ 95 | # f'--N_vis {5} ' \ 96 | # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \ 97 | # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \ 98 | # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \ 99 | # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \ 100 | # f'--render_test 1 ' 101 | # print(cmd) 102 | # os.system(cmd) 103 | 104 | # nsvf 105 | # expFolder = "nsvf_0227/" 106 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/' 107 | # param_dict = { 108 | # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 109 | # 'shadingMode': ['SH'], 110 | # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")], 111 | # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)], 112 | # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)], 113 | # ('n_iters','N_voxel_final'): [(30000,300**3)], 114 | # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)], 115 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")] 116 | # 117 | # } 118 | 119 | # tankstemple 120 | # expFolder = "tankstemple_0304/" 121 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/' 122 | # param_dict = { 123 | # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'], 124 | # 'shadingMode': ['MLP_Fea'], 125 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")], 126 | # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)], 127 | # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)], 128 | # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)], 129 | # ('n_iters','N_voxel_final'): [(15000,300**3)], 130 | # ('dataset_name','N_vis') : [("tankstemple",5)], 131 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")] 132 | # } 133 | 134 | # llff 135 | # expFolder = "real_iconic/" 136 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/' 137 | # List = os.listdir(datafolder) 138 | # param_dict = { 139 | # 'data_name': List, 140 | # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)], 141 | # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")], 142 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 143 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 144 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)], 145 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 146 | # } 147 | 148 | # expFolder = "llff/" 149 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data' 150 | # param_dict = { 151 | # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids' 152 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")], 153 | # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)], 154 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 155 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 156 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)], 157 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 158 | # } 159 | 160 | #setting available gpus 161 | gpus_que = queue.Queue(3) 162 | for i in [1,2,3]: 163 | gpus_que.put(i) 164 | 165 | os.makedirs(f"log/{expFolder}", exist_ok=True) 166 | 167 | def run_program(gpu, expname, param): 168 | cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \ 169 | f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \ 170 | f'{param}' \ 171 | f'> "log/{expFolder}{expname}/{expname}.txt"' 172 | print(cmd) 173 | os.system(cmd) 174 | gpus_que.put(gpu) 175 | 176 | params, expnames = get_param_list(param_dict) 177 | 178 | 179 | logFolder=f"log/{expFolder}" 180 | os.makedirs(logFolder, exist_ok=True) 181 | 182 | ths = [] 183 | for i in range(len(params)): 184 | 185 | if getStopFolder(logFolder): 186 | break 187 | 188 | 189 | targetFolder = f"log/{expFolder}{expnames[i]}" 190 | gpu = gpus_que.get() 191 | getFolderLocker(logFolder) 192 | if os.path.isdir(targetFolder): 193 | releaseFolderLocker(logFolder) 194 | gpus_que.put(gpu) 195 | continue 196 | else: 197 | os.makedirs(targetFolder, exist_ok=True) 198 | print("making",targetFolder, "running",expnames[i], params[i]) 199 | releaseFolderLocker(logFolder) 200 | 201 | 202 | t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True) 203 | t.start() 204 | ths.append(t) 205 | 206 | for th in ths: 207 | th.join() -------------------------------------------------------------------------------- /dataLoader/tankstemple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | 11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 12 | if axis == 'z': 13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] 14 | elif axis == 'y': 15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] 16 | else: 17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] 18 | 19 | 20 | def cross(x, y, axis=0): 21 | T = torch if isinstance(x, torch.Tensor) else np 22 | return T.cross(x, y, axis) 23 | 24 | 25 | def normalize(x, axis=-1, order=2): 26 | if isinstance(x, torch.Tensor): 27 | l2 = x.norm(p=order, dim=axis, keepdim=True) 28 | return x / (l2 + 1e-8), l2 29 | 30 | else: 31 | l2 = np.linalg.norm(x, order, axis) 32 | l2 = np.expand_dims(l2, axis) 33 | l2[l2 == 0] = 1 34 | return x / l2, 35 | 36 | 37 | def cat(x, axis=1): 38 | if isinstance(x[0], torch.Tensor): 39 | return torch.cat(x, dim=axis) 40 | return np.concatenate(x, axis=axis) 41 | 42 | 43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 44 | """ 45 | This function takes a vector 'camera_position' which specifies the location 46 | of the camera in world coordinates and two vectors `at` and `up` which 47 | indicate the position of the object and the up directions of the world 48 | coordinate system respectively. The object is assumed to be centered at 49 | the origin. 50 | The output is a rotation matrix representing the transformation 51 | from world coordinates -> view coordinates. 52 | Input: 53 | camera_position: 3 54 | at: 1 x 3 or N x 3 (0, 0, 0) in default 55 | up: 1 x 3 or N x 3 (0, 1, 0) in default 56 | """ 57 | 58 | if at is None: 59 | at = torch.zeros_like(camera_position) 60 | else: 61 | at = torch.tensor(at).type_as(camera_position) 62 | if up is None: 63 | up = torch.zeros_like(camera_position) 64 | up[2] = -1 65 | else: 66 | up = torch.tensor(up).type_as(camera_position) 67 | 68 | z_axis = normalize(at - camera_position)[0] 69 | x_axis = normalize(cross(up, z_axis))[0] 70 | y_axis = normalize(cross(z_axis, x_axis))[0] 71 | 72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 73 | return R 74 | 75 | 76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): 77 | c2ws = [] 78 | for t in range(frames): 79 | c2w = torch.eye(4) 80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) 81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) 82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot 83 | c2ws.append(c2w) 84 | return torch.stack(c2ws) 85 | 86 | class TanksTempleDataset(Dataset): 87 | """NSVF Generic Dataset.""" 88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False): 89 | self.root_dir = datadir 90 | self.split = split 91 | self.is_stack = is_stack 92 | self.downsample = downsample 93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 94 | self.define_transforms() 95 | 96 | self.white_bg = True 97 | self.near_far = [0.01,6.0] 98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2 99 | 100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 101 | self.read_meta() 102 | self.define_proj_mat() 103 | 104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 106 | 107 | def bbox2corners(self): 108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 109 | for i in range(3): 110 | corners[i,[0,1],i] = corners[i,[1,0],i] 111 | return corners.view(-1,3) 112 | 113 | 114 | def read_meta(self): 115 | 116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) 117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1) 118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 120 | 121 | if self.split == 'train': 122 | pose_files = [x for x in pose_files if x.startswith('0_')] 123 | img_files = [x for x in img_files if x.startswith('0_')] 124 | elif self.split == 'val': 125 | pose_files = [x for x in pose_files if x.startswith('1_')] 126 | img_files = [x for x in img_files if x.startswith('1_')] 127 | elif self.split == 'test': 128 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 129 | test_img_files = [x for x in img_files if x.startswith('2_')] 130 | if len(test_pose_files) == 0: 131 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 132 | test_img_files = [x for x in img_files if x.startswith('1_')] 133 | pose_files = test_pose_files 134 | img_files = test_img_files 135 | 136 | # ray directions for all pixels, same for all images (same H, W, focal) 137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 139 | 140 | 141 | 142 | self.poses = [] 143 | self.all_rays = [] 144 | self.all_rgbs = [] 145 | 146 | assert len(img_files) == len(pose_files) 147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 149 | img = Image.open(image_path) 150 | if self.downsample!=1.0: 151 | img = img.resize(self.img_wh, Image.LANCZOS) 152 | img = self.transform(img) # (4, h, w) 153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 154 | if img.shape[-1]==4: 155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 156 | self.all_rgbs.append(img) 157 | 158 | 159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans 160 | c2w = torch.FloatTensor(c2w) 161 | self.poses.append(c2w) # C2W 162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 164 | 165 | self.poses = torch.stack(self.poses) 166 | 167 | center = torch.mean(self.scene_bbox, dim=0) 168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2 169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() 170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y') 171 | self.render_path = gen_path(pos_gen, up=up,frames=200) 172 | self.render_path[:, :3, 3] += center 173 | 174 | 175 | 176 | if 'train' == self.split: 177 | if self.is_stack: 178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 180 | else: 181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 183 | else: 184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 186 | 187 | 188 | def define_transforms(self): 189 | self.transform = T.ToTensor() 190 | 191 | def define_proj_mat(self): 192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 193 | 194 | def world2ndc(self, points): 195 | device = points.device 196 | return (points - self.center.to(device)) / self.radius.to(device) 197 | 198 | def __len__(self): 199 | if self.split == 'train': 200 | return len(self.all_rays) 201 | return len(self.all_rgbs) 202 | 203 | def __getitem__(self, idx): 204 | 205 | if self.split == 'train': # use data in the buffers 206 | sample = {'rays': self.all_rays[idx], 207 | 'rgbs': self.all_rgbs[idx]} 208 | 209 | else: # create data for each image separately 210 | 211 | img = self.all_rgbs[idx] 212 | rays = self.all_rays[idx] 213 | 214 | sample = {'rays': rays, 215 | 'rgbs': img} 216 | return sample -------------------------------------------------------------------------------- /dataLoader/llff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import glob 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | def normalize(v): 13 | """Normalize a vector.""" 14 | return v / np.linalg.norm(v) 15 | 16 | 17 | def average_poses(poses): 18 | """ 19 | Calculate the average pose, which is then used to center all poses 20 | using @center_poses. Its computation is as follows: 21 | 1. Compute the center: the average of pose centers. 22 | 2. Compute the z axis: the normalized average z axis. 23 | 3. Compute axis y': the average y axis. 24 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 25 | 5. Compute the y axis: z cross product x. 26 | 27 | Note that at step 3, we cannot directly use y' as y axis since it's 28 | not necessarily orthogonal to z axis. We need to pass from x to y. 29 | Inputs: 30 | poses: (N_images, 3, 4) 31 | Outputs: 32 | pose_avg: (3, 4) the average pose 33 | """ 34 | # 1. Compute the center 35 | center = poses[..., 3].mean(0) # (3) 36 | 37 | # 2. Compute the z axis 38 | z = normalize(poses[..., 2].mean(0)) # (3) 39 | 40 | # 3. Compute axis y' (no need to normalize as it's not the final output) 41 | y_ = poses[..., 1].mean(0) # (3) 42 | 43 | # 4. Compute the x axis 44 | x = normalize(np.cross(z, y_)) # (3) 45 | 46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 47 | y = np.cross(x, z) # (3) 48 | 49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 50 | 51 | return pose_avg 52 | 53 | 54 | def center_poses(poses, blender2opencv): 55 | """ 56 | Center the poses so that we can use NDC. 57 | See https://github.com/bmild/nerf/issues/34 58 | Inputs: 59 | poses: (N_images, 3, 4) 60 | Outputs: 61 | poses_centered: (N_images, 3, 4) the centered poses 62 | pose_avg: (3, 4) the average pose 63 | """ 64 | poses = poses @ blender2opencv 65 | pose_avg = average_poses(poses) # (3, 4) 66 | pose_avg_homo = np.eye(4) 67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 68 | pose_avg_homo = pose_avg_homo 69 | # by simply adding 0, 0, 0, 1 as the last row 70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 71 | poses_homo = \ 72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 73 | 74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 75 | # poses_centered = poses_centered @ blender2opencv 76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 77 | 78 | return poses_centered, pose_avg_homo 79 | 80 | 81 | def viewmatrix(z, up, pos): 82 | vec2 = normalize(z) 83 | vec1_avg = up 84 | vec0 = normalize(np.cross(vec1_avg, vec2)) 85 | vec1 = normalize(np.cross(vec2, vec0)) 86 | m = np.eye(4) 87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) 88 | return m 89 | 90 | 91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 92 | render_poses = [] 93 | rads = np.array(list(rads) + [1.]) 94 | 95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: 96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 98 | render_poses.append(viewmatrix(z, up, c)) 99 | return render_poses 100 | 101 | 102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): 103 | # center pose 104 | c2w = average_poses(c2ws_all) 105 | 106 | # Get average pose 107 | up = normalize(c2ws_all[:, :3, 1].sum(0)) 108 | 109 | # Find a reasonable "focus depth" for this dataset 110 | dt = 0.75 111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 113 | 114 | # Get radii for spiral path 115 | zdelta = near_fars.min() * .2 116 | tt = c2ws_all[:, :3, 3] 117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale 118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 119 | return np.stack(render_poses) 120 | 121 | 122 | class LLFFDataset(Dataset): 123 | def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8): 124 | """ 125 | spheric_poses: whether the images are taken in a spheric inward-facing manner 126 | default: False (forward-facing) 127 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 128 | """ 129 | 130 | self.root_dir = datadir 131 | self.split = split 132 | self.hold_every = hold_every 133 | self.is_stack = is_stack 134 | self.downsample = downsample 135 | self.define_transforms() 136 | 137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 138 | self.read_meta() 139 | self.white_bg = False 140 | 141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] 142 | self.near_far = [0.0, 1.0] 143 | self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]) 144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) 145 | self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) 146 | self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 147 | 148 | def read_meta(self): 149 | 150 | 151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) 152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*'))) 153 | # load full resolution image then resize 154 | if self.split in ['train', 'test']: 155 | assert len(poses_bounds) == len(self.image_paths), \ 156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!' 157 | 158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2) 160 | hwf = poses[:, :, -1] 161 | 162 | # Step 1: rescale focal length according to training resolution 163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images 164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) 165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] 166 | 167 | # Step 2: correct poses 168 | # Original poses has rotation in form "down right back", change to "right up back" 169 | # See https://github.com/bmild/nerf/issues/34 170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 171 | # (N_images, 3, 4) exclude H, W, focal 172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) 173 | 174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 175 | # See https://github.com/bmild/nerf/issues/34 176 | near_original = self.near_fars.min() 177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 178 | # the nearest depth is at 1/0.75=1.33 179 | self.near_fars /= scale_factor 180 | self.poses[..., 3] /= scale_factor 181 | 182 | # build rendering path 183 | N_views, N_rots = 120, 2 184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 185 | up = normalize(self.poses[:, :3, 1].sum(0)) 186 | rads = np.percentile(np.abs(tt), 90, 0) 187 | 188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views) 189 | 190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to 192 | # center image 193 | 194 | # ray directions for all pixels, same for all images (same H, W, focal) 195 | W, H = self.img_wh 196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) 197 | 198 | average_pose = average_poses(self.poses) 199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1) 200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)] 201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test)) 202 | 203 | # use first N_images-1 to train, the LAST is val 204 | self.all_rays = [] 205 | self.all_rgbs = [] 206 | for i in img_list: 207 | image_path = self.image_paths[i] 208 | c2w = torch.FloatTensor(self.poses[i]) 209 | 210 | img = Image.open(image_path).convert('RGB') 211 | if self.downsample != 1.0: 212 | img = img.resize(self.img_wh, Image.LANCZOS) 213 | img = self.transform(img) # (3, h, w) 214 | 215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 216 | self.all_rgbs += [img] 217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 220 | 221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 222 | 223 | if not self.is_stack: 224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) 226 | else: 227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3) 228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 229 | 230 | 231 | def define_transforms(self): 232 | self.transform = T.ToTensor() 233 | 234 | def __len__(self): 235 | return len(self.all_rgbs) 236 | 237 | def __getitem__(self, idx): 238 | 239 | sample = {'rays': self.all_rays[idx], 240 | 'rgbs': self.all_rgbs[idx]} 241 | 242 | return sample -------------------------------------------------------------------------------- /dataLoader/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch, re 2 | import numpy as np 3 | from torch import searchsorted 4 | from kornia import create_meshgrid 5 | 6 | 7 | # from utils import index_point_feature 8 | 9 | def depth2dist(z_vals, cos_angle): 10 | # z_vals: [N_ray N_sample] 11 | device = z_vals.device 12 | dists = z_vals[..., 1:] - z_vals[..., :-1] 13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 14 | dists = dists * cos_angle.unsqueeze(-1) 15 | return dists 16 | 17 | 18 | def ndc2dist(ndc_pts, cos_angle): 19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 21 | return dists 22 | 23 | 24 | def get_ray_directions(H, W, focal, center=None): 25 | """ 26 | Get ray directions for all pixels in camera coordinate. 27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 28 | ray-tracing-generating-camera-rays/standard-coordinate-systems 29 | Inputs: 30 | H, W, focal: image height, width and focal length 31 | Outputs: 32 | directions: (H, W, 3), the direction of the rays in camera coordinate 33 | """ 34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 35 | 36 | i, j = grid.unbind(-1) 37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 38 | # see https://github.com/bmild/nerf/issues/24 39 | cent = center if center is not None else [W / 2, H / 2] 40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 41 | 42 | return directions 43 | 44 | 45 | def get_ray_directions_blender(H, W, focal, center=None): 46 | """ 47 | Get ray directions for all pixels in camera coordinate. 48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 49 | ray-tracing-generating-camera-rays/standard-coordinate-systems 50 | Inputs: 51 | H, W, focal: image height, width and focal length 52 | Outputs: 53 | directions: (H, W, 3), the direction of the rays in camera coordinate 54 | """ 55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 56 | i, j = grid.unbind(-1) 57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 58 | # see https://github.com/bmild/nerf/issues/24 59 | cent = center if center is not None else [W / 2, H / 2] 60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], 61 | -1) # (H, W, 3) 62 | 63 | return directions 64 | 65 | 66 | def get_rays(directions, c2w): 67 | """ 68 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 70 | ray-tracing-generating-camera-rays/standard-coordinate-systems 71 | Inputs: 72 | directions: (H, W, 3) precomputed ray directions in camera coordinate 73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 74 | Outputs: 75 | rays_o: (H*W, 3), the origin of the rays in world coordinate 76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 77 | """ 78 | # Rotate ray directions from camera coordinate to the world coordinate 79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 81 | # The origin of all rays is the camera origin in world coordinate 82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 83 | 84 | rays_d = rays_d.view(-1, 3) 85 | rays_o = rays_o.view(-1, 3) 86 | 87 | return rays_o, rays_d 88 | 89 | 90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): 91 | # Shift ray origins to near plane 92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 93 | rays_o = rays_o + t[..., None] * rays_d 94 | 95 | # Projection 96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 98 | o2 = 1. + 2. * near / rays_o[..., 2] 99 | 100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 102 | d2 = -2. * near / rays_o[..., 2] 103 | 104 | rays_o = torch.stack([o0, o1, o2], -1) 105 | rays_d = torch.stack([d0, d1, d2], -1) 106 | 107 | return rays_o, rays_d 108 | 109 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 110 | # Shift ray origins to near plane 111 | t = (near - rays_o[..., 2]) / rays_d[..., 2] 112 | rays_o = rays_o + t[..., None] * rays_d 113 | 114 | # Projection 115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 117 | o2 = 1. - 2. * near / rays_o[..., 2] 118 | 119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 121 | d2 = 2. * near / rays_o[..., 2] 122 | 123 | rays_o = torch.stack([o0, o1, o2], -1) 124 | rays_d = torch.stack([d0, d1, d2], -1) 125 | 126 | return rays_o, rays_d 127 | 128 | # Hierarchical sampling (section 5.2) 129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 130 | device = weights.device 131 | # Get pdf 132 | weights = weights + 1e-5 # prevent nans 133 | pdf = weights / torch.sum(weights, -1, keepdim=True) 134 | cdf = torch.cumsum(pdf, -1) 135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 136 | 137 | # Take uniform samples 138 | if det: 139 | u = torch.linspace(0., 1., steps=N_samples, device=device) 140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 141 | else: 142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) 143 | 144 | # Pytest, overwrite u with numpy's fixed random numbers 145 | if pytest: 146 | np.random.seed(0) 147 | new_shape = list(cdf.shape[:-1]) + [N_samples] 148 | if det: 149 | u = np.linspace(0., 1., N_samples) 150 | u = np.broadcast_to(u, new_shape) 151 | else: 152 | u = np.random.rand(*new_shape) 153 | u = torch.Tensor(u) 154 | 155 | # Invert CDF 156 | u = u.contiguous() 157 | inds = searchsorted(cdf.detach(), u, right=True) 158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 161 | 162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 165 | 166 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 168 | t = (u - cdf_g[..., 0]) / denom 169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 170 | 171 | return samples 172 | 173 | 174 | def dda(rays_o, rays_d, bbox_3D): 175 | inv_ray_d = 1.0 / (rays_d + 1e-6) 176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d 178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3 179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] 180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] 181 | return t_min, t_max 182 | 183 | 184 | def ray_marcher(rays, 185 | N_samples=64, 186 | lindisp=False, 187 | perturb=0, 188 | bbox_3D=None): 189 | """ 190 | sample points along the rays 191 | Inputs: 192 | rays: () 193 | 194 | Returns: 195 | 196 | """ 197 | 198 | # Decompose the inputs 199 | N_rays = rays.shape[0] 200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 202 | 203 | if bbox_3D is not None: 204 | # cal aabb boundles 205 | near, far = dda(rays_o, rays_d, bbox_3D) 206 | 207 | # Sample depth points 208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 209 | if not lindisp: # use linear sampling in depth space 210 | z_vals = near * (1 - z_steps) + far * z_steps 211 | else: # use linear sampling in disparity space 212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 213 | 214 | z_vals = z_vals.expand(N_rays, N_samples) 215 | 216 | if perturb > 0: # perturb sampling depths (z_vals) 217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points 218 | # get intervals between samples 219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 221 | 222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 223 | z_vals = lower + (upper - lower) * perturb_rand 224 | 225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 227 | 228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals 229 | 230 | 231 | def read_pfm(filename): 232 | file = open(filename, 'rb') 233 | color = None 234 | width = None 235 | height = None 236 | scale = None 237 | endian = None 238 | 239 | header = file.readline().decode('utf-8').rstrip() 240 | if header == 'PF': 241 | color = True 242 | elif header == 'Pf': 243 | color = False 244 | else: 245 | raise Exception('Not a PFM file.') 246 | 247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 248 | if dim_match: 249 | width, height = map(int, dim_match.groups()) 250 | else: 251 | raise Exception('Malformed PFM header.') 252 | 253 | scale = float(file.readline().rstrip()) 254 | if scale < 0: # little-endian 255 | endian = '<' 256 | scale = -scale 257 | else: 258 | endian = '>' # big-endian 259 | 260 | data = np.fromfile(file, endian + 'f') 261 | shape = (height, width, 3) if color else (height, width) 262 | 263 | data = np.reshape(data, shape) 264 | data = np.flipud(data) 265 | file.close() 266 | return data, scale 267 | 268 | 269 | def ndc_bbox(all_rays): 270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] 271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] 272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] 273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] 274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') 275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) -------------------------------------------------------------------------------- /dataLoader/colmap2nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from pathlib import Path, PurePosixPath 14 | 15 | import numpy as np 16 | import json 17 | import sys 18 | import math 19 | import cv2 20 | import os 21 | import shutil 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") 25 | 26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") 27 | parser.add_argument("--video_fps", default=2) 28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") 29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") 31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 32 | parser.add_argument("--images", default="images", help="input path to the images") 33 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") 34 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") 35 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start") 36 | parser.add_argument("--out", default="transforms.json", help="output path") 37 | args = parser.parse_args() 38 | return args 39 | 40 | def do_system(arg): 41 | print(f"==== running: {arg}") 42 | err = os.system(arg) 43 | if err: 44 | print("FATAL: command failed") 45 | sys.exit(err) 46 | 47 | def run_ffmpeg(args): 48 | if not os.path.isabs(args.images): 49 | args.images = os.path.join(os.path.dirname(args.video_in), args.images) 50 | images = args.images 51 | video = args.video_in 52 | fps = float(args.video_fps) or 1.0 53 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") 54 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 55 | sys.exit(1) 56 | try: 57 | shutil.rmtree(images) 58 | except: 59 | pass 60 | do_system(f"mkdir {images}") 61 | 62 | time_slice_value = "" 63 | time_slice = args.time_slice 64 | if time_slice: 65 | start, end = time_slice.split(",") 66 | time_slice_value = f",select='between(t\,{start}\,{end})'" 67 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") 68 | 69 | def run_colmap(args): 70 | db=args.colmap_db 71 | images=args.images 72 | db_noext=str(Path(db).with_suffix("")) 73 | 74 | if args.text=="text": 75 | args.text=db_noext+"_text" 76 | text=args.text 77 | sparse=db_noext+"_sparse" 78 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") 79 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 80 | sys.exit(1) 81 | if os.path.exists(db): 82 | os.remove(db) 83 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") 84 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}") 85 | try: 86 | shutil.rmtree(sparse) 87 | except: 88 | pass 89 | do_system(f"mkdir {sparse}") 90 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") 91 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") 92 | try: 93 | shutil.rmtree(text) 94 | except: 95 | pass 96 | do_system(f"mkdir {text}") 97 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") 98 | 99 | def variance_of_laplacian(image): 100 | return cv2.Laplacian(image, cv2.CV_64F).var() 101 | 102 | def sharpness(imagePath): 103 | image = cv2.imread(imagePath) 104 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 105 | fm = variance_of_laplacian(gray) 106 | return fm 107 | 108 | def qvec2rotmat(qvec): 109 | return np.array([ 110 | [ 111 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 114 | ], [ 115 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 116 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 117 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 118 | ], [ 119 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 120 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 121 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 122 | ] 123 | ]) 124 | 125 | def rotmat(a, b): 126 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 127 | v = np.cross(a, b) 128 | c = np.dot(a, b) 129 | s = np.linalg.norm(v) 130 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 131 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 132 | 133 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 134 | da = da / np.linalg.norm(da) 135 | db = db / np.linalg.norm(db) 136 | c = np.cross(da, db) 137 | denom = np.linalg.norm(c)**2 138 | t = ob - oa 139 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 140 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 141 | if ta > 0: 142 | ta = 0 143 | if tb > 0: 144 | tb = 0 145 | return (oa+ta*da+ob+tb*db) * 0.5, denom 146 | 147 | if __name__ == "__main__": 148 | args = parse_args() 149 | if args.video_in != "": 150 | run_ffmpeg(args) 151 | if args.run_colmap: 152 | run_colmap(args) 153 | AABB_SCALE = int(args.aabb_scale) 154 | SKIP_EARLY = int(args.skip_early) 155 | IMAGE_FOLDER = args.images 156 | TEXT_FOLDER = args.text 157 | OUT_PATH = args.out 158 | print(f"outputting to {OUT_PATH}...") 159 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: 160 | angle_x = math.pi / 2 161 | for line in f: 162 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 163 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 164 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 165 | if line[0] == "#": 166 | continue 167 | els = line.split(" ") 168 | w = float(els[2]) 169 | h = float(els[3]) 170 | fl_x = float(els[4]) 171 | fl_y = float(els[4]) 172 | k1 = 0 173 | k2 = 0 174 | p1 = 0 175 | p2 = 0 176 | cx = w / 2 177 | cy = h / 2 178 | if els[1] == "SIMPLE_PINHOLE": 179 | cx = float(els[5]) 180 | cy = float(els[6]) 181 | elif els[1] == "PINHOLE": 182 | fl_y = float(els[5]) 183 | cx = float(els[6]) 184 | cy = float(els[7]) 185 | elif els[1] == "SIMPLE_RADIAL": 186 | cx = float(els[5]) 187 | cy = float(els[6]) 188 | k1 = float(els[7]) 189 | elif els[1] == "RADIAL": 190 | cx = float(els[5]) 191 | cy = float(els[6]) 192 | k1 = float(els[7]) 193 | k2 = float(els[8]) 194 | elif els[1] == "OPENCV": 195 | fl_y = float(els[5]) 196 | cx = float(els[6]) 197 | cy = float(els[7]) 198 | k1 = float(els[8]) 199 | k2 = float(els[9]) 200 | p1 = float(els[10]) 201 | p2 = float(els[11]) 202 | else: 203 | print("unknown camera model ", els[1]) 204 | # fl = 0.5 * w / tan(0.5 * angle_x); 205 | angle_x = math.atan(w / (fl_x * 2)) * 2 206 | angle_y = math.atan(h / (fl_y * 2)) * 2 207 | fovx = angle_x * 180 / math.pi 208 | fovy = angle_y * 180 / math.pi 209 | 210 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") 211 | 212 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: 213 | i = 0 214 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 215 | out = { 216 | "camera_angle_x": angle_x, 217 | "camera_angle_y": angle_y, 218 | "fl_x": fl_x, 219 | "fl_y": fl_y, 220 | "k1": k1, 221 | "k2": k2, 222 | "p1": p1, 223 | "p2": p2, 224 | "cx": cx, 225 | "cy": cy, 226 | "w": w, 227 | "h": h, 228 | "aabb_scale": AABB_SCALE, 229 | "frames": [], 230 | } 231 | 232 | up = np.zeros(3) 233 | for line in f: 234 | line = line.strip() 235 | if line[0] == "#": 236 | continue 237 | i = i + 1 238 | if i < SKIP_EARLY*2: 239 | continue 240 | if i % 2 == 1: 241 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 242 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) 243 | # why is this requireing a relitive path while using ^ 244 | image_rel = os.path.relpath(IMAGE_FOLDER) 245 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}") 246 | b=sharpness(name) 247 | print(name, "sharpness=",b) 248 | image_id = int(elems[0]) 249 | qvec = np.array(tuple(map(float, elems[1:5]))) 250 | tvec = np.array(tuple(map(float, elems[5:8]))) 251 | R = qvec2rotmat(-qvec) 252 | t = tvec.reshape([3,1]) 253 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 254 | c2w = np.linalg.inv(m) 255 | c2w[0:3,2] *= -1 # flip the y and z axis 256 | c2w[0:3,1] *= -1 257 | c2w = c2w[[1,0,2,3],:] # swap y and z 258 | c2w[2,:] *= -1 # flip whole world upside down 259 | 260 | up += c2w[0:3,1] 261 | 262 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} 263 | out["frames"].append(frame) 264 | nframes = len(out["frames"]) 265 | up = up / np.linalg.norm(up) 266 | print("up vector was", up) 267 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] 268 | R = np.pad(R,[0,1]) 269 | R[-1, -1] = 1 270 | 271 | 272 | for f in out["frames"]: 273 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 274 | 275 | # find a central point they are all looking at 276 | print("computing center of attention...") 277 | totw = 0.0 278 | totp = np.array([0.0, 0.0, 0.0]) 279 | for f in out["frames"]: 280 | mf = f["transform_matrix"][0:3,:] 281 | for g in out["frames"]: 282 | mg = g["transform_matrix"][0:3,:] 283 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 284 | if w > 0.01: 285 | totp += p*w 286 | totw += w 287 | totp /= totw 288 | print(totp) # the cameras are looking at totp 289 | for f in out["frames"]: 290 | f["transform_matrix"][0:3,3] -= totp 291 | 292 | avglen = 0. 293 | for f in out["frames"]: 294 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 295 | avglen /= nframes 296 | print("avg camera distance from origin", avglen) 297 | for f in out["frames"]: 298 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 299 | 300 | for f in out["frames"]: 301 | f["transform_matrix"] = f["transform_matrix"].tolist() 302 | print(nframes,"frames") 303 | print(f"writing {OUT_PATH}") 304 | with open(OUT_PATH, "w") as outfile: 305 | json.dump(out, outfile, indent=2) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2023 Zackary Shen 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/weighted_vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import torch.distributed as distributed 5 | from torch.cuda.amp import autocast 6 | 7 | from einops import rearrange, repeat 8 | from contextlib import contextmanager 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val, d): 14 | return val if exists(val) else d 15 | 16 | def noop(*args, **kwargs): 17 | pass 18 | 19 | def l2norm(t): 20 | return F.normalize(t, p = 2, dim = -1) 21 | 22 | def log(t, eps = 1e-20): 23 | return torch.log(t.clamp(min = eps)) 24 | 25 | def uniform_init(*shape): 26 | t = torch.empty(shape) 27 | nn.init.kaiming_uniform_(t) 28 | return t 29 | 30 | def gumbel_noise(t): 31 | noise = torch.zeros_like(t).uniform_(0, 1) 32 | return -log(-log(noise)) 33 | 34 | def gumbel_sample(t, temperature = 1., dim = -1): 35 | if temperature == 0: 36 | return t.argmax(dim = dim) 37 | 38 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 39 | 40 | def ema_inplace(moving_avg, new, decay): 41 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 42 | 43 | def laplace_smoothing(x, n_categories, eps = 1e-5): 44 | return (x + eps) / (x.sum() + n_categories * eps) 45 | 46 | def sample_vectors(samples, num): 47 | num_samples, device = samples.shape[0], samples.device 48 | if num_samples >= num: 49 | indices = torch.randperm(num_samples, device = device)[:num] 50 | else: 51 | indices = torch.randint(0, num_samples, (num,), device = device) 52 | 53 | return samples[indices] 54 | 55 | def batched_sample_vectors(samples, num): 56 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 57 | 58 | def pad_shape(shape, size, dim = 0): 59 | return [size if i == dim else s for i, s in enumerate(shape)] 60 | 61 | def sample_multinomial(total_count, probs): 62 | device = probs.device 63 | probs = probs.cpu() 64 | 65 | total_count = probs.new_full((), total_count) 66 | remainder = probs.new_ones(()) 67 | sample = torch.empty_like(probs, dtype = torch.long) 68 | 69 | for i, p in enumerate(probs): 70 | s = torch.binomial(total_count, p / remainder) 71 | sample[i] = s 72 | total_count -= s 73 | remainder -= p 74 | 75 | return sample.to(device) 76 | 77 | def all_gather_sizes(x, dim): 78 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 79 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 80 | distributed.all_gather(all_sizes, size) 81 | 82 | return torch.stack(all_sizes) 83 | 84 | def all_gather_variably_sized(x, sizes, dim = 0): 85 | rank = distributed.get_rank() 86 | all_x = [] 87 | 88 | for i, size in enumerate(sizes): 89 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 90 | distributed.broadcast(t, src = i, async_op = True) 91 | all_x.append(t) 92 | 93 | distributed.barrier() 94 | return all_x 95 | 96 | def sample_vectors_distributed(local_samples, num): 97 | rank = distributed.get_rank() 98 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 99 | 100 | if rank == 0: 101 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 102 | else: 103 | samples_per_rank = torch.empty_like(all_num_samples) 104 | 105 | distributed.broadcast(samples_per_rank, src = 0) 106 | samples_per_rank = samples_per_rank.tolist() 107 | 108 | local_samples = batched_sample_vectors(local_samples, samples_per_rank[rank]) 109 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 110 | return torch.cat(all_samples, dim = 0) 111 | 112 | def batched_bincount(x, *, minlength): 113 | batch, dtype, device = x.shape[0], x.dtype, x.device 114 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 115 | values = torch.ones_like(x) 116 | target.scatter_add_(-1, x, values) 117 | return target 118 | 119 | def kmeans( 120 | samples, 121 | num_clusters, 122 | num_iters = 10, 123 | use_cosine_sim = False, 124 | sample_fn = batched_sample_vectors, 125 | all_reduce_fn = noop 126 | ): 127 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 128 | 129 | means = sample_fn(samples, num_clusters) 130 | 131 | for _ in range(num_iters): 132 | if use_cosine_sim: 133 | dists = samples @ rearrange(means, 'h n d -> h d n') 134 | else: 135 | dists = -torch.cdist(samples, means, p = 2) 136 | 137 | buckets = torch.argmax(dists, dim = -1) 138 | bins = batched_bincount(buckets, minlength = num_clusters) 139 | all_reduce_fn(bins) 140 | 141 | zero_mask = bins == 0 142 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 143 | 144 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 145 | 146 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 147 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 148 | all_reduce_fn(new_means) 149 | 150 | if use_cosine_sim: 151 | new_means = l2norm(new_means) 152 | 153 | means = torch.where( 154 | rearrange(zero_mask, '... -> ... 1'), 155 | means, 156 | new_means 157 | ) 158 | 159 | return means, bins 160 | 161 | def batched_embedding(indices, embeds): 162 | batch, dim = indices.shape[1], embeds.shape[-1] 163 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 164 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 165 | return embeds.gather(2, indices) 166 | 167 | # regularization losses 168 | 169 | def orthogonal_loss_fn(t): 170 | # eq (2) from https://arxiv.org/abs/2112.00384 171 | h, n = t.shape[:2] 172 | normed_codes = l2norm(t) 173 | identity = repeat(torch.eye(n, device = t.device), 'i j -> h i j', h = h) 174 | cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) 175 | return ((cosine_sim - identity) ** 2).sum() / (h * n ** 2) 176 | 177 | # distance types 178 | 179 | class EuclideanCodebook(nn.Module): 180 | def __init__( 181 | self, 182 | dim, 183 | codebook_size, 184 | num_codebooks = 1, 185 | kmeans_init = False, 186 | kmeans_iters = 10, 187 | decay = 0.8, 188 | eps = 1e-5, 189 | threshold_ema_dead_code = 2, 190 | use_ddp = False, 191 | learnable_codebook = False, 192 | sample_codebook_temp = 0 193 | ): 194 | super().__init__() 195 | self.decay = decay 196 | init_fn = uniform_init if not kmeans_init else torch.zeros 197 | embed = init_fn(num_codebooks, codebook_size, dim) 198 | 199 | self.codebook_size = codebook_size 200 | self.num_codebooks = num_codebooks 201 | 202 | self.kmeans_iters = kmeans_iters 203 | self.eps = eps 204 | self.threshold_ema_dead_code = threshold_ema_dead_code 205 | self.sample_codebook_temp = sample_codebook_temp 206 | 207 | self.sample_fn = sample_vectors_distributed if use_ddp else batched_sample_vectors 208 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 209 | 210 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 211 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 212 | self.register_buffer('embed_avg', embed.clone()) 213 | 214 | self.learnable_codebook = learnable_codebook 215 | if learnable_codebook: 216 | self.embed = nn.Parameter(embed) 217 | else: 218 | self.register_buffer('embed', embed) 219 | 220 | @torch.jit.ignore 221 | def init_embed_(self, data): 222 | if self.initted: 223 | return 224 | 225 | embed, cluster_size = kmeans( 226 | data, 227 | self.codebook_size, 228 | self.kmeans_iters, 229 | sample_fn = self.sample_fn, 230 | all_reduce_fn = self.all_reduce_fn 231 | ) 232 | 233 | self.embed.data.copy_(embed) 234 | self.embed_avg.data.copy_(embed.clone()) 235 | self.cluster_size.data.copy_(cluster_size) 236 | self.initted.data.copy_(torch.Tensor([True])) 237 | 238 | def replace(self, batch_samples, batch_mask): 239 | batch_samples = l2norm(batch_samples) 240 | 241 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 242 | if not torch.any(mask): 243 | continue 244 | 245 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 246 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 247 | 248 | def expire_codes_(self, batch_samples, verbose): 249 | if self.threshold_ema_dead_code == 0: 250 | return 251 | 252 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 253 | 254 | if not torch.any(expired_codes): 255 | return 256 | if verbose: 257 | print(f'expire code count: {expired_codes.sum()}') 258 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 259 | self.replace(batch_samples, batch_mask = expired_codes) 260 | 261 | @autocast(enabled = False) 262 | def forward(self, x, weight=None, verbose=False): 263 | if weight is not None: 264 | weight = weight * weight.numel()/weight.sum() 265 | needs_codebook_dim = x.ndim < 4 266 | 267 | x = x.float() 268 | 269 | if needs_codebook_dim: 270 | x = rearrange(x, '... -> 1 ...') 271 | 272 | shape, dtype = x.shape, x.dtype 273 | flatten = rearrange(x, 'h ... d -> h (...) d') 274 | 275 | self.init_embed_(flatten) 276 | 277 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 278 | 279 | dist = -torch.cdist(flatten, embed, p = 2) 280 | 281 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 282 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 283 | embed_ind = embed_ind.view(*shape[:-1]) 284 | 285 | quantize = batched_embedding(embed_ind, self.embed) 286 | 287 | if self.training: 288 | 289 | if weight is not None: 290 | cluster_size = (embed_onehot*weight).sum(dim = 1) 291 | else: 292 | cluster_size = embed_onehot.sum(dim = 1) 293 | self.all_reduce_fn(cluster_size) 294 | ema_inplace(self.cluster_size, cluster_size, self.decay) 295 | 296 | if weight is not None: 297 | 298 | embed_sum = einsum('h n d, h n c -> h c d', flatten*weight, embed_onehot) 299 | else: 300 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 301 | self.all_reduce_fn(embed_sum) 302 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() 303 | 304 | # embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') 305 | # print("embed_normalized: ",embed_normalized, 306 | # "\n embed_avg: ",self.embed_avg, 307 | # "\n cluster_size: ", cluster_size) 308 | # self.embed.data.copy_(embed_normalized) 309 | # print("before ema: self.embed:", self.embed, "embed_sum: ", embed_sum) 310 | ema_inplace(self.embed, embed_sum/rearrange(cluster_size, '... -> ... 1'), self.decay) 311 | # print("after ema: self.embed:", self.embed, "embed_sum: ", embed_sum) 312 | self.expire_codes_(x, verbose) 313 | # print("after expire: self.embed:", self.embed, "embed_sum: ", embed_sum) 314 | 315 | if needs_codebook_dim: 316 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 317 | 318 | return quantize, embed_ind 319 | 320 | # main class 321 | 322 | class VectorQuantize(nn.Module): 323 | def __init__( 324 | self, 325 | dim, 326 | codebook_size, 327 | codebook_dim = None, 328 | heads = 1, 329 | separate_codebook_per_head = False, 330 | decay = 0.8, 331 | eps = 1e-5, 332 | kmeans_init = False, 333 | kmeans_iters = 10, 334 | use_cosine_sim = False, 335 | threshold_ema_dead_code = 0, 336 | channel_last = True, 337 | accept_image_fmap = False, 338 | commitment_weight = 1., 339 | orthogonal_reg_weight = 0., 340 | orthogonal_reg_active_codes_only = False, 341 | orthogonal_reg_max_codes = None, 342 | sample_codebook_temp = 0., 343 | sync_codebook = False 344 | ): 345 | super().__init__() 346 | self.heads = heads 347 | self.separate_codebook_per_head = separate_codebook_per_head 348 | 349 | codebook_dim = default(codebook_dim, dim) 350 | codebook_input_dim = codebook_dim * heads 351 | 352 | requires_projection = codebook_input_dim != dim 353 | self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 354 | self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 355 | 356 | self.eps = eps 357 | self.commitment_weight = commitment_weight 358 | 359 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 360 | self.orthogonal_reg_weight = orthogonal_reg_weight 361 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 362 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 363 | 364 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook 365 | 366 | self._codebook = codebook_class( 367 | dim = codebook_dim, 368 | num_codebooks = heads if separate_codebook_per_head else 1, 369 | codebook_size = codebook_size, 370 | kmeans_init = kmeans_init, 371 | kmeans_iters = kmeans_iters, 372 | decay = decay, 373 | eps = eps, 374 | threshold_ema_dead_code = threshold_ema_dead_code, 375 | use_ddp = sync_codebook, 376 | learnable_codebook = has_codebook_orthogonal_loss, 377 | sample_codebook_temp = sample_codebook_temp 378 | ) 379 | 380 | self.codebook_size = codebook_size 381 | 382 | self.accept_image_fmap = accept_image_fmap 383 | self.channel_last = channel_last 384 | 385 | @property 386 | def codebook(self): 387 | codebook = self._codebook.embed 388 | if self.separate_codebook_per_head: 389 | return codebook 390 | 391 | return rearrange(codebook, '1 ... -> ...') 392 | 393 | def forward(self, x, weight=None, verbose=False): 394 | shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size 395 | 396 | need_transpose = not self.channel_last and not self.accept_image_fmap 397 | 398 | if self.accept_image_fmap: 399 | height, width = x.shape[-2:] 400 | x = rearrange(x, 'b c h w -> b (h w) c') 401 | 402 | if need_transpose: 403 | x = rearrange(x, 'b d n -> b n d') 404 | 405 | x = self.project_in(x) 406 | 407 | if is_multiheaded: 408 | ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' 409 | x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) 410 | 411 | quantize, embed_ind = self._codebook(x, weight, verbose) 412 | 413 | if self.training: 414 | quantize = x + (quantize - x).detach() 415 | 416 | loss = torch.tensor([0.], device = device, requires_grad = self.training) 417 | 418 | if self.training: 419 | if self.commitment_weight > 0: 420 | commit_loss = F.mse_loss(quantize.detach(), x) 421 | loss = loss + commit_loss * self.commitment_weight 422 | 423 | if self.orthogonal_reg_weight > 0: 424 | codebook = self._codebook.embed 425 | 426 | if self.orthogonal_reg_active_codes_only: 427 | # only calculate orthogonal loss for the activated codes for this batch 428 | unique_code_ids = torch.unique(embed_ind) 429 | codebook = codebook[unique_code_ids] 430 | 431 | num_codes = codebook.shape[0] 432 | if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: 433 | rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] 434 | codebook = codebook[rand_ids] 435 | 436 | orthogonal_reg_loss = orthogonal_loss_fn(codebook) 437 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight 438 | 439 | if is_multiheaded: 440 | if self.separate_codebook_per_head: 441 | quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) 442 | embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) 443 | else: 444 | quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) 445 | embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) 446 | 447 | quantize = self.project_out(quantize) 448 | 449 | if need_transpose: 450 | quantize = rearrange(quantize, 'b n d -> b d n') 451 | 452 | if self.accept_image_fmap: 453 | quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) 454 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width) 455 | 456 | return quantize, embed_ind, loss -------------------------------------------------------------------------------- /models/tensorBase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from .sh import eval_sh_bases 5 | import numpy as np 6 | import time 7 | 8 | 9 | def positional_encoding(positions, freqs): 10 | 11 | freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) 12 | pts = (positions[..., None] * freq_bands).reshape( 13 | positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) 14 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) 15 | return pts 16 | 17 | def raw2alpha(sigma, dist): 18 | # sigma, dist [N_rays, N_samples] 19 | alpha = 1. - torch.exp(-sigma*dist) 20 | 21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) 22 | 23 | weights = alpha * T[:, :-1] # [N_rays, N_samples] 24 | return alpha, weights, T[:,-1:] 25 | 26 | 27 | def SHRender(xyz_sampled, viewdirs, features): 28 | sh_mult = eval_sh_bases(2, viewdirs)[:, None] 29 | rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) 30 | rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) 31 | return rgb 32 | 33 | 34 | def RGBRender(xyz_sampled, viewdirs, features): 35 | 36 | rgb = features 37 | return rgb 38 | 39 | class AlphaGridMask(torch.nn.Module): 40 | def __init__(self, device, aabb, alpha_volume): 41 | super(AlphaGridMask, self).__init__() 42 | self.device = device 43 | 44 | self.aabb=aabb.to(self.device) 45 | self.aabbSize = self.aabb[1] - self.aabb[0] 46 | self.invgridSize = 1.0/self.aabbSize * 2 47 | self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) 48 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) 49 | 50 | def sample_alpha(self, xyz_sampled): 51 | xyz_sampled = self.normalize_coord(xyz_sampled) 52 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) 53 | 54 | return alpha_vals 55 | 56 | def normalize_coord(self, xyz_sampled): 57 | return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 58 | 59 | 60 | class MLPRender_Fea(torch.nn.Module): 61 | def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): 62 | super(MLPRender_Fea, self).__init__() 63 | 64 | self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel 65 | self.viewpe = viewpe 66 | self.feape = feape 67 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 68 | layer2 = torch.nn.Linear(featureC, featureC) 69 | layer3 = torch.nn.Linear(featureC,3) 70 | 71 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 72 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 73 | 74 | def forward(self, pts, viewdirs, features): 75 | indata = [features, viewdirs] 76 | if self.feape > 0: 77 | indata += [positional_encoding(features, self.feape)] 78 | if self.viewpe > 0: 79 | indata += [positional_encoding(viewdirs, self.viewpe)] 80 | mlp_in = torch.cat(indata, dim=-1) 81 | rgb = self.mlp(mlp_in) 82 | rgb = torch.sigmoid(rgb) 83 | 84 | return rgb 85 | 86 | class MLPRender_PE(torch.nn.Module): 87 | def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): 88 | super(MLPRender_PE, self).__init__() 89 | 90 | self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # 91 | self.viewpe = viewpe 92 | self.pospe = pospe 93 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 94 | layer2 = torch.nn.Linear(featureC, featureC) 95 | layer3 = torch.nn.Linear(featureC,3) 96 | 97 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 98 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 99 | 100 | def forward(self, pts, viewdirs, features): 101 | indata = [features, viewdirs] 102 | if self.pospe > 0: 103 | indata += [positional_encoding(pts, self.pospe)] 104 | if self.viewpe > 0: 105 | indata += [positional_encoding(viewdirs, self.viewpe)] 106 | mlp_in = torch.cat(indata, dim=-1) 107 | rgb = self.mlp(mlp_in) 108 | rgb = torch.sigmoid(rgb) 109 | 110 | return rgb 111 | 112 | class MLPRender(torch.nn.Module): 113 | def __init__(self,inChanel, viewpe=6, featureC=128): 114 | super(MLPRender, self).__init__() 115 | 116 | self.in_mlpC = (3+2*viewpe*3) + inChanel 117 | self.viewpe = viewpe 118 | 119 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 120 | layer2 = torch.nn.Linear(featureC, featureC) 121 | layer3 = torch.nn.Linear(featureC,3) 122 | 123 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 124 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 125 | 126 | def forward(self, pts, viewdirs, features): 127 | indata = [features, viewdirs] 128 | if self.viewpe > 0: 129 | indata += [positional_encoding(viewdirs, self.viewpe)] 130 | mlp_in = torch.cat(indata, dim=-1) 131 | rgb = self.mlp(mlp_in) 132 | rgb = torch.sigmoid(rgb) 133 | 134 | return rgb 135 | 136 | 137 | 138 | class TensorBase(torch.nn.Module): 139 | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27, 140 | shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], 141 | density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001, 142 | pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0, 143 | fea2denseAct = 'softplus', **kargs): 144 | super(TensorBase, self).__init__() 145 | 146 | self.density_n_comp = density_n_comp 147 | self.app_n_comp = appearance_n_comp 148 | self.app_dim = app_dim 149 | self.aabb = aabb 150 | self.alphaMask = alphaMask 151 | self.device=device 152 | 153 | self.density_shift = density_shift 154 | self.alphaMask_thres = alphaMask_thres 155 | self.distance_scale = distance_scale 156 | self.rayMarch_weight_thres = rayMarch_weight_thres 157 | self.fea2denseAct = fea2denseAct 158 | 159 | self.near_far = near_far 160 | self.step_ratio = step_ratio 161 | 162 | 163 | self.update_stepSize(gridSize) 164 | 165 | self.matMode = [[0,1], [0,2], [1,2]] 166 | self.vecMode = [2, 1, 0] 167 | self.comp_w = [1,1,1] 168 | 169 | 170 | self.init_svd_volume(gridSize[0], device) 171 | 172 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC 173 | self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device) 174 | 175 | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device): 176 | if shadingMode == 'MLP_PE': 177 | self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device) 178 | elif shadingMode == 'MLP_Fea': 179 | self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device) 180 | elif shadingMode == 'MLP': 181 | self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device) 182 | elif shadingMode == 'SH': 183 | self.renderModule = SHRender 184 | elif shadingMode == 'RGB': 185 | assert self.app_dim == 3 186 | self.renderModule = RGBRender 187 | else: 188 | print("Unrecognized shading module") 189 | exit() 190 | # print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) 191 | print(self.renderModule) 192 | 193 | def update_stepSize(self, gridSize): 194 | # print("aabb", self.aabb.view(-1)) 195 | # print("grid size", gridSize) 196 | self.aabbSize = self.aabb[1] - self.aabb[0] 197 | self.invaabbSize = 2.0/self.aabbSize 198 | self.gridSize= torch.LongTensor(gridSize).to(self.device) 199 | self.units=self.aabbSize / (self.gridSize-1) 200 | self.stepSize=torch.mean(self.units)*self.step_ratio 201 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) 202 | self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 203 | # print("sampling step size: ", self.stepSize) 204 | # print("sampling number: ", self.nSamples) 205 | 206 | def init_svd_volume(self, res, device): 207 | pass 208 | 209 | def compute_features(self, xyz_sampled): 210 | pass 211 | 212 | def compute_densityfeature(self, xyz_sampled): 213 | pass 214 | 215 | def compute_appfeature(self, xyz_sampled): 216 | pass 217 | 218 | def normalize_coord(self, xyz_sampled): 219 | return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 220 | 221 | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): 222 | pass 223 | 224 | def get_kwargs(self): 225 | return { 226 | 'aabb': self.aabb, 227 | 'gridSize':self.gridSize.tolist(), 228 | 'density_n_comp': self.density_n_comp, 229 | 'appearance_n_comp': self.app_n_comp, 230 | 'app_dim': self.app_dim, 231 | 232 | 'density_shift': self.density_shift, 233 | 'alphaMask_thres': self.alphaMask_thres, 234 | 'distance_scale': self.distance_scale, 235 | 'rayMarch_weight_thres': self.rayMarch_weight_thres, 236 | 'fea2denseAct': self.fea2denseAct, 237 | 238 | 'near_far': self.near_far, 239 | 'step_ratio': self.step_ratio, 240 | 241 | 'shadingMode': self.shadingMode, 242 | 'pos_pe': self.pos_pe, 243 | 'view_pe': self.view_pe, 244 | 'fea_pe': self.fea_pe, 245 | 'featureC': self.featureC 246 | } 247 | 248 | def save(self, path): 249 | kwargs = self.get_kwargs() 250 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()} 251 | if self.alphaMask is not None: 252 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy() 253 | ckpt.update({'alphaMask.shape':alpha_volume.shape}) 254 | ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))}) 255 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()}) 256 | torch.save(ckpt, path) 257 | 258 | def load(self, ckpt): 259 | if 'alphaMask.aabb' in ckpt.keys(): 260 | length = np.prod(ckpt['alphaMask.shape']) 261 | alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape'])) 262 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device)) 263 | self.load_state_dict(ckpt['state_dict']) 264 | 265 | 266 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): 267 | N_samples = N_samples if N_samples > 0 else self.nSamples 268 | near, far = self.near_far 269 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) 270 | if is_train: 271 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) 272 | 273 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] 274 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) 275 | return rays_pts, interpx, ~mask_outbbox 276 | 277 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): 278 | N_samples = N_samples if N_samples>0 else self.nSamples 279 | stepsize = self.stepSize 280 | near, far = self.near_far 281 | vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) 282 | rate_a = (self.aabb[1] - rays_o) / vec 283 | rate_b = (self.aabb[0] - rays_o) / vec 284 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) 285 | 286 | rng = torch.arange(N_samples)[None].float() 287 | if is_train: 288 | rng = rng.repeat(rays_d.shape[-2],1) 289 | rng += torch.rand_like(rng[:,[0]]) 290 | step = stepsize * rng.to(rays_o.device) 291 | interpx = (t_min[...,None] + step) 292 | 293 | rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] 294 | mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) 295 | 296 | return rays_pts, interpx, ~mask_outbbox 297 | 298 | 299 | def shrink(self, new_aabb, voxel_size): 300 | pass 301 | 302 | @torch.no_grad() 303 | def getDenseAlpha(self,gridSize=None): 304 | gridSize = self.gridSize if gridSize is None else gridSize 305 | 306 | samples = torch.stack(torch.meshgrid( 307 | torch.linspace(0, 1, gridSize[0]), 308 | torch.linspace(0, 1, gridSize[1]), 309 | torch.linspace(0, 1, gridSize[2]), 310 | ), -1).to(self.device) 311 | dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples 312 | 313 | # dense_xyz = dense_xyz 314 | # print(self.stepSize, self.distance_scale*self.aabbDiag) 315 | alpha = torch.zeros_like(dense_xyz[...,0]) 316 | for i in range(gridSize[0]): 317 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2])) 318 | return alpha, dense_xyz 319 | 320 | @torch.no_grad() 321 | def updateAlphaMask(self, gridSize=(200,200,200)): 322 | 323 | alpha, dense_xyz = self.getDenseAlpha(gridSize) 324 | dense_xyz = dense_xyz.transpose(0,2).contiguous() 325 | alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None] 326 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2] 327 | 328 | ks = 3 329 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1]) 330 | alpha[alpha>=self.alphaMask_thres] = 1 331 | alpha[alpha0.5] 336 | 337 | xyz_min = valid_xyz.amin(0) 338 | xyz_max = valid_xyz.amax(0) 339 | 340 | new_aabb = torch.stack((xyz_min, xyz_max)) 341 | 342 | total = torch.sum(alpha) 343 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100)) 344 | return new_aabb 345 | 346 | @torch.no_grad() 347 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False): 348 | print('========> filtering rays ...') 349 | tt = time.time() 350 | 351 | N = torch.tensor(all_rays.shape[:-1]).prod() 352 | 353 | mask_filtered = [] 354 | idx_chunks = torch.split(torch.arange(N), chunk) 355 | for idx_chunk in idx_chunks: 356 | rays_chunk = all_rays[idx_chunk].to(self.device) 357 | 358 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6] 359 | if bbox_only: 360 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) 361 | rate_a = (self.aabb[1] - rays_o) / vec 362 | rate_b = (self.aabb[0] - rays_o) / vec 363 | t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far) 364 | t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far) 365 | mask_inbbox = t_max > t_min 366 | 367 | else: 368 | xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False) 369 | mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1) 370 | 371 | mask_filtered.append(mask_inbbox.cpu()) 372 | 373 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1]) 374 | 375 | print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}') 376 | return all_rays[mask_filtered], all_rgbs[mask_filtered] 377 | 378 | 379 | def feature2density(self, density_features): 380 | if self.fea2denseAct == "softplus": 381 | return F.softplus(density_features+self.density_shift) 382 | elif self.fea2denseAct == "relu": 383 | return F.relu(density_features) 384 | 385 | 386 | def compute_alpha(self, xyz_locs, length=1): 387 | 388 | if self.alphaMask is not None: 389 | alphas = self.alphaMask.sample_alpha(xyz_locs) 390 | alpha_mask = alphas > 0 391 | else: 392 | alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) 393 | 394 | 395 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) 396 | 397 | if alpha_mask.any(): 398 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) 399 | sigma_feature = self.compute_densityfeature(xyz_sampled) 400 | validsigma = self.feature2density(sigma_feature) 401 | sigma[alpha_mask] = validsigma 402 | 403 | 404 | alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) 405 | 406 | return alpha 407 | 408 | 409 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): 410 | 411 | # sample points 412 | viewdirs = rays_chunk[:, 3:6] 413 | if ndc_ray: 414 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) 415 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 416 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) 417 | dists = dists * rays_norm 418 | viewdirs = viewdirs / rays_norm 419 | else: 420 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) 421 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 422 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) 423 | 424 | if self.alphaMask is not None: 425 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) 426 | alpha_mask = alphas > 0 427 | ray_invalid = ~ray_valid 428 | ray_invalid[ray_valid] |= (~alpha_mask) 429 | ray_valid = ~ray_invalid 430 | 431 | 432 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) 433 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) 434 | 435 | if ray_valid.any(): 436 | xyz_sampled = self.normalize_coord(xyz_sampled) 437 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) 438 | 439 | validsigma = self.feature2density(sigma_feature) 440 | sigma[ray_valid] = validsigma 441 | 442 | 443 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) 444 | 445 | app_mask = weight > self.rayMarch_weight_thres 446 | 447 | if app_mask.any(): 448 | app_features = self.compute_appfeature(xyz_sampled[app_mask]) 449 | valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) 450 | rgb[app_mask] = valid_rgbs 451 | 452 | acc_map = torch.sum(weight, -1) 453 | rgb_map = torch.sum(weight[..., None] * rgb, -2) 454 | 455 | if white_bg or (is_train and torch.rand((1,))<0.5): 456 | rgb_map = rgb_map + (1. - acc_map[..., None]) 457 | 458 | 459 | rgb_map = rgb_map.clamp(0,1) 460 | 461 | with torch.no_grad(): 462 | depth_map = torch.sum(weight * z_vals, -1) 463 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] 464 | 465 | return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight 466 | 467 | -------------------------------------------------------------------------------- /models/my_vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import torch.distributed as distributed 5 | from torch.cuda.amp import autocast 6 | 7 | from einops import rearrange, repeat 8 | from contextlib import contextmanager 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val, d): 14 | return val if exists(val) else d 15 | 16 | def noop(*args, **kwargs): 17 | pass 18 | 19 | def l2norm(t): 20 | return F.normalize(t, p = 2, dim = -1) 21 | 22 | def log(t, eps = 1e-20): 23 | return torch.log(t.clamp(min = eps)) 24 | 25 | def uniform_init(*shape): 26 | t = torch.empty(shape) 27 | nn.init.kaiming_uniform_(t) 28 | return t 29 | 30 | def gumbel_noise(t): 31 | noise = torch.zeros_like(t).uniform_(0, 1) 32 | return -log(-log(noise)) 33 | 34 | def gumbel_sample(t, temperature = 1., dim = -1): 35 | if temperature == 0: 36 | return t.argmax(dim = dim) 37 | 38 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 39 | 40 | def ema_inplace(moving_avg, new, decay): 41 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 42 | 43 | def laplace_smoothing(x, n_categories, eps = 1e-5): 44 | return (x + eps) / (x.sum() + n_categories * eps) 45 | 46 | def sample_vectors(samples, num): 47 | num_samples, device = samples.shape[0], samples.device 48 | if num_samples >= num: 49 | indices = torch.randperm(num_samples, device = device)[:num] 50 | else: 51 | indices = torch.randint(0, num_samples, (num,), device = device) 52 | 53 | return samples[indices] 54 | 55 | def batched_sample_vectors(samples, num): 56 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 57 | 58 | def pad_shape(shape, size, dim = 0): 59 | return [size if i == dim else s for i, s in enumerate(shape)] 60 | 61 | def sample_multinomial(total_count, probs): 62 | device = probs.device 63 | probs = probs.cpu() 64 | 65 | total_count = probs.new_full((), total_count) 66 | remainder = probs.new_ones(()) 67 | sample = torch.empty_like(probs, dtype = torch.long) 68 | 69 | for i, p in enumerate(probs): 70 | s = torch.binomial(total_count, p / remainder) 71 | sample[i] = s 72 | total_count -= s 73 | remainder -= p 74 | 75 | return sample.to(device) 76 | 77 | def all_gather_sizes(x, dim): 78 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 79 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 80 | distributed.all_gather(all_sizes, size) 81 | 82 | return torch.stack(all_sizes) 83 | 84 | def all_gather_variably_sized(x, sizes, dim = 0): 85 | rank = distributed.get_rank() 86 | all_x = [] 87 | 88 | for i, size in enumerate(sizes): 89 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 90 | distributed.broadcast(t, src = i, async_op = True) 91 | all_x.append(t) 92 | 93 | distributed.barrier() 94 | return all_x 95 | 96 | def sample_vectors_distributed(local_samples, num): 97 | rank = distributed.get_rank() 98 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 99 | 100 | if rank == 0: 101 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 102 | else: 103 | samples_per_rank = torch.empty_like(all_num_samples) 104 | 105 | distributed.broadcast(samples_per_rank, src = 0) 106 | samples_per_rank = samples_per_rank.tolist() 107 | 108 | local_samples = batched_sample_vectors(local_samples, samples_per_rank[rank]) 109 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 110 | return torch.cat(all_samples, dim = 0) 111 | 112 | def batched_bincount(x, *, minlength): 113 | batch, dtype, device = x.shape[0], x.dtype, x.device 114 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 115 | values = torch.ones_like(x) 116 | target.scatter_add_(-1, x, values) 117 | return target 118 | 119 | def kmeans( 120 | samples, 121 | num_clusters, 122 | num_iters = 10, 123 | use_cosine_sim = False, 124 | sample_fn = batched_sample_vectors, 125 | all_reduce_fn = noop 126 | ): 127 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 128 | 129 | means = sample_fn(samples, num_clusters) 130 | 131 | for _ in range(num_iters): 132 | if use_cosine_sim: 133 | dists = samples @ rearrange(means, 'h n d -> h d n') 134 | else: 135 | dists = -torch.cdist(samples, means, p = 2) 136 | 137 | buckets = torch.argmax(dists, dim = -1) 138 | bins = batched_bincount(buckets, minlength = num_clusters) 139 | all_reduce_fn(bins) 140 | 141 | zero_mask = bins == 0 142 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 143 | 144 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 145 | 146 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 147 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 148 | all_reduce_fn(new_means) 149 | 150 | if use_cosine_sim: 151 | new_means = l2norm(new_means) 152 | 153 | means = torch.where( 154 | rearrange(zero_mask, '... -> ... 1'), 155 | means, 156 | new_means 157 | ) 158 | 159 | return means, bins 160 | 161 | def batched_embedding(indices, embeds): 162 | batch, dim = indices.shape[1], embeds.shape[-1] 163 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 164 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 165 | return embeds.gather(2, indices) 166 | 167 | # regularization losses 168 | 169 | def orthogonal_loss_fn(t): 170 | # eq (2) from https://arxiv.org/abs/2112.00384 171 | h, n = t.shape[:2] 172 | normed_codes = l2norm(t) 173 | identity = repeat(torch.eye(n, device = t.device), 'i j -> h i j', h = h) 174 | cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) 175 | return ((cosine_sim - identity) ** 2).sum() / (h * n ** 2) 176 | 177 | # distance types 178 | 179 | class EuclideanCodebook(nn.Module): 180 | def __init__( 181 | self, 182 | dim, 183 | codebook_size, 184 | num_codebooks = 1, 185 | kmeans_init = False, 186 | kmeans_iters = 10, 187 | decay = 0.8, 188 | eps = 1e-5, 189 | threshold_ema_dead_code = 2, 190 | use_ddp = False, 191 | learnable_codebook = False, 192 | sample_codebook_temp = 0 193 | ): 194 | super().__init__() 195 | self.decay = decay 196 | init_fn = uniform_init if not kmeans_init else torch.zeros 197 | embed = init_fn(num_codebooks, codebook_size, dim) 198 | 199 | self.codebook_size = codebook_size 200 | self.num_codebooks = num_codebooks 201 | 202 | self.kmeans_iters = kmeans_iters 203 | self.eps = eps 204 | self.threshold_ema_dead_code = threshold_ema_dead_code 205 | self.sample_codebook_temp = sample_codebook_temp 206 | 207 | self.sample_fn = sample_vectors_distributed if use_ddp else batched_sample_vectors 208 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 209 | 210 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 211 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 212 | self.register_buffer('embed_avg', embed.clone()) 213 | 214 | self.learnable_codebook = learnable_codebook 215 | if learnable_codebook: 216 | self.embed = nn.Parameter(embed) 217 | else: 218 | self.register_buffer('embed', embed) 219 | 220 | @torch.jit.ignore 221 | def init_embed_(self, data): 222 | if self.initted: 223 | return 224 | 225 | embed, cluster_size = kmeans( 226 | data, 227 | self.codebook_size, 228 | self.kmeans_iters, 229 | sample_fn = self.sample_fn, 230 | all_reduce_fn = self.all_reduce_fn 231 | ) 232 | 233 | self.embed.data.copy_(embed) 234 | self.embed_avg.data.copy_(embed.clone()) 235 | self.cluster_size.data.copy_(cluster_size) 236 | self.initted.data.copy_(torch.Tensor([True])) 237 | 238 | def replace(self, batch_samples, batch_mask): 239 | batch_samples = l2norm(batch_samples) 240 | 241 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 242 | if not torch.any(mask): 243 | continue 244 | 245 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 246 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 247 | 248 | def expire_codes_(self, batch_samples): 249 | if self.threshold_ema_dead_code == 0: 250 | return 251 | 252 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 253 | 254 | if not torch.any(expired_codes): 255 | return 256 | 257 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 258 | self.replace(batch_samples, batch_mask = expired_codes) 259 | 260 | @autocast(enabled = False) 261 | def forward(self, x): 262 | needs_codebook_dim = x.ndim < 4 263 | 264 | x = x.float() 265 | 266 | if needs_codebook_dim: 267 | x = rearrange(x, '... -> 1 ...') 268 | 269 | shape, dtype = x.shape, x.dtype 270 | flatten = rearrange(x, 'h ... d -> h (...) d') 271 | 272 | self.init_embed_(flatten) 273 | 274 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 275 | 276 | dist = -torch.cdist(flatten, embed, p = 2) 277 | 278 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 279 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 280 | embed_ind = embed_ind.view(*shape[:-1]) 281 | 282 | quantize = batched_embedding(embed_ind, self.embed) 283 | 284 | if self.training: 285 | cluster_size = embed_onehot.sum(dim = 1) 286 | self.all_reduce_fn(cluster_size) 287 | ema_inplace(self.cluster_size, cluster_size, self.decay) 288 | 289 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 290 | self.all_reduce_fn(embed_sum) 291 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() 292 | 293 | # embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') 294 | # print("embed_normalized: ",embed_normalized, 295 | # "\n embed_avg: ",self.embed_avg, 296 | # "\n cluster_size: ", cluster_size) 297 | # self.embed.data.copy_(embed_normalized) 298 | # print("before ema: self.embed:", self.embed, "embed_sum: ", embed_sum) 299 | ema_inplace(self.embed, embed_sum/rearrange(cluster_size, '... -> ... 1'), self.decay) 300 | # print("after ema: self.embed:", self.embed, "embed_sum: ", embed_sum) 301 | self.expire_codes_(x) 302 | # print("after expire: self.embed:", self.embed, "embed_sum: ", embed_sum) 303 | 304 | if needs_codebook_dim: 305 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 306 | 307 | return quantize, embed_ind 308 | 309 | 310 | class CosineSimCodebook(nn.Module): 311 | def __init__( 312 | self, 313 | dim, 314 | codebook_size, 315 | num_codebooks = 1, 316 | kmeans_init = False, 317 | kmeans_iters = 10, 318 | decay = 0.8, 319 | eps = 1e-5, 320 | threshold_ema_dead_code = 2, 321 | use_ddp = False, 322 | learnable_codebook = False, 323 | sample_codebook_temp = 0. 324 | ): 325 | super().__init__() 326 | self.decay = decay 327 | 328 | if not kmeans_init: 329 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) 330 | else: 331 | embed = torch.zeros(num_codebooks, codebook_size, dim) 332 | 333 | self.codebook_size = codebook_size 334 | self.num_codebooks = num_codebooks 335 | 336 | self.kmeans_iters = kmeans_iters 337 | self.eps = eps 338 | self.threshold_ema_dead_code = threshold_ema_dead_code 339 | self.sample_codebook_temp = sample_codebook_temp 340 | 341 | self.sample_fn = sample_vectors_distributed if use_ddp else batched_sample_vectors 342 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 343 | 344 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 345 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 346 | 347 | self.learnable_codebook = learnable_codebook 348 | if learnable_codebook: 349 | self.embed = nn.Parameter(embed) 350 | else: 351 | self.register_buffer('embed', embed) 352 | 353 | @torch.jit.ignore 354 | def init_embed_(self, data): 355 | if self.initted: 356 | return 357 | 358 | embed, cluster_size = kmeans( 359 | data, 360 | self.codebook_size, 361 | self.kmeans_iters, 362 | use_cosine_sim = True, 363 | sample_fn = self.sample_fn, 364 | all_reduce_fn = self.all_reduce_fn 365 | ) 366 | 367 | self.embed.data.copy_(embed) 368 | self.cluster_size.data.copy_(cluster_size) 369 | self.initted.data.copy_(torch.Tensor([True])) 370 | 371 | def replace(self, batch_samples, batch_mask): 372 | batch_samples = l2norm(batch_samples) 373 | 374 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 375 | if not torch.any(mask): 376 | continue 377 | 378 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 379 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 380 | 381 | def expire_codes_(self, batch_samples): 382 | if self.threshold_ema_dead_code == 0: 383 | return 384 | 385 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 386 | 387 | if not torch.any(expired_codes): 388 | return 389 | 390 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 391 | self.replace(batch_samples, batch_mask = expired_codes) 392 | 393 | @autocast(enabled = False) 394 | def forward(self, x): 395 | needs_codebook_dim = x.ndim < 4 396 | 397 | x = x.float() 398 | 399 | if needs_codebook_dim: 400 | x = rearrange(x, '... -> 1 ...') 401 | 402 | shape, dtype = x.shape, x.dtype 403 | 404 | flatten = rearrange(x, 'h ... d -> h (...) d') 405 | flatten = l2norm(flatten) 406 | 407 | self.init_embed_(flatten) 408 | 409 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 410 | embed = l2norm(embed) 411 | 412 | dist = einsum('h n d, h c d -> h n c', flatten, embed) 413 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 414 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 415 | embed_ind = embed_ind.view(*shape[:-1]) 416 | 417 | quantize = batched_embedding(embed_ind, self.embed) 418 | 419 | if self.training: 420 | bins = embed_onehot.sum(dim = 1) 421 | self.all_reduce_fn(bins) 422 | 423 | ema_inplace(self.cluster_size, bins, self.decay) 424 | 425 | zero_mask = (bins == 0) 426 | bins = bins.masked_fill(zero_mask, 1.) 427 | 428 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 429 | self.all_reduce_fn(embed_sum) 430 | 431 | embed_normalized = embed_sum / rearrange(bins, '... -> ... 1') 432 | embed_normalized = l2norm(embed_normalized) 433 | 434 | embed_normalized = torch.where( 435 | rearrange(zero_mask, '... -> ... 1'), 436 | embed, 437 | embed_normalized 438 | ) 439 | 440 | ema_inplace(self.embed, embed_normalized, self.decay) 441 | self.expire_codes_(x) 442 | 443 | if needs_codebook_dim: 444 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 445 | 446 | return quantize, embed_ind 447 | # main class 448 | 449 | class VectorQuantize(nn.Module): 450 | def __init__( 451 | self, 452 | dim, 453 | codebook_size, 454 | codebook_dim = None, 455 | heads = 1, 456 | separate_codebook_per_head = False, 457 | decay = 0.8, 458 | eps = 1e-5, 459 | kmeans_init = False, 460 | kmeans_iters = 10, 461 | use_cosine_sim = False, 462 | threshold_ema_dead_code = 0, 463 | channel_last = True, 464 | accept_image_fmap = False, 465 | commitment_weight = 1., 466 | orthogonal_reg_weight = 0., 467 | orthogonal_reg_active_codes_only = False, 468 | orthogonal_reg_max_codes = None, 469 | sample_codebook_temp = 0., 470 | sync_codebook = False 471 | ): 472 | super().__init__() 473 | self.heads = heads 474 | self.separate_codebook_per_head = separate_codebook_per_head 475 | 476 | codebook_dim = default(codebook_dim, dim) 477 | codebook_input_dim = codebook_dim * heads 478 | 479 | requires_projection = codebook_input_dim != dim 480 | self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 481 | self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 482 | 483 | self.eps = eps 484 | self.commitment_weight = commitment_weight 485 | 486 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 487 | self.orthogonal_reg_weight = orthogonal_reg_weight 488 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 489 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 490 | 491 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook 492 | 493 | self._codebook = codebook_class( 494 | dim = codebook_dim, 495 | num_codebooks = heads if separate_codebook_per_head else 1, 496 | codebook_size = codebook_size, 497 | kmeans_init = kmeans_init, 498 | kmeans_iters = kmeans_iters, 499 | decay = decay, 500 | eps = eps, 501 | threshold_ema_dead_code = threshold_ema_dead_code, 502 | use_ddp = sync_codebook, 503 | learnable_codebook = has_codebook_orthogonal_loss, 504 | sample_codebook_temp = sample_codebook_temp 505 | ) 506 | 507 | self.codebook_size = codebook_size 508 | 509 | self.accept_image_fmap = accept_image_fmap 510 | self.channel_last = channel_last 511 | 512 | @property 513 | def codebook(self): 514 | codebook = self._codebook.embed 515 | if self.separate_codebook_per_head: 516 | return codebook 517 | 518 | return rearrange(codebook, '1 ... -> ...') 519 | 520 | def forward(self, x): 521 | shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size 522 | 523 | need_transpose = not self.channel_last and not self.accept_image_fmap 524 | 525 | if self.accept_image_fmap: 526 | height, width = x.shape[-2:] 527 | x = rearrange(x, 'b c h w -> b (h w) c') 528 | 529 | if need_transpose: 530 | x = rearrange(x, 'b d n -> b n d') 531 | 532 | x = self.project_in(x) 533 | 534 | if is_multiheaded: 535 | ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' 536 | x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) 537 | 538 | quantize, embed_ind = self._codebook(x) 539 | 540 | if self.training: 541 | quantize = x + (quantize - x).detach() 542 | 543 | loss = torch.tensor([0.], device = device, requires_grad = self.training) 544 | 545 | if self.training: 546 | if self.commitment_weight > 0: 547 | commit_loss = F.mse_loss(quantize.detach(), x) 548 | loss = loss + commit_loss * self.commitment_weight 549 | 550 | if self.orthogonal_reg_weight > 0: 551 | codebook = self._codebook.embed 552 | 553 | if self.orthogonal_reg_active_codes_only: 554 | # only calculate orthogonal loss for the activated codes for this batch 555 | unique_code_ids = torch.unique(embed_ind) 556 | codebook = codebook[unique_code_ids] 557 | 558 | num_codes = codebook.shape[0] 559 | if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: 560 | rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] 561 | codebook = codebook[rand_ids] 562 | 563 | orthogonal_reg_loss = orthogonal_loss_fn(codebook) 564 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight 565 | 566 | if is_multiheaded: 567 | if self.separate_codebook_per_head: 568 | quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) 569 | embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) 570 | else: 571 | quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) 572 | embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) 573 | 574 | quantize = self.project_out(quantize) 575 | 576 | if need_transpose: 577 | quantize = rearrange(quantize, 'b n d -> b d n') 578 | 579 | if self.accept_image_fmap: 580 | quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) 581 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width) 582 | 583 | return quantize, embed_ind, loss --------------------------------------------------------------------------------